jaxsim 0.4.3.dev31__py3-none-any.whl → 0.4.3.dev64__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev31'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev31')
15
+ __version__ = version = '0.4.3.dev64'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev64')
jaxsim/api/contact.py CHANGED
@@ -131,7 +131,8 @@ def collidable_point_dynamics(
131
131
  Returns:
132
132
  The 6D force applied to each collidable point and additional data based on the contact model configured:
133
133
  - Soft: the material deformation rate.
134
- - Rigid: nothing.
134
+ - Rigid: no additional data.
135
+ - QuasiRigid: no additional data.
135
136
 
136
137
  Note:
137
138
  The material deformation rate is always returned in the mixed frame
@@ -144,6 +145,10 @@ def collidable_point_dynamics(
144
145
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
145
146
 
146
147
  # Import privately the contacts classes.
148
+ from jaxsim.rbda.contacts.relaxed_rigid import (
149
+ RelaxedRigidContacts,
150
+ RelaxedRigidContactsState,
151
+ )
147
152
  from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
148
153
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
149
154
 
@@ -190,6 +195,27 @@ def collidable_point_dynamics(
190
195
 
191
196
  aux_data = dict()
192
197
 
198
+ case RelaxedRigidContacts():
199
+ assert isinstance(model.contact_model, RelaxedRigidContacts)
200
+ assert isinstance(data.state.contact, RelaxedRigidContactsState)
201
+
202
+ # Build the contact model.
203
+ relaxed_rigid_contacts = RelaxedRigidContacts(
204
+ parameters=data.contacts_params, terrain=model.terrain
205
+ )
206
+
207
+ # Compute the 6D force expressed in the inertial frame and applied to each
208
+ # collidable point.
209
+ W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
210
+ position=W_p_Ci,
211
+ velocity=W_ṗ_Ci,
212
+ model=model,
213
+ data=data,
214
+ link_forces=link_forces,
215
+ )
216
+
217
+ aux_data = dict()
218
+
193
219
  case _:
194
220
  raise ValueError(f"Invalid contact model {model.contact_model}")
195
221
 
jaxsim/api/data.py CHANGED
@@ -593,16 +593,18 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
593
593
  The updated `JaxSimModelData` object.
594
594
  """
595
595
 
596
- base_quaternion = jnp.array(base_quaternion)
596
+ W_Q_B = jnp.array(base_quaternion, dtype=float)
597
+
598
+ W_Q_B = jax.lax.select(
599
+ pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
600
+ on_true=W_Q_B,
601
+ on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
602
+ )
597
603
 
598
604
  return self.replace(
599
605
  validate=True,
600
606
  state=self.state.replace(
601
- physics_model=self.state.physics_model.replace(
602
- base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
603
- float
604
- )
605
- )
607
+ physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
606
608
  ),
607
609
  )
608
610
 
@@ -744,6 +746,13 @@ def random_model_data(
744
746
  jtp.FloatLike | Sequence[jtp.FloatLike],
745
747
  jtp.FloatLike | Sequence[jtp.FloatLike],
746
748
  ] = ((-1, -1, 0.5), 1.0),
749
+ joint_pos_bounds: (
750
+ tuple[
751
+ jtp.FloatLike | Sequence[jtp.FloatLike],
752
+ jtp.FloatLike | Sequence[jtp.FloatLike],
753
+ ]
754
+ | None
755
+ ) = None,
747
756
  base_vel_lin_bounds: tuple[
748
757
  jtp.FloatLike | Sequence[jtp.FloatLike],
749
758
  jtp.FloatLike | Sequence[jtp.FloatLike],
@@ -769,6 +778,8 @@ def random_model_data(
769
778
  key: The random key.
770
779
  velocity_representation: The velocity representation to use.
771
780
  base_pos_bounds: The bounds for the base position.
781
+ joint_pos_bounds:
782
+ The bounds for the joint positions (reading the joint limits if None).
772
783
  base_vel_lin_bounds: The bounds for the base linear velocity.
773
784
  base_vel_ang_bounds: The bounds for the base angular velocity.
774
785
  joint_vel_bounds: The bounds for the joint velocities.
@@ -813,8 +824,19 @@ def random_model_data(
813
824
  ).wxyz
814
825
 
815
826
  if model.number_of_joints() > 0:
816
- physics_model_state.joint_positions = js.joint.random_joint_positions(
817
- model=model, key=k3
827
+
828
+ s_min, s_max = (
829
+ jnp.array(joint_pos_bounds, dtype=float)
830
+ if joint_pos_bounds is not None
831
+ else (None, None)
832
+ )
833
+
834
+ physics_model_state.joint_positions = (
835
+ js.joint.random_joint_positions(model=model, key=k3)
836
+ if (s_min is None or s_max is None)
837
+ else jax.random.uniform(
838
+ key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
839
+ )
818
840
  )
819
841
 
820
842
  physics_model_state.joint_velocities = jax.random.uniform(
jaxsim/api/joint.py CHANGED
@@ -180,17 +180,77 @@ def random_joint_positions(
180
180
 
181
181
  Args:
182
182
  model: The model to consider.
183
- joint_names: The names of the joints.
184
- key: The random key.
183
+ joint_names: The names of the considered joints (all if None).
184
+ key: The random key (initialized from seed 0 if None).
185
+
186
+ Note:
187
+ If the joint range or revolute joints is larger than 2π, their joint positions
188
+ will be sampled from an interval of size 2π.
185
189
 
186
190
  Returns:
187
191
  The random joint positions.
188
192
  """
189
193
 
194
+ # Consider the key corresponding to a zero seed if it was not passed.
190
195
  key = key if key is not None else jax.random.PRNGKey(seed=0)
191
196
 
197
+ # Get the joint limits parsed from the model description.
192
198
  s_min, s_max = position_limits(model=model, joint_names=joint_names)
193
199
 
200
+ # Get the joint indices.
201
+ # Note that it will trigger an exception if the given `joint_names` are not valid.
202
+ joint_names = joint_names if joint_names is not None else model.joint_names()
203
+ joint_indices = names_to_idxs(model=model, joint_names=joint_names)
204
+
205
+ from jaxsim.parsers.descriptions.joint import JointType
206
+
207
+ # Filter for revolute joints.
208
+ is_revolute = jnp.where(
209
+ jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
210
+ == JointType.Revolute,
211
+ True,
212
+ False,
213
+ )
214
+
215
+ # Shorthand for π.
216
+ π = jnp.pi
217
+
218
+ # Filter for revolute with full range (or continuous).
219
+ is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
220
+
221
+ # Clip the lower limit to -π if the joint range is larger than [-π, π].
222
+ s_min = jnp.where(
223
+ jnp.logical_and(
224
+ is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
225
+ ),
226
+ -π,
227
+ s_min,
228
+ )
229
+
230
+ # Clip the upper limit to +π if the joint range is larger than [-π, π].
231
+ s_max = jnp.where(
232
+ jnp.logical_and(
233
+ is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
234
+ ),
235
+ π,
236
+ s_max,
237
+ )
238
+
239
+ # Shift the lower limit if the upper limit is smaller than +π.
240
+ s_min = jnp.where(
241
+ jnp.logical_and(is_revolute_full_range, s_max < π),
242
+ s_max - 2 * π,
243
+ s_min,
244
+ )
245
+
246
+ # Shift the upper limit if the lower limit is larger than -π.
247
+ s_max = jnp.where(
248
+ jnp.logical_and(is_revolute_full_range, s_min > -π),
249
+ s_min + 2 * π,
250
+ s_max,
251
+ )
252
+
253
+ # Sample the joint positions.
194
254
  s_random = jax.random.uniform(
195
255
  minval=s_min,
196
256
  maxval=s_max,
jaxsim/api/model.py CHANGED
@@ -1935,7 +1935,7 @@ def step(
1935
1935
  tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))
1936
1936
 
1937
1937
  jax.lax.cond(
1938
- pred=tf_ns >= t0_ns,
1938
+ pred=tf_ns < t0_ns,
1939
1939
  true_fun=lambda: jax.debug.print(
1940
1940
  "The simulation time overflowed, resetting simulation time to 0."
1941
1941
  ),
jaxsim/api/ode.py CHANGED
@@ -175,17 +175,15 @@ def system_velocity_dynamics(
175
175
  forces=W_f_Li_terrain,
176
176
  additive=True,
177
177
  )
178
- # Get the link forces in the data representation
179
- with references.switch_velocity_representation(data.velocity_representation):
178
+
179
+ # Get the link forces in inertial representation
180
180
  f_L_total = references.link_forces(model=model, data=data)
181
181
 
182
- # The following method always returns the inertial-fixed acceleration, and expects
183
- # the link_forces expressed in the inertial frame.
184
- W_v̇_WB, s̈ = system_acceleration(
185
- model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
186
- )
182
+ v̇_WB, = system_acceleration(
183
+ model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
184
+ )
187
185
 
188
- return W_v̇_WB, s̈, aux_data
186
+ return v̇_WB, s̈, aux_data
189
187
 
190
188
 
191
189
  def system_acceleration(
@@ -196,7 +194,7 @@ def system_acceleration(
196
194
  link_forces: jtp.MatrixLike | None = None,
197
195
  ) -> tuple[jtp.Vector, jtp.Vector]:
198
196
  """
199
- Compute the system acceleration in inertial-fixed representation.
197
+ Compute the system acceleration in the active representation.
200
198
 
201
199
  Args:
202
200
  model: The model to consider.
@@ -206,7 +204,7 @@ def system_acceleration(
206
204
  The 6D forces to apply to the links expressed in the same representation of data.
207
205
 
208
206
  Returns:
209
- A tuple containing the base 6D acceleration in inertial-fixed representation
207
+ A tuple containing the base 6D acceleration in in the active representation
210
208
  and the joint accelerations.
211
209
  """
212
210
 
@@ -272,18 +270,15 @@ def system_acceleration(
272
270
  )
273
271
 
274
272
  # - Joint accelerations: s̈ ∈ ℝⁿ
275
- # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
276
- with (
277
- data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
278
- references.switch_velocity_representation(VelRepr.Inertial),
279
- ):
280
- W_v̇_WB, = js.model.forward_dynamics_aba(
281
- model=model,
282
- data=data,
283
- joint_forces=references.joint_force_references(),
284
- link_forces=references.link_forces(),
285
- )
286
- return W_v̇_WB, s̈
273
+ # - Base acceleration: v̇_WB ∈ ℝ⁶
274
+ v̇_WB, s̈ = js.model.forward_dynamics_aba(
275
+ model=model,
276
+ data=data,
277
+ joint_forces=references.joint_force_references(model=model),
278
+ link_forces=references.link_forces(model=model, data=data),
279
+ )
280
+
281
+ return v̇_WB,
287
282
 
288
283
 
289
284
  @jax.jit
@@ -353,7 +348,7 @@ def system_dynamics(
353
348
  corresponding derivative, and the dictionary of auxiliary data returned
354
349
  by the system dynamics evaluation.
355
350
  """
356
-
351
+ from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
357
352
  from jaxsim.rbda.contacts.rigid import RigidContacts
358
353
  from jaxsim.rbda.contacts.soft import SoftContacts
359
354
 
@@ -371,7 +366,7 @@ def system_dynamics(
371
366
  case SoftContacts():
372
367
  ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
373
368
 
374
- case RigidContacts():
369
+ case RigidContacts() | RelaxedRigidContacts():
375
370
  pass
376
371
 
377
372
  case _:
jaxsim/api/ode_data.py CHANGED
@@ -6,6 +6,10 @@ import jax_dataclasses
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
8
  from jaxsim.rbda import ContactsState
9
+ from jaxsim.rbda.contacts.relaxed_rigid import (
10
+ RelaxedRigidContacts,
11
+ RelaxedRigidContactsState,
12
+ )
9
13
  from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
10
14
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
11
15
  from jaxsim.utils import JaxsimDataclass
@@ -173,6 +177,10 @@ class ODEState(JaxsimDataclass):
173
177
  )
174
178
  case RigidContacts():
175
179
  contact = RigidContactsState.build()
180
+
181
+ case RelaxedRigidContacts():
182
+ contact = RelaxedRigidContactsState.build()
183
+
176
184
  case _:
177
185
  raise ValueError("Unable to determine contact state class prefix.")
178
186
 
@@ -216,7 +224,9 @@ class ODEState(JaxsimDataclass):
216
224
 
217
225
  # Get the contact model from the `JaxSimModel`.
218
226
  match contact:
219
- case SoftContactsState() | RigidContactsState():
227
+ case (
228
+ SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
229
+ ):
220
230
  pass
221
231
  case None:
222
232
  contact = SoftContactsState.zero(model=model)
@@ -497,7 +497,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
497
497
  b: jtp.Matrix,
498
498
  c: jtp.Vector,
499
499
  index_of_solution: jtp.IntLike = 0,
500
- ) -> [bool, int | None]:
500
+ ) -> tuple[bool, int | None]:
501
501
  """
502
502
  Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
503
503
 
jaxsim/math/inertia.py CHANGED
@@ -45,7 +45,7 @@ class Inertia:
45
45
  M (jtp.Matrix): The 6x6 inertia matrix.
46
46
 
47
47
  Returns:
48
- Tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
48
+ tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
49
49
 
50
50
  Raises:
51
51
  ValueError: If the input matrix M has an unexpected shape.
jaxsim/mujoco/loaders.py CHANGED
@@ -211,7 +211,7 @@ class RodModelToMjcf:
211
211
  joints_dict = {j.name: j for j in rod_model.joints()}
212
212
 
213
213
  # Convert all the joints not considered to fixed joints.
214
- for joint_name in set(j.name for j in rod_model.joints()) - considered_joints:
214
+ for joint_name in {j.name for j in rod_model.joints()} - considered_joints:
215
215
  joints_dict[joint_name].type = "fixed"
216
216
 
217
217
  # Convert the ROD model to URDF.
@@ -289,10 +289,10 @@ class RodModelToMjcf:
289
289
  mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
290
290
 
291
291
  # Get the joint names.
292
- mj_joint_names = set(
292
+ mj_joint_names = {
293
293
  mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
294
294
  for idx in range(mj_model.njnt)
295
- )
295
+ }
296
296
 
297
297
  # Check that the Mujoco model only has the considered joints.
298
298
  if mj_joint_names != considered_joints:
@@ -394,7 +394,7 @@ class KinematicGraph(Sequence[LinkDescription]):
394
394
  return copy.deepcopy(self)
395
395
 
396
396
  # Check if all considered joints are part of the full kinematic graph
397
- if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
397
+ if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
398
398
  extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
399
399
  msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
400
400
  raise ValueError(msg)
@@ -536,8 +536,8 @@ class KinematicGraph(Sequence[LinkDescription]):
536
536
  root_link_name=full_graph.root.name,
537
537
  )
538
538
 
539
- assert set(f.name for f in self.frames).isdisjoint(
540
- set(f.name for f in unconnected_frames + reduced_frames)
539
+ assert {f.name for f in self.frames}.isdisjoint(
540
+ {f.name for f in unconnected_frames + reduced_frames}
541
541
  )
542
542
 
543
543
  for link in unconnected_links:
@@ -0,0 +1,384 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import Any
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import jax_dataclasses
9
+ import jaxopt
10
+
11
+ import jaxsim.api as js
12
+ import jaxsim.typing as jtp
13
+ from jaxsim.api.common import VelRepr
14
+ from jaxsim.math import Adjoint
15
+ from jaxsim.terrain.terrain import FlatTerrain, Terrain
16
+
17
+ from .common import ContactModel, ContactsParams, ContactsState
18
+
19
+
20
+ @jax_dataclasses.pytree_dataclass
21
+ class RelaxedRigidContactsParams(ContactsParams):
22
+ """Parameters of the relaxed rigid contacts model."""
23
+
24
+ # Time constant
25
+ time_constant: jtp.Float = dataclasses.field(
26
+ default_factory=lambda: jnp.array(0.01, dtype=float)
27
+ )
28
+
29
+ # Adimensional damping coefficient
30
+ damping_coefficient: jtp.Float = dataclasses.field(
31
+ default_factory=lambda: jnp.array(1.0, dtype=float)
32
+ )
33
+
34
+ # Minimum impedance
35
+ d_min: jtp.Float = dataclasses.field(
36
+ default_factory=lambda: jnp.array(0.9, dtype=float)
37
+ )
38
+
39
+ # Maximum impedance
40
+ d_max: jtp.Float = dataclasses.field(
41
+ default_factory=lambda: jnp.array(0.95, dtype=float)
42
+ )
43
+
44
+ # Width
45
+ width: jtp.Float = dataclasses.field(
46
+ default_factory=lambda: jnp.array(0.0001, dtype=float)
47
+ )
48
+
49
+ # Midpoint
50
+ midpoint: jtp.Float = dataclasses.field(
51
+ default_factory=lambda: jnp.array(0.1, dtype=float)
52
+ )
53
+
54
+ # Power exponent
55
+ power: jtp.Float = dataclasses.field(
56
+ default_factory=lambda: jnp.array(1.0, dtype=float)
57
+ )
58
+
59
+ # Stiffness
60
+ stiffness: jtp.Float = dataclasses.field(
61
+ default_factory=lambda: jnp.array(0.0, dtype=float)
62
+ )
63
+
64
+ # Damping
65
+ damping: jtp.Float = dataclasses.field(
66
+ default_factory=lambda: jnp.array(0.0, dtype=float)
67
+ )
68
+
69
+ # Friction coefficient
70
+ mu: jtp.Float = dataclasses.field(
71
+ default_factory=lambda: jnp.array(0.5, dtype=float)
72
+ )
73
+
74
+ # Maximum number of iterations
75
+ max_iterations: jtp.Int = dataclasses.field(
76
+ default_factory=lambda: jnp.array(50, dtype=int)
77
+ )
78
+
79
+ # Solver tolerance
80
+ tolerance: jtp.Float = dataclasses.field(
81
+ default_factory=lambda: jnp.array(1e-6, dtype=float)
82
+ )
83
+
84
+ def __hash__(self) -> int:
85
+ from jaxsim.utils.wrappers import HashedNumpyArray
86
+
87
+ return hash(
88
+ (
89
+ HashedNumpyArray(self.time_constant),
90
+ HashedNumpyArray(self.damping_coefficient),
91
+ HashedNumpyArray(self.d_min),
92
+ HashedNumpyArray(self.d_max),
93
+ HashedNumpyArray(self.width),
94
+ HashedNumpyArray(self.midpoint),
95
+ HashedNumpyArray(self.power),
96
+ HashedNumpyArray(self.stiffness),
97
+ HashedNumpyArray(self.damping),
98
+ HashedNumpyArray(self.mu),
99
+ HashedNumpyArray(self.max_iterations),
100
+ HashedNumpyArray(self.tolerance),
101
+ )
102
+ )
103
+
104
+ def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
105
+ return hash(self) == hash(other)
106
+
107
+ @classmethod
108
+ def build(
109
+ cls,
110
+ time_constant: jtp.FloatLike | None = None,
111
+ damping_coefficient: jtp.FloatLike | None = None,
112
+ d_min: jtp.FloatLike | None = None,
113
+ d_max: jtp.FloatLike | None = None,
114
+ width: jtp.FloatLike | None = None,
115
+ midpoint: jtp.FloatLike | None = None,
116
+ power: jtp.FloatLike | None = None,
117
+ stiffness: jtp.FloatLike | None = None,
118
+ damping: jtp.FloatLike | None = None,
119
+ mu: jtp.FloatLike | None = None,
120
+ max_iterations: jtp.IntLike | None = None,
121
+ tolerance: jtp.FloatLike | None = None,
122
+ ) -> RelaxedRigidContactsParams:
123
+ """Create a `RelaxedRigidContactsParams` instance"""
124
+
125
+ return cls(
126
+ **{
127
+ field: jnp.array(locals().get(field, default), dtype=default.dtype)
128
+ for field, default in map(
129
+ lambda f: (f, cls.__dataclass_fields__[f].default),
130
+ filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
131
+ )
132
+ }
133
+ )
134
+
135
+ def valid(self) -> bool:
136
+ return bool(
137
+ jnp.all(self.time_constant >= 0.0)
138
+ and jnp.all(self.damping_coefficient > 0.0)
139
+ and jnp.all(self.d_min >= 0.0)
140
+ and jnp.all(self.d_max <= 1.0)
141
+ and jnp.all(self.d_min <= self.d_max)
142
+ and jnp.all(self.width >= 0.0)
143
+ and jnp.all(self.midpoint >= 0.0)
144
+ and jnp.all(self.power >= 0.0)
145
+ and jnp.all(self.mu >= 0.0)
146
+ and jnp.all(self.max_iterations > 0)
147
+ and jnp.all(self.tolerance > 0.0)
148
+ )
149
+
150
+
151
+ @jax_dataclasses.pytree_dataclass
152
+ class RelaxedRigidContactsState(ContactsState):
153
+ """Class storing the state of the relaxed rigid contacts model."""
154
+
155
+ def __eq__(self, other: RelaxedRigidContactsState) -> bool:
156
+ return hash(self) == hash(other)
157
+
158
+ @staticmethod
159
+ def build() -> RelaxedRigidContactsState:
160
+ """Create a `RelaxedRigidContactsState` instance"""
161
+
162
+ return RelaxedRigidContactsState()
163
+
164
+ @staticmethod
165
+ def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
166
+ """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
167
+ return RelaxedRigidContactsState.build()
168
+
169
+ def valid(self, model: js.model.JaxSimModel) -> bool:
170
+ return True
171
+
172
+
173
+ @jax_dataclasses.pytree_dataclass
174
+ class RelaxedRigidContacts(ContactModel):
175
+ """Relaxed rigid contacts model."""
176
+
177
+ parameters: RelaxedRigidContactsParams = dataclasses.field(
178
+ default_factory=RelaxedRigidContactsParams
179
+ )
180
+
181
+ terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
182
+ default_factory=FlatTerrain
183
+ )
184
+
185
+ def compute_contact_forces(
186
+ self,
187
+ position: jtp.Vector,
188
+ velocity: jtp.Vector,
189
+ model: js.model.JaxSimModel,
190
+ data: js.data.JaxSimModelData,
191
+ link_forces: jtp.MatrixLike | None = None,
192
+ ) -> tuple[jtp.Vector, tuple[Any, ...]]:
193
+
194
+ link_forces = (
195
+ link_forces
196
+ if link_forces is not None
197
+ else jnp.zeros((model.number_of_links(), 6))
198
+ )
199
+
200
+ references = js.references.JaxSimModelReferences.build(
201
+ model=model,
202
+ data=data,
203
+ velocity_representation=data.velocity_representation,
204
+ link_forces=link_forces,
205
+ )
206
+
207
+ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
208
+ x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))
209
+
210
+ n̂ = self.terrain.normal(x=x, y=y).squeeze()
211
+ h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
212
+
213
+ return jnp.dot(h, n̂)
214
+
215
+ # Compute the activation state of the collidable points
216
+ δ = jax.vmap(_detect_contact)(*position.T)
217
+
218
+ with (
219
+ references.switch_velocity_representation(VelRepr.Mixed),
220
+ data.switch_velocity_representation(VelRepr.Mixed),
221
+ ):
222
+ M = js.model.free_floating_mass_matrix(model=model, data=data)
223
+ Jl_WC = jnp.vstack(
224
+ jax.vmap(lambda J, height: J * (height < 0))(
225
+ js.contact.jacobian(model=model, data=data)[:, :3, :], δ
226
+ )
227
+ )
228
+ W_H_C = js.contact.transforms(model=model, data=data)
229
+ BW_ν̇_free = jnp.hstack(
230
+ js.ode.system_acceleration(
231
+ model=model,
232
+ data=data,
233
+ link_forces=references.link_forces(model=model, data=data),
234
+ )
235
+ )
236
+ BW_ν = data.generalized_velocity()
237
+ J̇_WC = jnp.vstack(
238
+ jax.vmap(lambda J̇, height: J̇ * (height < 0))(
239
+ js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
240
+ ),
241
+ )
242
+
243
+ a_ref, R, K, D = self._regularizers(
244
+ model=model,
245
+ penetration=δ,
246
+ velocity=velocity,
247
+ parameters=self.parameters,
248
+ )
249
+
250
+ G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
251
+ CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
252
+
253
+ # Calculate quantities for the linear optimization problem.
254
+ A = G + R
255
+ b = CW_al_free_WC - a_ref
256
+
257
+ objective = lambda x: jnp.sum(jnp.square(A @ x + b))
258
+
259
+ # Compute the 3D linear force in C[W] frame
260
+ opt = jaxopt.LBFGS(
261
+ fun=objective,
262
+ maxiter=self.parameters.max_iterations,
263
+ tol=self.parameters.tolerance,
264
+ maxls=30,
265
+ history_size=10,
266
+ max_stepsize=100.0,
267
+ )
268
+
269
+ init_params = (
270
+ K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
271
+ + D[:, jnp.newaxis] * velocity
272
+ ).flatten()
273
+
274
+ CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
275
+
276
+ def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
277
+ W_Xf_CW = Adjoint.from_transform(
278
+ W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
279
+ inverse=True,
280
+ ).T
281
+ return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
282
+
283
+ W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
284
+
285
+ return W_f_C, (None,)
286
+
287
+ @staticmethod
288
+ def _regularizers(
289
+ model: js.model.JaxSimModel,
290
+ penetration: jtp.Array,
291
+ velocity: jtp.Array,
292
+ parameters: RelaxedRigidContactsParams,
293
+ ) -> tuple:
294
+ """
295
+ Compute the contact jacobian and the reference acceleration.
296
+
297
+ Args:
298
+ model: The jaxsim model.
299
+ penetration: The penetration of the collidable points.
300
+ velocity: The velocity of the collidable points.
301
+ parameters: The parameters of the relaxed rigid contacts model.
302
+
303
+ Returns:
304
+ A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
305
+ """
306
+
307
+ Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
308
+ parameters
309
+ )
310
+
311
+ def _imp_aref(
312
+ penetration: jtp.Array,
313
+ velocity: jtp.Array,
314
+ ) -> tuple[jtp.Array, jtp.Array]:
315
+ """
316
+ Calculates impedance and offset acceleration in constraint frame.
317
+
318
+ Args:
319
+ penetration: penetration in constraint frame
320
+ velocity: velocity in constraint frame
321
+
322
+ Returns:
323
+ a_ref: offset acceleration in constraint frame
324
+ R: regularization matrix
325
+ K: computed stiffness
326
+ D: computed damping
327
+ """
328
+ position = jnp.zeros(shape=(3,)).at[2].set(penetration)
329
+
330
+ imp_x = jnp.abs(position) / width
331
+ imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
332
+
333
+ imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
334
+
335
+ imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
336
+
337
+ imp = jnp.clip(ξ_min + imp_y * (ξ_max - ξ_min), ξ_min, ξ_max)
338
+ imp = jnp.atleast_1d(jnp.where(imp_x > 1.0, ξ_max, imp))
339
+
340
+ # When passing negative values, K and D represent a spring and damper, respectively.
341
+ K_f = jnp.where(K < 0, -K / ξ_max**2, 1 / (ξ_max * Ω * ζ) ** 2)
342
+ D_f = jnp.where(D < 0, -D / ξ_max, 2 / (ξ_max * Ω))
343
+
344
+ a_ref = -jnp.atleast_1d(D_f * velocity + K_f * imp * position)
345
+
346
+ return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
347
+
348
+ def _compute_row(
349
+ *,
350
+ link_idx: jtp.Float,
351
+ penetration: jtp.Array,
352
+ velocity: jtp.Array,
353
+ ) -> tuple[jtp.Array, jtp.Array]:
354
+
355
+ # Compute the reference acceleration.
356
+ ξ, a_ref, K, D = _imp_aref(
357
+ penetration=penetration,
358
+ velocity=velocity,
359
+ )
360
+
361
+ # Compute the regularization terms.
362
+ R = (
363
+ (2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
364
+ * (1 + μ**2)
365
+ @ jnp.linalg.inv(M_L[link_idx, :3, :3])
366
+ )
367
+
368
+ return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
369
+
370
+ M_L = js.model.link_spatial_inertia_matrices(model=model)
371
+
372
+ a_ref, R, K, D = jax.tree.map(
373
+ jnp.concatenate,
374
+ (
375
+ *jax.vmap(_compute_row)(
376
+ link_idx=jnp.array(
377
+ model.kin_dyn_parameters.contact_parameters.body
378
+ ),
379
+ penetration=penetration,
380
+ velocity=velocity,
381
+ ),
382
+ ),
383
+ )
384
+ return a_ref, jnp.diag(R), K, D
@@ -9,7 +9,6 @@ import jax_dataclasses
9
9
 
10
10
  import jaxsim.api as js
11
11
  import jaxsim.typing as jtp
12
- from jaxsim import math
13
12
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
14
13
  from jaxsim.terrain import FlatTerrain, Terrain
15
14
 
@@ -272,9 +271,17 @@ class RigidContacts(ContactModel):
272
271
  link_forces=link_forces,
273
272
  )
274
273
 
275
- with references.switch_velocity_representation(VelRepr.Mixed):
276
- BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
277
- model, data, references=references
274
+ with (
275
+ references.switch_velocity_representation(VelRepr.Mixed),
276
+ data.switch_velocity_representation(VelRepr.Mixed),
277
+ ):
278
+ BW_ν̇_free = jnp.hstack(
279
+ js.ode.system_acceleration(
280
+ model=model,
281
+ data=data,
282
+ joint_forces=references.joint_force_references(model=model),
283
+ link_forces=references.link_forces(model=model, data=data),
284
+ )
278
285
  )
279
286
 
280
287
  free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
@@ -380,43 +387,6 @@ class RigidContacts(ContactModel):
380
387
  n_constraints = 6 * n_collidable_points
381
388
  return jnp.zeros(shape=(n_constraints,))
382
389
 
383
- @staticmethod
384
- def _compute_mixed_nu_dot_free(
385
- model: js.model.JaxSimModel,
386
- data: js.data.JaxSimModelData,
387
- references: js.references.JaxSimModelReferences | None = None,
388
- ) -> jtp.Array:
389
- references = (
390
- references
391
- if references is not None
392
- else js.references.JaxSimModelReferences.zero(model=model, data=data)
393
- )
394
-
395
- with (
396
- data.switch_velocity_representation(VelRepr.Mixed),
397
- references.switch_velocity_representation(VelRepr.Mixed),
398
- ):
399
- BW_v_WB = data.base_velocity()
400
- W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
401
- W_v̇_WB, s̈ = js.ode.system_acceleration(
402
- model=model,
403
- data=data,
404
- joint_forces=references.joint_force_references(model=model),
405
- link_forces=references.link_forces(model=model, data=data),
406
- )
407
-
408
- # Convert the inertial-fixed base acceleration to a mixed base acceleration.
409
- W_H_B = data.base_transform()
410
- W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
411
- BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
412
- term1 = BW_X_W @ W_v̇_WB
413
- term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
414
- BW_v̇_WB = term1 - term2
415
-
416
- BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
417
-
418
- return BW_ν̇
419
-
420
390
  @staticmethod
421
391
  def _linear_acceleration_of_collidable_points(
422
392
  model: js.model.JaxSimModel,
jaxsim/terrain/terrain.py CHANGED
@@ -46,66 +46,82 @@ class Terrain(abc.ABC):
46
46
  @jax_dataclasses.pytree_dataclass
47
47
  class FlatTerrain(Terrain):
48
48
 
49
- z: float = dataclasses.field(default=0.0, kw_only=True)
49
+ _height: float = dataclasses.field(default=0.0, kw_only=True)
50
50
 
51
51
  @staticmethod
52
52
  def build(height: jtp.FloatLike) -> FlatTerrain:
53
53
 
54
- return FlatTerrain(z=float(height))
54
+ return FlatTerrain(_height=float(height))
55
55
 
56
56
  def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
57
57
 
58
- return jnp.array(self.z, dtype=float)
58
+ return jnp.array(self._height, dtype=float)
59
+
60
+ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
61
+
62
+ return jnp.array([0.0, 0.0, 1.0], dtype=float)
59
63
 
60
64
  def __hash__(self) -> int:
61
65
 
62
- return hash(self.z)
66
+ return hash(self._height)
63
67
 
64
68
  def __eq__(self, other: FlatTerrain) -> bool:
65
69
 
66
70
  if not isinstance(other, FlatTerrain):
67
71
  return False
68
72
 
69
- return self.z == other.z
73
+ return self._height == other._height
70
74
 
71
75
 
72
76
  @jax_dataclasses.pytree_dataclass
73
77
  class PlaneTerrain(FlatTerrain):
74
78
 
75
- plane_normal: tuple[float, float, float] = jax_dataclasses.field(
79
+ _normal: tuple[float, float, float] = jax_dataclasses.field(
76
80
  default=(0.0, 0.0, 1.0), kw_only=True
77
81
  )
78
82
 
79
83
  @staticmethod
80
- def build(
81
- plane_normal: jtp.VectorLike, plane_height_over_origin: jtp.FloatLike = 0.0
82
- ) -> PlaneTerrain:
84
+ def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain:
83
85
  """
84
86
  Create a PlaneTerrain instance with a specified plane normal vector.
85
87
 
86
88
  Args:
87
- plane_normal: The normal vector of the terrain plane.
88
- plane_height_over_origin: The height of the plane over the origin.
89
+ normal: The normal vector of the terrain plane.
90
+ height: The height of the plane over the origin.
89
91
 
90
92
  Returns:
91
93
  PlaneTerrain: A PlaneTerrain instance.
92
94
  """
93
95
 
94
- plane_normal = jnp.array(plane_normal, dtype=float)
95
- plane_height_over_origin = jnp.array(plane_height_over_origin, dtype=float)
96
+ normal = jnp.array(normal, dtype=float)
97
+ height = jnp.array(height, dtype=float)
96
98
 
97
- if plane_normal.shape != (3,):
99
+ if normal.shape != (3,):
98
100
  msg = "Expected a 3D vector for the plane normal, got '{}'."
99
- raise ValueError(msg.format(plane_normal.shape))
101
+ raise ValueError(msg.format(normal.shape))
100
102
 
101
103
  # Make sure that the plane normal is a unit vector.
102
- plane_normal = plane_normal / jnp.linalg.norm(plane_normal)
104
+ normal = normal / jnp.linalg.norm(normal)
103
105
 
104
106
  return PlaneTerrain(
105
- z=float(plane_height_over_origin),
106
- plane_normal=tuple(plane_normal.tolist()),
107
+ _height=height.item(),
108
+ _normal=tuple(normal.tolist()),
107
109
  )
108
110
 
111
+ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
112
+ """
113
+ Compute the normal vector of the terrain at a specific (x, y) location.
114
+
115
+ Args:
116
+ x: The x-coordinate of the location.
117
+ y: The y-coordinate of the location.
118
+
119
+ Returns:
120
+ The normal vector of the terrain surface at the specified location.
121
+ """
122
+
123
+ return jnp.array(self._normal, dtype=float)
124
+
109
125
  def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
110
126
  """
111
127
  Compute the height of the terrain at a specific (x, y) location on a plane.
@@ -123,10 +139,10 @@ class PlaneTerrain(FlatTerrain):
123
139
  # The height over the origin: -D/C
124
140
 
125
141
  # Get the plane equation coefficients from the terrain normal.
126
- A, B, C = self.plane_normal
142
+ A, B, C = self._normal
127
143
 
128
144
  # Compute the final coefficient D considering the terrain height.
129
- D = -C * self.z
145
+ D = -C * self._height
130
146
 
131
147
  # Invert the plane equation to get the height at the given (x, y) coordinates.
132
148
  return jnp.array(-(A * x + B * y + D) / C).astype(float)
@@ -137,9 +153,9 @@ class PlaneTerrain(FlatTerrain):
137
153
 
138
154
  return hash(
139
155
  (
140
- hash(self.z),
156
+ hash(self._height),
141
157
  HashedNumpyArray.hash_of_array(
142
- array=jnp.array(self.plane_normal, dtype=float)
158
+ array=jnp.array(self._normal, dtype=float)
143
159
  ),
144
160
  )
145
161
  )
@@ -150,10 +166,10 @@ class PlaneTerrain(FlatTerrain):
150
166
  return False
151
167
 
152
168
  if not (
153
- np.allclose(self.z, other.z)
169
+ np.allclose(self._height, other._height)
154
170
  and np.allclose(
155
- np.array(self.plane_normal, dtype=float),
156
- np.array(other.plane_normal, dtype=float),
171
+ np.array(self._normal, dtype=float),
172
+ np.array(other._normal, dtype=float),
157
173
  )
158
174
  ):
159
175
  return False
jaxsim/typing.py CHANGED
@@ -16,7 +16,7 @@ Int = Scalar
16
16
  Bool = Scalar
17
17
  Float = Scalar
18
18
 
19
- PyTree = (
19
+ PyTree: object = (
20
20
  dict[Hashable, TypeVar("PyTree")]
21
21
  | list[TypeVar("PyTree")]
22
22
  | tuple[TypeVar("PyTree")]
@@ -135,9 +135,10 @@ class JaxsimDataclass(abc.ABC):
135
135
  """
136
136
 
137
137
  return tuple(
138
- leaf.shape if hasattr(leaf, "shape") else None
139
- for leaf in jax.tree_util.tree_leaves(tree)
140
- if hasattr(leaf, "shape")
138
+ map(
139
+ lambda leaf: getattr(leaf, "shape", None),
140
+ jax.tree_util.tree_leaves(tree),
141
+ )
141
142
  )
142
143
 
143
144
  @staticmethod
@@ -154,9 +155,10 @@ class JaxsimDataclass(abc.ABC):
154
155
  """
155
156
 
156
157
  return tuple(
157
- leaf.dtype if hasattr(leaf, "dtype") else None
158
- for leaf in jax.tree_util.tree_leaves(tree)
159
- if hasattr(leaf, "dtype")
158
+ map(
159
+ lambda leaf: getattr(leaf, "dtype", None),
160
+ jax.tree_util.tree_leaves(tree),
161
+ )
160
162
  )
161
163
 
162
164
  @staticmethod
@@ -172,9 +174,10 @@ class JaxsimDataclass(abc.ABC):
172
174
  """
173
175
 
174
176
  return tuple(
175
- leaf.weak_type if hasattr(leaf, "weak_type") else False
176
- for leaf in jax.tree_util.tree_leaves(tree)
177
- if hasattr(leaf, "weak_type")
177
+ map(
178
+ lambda leaf: getattr(leaf, "weak_type", None),
179
+ jax.tree_util.tree_leaves(tree),
180
+ )
178
181
  )
179
182
 
180
183
  @staticmethod
jaxsim/utils/wrappers.py CHANGED
@@ -110,7 +110,7 @@ class HashedNumpyArray:
110
110
  return np.allclose(
111
111
  self.array,
112
112
  other.array,
113
- **({dict(atol=self.precision)} if self.precision is not None else {}),
113
+ **(dict(atol=self.precision) if self.precision is not None else {}),
114
114
  )
115
115
 
116
116
  return hash(self) == hash(other)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev31
3
+ Version: 0.4.3.dev64
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -61,6 +61,7 @@ Description-Content-Type: text/markdown
61
61
  License-File: LICENSE
62
62
  Requires-Dist: coloredlogs
63
63
  Requires-Dist: jax>=0.4.13
64
+ Requires-Dist: jaxopt>=0.8.0
64
65
  Requires-Dist: jaxlib>=0.4.13
65
66
  Requires-Dist: jaxlie>=1.3.0
66
67
  Requires-Dist: jax-dataclasses>=1.4.0
@@ -1,29 +1,29 @@
1
1
  jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
2
- jaxsim/_version.py,sha256=J3LwBUFVVgKt400akbPS_UHFGRkurLsRFlpfwhbZGfc,426
2
+ jaxsim/_version.py,sha256=lLNskxtfHW1HqvnLRuhux3LlK89fMiZFUWknSYopw7k,426
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
- jaxsim/typing.py,sha256=IbFx3UkEXi-cm7UBqMPi58rJAFV_HbZ9E_K4JwfNvVM,753
5
+ jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
8
8
  jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
9
- jaxsim/api/contact.py,sha256=HyEAjF7BySDDOlRahN0l7V15IPB0HPXuoM0twamuEW0,20913
10
- jaxsim/api/data.py,sha256=p44Q7pNQEWQz2KgwNrk47_m0D-vbVtT8xuy4fLD8T58,27465
9
+ jaxsim/api/contact.py,sha256=C_PgMjWYYiqpA7Oz3IxHeFgrp855-xG6AQr6Ze98CtI,21863
10
+ jaxsim/api/data.py,sha256=mFUw2mj8AIXduW6HnkGN7eooZHfJhwnWbtYZfLF6gk4,28206
11
11
  jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
12
- jaxsim/api/joint.py,sha256=L81bQe-noPT6_54KOSF7KBjRmEPAS433ULn2EcXI8vI,5115
12
+ jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
14
14
  jaxsim/api/link.py,sha256=qPRtc8qqMRjZxUCZYXJMygbB6huDXBfIT1b1b8Durkw,18631
15
- jaxsim/api/model.py,sha256=Tq0CBjmjipGT02q4LgvVv2PH2uEFFHrDr_OiaiQNIXQ,65872
16
- jaxsim/api/ode.py,sha256=Vb2sN4zwpXnaJDD9-ziz2qvfmfa4jvIQ0fONbBIRGmU,13368
17
- jaxsim/api/ode_data.py,sha256=U7F6TL6bENAxpQQl4PupPoDG7d7VfTTFqDAs3xwu6Hs,20003
15
+ jaxsim/api/model.py,sha256=K0q8-j-04f6B3MEXsctDGtWiuWlN3HbDrsS7zoPYStk,65871
16
+ jaxsim/api/ode.py,sha256=VuOLvCFoyGLmhNf2vFP5BI9BAPz78V_RW5tJ4hrizsw,13041
17
+ jaxsim/api/ode_data.py,sha256=7RSoBhfCJdP6P9InQbDwdBVpClPMMuetewI-6AWm-_0,20276
18
18
  jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
- jaxsim/integrators/common.py,sha256=ntjflaV3qWaFH_E65pAGZ6QipdnFsgQDasKtIKpxTe4,20432
20
+ jaxsim/integrators/common.py,sha256=XIrJVJDO0ldaZ93WgoGNlFoRvazsRJTpO3DrK9kIXqM,20437
21
21
  jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
22
22
  jaxsim/integrators/variable_step.py,sha256=5StkFh9oQba34zlkIoXG2fUN78gbxkHePWbrpQ-QZOI,21274
23
23
  jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
24
24
  jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
25
25
  jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
26
- jaxsim/math/inertia.py,sha256=_hNpoeyEpAGr9ExDQJjckbjhk39luJFF-jv0SKqefnQ,1614
26
+ jaxsim/math/inertia.py,sha256=01hz6wMFreN2jBA0rVoBS1YMVh77KvwuzXSOpI3pxNk,1614
27
27
  jaxsim/math/joint_model.py,sha256=EzAveaG5B6ZnCFNUzN30KEQUVesd83lfWXJarYR-kUw,9989
28
28
  jaxsim/math/quaternion.py,sha256=_WA7W3iv7px83sWO1V1n0-J78hqAlO4SL1-jofE-UZ4,4754
29
29
  jaxsim/math/rotation.py,sha256=k-nwT79zmWrys3NNAB-lGWxat7Kqm_6JnFRoimJ8rBg,2156
@@ -31,11 +31,11 @@ jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
31
31
  jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
32
32
  jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
33
33
  jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
34
- jaxsim/mujoco/loaders.py,sha256=XB-fgXuWMTFiaand5MZlLFQ5__Sh8MK5CJsxIU34MBk,25328
34
+ jaxsim/mujoco/loaders.py,sha256=_8Af_5Yo0-lWHE-46BBMcrqSJnDNxr3peyc519DExtA,25322
35
35
  jaxsim/mujoco/model.py,sha256=AQksXemXWACJ3yvefV2G5HLwwBU9ISoJrOD1wlxdY5w,16386
36
36
  jaxsim/mujoco/visualizer.py,sha256=T1vU-w4NKSmgEkZ0FqVcGmIvYrYO0len2UBSsU4MOZ0,6978
37
37
  jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
- jaxsim/parsers/kinematic_graph.py,sha256=KijMWKyhTLKSNUmOOk4sYQMgPh_OkA_brncL7gBRHaY,34757
38
+ jaxsim/parsers/kinematic_graph.py,sha256=wT2bgaCS8VQJTHy2H9sENkVPDOiMkRikxEF1t_WaahQ,34748
39
39
  jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
40
40
  jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
41
41
  jaxsim/parsers/descriptions/joint.py,sha256=VSb6C0FBBKMqwrHBKfc-Bbn4rl_J0RzUxMQlhIEvOPM,5185
@@ -54,16 +54,17 @@ jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
54
54
  jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
55
55
  jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
56
  jaxsim/rbda/contacts/common.py,sha256=VwAs742futAmLnDgbaOuLzNDBFiKDfYItdEZ4UcFgzE,2467
57
- jaxsim/rbda/contacts/rigid.py,sha256=8Vbnxng-ERZ5ka_eZGIBuhBDr2PNjc7m-Or255AfEw4,15862
57
+ jaxsim/rbda/contacts/relaxed_rigid.py,sha256=9YkPLbK6Kk0wPkuj47r7NBqY2tARyJsiCbrvDlOWHSI,12700
58
+ jaxsim/rbda/contacts/rigid.py,sha256=fbZk7sC6YOnTs_tzQRfsyBpHyT22XF-wB-EvOSZmhos,14746
58
59
  jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
59
60
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
60
- jaxsim/terrain/terrain.py,sha256=ctyNANIFSM3tZmamprjaEDcWgUSP0oNJbmT1zw9RjPs,4565
61
+ jaxsim/terrain/terrain.py,sha256=xUQg47yGxIOcTkLPbnO3sruEGBhoCd16j1evTGlmNjI,5010
61
62
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
62
- jaxsim/utils/jaxsim_dataclass.py,sha256=5xJbY0G8d7C0OTNIW9T4vQxiDak6TGZT9gpNOvRykFI,11373
63
+ jaxsim/utils/jaxsim_dataclass.py,sha256=FSiUvdnq4Y1T9Jaa_mw4ZBQJe8H7deLr3Kupxtlh4iI,11322
63
64
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
64
- jaxsim/utils/wrappers.py,sha256=JhLUh1g8iU-lhjbuZRfkscPZhYlLCOorVM2Xl3ulRBI,4054
65
- jaxsim-0.4.3.dev31.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
66
- jaxsim-0.4.3.dev31.dist-info/METADATA,sha256=6dwVfUvgQY8Mgd7wghQuOITOPPqAviKo7rkiB7kq3IE,17247
67
- jaxsim-0.4.3.dev31.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
68
- jaxsim-0.4.3.dev31.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
69
- jaxsim-0.4.3.dev31.dist-info/RECORD,,
65
+ jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
66
+ jaxsim-0.4.3.dev64.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
+ jaxsim-0.4.3.dev64.dist-info/METADATA,sha256=0-JS1eJjFMSaMzwqbCSpWYU2GcrZkxT1LBDo7lhWICo,17276
68
+ jaxsim-0.4.3.dev64.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
69
+ jaxsim-0.4.3.dev64.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
+ jaxsim-0.4.3.dev64.dist-info/RECORD,,