File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 55from abc import ABC
66from abc import abstractmethod
77import jax .numpy as jnp
8+ from jax ._src .tree_util import register_pytree_node
9+ from jax import Array
810
911from autoconf .fitsable import output_to_fits
1012
11- from autoarray .numpy_wrapper import register_pytree_node , Array
12-
1313from typing import TYPE_CHECKING
1414
1515if TYPE_CHECKING :
Original file line number Diff line number Diff line change 44
55from autoconf import cached_property
66
7- from autoarray . numpy_wrapper import register_pytree_node_class
7+ from jax . _src . tree_util import register_pytree_node_class
88from typing import TYPE_CHECKING
99
1010if TYPE_CHECKING :
Load Diff This file was deleted.
Original file line number Diff line number Diff line change 11import numpy as np
22import jax .numpy as jnp
33import jax
4+ from jax ._src .tree_util import register_pytree_node_class
45from typing import Union
56
67from autoconf import conf
1112
1213from autoarray .operators .over_sampling import over_sample_util
1314
14- from autoarray .numpy_wrapper import register_pytree_node_class
1515
1616
1717@register_pytree_node_class
Original file line number Diff line number Diff line change 22
33import numpy as np
44import jax .numpy as jnp
5+ from jax ._src .tree_util import register_pytree_node_class
56
67from autoarray .structures .triangles .abstract import HEIGHT_FACTOR
78from autoarray .structures .triangles .abstract import AbstractTriangles
89from autoarray .structures .triangles .array import ArrayTriangles
9- from autoarray .numpy_wrapper import register_pytree_node_class
1010
1111
1212@register_pytree_node_class
Original file line number Diff line number Diff line change 11from abc import ABC , abstractmethod
2+ from jax ._src .tree_util import register_pytree_node_class
23from typing import List , Tuple
34
45import numpy as np
56
6- from autoarray .numpy_wrapper import register_pytree_node_class
7-
87
98class Shape (ABC ):
109 """
Original file line number Diff line number Diff line change @@ -29,8 +29,6 @@ dependencies = [
2929 " astropy>=5.0,<=6.1.2" ,
3030 " decorator>=4.0.0" ,
3131 " dill>=0.3.1.1" ,
32- " jax==0.4.28" ,
33- " jaxlib==0.4.28" ,
3432 " jaxnnls==1.0.1" ,
3533 " matplotlib>=3.7.0" ,
3634 " scipy<=1.14.0" ,
Original file line number Diff line number Diff line change 1- import jax
21import jax .numpy as jnp
32
43def pytest_configure ():
54 _ = jnp .sum (jnp .array ([0.0 ])) # Force backend init
65
7- jax .config .update ("jax_enable_x64" , True )
8-
96import os
107from os import path
118import pytest
Original file line number Diff line number Diff line change 1- from autoarray . numpy_wrapper import np
1+ from autoconf . jax_wrapper import np
22from autoarray .structures .triangles .array import ArrayTriangles
33from autoarray .structures .triangles .coordinate_array import CoordinateArrayTriangles
44
You can’t perform that action at this time.
0 commit comments