Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions autofit/jax/pytrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def register_model(model) -> bool:
return False

from autofit.mapper.prior.abstract import Prior
from autofit.mapper.prior.tuple_prior import TuplePrior
from autofit.mapper.prior_model.prior_model import Model
from autofit.mapper.prior_model.collection import Collection
from autofit.mapper.prior_model.abstract import AbstractPriorModel
Expand All @@ -83,7 +84,7 @@ def _walk(node):
cls = node.cls
classifier = _CLASS_FIELD_CLASSIFIERS.setdefault(cls, {})
for name, value in node.items():
is_dynamic = isinstance(value, (Prior, AbstractPriorModel))
is_dynamic = isinstance(value, (Prior, AbstractPriorModel, TuplePrior))
# setdefault: earliest classification wins. Different models
# sharing the same cls (e.g. lens vs source Galaxy) may
# declare different attribute sets; we accumulate them all.
Expand Down Expand Up @@ -123,9 +124,10 @@ def _build_instance_pytree_funcs(cls):
Each attribute is classified as:

* **Dynamic** (prior-derived): the corresponding ``Model`` attribute was
a ``Prior`` or ``AbstractPriorModel``. Resolved to concrete numbers
(or nested instances) per sampled point, so it becomes a JAX child leaf
and gets traced under ``jax.jit``.
a ``Prior``, ``AbstractPriorModel``, or ``TuplePrior``. Resolved to
concrete numbers (or nested instances, or a Python tuple of scalars in
the ``TuplePrior`` case) per sampled point, so it becomes a JAX child
leaf (or tuple of leaves) and gets traced under ``jax.jit``.
* **Constant**: everything else — a fixed ``redshift=0.5``, or a concrete
non-prior kwarg like ``Galaxy(pixelization=<Pixelization>)``. Goes into
``aux_data`` so it stays as the original Python object inside a trace.
Expand Down
44 changes: 44 additions & 0 deletions test_autofit/jax/test_enable_pytrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,50 @@ def use_marker_isinstance(inst):
assert float(result) == pytest.approx(2.0 * instance.scale)


def test_register_model_traces_tuple_prior_attributes():
"""``TuplePrior``-backed attributes must be routed into JAX children so
gradients flow through paired priors like ``centre=(x, y)`` and
``ell_comps=(e1, e2)``.

Mirrors real-world MGE / Isothermal / ExternalShear usage where the
paired priors are the majority of the free parameters. Prior to the
fix, ``TuplePrior`` failed the ``(Prior, AbstractPriorModel)``
isinstance check in ``register_model``, so the resolved tuple was
frozen in ``aux_data`` and ``jax.value_and_grad`` returned gradients
only for the non-tuple attributes.
"""
class Twin:
def __init__(self, centre, amplitude):
self.centre = centre
self.amplitude = amplitude

model = af.Model(
Twin,
centre=af.TuplePrior(
centre_0=af.GaussianPrior(mean=0.5, sigma=1.0),
centre_1=af.GaussianPrior(mean=-0.5, sigma=1.0),
),
amplitude=af.GaussianPrior(mean=1.0, sigma=1.0),
)
register_model(model)
instance = model.instance_from_prior_medians()
params_tree = jax.tree_util.tree_map(jnp.asarray, instance)

leaves = jax.tree_util.tree_leaves(params_tree)
assert len(leaves) == 3 # centre[0], centre[1], amplitude

def loss(inst):
cx, cy = inst.centre
return cx * cx + cy * cy + inst.amplitude

_, grad = jax.value_and_grad(loss)(params_tree)
flat_grad = jnp.concatenate(
[jnp.asarray(l).ravel() for l in jax.tree_util.tree_leaves(grad)]
)
assert jnp.all(jnp.isfinite(flat_grad))
assert flat_grad.size == 3


def test_enable_pytrees_idempotent():
assert enable_pytrees() is True
assert enable_pytrees() is True
Expand Down
Loading