jaxsim 0.4.3.dev68__py3-none-any.whl → 0.4.3.dev77__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/__init__.py CHANGED
@@ -20,6 +20,11 @@ def _jnp_options() -> None:
20
20
  if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
21
21
  logging.warning("Failed to enable 64bit precision in JAX")
22
22
 
23
+ else:
24
+ logging.warning(
25
+ "Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
26
+ )
27
+
23
28
 
24
29
  def _np_options() -> None:
25
30
  import numpy as np
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev68'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev68')
15
+ __version__ = version = '0.4.3.dev77'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev77')
jaxsim/api/contact.py CHANGED
@@ -117,6 +117,7 @@ def collidable_point_dynamics(
117
117
  model: js.model.JaxSimModel,
118
118
  data: js.data.JaxSimModelData,
119
119
  link_forces: jtp.MatrixLike | None = None,
120
+ joint_force_references: jtp.VectorLike | None = None,
120
121
  ) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
121
122
  r"""
122
123
  Compute the 6D force applied to each collidable point.
@@ -127,11 +128,14 @@ def collidable_point_dynamics(
127
128
  link_forces:
128
129
  The 6D external forces to apply to the links expressed in the same
129
130
  representation of data.
131
+ joint_force_references:
132
+ The joint force references to apply to the joints.
130
133
 
131
134
  Returns:
132
135
  The 6D force applied to each collidable point and additional data based on the contact model configured:
133
136
  - Soft: the material deformation rate.
134
- - Rigid: nothing.
137
+ - Rigid: no additional data.
138
+ - QuasiRigid: no additional data.
135
139
 
136
140
  Note:
137
141
  The material deformation rate is always returned in the mixed frame
@@ -144,6 +148,10 @@ def collidable_point_dynamics(
144
148
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
145
149
 
146
150
  # Import privately the contacts classes.
151
+ from jaxsim.rbda.contacts.relaxed_rigid import (
152
+ RelaxedRigidContacts,
153
+ RelaxedRigidContactsState,
154
+ )
147
155
  from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
148
156
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
149
157
 
@@ -186,6 +194,29 @@ def collidable_point_dynamics(
186
194
  model=model,
187
195
  data=data,
188
196
  link_forces=link_forces,
197
+ joint_force_references=joint_force_references,
198
+ )
199
+
200
+ aux_data = dict()
201
+
202
+ case RelaxedRigidContacts():
203
+ assert isinstance(model.contact_model, RelaxedRigidContacts)
204
+ assert isinstance(data.state.contact, RelaxedRigidContactsState)
205
+
206
+ # Build the contact model.
207
+ relaxed_rigid_contacts = RelaxedRigidContacts(
208
+ parameters=data.contacts_params, terrain=model.terrain
209
+ )
210
+
211
+ # Compute the 6D force expressed in the inertial frame and applied to each
212
+ # collidable point.
213
+ W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
214
+ position=W_p_Ci,
215
+ velocity=W_ṗ_Ci,
216
+ model=model,
217
+ data=data,
218
+ link_forces=link_forces,
219
+ joint_force_references=joint_force_references,
189
220
  )
190
221
 
191
222
  aux_data = dict()
jaxsim/api/data.py CHANGED
@@ -6,10 +6,11 @@ from collections.abc import Sequence
6
6
 
7
7
  import jax
8
8
  import jax.numpy as jnp
9
+ import jax.scipy.spatial.transform
9
10
  import jax_dataclasses
10
- import jaxlie
11
11
 
12
12
  import jaxsim.api as js
13
+ import jaxsim.math
13
14
  import jaxsim.rbda
14
15
  import jaxsim.typing as jtp
15
16
  from jaxsim.rbda.contacts.soft import SoftContacts
@@ -39,7 +40,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
39
40
  contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
40
41
 
41
42
  time_ns: jtp.Int = dataclasses.field(
42
- default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
43
+ default_factory=lambda: jnp.array(
44
+ 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
45
+ ),
43
46
  )
44
47
 
45
48
  def __hash__(self) -> int:
@@ -172,9 +175,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
172
175
  )
173
176
 
174
177
  time_ns = (
175
- jnp.array(time * 1e9, dtype=jnp.uint64)
178
+ jnp.array(
179
+ time * 1e9,
180
+ dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
181
+ )
176
182
  if time is not None
177
- else jnp.array(0, dtype=jnp.uint64)
183
+ else jnp.array(
184
+ 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
185
+ )
178
186
  )
179
187
 
180
188
  if isinstance(model.contact_model, SoftContacts):
@@ -188,10 +196,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
188
196
  else:
189
197
  contacts_params = model.contact_model.parameters
190
198
 
191
- W_H_B = jaxlie.SE3.from_rotation_and_translation(
192
- translation=base_position,
193
- rotation=jaxlie.SO3(wxyz=base_quaternion),
194
- ).as_matrix()
199
+ W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
200
+ translation=base_position, quaternion=base_quaternion
201
+ )
195
202
 
196
203
  v_WB = JaxSimModelData.other_representation_to_inertial(
197
204
  array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
@@ -377,7 +384,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
377
384
  on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
378
385
  )
379
386
 
380
- return (W_Q_B if not dcm else jaxlie.SO3(wxyz=W_Q_B).as_matrix()).astype(float)
387
+ return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
388
+ float
389
+ )
381
390
 
382
391
  @jax.jit
383
392
  def base_transform(self) -> jtp.Matrix:
@@ -586,16 +595,18 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
586
595
  The updated `JaxSimModelData` object.
587
596
  """
