jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/utils/jaxsim_dataclass.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import abc
|
2
2
|
import contextlib
|
3
|
-
import copy
|
4
3
|
import dataclasses
|
5
|
-
|
4
|
+
import functools
|
5
|
+
from collections.abc import Callable, Iterator, Sequence
|
6
|
+
from typing import Any, ClassVar
|
6
7
|
|
7
8
|
import jax.flatten_util
|
8
9
|
import jax_dataclasses
|
@@ -19,51 +20,224 @@ except ImportError:
|
|
19
20
|
|
20
21
|
@jax_dataclasses.pytree_dataclass
|
21
22
|
class JaxsimDataclass(abc.ABC):
|
22
|
-
""""""
|
23
|
+
"""Class extending `jax_dataclasses.pytree_dataclass` instances with utilities."""
|
23
24
|
|
24
25
|
# This attribute is set by jax_dataclasses
|
25
26
|
__mutability__: ClassVar[Mutability] = Mutability.FROZEN
|
26
27
|
|
27
28
|
@contextlib.contextmanager
|
28
|
-
def editable(self: Self, validate: bool = True) ->
|
29
|
-
"""
|
29
|
+
def editable(self: Self, validate: bool = True) -> Iterator[Self]:
|
30
|
+
"""
|
31
|
+
Context manager to operate on a mutable copy of the object.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
validate: Whether to validate the output PyTree upon exiting the context.
|
35
|
+
|
36
|
+
Yields:
|
37
|
+
A mutable copy of the object.
|
38
|
+
|
39
|
+
Note:
|
40
|
+
This context manager is useful to operate on an r/w copy of a PyTree making
|
41
|
+
sure that the output object does not trigger JIT recompilations.
|
42
|
+
"""
|
30
43
|
|
31
44
|
mutability = (
|
32
45
|
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
33
46
|
)
|
34
47
|
|
35
|
-
with
|
48
|
+
with self.copy().mutable_context(mutability=mutability) as obj:
|
36
49
|
yield obj
|
37
50
|
|
38
51
|
@contextlib.contextmanager
|
39
52
|
def mutable_context(
|
40
|
-
self: Self,
|
41
|
-
|
42
|
-
|
53
|
+
self: Self,
|
54
|
+
mutability: Mutability = Mutability.MUTABLE,
|
55
|
+
restore_after_exception: bool = True,
|
56
|
+
) -> Iterator[Self]:
|
57
|
+
"""
|
58
|
+
Context manager to temporarily change the mutability of the object.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
mutability: The mutability to set.
|
62
|
+
restore_after_exception:
|
63
|
+
Whether to restore the original object in case of an exception
|
64
|
+
occurring within the context.
|
65
|
+
|
66
|
+
Yields:
|
67
|
+
The object with the new mutability.
|
68
|
+
|
69
|
+
Note:
|
70
|
+
This context manager is useful to operate in place on a PyTree without
|
71
|
+
the need to make a copy while optionally keeping active the checks on
|
72
|
+
the PyTree structure, shapes, and dtypes.
|
73
|
+
"""
|
43
74
|
|
44
75
|
if restore_after_exception:
|
45
76
|
self_copy = self.copy()
|
46
77
|
|
47
|
-
original_mutability = self.
|
78
|
+
original_mutability = self.mutability()
|
48
79
|
|
49
|
-
|
50
|
-
|
80
|
+
original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
|
81
|
+
original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
|
82
|
+
original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
|
83
|
+
original_structure = jax.tree_util.tree_structure(tree=self)
|
84
|
+
|
85
|
+
def restore_self() -> None:
|
86
|
+
self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION)
|
51
87
|
for f in dataclasses.fields(self_copy):
|
52
88
|
setattr(self, f.name, getattr(self_copy, f.name))
|
53
89
|
|
54
90
|
try:
|
55
|
-
self.
|
91
|
+
self.set_mutability(mutability=mutability)
|
56
92
|
yield self
|
93
|
+
|
94
|
+
if mutability is not Mutability.MUTABLE_NO_VALIDATION:
|
95
|
+
new_structure = jax.tree_util.tree_structure(tree=self)
|
96
|
+
if original_structure != new_structure:
|
97
|
+
msg = "Pytree structure has changed from {} to {}"
|
98
|
+
raise ValueError(msg.format(original_structure, new_structure))
|
99
|
+
|
100
|
+
new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
|
101
|
+
if original_shapes != new_shapes:
|
102
|
+
msg = "Leaves shapes have changed from {} to {}"
|
103
|
+
raise ValueError(msg.format(original_shapes, new_shapes))
|
104
|
+
|
105
|
+
new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
|
106
|
+
if original_dtypes != new_dtypes:
|
107
|
+
msg = "Leaves dtypes have changed from {} to {}"
|
108
|
+
raise ValueError(msg.format(original_dtypes, new_dtypes))
|
109
|
+
|
110
|
+
new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
|
111
|
+
if original_weak_types != new_weak_types:
|
112
|
+
msg = "Leaves weak types have changed from {} to {}"
|
113
|
+
raise ValueError(msg.format(original_weak_types, new_weak_types))
|
114
|
+
|
57
115
|
except Exception as e:
|
58
116
|
if restore_after_exception:
|
59
117
|
restore_self()
|
60
|
-
self.
|
118
|
+
self.set_mutability(original_mutability)
|
61
119
|
raise e
|
120
|
+
|
62
121
|
finally:
|
63
|
-
self.
|
122
|
+
self.set_mutability(original_mutability)
|
123
|
+
|
124
|
+
@staticmethod
|
125
|
+
def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]:
|
126
|
+
"""
|
127
|
+
Get the leaf shapes of a PyTree.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
tree: The PyTree to consider.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
A tuple containing the leaf shapes of the PyTree or `None` is the leaf is
|
134
|
+
not a numpy-like array.
|
135
|
+
"""
|
136
|
+
|
137
|
+
return tuple(
|
138
|
+
map(
|
139
|
+
lambda leaf: getattr(leaf, "shape", None),
|
140
|
+
jax.tree_util.tree_leaves(tree),
|
141
|
+
)
|
142
|
+
)
|
143
|
+
|
144
|
+
@staticmethod
|
145
|
+
def get_leaf_dtypes(tree: jtp.PyTree) -> tuple:
|
146
|
+
"""
|
147
|
+
Get the leaf dtypes of a PyTree.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
tree: The PyTree to consider.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is
|
154
|
+
not a numpy-like array.
|
155
|
+
"""
|
156
|
+
|
157
|
+
return tuple(
|
158
|
+
map(
|
159
|
+
lambda leaf: getattr(leaf, "dtype", None),
|
160
|
+
jax.tree_util.tree_leaves(tree),
|
161
|
+
)
|
162
|
+
)
|
163
|
+
|
164
|
+
@staticmethod
|
165
|
+
def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]:
|
166
|
+
"""
|
167
|
+
Get the leaf weak types of a PyTree.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
tree: The PyTree to consider.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
A tuple marking whether the leaf contains a JAX array with weak type.
|
174
|
+
"""
|
175
|
+
|
176
|
+
return tuple(
|
177
|
+
map(
|
178
|
+
lambda leaf: getattr(leaf, "weak_type", None),
|
179
|
+
jax.tree_util.tree_leaves(tree),
|
180
|
+
)
|
181
|
+
)
|
182
|
+
|
183
|
+
@staticmethod
|
184
|
+
def check_compatibility(*trees: Sequence[Any]) -> None:
|
185
|
+
"""
|
186
|
+
Check whether the PyTrees are compatible in structure, shape, and dtype.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
*trees: The PyTrees to compare.
|
190
|
+
|
191
|
+
Raises:
|
192
|
+
ValueError: If the PyTrees have incompatible structures, shapes, or dtypes.
|
193
|
+
"""
|
194
|
+
|
195
|
+
target_structure = jax.tree_util.tree_structure(trees[0])
|
196
|
+
|
197
|
+
compatible_structure = functools.reduce(
|
198
|
+
lambda compatible, tree: compatible
|
199
|
+
and jax.tree_util.tree_structure(tree) == target_structure,
|
200
|
+
trees[1:],
|
201
|
+
True,
|
202
|
+
)
|
203
|
+
|
204
|
+
if not compatible_structure:
|
205
|
+
raise ValueError("Pytrees have incompatible structures.")
|
206
|
+
|
207
|
+
target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0])
|
208
|
+
|
209
|
+
compatible_shapes = functools.reduce(
|
210
|
+
lambda compatible, tree: compatible
|
211
|
+
and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes,
|
212
|
+
trees[1:],
|
213
|
+
True,
|
214
|
+
)
|
215
|
+
|
216
|
+
if not compatible_shapes:
|
217
|
+
raise ValueError("Pytrees have incompatible shapes.")
|
218
|
+
|
219
|
+
target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0])
|
220
|
+
|
221
|
+
compatible_dtypes = functools.reduce(
|
222
|
+
lambda compatible, tree: compatible
|
223
|
+
and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes,
|
224
|
+
trees[1:],
|
225
|
+
True,
|
226
|
+
)
|
227
|
+
|
228
|
+
if not compatible_dtypes:
|
229
|
+
raise ValueError("Pytrees have incompatible dtypes.")
|
64
230
|
|
65
231
|
def is_mutable(self, validate: bool = False) -> bool:
|
66
|
-
"""
|
232
|
+
"""
|
233
|
+
Check whether the object is mutable.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
validate: Additionally checks if the object also has validation enabled.
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
True if the object is mutable, False otherwise.
|
240
|
+
"""
|
67
241
|
|
68
242
|
return (
|
69
243
|
self.__mutability__ is Mutability.MUTABLE
|
@@ -71,39 +245,120 @@ class JaxsimDataclass(abc.ABC):
|
|
71
245
|
else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION
|
72
246
|
)
|
73
247
|
|
74
|
-
def
|
75
|
-
|
76
|
-
|
77
|
-
else:
|
78
|
-
mutability = (
|
79
|
-
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
80
|
-
)
|
248
|
+
def mutability(self) -> Mutability:
|
249
|
+
"""
|
250
|
+
Get the mutability type of the object.
|
81
251
|
|
82
|
-
|
252
|
+
Returns:
|
253
|
+
The mutability type of the object.
|
254
|
+
"""
|
83
255
|
|
84
|
-
def _mutability(self) -> Mutability:
|
85
256
|
return self.__mutability__
|
86
257
|
|
87
|
-
def
|
258
|
+
def set_mutability(self, mutability: Mutability) -> None:
|
259
|
+
"""
|
260
|
+
Set the mutability of the object in-place.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
mutability: The desired mutability type.
|
264
|
+
"""
|
265
|
+
|
88
266
|
jax_dataclasses._copy_and_mutate._mark_mutable(
|
89
267
|
self, mutable=mutability, visited=set()
|
90
268
|
)
|
91
269
|
|
92
270
|
def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self:
|
93
|
-
|
271
|
+
"""
|
272
|
+
Return a mutable reference of the object.
|
273
|
+
|
274
|
+
Args:
|
275
|
+
mutable: Whether to make the object mutable.
|
276
|
+
validate: Whether to enable validation on the object.
|
277
|
+
|
278
|
+
Returns:
|
279
|
+
A mutable reference of the object.
|
280
|
+
"""
|
281
|
+
|
282
|
+
if mutable:
|
283
|
+
mutability = (
|
284
|
+
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
285
|
+
)
|
286
|
+
else:
|
287
|
+
mutability = Mutability.FROZEN
|
288
|
+
|
289
|
+
self.set_mutability(mutability=mutability)
|
94
290
|
return self
|
95
291
|
|
96
292
|
def copy(self: Self) -> Self:
|
97
|
-
|
98
|
-
|
293
|
+
"""
|
294
|
+
Return a copy of the object.
|
295
|
+
|
296
|
+
Returns:
|
297
|
+
A copy of the object.
|
298
|
+
"""
|
299
|
+
|
300
|
+
# Make a copy calling tree_map.
|
301
|
+
obj = jax.tree.map(lambda leaf: leaf, self)
|
302
|
+
|
303
|
+
# Make sure that the copied object and all the copied leaves have the same
|
304
|
+
# mutability of the original object.
|
305
|
+
obj.set_mutability(mutability=self.mutability())
|
306
|
+
|
99
307
|
return obj
|
100
308
|
|
101
309
|
def replace(self: Self, validate: bool = True, **kwargs) -> Self:
|
102
|
-
|
103
|
-
|
310
|
+
"""
|
311
|
+
Return a new object replacing in-place the specified fields with new values.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
validate: Whether to validate that the new fields do not alter the PyTree.
|
315
|
+
**kwargs: The fields to replace.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
A reference of the object with the specified fields replaced.
|
319
|
+
"""
|
320
|
+
|
321
|
+
# Use the dataclasses replace method.
|
322
|
+
obj = dataclasses.replace(self, **kwargs)
|
323
|
+
|
324
|
+
if validate:
|
325
|
+
JaxsimDataclass.check_compatibility(self, obj)
|
326
|
+
|
327
|
+
# Make sure that all the new leaves have the same mutability of the object.
|
328
|
+
obj.set_mutability(mutability=self.mutability())
|
104
329
|
|
105
|
-
obj._set_mutability(mutability=self._mutability())
|
106
330
|
return obj
|
107
331
|
|
108
|
-
def flatten(self) -> jtp.
|
109
|
-
|
332
|
+
def flatten(self) -> jtp.Vector:
|
333
|
+
"""
|
334
|
+
Flatten the object into a 1D vector.
|
335
|
+
|
336
|
+
Returns:
|
337
|
+
A 1D vector containing the flattened object.
|
338
|
+
"""
|
339
|
+
|
340
|
+
return self.flatten_fn()(self)
|
341
|
+
|
342
|
+
@classmethod
|
343
|
+
def flatten_fn(cls: type[Self]) -> Callable[[Self], jtp.Vector]:
|
344
|
+
"""
|
345
|
+
Return a function to flatten the object into a 1D vector.
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
A function to flatten the object into a 1D vector.
|
349
|
+
"""
|
350
|
+
|
351
|
+
return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0]
|
352
|
+
|
353
|
+
def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]:
|
354
|
+
"""
|
355
|
+
Return a function to unflatten a 1D vector into the object.
|
356
|
+
|
357
|
+
Returns:
|
358
|
+
A function to unflatten a 1D vector into the object.
|
359
|
+
|
360
|
+
Notes:
|
361
|
+
Due to JAX internals, the function to unflatten a PyTree needs to be
|
362
|
+
created from an existing instance of the PyTree.
|
363
|
+
"""
|
364
|
+
return jax.flatten_util.ravel_pytree(self)[1]
|
jaxsim/utils/tracing.py
CHANGED
@@ -6,20 +6,14 @@ import jax.interpreters.partial_eval
|
|
6
6
|
|
7
7
|
|
8
8
|
def tracing(var: Any) -> bool | jax.Array:
|
9
|
-
"""
|
9
|
+
"""Return True if the variable is being traced by JAX, False otherwise."""
|
10
10
|
|
11
|
-
return
|
12
|
-
|
13
|
-
|
14
|
-
for t in (
|
15
|
-
jax._src.core.Tracer,
|
16
|
-
jax.interpreters.partial_eval.DynamicJaxprTracer,
|
17
|
-
)
|
18
|
-
]
|
19
|
-
).any()
|
11
|
+
return isinstance(
|
12
|
+
var, jax._src.core.Tracer | jax.interpreters.partial_eval.DynamicJaxprTracer
|
13
|
+
)
|
20
14
|
|
21
15
|
|
22
16
|
def not_tracing(var: Any) -> bool | jax.Array:
|
23
|
-
"""
|
17
|
+
"""Return True if the variable is not being traced by JAX, False otherwise."""
|
24
18
|
|
25
19
|
return True if tracing(var) is False else False
|
jaxsim/utils/wrappers.py
ADDED
@@ -0,0 +1,159 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from collections.abc import Callable
|
5
|
+
from typing import Generic, TypeVar
|
6
|
+
|
7
|
+
import jax
|
8
|
+
import jax_dataclasses
|
9
|
+
import numpy as np
|
10
|
+
import numpy.typing as npt
|
11
|
+
|
12
|
+
T = TypeVar("T")
|
13
|
+
|
14
|
+
|
15
|
+
@dataclasses.dataclass
|
16
|
+
class HashlessObject(Generic[T]):
|
17
|
+
"""
|
18
|
+
A class that wraps an object and makes it hashless.
|
19
|
+
|
20
|
+
This is useful for creating particular JAX pytrees.
|
21
|
+
For example, to create a pytree with a static leaf that is ignored
|
22
|
+
by JAX when it compares two instances to trigger a JIT recompilation.
|
23
|
+
"""
|
24
|
+
|
25
|
+
obj: T
|
26
|
+
|
27
|
+
def get(self: HashlessObject[T]) -> T:
|
28
|
+
"""
|
29
|
+
Get the wrapped object.
|
30
|
+
"""
|
31
|
+
return self.obj
|
32
|
+
|
33
|
+
def __hash__(self) -> int:
|
34
|
+
|
35
|
+
return 0
|
36
|
+
|
37
|
+
def __eq__(self, other: HashlessObject[T]) -> bool:
|
38
|
+
|
39
|
+
if not isinstance(other, HashlessObject) and isinstance(
|
40
|
+
other.get(), type(self.get())
|
41
|
+
):
|
42
|
+
return False
|
43
|
+
|
44
|
+
return hash(self) == hash(other)
|
45
|
+
|
46
|
+
|
47
|
+
@dataclasses.dataclass
|
48
|
+
class CustomHashedObject(Generic[T]):
|
49
|
+
"""
|
50
|
+
A class that wraps an object and computes its hash with a custom hash function.
|
51
|
+
"""
|
52
|
+
|
53
|
+
obj: T
|
54
|
+
|
55
|
+
hash_function: Callable[[T], int] = hash
|
56
|
+
|
57
|
+
def get(self: CustomHashedObject[T]) -> T:
|
58
|
+
"""
|
59
|
+
Get the wrapped object.
|
60
|
+
"""
|
61
|
+
return self.obj
|
62
|
+
|
63
|
+
def __hash__(self) -> int:
|
64
|
+
|
65
|
+
return self.hash_function(self.obj)
|
66
|
+
|
67
|
+
def __eq__(self, other: CustomHashedObject[T]) -> bool:
|
68
|
+
|
69
|
+
if not isinstance(other, CustomHashedObject) and isinstance(
|
70
|
+
other.get(), type(self.get())
|
71
|
+
):
|
72
|
+
return False
|
73
|
+
|
74
|
+
return hash(self) == hash(other)
|
75
|
+
|
76
|
+
|
77
|
+
@jax_dataclasses.pytree_dataclass
|
78
|
+
class HashedNumpyArray:
|
79
|
+
"""
|
80
|
+
A class that wraps a numpy array and makes it hashable.
|
81
|
+
|
82
|
+
This is useful for creating particular JAX pytrees.
|
83
|
+
For example, to create a pytree with a plain NumPy or JAX NumPy array as static leaf.
|
84
|
+
|
85
|
+
Note:
|
86
|
+
Calculating with the wrapper class the hash of a very large array can be
|
87
|
+
very expensive. If the array is large and only the equality operator is needed,
|
88
|
+
set `large_array=True` to use a faster comparison method.
|
89
|
+
"""
|
90
|
+
|
91
|
+
array: jax.Array | npt.NDArray
|
92
|
+
|
93
|
+
precision: float | None = dataclasses.field(
|
94
|
+
default=1e-9, repr=False, compare=False, hash=False
|
95
|
+
)
|
96
|
+
|
97
|
+
large_array: jax_dataclasses.Static[bool] = dataclasses.field(
|
98
|
+
default=False, repr=False, compare=False, hash=False
|
99
|
+
)
|
100
|
+
|
101
|
+
def get(self) -> jax.Array | npt.NDArray:
|
102
|
+
"""
|
103
|
+
Get the wrapped array.
|
104
|
+
"""
|
105
|
+
return self.array
|
106
|
+
|
107
|
+
def __hash__(self) -> int:
|
108
|
+
|
109
|
+
return HashedNumpyArray.hash_of_array(
|
110
|
+
array=self.array, precision=self.precision
|
111
|
+
)
|
112
|
+
|
113
|
+
def __eq__(self, other: HashedNumpyArray) -> bool:
|
114
|
+
|
115
|
+
if not isinstance(other, HashedNumpyArray):
|
116
|
+
return False
|
117
|
+
|
118
|
+
if self.large_array:
|
119
|
+
return np.allclose(
|
120
|
+
self.array,
|
121
|
+
other.array,
|
122
|
+
**(dict(atol=self.precision) if self.precision is not None else {}),
|
123
|
+
)
|
124
|
+
|
125
|
+
return hash(self) == hash(other)
|
126
|
+
|
127
|
+
@staticmethod
|
128
|
+
def hash_of_array(
|
129
|
+
array: jax.Array | npt.NDArray, precision: float | None = 1e-9
|
130
|
+
) -> int:
|
131
|
+
"""
|
132
|
+
Calculate the hash of a NumPy array.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
array: The array to hash.
|
136
|
+
precision: Optionally limit the precision over which the hash is computed.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
The hash of the array.
|
140
|
+
"""
|
141
|
+
|
142
|
+
array = np.array(array).flatten()
|
143
|
+
|
144
|
+
array = np.where(array == np.nan, hash(np.nan), array)
|
145
|
+
array = np.where(array == np.inf, hash(np.inf), array)
|
146
|
+
array = np.where(array == -np.inf, hash(-np.inf), array)
|
147
|
+
|
148
|
+
if precision is not None:
|
149
|
+
|
150
|
+
integer1 = (array * precision).astype(int)
|
151
|
+
integer2 = (array - integer1 / precision).astype(int)
|
152
|
+
|
153
|
+
decimal_array = ((array - integer1 * 1e9 - integer2) / precision).astype(
|
154
|
+
int
|
155
|
+
)
|
156
|
+
|
157
|
+
array = np.hstack([integer1, integer2, decimal_array]).astype(int)
|
158
|
+
|
159
|
+
return hash(tuple(array.tolist()))
|