jaxsim 0.2.1.dev80__py3-none-any.whl → 0.2.1.dev98__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/data.py +19 -1
- jaxsim/api/frame.py +4 -4
- jaxsim/api/kin_dyn_parameters.py +21 -18
- jaxsim/api/model.py +26 -7
- jaxsim/api/ode_data.py +31 -0
- jaxsim/math/joint_model.py +25 -18
- jaxsim/parsers/descriptions/joint.py +3 -1
- jaxsim/rbda/soft_contacts.py +17 -0
- jaxsim/utils/__init__.py +1 -1
- jaxsim/utils/wrappers.py +78 -0
- {jaxsim-0.2.1.dev80.dist-info → jaxsim-0.2.1.dev98.dist-info}/METADATA +1 -1
- {jaxsim-0.2.1.dev80.dist-info → jaxsim-0.2.1.dev98.dist-info}/RECORD +16 -16
- jaxsim/utils/hashless.py +0 -18
- {jaxsim-0.2.1.dev80.dist-info → jaxsim-0.2.1.dev98.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.1.dev80.dist-info → jaxsim-0.2.1.dev98.dist-info}/WHEEL +0 -0
- {jaxsim-0.2.1.dev80.dist-info → jaxsim-0.2.1.dev98.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.2.1.
|
16
|
-
__version_tuple__ = version_tuple = (0, 2, 1, '
|
15
|
+
__version__ = version = '0.2.1.dev98'
|
16
|
+
__version_tuple__ = version_tuple = (0, 2, 1, 'dev98')
|
jaxsim/api/data.py
CHANGED
@@ -30,7 +30,7 @@ except ImportError:
|
|
30
30
|
@jax_dataclasses.pytree_dataclass
|
31
31
|
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
32
32
|
"""
|
33
|
-
Class containing the
|
33
|
+
Class containing the data of a `JaxSimModel` object.
|
34
34
|
"""
|
35
35
|
|
36
36
|
state: ODEState
|
@@ -43,6 +43,24 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
43
43
|
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
|
44
44
|
)
|
45
45
|
|
46
|
+
def __hash__(self) -> int:
|
47
|
+
|
48
|
+
return hash(
|
49
|
+
(
|
50
|
+
hash(self.state),
|
51
|
+
hash(tuple(self.gravity.flatten().tolist())),
|
52
|
+
hash(self.soft_contacts_params),
|
53
|
+
hash(jnp.atleast_1d(self.time_ns).flatten().tolist()),
|
54
|
+
)
|
55
|
+
)
|
56
|
+
|
57
|
+
def __eq__(self, other: JaxSimModelData) -> bool:
|
58
|
+
|
59
|
+
if not isinstance(other, JaxSimModelData):
|
60
|
+
return False
|
61
|
+
|
62
|
+
return hash(self) == hash(other)
|
63
|
+
|
46
64
|
def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
|
47
65
|
"""
|
48
66
|
Check if the current state is valid for the given model.
|
jaxsim/api/frame.py
CHANGED
@@ -30,7 +30,7 @@ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -
|
|
30
30
|
"""
|
31
31
|
|
32
32
|
# Get the intermediate representation parsed from the model description.
|
33
|
-
ir = model.description
|
33
|
+
ir = model.description
|
34
34
|
|
35
35
|
# Extract the indices of the frame and the link it is attached to.
|
36
36
|
F = ir.frames[frame_idx - model.number_of_links()]
|
@@ -51,7 +51,7 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
|
|
51
51
|
The index of the frame.
|
52
52
|
"""
|
53
53
|
|
54
|
-
frame_names = np.array([frame.name for frame in model.description.
|
54
|
+
frame_names = np.array([frame.name for frame in model.description.frames])
|
55
55
|
|
56
56
|
if frame_name in frame_names:
|
57
57
|
idx_in_list = np.argwhere(frame_names == frame_name)
|
@@ -72,7 +72,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
|
|
72
72
|
The name of the frame.
|
73
73
|
"""
|
74
74
|
|
75
|
-
return model.description.
|
75
|
+
return model.description.frames[frame_index - model.number_of_links()].name
|
76
76
|
|
77
77
|
|
78
78
|
@functools.partial(jax.jit, static_argnames=["frame_names"])
|
@@ -144,7 +144,7 @@ def transform(
|
|
144
144
|
W_H_L = js.link.transform(model=model, data=data, link_index=L)
|
145
145
|
|
146
146
|
# Get the static frame pose wrt the parent link.
|
147
|
-
frame = model.description.
|
147
|
+
frame = model.description.frames[frame_index - model.number_of_links()]
|
148
148
|
L_H_F = frame.pose
|
149
149
|
|
150
150
|
# Combine the transforms computing the frame pose.
|
jaxsim/api/kin_dyn_parameters.py
CHANGED
@@ -11,7 +11,7 @@ from jax_dataclasses import Static
|
|
11
11
|
import jaxsim.typing as jtp
|
12
12
|
from jaxsim.math import Inertia, JointModel, supported_joint_motion
|
13
13
|
from jaxsim.parsers.descriptions import JointDescription, ModelDescription
|
14
|
-
from jaxsim.utils import JaxsimDataclass
|
14
|
+
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
|
15
15
|
|
16
16
|
|
17
17
|
@jax_dataclasses.pytree_dataclass
|
@@ -32,8 +32,8 @@ class KynDynParameters(JaxsimDataclass):
|
|
32
32
|
|
33
33
|
# Static
|
34
34
|
link_names: Static[tuple[str]]
|
35
|
-
|
36
|
-
|
35
|
+
_parent_array: Static[HashedNumpyArray]
|
36
|
+
_support_body_array_bool: Static[HashedNumpyArray]
|
37
37
|
|
38
38
|
# Links
|
39
39
|
link_parameters: LinkParameters
|
@@ -45,6 +45,14 @@ class KynDynParameters(JaxsimDataclass):
|
|
45
45
|
joint_model: JointModel
|
46
46
|
joint_parameters: JointParameters | None
|
47
47
|
|
48
|
+
@property
|
49
|
+
def parent_array(self) -> jtp.Vector:
|
50
|
+
return self._parent_array.get()
|
51
|
+
|
52
|
+
@property
|
53
|
+
def support_body_array_bool(self) -> jtp.Matrix:
|
54
|
+
return self._support_body_array_bool.get()
|
55
|
+
|
48
56
|
@staticmethod
|
49
57
|
def build(model_description: ModelDescription) -> KynDynParameters:
|
50
58
|
"""
|
@@ -191,8 +199,8 @@ class KynDynParameters(JaxsimDataclass):
|
|
191
199
|
|
192
200
|
return KynDynParameters(
|
193
201
|
link_names=tuple(l.name for l in ordered_links),
|
194
|
-
|
195
|
-
|
202
|
+
_parent_array=HashedNumpyArray(array=parent_array),
|
203
|
+
_support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
|
196
204
|
link_parameters=link_parameters,
|
197
205
|
joint_model=joint_model,
|
198
206
|
joint_parameters=joint_parameters,
|
@@ -204,23 +212,18 @@ class KynDynParameters(JaxsimDataclass):
|
|
204
212
|
if not isinstance(other, KynDynParameters):
|
205
213
|
return False
|
206
214
|
|
207
|
-
|
208
|
-
equal = equal and self.number_of_links() == other.number_of_links()
|
209
|
-
equal = equal and self.number_of_joints() == other.number_of_joints()
|
210
|
-
equal = equal and jnp.allclose(self.parent_array, other.parent_array)
|
211
|
-
|
212
|
-
return equal
|
215
|
+
return hash(self) == hash(other)
|
213
216
|
|
214
217
|
def __hash__(self) -> int:
|
215
218
|
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
219
|
+
return hash(
|
220
|
+
(
|
221
|
+
hash(self.number_of_links()),
|
222
|
+
hash(self.number_of_joints()),
|
223
|
+
hash(tuple(jnp.atleast_1d(self.parent_array).flatten().tolist())),
|
224
|
+
)
|
220
225
|
)
|
221
226
|
|
222
|
-
return hash(h)
|
223
|
-
|
224
227
|
# =============================
|
225
228
|
# Helpers to extract parameters
|
226
229
|
# =============================
|
@@ -388,7 +391,7 @@ class KynDynParameters(JaxsimDataclass):
|
|
388
391
|
pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
|
389
392
|
jnp.array(self.joint_model.joint_types[1:]).astype(int),
|
390
393
|
jnp.array(joint_positions),
|
391
|
-
jnp.array(self.joint_model.joint_axis),
|
394
|
+
jnp.array([j.axis for j in self.joint_model.joint_axis]),
|
392
395
|
)
|
393
396
|
|
394
397
|
# Extract the transforms and motion subspaces of the joints.
|
jaxsim/api/model.py
CHANGED
@@ -32,18 +32,37 @@ class JaxSimModel(JaxsimDataclass):
|
|
32
32
|
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
|
33
33
|
default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
|
34
34
|
)
|
35
|
+
kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
|
36
|
+
dataclasses.field(default=None, repr=False, compare=False, hash=False)
|
37
|
+
)
|
35
38
|
|
36
39
|
built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
|
37
40
|
default=None, repr=False, compare=False, hash=False
|
38
41
|
)
|
39
42
|
|
40
|
-
|
43
|
+
_description: Static[
|
41
44
|
HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
|
42
45
|
] = dataclasses.field(default=None, repr=False, compare=False, hash=False)
|
43
46
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
+
@property
|
48
|
+
def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
|
49
|
+
return self._description.get()
|
50
|
+
|
51
|
+
def __eq__(self, other: JaxSimModel) -> bool:
|
52
|
+
|
53
|
+
if not isinstance(other, JaxSimModel):
|
54
|
+
return False
|
55
|
+
|
56
|
+
return hash(self) == hash(other)
|
57
|
+
|
58
|
+
def __hash__(self) -> int:
|
59
|
+
|
60
|
+
return hash(
|
61
|
+
(
|
62
|
+
hash(self.model_name),
|
63
|
+
hash(self.kin_dyn_parameters),
|
64
|
+
)
|
65
|
+
)
|
47
66
|
|
48
67
|
# ========================
|
49
68
|
# Initialization and state
|
@@ -137,7 +156,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
137
156
|
# Build the model
|
138
157
|
model = JaxSimModel(
|
139
158
|
model_name=model_name,
|
140
|
-
|
159
|
+
_description=HashlessObject(obj=model_description),
|
141
160
|
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
|
142
161
|
model_description=model_description
|
143
162
|
),
|
@@ -260,7 +279,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
260
279
|
The names of the links in the model.
|
261
280
|
"""
|
262
281
|
|
263
|
-
return tuple(
|
282
|
+
return tuple(frame.name for frame in self.description.frames)
|
264
283
|
|
265
284
|
|
266
285
|
# =====================
|
@@ -297,7 +316,7 @@ def reduce(
|
|
297
316
|
|
298
317
|
# Copy the model description with a deep copy of the joints.
|
299
318
|
intermediate_description = dataclasses.replace(
|
300
|
-
model.description
|
319
|
+
model.description, joints=copy.deepcopy(model.description.joints)
|
301
320
|
)
|
302
321
|
|
303
322
|
# Update the initial position of the joints.
|
jaxsim/api/ode_data.py
CHANGED
@@ -281,6 +281,24 @@ class PhysicsModelState(JaxsimDataclass):
|
|
281
281
|
default_factory=lambda: jnp.zeros(3)
|
282
282
|
)
|
283
283
|
|
284
|
+
def __hash__(self) -> int:
|
285
|
+
|
286
|
+
return hash(
|
287
|
+
(
|
288
|
+
hash(tuple(jnp.atleast_1d(self.joint_positions.flatten().tolist()))),
|
289
|
+
hash(tuple(jnp.atleast_1d(self.joint_velocities.flatten().tolist()))),
|
290
|
+
hash(tuple(self.base_position.flatten().tolist())),
|
291
|
+
hash(tuple(self.base_quaternion.flatten().tolist())),
|
292
|
+
)
|
293
|
+
)
|
294
|
+
|
295
|
+
def __eq__(self, other: PhysicsModelState) -> bool:
|
296
|
+
|
297
|
+
if not isinstance(other, PhysicsModelState):
|
298
|
+
return False
|
299
|
+
|
300
|
+
return hash(self) == hash(other)
|
301
|
+
|
284
302
|
@staticmethod
|
285
303
|
def build_from_jaxsim_model(
|
286
304
|
model: js.model.JaxSimModel | None = None,
|
@@ -593,6 +611,19 @@ class SoftContactsState(JaxsimDataclass):
|
|
593
611
|
|
594
612
|
tangential_deformation: jtp.Matrix
|
595
613
|
|
614
|
+
def __hash__(self) -> int:
|
615
|
+
|
616
|
+
return hash(
|
617
|
+
tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist())
|
618
|
+
)
|
619
|
+
|
620
|
+
def __eq__(self, other: SoftContactsState) -> bool:
|
621
|
+
|
622
|
+
if not isinstance(other, SoftContactsState):
|
623
|
+
return False
|
624
|
+
|
625
|
+
return hash(self) == hash(other)
|
626
|
+
|
596
627
|
@staticmethod
|
597
628
|
def build_from_jaxsim_model(
|
598
629
|
model: js.model.JaxSimModel | None = None,
|
jaxsim/math/joint_model.py
CHANGED
@@ -39,7 +39,7 @@ class JointModel:
|
|
39
39
|
|
40
40
|
joint_dofs: Static[tuple[int, ...]]
|
41
41
|
joint_names: Static[tuple[str, ...]]
|
42
|
-
joint_types: Static[tuple[
|
42
|
+
joint_types: Static[tuple[int, ...]]
|
43
43
|
joint_axis: Static[tuple[JointGenericAxis, ...]]
|
44
44
|
|
45
45
|
@staticmethod
|
@@ -109,7 +109,7 @@ class JointModel:
|
|
109
109
|
joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
|
110
110
|
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
|
111
111
|
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
|
112
|
-
joint_axis=tuple(
|
112
|
+
joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
|
113
113
|
)
|
114
114
|
|
115
115
|
def parent_H_child(
|
@@ -201,7 +201,7 @@ class JointModel:
|
|
201
201
|
pre_H_suc, S = supported_joint_motion(
|
202
202
|
self.joint_types[joint_index],
|
203
203
|
joint_position,
|
204
|
-
self.joint_axis[joint_index],
|
204
|
+
self.joint_axis[joint_index].axis,
|
205
205
|
)
|
206
206
|
|
207
207
|
return pre_H_suc, S
|
@@ -224,9 +224,9 @@ class JointModel:
|
|
224
224
|
|
225
225
|
@jax.jit
|
226
226
|
def supported_joint_motion(
|
227
|
-
joint_type:
|
227
|
+
joint_type: jtp.IntLike,
|
228
228
|
joint_position: jtp.VectorLike,
|
229
|
-
joint_axis:
|
229
|
+
joint_axis: jtp.VectorLike | None = None,
|
230
230
|
/,
|
231
231
|
) -> tuple[jtp.Matrix, jtp.Array]:
|
232
232
|
"""
|
@@ -234,8 +234,8 @@ def supported_joint_motion(
|
|
234
234
|
|
235
235
|
Args:
|
236
236
|
joint_type: The type of the joint.
|
237
|
-
joint_axis: The axis of rotation or translation of the joint.
|
238
237
|
joint_position: The position of the joint.
|
238
|
+
joint_axis: The optional 3D axis of rotation or translation of the joint.
|
239
239
|
|
240
240
|
Returns:
|
241
241
|
A tuple containing the homogeneous transformation and the motion subspace.
|
@@ -244,26 +244,33 @@ def supported_joint_motion(
|
|
244
244
|
# Prepare the joint position
|
245
245
|
s = jnp.array(joint_position).astype(float)
|
246
246
|
|
247
|
-
def compute_F():
|
247
|
+
def compute_F() -> tuple[jtp.Matrix, jtp.Array]:
|
248
248
|
return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))
|
249
249
|
|
250
|
-
def compute_R():
|
250
|
+
def compute_R() -> tuple[jtp.Matrix, jtp.Array]:
|
251
|
+
|
252
|
+
# Get the additional argument specifying the joint axis.
|
253
|
+
# This is a metadata required by only some joint types.
|
254
|
+
axis = jnp.array(joint_axis).astype(float).squeeze()
|
255
|
+
|
251
256
|
pre_H_suc = jaxlie.SE3.from_rotation(
|
252
|
-
rotation=jaxlie.SO3.from_matrix(
|
253
|
-
Rotation.from_axis_angle(vector=s * joint_axis)
|
254
|
-
)
|
257
|
+
rotation=jaxlie.SO3.from_matrix(Rotation.from_axis_angle(vector=s * axis))
|
255
258
|
)
|
256
259
|
|
257
|
-
S = jnp.vstack(jnp.hstack([jnp.zeros(3),
|
260
|
+
S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))
|
261
|
+
|
258
262
|
return pre_H_suc, S
|
259
263
|
|
260
|
-
def compute_P():
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
)
|
264
|
+
def compute_P() -> tuple[jtp.Matrix, jtp.Array]:
|
265
|
+
|
266
|
+
# Get the additional argument specifying the joint axis.
|
267
|
+
# This is a metadata required by only some joint types.
|
268
|
+
axis = jnp.array(joint_axis).astype(float).squeeze()
|
269
|
+
|
270
|
+
pre_H_suc = jaxlie.SE3.from_translation(translation=jnp.array(s * axis))
|
271
|
+
|
272
|
+
S = jnp.vstack(jnp.hstack([axis, jnp.zeros(3)]))
|
265
273
|
|
266
|
-
S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)]))
|
267
274
|
return pre_H_suc, S
|
268
275
|
|
269
276
|
pre_H_suc, S = jax.lax.switch(
|
@@ -30,9 +30,11 @@ class JointGenericAxis:
|
|
30
30
|
axis: jtp.Vector
|
31
31
|
|
32
32
|
def __hash__(self) -> int:
|
33
|
-
|
33
|
+
|
34
|
+
return hash(tuple(self.axis.tolist()))
|
34
35
|
|
35
36
|
def __eq__(self, other: JointGenericAxis) -> bool:
|
37
|
+
|
36
38
|
if not isinstance(other, JointGenericAxis):
|
37
39
|
return False
|
38
40
|
|
jaxsim/rbda/soft_contacts.py
CHANGED
@@ -29,6 +29,23 @@ class SoftContactsParams(JaxsimDataclass):
|
|
29
29
|
default_factory=lambda: jnp.array(0.5, dtype=float)
|
30
30
|
)
|
31
31
|
|
32
|
+
def __hash__(self) -> int:
|
33
|
+
|
34
|
+
return hash(
|
35
|
+
(
|
36
|
+
hash(tuple(jnp.atleast_1d(self.K).flatten().tolist())),
|
37
|
+
hash(tuple(jnp.atleast_1d(self.D).flatten().tolist())),
|
38
|
+
hash(tuple(jnp.atleast_1d(self.mu).flatten().tolist())),
|
39
|
+
)
|
40
|
+
)
|
41
|
+
|
42
|
+
def __eq__(self, other: SoftContactsParams) -> bool:
|
43
|
+
|
44
|
+
if not isinstance(other, SoftContactsParams):
|
45
|
+
return NotImplemented
|
46
|
+
|
47
|
+
return hash(self) == hash(other)
|
48
|
+
|
32
49
|
@staticmethod
|
33
50
|
def build(
|
34
51
|
K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5
|
jaxsim/utils/__init__.py
CHANGED
jaxsim/utils/wrappers.py
ADDED
@@ -0,0 +1,78 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import Generic, TypeVar
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax_dataclasses
|
8
|
+
import numpy as np
|
9
|
+
import numpy.typing as npt
|
10
|
+
|
11
|
+
T = TypeVar("T")
|
12
|
+
|
13
|
+
|
14
|
+
@dataclasses.dataclass
|
15
|
+
class HashlessObject(Generic[T]):
|
16
|
+
"""
|
17
|
+
A class that wraps an object and makes it hashless.
|
18
|
+
|
19
|
+
This is useful for creating particular JAX pytrees.
|
20
|
+
For example, to create a pytree with a static leaf that is ignored
|
21
|
+
by JAX when it compares two instances to trigger a JIT recompilation.
|
22
|
+
"""
|
23
|
+
|
24
|
+
obj: T
|
25
|
+
|
26
|
+
def get(self: HashlessObject[T]) -> T:
|
27
|
+
return self.obj
|
28
|
+
|
29
|
+
def __hash__(self) -> int:
|
30
|
+
|
31
|
+
return 0
|
32
|
+
|
33
|
+
def __eq__(self, other: HashlessObject[T]) -> bool:
|
34
|
+
|
35
|
+
if not isinstance(other, HashlessObject) and isinstance(
|
36
|
+
other.get(), type(self.get())
|
37
|
+
):
|
38
|
+
return False
|
39
|
+
|
40
|
+
return hash(self) == hash(other)
|
41
|
+
|
42
|
+
|
43
|
+
@jax_dataclasses.pytree_dataclass
|
44
|
+
class HashedNumpyArray:
|
45
|
+
"""
|
46
|
+
A class that wraps a numpy array and makes it hashable.
|
47
|
+
|
48
|
+
This is useful for creating particular JAX pytrees.
|
49
|
+
For example, to create a pytree with a plain NumPy or JAX NumPy array as static leaf.
|
50
|
+
|
51
|
+
Note:
|
52
|
+
Calculating with the wrapper class the hash of a very large array can be
|
53
|
+
very expensive. If the array is large and only the equality operator is needed,
|
54
|
+
set `large_array=True` to use a faster comparison method.
|
55
|
+
"""
|
56
|
+
|
57
|
+
array: jax.Array | npt.NDArray
|
58
|
+
|
59
|
+
large_array: jax_dataclasses.Static[bool] = dataclasses.field(
|
60
|
+
default=False, repr=False, compare=False, hash=False
|
61
|
+
)
|
62
|
+
|
63
|
+
def get(self) -> jax.Array | npt.NDArray:
|
64
|
+
return self.array
|
65
|
+
|
66
|
+
def __hash__(self) -> int:
|
67
|
+
|
68
|
+
return hash(tuple(np.atleast_1d(self.array).flatten().tolist()))
|
69
|
+
|
70
|
+
def __eq__(self, other: HashedNumpyArray) -> bool:
|
71
|
+
|
72
|
+
if not isinstance(other, HashedNumpyArray):
|
73
|
+
return False
|
74
|
+
|
75
|
+
if self.large_array:
|
76
|
+
return np.array_equal(self.array, other.array)
|
77
|
+
|
78
|
+
return hash(self) == hash(other)
|
@@ -1,19 +1,19 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=OcrfoYS1DGcmAGqu2AqlCTiUVxcpi-IsVwcr_16x74Q,1789
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=PlgYXFRQrTcDBWrHgW3TWsyD0WOKMdyNQ1dtp2gm-oU,426
|
3
3
|
jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
|
4
4
|
jaxsim/typing.py,sha256=MeuOCQtLAr-sPkvB_sU8FtwGNRirz1auCwIgRC-QZl8,646
|
5
5
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
6
6
|
jaxsim/api/com.py,sha256=Yof6otFi-mLWAs1rqjmeNJTOWIH9gn7BdU5EIjiL6Ts,13481
|
7
7
|
jaxsim/api/common.py,sha256=DV-WZG28sikXopNv458aYvpLjmiAtFr5LRscOwXusuk,6640
|
8
8
|
jaxsim/api/contact.py,sha256=Cvr-EfQtHP3nymtWdo-9WWU24Bkta-2Pp3nKsdjo6uc,12778
|
9
|
-
jaxsim/api/data.py,sha256=
|
10
|
-
jaxsim/api/frame.py,sha256=
|
9
|
+
jaxsim/api/data.py,sha256=xfKJz6Rw0YTk-EHCGiT8BFQrs_ggOz01lRi1Qh1mb28,27256
|
10
|
+
jaxsim/api/frame.py,sha256=0YXOrGmx3cSQqa4_Ky-n6zyup3I3xvXNEgub-Bc5xUw,6222
|
11
11
|
jaxsim/api/joint.py,sha256=-5DogPg4g4mmLckyVIVNjwv-Rxz0IWS7_md9nDlhPWA,4581
|
12
|
-
jaxsim/api/kin_dyn_parameters.py,sha256=
|
12
|
+
jaxsim/api/kin_dyn_parameters.py,sha256=zMca7OmCsCWK_cavLTSZSeYh9Qu1-409cdsyWvWPAUQ,26090
|
13
13
|
jaxsim/api/link.py,sha256=rypTwkMf9HJ5UuAtHRJh0LqqdJWcLKTtTjWcjduEsF0,9842
|
14
|
-
jaxsim/api/model.py,sha256=
|
14
|
+
jaxsim/api/model.py,sha256=Wwg3Wp9jm2Ah7wjvzou7oZYdZk2iTlfzSidp6GNwfJ0,54263
|
15
15
|
jaxsim/api/ode.py,sha256=6l-6i2YHagsQvR8Ac-_fmO6P0hBVT6NkHhwXnrdITEg,9785
|
16
|
-
jaxsim/api/ode_data.py,sha256=
|
16
|
+
jaxsim/api/ode_data.py,sha256=D6FzMkvY_qNuoFEImyp7sxAk-0pJOd3oZeSr9bBTcLk,23089
|
17
17
|
jaxsim/api/references.py,sha256=Lvskf17r619KKxwCJP7hAAty2kaXgDXJX1uKqoDIDgo,15483
|
18
18
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
19
19
|
jaxsim/integrators/common.py,sha256=9HXRVFo95Mpt6RcVhBrOfvOO7mDxqbkXeg_lKUibEFY,20693
|
@@ -23,7 +23,7 @@ jaxsim/math/__init__.py,sha256=inJ9nRFkqstuGa8OyFkfWVudo5U9Ug4WgDBuKva8AIA,337
|
|
23
23
|
jaxsim/math/adjoint.py,sha256=DT21izjVW497GRrgNfx8tv0ZeWW5QncWMGMhI0acUNw,4425
|
24
24
|
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
25
25
|
jaxsim/math/inertia.py,sha256=UAB7ym4gXFanejcs_ovZMpteHCc6poWYmt-mLmd5hhk,1640
|
26
|
-
jaxsim/math/joint_model.py,sha256=
|
26
|
+
jaxsim/math/joint_model.py,sha256=lGZxwuGqIXYeF2dYC5I248-mRNJLvb86iomq7yLEBmE,9909
|
27
27
|
jaxsim/math/quaternion.py,sha256=X9b8jHf0QemKUjIZSnXRJc3DdMr42CBhBy_mi9_X_AM,5068
|
28
28
|
jaxsim/math/rotation.py,sha256=Z90daUjGpuNEVLfWB3SVtM9EtwAIaneVj9A9UpWXqhA,2182
|
29
29
|
jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
|
@@ -37,7 +37,7 @@ jaxsim/parsers/__init__.py,sha256=sonYi-bBWAoB04kp1mxT4uIORxjb7SdZ0ukGPmVx98Y,44
|
|
37
37
|
jaxsim/parsers/kinematic_graph.py,sha256=zFt7x7pPGJar36Azukdi1eI_sa1kMWD3B8kZqcHx6iw,33934
|
38
38
|
jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
|
39
39
|
jaxsim/parsers/descriptions/collision.py,sha256=HUWwuRgI9KznY29FFw1_zU3bGigDEezrcPOJSxSJGNU,3382
|
40
|
-
jaxsim/parsers/descriptions/joint.py,sha256=
|
40
|
+
jaxsim/parsers/descriptions/joint.py,sha256=lRnYMmjpASpz0Ueuqzwnj5Ze4yLRgPTx66H0_kbQnNI,3042
|
41
41
|
jaxsim/parsers/descriptions/link.py,sha256=GC-6ZgRZuRVpcRo1sY6YaR8lkCHkR4DvHNs2Ydw_tn4,2887
|
42
42
|
jaxsim/parsers/descriptions/model.py,sha256=uO5xOJtViihVPnSSsmfQJvCh45ANyi9KYAzLOhH0R8g,8993
|
43
43
|
jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
|
@@ -50,16 +50,16 @@ jaxsim/rbda/crba.py,sha256=GodskOZjtrSlbQAqxRv1un_706O7BaJK-U2qa18vJk8,4741
|
|
50
50
|
jaxsim/rbda/forward_kinematics.py,sha256=OHugNU7C0UxYAW0o1rqH1ZgniSwurz6L1T1MJxfxq08,3418
|
51
51
|
jaxsim/rbda/jacobian.py,sha256=9LGGy9ya5m5U0mBmV1NFH5XYZpEMYbx74qnYBvZs7Ok,6360
|
52
52
|
jaxsim/rbda/rnea.py,sha256=DjwkvXQVUSUclM3Uy3UPZ2tao91R5dGd4o7TsS2qObI,7650
|
53
|
-
jaxsim/rbda/soft_contacts.py,sha256=
|
53
|
+
jaxsim/rbda/soft_contacts.py,sha256=52zJOF31hFpqoaOednTvi8j_UxhRcdGNjzOPb2v2MPc,11257
|
54
54
|
jaxsim/rbda/utils.py,sha256=zpbFM2Iq8cntku0BFVu9nfEqZhInCWi9D2INT6MFEI8,5003
|
55
55
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
56
56
|
jaxsim/terrain/terrain.py,sha256=q0xkWqEShVq-p1j2abTLZq8sEhjyJwquxQKm80PaHhM,2161
|
57
|
-
jaxsim/utils/__init__.py,sha256=
|
58
|
-
jaxsim/utils/hashless.py,sha256=bFIwKeo9KiWwsY8QM55duEGGQOyyJ4jQyPcuqTLEp5k,297
|
57
|
+
jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
59
58
|
jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
|
60
59
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
61
|
-
jaxsim
|
62
|
-
jaxsim-0.2.1.
|
63
|
-
jaxsim-0.2.1.
|
64
|
-
jaxsim-0.2.1.
|
65
|
-
jaxsim-0.2.1.
|
60
|
+
jaxsim/utils/wrappers.py,sha256=EJMcblYKUjxw9HJShVf81Ig3pHUJno6Dx6h-RnY--wM,2040
|
61
|
+
jaxsim-0.2.1.dev98.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
62
|
+
jaxsim-0.2.1.dev98.dist-info/METADATA,sha256=_aAfq6LooqjTyZY6utw1p0NrkwVZr10NC9Gw-MqCR1Y,9744
|
63
|
+
jaxsim-0.2.1.dev98.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
64
|
+
jaxsim-0.2.1.dev98.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
65
|
+
jaxsim-0.2.1.dev98.dist-info/RECORD,,
|
jaxsim/utils/hashless.py
DELETED
@@ -1,18 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import dataclasses
|
4
|
-
from typing import Generic, TypeVar
|
5
|
-
|
6
|
-
T = TypeVar("T")
|
7
|
-
|
8
|
-
|
9
|
-
@dataclasses.dataclass
|
10
|
-
class HashlessObject(Generic[T]):
|
11
|
-
|
12
|
-
obj: T
|
13
|
-
|
14
|
-
def get(self: HashlessObject[T]) -> T:
|
15
|
-
return self.obj
|
16
|
-
|
17
|
-
def __hash__(self) -> int:
|
18
|
-
return 0
|
File without changes
|
File without changes
|
File without changes
|