jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,9 @@
1
1
  import abc
2
2
  import contextlib
3
- import copy
4
3
  import dataclasses
5
- from typing import ClassVar, Generator
4
+ import functools
5
+ from collections.abc import Callable, Iterator, Sequence
6
+ from typing import Any, ClassVar
6
7
 
7
8
  import jax.flatten_util
8
9
  import jax_dataclasses
@@ -19,51 +20,224 @@ except ImportError:
19
20
 
20
21
  @jax_dataclasses.pytree_dataclass
21
22
  class JaxsimDataclass(abc.ABC):
22
- """"""
23
+ """Class extending `jax_dataclasses.pytree_dataclass` instances with utilities."""
23
24
 
24
25
  # This attribute is set by jax_dataclasses
25
26
  __mutability__: ClassVar[Mutability] = Mutability.FROZEN
26
27
 
27
28
  @contextlib.contextmanager
28
- def editable(self: Self, validate: bool = True) -> Generator[Self, None, None]:
29
- """"""
29
+ def editable(self: Self, validate: bool = True) -> Iterator[Self]:
30
+ """
31
+ Context manager to operate on a mutable copy of the object.
32
+
33
+ Args:
34
+ validate: Whether to validate the output PyTree upon exiting the context.
35
+
36
+ Yields:
37
+ A mutable copy of the object.
38
+
39
+ Note:
40
+ This context manager is useful to operate on an r/w copy of a PyTree making
41
+ sure that the output object does not trigger JIT recompilations.
42
+ """
30
43
 
31
44
  mutability = (
32
45
  Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
33
46
  )
34
47
 
35
- with JaxsimDataclass.mutable_context(self.copy(), mutability=mutability) as obj:
48
+ with self.copy().mutable_context(mutability=mutability) as obj:
36
49
  yield obj
37
50
 
38
51
  @contextlib.contextmanager
39
52
  def mutable_context(
40
- self: Self, mutability: Mutability, restore_after_exception: bool = True
41
- ) -> Generator[Self, None, None]:
42
- """"""
53
+ self: Self,
54
+ mutability: Mutability = Mutability.MUTABLE,
55
+ restore_after_exception: bool = True,
56
+ ) -> Iterator[Self]:
57
+ """
58
+ Context manager to temporarily change the mutability of the object.
59
+
60
+ Args:
61
+ mutability: The mutability to set.
62
+ restore_after_exception:
63
+ Whether to restore the original object in case of an exception
64
+ occurring within the context.
65
+
66
+ Yields:
67
+ The object with the new mutability.
68
+
69
+ Note:
70
+ This context manager is useful to operate in place on a PyTree without
71
+ the need to make a copy while optionally keeping active the checks on
72
+ the PyTree structure, shapes, and dtypes.
73
+ """
43
74
 
44
75
  if restore_after_exception:
45
76
  self_copy = self.copy()
46
77
 
47
- original_mutability = self._mutability()
78
+ original_mutability = self.mutability()
48
79
 
49
- def restore_self():
50
- self._set_mutability(mutability=Mutability.MUTABLE)
80
+ original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
81
+ original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
82
+ original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
83
+ original_structure = jax.tree_util.tree_structure(tree=self)
84
+
85
+ def restore_self() -> None:
86
+ self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION)
51
87
  for f in dataclasses.fields(self_copy):
52
88
  setattr(self, f.name, getattr(self_copy, f.name))
53
89
 
54
90
  try:
55
- self._set_mutability(mutability)
91
+ self.set_mutability(mutability=mutability)
56
92
  yield self
93
+
94
+ if mutability is not Mutability.MUTABLE_NO_VALIDATION:
95
+ new_structure = jax.tree_util.tree_structure(tree=self)
96
+ if original_structure != new_structure:
97
+ msg = "Pytree structure has changed from {} to {}"
98
+ raise ValueError(msg.format(original_structure, new_structure))
99
+
100
+ new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
101
+ if original_shapes != new_shapes:
102
+ msg = "Leaves shapes have changed from {} to {}"
103
+ raise ValueError(msg.format(original_shapes, new_shapes))
104
+
105
+ new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
106
+ if original_dtypes != new_dtypes:
107
+ msg = "Leaves dtypes have changed from {} to {}"
108
+ raise ValueError(msg.format(original_dtypes, new_dtypes))
109
+
110
+ new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
111
+ if original_weak_types != new_weak_types:
112
+ msg = "Leaves weak types have changed from {} to {}"
113
+ raise ValueError(msg.format(original_weak_types, new_weak_types))
114
+
57
115
  except Exception as e:
58
116
  if restore_after_exception:
59
117
  restore_self()
60
- self._set_mutability(original_mutability)
118
+ self.set_mutability(original_mutability)
61
119
  raise e
120
+
62
121
  finally:
63
- self._set_mutability(original_mutability)
122
+ self.set_mutability(original_mutability)
123
+
124
+ @staticmethod
125
+ def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]:
126
+ """
127
+ Get the leaf shapes of a PyTree.
128
+
129
+ Args:
130
+ tree: The PyTree to consider.
131
+
132
+ Returns:
133
+ A tuple containing the leaf shapes of the PyTree or `None` is the leaf is
134
+ not a numpy-like array.
135
+ """
136
+
137
+ return tuple(
138
+ map(
139
+ lambda leaf: getattr(leaf, "shape", None),
140
+ jax.tree_util.tree_leaves(tree),
141
+ )
142
+ )
143
+
144
+ @staticmethod
145
+ def get_leaf_dtypes(tree: jtp.PyTree) -> tuple:
146
+ """
147
+ Get the leaf dtypes of a PyTree.
148
+
149
+ Args:
150
+ tree: The PyTree to consider.
151
+
152
+ Returns:
153
+ A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is
154
+ not a numpy-like array.
155
+ """
156
+
157
+ return tuple(
158
+ map(
159
+ lambda leaf: getattr(leaf, "dtype", None),
160
+ jax.tree_util.tree_leaves(tree),
161
+ )
162
+ )
163
+
164
+ @staticmethod
165
+ def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]:
166
+ """
167
+ Get the leaf weak types of a PyTree.
168
+
169
+ Args:
170
+ tree: The PyTree to consider.
171
+
172
+ Returns:
173
+ A tuple marking whether the leaf contains a JAX array with weak type.
174
+ """
175
+
176
+ return tuple(
177
+ map(
178
+ lambda leaf: getattr(leaf, "weak_type", None),
179
+ jax.tree_util.tree_leaves(tree),
180
+ )
181
+ )
182
+
183
+ @staticmethod
184
+ def check_compatibility(*trees: Sequence[Any]) -> None:
185
+ """
186
+ Check whether the PyTrees are compatible in structure, shape, and dtype.
187
+
188
+ Args:
189
+ *trees: The PyTrees to compare.
190
+
191
+ Raises:
192
+ ValueError: If the PyTrees have incompatible structures, shapes, or dtypes.
193
+ """
194
+
195
+ target_structure = jax.tree_util.tree_structure(trees[0])
196
+
197
+ compatible_structure = functools.reduce(
198
+ lambda compatible, tree: compatible
199
+ and jax.tree_util.tree_structure(tree) == target_structure,
200
+ trees[1:],
201
+ True,
202
+ )
203
+
204
+ if not compatible_structure:
205
+ raise ValueError("Pytrees have incompatible structures.")
206
+
207
+ target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0])
208
+
209
+ compatible_shapes = functools.reduce(
210
+ lambda compatible, tree: compatible
211
+ and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes,
212
+ trees[1:],
213
+ True,
214
+ )
215
+
216
+ if not compatible_shapes:
217
+ raise ValueError("Pytrees have incompatible shapes.")
218
+
219
+ target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0])
220
+
221
+ compatible_dtypes = functools.reduce(
222
+ lambda compatible, tree: compatible
223
+ and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes,
224
+ trees[1:],
225
+ True,
226
+ )
227
+
228
+ if not compatible_dtypes:
229
+ raise ValueError("Pytrees have incompatible dtypes.")
64
230
 
