Skip to content

TransformerNUFFT: add chunk_size knob to cap nufftax gather buffer#330

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/nufft-chunking
May 22, 2026
Merged

TransformerNUFFT: add chunk_size knob to cap nufftax gather buffer#330
Jammy2211 merged 1 commit into
mainfrom
feature/nufft-chunking

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • Add chunk_size: Optional[int] = None to TransformerNUFFT.__init__. When set, splits the visibility axis into fixed-size chunks during both forward (_forward_native) and adjoint (image_from) NUFFT calls.
  • JAX path uses jax.lax.scan with dynamic_slice over a padded uv axis — keeps the compiled HLO graph bounded regardless of n_chunks. NumPy path uses a plain Python loop. Default chunk_size=None preserves the existing one-shot behaviour exactly (no regression for SMA-class callers).
  • 5 new parity + JIT-smoke tests in test_transformer.py.

Motivation

nufftax's spread-step gather buffer is 2 × N_vis × nspread² × dtype_size. At 5M visibilities with default eps=1e-6 (nspread=14) and complex64 that's 15.7 GB for a single intermediate; with adjacent JAX intermediates an A100 80 GB tips over. This is the upstream blocker preventing ALMA-high-class visibility counts on GPU — both the interferometer simulator (forward NUFFT) and the apply_sparse_operator setup path (adjoint NUFFT).

With chunk_size = 1_000_000 the per-call gather buffer drops to ~3 GB regardless of total visibility count.

Scope notes

  • transform_mapping_matrix is intentionally NOT chunked in this PR. The sparse-operator runtime path never calls it per-likelihood (the W-Tilde sparse formalism replaces it — see PR interferometer: enable sparse_operator for nufftax TransformerNUFFT #329 + autolens_profiling#22). The batched form is (n_src, N_vis, nspread²) and needs a separate chunk-size analysis with n_src factored in. Flagged as a follow-up.
  • Future upstream PR to nufftax itself could move chunking inside interp_2d_impl, but the autoarray-side chunking is sufficient to unblock the immediate work.

Test plan

  • pytest test_autoarray/ (828 passed locally — 5 new tests on top of 823 baseline)
  • New tests: chunk_size validation; numpy parity (forward + adjoint) at rtol=1e-6; JAX parity at rtol=1e-6; jax.jit smoke trace via lax.scan
  • Follow-up: autolens_profiling INSTRUMENTS dict wiring transformer_chunk_size per instrument (sma=None, alma=None, alma_high=1_000_000) — once that lands, re-run the 4 blocked alma_high SLURM submits

🤖 Generated with Claude Code

nufftax's spread-step gather buffer scales as 2 * N_vis * nspread^2 *
dtype_size. At 5M+ visibilities (nspread=14, complex64) the single
gather allocation hits 15.7 GB, exceeding A100 80 GB headroom once
the other intermediates are accounted for. This blocks ALMA-high-class
visibility counts on GPU.

Adds an opt-in `chunk_size: Optional[int] = None` parameter to
TransformerNUFFT.__init__. When set, _forward_native and image_from
split the visibility axis into fixed-size chunks and iterate:

- JAX path: jax.lax.scan with dynamic_slice over a padded uv axis.
  Output for forward is concatenated chunks (trimmed to K); adjoint
  accumulates per-chunk images additively into a fixed (N_y, N_x)
  buffer. Using lax.scan rather than a Python for loop keeps the
  compiled HLO graph bounded regardless of n_chunks — critical for
  JIT compile time at large N_vis / chunk_size ratios.

- NumPy path: plain Python loop with np.concatenate / additive image
  accumulator. Same algorithm, no JIT concerns.

Default chunk_size=None preserves the existing one-shot behaviour for
small-N callers (SMA-class datasets at 190 visibilities), so there is
no per-call overhead regression for existing code.

transform_mapping_matrix is intentionally NOT chunked in this PR. The
sparse-operator runtime path never calls it per-likelihood (the W-Tilde
sparse formalism replaces it), and the batched (n_src, N_vis, nspread^2)
gather buffer needs a separate chunk-size analysis with n_src
factored in. Flagged as a follow-up.

5 new tests assert:
- chunk_size <= 0 raises ValueError
- chunked forward / adjoint match unchunked at rtol=1e-6 (numpy)
- chunked forward / adjoint match unchunked at rtol=1e-6 (JAX)
- jax.jit of chunked forward traces cleanly via lax.scan

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit b2a3a75 into main May 22, 2026
6 checks passed
@Jammy2211 Jammy2211 deleted the feature/nufft-chunking branch May 22, 2026 17:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant