jaxsim 0.3.1.dev4__py3-none-any.whl → 0.3.1.dev21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.1.dev4'
16
- __version_tuple__ = version_tuple = (0, 3, 1, 'dev4')
15
+ __version__ = version = '0.3.1.dev21'
16
+ __version_tuple__ = version_tuple = (0, 3, 1, 'dev21')
jaxsim/api/data.py CHANGED
@@ -45,12 +45,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
45
45
 
46
46
  def __hash__(self) -> int:
47
47
 
48
+ from jaxsim.utils.wrappers import HashedNumpyArray
49
+
48
50
  return hash(
49
51
  (
50
52
  hash(self.state),
51
- hash(tuple(self.gravity.flatten().tolist())),
53
+ HashedNumpyArray.hash_of_array(self.gravity),
52
54
  hash(self.soft_contacts_params),
53
- hash(jnp.atleast_1d(self.time_ns).flatten().tolist()),
55
+ HashedNumpyArray.hash_of_array(self.time_ns),
54
56
  )
55
57
  )
56
58
 
jaxsim/api/frame.py CHANGED
@@ -17,7 +17,9 @@ from .common import VelRepr
17
17
  # =======================
18
18
 
19
19
 
20
- def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -> int:
20
+ def idx_of_parent_link(
21
+ model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike
22
+ ) -> jtp.Int:
21
23
  """
22
24
  Get the index of the link to which the frame is rigidly attached.
23
25
 
@@ -29,17 +31,13 @@ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -
29
31
  The index of the frame's parent link.
30
32
  """
31
33
 
32
- # Get the intermediate representation parsed from the model description.
33
- ir = model.description
34
+ return model.kin_dyn_parameters.frame_parameters.body[
35
+ frame_idx - model.number_of_links()
36
+ ]
34
37
 
35
- # Extract the indices of the frame and the link it is attached to.
36
- F = ir.frames[frame_idx - model.number_of_links()]
37
- L = ir.links_dict[F.parent.name].index
38
38
 
39
- return int(L)
40
-
41
-
42
- def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
39
+ @functools.partial(jax.jit, static_argnames="frame_name")
40
+ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
43
41
  """
44
42
  Convert the name of a frame to its index.
45
43
 
@@ -51,13 +49,19 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
51
49
  The index of the frame.
52
50
  """
53
51
 
54
- frame_names = np.array([frame.name for frame in model.description.frames])
52
+ if frame_name in model.kin_dyn_parameters.frame_parameters.name:
53
+ return (
54
+ jnp.array(
55
+ np.argwhere(
56
+ np.array(model.kin_dyn_parameters.frame_parameters.name)
57
+ == frame_name
58
+ )
59
+ )
60
+ .squeeze()
61
+ .astype(int)
62
+ ) + model.number_of_links()
55
63
 
56
- if frame_name in frame_names:
57
- idx_in_list = np.argwhere(frame_names == frame_name)
58
- return int(idx_in_list.squeeze().tolist()) + model.number_of_links()
59
-
60
- return -1
64
+ return jnp.array(-1).astype(int)
61
65
 
62
66
 
63
67
  def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
@@ -72,7 +76,9 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
72
76
  The name of the frame.
73
77
  """
74
78
 
75
- return model.description.frames[frame_index - model.number_of_links()].name
79
+ return model.kin_dyn_parameters.frame_parameters.name[
80
+ frame_index - model.number_of_links()
81
+ ]
76
82
 
77
83
 
78
84
  @functools.partial(jax.jit, static_argnames=["frame_names"])
@@ -91,7 +97,7 @@ def names_to_idxs(
91
97
  """
92
98
 
93
99
  return jnp.array(
94
- [name_to_idx(model=model, frame_name=frame_name) for frame_name in frame_names]
100
+ [name_to_idx(model=model, frame_name=name) for name in frame_names]
95
101
  ).astype(int)
96
102
 
97
103
 
@@ -109,10 +115,7 @@ def idxs_to_names(
109
115
  The names of the frames.
110
116
  """
111
117
 
112
- return tuple(
113
- idx_to_name(model=model, frame_index=frame_index)
114
- for frame_index in frame_indices
115
- )
118
+ return tuple(idx_to_name(model=model, frame_index=idx) for idx in frame_indices)
116
119
 
117
120
 
118
121
  # ==========
@@ -120,7 +123,7 @@ def idxs_to_names(
120
123
  # ==========
121
124
 
122
125
 
123
- @functools.partial(jax.jit, static_argnames=["frame_index"])
126
+ @jax.jit
124
127
  def transform(
125
128
  model: js.model.JaxSimModel,
126
129
  data: js.data.JaxSimModelData,
@@ -144,14 +147,15 @@ def transform(
144
147
  W_H_L = js.link.transform(model=model, data=data, link_index=L)
145
148
 
146
149
  # Get the static frame pose wrt the parent link.
147
- frame = model.description.frames[frame_index - model.number_of_links()]
148
- L_H_F = frame.pose
150
+ L_H_F = model.kin_dyn_parameters.frame_parameters.transform[
151
+ frame_index - model.number_of_links()
152
+ ]
149
153
 
150
154
  # Combine the transforms computing the frame pose.
151
155
  return W_H_L @ L_H_F
152
156
 
153
157
 
154
- @functools.partial(jax.jit, static_argnames=["frame_index", "output_vel_repr"])
158
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
155
159
  def jacobian(
156
160
  model: js.model.JaxSimModel,
157
161
  data: js.data.JaxSimModelData,
@@ -6,7 +6,6 @@ import jax.lax
6
6
  import jax.numpy as jnp
7
7
  import jax_dataclasses
8
8
  import jaxlie
9
- import numpy as np
10
9
  from jax_dataclasses import Static
11
10
 
12
11
  import jaxsim.typing as jtp
@@ -15,7 +14,7 @@ from jaxsim.parsers.descriptions import JointDescription, ModelDescription
15
14
  from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
16
15
 
17
16
 
18
- @jax_dataclasses.pytree_dataclass
17
+ @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
19
18
  class KynDynParameters(JaxsimDataclass):
20
19
  r"""
21
20
  Class storing the kinematic and dynamic parameters of a model.
@@ -26,6 +25,7 @@ class KynDynParameters(JaxsimDataclass):
26
25
  support_body_array_bool:
27
26
  The boolean support parent array :math:`\kappa_{b}(i)` of the model.
28
27
  link_parameters: The parameters of the links.
28
+ frame_parameters: The parameters of the frames.
29
29
  contact_parameters: The parameters of the collidable points.
30
30
  joint_model: The joint model of the model.
31
31
  joint_parameters: The parameters of the joints.
@@ -42,6 +42,9 @@ class KynDynParameters(JaxsimDataclass):
42
42
  # Contacts
43
43
  contact_parameters: ContactParameters
44
44
 
45
+ # Frames
46
+ frame_parameters: FrameParameters
47
+
45
48
  # Joints
46
49
  joint_model: JointModel
47
50
  joint_parameters: JointParameters | None
@@ -141,6 +144,19 @@ class KynDynParameters(JaxsimDataclass):
141
144
  model_description=model_description
142
145
  )
143
146
 
147
+ # =================
148
+ # Frames properties
149
+ # =================
150
+
151
+ # Create the object storing the parameters of frames.
152
+ # Note that, contrarily to LinkParameters and JointsParameters, this object
153
+ # is not created with vmap. This is because the "name" attribute of the object
154
+ # must be Static for JIT-related reasons, and tree_map would not consider it
155
+ # as a leaf.
156
+ frame_parameters = FrameParameters.build_from(
157
+ model_description=model_description
158
+ )
159
+
144
160
  # ===============
145
161
  # Tree properties
146
162
  # ===============
@@ -206,6 +222,7 @@ class KynDynParameters(JaxsimDataclass):
206
222
  joint_model=joint_model,
207
223
  joint_parameters=joint_parameters,
208
224
  contact_parameters=contact_parameters,
225
+ frame_parameters=frame_parameters,
209
226
  )
210
227
 
211
228
  def __eq__(self, other: KynDynParameters) -> bool:
@@ -221,7 +238,8 @@ class KynDynParameters(JaxsimDataclass):
221
238
  (
222
239
  hash(self.number_of_links()),
223
240
  hash(self.number_of_joints()),
224
- hash(tuple(np.atleast_1d(self.parent_array).flatten().tolist())),
241
+ hash(self.frame_parameters.name),
242
+ hash(tuple(self.frame_parameters.body.tolist())),
225
243
  hash(self._parent_array),
226
244
  hash(self._support_body_array_bool),
227
245
  )
@@ -730,7 +748,7 @@ class ContactParameters(JaxsimDataclass):
730
748
  A tuple of integers representing, for each collidable point, the index of
731
749
  the body (link) to which it is rigidly attached to.
732
750
  point:
733
- The translation between the link frame and the collidable point, expressed
751
+ The translations between the link frame and the collidable point, expressed
734
752
  in the coordinates of the parent link frame.
735
753
 
736
754
  Note:
@@ -773,10 +791,75 @@ class ContactParameters(JaxsimDataclass):
773
791
  links_dict[cp.parent_link.name].index for cp in collidable_points
774
792
  )
775
793
 
776
- # Build the GroundContact object.
794
+ # Build the ContactParameters object.
777
795
  cp = ContactParameters(point=points, body=link_index_of_points) # noqa
778
796
 
779
- assert cp.point.shape[1] == 3
780
- assert cp.point.shape[0] == len(cp.body)
797
+ assert cp.point.shape[1] == 3, cp.point.shape[1]
798
+ assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]
781
799
 
782
800
  return cp
801
+
802
+
803
+ @jax_dataclasses.pytree_dataclass
804
+ class FrameParameters(JaxsimDataclass):
805
+ """
806
+ Class storing the frame parameters of a model.
807
+
808
+ Attributes:
809
+ name: A tuple of strings defining the frame names.
810
+ body:
811
+ A vector of integers representing, for each frame, the index of
812
+ the body (link) to which it is rigidly attached to.
813
+ transform: The transforms of the frames w.r.t. their parent link.
814
+
815
+ Note:
816
+ Contrarily to LinkParameters and JointParameters, this class is not meant
817
+ to be created with vmap. This is because the `name` attribute must be `Static`.
818
+ """
819
+
820
+ name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple)
821
+
822
+ body: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([]))
823
+
824
+ transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([]))
825
+
826
+ @staticmethod
827
+ def build_from(model_description: ModelDescription) -> FrameParameters:
828
+ """
829
+ Build a FrameParameters object from a model description.
830
+
831
+ Args:
832
+ model_description: The model description to consider.
833
+
834
+ Returns:
835
+ The FrameParameters object.
836
+ """
837
+
838
+ if len(model_description.frames) == 0:
839
+ return FrameParameters()
840
+
841
+ # Extract the frame names.
842
+ names = tuple(frame.name for frame in model_description.frames)
843
+
844
+ # For each frame, extract the index of the link to which it is attached to.
845
+ parent_link_index_of_frames = tuple(
846
+ model_description.links_dict[frame.parent.name].index
847
+ for frame in model_description.frames
848
+ )
849
+
850
+ # For each frame, extract the transform w.r.t. its parent link.
851
+ transforms = jnp.atleast_3d(
852
+ jnp.stack([frame.pose for frame in model_description.frames])
853
+ )
854
+
855
+ # Build the FrameParameters object.
856
+ fp = FrameParameters(
857
+ name=names,
858
+ transform=transforms.astype(float),
859
+ body=jnp.array(parent_link_index_of_frames).astype(int),
860
+ )
861
+
862
+ assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:]
863
+ assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]
864
+
865
+ return fp
jaxsim/api/model.py CHANGED
@@ -17,12 +17,12 @@ import jaxsim.api as js
17
17
  import jaxsim.parsers.descriptions
18
18
  import jaxsim.typing as jtp
19
19
  from jaxsim.math import Cross
20
- from jaxsim.utils import JaxsimDataclass, Mutability
20
+ from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
21
21
 
22
22
  from .common import VelRepr
23
23
 
24
24
 
25
- @jax_dataclasses.pytree_dataclass
25
+ @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
26
26
  class JaxSimModel(JaxsimDataclass):
27
27
  """
28
28
  The JaxSim model defining the kinematics and dynamics of a robot.
@@ -31,34 +31,43 @@ class JaxSimModel(JaxsimDataclass):
31
31
  model_name: Static[str]
32
32
 
33
33
  terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
34
- default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
34
+ default=jaxsim.terrain.FlatTerrain(), repr=False
35
35
  )
36
36
 
37
37
  kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
38
- dataclasses.field(default=None, repr=False, compare=False, hash=False)
38
+ dataclasses.field(default=None, repr=False)
39
39
  )
40
40
 
41
41
  built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
42
- default=None, repr=False, compare=False, hash=False
42
+ default=None, repr=False
43
43
  )
44
44
 
45
- description: Static[jaxsim.parsers.descriptions.ModelDescription | None] = (
46
- dataclasses.field(default=None, repr=False, compare=False, hash=False)
47
- )
45
+ _description: Static[
46
+ wrappers.HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
47
+ ] = dataclasses.field(default=None, repr=False)
48
+
49
+ @property
50
+ def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
51
+ return self._description.get()
48
52
 
49
53
  def __eq__(self, other: JaxSimModel) -> bool:
50
54
 
51
55
  if not isinstance(other, JaxSimModel):
52
56
  return False
53
57
 
54
- return hash(self) == hash(other)
58
+ if self.model_name != other.model_name:
59
+ return False
60
+
61
+ if self.kin_dyn_parameters != other.kin_dyn_parameters:
62
+ return False
63
+
64
+ return True
55
65
 
56
66
  def __hash__(self) -> int:
57
67
 
58
68
  return hash(
59
69
  (
60
70
  hash(self.model_name),
61
- hash(self.description),
62
71
  hash(self.kin_dyn_parameters),
63
72
  )
64
73
  )
@@ -152,10 +161,10 @@ class JaxSimModel(JaxsimDataclass):
152
161
  # Set the model name (if not provided, use the one from the model description)
153
162
  model_name = model_name if model_name is not None else model_description.name
154
163
 
155
- # Build the model
164
+ # Build the model.
156
165
  model = JaxSimModel(
157
166
  model_name=model_name,
158
- description=model_description,
167
+ _description=wrappers.HashlessObject(obj=model_description),
159
168
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
160
169
  model_description=model_description
161
170
  ),
@@ -270,6 +279,10 @@ class JaxSimModel(JaxsimDataclass):
270
279
 
271
280
  return self.kin_dyn_parameters.link_names
272
281
 
282
+ # =====================
283
+ # Frame-related methods
284
+ # =====================
285
+
273
286
  def frame_names(self) -> tuple[str, ...]:
274
287
  """
275
288
  Return the names of the links in the model.
@@ -278,7 +291,7 @@ class JaxSimModel(JaxsimDataclass):
278
291
  The names of the links in the model.
279
292
  """
280
293
 
281
- return tuple(frame.name for frame in self.description.frames)
294
+ return self.kin_dyn_parameters.frame_parameters.name
282
295
 
283
296
 
284
297
  # =====================
jaxsim/api/ode_data.py CHANGED
@@ -283,12 +283,16 @@ class PhysicsModelState(JaxsimDataclass):
283
283
 
284
284
  def __hash__(self) -> int:
285
285
 
286
+ from jaxsim.utils.wrappers import HashedNumpyArray
287
+
286
288
  return hash(
287
289
  (
288
- hash(tuple(jnp.atleast_1d(self.joint_positions.flatten().tolist()))),
289
- hash(tuple(jnp.atleast_1d(self.joint_velocities.flatten().tolist()))),
290
- hash(tuple(self.base_position.flatten().tolist())),
291
- hash(tuple(self.base_quaternion.flatten().tolist())),
290
+ HashedNumpyArray.hash_of_array(self.joint_positions),
291
+ HashedNumpyArray.hash_of_array(self.joint_velocities),
292
+ HashedNumpyArray.hash_of_array(self.base_position),
293
+ HashedNumpyArray.hash_of_array(self.base_quaternion),
294
+ HashedNumpyArray.hash_of_array(self.base_linear_velocity),
295
+ HashedNumpyArray.hash_of_array(self.base_angular_velocity),
292
296
  )
293
297
  )
294
298
 
@@ -613,9 +617,9 @@ class SoftContactsState(JaxsimDataclass):
613
617
 
614
618
  def __hash__(self) -> int:
615
619
 
616
- return hash(
617
- tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist())
618
- )
620
+ from jaxsim.utils.wrappers import HashedNumpyArray
621
+
622
+ return HashedNumpyArray.hash_of_array(self.tangential_deformation)
619
623
 
620
624
  def __eq__(self, other: SoftContactsState) -> bool:
621
625
 
jaxsim/exceptions.py ADDED
@@ -0,0 +1,63 @@
1
+ import jax
2
+
3
+
4
+ def raise_if(
5
+ condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs
6
+ ) -> None:
7
+ """
8
+ Raise a host-side exception if a condition is met. Useful in jit-compiled functions.
9
+
10
+ Args:
11
+ condition:
12
+ The boolean condition of the evaluated expression that triggers
13
+ the exception during runtime.
14
+ exception: The type of exception to raise.
15
+ msg:
16
+ The message to display when the exception is raised. The message can be a
17
+ format string (fmt), whose fields are filled with the args and kwargs.
18
+ """
19
+
20
+ # Check early that the format string is well-formed.
21
+ try:
22
+ _ = msg.format(*args, **kwargs)
23
+ except Exception as e:
24
+ msg = "Error in formatting exception message with args={} and kwargs={}"
25
+ raise ValueError(msg.format(args, kwargs)) from e
26
+
27
+ def _raise_exception(condition: bool, *args, **kwargs) -> None:
28
+ """The function called by the JAX callback."""
29
+
30
+ if condition:
31
+ raise exception(msg.format(*args, **kwargs))
32
+
33
+ def _callback(args, kwargs) -> None:
34
+ """The function that calls the JAX callback, executed only when needed."""
35
+
36
+ jax.debug.callback(_raise_exception, condition, *args, **kwargs)
37
+
38
+ # Since running a callable on the host is expensive, we prevent its execution
39
+ # if the condition is False with a low-level conditional expression.
40
+ def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
41
+ return jax.lax.cond(
42
+ condition,
43
+ _callback,
44
+ lambda args, kwargs: None,
45
+ args,
46
+ kwargs,
47
+ )
48
+
49
+ return _run_callback_only_if_condition_is_true(*args, **kwargs)
50
+
51
+
52
+ def raise_runtime_error_if(
53
+ condition: bool | jax.Array, msg: str, *args, **kwargs
54
+ ) -> None:
55
+
56
+ return raise_if(condition, RuntimeError, msg, *args, **kwargs)
57
+
58
+
59
+ def raise_value_error_if(
60
+ condition: bool | jax.Array, msg: str, *args, **kwargs
61
+ ) -> None:
62
+
63
+ return raise_if(condition, ValueError, msg, *args, **kwargs)
@@ -41,7 +41,7 @@ class JointGenericAxis:
41
41
  return hash(self) == hash(other)
42
42
 
43
43
 
44
- @jax_dataclasses.pytree_dataclass
44
+ @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
45
45
  class JointDescription(JaxsimDataclass):
46
46
  """
47
47
  In-memory description of a robot link.
@@ -95,25 +95,59 @@ class JointDescription(JaxsimDataclass):
95
95
  norm_of_axis = np.linalg.norm(self.axis)
96
96
  self.axis = self.axis / norm_of_axis
97
97
 
98
+ def __eq__(self, other: JointDescription) -> bool:
99
+
100
+ if not isinstance(other, JointDescription):
101
+ return False
102
+
103
+ if not (
104
+ self.name == other.name
105
+ and self.jtype == other.jtype
106
+ and self.child == other.child
107
+ and self.parent == other.parent
108
+ and self.index == other.index
109
+ and all(
110
+ np.allclose(getattr(self, attr), getattr(other, attr))
111
+ for attr in [
112
+ "axis",
113
+ "pose",
114
+ "friction_static",
115
+ "friction_viscous",
116
+ "position_limit_damper",
117
+ "position_limit_spring",
118
+ "position_limit",
119
+ "initial_position",
120
+ "motor_inertia",
121
+ "motor_viscous_friction",
122
+ "motor_gear_ratio",
123
+ ]
124
+ ),
125
+ ):
126
+ return False
127
+
128
+ return True
129
+
98
130
  def __hash__(self) -> int:
99
131
 
132
+ from jaxsim.utils.wrappers import HashedNumpyArray
133
+
100
134
  return hash(
101
135
  (
102
136
  hash(self.name),
103
- hash(tuple(self.axis.tolist())),
104
- hash(tuple(self.pose.flatten().tolist())),
137
+ HashedNumpyArray.hash_of_array(self.axis),
138
+ HashedNumpyArray.hash_of_array(self.pose),
105
139
  hash(int(self.jtype)),
106
140
  hash(self.child),
107
141
  hash(self.parent),
108
142
  hash(int(self.index)) if self.index is not None else 0,
109
- hash(float(self.friction_static)),
110
- hash(float(self.friction_viscous)),
111
- hash(float(self.position_limit_damper)),
112
- hash(float(self.position_limit_spring)),
113
- hash((float(el) for el in self.position_limit)),
114
- hash(tuple(np.atleast_1d(self.initial_position).tolist())),
115
- hash(float(self.motor_inertia)),
116
- hash(float(self.motor_viscous_friction)),
117
- hash(float(self.motor_gear_ratio)),
143
+ HashedNumpyArray.hash_of_array(self.friction_static),
144
+ HashedNumpyArray.hash_of_array(self.friction_viscous),
145
+ HashedNumpyArray.hash_of_array(self.position_limit_damper),
146
+ HashedNumpyArray.hash_of_array(self.position_limit_spring),
147
+ HashedNumpyArray.hash_of_array(self.position_limit),
148
+ HashedNumpyArray.hash_of_array(self.initial_position),
149
+ HashedNumpyArray.hash_of_array(self.motor_inertia),
150
+ HashedNumpyArray.hash_of_array(self.motor_viscous_friction),
151
+ HashedNumpyArray.hash_of_array(self.motor_gear_ratio),
118
152
  ),
119
153
  )
@@ -12,7 +12,7 @@ import jaxsim.typing as jtp
12
12
  from jaxsim.utils import JaxsimDataclass
13
13
 
14
14
 
15
- @jax_dataclasses.pytree_dataclass
15
+ @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
16
16
  class LinkDescription(JaxsimDataclass):
17
17
  """
18
18
  In-memory description of a robot link.
@@ -31,7 +31,7 @@ class LinkDescription(JaxsimDataclass):
31
31
  mass: float = dataclasses.field(repr=False)
32
32
  inertia: jtp.Matrix = dataclasses.field(repr=False)
33
33
  index: int | None = None
34
- parent: LinkDescription = dataclasses.field(default=None, repr=False)
34
+ parent: LinkDescription | None = dataclasses.field(default=None, repr=False)
35
35
  pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
36
36
 
37
37
  children: Static[tuple[LinkDescription]] = dataclasses.field(
@@ -40,13 +40,15 @@ class LinkDescription(JaxsimDataclass):
40
40
 
41
41
  def __hash__(self) -> int:
42
42
 
43
+ from jaxsim.utils.wrappers import HashedNumpyArray
44
+
43
45
  return hash(
44
46
  (
45
47
  hash(self.name),
46
48
  hash(float(self.mass)),
47
- hash(tuple(np.atleast_1d(self.inertia).flatten().tolist())),
49
+ HashedNumpyArray.hash_of_array(self.inertia),
48
50
  hash(int(self.index)) if self.index is not None else 0,
49
- hash(tuple(np.atleast_1d(self.pose).flatten().tolist())),
51
+ HashedNumpyArray.hash_of_array(self.pose),
50
52
  hash(tuple(self.children)),
51
53
  # Here only using the name to prevent circular recursion:
52
54
  hash(self.parent.name) if self.parent is not None else 0,
@@ -58,7 +60,22 @@ class LinkDescription(JaxsimDataclass):
58
60
  if not isinstance(other, LinkDescription):
59
61
  return False
60
62
 
61
- return hash(self) == hash(other)
63
+ if not (
64
+ self.name == other.name
65
+ and np.allclose(self.mass, other.mass)
66
+ and np.allclose(self.inertia, other.inertia)
67
+ and self.index == other.index
68
+ and np.allclose(self.pose, other.pose)
69
+ and self.children == other.children
70
+ and (
71
+ (self.parent is not None and self.parent.name == other.parent.name)
72
+ if self.parent is not None
73
+ else other.parent is None
74
+ ),
75
+ ):
76
+ return False
77
+
78
+ return True
62
79
 
63
80
  @property
64
81
  def name_and_index(self) -> str:
@@ -12,7 +12,7 @@ from .joint import JointDescription
12
12
  from .link import LinkDescription
13
13
 
14
14
 
15
- @dataclasses.dataclass(frozen=True)
15
+ @dataclasses.dataclass(frozen=True, eq=False, unsafe_hash=False)
16
16
  class ModelDescription(KinematicGraph):
17
17
  """
18
18
  Intermediate representation representing the kinematic graph of a robot model.
@@ -28,7 +28,7 @@ class ModelDescription(KinematicGraph):
28
28
  fixed_base: bool = True
29
29
 
30
30
  collision_shapes: tuple[CollisionShape, ...] = dataclasses.field(
31
- default_factory=list, repr=False, hash=False
31
+ default_factory=list, repr=False
32
32
  )
33
33
 
34
34
  @staticmethod
@@ -249,7 +249,17 @@ class ModelDescription(KinematicGraph):
249
249
  if not isinstance(other, ModelDescription):
250
250
  return False
251
251
 