65
231
  def is_mutable(self, validate: bool = False) -> bool:
66
- """"""
232
+ """
233
+ Check whether the object is mutable.
234
+
235
+ Args:
236
+ validate: Additionally checks if the object also has validation enabled.
237
+
238
+ Returns:
239
+ True if the object is mutable, False otherwise.
240
+ """
67
241
 
68
242
  return (
69
243
  self.__mutability__ is Mutability.MUTABLE
@@ -71,39 +245,120 @@ class JaxsimDataclass(abc.ABC):
71
245
  else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION
72
246
  )
73
247
 
74
- def set_mutability(self, mutable: bool = True, validate: bool = False) -> None:
75
- if not mutable:
76
- mutability = Mutability.FROZEN
77
- else:
78
- mutability = (
79
- Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
80
- )
248
+ def mutability(self) -> Mutability:
249
+ """
250
+ Get the mutability type of the object.
81
251
 
82
- self._set_mutability(mutability=mutability)
252
+ Returns:
253
+ The mutability type of the object.
254
+ """
83
255
 
84
- def _mutability(self) -> Mutability:
85
256
  return self.__mutability__
86
257
 
87
- def _set_mutability(self, mutability: Mutability) -> None:
258
+ def set_mutability(self, mutability: Mutability) -> None:
259
+ """
260
+ Set the mutability of the object in-place.
261
+
262
+ Args:
263
+ mutability: The desired mutability type.
264
+ """
265
+
88
266
  jax_dataclasses._copy_and_mutate._mark_mutable(
89
267
  self, mutable=mutability, visited=set()
90
268
  )
91
269
 
92
270
  def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self:
93
- self.set_mutability(mutable=mutable, validate=validate)
271
+ """
272
+ Return a mutable reference of the object.
273
+
274
+ Args:
275
+ mutable: Whether to make the object mutable.
276
+ validate: Whether to enable validation on the object.
277
+
278
+ Returns:
279
+ A mutable reference of the object.
280
+ """
281
+
282
+ if mutable:
283
+ mutability = (
284
+ Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
285
+ )
286
+ else:
287
+ mutability = Mutability.FROZEN
288
+
289
+ self.set_mutability(mutability=mutability)
94
290
  return self
95
291
 
96
292
  def copy(self: Self) -> Self:
97
- obj = jax.tree_util.tree_map(lambda leaf: leaf, self)
98
- obj._set_mutability(mutability=self._mutability())
293
+ """
294
+ Return a copy of the object.
295
+
296
+ Returns:
297
+ A copy of the object.
298
+ """
299
+
300
+ # Make a copy calling tree_map.
301
+ obj = jax.tree.map(lambda leaf: leaf, self)
302
+
303
+ # Make sure that the copied object and all the copied leaves have the same
304
+ # mutability of the original object.
305
+ obj.set_mutability(mutability=self.mutability())
306
+
99
307
  return obj
100
308
 
101
309
  def replace(self: Self, validate: bool = True, **kwargs) -> Self:
102
- with self.editable(validate=validate) as obj:
103
- _ = [obj.__setattr__(k, copy.copy(v)) for k, v in kwargs.items()]
310
+ """
311
+ Return a new object replacing in-place the specified fields with new values.
312
+
313
+ Args:
314
+ validate: Whether to validate that the new fields do not alter the PyTree.
315
+ **kwargs: The fields to replace.
316
+
317
+ Returns:
318
+ A reference of the object with the specified fields replaced.
319
+ """
320
+
321
+ # Use the dataclasses replace method.
322
+ obj = dataclasses.replace(self, **kwargs)
323
+
324
+ if validate:
325
+ JaxsimDataclass.check_compatibility(self, obj)
326
+
327
+ # Make sure that all the new leaves have the same mutability of the object.
328
+ obj.set_mutability(mutability=self.mutability())
104
329
 
105
- obj._set_mutability(mutability=self._mutability())
106
330
  return obj
107
331
 
108
- def flatten(self) -> jtp.VectorJax:
109
- return jax.flatten_util.ravel_pytree(self)[0]
332
+ def flatten(self) -> jtp.Vector:
333
+ """
334
+ Flatten the object into a 1D vector.
335
+
336
+ Returns:
337
+ A 1D vector containing the flattened object.
338
+ """
339
+
340
+ return self.flatten_fn()(self)
341
+
342
+ @classmethod
343
+ def flatten_fn(cls: type[Self]) -> Callable[[Self], jtp.Vector]:
344
+ """
345
+ Return a function to flatten the object into a 1D vector.
346
+
347
+ Returns:
348
+ A function to flatten the object into a 1D vector.
349
+ """
350
+
351
+ return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0]
352
+
353
+ def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]:
354
+ """
355
+ Return a function to unflatten a 1D vector into the object.
356
+
357
+ Returns:
358
+ A function to unflatten a 1D vector into the object.
359
+
360
+ Notes:
361
+ Due to JAX internals, the function to unflatten a PyTree needs to be
362
+ created from an existing instance of the PyTree.
363
+ """
364
+ return jax.flatten_util.ravel_pytree(self)[1]
jaxsim/utils/tracing.py CHANGED
@@ -6,20 +6,14 @@ import jax.interpreters.partial_eval
6
6
 
7
7
 
8
8
  def tracing(var: Any) -> bool | jax.Array:
9
- """Returns True if the variable is being traced by JAX, False otherwise."""
9
+ """Return True if the variable is being traced by JAX, False otherwise."""
10
10
 
11
- return jax.numpy.array(
12
- [
13
- isinstance(var, t)
14
- for t in (
15
- jax._src.core.Tracer,
16
- jax.interpreters.partial_eval.DynamicJaxprTracer,
17
- )
18
- ]
19
- ).any()
11
+ return isinstance(
12
+ var, jax._src.core.Tracer | jax.interpreters.partial_eval.DynamicJaxprTracer
13
+ )
20
14
 
21
15
 
22
16
  def not_tracing(var: Any) -> bool | jax.Array:
23
- """Returns True if the variable is not being traced by JAX, False otherwise."""
17
+ """Return True if the variable is not being traced by JAX, False otherwise."""
24
18
 
25
19
  return True if tracing(var) is False else False
