diff --git a/autogalaxy/profiles/mass/total/jax_utils.py b/autogalaxy/profiles/mass/total/jax_utils.py index 94c2f139..d04691ea 100644 --- a/autogalaxy/profiles/mass/total/jax_utils.py +++ b/autogalaxy/profiles/mass/total/jax_utils.py @@ -32,7 +32,7 @@ def omega(eiphi, slope, factor, n_terms=20, xp=np): be sufficient most of the time) """ - from functools import partial + from jax.tree_util import Partial as partial import jax scan = jax.jit(jax.lax.scan, static_argnames=("length", "reverse", "unroll"))