252
- return hash(self) == hash(other)
252
+ if not (
253
+ self.name == other.name
254
+ and self.fixed_base == other.fixed_base
255
+ and self.root == other.root
256
+ and self.joints == other.joints
257
+ and self.frames == other.frames
258
+ and self.root_pose == other.root_pose
259
+ ):
260
+ return False
261
+
262
+ return True
253
263
 
254
264
  def __hash__(self) -> int:
255
265
 
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import copy
4
4
  import dataclasses
5
5
  import functools
6
- from typing import Any, Callable, Iterable, NamedTuple, Sequence
6
+ from typing import Any, Callable, Iterable, Sequence
7
7
 
8
8
  import numpy as np
9
9
  import numpy.typing as npt
@@ -15,7 +15,8 @@ from jaxsim.utils import Mutability
15
15
  from . import descriptions
16
16
 
17
17
 
18
- class RootPose(NamedTuple):
18
+ @dataclasses.dataclass
19
+ class RootPose:
19
20
  """
20
21
  Represents the root pose in a kinematic graph.
21
22
 
@@ -28,15 +29,20 @@ class RootPose(NamedTuple):
28
29
  The root link of the kinematic graph is the base link.
29
30
  """
30
31
 
31
- root_position: npt.NDArray = np.zeros(3)
32
- root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])
32
+ root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))
33
+
34
+ root_quaternion: npt.NDArray = dataclasses.field(
35
+ default_factory=lambda: np.array([1.0, 0, 0, 0])
36
+ )
33
37
 
34
38
  def __hash__(self) -> int:
35
39
 
40
+ from jaxsim.utils.wrappers import HashedNumpyArray
41
+
36
42
  return hash(
37
43
  (
38
- hash(tuple(self.root_position.tolist())),
39
- hash(tuple(self.root_quaternion.tolist())),
44
+ HashedNumpyArray.hash_of_array(self.root_position),
45
+ HashedNumpyArray.hash_of_array(self.root_quaternion),
40
46
  )
41
47
  )
42
48
 
@@ -45,7 +51,13 @@ class RootPose(NamedTuple):
45
51
  if not isinstance(other, RootPose):
46
52
  return False
47
53
 
48
- return hash(self) == hash(other)
54
+ if not np.allclose(self.root_position, other.root_position):
55
+ return False
56
+
57
+ if not np.allclose(self.root_quaternion, other.root_quaternion):
58
+ return False
59
+
60
+ return True
49
61
 
50
62
 
51
63
  @dataclasses.dataclass(frozen=True)
@@ -31,11 +31,13 @@ class SoftContactsParams(JaxsimDataclass):
31
31
 
32
32
  def __hash__(self) -> int:
33
33
 
34
+ from jaxsim.utils.wrappers import HashedNumpyArray
35
+
34
36
  return hash(
35
37
  (
36
- hash(tuple(jnp.atleast_1d(self.K).flatten().tolist())),
37
- hash(tuple(jnp.atleast_1d(self.D).flatten().tolist())),
38
- hash(tuple(jnp.atleast_1d(self.mu).flatten().tolist())),
38
+ HashedNumpyArray.hash_of_array(self.K),
39
+ HashedNumpyArray.hash_of_array(self.D),
40
+ HashedNumpyArray.hash_of_array(self.mu),
39
41
  )
40
42
  )
41
43
 
jaxsim/utils/wrappers.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- from typing import Generic, TypeVar
4
+ from typing import Callable, Generic, TypeVar
5
5
 
6
6
  import jax
7
7
  import jax_dataclasses
@@ -40,6 +40,33 @@ class HashlessObject(Generic[T]):
40
40
  return hash(self) == hash(other)
41
41
 
42
42
 
43
+ @dataclasses.dataclass
44
+ class CustomHashedObject(Generic[T]):
45
+ """
46
+ A class that wraps an object and computes its hash with a custom hash function.
47
+ """
48
+
49
+ obj: T
50
+
51
+ hash_function: Callable[[T], int] = dataclasses.field(default=lambda obj: hash(obj))
52
+
53
+ def get(self: CustomHashedObject[T]) -> T:
54
+ return self.obj
55
+
56
+ def __hash__(self) -> int:
57
+
58
+ return self.hash_function(self.obj)
59
+
60
+ def __eq__(self, other: CustomHashedObject[T]) -> bool:
61
+
62
+ if not isinstance(other, CustomHashedObject) and isinstance(
63
+ other.get(), type(self.get())
64
+ ):
65
+ return False
66
+
67
+ return hash(self) == hash(other)
68
+
69
+
43
70
  @jax_dataclasses.pytree_dataclass
44
71
  class HashedNumpyArray:
