jaxsim 0.3.1.dev4__py3-none-any.whl → 0.3.1.dev17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/data.py +4 -2
- jaxsim/api/frame.py +30 -26
- jaxsim/api/kin_dyn_parameters.py +90 -7
- jaxsim/api/model.py +26 -13
- jaxsim/api/ode_data.py +11 -7
- jaxsim/parsers/descriptions/joint.py +46 -12
- jaxsim/parsers/descriptions/link.py +22 -5
- jaxsim/parsers/descriptions/model.py +13 -3
- jaxsim/parsers/kinematic_graph.py +19 -7
- jaxsim/rbda/soft_contacts.py +5 -3
- jaxsim/utils/wrappers.py +69 -2
- {jaxsim-0.3.1.dev4.dist-info → jaxsim-0.3.1.dev17.dist-info}/METADATA +1 -1
- {jaxsim-0.3.1.dev4.dist-info → jaxsim-0.3.1.dev17.dist-info}/RECORD +17 -17
- {jaxsim-0.3.1.dev4.dist-info → jaxsim-0.3.1.dev17.dist-info}/LICENSE +0 -0
- {jaxsim-0.3.1.dev4.dist-info → jaxsim-0.3.1.dev17.dist-info}/WHEEL +0 -0
- {jaxsim-0.3.1.dev4.dist-info → jaxsim-0.3.1.dev17.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.3.1.
|
16
|
-
__version_tuple__ = version_tuple = (0, 3, 1, '
|
15
|
+
__version__ = version = '0.3.1.dev17'
|
16
|
+
__version_tuple__ = version_tuple = (0, 3, 1, 'dev17')
|
jaxsim/api/data.py
CHANGED
@@ -45,12 +45,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
45
45
|
|
46
46
|
def __hash__(self) -> int:
|
47
47
|
|
48
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
49
|
+
|
48
50
|
return hash(
|
49
51
|
(
|
50
52
|
hash(self.state),
|
51
|
-
|
53
|
+
HashedNumpyArray.hash_of_array(self.gravity),
|
52
54
|
hash(self.soft_contacts_params),
|
53
|
-
|
55
|
+
HashedNumpyArray.hash_of_array(self.time_ns),
|
54
56
|
)
|
55
57
|
)
|
56
58
|
|
jaxsim/api/frame.py
CHANGED
@@ -17,7 +17,9 @@ from .common import VelRepr
|
|
17
17
|
# =======================
|
18
18
|
|
19
19
|
|
20
|
-
def idx_of_parent_link(
|
20
|
+
def idx_of_parent_link(
|
21
|
+
model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike
|
22
|
+
) -> jtp.Int:
|
21
23
|
"""
|
22
24
|
Get the index of the link to which the frame is rigidly attached.
|
23
25
|
|
@@ -29,17 +31,13 @@ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -
|
|
29
31
|
The index of the frame's parent link.
|
30
32
|
"""
|
31
33
|
|
32
|
-
|
33
|
-
|
34
|
+
return model.kin_dyn_parameters.frame_parameters.body[
|
35
|
+
frame_idx - model.number_of_links()
|
36
|
+
]
|
34
37
|
|
35
|
-
# Extract the indices of the frame and the link it is attached to.
|
36
|
-
F = ir.frames[frame_idx - model.number_of_links()]
|
37
|
-
L = ir.links_dict[F.parent.name].index
|
38
38
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
|
39
|
+
@functools.partial(jax.jit, static_argnames="frame_name")
|
40
|
+
def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
|
43
41
|
"""
|
44
42
|
Convert the name of a frame to its index.
|
45
43
|
|
@@ -51,13 +49,19 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
|
|
51
49
|
The index of the frame.
|
52
50
|
"""
|
53
51
|
|
54
|
-
|
52
|
+
if frame_name in model.kin_dyn_parameters.frame_parameters.name:
|
53
|
+
return (
|
54
|
+
jnp.array(
|
55
|
+
np.argwhere(
|
56
|
+
np.array(model.kin_dyn_parameters.frame_parameters.name)
|
57
|
+
== frame_name
|
58
|
+
)
|
59
|
+
)
|
60
|
+
.squeeze()
|
61
|
+
.astype(int)
|
62
|
+
) + model.number_of_links()
|
55
63
|
|
56
|
-
|
57
|
-
idx_in_list = np.argwhere(frame_names == frame_name)
|
58
|
-
return int(idx_in_list.squeeze().tolist()) + model.number_of_links()
|
59
|
-
|
60
|
-
return -1
|
64
|
+
return jnp.array(-1).astype(int)
|
61
65
|
|
62
66
|
|
63
67
|
def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
|
@@ -72,7 +76,9 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
|
|
72
76
|
The name of the frame.
|
73
77
|
"""
|
74
78
|
|
75
|
-
return model.
|
79
|
+
return model.kin_dyn_parameters.frame_parameters.name[
|
80
|
+
frame_index - model.number_of_links()
|
81
|
+
]
|
76
82
|
|
77
83
|
|
78
84
|
@functools.partial(jax.jit, static_argnames=["frame_names"])
|
@@ -91,7 +97,7 @@ def names_to_idxs(
|
|
91
97
|
"""
|
92
98
|
|
93
99
|
return jnp.array(
|
94
|
-
[name_to_idx(model=model, frame_name=
|
100
|
+
[name_to_idx(model=model, frame_name=name) for name in frame_names]
|
95
101
|
).astype(int)
|
96
102
|
|
97
103
|
|
@@ -109,10 +115,7 @@ def idxs_to_names(
|
|
109
115
|
The names of the frames.
|
110
116
|
"""
|
111
117
|
|
112
|
-
return tuple(
|
113
|
-
idx_to_name(model=model, frame_index=frame_index)
|
114
|
-
for frame_index in frame_indices
|
115
|
-
)
|
118
|
+
return tuple(idx_to_name(model=model, frame_index=idx) for idx in frame_indices)
|
116
119
|
|
117
120
|
|
118
121
|
# ==========
|
@@ -120,7 +123,7 @@ def idxs_to_names(
|
|
120
123
|
# ==========
|
121
124
|
|
122
125
|
|
123
|
-
@
|
126
|
+
@jax.jit
|
124
127
|
def transform(
|
125
128
|
model: js.model.JaxSimModel,
|
126
129
|
data: js.data.JaxSimModelData,
|
@@ -144,14 +147,15 @@ def transform(
|
|
144
147
|
W_H_L = js.link.transform(model=model, data=data, link_index=L)
|
145
148
|
|
146
149
|
# Get the static frame pose wrt the parent link.
|
147
|
-
|
148
|
-
|
150
|
+
L_H_F = model.kin_dyn_parameters.frame_parameters.transform[
|
151
|
+
frame_index - model.number_of_links()
|
152
|
+
]
|
149
153
|
|
150
154
|
# Combine the transforms computing the frame pose.
|
151
155
|
return W_H_L @ L_H_F
|
152
156
|
|
153
157
|
|
154
|
-
@functools.partial(jax.jit, static_argnames=["
|
158
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
155
159
|
def jacobian(
|
156
160
|
model: js.model.JaxSimModel,
|
157
161
|
data: js.data.JaxSimModelData,
|
jaxsim/api/kin_dyn_parameters.py
CHANGED
@@ -6,7 +6,6 @@ import jax.lax
|
|
6
6
|
import jax.numpy as jnp
|
7
7
|
import jax_dataclasses
|
8
8
|
import jaxlie
|
9
|
-
import numpy as np
|
10
9
|
from jax_dataclasses import Static
|
11
10
|
|
12
11
|
import jaxsim.typing as jtp
|
@@ -15,7 +14,7 @@ from jaxsim.parsers.descriptions import JointDescription, ModelDescription
|
|
15
14
|
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
|
16
15
|
|
17
16
|
|
18
|
-
@jax_dataclasses.pytree_dataclass
|
17
|
+
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
|
19
18
|
class KynDynParameters(JaxsimDataclass):
|
20
19
|
r"""
|
21
20
|
Class storing the kinematic and dynamic parameters of a model.
|
@@ -26,6 +25,7 @@ class KynDynParameters(JaxsimDataclass):
|
|
26
25
|
support_body_array_bool:
|
27
26
|
The boolean support parent array :math:`\kappa_{b}(i)` of the model.
|
28
27
|
link_parameters: The parameters of the links.
|
28
|
+
frame_parameters: The parameters of the frames.
|
29
29
|
contact_parameters: The parameters of the collidable points.
|
30
30
|
joint_model: The joint model of the model.
|
31
31
|
joint_parameters: The parameters of the joints.
|
@@ -42,6 +42,9 @@ class KynDynParameters(JaxsimDataclass):
|
|
42
42
|
# Contacts
|
43
43
|
contact_parameters: ContactParameters
|
44
44
|
|
45
|
+
# Frames
|
46
|
+
frame_parameters: FrameParameters
|
47
|
+
|
45
48
|
# Joints
|
46
49
|
joint_model: JointModel
|
47
50
|
joint_parameters: JointParameters | None
|
@@ -141,6 +144,19 @@ class KynDynParameters(JaxsimDataclass):
|
|
141
144
|
model_description=model_description
|
142
145
|
)
|
143
146
|
|
147
|
+
# =================
|
148
|
+
# Frames properties
|
149
|
+
# =================
|
150
|
+
|
151
|
+
# Create the object storing the parameters of frames.
|
152
|
+
# Note that, contrarily to LinkParameters and JointsParameters, this object
|
153
|
+
# is not created with vmap. This is because the "name" attribute of the object
|
154
|
+
# must be Static for JIT-related reasons, and tree_map would not consider it
|
155
|
+
# as a leaf.
|
156
|
+
frame_parameters = FrameParameters.build_from(
|
157
|
+
model_description=model_description
|
158
|
+
)
|
159
|
+
|
144
160
|
# ===============
|
145
161
|
# Tree properties
|
146
162
|
# ===============
|
@@ -206,6 +222,7 @@ class KynDynParameters(JaxsimDataclass):
|
|
206
222
|
joint_model=joint_model,
|
207
223
|
joint_parameters=joint_parameters,
|
208
224
|
contact_parameters=contact_parameters,
|
225
|
+
frame_parameters=frame_parameters,
|
209
226
|
)
|
210
227
|
|
211
228
|
def __eq__(self, other: KynDynParameters) -> bool:
|
@@ -221,7 +238,8 @@ class KynDynParameters(JaxsimDataclass):
|
|
221
238
|
(
|
222
239
|
hash(self.number_of_links()),
|
223
240
|
hash(self.number_of_joints()),
|
224
|
-
hash(
|
241
|
+
hash(self.frame_parameters.name),
|
242
|
+
hash(tuple(self.frame_parameters.body.tolist())),
|
225
243
|
hash(self._parent_array),
|
226
244
|
hash(self._support_body_array_bool),
|
227
245
|
)
|
@@ -730,7 +748,7 @@ class ContactParameters(JaxsimDataclass):
|
|
730
748
|
A tuple of integers representing, for each collidable point, the index of
|
731
749
|
the body (link) to which it is rigidly attached to.
|
732
750
|
point:
|
733
|
-
The
|
751
|
+
The translations between the link frame and the collidable point, expressed
|
734
752
|
in the coordinates of the parent link frame.
|
735
753
|
|
736
754
|
Note:
|
@@ -773,10 +791,75 @@ class ContactParameters(JaxsimDataclass):
|
|
773
791
|
links_dict[cp.parent_link.name].index for cp in collidable_points
|
774
792
|
)
|
775
793
|
|
776
|
-
# Build the
|
794
|
+
# Build the ContactParameters object.
|
777
795
|
cp = ContactParameters(point=points, body=link_index_of_points) # noqa
|
778
796
|
|
779
|
-
assert cp.point.shape[1] == 3
|
780
|
-
assert cp.point.shape[0] == len(cp.body)
|
797
|
+
assert cp.point.shape[1] == 3, cp.point.shape[1]
|
798
|
+
assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]
|
781
799
|
|
782
800
|
return cp
|
801
|
+
|
802
|
+
|
803
|
+
@jax_dataclasses.pytree_dataclass
|
804
|
+
class FrameParameters(JaxsimDataclass):
|
805
|
+
"""
|
806
|
+
Class storing the frame parameters of a model.
|
807
|
+
|
808
|
+
Attributes:
|
809
|
+
name: A tuple of strings defining the frame names.
|
810
|
+
body:
|
811
|
+
A vector of integers representing, for each frame, the index of
|
812
|
+
the body (link) to which it is rigidly attached to.
|
813
|
+
transform: The transforms of the frames w.r.t. their parent link.
|
814
|
+
|
815
|
+
Note:
|
816
|
+
Contrarily to LinkParameters and JointParameters, this class is not meant
|
817
|
+
to be created with vmap. This is because the `name` attribute must be `Static`.
|
818
|
+
"""
|
819
|
+
|
820
|
+
name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple)
|
821
|
+
|
822
|
+
body: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([]))
|
823
|
+
|
824
|
+
transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([]))
|
825
|
+
|
826
|
+
@staticmethod
|
827
|
+
def build_from(model_description: ModelDescription) -> FrameParameters:
|
828
|
+
"""
|
829
|
+
Build a FrameParameters object from a model description.
|
830
|
+
|
831
|
+
Args:
|
832
|
+
model_description: The model description to consider.
|
833
|
+
|
834
|
+
Returns:
|
835
|
+
The FrameParameters object.
|
836
|
+
"""
|
837
|
+
|
838
|
+
if len(model_description.frames) == 0:
|
839
|
+
return FrameParameters()
|
840
|
+
|
841
|
+
# Extract the frame names.
|
842
|
+
names = tuple(frame.name for frame in model_description.frames)
|
843
|
+
|
844
|
+
# For each frame, extract the index of the link to which it is attached to.
|
845
|
+
parent_link_index_of_frames = tuple(
|
846
|
+
model_description.links_dict[frame.parent.name].index
|
847
|
+
for frame in model_description.frames
|
848
|
+
)
|
849
|
+
|
850
|
+
# For each frame, extract the transform w.r.t. its parent link.
|
851
|
+
transforms = jnp.atleast_3d(
|
852
|
+
jnp.stack([frame.pose for frame in model_description.frames])
|
853
|
+
)
|
854
|
+
|
855
|
+
# Build the FrameParameters object.
|
856
|
+
fp = FrameParameters(
|
857
|
+
name=names,
|
858
|
+
transform=transforms.astype(float),
|
859
|
+
body=jnp.array(parent_link_index_of_frames).astype(int),
|
860
|
+
)
|
861
|
+
|
862
|
+
assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:]
|
863
|
+
assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]
|
864
|
+
|
865
|
+
return fp
|
jaxsim/api/model.py
CHANGED
@@ -17,12 +17,12 @@ import jaxsim.api as js
|
|
17
17
|
import jaxsim.parsers.descriptions
|
18
18
|
import jaxsim.typing as jtp
|
19
19
|
from jaxsim.math import Cross
|
20
|
-
from jaxsim.utils import JaxsimDataclass, Mutability
|
20
|
+
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
|
21
21
|
|
22
22
|
from .common import VelRepr
|
23
23
|
|
24
24
|
|
25
|
-
@jax_dataclasses.pytree_dataclass
|
25
|
+
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
|
26
26
|
class JaxSimModel(JaxsimDataclass):
|
27
27
|
"""
|
28
28
|
The JaxSim model defining the kinematics and dynamics of a robot.
|
@@ -31,34 +31,43 @@ class JaxSimModel(JaxsimDataclass):
|
|
31
31
|
model_name: Static[str]
|
32
32
|
|
33
33
|
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
|
34
|
-
default=jaxsim.terrain.FlatTerrain(), repr=False
|
34
|
+
default=jaxsim.terrain.FlatTerrain(), repr=False
|
35
35
|
)
|
36
36
|
|
37
37
|
kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
|
38
|
-
dataclasses.field(default=None, repr=False
|
38
|
+
dataclasses.field(default=None, repr=False)
|
39
39
|
)
|
40
40
|
|
41
41
|
built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
|
42
|
-
default=None, repr=False
|
42
|
+
default=None, repr=False
|
43
43
|
)
|
44
44
|
|
45
|
-
|
46
|
-
|
47
|
-
)
|
45
|
+
_description: Static[
|
46
|
+
wrappers.HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
|
47
|
+
] = dataclasses.field(default=None, repr=False)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
|
51
|
+
return self._description.get()
|
48
52
|
|
49
53
|
def __eq__(self, other: JaxSimModel) -> bool:
|
50
54
|
|
51
55
|
if not isinstance(other, JaxSimModel):
|
52
56
|
return False
|
53
57
|
|
54
|
-
|
58
|
+
if self.model_name != other.model_name:
|
59
|
+
return False
|
60
|
+
|
61
|
+
if self.kin_dyn_parameters != other.kin_dyn_parameters:
|
62
|
+
return False
|
63
|
+
|
64
|
+
return True
|
55
65
|
|
56
66
|
def __hash__(self) -> int:
|
57
67
|
|
58
68
|
return hash(
|
59
69
|
(
|
60
70
|
hash(self.model_name),
|
61
|
-
hash(self.description),
|
62
71
|
hash(self.kin_dyn_parameters),
|
63
72
|
)
|
64
73
|
)
|
@@ -152,10 +161,10 @@ class JaxSimModel(JaxsimDataclass):
|
|
152
161
|
# Set the model name (if not provided, use the one from the model description)
|
153
162
|
model_name = model_name if model_name is not None else model_description.name
|
154
163
|
|
155
|
-
# Build the model
|
164
|
+
# Build the model.
|
156
165
|
model = JaxSimModel(
|
157
166
|
model_name=model_name,
|
158
|
-
|
167
|
+
_description=wrappers.HashlessObject(obj=model_description),
|
159
168
|
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
|
160
169
|
model_description=model_description
|
161
170
|
),
|
@@ -270,6 +279,10 @@ class JaxSimModel(JaxsimDataclass):
|
|
270
279
|
|
271
280
|
return self.kin_dyn_parameters.link_names
|
272
281
|
|
282
|
+
# =====================
|
283
|
+
# Frame-related methods
|
284
|
+
# =====================
|
285
|
+
|
273
286
|
def frame_names(self) -> tuple[str, ...]:
|
274
287
|
"""
|
275
288
|
Return the names of the links in the model.
|
@@ -278,7 +291,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
278
291
|
The names of the links in the model.
|
279
292
|
"""
|
280
293
|
|
281
|
-
return
|
294
|
+
return self.kin_dyn_parameters.frame_parameters.name
|
282
295
|
|
283
296
|
|
284
297
|
# =====================
|
jaxsim/api/ode_data.py
CHANGED
@@ -283,12 +283,16 @@ class PhysicsModelState(JaxsimDataclass):
|
|
283
283
|
|
284
284
|
def __hash__(self) -> int:
|
285
285
|
|
286
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
287
|
+
|
286
288
|
return hash(
|
287
289
|
(
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
290
|
+
HashedNumpyArray.hash_of_array(self.joint_positions),
|
291
|
+
HashedNumpyArray.hash_of_array(self.joint_velocities),
|
292
|
+
HashedNumpyArray.hash_of_array(self.base_position),
|
293
|
+
HashedNumpyArray.hash_of_array(self.base_quaternion),
|
294
|
+
HashedNumpyArray.hash_of_array(self.base_linear_velocity),
|
295
|
+
HashedNumpyArray.hash_of_array(self.base_angular_velocity),
|
292
296
|
)
|
293
297
|
)
|
294
298
|
|
@@ -613,9 +617,9 @@ class SoftContactsState(JaxsimDataclass):
|
|
613
617
|
|
614
618
|
def __hash__(self) -> int:
|
615
619
|
|
616
|
-
|
617
|
-
|
618
|
-
)
|
620
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
621
|
+
|
622
|
+
return HashedNumpyArray.hash_of_array(self.tangential_deformation)
|
619
623
|
|
620
624
|
def __eq__(self, other: SoftContactsState) -> bool:
|
621
625
|
|
@@ -41,7 +41,7 @@ class JointGenericAxis:
|
|
41
41
|
return hash(self) == hash(other)
|
42
42
|
|
43
43
|
|
44
|
-
@jax_dataclasses.pytree_dataclass
|
44
|
+
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
|
45
45
|
class JointDescription(JaxsimDataclass):
|
46
46
|
"""
|
47
47
|
In-memory description of a robot link.
|
@@ -95,25 +95,59 @@ class JointDescription(JaxsimDataclass):
|
|
95
95
|
norm_of_axis = np.linalg.norm(self.axis)
|
96
96
|
self.axis = self.axis / norm_of_axis
|
97
97
|
|
98
|
+
def __eq__(self, other: JointDescription) -> bool:
|
99
|
+
|
100
|
+
if not isinstance(other, JointDescription):
|
101
|
+
return False
|
102
|
+
|
103
|
+
if not (
|
104
|
+
self.name == other.name
|
105
|
+
and self.jtype == other.jtype
|
106
|
+
and self.child == other.child
|
107
|
+
and self.parent == other.parent
|
108
|
+
and self.index == other.index
|
109
|
+
and all(
|
110
|
+
np.allclose(getattr(self, attr), getattr(other, attr))
|
111
|
+
for attr in [
|
112
|
+
"axis",
|
113
|
+
"pose",
|
114
|
+
"friction_static",
|
115
|
+
"friction_viscous",
|
116
|
+
"position_limit_damper",
|
117
|
+
"position_limit_spring",
|
118
|
+
"position_limit",
|
119
|
+
"initial_position",
|
120
|
+
"motor_inertia",
|
121
|
+
"motor_viscous_friction",
|
122
|
+
"motor_gear_ratio",
|
123
|
+
]
|
124
|
+
),
|
125
|
+
):
|
126
|
+
return False
|
127
|
+
|
128
|
+
return True
|
129
|
+
|
98
130
|
def __hash__(self) -> int:
|
99
131
|
|
132
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
133
|
+
|
100
134
|
return hash(
|
101
135
|
(
|
102
136
|
hash(self.name),
|
103
|
-
|
104
|
-
|
137
|
+
HashedNumpyArray.hash_of_array(self.axis),
|
138
|
+
HashedNumpyArray.hash_of_array(self.pose),
|
105
139
|
hash(int(self.jtype)),
|
106
140
|
hash(self.child),
|
107
141
|
hash(self.parent),
|
108
142
|
hash(int(self.index)) if self.index is not None else 0,
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
143
|
+
HashedNumpyArray.hash_of_array(self.friction_static),
|
144
|
+
HashedNumpyArray.hash_of_array(self.friction_viscous),
|
145
|
+
HashedNumpyArray.hash_of_array(self.position_limit_damper),
|
146
|
+
HashedNumpyArray.hash_of_array(self.position_limit_spring),
|
147
|
+
HashedNumpyArray.hash_of_array(self.position_limit),
|
148
|
+
HashedNumpyArray.hash_of_array(self.initial_position),
|
149
|
+
HashedNumpyArray.hash_of_array(self.motor_inertia),
|
150
|
+
HashedNumpyArray.hash_of_array(self.motor_viscous_friction),
|
151
|
+
HashedNumpyArray.hash_of_array(self.motor_gear_ratio),
|
118
152
|
),
|
119
153
|
)
|
@@ -12,7 +12,7 @@ import jaxsim.typing as jtp
|
|
12
12
|
from jaxsim.utils import JaxsimDataclass
|
13
13
|
|
14
14
|
|
15
|
-
@jax_dataclasses.pytree_dataclass
|
15
|
+
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
|
16
16
|
class LinkDescription(JaxsimDataclass):
|
17
17
|
"""
|
18
18
|
In-memory description of a robot link.
|
@@ -31,7 +31,7 @@ class LinkDescription(JaxsimDataclass):
|
|
31
31
|
mass: float = dataclasses.field(repr=False)
|
32
32
|
inertia: jtp.Matrix = dataclasses.field(repr=False)
|
33
33
|
index: int | None = None
|
34
|
-
parent: LinkDescription = dataclasses.field(default=None, repr=False)
|
34
|
+
parent: LinkDescription | None = dataclasses.field(default=None, repr=False)
|
35
35
|
pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
|
36
36
|
|
37
37
|
children: Static[tuple[LinkDescription]] = dataclasses.field(
|
@@ -40,13 +40,15 @@ class LinkDescription(JaxsimDataclass):
|
|
40
40
|
|
41
41
|
def __hash__(self) -> int:
|
42
42
|
|
43
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
44
|
+
|
43
45
|
return hash(
|
44
46
|
(
|
45
47
|
hash(self.name),
|
46
48
|
hash(float(self.mass)),
|
47
|
-
|
49
|
+
HashedNumpyArray.hash_of_array(self.inertia),
|
48
50
|
hash(int(self.index)) if self.index is not None else 0,
|
49
|
-
|
51
|
+
HashedNumpyArray.hash_of_array(self.pose),
|
50
52
|
hash(tuple(self.children)),
|
51
53
|
# Here only using the name to prevent circular recursion:
|
52
54
|
hash(self.parent.name) if self.parent is not None else 0,
|
@@ -58,7 +60,22 @@ class LinkDescription(JaxsimDataclass):
|
|
58
60
|
if not isinstance(other, LinkDescription):
|
59
61
|
return False
|
60
62
|
|
61
|
-
|
63
|
+
if not (
|
64
|
+
self.name == other.name
|
65
|
+
and np.allclose(self.mass, other.mass)
|
66
|
+
and np.allclose(self.inertia, other.inertia)
|
67
|
+
and self.index == other.index
|
68
|
+
and np.allclose(self.pose, other.pose)
|
69
|
+
and self.children == other.children
|
70
|
+
and (
|
71
|
+
(self.parent is not None and self.parent.name == other.parent.name)
|
72
|
+
if self.parent is not None
|
73
|
+
else other.parent is None
|
74
|
+
),
|
75
|
+
):
|
76
|
+
return False
|
77
|
+
|
78
|
+
return True
|
62
79
|
|
63
80
|
@property
|
64
81
|
def name_and_index(self) -> str:
|
@@ -12,7 +12,7 @@ from .joint import JointDescription
|
|
12
12
|
from .link import LinkDescription
|
13
13
|
|
14
14
|
|
15
|
-
@dataclasses.dataclass(frozen=True)
|
15
|
+
@dataclasses.dataclass(frozen=True, eq=False, unsafe_hash=False)
|
16
16
|
class ModelDescription(KinematicGraph):
|
17
17
|
"""
|
18
18
|
Intermediate representation representing the kinematic graph of a robot model.
|
@@ -28,7 +28,7 @@ class ModelDescription(KinematicGraph):
|
|
28
28
|
fixed_base: bool = True
|
29
29
|
|
30
30
|
collision_shapes: tuple[CollisionShape, ...] = dataclasses.field(
|
31
|
-
default_factory=list, repr=False
|
31
|
+
default_factory=list, repr=False
|
32
32
|
)
|
33
33
|
|
34
34
|
@staticmethod
|
@@ -249,7 +249,17 @@ class ModelDescription(KinematicGraph):
|
|
249
249
|
if not isinstance(other, ModelDescription):
|
250
250
|
return False
|
251
251
|
|
252
|
-
|
252
|
+
if not (
|
253
|
+
self.name == other.name
|
254
|
+
and self.fixed_base == other.fixed_base
|
255
|
+
and self.root == other.root
|
256
|
+
and self.joints == other.joints
|
257
|
+
and self.frames == other.frames
|
258
|
+
and self.root_pose == other.root_pose
|
259
|
+
):
|
260
|
+
return False
|
261
|
+
|
262
|
+
return True
|
253
263
|
|
254
264
|
def __hash__(self) -> int:
|
255
265
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import copy
|
4
4
|
import dataclasses
|
5
5
|
import functools
|
6
|
-
from typing import Any, Callable, Iterable,
|
6
|
+
from typing import Any, Callable, Iterable, Sequence
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import numpy.typing as npt
|
@@ -15,7 +15,8 @@ from jaxsim.utils import Mutability
|
|
15
15
|
from . import descriptions
|
16
16
|
|
17
17
|
|
18
|
-
|
18
|
+
@dataclasses.dataclass
|
19
|
+
class RootPose:
|
19
20
|
"""
|
20
21
|
Represents the root pose in a kinematic graph.
|
21
22
|
|
@@ -28,15 +29,20 @@ class RootPose(NamedTuple):
|
|
28
29
|
The root link of the kinematic graph is the base link.
|
29
30
|
"""
|
30
31
|
|
31
|
-
root_position: npt.NDArray = np.zeros(3)
|
32
|
-
|
32
|
+
root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))
|
33
|
+
|
34
|
+
root_quaternion: npt.NDArray = dataclasses.field(
|
35
|
+
default_factory=lambda: np.array([1.0, 0, 0, 0])
|
36
|
+
)
|
33
37
|
|
34
38
|
def __hash__(self) -> int:
|
35
39
|
|
40
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
41
|
+
|
36
42
|
return hash(
|
37
43
|
(
|
38
|
-
|
39
|
-
|
44
|
+
HashedNumpyArray.hash_of_array(self.root_position),
|
45
|
+
HashedNumpyArray.hash_of_array(self.root_quaternion),
|
40
46
|
)
|
41
47
|
)
|
42
48
|
|
@@ -45,7 +51,13 @@ class RootPose(NamedTuple):
|
|
45
51
|
if not isinstance(other, RootPose):
|
46
52
|
return False
|
47
53
|
|
48
|
-
|
54
|
+
if not np.allclose(self.root_position, other.root_position):
|
55
|
+
return False
|
56
|
+
|
57
|
+
if not np.allclose(self.root_quaternion, other.root_quaternion):
|
58
|
+
return False
|
59
|
+
|
60
|
+
return True
|
49
61
|
|
50
62
|
|
51
63
|
@dataclasses.dataclass(frozen=True)
|
jaxsim/rbda/soft_contacts.py
CHANGED
@@ -31,11 +31,13 @@ class SoftContactsParams(JaxsimDataclass):
|
|
31
31
|
|
32
32
|
def __hash__(self) -> int:
|
33
33
|
|
34
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
35
|
+
|
34
36
|
return hash(
|
35
37
|
(
|
36
|
-
|
37
|
-
|
38
|
-
|
38
|
+
HashedNumpyArray.hash_of_array(self.K),
|
39
|
+
HashedNumpyArray.hash_of_array(self.D),
|
40
|
+
HashedNumpyArray.hash_of_array(self.mu),
|
39
41
|
)
|
40
42
|
)
|
41
43
|
|
jaxsim/utils/wrappers.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
from typing import Generic, TypeVar
|
4
|
+
from typing import Callable, Generic, TypeVar
|
5
5
|
|
6
6
|
import jax
|
7
7
|
import jax_dataclasses
|
@@ -40,6 +40,33 @@ class HashlessObject(Generic[T]):
|
|
40
40
|
return hash(self) == hash(other)
|
41
41
|
|
42
42
|
|
43
|
+
@dataclasses.dataclass
|
44
|
+
class CustomHashedObject(Generic[T]):
|
45
|
+
"""
|
46
|
+
A class that wraps an object and computes its hash with a custom hash function.
|
47
|
+
"""
|
48
|
+
|
49
|
+
obj: T
|
50
|
+
|
51
|
+
hash_function: Callable[[T], int] = dataclasses.field(default=lambda obj: hash(obj))
|
52
|
+
|
53
|
+
def get(self: CustomHashedObject[T]) -> T:
|
54
|
+
return self.obj
|
55
|
+
|
56
|
+
def __hash__(self) -> int:
|
57
|
+
|
58
|
+
return self.hash_function(self.obj)
|
59
|
+
|
60
|
+
def __eq__(self, other: CustomHashedObject[T]) -> bool:
|
61
|
+
|
62
|
+
if not isinstance(other, CustomHashedObject) and isinstance(
|
63
|
+
other.get(), type(self.get())
|
64
|
+
):
|
65
|
+
return False
|
66
|
+
|
67
|
+
return hash(self) == hash(other)
|
68
|
+
|
69
|
+
|
43
70
|
@jax_dataclasses.pytree_dataclass
|
44
71
|
class HashedNumpyArray:
|
45
72
|
"""
|
@@ -56,6 +83,10 @@ class HashedNumpyArray:
|
|
56
83
|
|
57
84
|
array: jax.Array | npt.NDArray
|
58
85
|
|
86
|
+
precision: float | None = dataclasses.field(
|
87
|
+
default=1e-9, repr=False, compare=False, hash=False
|
88
|
+
)
|
89
|
+
|
59
90
|
large_array: jax_dataclasses.Static[bool] = dataclasses.field(
|
60
91
|
default=False, repr=False, compare=False, hash=False
|
61
92
|
)
|
@@ -65,7 +96,9 @@ class HashedNumpyArray:
|
|
65
96
|
|
66
97
|
def __hash__(self) -> int:
|
67
98
|
|
68
|
-
return
|
99
|
+
return HashedNumpyArray.hash_of_array(
|
100
|
+
array=self.array, precision=self.precision
|
101
|
+
)
|
69
102
|
|
70
103
|
def __eq__(self, other: HashedNumpyArray) -> bool:
|
71
104
|
|
@@ -76,3 +109,37 @@ class HashedNumpyArray:
|
|
76
109
|
return np.array_equal(self.array, other.array)
|
77
110
|
|
78
111
|
return hash(self) == hash(other)
|
112
|
+
|
113
|
+
@staticmethod
|
114
|
+
def hash_of_array(
|
115
|
+
array: jax.Array | npt.NDArray, precision: float | None = 1e-9
|
116
|
+
) -> int:
|
117
|
+
"""
|
118
|
+
Calculate the hash of a NumPy array.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
array: The array to hash.
|
122
|
+
precision: Optionally limit the precision over which the hash is computed.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
The hash of the array.
|
126
|
+
"""
|
127
|
+
|
128
|
+
array = np.array(array).flatten()
|
129
|
+
|
130
|
+
array = np.where(array == np.nan, hash(np.nan), array)
|
131
|
+
array = np.where(array == np.inf, hash(np.inf), array)
|
132
|
+
array = np.where(array == -np.inf, hash(-np.inf), array)
|
133
|
+
|
134
|
+
if precision is not None:
|
135
|
+
|
136
|
+
integer1 = (array * precision).astype(int)
|
137
|
+
integer2 = (array - integer1 / precision).astype(int)
|
138
|
+
|
139
|
+
decimal_array = ((array - integer1 * 1e9 - integer2) / precision).astype(
|
140
|
+
int
|
141
|
+
)
|
142
|
+
|
143
|
+
array = np.hstack([integer1, integer2, decimal_array]).astype(int)
|
144
|
+
|
145
|
+
return hash(tuple(array.tolist()))
|
@@ -1,19 +1,19 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=xzuTuZrgKdWLqqDzbvqzm2cJrEtAbepOeUqDu7ByVek,2621
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=EQQfkY5WXMHFjdRnYAQqABGWC0VK4dlpuNh_wr1KxYA,426
|
3
3
|
jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
|
4
4
|
jaxsim/typing.py,sha256=cl7HHQCeP3mHmtF6EuQZcCjGvDmc_AryMWntP_lRBGg,722
|
5
5
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
6
6
|
jaxsim/api/com.py,sha256=Yof6otFi-mLWAs1rqjmeNJTOWIH9gn7BdU5EIjiL6Ts,13481
|
7
7
|
jaxsim/api/common.py,sha256=bqQ__pIQZbh-j8rkoHUkYHAgGiJnDzjHG-q4Ny0OOYQ,6646
|
8
8
|
jaxsim/api/contact.py,sha256=79kcdq7C1_kWgxd1QWBabBhIPkwWEVLk-Fiz9kh-4so,12800
|
9
|
-
jaxsim/api/data.py,sha256=
|
10
|
-
jaxsim/api/frame.py,sha256=
|
9
|
+
jaxsim/api/data.py,sha256=fkVDBV1tODRYIaRb2N15l34InAcnzNygMGG1KFiIU2w,27307
|
10
|
+
jaxsim/api/frame.py,sha256=vSbFHL4WtKPySxunNoZLlM_aDuJXZtf8CSBKku63BAs,6178
|
11
11
|
jaxsim/api/joint.py,sha256=-5DogPg4g4mmLckyVIVNjwv-Rxz0IWS7_md9nDlhPWA,4581
|
12
|
-
jaxsim/api/kin_dyn_parameters.py,sha256=
|
12
|
+
jaxsim/api/kin_dyn_parameters.py,sha256=AEpDg9kihbKUN9PA8pNrAruSuWFUC-k_GGxtlcdcDiQ,29215
|
13
13
|
jaxsim/api/link.py,sha256=MdMWaMpM5Dj5JHK8uwHZ4zR4Fjq3R4asi2sGTxk1OAs,16647
|
14
|
-
jaxsim/api/model.py,sha256=
|
14
|
+
jaxsim/api/model.py,sha256=iuNYsn4xIfX36smmZpwM2O5eftT7ioDQtb6mSUqWu6Q,59759
|
15
15
|
jaxsim/api/ode.py,sha256=luTQJsIXUtCp_81dR42X7WrMvwrXtYbyJiqss29v7zA,10786
|
16
|
-
jaxsim/api/ode_data.py,sha256=
|
16
|
+
jaxsim/api/ode_data.py,sha256=FxUIV5qDNOg_OiOXWs3UrhDgKhGmTKcbHqgr4NX5bv0,23290
|
17
17
|
jaxsim/api/references.py,sha256=UA6kSQVBoq-bXSo99EOELf-_MD5MTy2zS0GtG3wQ410,16618
|
18
18
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
19
19
|
jaxsim/integrators/common.py,sha256=9HXRVFo95Mpt6RcVhBrOfvOO7mDxqbkXeg_lKUibEFY,20693
|
@@ -34,12 +34,12 @@ jaxsim/mujoco/loaders.py,sha256=7rjpeJ6_GuitlCty-ZkLhTILQ0GmsFzDMgve-7Gkkh4,2098
|
|
34
34
|
jaxsim/mujoco/model.py,sha256=1KVRjSLOTCuHt53apBPQTnFYJRknlVoKLQaxWsNK8qc,13494
|
35
35
|
jaxsim/mujoco/visualizer.py,sha256=PXgQzwetS9mRJYHBknDMLsQ9152FdrSvZuT9xE_dfIQ,5069
|
36
36
|
jaxsim/parsers/__init__.py,sha256=sonYi-bBWAoB04kp1mxT4uIORxjb7SdZ0ukGPmVx98Y,44
|
37
|
-
jaxsim/parsers/kinematic_graph.py,sha256=
|
37
|
+
jaxsim/parsers/kinematic_graph.py,sha256=1d0JAc3LrGTymaqO9exRsb33-o0Vtgc3cUvNP1YI-0Q,35083
|
38
38
|
jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
|
39
39
|
jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
|
40
|
-
jaxsim/parsers/descriptions/joint.py,sha256=
|
41
|
-
jaxsim/parsers/descriptions/link.py,sha256=
|
42
|
-
jaxsim/parsers/descriptions/model.py,sha256=
|
40
|
+
jaxsim/parsers/descriptions/joint.py,sha256=7qUabpldRKwpGYQLCtQyMKiY47hB78J80DIuzI6bGLc,5186
|
41
|
+
jaxsim/parsers/descriptions/link.py,sha256=s0NXGOqmDknX0DYof31TGjVLLUHC9kSwzlGYLcCc03A,3710
|
42
|
+
jaxsim/parsers/descriptions/model.py,sha256=vfubtW68CUdgcbCHPcgKy0_BxzKQhhM8ycbCE-dF7Vk,9827
|
43
43
|
jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
|
44
44
|
jaxsim/parsers/rod/parser.py,sha256=4COuhkAYv4-GIpCqvkXEJWpDEQczEkBM3KwpqX48Rek,13514
|
45
45
|
jaxsim/parsers/rod/utils.py,sha256=KSjgy6WsmTrD5HZEA2x8hOBSRU4bUGOOHzxKkeFO5r8,5721
|
@@ -50,16 +50,16 @@ jaxsim/rbda/crba.py,sha256=awsWEQXLE0UPEXIcZCVsAqBEPjyahMNzY9ux6nE1l-s,4739
|
|
50
50
|
jaxsim/rbda/forward_kinematics.py,sha256=94W7TUXvZjMb-99CyYR8pObuxIYYX9B_dtRZqsNcThs,3418
|
51
51
|
jaxsim/rbda/jacobian.py,sha256=M79bGir-2w_iJ2GurYhOGgMfJnp7ZMOCW6AeeWKK8iM,10745
|
52
52
|
jaxsim/rbda/rnea.py,sha256=DjwkvXQVUSUclM3Uy3UPZ2tao91R5dGd4o7TsS2qObI,7650
|
53
|
-
jaxsim/rbda/soft_contacts.py,sha256=
|
53
|
+
jaxsim/rbda/soft_contacts.py,sha256=0hx9JT4R1X2PPjhZ1EDizBR1gGoCFCtKYu86SeuIvvA,11269
|
54
54
|
jaxsim/rbda/utils.py,sha256=zpbFM2Iq8cntku0BFVu9nfEqZhInCWi9D2INT6MFEI8,5003
|
55
55
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
56
56
|
jaxsim/terrain/terrain.py,sha256=UXQCt7TCkq6GkM8bOZu44pNTpf-FZWiKN6VE4kb4kFk,2342
|
57
57
|
jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
58
58
|
jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
|
59
59
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
60
|
-
jaxsim/utils/wrappers.py,sha256=
|
61
|
-
jaxsim-0.3.1.
|
62
|
-
jaxsim-0.3.1.
|
63
|
-
jaxsim-0.3.1.
|
64
|
-
jaxsim-0.3.1.
|
65
|
-
jaxsim-0.3.1.
|
60
|
+
jaxsim/utils/wrappers.py,sha256=QIJitSoljrKR_U4T3ewCJPT3DTh-tPZsRsg0t_MH93E,3896
|
61
|
+
jaxsim-0.3.1.dev17.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
62
|
+
jaxsim-0.3.1.dev17.dist-info/METADATA,sha256=zRsMl96hDJt919NgrEuxkhye1S8X20bi_nWdPzJiptU,9739
|
63
|
+
jaxsim-0.3.1.dev17.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
64
|
+
jaxsim-0.3.1.dev17.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
65
|
+
jaxsim-0.3.1.dev17.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|