diff --git a/autoarray/__init__.py b/autoarray/__init__.py index 8c2319272..9a5d2369e 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -60,6 +60,7 @@ from .mask.mask_2d import Mask2D from .operators.transformer import TransformerDFT from .operators.transformer import TransformerNUFFT +from .operators.transformer import TransformerNUFFTPyNUFFT from .operators.over_sampling.decorator import over_sample from .operators.contour import Grid2DContour from .layout.layout import Layout1D diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 0126e13bb..894b979c3 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -251,6 +251,29 @@ def apply_sparse_operator( enabling efficient pixelized source reconstruction via the sparse linear algebra formalism. """ + if isinstance(self.transformer, TransformerNUFFT): + raise NotImplementedError( + "\n--------------------\n" + "`apply_sparse_operator` is not yet supported with the default " + "`TransformerNUFFT` (nufftax-backed) transformer.\n\n" + "The sparse-operator path consumes the dirty image returned by " + "`transformer.image_from(use_adjoint_scaling=True)` together with " + "the NUFFT precision operator; their relative scale matters. The " + "new `TransformerNUFFT` returns the strict mathematical adjoint " + "(matching `TransformerDFT`), whereas the legacy pynufft adjoint " + "applies an internal Kaiser-Bessel kernel deconvolution. The two " + "scales differ by a non-constant factor, so feeding the new " + "dirty image into the existing sparse-operator solver would " + "silently give wrong answers.\n\n" + "Workarounds:\n" + " - Build the dataset with `transformer_class=TransformerDFT` " + "(the JAX-likelihood scripts do this today), or\n" + " - Build the dataset with " + "`transformer_class=TransformerNUFFTPyNUFFT` to keep the legacy " + "pynufft adjoint scale (requires `pip install pynufft`).\n" + "----------------------" + ) + if nufft_precision_operator is None: logger.info( diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 582f6c810..3e9e80cd0 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -23,16 +23,36 @@ class NUFFTPlaceholder: from autoarray.operators import transformer_util +try: + import nufftax as _nufftax +except ModuleNotFoundError: + _nufftax = None + + def pynufft_exception(): raise ModuleNotFoundError( "\n--------------------\n" - "You are attempting to perform interferometer analysis.\n\n" + "You are attempting to perform interferometer analysis with the legacy " + "pynufft-backed `TransformerNUFFTPyNUFFT`.\n\n" "However, the optional library PyNUFFT (https://github.com/jyhmiinlin/pynufft) is not installed.\n\n" "Install it via the command `pip install pynufft==2022.2.2`.\n\n" "----------------------" ) +def nufftax_exception(): + raise ModuleNotFoundError( + "\n--------------------\n" + "You are attempting to perform interferometer analysis with the default " + "JAX-native `TransformerNUFFT`.\n\n" + "However, the optional library nufftax (https://github.com/GragasLab/nufftax) is not installed.\n\n" + "Install it via the command `pip install nufftax`.\n\n" + "If you want to use the legacy pynufft backend instead, pass " + "`transformer_class=TransformerNUFFTPyNUFFT` and install pynufft.\n\n" + "----------------------" + ) + + class TransformerDFT: def __init__( self, @@ -175,13 +195,18 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar ) -class TransformerNUFFT(NUFFT_cpu): +class TransformerNUFFTPyNUFFT(NUFFT_cpu): def __init__( self, uv_wavelengths: np.ndarray, real_space_mask: Mask2D, xp=np, **kwargs ): """ Performs the Non-Uniform Fast Fourier Transform (NUFFT) for interferometric image reconstruction. + Legacy pynufft-backed transformer. The default `TransformerNUFFT` is now backed by `nufftax` + (JAX-native, differentiable, ~zero gridding error) — this class is retained so users who depend + on pynufft's specific gridding behaviour can opt in by passing + `transformer_class=TransformerNUFFTPyNUFFT`. + This transformer uses the PyNUFFT library to efficiently compute the Fourier transform of an image defined on a regular real-space grid to a set of non-uniform uv-plane (Fourier space) coordinates, as is typical in radio interferometry. @@ -226,7 +251,7 @@ def __init__( if isinstance(self, NUFFTPlaceholder): pynufft_exception() - super(TransformerNUFFT, self).__init__() + super(TransformerNUFFTPyNUFFT, self).__init__() self.uv_wavelengths = uv_wavelengths self.real_space_mask = real_space_mask @@ -469,3 +494,212 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar ) return transformed_mapping_matrix + + +class TransformerNUFFT: + def __init__( + self, + uv_wavelengths: np.ndarray, + real_space_mask: Mask2D, + eps: float = 1e-12, + xp=np, + **kwargs, + ): + """ + JAX-native Non-Uniform FFT for image -> visibilities, backed by `nufftax`. + + This is the default `TransformerNUFFT` in PyAutoArray. It uses the + `nufftax` library (https://github.com/GragasLab/nufftax), a pure-JAX + NUFFT implementation that supports `jax.jit`, `jax.grad`, and + `jax.vmap`. It replaces the legacy `TransformerNUFFTPyNUFFT` (which + wraps the non-differentiable `pynufft` library) as the default backend. + + Convention recipe (matches `TransformerDFT` to ~1e-13 relative across + odd/even/non-square image sizes): + + image_flipped = image[::-1, :] + x = 2 * pi * u_lambda * pixel_scale_rad + y = 2 * pi * v_lambda * pixel_scale_rad + offset_x = 0.5 if N_x is even else 0.0 + offset_y = 0.5 if N_y is even else 0.0 + shift = exp(-i * (offset_x * x + offset_y * y)) + visibilities = nufftax.nufft2d2(x, y, image_flipped, eps, -1) * shift + + The `shift` factor is the half-pixel correction between autoarray's + grid centre at index `(N - 1) / 2` and nufftax's mode-0 at index + `N // 2`; pynufft applies this internally, nufftax does not. + + Parameters + ---------- + uv_wavelengths + The (u, v) coordinates of the measured visibilities in wavelengths, + shape `(n_vis, 2)`. + real_space_mask + The 2D mask defining the real-space image grid. + eps + Requested NUFFT precision passed to nufftax. Defaults to `1e-12` + (effectively machine precision); relax to `1e-9` or `1e-6` for + faster execution if marginal accuracy is acceptable. + xp + Accepted for signature compatibility with the legacy class; not + stored. The active backend is selected per-call via the `xp` + argument to `visibilities_from` / `image_from`. + + Attributes + ---------- + grid + The real-space pixel grid in radians (computed from the mask). + total_visibilities + Number of measured visibilities. + total_image_pixels + Number of unmasked pixels in the image grid. + adjoint_scaling + Scaling factor available for callers who want to apply an + optional normalisation to the adjoint output. Provided for + parity with the legacy class. + """ + from astropy import units + + if _nufftax is None: + nufftax_exception() + + self.uv_wavelengths = uv_wavelengths.astype("float") + self.real_space_mask = real_space_mask + self.grid = Grid2D.from_mask(mask=self.real_space_mask).in_radians + self.eps = eps + self.native_index_for_slim_index = copy.copy( + real_space_mask.derive_indexes.native_for_slim.astype("int") + ) + + pixel_scale_rad = self.grid.pixel_scales[0] * units.arcsec.to(units.rad) + # nufft2d2 frequency arguments: + # x is paired with the column-axis mode (image x) + # y is paired with the row-axis mode (image y) + # Both must lie in [-pi, pi); the 2*pi*Δ_rad scaling makes uv_lambda + # land in that range for any sane uv-coverage. + self._x = 2.0 * np.pi * self.uv_wavelengths[:, 0] * pixel_scale_rad + self._y = 2.0 * np.pi * self.uv_wavelengths[:, 1] * pixel_scale_rad + + n_y, n_x = self.real_space_mask.shape_native + offset_x = 0.5 if n_x % 2 == 0 else 0.0 + offset_y = 0.5 if n_y % 2 == 0 else 0.0 + self._shift = np.exp(-1j * (offset_x * self._x + offset_y * self._y)) + + self.total_visibilities = uv_wavelengths.shape[0] + self.total_image_pixels = real_space_mask.pixels_in_mask + self.adjoint_scaling = (2.0 * n_y) * (2.0 * n_x) + + def _forward_native(self, image_native_2d, xp=np): + """Run nufft2d2 on a 2D native-shape image array, returning visibilities.""" + if xp.__name__.startswith("jax"): + import jax.numpy as jnp + + img = jnp.asarray(image_native_2d)[::-1, :].astype(jnp.complex128) + x = jnp.asarray(self._x) + y = jnp.asarray(self._y) + shift = jnp.asarray(self._shift) + return _nufftax.nufft2d2(x, y, img, self.eps, -1) * shift + + img = image_native_2d[::-1, :].astype(np.complex128) + out = _nufftax.nufft2d2(self._x, self._y, img, self.eps, -1) * self._shift + return np.array(np.asarray(out)) + + def visibilities_from(self, image, xp=np) -> Visibilities: + """ + Forward NUFFT: real-space image -> visibilities at the configured uv points. + + For numpy callers (`xp=np`) the result is materialised back to numpy + before being wrapped in `Visibilities`. For JAX callers (`xp=jnp`) + the result stays as a `jax.Array` so it can flow through `jax.jit` + / `jax.grad` / `jax.vmap` without device round-trips. + """ + if xp.__name__.startswith("jax"): + import jax.numpy as jnp + + image_native = jnp.zeros(image.mask.shape, dtype=image.dtype) + image_native = image_native.at[image.mask.slim_to_native_tuple].set( + image.array + ) + else: + image_native = image.native.array + + return Visibilities(visibilities=self._forward_native(image_native, xp=xp)) + + def image_from( + self, + visibilities: Visibilities, + use_adjoint_scaling: bool = False, + xp=np, + ) -> Array2D: + """ + Adjoint NUFFT: visibilities -> real-space (dirty) image. + + Implemented as `nufftax.nufft2d1` with `conj(shift)` applied to the + visibilities and a final row-flip to return to autoarray's native + orientation. The real part is taken to discard imaginary residue, + matching the legacy class' behaviour. + + Note that this is the **mathematical adjoint** of `visibilities_from`, + with no kernel deconvolution applied. The dirty image therefore + differs in absolute scale from the legacy `TransformerNUFFTPyNUFFT` + adjoint (which applies pynufft's internal IFFT and kernel + deconvolution). The structure of the dirty image is the same, and + the values match `TransformerDFT.image_from` exactly. + + **Scale-sensitive callers**: `Interferometer.apply_sparse_operator` + consumes the dirty-image scale together with a precision operator; it + is currently incompatible with this class and raises + `NotImplementedError`. Use `TransformerDFT` or + `TransformerNUFFTPyNUFFT` if you need the sparse-operator path. + """ + n_y, n_x = self.real_space_mask.shape_native + n_modes = (n_x, n_y) # nufftax wants (n1, n2) = (N_x, N_y) + + if xp.__name__.startswith("jax"): + import jax.numpy as jnp + + x = jnp.asarray(self._x) + y = jnp.asarray(self._y) + shift_conj = jnp.asarray(np.conj(self._shift)) + c = jnp.asarray(visibilities.array) * shift_conj + f = _nufftax.nufft2d1(x, y, c, n_modes, self.eps, +1) + image = jnp.real(f)[::-1, :] + else: + c = visibilities.array * np.conj(self._shift) + f = _nufftax.nufft2d1(self._x, self._y, c, n_modes, self.eps, +1) + image = np.array(np.asarray(f)[::-1, :].real) + + if use_adjoint_scaling: + image = image * self.adjoint_scaling + + return Array2D(values=image, mask=self.real_space_mask) + + def transform_mapping_matrix(self, mapping_matrix, xp=np): + """ + Apply the forward NUFFT to each column of a mapping matrix. + + Each column is scattered back to the native 2D image grid using the + mask's `slim_to_native_tuple`, then passed through `_forward_native`. + """ + n_uv = self.uv_wavelengths.shape[0] + n_src = mapping_matrix.shape[1] + slim_to_native = self.real_space_mask.slim_to_native_tuple + native_shape = self.real_space_mask.shape_native + + if xp.__name__.startswith("jax"): + import jax.numpy as jnp + + out = jnp.zeros((n_uv, n_src), dtype=jnp.complex128) + for k in range(n_src): + image_2d = jnp.zeros(native_shape, dtype=mapping_matrix.dtype) + image_2d = image_2d.at[slim_to_native].set(mapping_matrix[:, k]) + vis = self._forward_native(image_2d, xp=xp) + out = out.at[:, k].set(vis) + return out + + out = np.zeros((n_uv, n_src), dtype=np.complex128) + for k in range(n_src): + image_2d = np.zeros(native_shape, dtype=mapping_matrix.dtype) + image_2d[slim_to_native] = mapping_matrix[:, k] + out[:, k] = self._forward_native(image_2d, xp=xp) + return out diff --git a/autoarray/type.py b/autoarray/type.py index a2bee4832..de2d070d6 100644 --- a/autoarray/type.py +++ b/autoarray/type.py @@ -33,8 +33,9 @@ from autoarray.operators.transformer import TransformerDFT from autoarray.operators.transformer import TransformerNUFFT +from autoarray.operators.transformer import TransformerNUFFTPyNUFFT -Transformer = Union[TransformerDFT, TransformerNUFFT] +Transformer = Union[TransformerDFT, TransformerNUFFT, TransformerNUFFTPyNUFFT] from autoarray.layout.region import Region1D diff --git a/pyproject.toml b/pyproject.toml index 19cf2a441..da3397b53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,11 +56,12 @@ jax = ["autoconf[jax]"] optional = [ "autoarray[jax]", "numba", + "nufftax", "pynufft", "tensorflow-probability==0.25.0" ] test = ["pytest"] -dev = ["pytest", "black", "numba", "pynufft==2022.2.2"] +dev = ["pytest", "black", "numba", "nufftax", "pynufft==2022.2.2"] [tool.pytest.ini_options] testpaths = ["test_autoarray"] diff --git a/test_autoarray/operators/test_transformer.py b/test_autoarray/operators/test_transformer.py index 13a072f3b..66f377cd6 100644 --- a/test_autoarray/operators/test_transformer.py +++ b/test_autoarray/operators/test_transformer.py @@ -70,6 +70,29 @@ def test__nufft__visibilities_from__all_ones_image__first_visibility_matches_exp visibilities_nufft = transformer_nufft.visibilities_from(image=image.native) + # nufftax-backed forward NUFFT: matches the analytic DFT to machine precision. + # For an all-ones image the visibility at any uv is N_y * N_x = 25. + assert visibilities_nufft[0] == pytest.approx(25.0 + 0.0j, 1.0e-7) + + +def test__nufft_pynufft__visibilities_from__all_ones_image__first_visibility_matches_expected(): + + uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) + real_space_mask = aa.Mask2D.all_false(shape_native=(5, 5), pixel_scales=0.005) + + image = aa.Array2D.ones( + shape_native=(5, 5), + pixel_scales=0.005, + ) + + transformer_nufft = aa.TransformerNUFFTPyNUFFT( + uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask + ) + + visibilities_nufft = transformer_nufft.visibilities_from(image=image.native) + + # Legacy pynufft has a small gridding-kernel error at N=5; expected value + # encodes that error and is retained for backwards compatibility. assert visibilities_nufft[0] == pytest.approx(25.02317617953263 + 0.0j, 1.0e-7) @@ -84,6 +107,24 @@ def test__nufft__image_from__visibilities_7__first_three_image_pixels_match_expe image = transformer.image_from(visibilities=visibilities_7) + # nufftax adjoint matches `TransformerDFT.image_from` exactly (no kernel + # deconvolution applied; this is the mathematical adjoint of the forward). + assert image[0:3] == pytest.approx([-1.49022481, -0.22395855, -0.45588535], 1.0e-4) + + +def test__nufft_pynufft__image_from__visibilities_7__first_three_image_pixels_match_expected( + visibilities_7, uv_wavelengths_7x2, mask_2d_7x7 +): + + transformer = aa.TransformerNUFFTPyNUFFT( + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, + ) + + image = transformer.image_from(visibilities=visibilities_7) + + # Legacy pynufft adjoint includes internal kernel deconvolution and IFFT + # normalisation; expected values reflect that behaviour. assert image[0:3] == pytest.approx([0.00726546, 0.01149121, 0.01421022], 1.0e-4) @@ -102,6 +143,26 @@ def test__nufft__transform_mapping_matrix__ones_mapping_matrix__first_element_ma mapping_matrix=mapping_matrix ) + # nufftax-backed forward over a mapping matrix column reduces to the + # all-ones forward NUFFT case; equals N_y * N_x = 25 exactly. + assert transformed_mapping_matrix_nufft[0, 0] == pytest.approx(25.0 + 0.0j, 1.0e-4) + + +def test__nufft_pynufft__transform_mapping_matrix__ones_mapping_matrix__first_element_matches_expected(): + uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) + + mapping_matrix = np.ones(shape=(25, 3)) + + real_space_mask = aa.Mask2D.all_false(shape_native=(5, 5), pixel_scales=0.005) + + transformer_nufft = aa.TransformerNUFFTPyNUFFT( + uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask + ) + + transformed_mapping_matrix_nufft = transformer_nufft.transform_mapping_matrix( + mapping_matrix=mapping_matrix + ) + assert transformed_mapping_matrix_nufft[0, 0] == pytest.approx( 25.02317 + 0.0j, 1.0e-4 )