From 5fd0f434cd6bbb4daf1b1e5c86782b6dbc13343e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 17 Apr 2026 13:39:57 +0100 Subject: [PATCH] fix(jax): trace TuplePrior paired priors as pytree children MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit register_model classified attributes as dynamic (traced) vs constant (aux_data) via `isinstance(value, (Prior, AbstractPriorModel))`. TuplePrior is a ModelObject — neither — so paired-prior attributes like `centre=(x, y)` and `ell_comps=(e1, e2)` were routed to aux_data and frozen inside JIT, invisible to jax.value_and_grad. Impact: a 15-free-parameter MGE lens model produced only 3 JAX leaves (einstein_radius, gamma_1, gamma_2). The 12 tuple-valued priors could not receive gradients, and downstream NNLS backward passes produced NaN. Fix: extend the isinstance tuple with TuplePrior so the resolved Python tuple ends up in children. JAX's built-in tuple pytree handling recurses into the scalar leaves — no new registration needed for TuplePrior itself. Adds test_register_model_traces_tuple_prior_attributes as a regression test covering a TuplePrior-backed attribute through jax.value_and_grad. --- autofit/jax/pytrees.py | 10 +++--- test_autofit/jax/test_enable_pytrees.py | 44 +++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/autofit/jax/pytrees.py b/autofit/jax/pytrees.py index 177d90e31..cd9cdfde1 100644 --- a/autofit/jax/pytrees.py +++ b/autofit/jax/pytrees.py @@ -74,6 +74,7 @@ def register_model(model) -> bool: return False from autofit.mapper.prior.abstract import Prior + from autofit.mapper.prior.tuple_prior import TuplePrior from autofit.mapper.prior_model.prior_model import Model from autofit.mapper.prior_model.collection import Collection from autofit.mapper.prior_model.abstract import AbstractPriorModel @@ -83,7 +84,7 @@ def _walk(node): cls = node.cls classifier = _CLASS_FIELD_CLASSIFIERS.setdefault(cls, {}) for name, value in node.items(): - is_dynamic = isinstance(value, (Prior, AbstractPriorModel)) + is_dynamic = isinstance(value, (Prior, AbstractPriorModel, TuplePrior)) # setdefault: earliest classification wins. Different models # sharing the same cls (e.g. lens vs source Galaxy) may # declare different attribute sets; we accumulate them all. @@ -123,9 +124,10 @@ def _build_instance_pytree_funcs(cls): Each attribute is classified as: * **Dynamic** (prior-derived): the corresponding ``Model`` attribute was - a ``Prior`` or ``AbstractPriorModel``. Resolved to concrete numbers - (or nested instances) per sampled point, so it becomes a JAX child leaf - and gets traced under ``jax.jit``. + a ``Prior``, ``AbstractPriorModel``, or ``TuplePrior``. Resolved to + concrete numbers (or nested instances, or a Python tuple of scalars in + the ``TuplePrior`` case) per sampled point, so it becomes a JAX child + leaf (or tuple of leaves) and gets traced under ``jax.jit``. * **Constant**: everything else — a fixed ``redshift=0.5``, or a concrete non-prior kwarg like ``Galaxy(pixelization=)``. Goes into ``aux_data`` so it stays as the original Python object inside a trace. diff --git a/test_autofit/jax/test_enable_pytrees.py b/test_autofit/jax/test_enable_pytrees.py index 6d5200b96..076f2dd3d 100644 --- a/test_autofit/jax/test_enable_pytrees.py +++ b/test_autofit/jax/test_enable_pytrees.py @@ -125,6 +125,50 @@ def use_marker_isinstance(inst): assert float(result) == pytest.approx(2.0 * instance.scale) +def test_register_model_traces_tuple_prior_attributes(): + """``TuplePrior``-backed attributes must be routed into JAX children so + gradients flow through paired priors like ``centre=(x, y)`` and + ``ell_comps=(e1, e2)``. + + Mirrors real-world MGE / Isothermal / ExternalShear usage where the + paired priors are the majority of the free parameters. Prior to the + fix, ``TuplePrior`` failed the ``(Prior, AbstractPriorModel)`` + isinstance check in ``register_model``, so the resolved tuple was + frozen in ``aux_data`` and ``jax.value_and_grad`` returned gradients + only for the non-tuple attributes. + """ + class Twin: + def __init__(self, centre, amplitude): + self.centre = centre + self.amplitude = amplitude + + model = af.Model( + Twin, + centre=af.TuplePrior( + centre_0=af.GaussianPrior(mean=0.5, sigma=1.0), + centre_1=af.GaussianPrior(mean=-0.5, sigma=1.0), + ), + amplitude=af.GaussianPrior(mean=1.0, sigma=1.0), + ) + register_model(model) + instance = model.instance_from_prior_medians() + params_tree = jax.tree_util.tree_map(jnp.asarray, instance) + + leaves = jax.tree_util.tree_leaves(params_tree) + assert len(leaves) == 3 # centre[0], centre[1], amplitude + + def loss(inst): + cx, cy = inst.centre + return cx * cx + cy * cy + inst.amplitude + + _, grad = jax.value_and_grad(loss)(params_tree) + flat_grad = jnp.concatenate( + [jnp.asarray(l).ravel() for l in jax.tree_util.tree_leaves(grad)] + ) + assert jnp.all(jnp.isfinite(flat_grad)) + assert flat_grad.size == 3 + + def test_enable_pytrees_idempotent(): assert enable_pytrees() is True assert enable_pytrees() is True