fix(jax): route concrete kwarg attrs to aux_data in pytree flatten#1221
Merged
fix(jax): route concrete kwarg attrs to aux_data in pytree flatten#1221
Conversation
`_build_instance_pytree_funcs` previously classified attributes from a single model at registration time and captured a fixed name list in the flatten closure. Two problems: 1. `model.direct_argument_names` / `direct_instance_tuples` miss concrete non-Prior kwargs (e.g. `af.Model(Galaxy, pixelization=<Pixelization>)`), which are stored directly on `model.__dict__` but not typed as Prior. These slipped into `children` and became JAX tracers under JIT, so `isinstance(x, Pixelization)` inside `FitImaging` returned False and `fit.inversion` silently resolved to None. 2. Classes using `**kwargs` (like Galaxy) produce instances with different attribute sets per model (lens: bulge/mass/shear; source: pixelization) while sharing one `cls`. JAX registers flatten per-class, so the first model's captured names fail `getattr` on later instances. Fix: maintain a shared `_CLASS_FIELD_CLASSIFIERS[cls]` populated across all `register_model` walks. `flatten` now iterates `vars(instance)` at call time and consults the classifier, defaulting unknown attrs to constant (aux_data) — safer than tracing an unknown object. Regression test `test_register_model_keeps_kwarg_constants_static` exercises the Pixelization-style pattern: a concrete kwarg must remain identity-equal to the original object under JIT and pass `isinstance`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Collaborator
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes two bugs in
autofit.jax._build_instance_pytree_funcsthat surfaced when migratingjax_profiling/imaging/pixelization.pyto pass a pytreeModelInstanceinto JIT/vmap directly (follow-up to PyAutoLabs/autolens_workspace_developer#10, the same work that produced PyAutoFit#1220 for MGE):Galaxy(pixelization=<concrete Pixelization>)attribute is stored directly onmodel.__dict__but is not typed asPrior/AbstractPriorModel. The old classification path (direct_argument_names/direct_instance_tuples) never saw it, so it was routed intochildreninstead ofaux_data. Underjax.jitthis turned the pixelization into a tracer, causingisinstance(x, aa.Pixelization)insideFitImagingto silently returnFalseandfit.inversionto resolve toNone.**kwargsclasses shared a broken flatten.Galaxy.__init__(self, redshift, **kwargs)produces instances with different attribute sets per-model (lens:bulge/mass/shear; source:pixelization). JAX registers one flatten per class, so the first model's captured name list blew up withAttributeError: 'Galaxy' object has no attribute 'bulge'when flattening the source instance.API Changes
None — internal changes only.
enable_pytrees()andregister_model(model)signatures and semantics are unchanged; these are pure correctness fixes to the internal flatten/unflatten closures registered withjax.tree_util.Test Plan
test_register_model_keeps_kwarg_constants_static— exercises the exactGalaxy(pixelization=...)pattern: a concrete kwarg must stay identity-equal to the original object under JIT and passisinstance.test_register_model_keeps_constants_static(Galaxy-redshift-style constant) still passes.pytest test_autofit/ -x→ 1224 passed.jax_profiling/imaging/pixelization.pyruns cleanly, both correctness assertions pass (step-by-steplog_evidencematches reference withinrtol=1e-4; vmap batch of 3 matches single-JIT result).Full API Changes (for automation & release notes)
Removed
None.
Added
None.
Renamed
None.
Changed Signature
None.
Changed Behaviour
autofit.jax._build_instance_pytree_funcsis now instance-driven.flatteniteratesvars(instance)at call time against a shared_CLASS_FIELD_CLASSIFIERS[cls]dict populated across allregister_modelwalks (rather than closing over a fixed name list captured from a single model at registration time). Attributes unknown to the classifier default toaux_data(safer than tracing an unknown object).Migration
No migration required. Callers of
autofit.jax.enable_pytrees/autofit.jax.register_modelsee only a behavioural fix — patterns that previously silently produced broken JAX traces now work correctly.🤖 Generated with Claude Code