TransformerNUFFT: add chunk_size knob to cap nufftax gather buffer#330
Merged
Conversation
nufftax's spread-step gather buffer scales as 2 * N_vis * nspread^2 * dtype_size. At 5M+ visibilities (nspread=14, complex64) the single gather allocation hits 15.7 GB, exceeding A100 80 GB headroom once the other intermediates are accounted for. This blocks ALMA-high-class visibility counts on GPU. Adds an opt-in `chunk_size: Optional[int] = None` parameter to TransformerNUFFT.__init__. When set, _forward_native and image_from split the visibility axis into fixed-size chunks and iterate: - JAX path: jax.lax.scan with dynamic_slice over a padded uv axis. Output for forward is concatenated chunks (trimmed to K); adjoint accumulates per-chunk images additively into a fixed (N_y, N_x) buffer. Using lax.scan rather than a Python for loop keeps the compiled HLO graph bounded regardless of n_chunks — critical for JIT compile time at large N_vis / chunk_size ratios. - NumPy path: plain Python loop with np.concatenate / additive image accumulator. Same algorithm, no JIT concerns. Default chunk_size=None preserves the existing one-shot behaviour for small-N callers (SMA-class datasets at 190 visibilities), so there is no per-call overhead regression for existing code. transform_mapping_matrix is intentionally NOT chunked in this PR. The sparse-operator runtime path never calls it per-likelihood (the W-Tilde sparse formalism replaces it), and the batched (n_src, N_vis, nspread^2) gather buffer needs a separate chunk-size analysis with n_src factored in. Flagged as a follow-up. 5 new tests assert: - chunk_size <= 0 raises ValueError - chunked forward / adjoint match unchunked at rtol=1e-6 (numpy) - chunked forward / adjoint match unchunked at rtol=1e-6 (JAX) - jax.jit of chunked forward traces cleanly via lax.scan Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
chunk_size: Optional[int] = NonetoTransformerNUFFT.__init__. When set, splits the visibility axis into fixed-size chunks during both forward (_forward_native) and adjoint (image_from) NUFFT calls.jax.lax.scanwithdynamic_sliceover a padded uv axis — keeps the compiled HLO graph bounded regardless ofn_chunks. NumPy path uses a plain Python loop. Defaultchunk_size=Nonepreserves the existing one-shot behaviour exactly (no regression for SMA-class callers).test_transformer.py.Motivation
nufftax's spread-step gather buffer is
2 × N_vis × nspread² × dtype_size. At 5M visibilities with defaulteps=1e-6(nspread=14) and complex64 that's 15.7 GB for a single intermediate; with adjacent JAX intermediates an A100 80 GB tips over. This is the upstream blocker preventing ALMA-high-class visibility counts on GPU — both the interferometer simulator (forward NUFFT) and theapply_sparse_operatorsetup path (adjoint NUFFT).With
chunk_size = 1_000_000the per-call gather buffer drops to ~3 GB regardless of total visibility count.Scope notes
transform_mapping_matrixis intentionally NOT chunked in this PR. The sparse-operator runtime path never calls it per-likelihood (the W-Tilde sparse formalism replaces it — see PR interferometer: enable sparse_operator for nufftax TransformerNUFFT #329 + autolens_profiling#22). The batched form is(n_src, N_vis, nspread²)and needs a separate chunk-size analysis withn_srcfactored in. Flagged as a follow-up.interp_2d_impl, but the autoarray-side chunking is sufficient to unblock the immediate work.Test plan
pytest test_autoarray/(828 passed locally — 5 new tests on top of 823 baseline)jax.jitsmoke trace vialax.scantransformer_chunk_sizeper instrument (sma=None, alma=None, alma_high=1_000_000) — once that lands, re-run the 4 blocked alma_high SLURM submits🤖 Generated with Claude Code