45
72
  """
@@ -56,6 +83,10 @@ class HashedNumpyArray:
56
83
 
57
84
  array: jax.Array | npt.NDArray
58
85
 
86
+ precision: float | None = dataclasses.field(
87
+ default=1e-9, repr=False, compare=False, hash=False
88
+ )
89
+
59
90
  large_array: jax_dataclasses.Static[bool] = dataclasses.field(
60
91
  default=False, repr=False, compare=False, hash=False
61
92
  )
@@ -65,7 +96,9 @@ class HashedNumpyArray:
65
96
 
66
97
  def __hash__(self) -> int:
67
98
 
68
- return hash(tuple(np.atleast_1d(self.array).flatten().tolist()))
99
+ return HashedNumpyArray.hash_of_array(
100
+ array=self.array, precision=self.precision
101
+ )
69
102
 
70
103
  def __eq__(self, other: HashedNumpyArray) -> bool:
71
104
 
@@ -76,3 +109,37 @@ class HashedNumpyArray:
76
109
  return np.array_equal(self.array, other.array)
77
110
 
78
111
  return hash(self) == hash(other)
112
+
113
+ @staticmethod
114
+ def hash_of_array(
115
+ array: jax.Array | npt.NDArray, precision: float | None = 1e-9
116
+ ) -> int:
117
+ """
118
+ Calculate the hash of a NumPy array.
119
+
120
+ Args:
121
+ array: The array to hash.
122
+ precision: Optionally limit the precision over which the hash is computed.
123
+
124
+ Returns:
125
+ The hash of the array.
126
+ """
127
+
128
+ array = np.array(array).flatten()
129
+
130
+ array = np.where(array == np.nan, hash(np.nan), array)
131
+ array = np.where(array == np.inf, hash(np.inf), array)
132
+ array = np.where(array == -np.inf, hash(-np.inf), array)
133
+
134
+ if precision is not None:
135
+
136
+ integer1 = (array * precision).astype(int)
137
+ integer2 = (array - integer1 / precision).astype(int)
138
+
139
+ decimal_array = ((array - integer1 * 1e9 - integer2) / precision).astype(
140
+ int
141
+ )
142
+
143
+ array = np.hstack([integer1, integer2, decimal_array]).astype(int)
144
+
145
+ return hash(tuple(array.tolist()))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.3.1.dev4
3
+ Version: 0.3.1.dev21
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -1,19 +1,20 @@
1
1
  jaxsim/__init__.py,sha256=xzuTuZrgKdWLqqDzbvqzm2cJrEtAbepOeUqDu7ByVek,2621
2
- jaxsim/_version.py,sha256=VGLCHIu949eId2wWA41OJLwZQvXRoNWfGQ_EuSkWWtQ,424
2
+ jaxsim/_version.py,sha256=fSekabX0ZEIHwkdp0Sa0iQ2H7hhPzvja3dhu7EFiX4I,426
3
+ jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
3
4
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
4
5
  jaxsim/typing.py,sha256=cl7HHQCeP3mHmtF6EuQZcCjGvDmc_AryMWntP_lRBGg,722
5
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
6
7
  jaxsim/api/com.py,sha256=Yof6otFi-mLWAs1rqjmeNJTOWIH9gn7BdU5EIjiL6Ts,13481
7
8
  jaxsim/api/common.py,sha256=bqQ__pIQZbh-j8rkoHUkYHAgGiJnDzjHG-q4Ny0OOYQ,6646
8
9
  jaxsim/api/contact.py,sha256=79kcdq7C1_kWgxd1QWBabBhIPkwWEVLk-Fiz9kh-4so,12800
9
- jaxsim/api/data.py,sha256=xfKJz6Rw0YTk-EHCGiT8BFQrs_ggOz01lRi1Qh1mb28,27256
10
- jaxsim/api/frame.py,sha256=0YXOrGmx3cSQqa4_Ky-n6zyup3I3xvXNEgub-Bc5xUw,6222
10
+ jaxsim/api/data.py,sha256=fkVDBV1tODRYIaRb2N15l34InAcnzNygMGG1KFiIU2w,27307
11
+ jaxsim/api/frame.py,sha256=vSbFHL4WtKPySxunNoZLlM_aDuJXZtf8CSBKku63BAs,6178
11
12
  jaxsim/api/joint.py,sha256=-5DogPg4g4mmLckyVIVNjwv-Rxz0IWS7_md9nDlhPWA,4581
12
- jaxsim/api/kin_dyn_parameters.py,sha256=b1e96I8hKU5fh4StLObdVcDpr_6ZglrgD3SRyrqTu18,26203
13
+ jaxsim/api/kin_dyn_parameters.py,sha256=AEpDg9kihbKUN9PA8pNrAruSuWFUC-k_GGxtlcdcDiQ,29215
13
14
  jaxsim/api/link.py,sha256=MdMWaMpM5Dj5JHK8uwHZ4zR4Fjq3R4asi2sGTxk1OAs,16647
14
- jaxsim/api/model.py,sha256=sCx9CcP23A1I_ae4UqTq4Fpq5u0aDki72CqgnR1H50w,59465
15
+ jaxsim/api/model.py,sha256=iuNYsn4xIfX36smmZpwM2O5eftT7ioDQtb6mSUqWu6Q,59759
15
16
  jaxsim/api/ode.py,sha256=luTQJsIXUtCp_81dR42X7WrMvwrXtYbyJiqss29v7zA,10786
16
- jaxsim/api/ode_data.py,sha256=D6FzMkvY_qNuoFEImyp7sxAk-0pJOd3oZeSr9bBTcLk,23089
17
+ jaxsim/api/ode_data.py,sha256=FxUIV5qDNOg_OiOXWs3UrhDgKhGmTKcbHqgr4NX5bv0,23290
17
18
  jaxsim/api/references.py,sha256=UA6kSQVBoq-bXSo99EOELf-_MD5MTy2zS0GtG3wQ410,16618
18
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
19
20
  jaxsim/integrators/common.py,sha256=9HXRVFo95Mpt6RcVhBrOfvOO7mDxqbkXeg_lKUibEFY,20693
@@ -34,12 +35,12 @@ jaxsim/mujoco/loaders.py,sha256=7rjpeJ6_GuitlCty-ZkLhTILQ0GmsFzDMgve-7Gkkh4,2098
34
35
  jaxsim/mujoco/model.py,sha256=1KVRjSLOTCuHt53apBPQTnFYJRknlVoKLQaxWsNK8qc,13494
35
36
  jaxsim/mujoco/visualizer.py,sha256=PXgQzwetS9mRJYHBknDMLsQ9152FdrSvZuT9xE_dfIQ,5069
36
37
  jaxsim/parsers/__init__.py,sha256=sonYi-bBWAoB04kp1mxT4uIORxjb7SdZ0ukGPmVx98Y,44
37
- jaxsim/parsers/kinematic_graph.py,sha256=WdIxntWfxXf67x90oM5KHHXFrSITMwVahqWgcOjYFzc,34730
38
+ jaxsim/parsers/kinematic_graph.py,sha256=1d0JAc3LrGTymaqO9exRsb33-o0Vtgc3cUvNP1YI-0Q,35083
38
39
  jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
39
40
  jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
40
- jaxsim/parsers/descriptions/joint.py,sha256=z_nYSS0fdkcaerjUlPX0U1Vn1ArBT0u_XdKjqxG3HcY,3959
41
- jaxsim/parsers/descriptions/link.py,sha256=QvEE7J6iMQLibpLqlcBV428UA7NMpFFXJwe35GYnjAY,3124
42
- jaxsim/parsers/descriptions/model.py,sha256=V9nSyCK3mo7680WYMDEx1MTfdDTJzbCGPqAp3qA2XRE,9511
41
+ jaxsim/parsers/descriptions/joint.py,sha256=7qUabpldRKwpGYQLCtQyMKiY47hB78J80DIuzI6bGLc,5186
42
+ jaxsim/parsers/descriptions/link.py,sha256=s0NXGOqmDknX0DYof31TGjVLLUHC9kSwzlGYLcCc03A,3710
43
+ jaxsim/parsers/descriptions/model.py,sha256=vfubtW68CUdgcbCHPcgKy0_BxzKQhhM8ycbCE-dF7Vk,9827
43
44
  jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
44
45
  jaxsim/parsers/rod/parser.py,sha256=4COuhkAYv4-GIpCqvkXEJWpDEQczEkBM3KwpqX48Rek,13514
45
46
  jaxsim/parsers/rod/utils.py,sha256=KSjgy6WsmTrD5HZEA2x8hOBSRU4bUGOOHzxKkeFO5r8,5721
@@ -50,16 +51,16 @@ jaxsim/rbda/crba.py,sha256=awsWEQXLE0UPEXIcZCVsAqBEPjyahMNzY9ux6nE1l-s,4739
50
51
  jaxsim/rbda/forward_kinematics.py,sha256=94W7TUXvZjMb-99CyYR8pObuxIYYX9B_dtRZqsNcThs,3418
51
52
  jaxsim/rbda/jacobian.py,sha256=M79bGir-2w_iJ2GurYhOGgMfJnp7ZMOCW6AeeWKK8iM,10745
52
53
  jaxsim/rbda/rnea.py,sha256=DjwkvXQVUSUclM3Uy3UPZ2tao91R5dGd4o7TsS2qObI,7650
53
- jaxsim/rbda/soft_contacts.py,sha256=52zJOF31hFpqoaOednTvi8j_UxhRcdGNjzOPb2v2MPc,11257
54
+ jaxsim/rbda/soft_contacts.py,sha256=0hx9JT4R1X2PPjhZ1EDizBR1gGoCFCtKYu86SeuIvvA,11269
54
55
  jaxsim/rbda/utils.py,sha256=zpbFM2Iq8cntku0BFVu9nfEqZhInCWi9D2INT6MFEI8,5003
55
56
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
56
57
  jaxsim/terrain/terrain.py,sha256=UXQCt7TCkq6GkM8bOZu44pNTpf-FZWiKN6VE4kb4kFk,2342
57
58
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
58
59
  jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
59
60
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
60
- jaxsim/utils/wrappers.py,sha256=EJMcblYKUjxw9HJShVf81Ig3pHUJno6Dx6h-RnY--wM,2040
61
- jaxsim-0.3.1.dev4.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
62
- jaxsim-0.3.1.dev4.dist-info/METADATA,sha256=VBJIUI8eyYPeIFYpycB-meBbGxcStaW9lqbUyxoxkaM,9738
63
- jaxsim-0.3.1.dev4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
64
- jaxsim-0.3.1.dev4.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
65
- jaxsim-0.3.1.dev4.dist-info/RECORD,,
61
+ jaxsim/utils/wrappers.py,sha256=QIJitSoljrKR_U4T3ewCJPT3DTh-tPZsRsg0t_MH93E,3896
62
+ jaxsim-0.3.1.dev21.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
63
+ jaxsim-0.3.1.dev21.dist-info/METADATA,sha256=wtxQdWa5FFEqYdZx81i-VgNk7DKBY6YMQAXn5_1ctMY,9739
64
+ jaxsim-0.3.1.dev21.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
65
+ jaxsim-0.3.1.dev21.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
66
+ jaxsim-0.3.1.dev21.dist-info/RECORD,,