@@ -0,0 +1,159 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from collections.abc import Callable
5
+ from typing import Generic, TypeVar
6
+
7
+ import jax
8
+ import jax_dataclasses
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class HashlessObject(Generic[T]):
17
+ """
18
+ A class that wraps an object and makes it hashless.
19
+
20
+ This is useful for creating particular JAX pytrees.
21
+ For example, to create a pytree with a static leaf that is ignored
22
+ by JAX when it compares two instances to trigger a JIT recompilation.
23
+ """
24
+
25
+ obj: T
26
+
27
+ def get(self: HashlessObject[T]) -> T:
28
+ """
29
+ Get the wrapped object.
30
+ """
31
+ return self.obj
32
+
33
+ def __hash__(self) -> int:
34
+
35
+ return 0
36
+
37
+ def __eq__(self, other: HashlessObject[T]) -> bool:
38
+
39
+ if not isinstance(other, HashlessObject) and isinstance(
40
+ other.get(), type(self.get())
41
+ ):
42
+ return False
43
+
44
+ return hash(self) == hash(other)
45
+
46
+
47
+ @dataclasses.dataclass
48
+ class CustomHashedObject(Generic[T]):
49
+ """
50
+ A class that wraps an object and computes its hash with a custom hash function.
51
+ """
52
+
53
+ obj: T
54
+
55
+ hash_function: Callable[[T], int] = hash
56
+
57
+ def get(self: CustomHashedObject[T]) -> T:
58
+ """
59
+ Get the wrapped object.
60
+ """
61
+ return self.obj
62
+
63
+ def __hash__(self) -> int:
64
+
65
+ return self.hash_function(self.obj)
66
+
67
+ def __eq__(self, other: CustomHashedObject[T]) -> bool:
68
+
69
+ if not isinstance(other, CustomHashedObject) and isinstance(
70
+ other.get(), type(self.get())
71
+ ):
72
+ return False
73
+
74
+ return hash(self) == hash(other)
75
+
76
+
77
+ @jax_dataclasses.pytree_dataclass
78
+ class HashedNumpyArray:
79
+ """
80
+ A class that wraps a numpy array and makes it hashable.
81
+
82
+ This is useful for creating particular JAX pytrees.
83
+ For example, to create a pytree with a plain NumPy or JAX NumPy array as static leaf.
84
+
85
+ Note:
86
+ Calculating with the wrapper class the hash of a very large array can be
87
+ very expensive. If the array is large and only the equality operator is needed,
88
+ set `large_array=True` to use a faster comparison method.
89
+ """
90
+
91
+ array: jax.Array | npt.NDArray
92
+
93
+ precision: float | None = dataclasses.field(
94
+ default=1e-9, repr=False, compare=False, hash=False
95
+ )
96
+
97
+ large_array: jax_dataclasses.Static[bool] = dataclasses.field(
98
+ default=False, repr=False, compare=False, hash=False
99
+ )
100
+
101
+ def get(self) -> jax.Array | npt.NDArray:
102
+ """
103
+ Get the wrapped array.
104
+ """
105
+ return self.array
106
+
107
+ def __hash__(self) -> int:
108
+
109
+ return HashedNumpyArray.hash_of_array(
110
+ array=self.array, precision=self.precision
111
+ )
112
+
113
+ def __eq__(self, other: HashedNumpyArray) -> bool:
114
+
115
+ if not isinstance(other, HashedNumpyArray):
116
+ return False
117
+
118
+ if self.large_array:
119
+ return np.allclose(
120
+ self.array,
121
+ other.array,
122
+ **(dict(atol=self.precision) if self.precision is not None else {}),
123
+ )
124
+
125
+ return hash(self) == hash(other)
126
+
127
+ @staticmethod
128
+ def hash_of_array(
129
+ array: jax.Array | npt.NDArray, precision: float | None = 1e-9
130
+ ) -> int:
131
+ """
132
+ Calculate the hash of a NumPy array.
133
+
134
+ Args:
135
+ array: The array to hash.
136
+ precision: Optionally limit the precision over which the hash is computed.
137
+
138
+ Returns:
139
+ The hash of the array.
140
+ """
141
+
142
+ array = np.array(array).flatten()
143
+
144
+ array = np.where(array == np.nan, hash(np.nan), array)
145
+ array = np.where(array == np.inf, hash(np.inf), array)
146
+ array = np.where(array == -np.inf, hash(-np.inf), array)
147
+
148
+ if precision is not None:
149
+
150
+ integer1 = (array * precision).astype(int)
151
+ integer2 = (array - integer1 / precision).astype(int)
152
+
153
+ decimal_array = ((array - integer1 * 1e9 - integer2) / precision).astype(
154
+ int
155
+ )
156
+
157
+ array = np.hstack([integer1, integer2, decimal_array]).astype(int)
158
+
159
+ return hash(tuple(array.tolist()))
@@ -1,6 +1,6 @@
1
1
  BSD 3-Clause License
2
2
 
3
- Copyright (c) 2022, Artificial and Mechanical Intelligence
3
+ Copyright (c) 2022, Artificial and Mechanical Intelligence
4
4
  All rights reserved.
5
5
 
6
6
  Redistribution and use in source and binary forms, with or without