588
597
 
589
- base_quaternion = jnp.array(base_quaternion)
598
+ W_Q_B = jnp.array(base_quaternion, dtype=float)
599
+
600
+ W_Q_B = jax.lax.select(
601
+ pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
602
+ on_true=W_Q_B,
603
+ on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
604
+ )
590
605
 
591
606
  return self.replace(
592
607
  validate=True,
593
608
  state=self.state.replace(
594
- physics_model=self.state.physics_model.replace(
595
- base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
596
- float
597
- )
598
- )
609
+ physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
599
610
  ),
600
611
  )
601
612
 
@@ -728,6 +739,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
728
739
  )
729
740
 
730
741
 
742
+ @functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"])
731
743
  def random_model_data(
732
744
  model: js.model.JaxSimModel,
733
745
  *,
@@ -737,6 +749,18 @@ def random_model_data(
737
749
  jtp.FloatLike | Sequence[jtp.FloatLike],
738
750
  jtp.FloatLike | Sequence[jtp.FloatLike],
739
751
  ] = ((-1, -1, 0.5), 1.0),
752
+ base_rpy_bounds: tuple[
753
+ jtp.FloatLike | Sequence[jtp.FloatLike],
754
+ jtp.FloatLike | Sequence[jtp.FloatLike],
755
+ ] = (-jnp.pi, jnp.pi),
756
+ base_rpy_seq: str = "XYZ",
757
+ joint_pos_bounds: (
758
+ tuple[
759
+ jtp.FloatLike | Sequence[jtp.FloatLike],
760
+ jtp.FloatLike | Sequence[jtp.FloatLike],
761
+ ]
762
+ | None
763
+ ) = None,
740
764
  base_vel_lin_bounds: tuple[
741
765
  jtp.FloatLike | Sequence[jtp.FloatLike],
742
766
  jtp.FloatLike | Sequence[jtp.FloatLike],
@@ -762,6 +786,12 @@ def random_model_data(
762
786
  key: The random key.
763
787
  velocity_representation: The velocity representation to use.
764
788
  base_pos_bounds: The bounds for the base position.
789
+ base_rpy_bounds:
790
+ The bounds for the euler angles used to build the base orientation.
791
+ base_rpy_seq:
792
+ The sequence of axes for rotation (using `Rotation` from scipy).
793
+ joint_pos_bounds:
794
+ The bounds for the joint positions (reading the joint limits if None).
765
795
  base_vel_lin_bounds: The bounds for the base linear velocity.
766
796
  base_vel_ang_bounds: The bounds for the base angular velocity.
767
797
  joint_vel_bounds: The bounds for the joint velocities.
@@ -776,6 +806,8 @@ def random_model_data(
776
806
 
777
807
  p_min = jnp.array(base_pos_bounds[0], dtype=float)
778
808
  p_max = jnp.array(base_pos_bounds[1], dtype=float)
809
+ rpy_min = jnp.array(base_rpy_bounds[0], dtype=float)
810
+ rpy_max = jnp.array(base_rpy_bounds[1], dtype=float)
779
811
  v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
780
812
  v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
781
813
  ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
@@ -801,13 +833,29 @@ def random_model_data(
801
833
  key=k1, shape=(3,), minval=p_min, maxval=p_max
802
834
  )
803
835
 
804
- physics_model_state.base_quaternion = jaxlie.SO3.from_rpy_radians(
805
- *jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
806
- ).wxyz
836
+ physics_model_state.base_quaternion = jaxsim.math.Quaternion.to_wxyz(
837
+ xyzw=jax.scipy.spatial.transform.Rotation.from_euler(
838
+ seq=base_rpy_seq,
839
+ angles=jax.random.uniform(
840
+ key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
841
+ ),
842
+ ).as_quat()
843
+ )
807
844
 
808
845
  if model.number_of_joints() > 0:
809
- physics_model_state.joint_positions = js.joint.random_joint_positions(
810
- model=model, key=k3
846
+
847
+ s_min, s_max = (
848
+ jnp.array(joint_pos_bounds, dtype=float)
849
+ if joint_pos_bounds is not None
850
+ else (None, None)
851
+ )
852
+
853
+ physics_model_state.joint_positions = (
854
+ js.joint.random_joint_positions(model=model, key=k3)
855
+ if (s_min is None or s_max is None)
856
+ else jax.random.uniform(
857
+ key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
858
+ )
811
859
  )
812
860
 
813
861
  physics_model_state.joint_velocities = jax.random.uniform(
jaxsim/api/joint.py CHANGED
@@ -180,17 +180,77 @@ def random_joint_positions(
180
180
 
181
181
  Args:
182
182
  model: The model to consider.
183
- joint_names: The names of the joints.
184
- key: The random key.
183
+ joint_names: The names of the considered joints (all if None).
184
+ key: The random key (initialized from seed 0 if None).
185
+
186
+ Note:
187
+ If the joint range or revolute joints is larger than 2π, their joint positions
188
+ will be sampled from an interval of size 2π.
185
189
 
186
190
  Returns:
187
191
  The random joint positions.
188
192
  """
189
193
 
194
+ # Consider the key corresponding to a zero seed if it was not passed.
190
195
  key = key if key is not None else jax.random.PRNGKey(seed=0)
191
196
 
197
+ # Get the joint limits parsed from the model description.
192
198
  s_min, s_max = position_limits(model=model, joint_names=joint_names)
193
199
 
200
+ # Get the joint indices.
201
+ # Note that it will trigger an exception if the given `joint_names` are not valid.
202
+ joint_names = joint_names if joint_names is not None else model.joint_names()
203
+ joint_indices = names_to_idxs(model=model, joint_names=joint_names)
204
+
205
+ from jaxsim.parsers.descriptions.joint import JointType
206
+
207
+ # Filter for revolute joints.
208
+ is_revolute = jnp.where(
209
+ jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
210
+ == JointType.Revolute,
211
+ True,
212
+ False,
213
+ )
214
+
215
+ # Shorthand for π.
216
+ π = jnp.pi
217
+
218
+ # Filter for revolute with full range (or continuous).
219
+ is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
220
+
221
+ # Clip the lower limit to -π if the joint range is larger than [-π, π].
222
+ s_min = jnp.where(
223
+ jnp.logical_and(
224
+ is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
225
+ ),
226
+ -π,
227
+ s_min,
228
+ )
229
+
230
+ # Clip the upper limit to +π if the joint range is larger than [-π, π].
231
+ s_max = jnp.where(
232
+ jnp.logical_and(
233
+ is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
234
+ ),
235
+ π,
236
+ s_max,
237
+ )
238
+
239
+ # Shift the lower limit if the upper limit is smaller than +π.
240
+ s_min = jnp.where(
241
+ jnp.logical_and(is_revolute_full_range, s_max < π),
242
+ s_max - 2 * π,
243
+ s_min,
244
+ )
245
+
246
+ # Shift the upper limit if the lower limit is larger than -π.
247
+ s_max = jnp.where(
248
+ jnp.logical_and(is_revolute_full_range, s_min > -π),
249
+ s_min + 2 * π,
250
+ s_max,
251
+ )
252
+
253
+ # Sample the joint positions.
194
254
  s_random = jax.random.uniform(
195
255
  minval=s_min,
196
256
  maxval=s_max,
jaxsim/api/model.py CHANGED
@@ -1747,14 +1747,18 @@ def link_contact_forces(
1747
1747
  data: The data of the considered model.
1748
1748
 
1749
1749
  Returns:
1750
- A (nL, 6) array containing the stacked 6D contact forces of the links,
1750
+ A `(nL, 6)` array containing the stacked 6D contact forces of the links,
1751
1751
  expressed in the frame corresponding to the active representation.
1752
1752
  """
1753
1753
 
1754
+ # Note: the following code should be kept in sync with the function
1755
+ # `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since
1756
+ # there we need to get also aux_data.
1757
+
1754
1758
  # Compute the 6D forces applied to each collidable point expressed in the
1755
1759
  # inertial frame.
1756
1760
  with data.switch_velocity_representation(VelRepr.Inertial):
1757
- W_f_Ci = js.contact.collidable_point_forces(model=model, data=data)
1761
+ W_f_C = js.contact.collidable_point_forces(model=model, data=data)
1758
1762
 
1759
1763
  # Construct the vector defining the parent link index of each collidable point.
1760
1764
  # We use this vector to sum the 6D forces of all collidable points rigidly
@@ -1763,29 +1767,28 @@ def link_contact_forces(
1763
1767
  model.kin_dyn_parameters.contact_parameters.body, dtype=int
1764
1768
  )
1765
1769
 
1770
+ # Create the mask that associate each collidable point to their parent link.
1771
+ # We use this mask to sum the collidable points to the right link.
1772
+ mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
1773
+ model.number_of_links()
1774
+ )
1775
+
1766
1776
  # Sum the forces of all collidable points rigidly attached to a body.
1767
- # Since the contact forces W_f_Ci are expressed in the world frame,
1777
+ # Since the contact forces W_f_C are expressed in the world frame,
1768
1778
  # we don't need any coordinate transformation.
1769
- W_f_Li = jax.vmap(
1770
- lambda nc: (
1771
- jnp.vstack(
1772
- jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
1773
- )
1774
- * W_f_Ci
1775
- ).sum(axis=0)
1776
- )(jnp.arange(model.number_of_links()))
1777
-
1778
- # Convert the 6D forces to the active representation.
1779
- f_Li = jax.vmap(
1780
- lambda W_f_L: data.inertial_to_other_representation(
1781
- array=W_f_L,
1782
- other_representation=data.velocity_representation,
1783
- transform=data.base_transform(),
1784
- is_force=True,
1785
- )
1786
- )(W_f_Li)
1779
+ W_f_L = mask.T @ W_f_C
1787
1780
 
1788
- return f_Li
1781
+ # Create a references object to store the link forces.
1782
+ references = js.references.JaxSimModelReferences.build(
1783
+ model=model, link_forces=W_f_L, velocity_representation=VelRepr.Inertial
1784
+ )
1785
+
1786
+ # Use the references object to convert the link forces to the velocity
1787
+ # representation of data.
1788
+ with references.switch_velocity_representation(data.velocity_representation):
1789
+ f_L = references.link_forces(model=model, data=data)
1790
+
1791
+ return f_L
1789
1792
 
1790
1793
 
1791
1794
  # ======
@@ -1931,11 +1934,22 @@ def step(
1931
1934
  ),
1932
1935
  )
1933
1936
 
1937
+ tf_ns = t0_ns + jnp.array(dt * 1e9, dtype=t0_ns.dtype)
1938
+ tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))
1939
+
1940
+ jax.lax.cond(
1941
+ pred=tf_ns < t0_ns,
1942
+ true_fun=lambda: jax.debug.print(
1943
+ "The simulation time overflowed, resetting simulation time to 0."
1944
+ ),
1945
+ false_fun=lambda: None,
1946
+ )
1947
+
1934
1948
  data_tf = (
1935
1949
  # Store the new state of the model and the new time.
1936
1950
  data.replace(
1937
1951
  state=state_tf,
1938
- time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
1952
+ time_ns=tf_ns,
1939
1953
  )
1940
1954
  )
1941
1955
 
jaxsim/api/ode.py CHANGED
@@ -95,7 +95,7 @@ def system_velocity_dynamics(
95
95
  Args:
96
96
  model: The model to consider.
97
97
  data: The data of the considered model.
98
- joint_forces: The joint forces to apply.
98
+ joint_forces: The joint force references to apply.
99
99
  link_forces:
100
100
  The 6D forces to apply to the links expressed in the frame corresponding to
101
101
  the velocity representation of `data`.
@@ -120,6 +120,7 @@ def system_velocity_dynamics(
120
120
  references = js.references.JaxSimModelReferences.build(
121
121
  model=model,
122
122
  link_forces=O_f_L,
123
+ joint_force_references=joint_forces,
123
124
  data=data,
124
125
  velocity_representation=data.velocity_representation,
125
126
  )
@@ -132,9 +133,16 @@ def system_velocity_dynamics(
132
133
  # with the terrain.
133
134
  W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
134
135
 
136
+ # Initialize a dictionary of auxiliary data.
137
+ # This dictionary is used to store additional data computed by the contact model.
135
138
  aux_data = {}
139
+
136
140
  if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
137
141
 
142
+ # Note: the following code should be kept in sync with the function
143
+ # `jaxsim.api.model.link_contact_forces`. We cannot merge them since
144
+ # here we need to get also aux_data.
145
+
138
146
  # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
139
147
  # along with contact-specific auxiliary states.
140
148
  with data.switch_velocity_representation(VelRepr.Inertial):
@@ -142,6 +150,7 @@ def system_velocity_dynamics(
142
150
  model=model,
143
151
  data=data,
144
152
  link_forces=references.link_forces(model=model, data=data),
153
+ joint_force_references=references.joint_force_references(model=model),
145
154
  )
146
155
 
147
156
  # Construct the vector defining the parent link index of each collidable point.
@@ -175,17 +184,15 @@ def system_velocity_dynamics(
175
184
  forces=W_f_Li_terrain,
176
185
  additive=True,
177
186
  )
178
- # Get the link forces in the data representation
179
- with references.switch_velocity_representation(data.velocity_representation):
187
+
188
+ # Get the link forces in inertial representation
180
189
  f_L_total = references.link_forces(model=model, data=data)
181
190
 
182
- # The following method always returns the inertial-fixed acceleration, and expects
183
- # the link_forces expressed in the inertial frame.
184
- W_v̇_WB, s̈ = system_acceleration(
185
- model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
186
- )
191
+ v̇_WB, = system_acceleration(
192
+ model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
193
+ )
187
194
 
188
- return W_v̇_WB, s̈, aux_data
195
+ return v̇_WB, s̈, aux_data
189
196
 
190
197
 
191
198
  def system_acceleration(
@@ -196,7 +203,7 @@ def system_acceleration(
196
203
  link_forces: jtp.MatrixLike | None = None,
197
204
  ) -> tuple[jtp.Vector, jtp.Vector]:
198
205
  """
199
- Compute the system acceleration in inertial-fixed representation.
206
+ Compute the system acceleration in the active representation.
200
207
 
201
208
  Args:
202
209
  model: The model to consider.
@@ -206,7 +213,7 @@ def system_acceleration(
206
213
  The 6D forces to apply to the links expressed in the same representation of data.
207
214
 
208
215
  Returns:
209
- A tuple containing the base 6D acceleration in inertial-fixed representation
216
+ A tuple containing the base 6D acceleration in in the active representation
210
217
  and the joint accelerations.
211
218
  """
212
219
 
@@ -272,18 +279,15 @@ def system_acceleration(
272
279
  )
273
280
 
274
281
  # - Joint accelerations: s̈ ∈ ℝⁿ
275
- # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
276
- with (
277
- data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
278
- references.switch_velocity_representation(VelRepr.Inertial),
279
- ):
280
- W_v̇_WB, = js.model.forward_dynamics_aba(
281
- model=model,
282
- data=data,
283
- joint_forces=references.joint_force_references(),
284
- link_forces=references.link_forces(),
285
- )
286
- return W_v̇_WB, s̈
282
+ # - Base acceleration: v̇_WB ∈ ℝ⁶
283
+ v̇_WB, s̈ = js.model.forward_dynamics_aba(
284
+ model=model,
285
+ data=data,
286
+ joint_forces=references.joint_force_references(model=model),
287
+ link_forces=references.link_forces(model=model, data=data),
288
+ )
289
+
290
+ return v̇_WB,
287
291
 
288
292
 
289
293
  @jax.jit
@@ -353,7 +357,7 @@ def system_dynamics(
353
357
  corresponding derivative, and the dictionary of auxiliary data returned
354
358
  by the system dynamics evaluation.
355
359
  """
356
-
360
+ from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
357
361
  from jaxsim.rbda.contacts.rigid import RigidContacts
358
362
  from jaxsim.rbda.contacts.soft import SoftContacts
359
363
 
@@ -371,7 +375,7 @@ def system_dynamics(
371
375
  case SoftContacts():
372
376
  ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
373
377
 
374
- case RigidContacts():
378
+ case RigidContacts() | RelaxedRigidContacts():
375
379
  pass
376
380
 
377
381
  case _:
jaxsim/api/ode_data.py CHANGED
@@ -6,6 +6,10 @@ import jax_dataclasses
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
8
  from jaxsim.rbda import ContactsState
9
+ from jaxsim.rbda.contacts.relaxed_rigid import (
10
+ RelaxedRigidContacts,
11
+ RelaxedRigidContactsState,
12
+ )
9
13
  from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
10
14
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
11
15
  from jaxsim.utils import JaxsimDataclass
@@ -173,6 +177,10 @@ class ODEState(JaxsimDataclass):
173
177
  )
174
178
  case RigidContacts():
175
179
  contact = RigidContactsState.build()
180
+
181
+ case RelaxedRigidContacts():
182
+ contact = RelaxedRigidContactsState.build()
183
+
176
184
  case _:
177
185
  raise ValueError("Unable to determine contact state class prefix.")
178
186
 
@@ -216,7 +224,9 @@ class ODEState(JaxsimDataclass):
216
224
 
217
225
  # Get the contact model from the `JaxSimModel`.
218
226
  match contact:
219
- case SoftContactsState() | RigidContactsState():
227
+ case (
228
+ SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
229
+ ):
220
230
  pass
221
231
  case None:
222
232
  contact = SoftContactsState.zero(model=model)
@@ -497,7 +497,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
497
497
  b: jtp.Matrix,
498
498
  c: jtp.Vector,
499
499
  index_of_solution: jtp.IntLike = 0,
500
- ) -> [bool, int | None]:
500
+ ) -> tuple[bool, int | None]:
501
501
  """
502
502
  Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
503
503
 
jaxsim/math/inertia.py CHANGED
@@ -45,7 +45,7 @@ class Inertia:
45
45
  M (jtp.Matrix): The 6x6 inertia matrix.
46
46
 
47
47
  Returns:
48
- Tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
48
+ tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
49
49
 
50
50
  Raises:
51
51
  ValueError: If the input matrix M has an unexpected shape.
jaxsim/mujoco/loaders.py CHANGED
@@ -211,7 +211,7 @@ class RodModelToMjcf:
211
211
  joints_dict = {j.name: j for j in rod_model.joints()}
212
212
 
213
213
  # Convert all the joints not considered to fixed joints.
214
- for joint_name in set(j.name for j in rod_model.joints()) - considered_joints:
214
+ for joint_name in {j.name for j in rod_model.joints()} - considered_joints:
215
215
  joints_dict[joint_name].type = "fixed"
216
216
 
217
217
  # Convert the ROD model to URDF.
@@ -289,10 +289,10 @@ class RodModelToMjcf:
289
289
  mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
290
290
 
291
291
  # Get the joint names.
292
- mj_joint_names = set(
292
+ mj_joint_names = {
293
293
  mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
294
294
  for idx in range(mj_model.njnt)
295
- )
295
+ }
296
296
 
297
297
  # Check that the Mujoco model only has the considered joints.
298
298
  if mj_joint_names != considered_joints:
@@ -394,7 +394,7 @@ class KinematicGraph(Sequence[LinkDescription]):
394
394
  return copy.deepcopy(self)
395
395
 
396
396
  # Check if all considered joints are part of the full kinematic graph
397
- if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
397
+ if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
398
398
  extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
399
399
  msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
400
400
  raise ValueError(msg)
@@ -536,8 +536,8 @@ class KinematicGraph(Sequence[LinkDescription]):
536
536
  root_link_name=full_graph.root.name,
537
537
  )
538
538
 
539
- assert set(f.name for f in self.frames).isdisjoint(
540
- set(f.name for f in unconnected_frames + reduced_frames)
539
+ assert {f.name for f in self.frames}.isdisjoint(
540
+ {f.name for f in unconnected_frames + reduced_frames}
541
541
  )
542
542
 
543
543
  for link in unconnected_links: