jaxsim 0.4.3.dev17__py3-none-any.whl → 0.4.3.dev18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev17'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev17')
15
+ __version__ = version = '0.4.3.dev18'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev18')
jaxsim/api/contact.py CHANGED
@@ -114,19 +114,24 @@ def collidable_point_forces(
114
114
 
115
115
  @jax.jit
116
116
  def collidable_point_dynamics(
117
- model: js.model.JaxSimModel, data: js.data.JaxSimModelData
118
- ) -> tuple[jtp.Matrix, jtp.Matrix]:
117
+ model: js.model.JaxSimModel,
118
+ data: js.data.JaxSimModelData,
119
+ link_forces: jtp.MatrixLike | None = None,
120
+ ) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
119
121
  r"""
120
- Compute the 6D force applied to each collidable point and the corresponding
121
- material deformation rate.
122
+ Compute the 6D force applied to each collidable point.
122
123
 
123
124
  Args:
124
125
  model: The model to consider.
125
126
  data: The data of the considered model.
127
+ link_forces:
128
+ The 6D external forces to apply to the links expressed in the same
129
+ representation of data.
126
130
 
127
131
  Returns:
128
- The 6D force applied to each collidable point and the corresponding
129
- material deformation rate.
132
+ The 6D force applied to each collidable point and additional data based on the contact model configured:
133
+ - Soft: the material deformation rate.
134
+ - Rigid: nothing.
130
135
 
131
136
  Note:
132
137
  The material deformation rate is always returned in the mixed frame
@@ -138,7 +143,8 @@ def collidable_point_dynamics(
138
143
  # all collidable points belonging to the robot.
139
144
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
140
145
 
141
- # Import privately the soft contacts classes.
146
+ # Import privately the contacts classes.
147
+ from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
142
148
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
143
149
 
144
150
  # Build the soft contact model.
@@ -161,6 +167,28 @@ def collidable_point_dynamics(
161
167
  W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
162
168
  W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
163
169
  )
170
+ aux_data = dict(m_dot=CW_ṁ)
171
+
172
+ case RigidContacts():
173
+ assert isinstance(model.contact_model, RigidContacts)
174
+ assert isinstance(data.state.contact, RigidContactsState)
175
+
176
+ # Build the contact model.
177
+ rigid_contacts = RigidContacts(
178
+ parameters=data.contacts_params, terrain=model.terrain
179
+ )
180
+
181
+ # Compute the 6D force expressed in the inertial frame and applied to each
182
+ # collidable point.
183
+ W_f_Ci, _ = rigid_contacts.compute_contact_forces(
184
+ position=W_p_Ci,
185
+ velocity=W_ṗ_Ci,
186
+ model=model,
187
+ data=data,
188
+ link_forces=link_forces,
189
+ )
190
+
191
+ aux_data = dict()
164
192
 
165
193
  case _:
166
194
  raise ValueError(f"Invalid contact model {model.contact_model}")
@@ -175,7 +203,7 @@ def collidable_point_dynamics(
175
203
  )
176
204
  )(W_f_Ci)
177
205
 
178
- return f_Ci, CW_ṁ
206
+ return f_Ci, aux_data
179
207
 
180
208
 
181
209
  @functools.partial(jax.jit, static_argnames=["link_names"])
jaxsim/api/model.py CHANGED
@@ -14,6 +14,7 @@ import rod
14
14
  from jax_dataclasses import Static
15
15
 
16
16
  import jaxsim.api as js
17
+ import jaxsim.exceptions
17
18
  import jaxsim.terrain
18
19
  import jaxsim.typing as jtp
19
20
  from jaxsim.math import Adjoint, Cross
@@ -1890,6 +1891,8 @@ def step(
1890
1891
  and the new state of the integrator.
1891
1892
  """
1892
1893
 
1894
+ from jaxsim.rbda.contacts.rigid import RigidContacts
1895
+
1893
1896
  # Extract the integrator kwargs.
1894
1897
  # The following logic allows using integrators having kwargs colliding with the
1895
1898
  # kwargs of this step function.
@@ -1901,12 +1904,12 @@ def step(
1901
1904
 
1902
1905
  # Extract the initial resources.
1903
1906
  t0_ns = data.time_ns
1904
- state_x0 = data.state
1907
+ state_t0 = data.state
1905
1908
  integrator_state_x0 = integrator_state
1906
1909
 
1907
1910
  # Step the dynamics forward.
1908
- state_xf, integrator_state_xf = integrator.step(
1909
- x0=state_x0,
1911
+ state_tf, integrator_state_tf = integrator.step(
1912
+ x0=state_t0,
1910
1913
  t0=jnp.array(t0_ns / 1e9).astype(float),
1911
1914
  dt=dt,
1912
1915
  params=integrator_state_x0,
@@ -1928,11 +1931,61 @@ def step(
1928
1931
  ),
1929
1932
  )
1930
1933
 
1931
- return (
1934
+ data_tf = (
1932
1935
  # Store the new state of the model and the new time.
1933
1936
  data.replace(
1934
- state=state_xf,
1937
+ state=state_tf,
1935
1938
  time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
1936
- ),
1937
- integrator_state_xf,
1939
+ )
1940
+ )
1941
+
1942
+ # Post process the simulation state, if needed.
1943
+ match model.contact_model:
1944
+
1945
+ # Rigid contact models use an impact model that produces a discontinuous model velocity.
1946
+ # Hence here we need to reset the velocity after each impact to guarantee that
1947
+ # the linear velocity of the active collidable points is zero.
1948
+ case RigidContacts():
1949
+ # Raise runtime error for not supported case in which Rigid contacts and Baumgarte stabilization
1950
+ # enabled are used with ForwardEuler integrator.
1951
+ jaxsim.exceptions.raise_runtime_error_if(
1952
+ condition=jnp.logical_and(
1953
+ isinstance(
1954
+ integrator,
1955
+ jaxsim.integrators.fixed_step.ForwardEuler
1956
+ | jaxsim.integrators.fixed_step.ForwardEulerSO3,
1957
+ ),
1958
+ jnp.array(
1959
+ [data_tf.contacts_params.K, data_tf.contacts_params.D]
1960
+ ).any(),
1961
+ ),
1962
+ msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
1963
+ )
1964
+
1965
+ with data_tf.switch_velocity_representation(VelRepr.Mixed):
1966
+ W_p_C = js.contact.collidable_point_positions(model, data_tf)
1967
+ M = js.model.free_floating_mass_matrix(model, data_tf)
1968
+ J_WC = js.contact.jacobian(model, data_tf)
1969
+ px, py, _ = W_p_C.T
1970
+ terrain_height = jax.vmap(model.terrain.height)(px, py)
1971
+ inactive_collidable_points, _ = RigidContacts.detect_contacts(
1972
+ W_p_C=W_p_C,
1973
+ terrain_height=terrain_height,
1974
+ )
1975
+ BW_nu_post_impact = RigidContacts.compute_impact_velocity(
1976
+ data=data_tf,
1977
+ inactive_collidable_points=inactive_collidable_points,
1978
+ M=M,
1979
+ J_WC=J_WC,
1980
+ )
1981
+ data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
1982
+ data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
1983
+ # Restore the input velocity representation.
1984
+ data_tf = data_tf.replace(
1985
+ velocity_representation=data.velocity_representation, validate=False
1986
+ )
1987
+
1988
+ return (
1989
+ data_tf,
1990
+ integrator_state_tf,
1938
1991
  )
jaxsim/api/ode.py CHANGED
@@ -50,7 +50,7 @@ def wrap_system_dynamics_for_integration(
50
50
  # The wrapped dynamics will hold a reference of this object.
51
51
  model_closed = model.copy()
52
52
  data_closed = data.copy().replace(
53
- state=js.ode_data.ODEState.zero(model=model_closed)
53
+ state=js.ode_data.ODEState.zero(model=model_closed, data=data)
54
54
  )
55
55
 
56
56
  def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
@@ -88,7 +88,7 @@ def system_velocity_dynamics(
88
88
  *,
89
89
  joint_forces: jtp.Vector | None = None,
90
90
  link_forces: jtp.Vector | None = None,
91
- ) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
91
+ ) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
92
92
  """
93
93
  Compute the dynamics of the system velocity.
94
94
 
@@ -102,18 +102,10 @@ def system_velocity_dynamics(
102
102
 
103
103
  Returns:
104
104
  A tuple containing the derivative of the base 6D velocity in inertial-fixed
105
- representation, the derivative of the joint velocities, the derivative of
106
- the material deformation, and the dictionary of auxiliary data returned by
107
- the system dynamics evaluation.
105
+ representation, the derivative of the joint velocities, and auxiliary data
106
+ returned by the system dynamics evaluation.
108
107
  """
109
108
 
110
- # Build joint torques if not provided.
111
- τ = (
112
- jnp.atleast_1d(joint_forces.squeeze())
113
- if joint_forces is not None
114
- else jnp.zeros_like(data.joint_positions())
115
- ).astype(float)
116
-
117
109
  # Build link forces if not provided.
118
110
  # These forces are expressed in the frame corresponding to the velocity
119
111
  # representation of data.
@@ -123,6 +115,15 @@ def system_velocity_dynamics(
123
115
  else jnp.zeros((model.number_of_links(), 6))
124
116
  ).astype(float)
125
117
 
118
+ # We expect that the 6D forces included in the `link_forces` argument are expressed
119
+ # in the frame corresponding to the velocity representation of `data`.
120
+ references = js.references.JaxSimModelReferences.build(
121
+ model=model,
122
+ link_forces=O_f_L,
123
+ data=data,
124
+ velocity_representation=data.velocity_representation,
125
+ )
126
+
126
127
  # ======================
127
128
  # Compute contact forces
128
129
  # ======================
@@ -131,19 +132,17 @@ def system_velocity_dynamics(
131
132
  # with the terrain.
132
133
  W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
133
134
 
134
- # Import privately the soft contacts classes.
135
- from jaxsim.rbda.contacts.soft import SoftContactsState
136
-
137
- # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
138
- assert isinstance(data.state.contact, SoftContactsState)
139
- ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)
140
-
135
+ aux_data = {}
141
136
  if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
142
137
 
143
138
  # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
144
- # and the corresponding material deformation rates.
139
+ # along with contact-specific auxiliary states.
145
140
  with data.switch_velocity_representation(VelRepr.Inertial):
146
- W_f_Ci, = js.contact.collidable_point_dynamics(model=model, data=data)
141
+ W_f_Ci, aux_data = js.contact.collidable_point_dynamics(
142
+ model=model,
143
+ data=data,
144
+ link_forces=references.link_forces(model=model, data=data),
145
+ )
147
146
 
148
147
  # Construct the vector defining the parent link index of each collidable point.
149
148
  # We use this vector to sum the 6D forces of all collidable points rigidly
@@ -161,6 +160,74 @@ def system_velocity_dynamics(
161
160
 
162
161
  W_f_Li_terrain = mask.T @ W_f_Ci
163
162
 
163
+ # ===========================
164
+ # Compute system acceleration
165
+ # ===========================
166
+
167
+ # Compute the total link forces
168
+ with (
169
+ data.switch_velocity_representation(VelRepr.Inertial),
170
+ references.switch_velocity_representation(VelRepr.Inertial),
171
+ ):
172
+ references = references.apply_link_forces(
173
+ model=model,
174
+ data=data,
175
+ forces=W_f_Li_terrain,
176
+ additive=True,
177
+ )
178
+ # Get the link forces in the data representation
179
+ with references.switch_velocity_representation(data.velocity_representation):
180
+ f_L_total = references.link_forces(model=model, data=data)
181
+
182
+ # The following method always returns the inertial-fixed acceleration, and expects
183
+ # the link_forces expressed in the inertial frame.
184
+ W_v̇_WB, s̈ = system_acceleration(
185
+ model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
186
+ )
187
+
188
+ return W_v̇_WB, s̈, aux_data
189
+
190
+
191
+ def system_acceleration(
192
+ model: js.model.JaxSimModel,
193
+ data: js.data.JaxSimModelData,
194
+ *,
195
+ joint_forces: jtp.VectorLike | None = None,
196
+ link_forces: jtp.MatrixLike | None = None,
197
+ ) -> tuple[jtp.Vector, jtp.Vector]:
198
+ """
199
+ Compute the system acceleration in inertial-fixed representation.
200
+
201
+ Args:
202
+ model: The model to consider.
203
+ data: The data of the considered model.
204
+ joint_forces: The joint forces to apply.
205
+ link_forces:
206
+ The 6D forces to apply to the links expressed in the same representation of data.
207
+
208
+ Returns:
209
+ A tuple containing the base 6D acceleration in inertial-fixed representation
210
+ and the joint accelerations.
211
+ """
212
+
213
+ # ====================
214
+ # Validate input data
215
+ # ====================
216
+
217
+ # Build link forces if not provided.
218
+ f_L = (
219
+ jnp.atleast_2d(link_forces.squeeze())
220
+ if link_forces is not None
221
+ else jnp.zeros((model.number_of_links(), 6))
222
+ ).astype(float)
223
+
224
+ # Build joint torques if not provided.
225
+ τ = (
226
+ jnp.atleast_1d(joint_forces.squeeze())
227
+ if joint_forces is not None
228
+ else jnp.zeros_like(data.joint_positions())
229
+ ).astype(float)
230
+
164
231
  # ====================
165
232
  # Enforce joint limits
166
233
  # ====================
@@ -198,29 +265,25 @@ def system_velocity_dynamics(
198
265
 
199
266
  references = js.references.JaxSimModelReferences.build(
200
267
  model=model,
201
- joint_force_references=τ_total,
202
- link_forces=O_f_L,
203
268
  data=data,
204
269
  velocity_representation=data.velocity_representation,
270
+ joint_force_references=τ_total,
271
+ link_forces=f_L,
205
272
  )
206
273
 
207
- with references.switch_velocity_representation(VelRepr.Inertial):
208
- W_f_L = references.link_forces(model=model, data=data)
209
-
210
- # Compute the total external 6D forces applied to the links.
211
- W_f_L_total = W_f_L + W_f_Li_terrain
212
-
213
274
  # - Joint accelerations: s̈ ∈ ℝⁿ
214
275
  # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
215
- with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
276
+ with (
277
+ data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
278
+ references.switch_velocity_representation(VelRepr.Inertial),
279
+ ):
216
280
  W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
217
281
  model=model,
218
282
  data=data,
219
- joint_forces=τ_total,
220
- link_forces=W_f_L_total,
283
+ joint_forces=references.joint_force_references(),
284
+ link_forces=references.link_forces(),
221
285
  )
222
-
223
- return W_v̇_WB, s̈, ṁ, dict()
286
+ return W_v̇_WB, s̈
224
287
 
225
288
 
226
289
  @jax.jit
@@ -291,14 +354,29 @@ def system_dynamics(
291
354
  by the system dynamics evaluation.
292
355
  """
293
356
 
357
+ from jaxsim.rbda.contacts.rigid import RigidContacts
358
+ from jaxsim.rbda.contacts.soft import SoftContacts
359
+
294
360
  # Compute the accelerations and the material deformation rate.
295
- W_v̇_WB, s̈, ṁ, aux_dict = system_velocity_dynamics(
361
+ W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
296
362
  model=model,
297
363
  data=data,
298
364
  joint_forces=joint_forces,
299
365
  link_forces=link_forces,
300
366
  )
301
367
 
368
+ ode_state_kwargs = {}
369
+
370
+ match model.contact_model:
371
+ case SoftContacts():
372
+ ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
373
+
374
+ case RigidContacts():
375
+ pass
376
+
377
+ case _:
378
+ raise ValueError("Unable to determine contact state class prefix.")
379
+
302
380
  # Extract the velocities.
303
381
  W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
304
382
  model=model,
@@ -317,7 +395,7 @@ def system_dynamics(
317
395
  base_linear_velocity=W_v̇_WB[0:3],
318
396
  base_angular_velocity=W_v̇_WB[3:6],
319
397
  joint_velocities=s̈,
320
- tangential_deformation=ṁ,
398
+ **ode_state_kwargs,
321
399
  )
322
400
 
323
401
  return ode_state_derivative, aux_dict
jaxsim/api/ode_data.py CHANGED
@@ -6,6 +6,7 @@ import jax_dataclasses
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
8
  from jaxsim.rbda import ContactsState
9
+ from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
9
10
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
10
11
  from jaxsim.utils import JaxsimDataclass
11
12
 
@@ -133,7 +134,7 @@ class ODEState(JaxsimDataclass):
133
134
  base_quaternion: jtp.Vector | None = None,
134
135
  base_linear_velocity: jtp.Vector | None = None,
135
136
  base_angular_velocity: jtp.Vector | None = None,
136
- tangential_deformation: jtp.Matrix | None = None,
137
+ **kwargs,
137
138
  ) -> ODEState:
138
139
  """
139
140
  Build an `ODEState` from a `JaxSimModel`.
@@ -148,9 +149,7 @@ class ODEState(JaxsimDataclass):
148
149
  The linear velocity of the base link in inertial-fixed representation.
149
150
  base_angular_velocity:
150
151
  The angular velocity of the base link in inertial-fixed representation.
151
- tangential_deformation:
152
- The matrix of 3D tangential material deformations corresponding to
153
- each collidable point.
152
+ kwargs: Additional arguments needed to build the contact state.
154
153
 
155
154
  Returns:
156
155
  The `ODEState` built from the `JaxSimModel`.
@@ -163,6 +162,7 @@ class ODEState(JaxsimDataclass):
163
162
  # Get the contact model from the `JaxSimModel`.
164
163
  match model.contact_model:
165
164
  case SoftContacts():
165
+ tangential_deformation = kwargs.get("tangential_deformation", None)
166
166
  contact = SoftContactsState.build_from_jaxsim_model(
167
167
  model=model,
168
168
  **(
@@ -171,6 +171,8 @@ class ODEState(JaxsimDataclass):
171
171
  else dict()
172
172
  ),
173
173
  )
174
+ case RigidContacts():
175
+ contact = RigidContactsState.build()
174
176
  case _:
175
177
  raise ValueError("Unable to determine contact state class prefix.")
176
178
 
@@ -214,7 +216,7 @@ class ODEState(JaxsimDataclass):
214
216
 
215
217
  # Get the contact model from the `JaxSimModel`.
216
218
  match contact:
217
- case SoftContactsState():
219
+ case SoftContactsState() | RigidContactsState():
218
220
  pass
219
221
  case None:
220
222
  contact = SoftContactsState.zero(model=model)
@@ -224,7 +226,7 @@ class ODEState(JaxsimDataclass):
224
226
  return ODEState(physics_model=physics_model_state, contact=contact)
225
227
 
226
228
  @staticmethod
227
- def zero(model: js.model.JaxSimModel) -> ODEState:
229
+ def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState:
228
230
  """
229
231
  Build a zero `ODEState` from a `JaxSimModel`.
230
232
 
@@ -235,7 +237,9 @@ class ODEState(JaxsimDataclass):
235
237
  A zero `ODEState` instance.
236
238
  """
237
239
 
238
- model_state = ODEState.build(model=model)
240
+ model_state = ODEState.build(
241
+ model=model, contact=data.state.contact.zero(model=model)
242
+ )
239
243
 
240
244
  return model_state
241
245
 
@@ -109,11 +109,14 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
109
109
  integrator.params = params
110
110
 
111
111
  with integrator.mutable_context(mutability=Mutability.MUTABLE):
112
- xf = integrator(x0, t0, dt, **kwargs)
112
+ xf, aux_dict = integrator(x0, t0, dt, **kwargs)
113
113
 
114
- return xf, integrator.params | {
115
- Integrator.AfterInitKey: jnp.array(False).astype(bool)
116
- }
114
+ return (
115
+ xf,
116
+ integrator.params
117
+ | {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
118
+ | aux_dict,
119
+ )
117
120
 
118
121
  @abc.abstractmethod
119
122
  def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
@@ -277,15 +280,19 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
277
280
 
278
281
  return integrator
279
282
 
280
- def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
283
+ def __call__(
284
+ self, x0: State, t0: Time, dt: TimeStep, **kwargs
285
+ ) -> tuple[NextState, dict[str, Any]]:
281
286
 
282
287
  # Here z is a batched state with as many batch elements as b.T rows.
283
288
  # Note that z has multiple batches only if b.T has more than one row,
284
289
  # e.g. in Butcher tableau of embedded schemes.
285
- z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
290
+ z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
286
291
 
287
292
  # The next state is the batch element located at the configured index of solution.
288
- return jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
293
+ next_state = jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
294
+
295
+ return next_state, aux_dict
289
296
 
290
297
  @classmethod
291
298
  def integrate_rk_stage(
@@ -343,7 +350,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
343
350
 
344
351
  def _compute_next_state(
345
352
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
346
- ) -> NextState:
353
+ ) -> tuple[NextState, dict[str, Any]]:
347
354
  """
348
355
  Compute the next state of the system, returning all the output states.
349
356
 
@@ -373,19 +380,21 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
373
380
  )
374
381
 
375
382
  # Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
376
- get_ẋ0 = lambda: self.params.get("dxdt0", f(x0, t0)[0])
383
+ get_ẋ0_and_aux_dict = lambda: self.params.get("dxdt0", f(x0, t0))
377
384
 
378
385
  # We use a `jax.lax.scan` to compile the `f` function only once.
379
386
  # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
380
387
  # would include 4 repetitions of the `f` logic, making everything extremely slow.
381
- def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[jax.Array, None]:
388
+ def scan_body(
389
+ carry: jax.Array, i: int | jax.Array
390
+ ) -> tuple[jax.Array, dict[str, Any]]:
382
391
  """"""
383
392
 
384
393
  # Unpack the carry, i.e. the stacked kᵢ vectors.
385
394
  K = carry
386
395
 
387
396
  # Define the computation of the Runge-Kutta stage.
388
- def compute_ki() -> jax.Array:
397
+ def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
389
398
 
390
399
  # Compute ∑ⱼ aᵢⱼ kⱼ.
391
400
  op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
@@ -398,13 +407,13 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
398
407
  # Compute the next time for the kᵢ evaluation.
399
408
  ti = t0 + c[i] * Δt
400
409
 
401
- # This is k = f(xᵢ, tᵢ).
402
- return f(xi, ti)[0]
410
+ # This is kᵢ, aux_dict = f(xᵢ, tᵢ).
411
+ return f(xi, ti)
403
412
 
404
413
  # This selector enables FSAL property in the first iteration (i=0).
405
- ki = jax.lax.cond(
414
+ ki, aux_dict = jax.lax.cond(
406
415
  pred=jnp.logical_and(i == 0, self.has_fsal),
407
- true_fun=get_ẋ0,
416
+ true_fun=get_ẋ0_and_aux_dict,
408
417
  false_fun=compute_ki,
409
418
  )
410
419
 
@@ -413,10 +422,10 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
413
422
  K = jax.tree_util.tree_map(op, K, ki)
414
423
 
415
424
  carry = K
416
- return carry, None
425
+ return carry, aux_dict
417
426
 
418
427
  # Compute the state derivatives kᵢ.
419
- K, _ = jax.lax.scan(
428
+ K, aux_dict = jax.lax.scan(
420
429
  f=scan_body,
421
430
  init=carry0,
422
431
  xs=jnp.arange(c.size),
@@ -439,7 +448,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
439
448
  lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
440
449
  )(z)
441
450
 
442
- return z_transformed
451
+ return z_transformed, aux_dict
443
452
 
444
453
  @staticmethod
445
454
  def butcher_tableau_is_valid(
@@ -5,6 +5,7 @@ from typing import Any
5
5
 
6
6
  import jaxsim.terrain
7
7
  import jaxsim.typing as jtp
8
+ from jaxsim.utils import JaxsimDataclass
8
9
 
9
10
 
10
11
  class ContactsState(abc.ABC):
@@ -42,7 +43,7 @@ class ContactsState(abc.ABC):
42
43
  pass
43
44
 
44
45
 
45
- class ContactsParams(abc.ABC):
46
+ class ContactsParams(JaxsimDataclass):
46
47
  """
47
48
  Abstract class representing the parameters of a contact model.
48
49
  """
@@ -67,7 +68,7 @@ class ContactsParams(abc.ABC):
67
68
  pass
68
69
 
69
70
 
70
- class ContactModel(abc.ABC):
71
+ class ContactModel(JaxsimDataclass):
71
72
  """
72
73
  Abstract class representing a contact model.
73
74
 
@@ -0,0 +1,478 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import Any
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import jax_dataclasses
9
+
10
+ import jaxsim.api as js
11
+ import jaxsim.typing as jtp
12
+ from jaxsim import math
13
+ from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
14
+ from jaxsim.terrain import FlatTerrain, Terrain
15
+
16
+ from .common import ContactModel, ContactsParams, ContactsState
17
+
18
+
19
+ @jax_dataclasses.pytree_dataclass
20
+ class RigidContactsParams(ContactsParams):
21
+ """Parameters of the rigid contacts model."""
22
+
23
+ # Static friction coefficient
24
+ mu: jtp.Float = dataclasses.field(
25
+ default_factory=lambda: jnp.array(0.5, dtype=float)
26
+ )
27
+
28
+ # Baumgarte proportional term
29
+ K: jtp.Float = dataclasses.field(
30
+ default_factory=lambda: jnp.array(0.0, dtype=float)
31
+ )
32
+
33
+ # Baumgarte derivative term
34
+ D: jtp.Float = dataclasses.field(
35
+ default_factory=lambda: jnp.array(0.0, dtype=float)
36
+ )
37
+
38
+ def __hash__(self) -> int:
39
+ from jaxsim.utils.wrappers import HashedNumpyArray
40
+
41
+ return hash(
42
+ (
43
+ HashedNumpyArray.hash_of_array(self.mu),
44
+ HashedNumpyArray.hash_of_array(self.K),
45
+ HashedNumpyArray.hash_of_array(self.D),
46
+ )
47
+ )
48
+
49
+ def __eq__(self, other: RigidContactsParams) -> bool:
50
+ return hash(self) == hash(other)
51
+
52
+ @classmethod
53
+ def build(
54
+ cls,
55
+ mu: jtp.FloatLike | None = None,
56
+ K: jtp.FloatLike | None = None,
57
+ D: jtp.FloatLike | None = None,
58
+ ) -> RigidContactsParams:
59
+ """Create a `RigidContactParams` instance"""
60
+ return RigidContactsParams(
61
+ mu=mu or cls.__dataclass_fields__["mu"].default,
62
+ K=K or cls.__dataclass_fields__["K"].default,
63
+ D=D or cls.__dataclass_fields__["D"].default,
64
+ )
65
+
66
+ def valid(self) -> bool:
67
+ return bool(
68
+ jnp.all(self.mu >= 0.0)
69
+ and jnp.all(self.K >= 0.0)
70
+ and jnp.all(self.D >= 0.0)
71
+ )
72
+
73
+
74
+ @jax_dataclasses.pytree_dataclass
75
+ class RigidContactsState(ContactsState):
76
+ """Class storing the state of the rigid contacts model."""
77
+
78
+ def __eq__(self, other: RigidContactsState) -> bool:
79
+ return hash(self) == hash(other)
80
+
81
+ @staticmethod
82
+ def build(**kwargs) -> RigidContactsState:
83
+ """Create a `RigidContactsState` instance"""
84
+
85
+ return RigidContactsState()
86
+
87
+ @staticmethod
88
+ def zero(**kwargs) -> RigidContactsState:
89
+ """Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
90
+ return RigidContactsState.build()
91
+
92
+ def valid(self, **kwargs) -> bool:
93
+ return True
94
+
95
+
96
+ @jax_dataclasses.pytree_dataclass
97
+ class RigidContacts(ContactModel):
98
+ """Rigid contacts model."""
99
+
100
+ parameters: RigidContactsParams = dataclasses.field(
101
+ default_factory=RigidContactsParams
102
+ )
103
+
104
+ terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
105
+ default_factory=FlatTerrain
106
+ )
107
+
108
+ @staticmethod
109
+ def detect_contacts(
110
+ W_p_C: jtp.ArrayLike,
111
+ terrain_height: jtp.ArrayLike,
112
+ ) -> tuple[jtp.Vector, jtp.Vector]:
113
+ """
114
+ Detect contacts between the collidable points and the terrain.
115
+
116
+ Args:
117
+ W_p_C: The position of the collidable points.
118
+ terrain_height: The height of the terrain at the collidable point position.
119
+
120
+ Returns:
121
+ A tuple containing the activation state of the collidable points and the contact penetration depth h.
122
+ """
123
+
124
+ # TODO: reduce code duplication with js.contact.in_contact
125
+ def detect_contact(
126
+ W_p_C: jtp.ArrayLike,
127
+ terrain_height: jtp.FloatLike,
128
+ ) -> tuple[jtp.Bool, jtp.Float]:
129
+ """
130
+ Detect contacts between the collidable points and the terrain.
131
+ """
132
+
133
+ # Unpack the position of the collidable point.
134
+ _, _, pz = W_p_C.squeeze()
135
+
136
+ inactive = pz > terrain_height
137
+
138
+ # Compute contact penetration depth
139
+ h = jnp.maximum(0.0, terrain_height - pz)
140
+
141
+ return inactive, h
142
+
143
+ inactive_collidable_points, h = jax.vmap(detect_contact)(W_p_C, terrain_height)
144
+
145
+ return inactive_collidable_points, h
146
+
147
+ @staticmethod
148
+ def compute_impact_velocity(
149
+ inactive_collidable_points: jtp.ArrayLike,
150
+ M: jtp.MatrixLike,
151
+ J_WC: jtp.MatrixLike,
152
+ data: js.data.JaxSimModelData,
153
+ ) -> jtp.Vector:
154
+ """Returns the new velocity of the system after a potential impact.
155
+
156
+ Args:
157
+ inactive_collidable_points: The activation state of the collidable points.
158
+ M: The mass matrix of the system.
159
+ J_WC: The Jacobian matrix of the collidable points.
160
+ data: The `JaxSimModelData` instance.
161
+ """
162
+
163
+ def impact_velocity(
164
+ inactive_collidable_points: jtp.ArrayLike,
165
+ nu_pre: jtp.ArrayLike,
166
+ M: jtp.MatrixLike,
167
+ J_WC: jtp.MatrixLike,
168
+ data: js.data.JaxSimModelData,
169
+ ):
170
+ # Compute system velocity after impact maintaining zero linear velocity of active points
171
+ with data.switch_velocity_representation(VelRepr.Mixed):
172
+ sl = jnp.s_[:, 0:3, :]
173
+ Jl_WC = J_WC[sl]
174
+ # Zero out the jacobian rows of inactive points
175
+ Jl_WC = jnp.vstack(
176
+ jnp.where(
177
+ inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
178
+ jnp.zeros_like(Jl_WC),
179
+ Jl_WC,
180
+ )
181
+ )
182
+
183
+ A = jnp.vstack(
184
+ [
185
+ jnp.hstack([M, -Jl_WC.T]),
186
+ jnp.hstack(
187
+ [Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]
188
+ ),
189
+ ]
190
+ )
191
+ b = jnp.hstack([M @ nu_pre, jnp.zeros(Jl_WC.shape[0])])
192
+ x = jnp.linalg.lstsq(A, b)[0]
193
+ nu_post = x[0 : M.shape[0]]
194
+
195
+ return nu_post
196
+
197
+ with data.switch_velocity_representation(VelRepr.Mixed):
198
+ BW_ν_pre_impact = data.generalized_velocity()
199
+
200
+ BW_ν_post_impact = impact_velocity(
201
+ data=data,
202
+ inactive_collidable_points=inactive_collidable_points,
203
+ nu_pre=BW_ν_pre_impact,
204
+ M=M,
205
+ J_WC=J_WC,
206
+ )
207
+
208
+ return BW_ν_post_impact
209
+
210
+ def compute_contact_forces(
211
+ self,
212
+ position: jtp.Vector,
213
+ velocity: jtp.Vector,
214
+ model: js.model.JaxSimModel,
215
+ data: js.data.JaxSimModelData,
216
+ link_forces: jtp.MatrixLike | None = None,
217
+ regularization_term: jtp.FloatLike = 1e-6,
218
+ ) -> tuple[jtp.Vector, tuple[Any, ...]]:
219
+ """
220
+ Compute the contact forces.
221
+
222
+ Args:
223
+ position: The position of the collidable point.
224
+ velocity: The linear velocity of the collidable point.
225
+ model: The `JaxSimModel` instance.
226
+ data: The `JaxSimModelData` instance.
227
+ link_forces:
228
+ Optional `(n_links, 6)` matrix of external forces acting on the links,
229
+ expressed in the same representation of data.
230
+ regularization_term:
231
+ The regularization term to add to the diagonal of the Delassus
232
+ matrix for better numerical conditioning.
233
+
234
+ Returns:
235
+ A tuple containing the contact forces.
236
+ """
237
+
238
+ # Import qpax just in this method
239
+ import qpax
240
+
241
+ link_forces = (
242
+ link_forces
243
+ if link_forces is not None
244
+ else jnp.zeros((model.number_of_links(), 6))
245
+ )
246
+
247
+ # Compute kin-dyn quantities used in the contact model
248
+ with data.switch_velocity_representation(VelRepr.Mixed):
249
+ M = js.model.free_floating_mass_matrix(model=model, data=data)
250
+ J_WC = js.contact.jacobian(model=model, data=data)
251
+ W_H_C = js.contact.transforms(model=model, data=data)
252
+ terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1])
253
+ n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0]
254
+
255
+ # Compute the activation state of the collidable points
256
+ inactive_collidable_points, h = RigidContacts.detect_contacts(
257
+ W_p_C=position,
258
+ terrain_height=terrain_height,
259
+ )
260
+
261
+ delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
262
+
263
+ # Add regularization for better numerical conditioning
264
+ delassus_matrix = delassus_matrix + regularization_term * jnp.eye(
265
+ delassus_matrix.shape[0]
266
+ )
267
+
268
+ references = js.references.JaxSimModelReferences.build(
269
+ model=model,
270
+ data=data,
271
+ velocity_representation=data.velocity_representation,
272
+ link_forces=link_forces,
273
+ )
274
+
275
+ with references.switch_velocity_representation(VelRepr.Mixed):
276
+ BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
277
+ model, data, references=references
278
+ )
279
+
280
+ free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
281
+ model,
282
+ data,
283
+ BW_ν̇_free,
284
+ ).flatten()
285
+
286
+ # Compute stabilization term
287
+ ḣ = velocity[:, 2].squeeze()
288
+ baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term(
289
+ inactive_collidable_points=inactive_collidable_points,
290
+ h=h,
291
+ ḣ=ḣ,
292
+ K=self.parameters.K,
293
+ D=self.parameters.D,
294
+ ).flatten()
295
+
296
+ free_contact_acc -= baumgarte_term
297
+
298
+ # Setup optimization problem
299
+ Q = delassus_matrix
300
+ q = free_contact_acc
301
+ G = RigidContacts._compute_ineq_constraint_matrix(
302
+ inactive_collidable_points=inactive_collidable_points, mu=self.parameters.mu
303
+ )
304
+ h_bounds = RigidContacts._compute_ineq_bounds(
305
+ n_collidable_points=n_collidable_points
306
+ )
307
+ A = jnp.zeros((0, 3 * n_collidable_points))
308
+ b = jnp.zeros((0,))
309
+
310
+ # Solve the optimization problem
311
+ solution, *_ = qpax.solve_qp(Q=Q, q=q, A=A, b=b, G=G, h=h_bounds)
312
+
313
+ f_C_lin = solution.reshape(-1, 3)
314
+
315
+ # Transform linear contact forces to 6D
316
+ CW_f_C = jnp.hstack(
317
+ (
318
+ f_C_lin,
319
+ jnp.zeros((f_C_lin.shape[0], 3)),
320
+ )
321
+ )
322
+
323
+ # Transform the contact forces to inertial-fixed representation
324
+ W_f_C = jax.vmap(
325
+ lambda CW_f_C, W_H_C: ModelDataWithVelocityRepresentation.other_representation_to_inertial(
326
+ array=CW_f_C,
327
+ transform=W_H_C,
328
+ other_representation=VelRepr.Mixed,
329
+ is_force=True,
330
+ ),
331
+ )(
332
+ CW_f_C,
333
+ W_H_C,
334
+ )
335
+
336
+ return W_f_C, ()
337
+
338
+ @staticmethod
339
+ def _delassus_matrix(
340
+ M: jtp.MatrixLike,
341
+ J_WC: jtp.MatrixLike,
342
+ ) -> jtp.Matrix:
343
+ sl = jnp.s_[:, 0:3, :]
344
+ J_WC_lin = jnp.vstack(J_WC[sl])
345
+
346
+ delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
347
+ return delassus_matrix
348
+
349
+ @staticmethod
350
+ def _compute_ineq_constraint_matrix(
351
+ inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
352
+ ) -> jtp.Matrix:
353
+ def compute_G_single_point(mu: float, c: float) -> jtp.Matrix:
354
+ """
355
+ Compute the inequality constraint matrix for a single collidable point
356
+ Rows 0-3: enforce the friction pyramid constraint,
357
+ Row 4: last one is for the non negativity of the vertical force
358
+ Row 5: contact complementarity condition
359
+ """
360
+ G_single_point = jnp.array(
361
+ [
362
+ [1, 0, -mu],
363
+ [0, 1, -mu],
364
+ [-1, 0, -mu],
365
+ [0, -1, -mu],
366
+ [0, 0, -1],
367
+ [0, 0, c],
368
+ ]
369
+ )
370
+ return G_single_point
371
+
372
+ G = jax.vmap(compute_G_single_point, in_axes=(None, 0))(
373
+ mu, inactive_collidable_points
374
+ )
375
+ G = jax.scipy.linalg.block_diag(*G)
376
+ return G
377
+
378
+ @staticmethod
379
+ def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector:
380
+ n_constraints = 6 * n_collidable_points
381
+ return jnp.zeros(shape=(n_constraints,))
382
+
383
+ @staticmethod
384
+ def _compute_mixed_nu_dot_free(
385
+ model: js.model.JaxSimModel,
386
+ data: js.data.JaxSimModelData,
387
+ references: js.references.JaxSimModelReferences | None = None,
388
+ ) -> jtp.Array:
389
+ references = (
390
+ references
391
+ if references is not None
392
+ else js.references.JaxSimModelReferences.zero(model=model, data=data)
393
+ )
394
+
395
+ with (
396
+ data.switch_velocity_representation(VelRepr.Mixed),
397
+ references.switch_velocity_representation(VelRepr.Mixed),
398
+ ):
399
+ BW_v_WB = data.base_velocity()
400
+ W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
401
+ W_v̇_WB, s̈ = js.ode.system_acceleration(
402
+ model=model,
403
+ data=data,
404
+ joint_forces=references.joint_force_references(model=model),
405
+ link_forces=references.link_forces(model=model, data=data),
406
+ )
407
+
408
+ # Convert the inertial-fixed base acceleration to a mixed base acceleration.
409
+ W_H_B = data.base_transform()
410
+ W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
411
+ BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
412
+ term1 = BW_X_W @ W_v̇_WB
413
+ term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
414
+ BW_v̇_WB = term1 - term2
415
+
416
+ BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
417
+
418
+ return BW_ν̇
419
+
420
+ @staticmethod
421
+ def _linear_acceleration_of_collidable_points(
422
+ model: js.model.JaxSimModel,
423
+ data: js.data.JaxSimModelData,
424
+ mixed_nu_dot: jtp.ArrayLike,
425
+ ) -> jtp.Matrix:
426
+ with data.switch_velocity_representation(VelRepr.Mixed):
427
+ CW_J_WC_BW = js.contact.jacobian(
428
+ model=model,
429
+ data=data,
430
+ output_vel_repr=VelRepr.Mixed,
431
+ )
432
+ CW_J̇_WC_BW = js.contact.jacobian_derivative(
433
+ model=model,
434
+ data=data,
435
+ output_vel_repr=VelRepr.Mixed,
436
+ )
437
+
438
+ BW_ν = data.generalized_velocity()
439
+ BW_ν̇ = mixed_nu_dot
440
+
441
+ CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
442
+ CW_a_WC = CW_a_WC.reshape(-1, 6)
443
+
444
+ return CW_a_WC[:, 0:3].squeeze()
445
+
446
+ @staticmethod
447
+ def _compute_baumgarte_stabilization_term(
448
+ inactive_collidable_points: jtp.ArrayLike,
449
+ h: jtp.ArrayLike,
450
+ ḣ: jtp.ArrayLike,
451
+ K: jtp.FloatLike,
452
+ D: jtp.FloatLike,
453
+ ) -> jtp.Array:
454
+ def baumgarte_stabilization(
455
+ inactive: jtp.BoolLike,
456
+ h: jtp.FloatLike,
457
+ ḣ: jtp.FloatLike,
458
+ k_baumgarte: jtp.FloatLike,
459
+ d_baumgarte: jtp.FloatLike,
460
+ ) -> jtp.Array:
461
+ baumgarte_term = jax.lax.cond(
462
+ inactive,
463
+ lambda h, ḣ, K, D: jnp.zeros(shape=(3,)),
464
+ lambda h, ḣ, K, D: jnp.zeros(shape=(3,)).at[2].set(K * h + D * ḣ),
465
+ *(
466
+ h,
467
+ ḣ,
468
+ k_baumgarte,
469
+ d_baumgarte,
470
+ ),
471
+ )
472
+ return baumgarte_term
473
+
474
+ baumgarte_term = jax.vmap(
475
+ baumgarte_stabilization, in_axes=(0, 0, 0, None, None)
476
+ )(inactive_collidable_points, h, ḣ, K, D)
477
+
478
+ return baumgarte_term
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev17
3
+ Version: 0.4.3.dev18
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -65,6 +65,7 @@ Requires-Dist: jaxlib>=0.4.13
65
65
  Requires-Dist: jaxlie>=1.3.0
66
66
  Requires-Dist: jax-dataclasses>=1.4.0
67
67
  Requires-Dist: pptree
68
+ Requires-Dist: qpax
68
69
  Requires-Dist: rod>=0.3.0
69
70
  Requires-Dist: typing-extensions; python_version < "3.12"
70
71
  Provides-Extra: all
@@ -1,23 +1,23 @@
1
1
  jaxsim/__init__.py,sha256=ixsS4dYMPex2wOUUp_rkPnwrPhYzkRh1xO_YuMj3Cr4,2626
2
- jaxsim/_version.py,sha256=fl1A3KE70N9IF6eJ3t04zIPwnzaTgzXV-7IhkEmXc24,426
2
+ jaxsim/_version.py,sha256=SFJGfO84uy3oOc6jDmWWev6VuJIqHL3tI8_OvaYfdsA,426
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=IbFx3UkEXi-cm7UBqMPi58rJAFV_HbZ9E_K4JwfNvVM,753
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
8
8
  jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
9
- jaxsim/api/contact.py,sha256=Ff3rOYm-o4Atxipzd4VvsnqFVg1-lfV1zFPXJU3MZKM,19847
9
+ jaxsim/api/contact.py,sha256=HyEAjF7BySDDOlRahN0l7V15IPB0HPXuoM0twamuEW0,20913
10
10
  jaxsim/api/data.py,sha256=CUh9lvhVk3_clNQ26BUBGpjvFSsK_PrVWVMEWpMdHRM,27206
11
11
  jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
12
12
  jaxsim/api/joint.py,sha256=L81bQe-noPT6_54KOSF7KBjRmEPAS433ULn2EcXI8vI,5115
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
14
14
  jaxsim/api/link.py,sha256=qPRtc8qqMRjZxUCZYXJMygbB6huDXBfIT1b1b8Durkw,18631
15
- jaxsim/api/model.py,sha256=gU44Cvn5ceYr37jsgWC6rMiRYT7rI0D7I1jf7v-yLLw,63121
16
- jaxsim/api/ode.py,sha256=NnLTBvpaT4kXnbjAghXIzLv9DTMJ8bele2iOlUQDv3Q,11028
17
- jaxsim/api/ode_data.py,sha256=9YZX-SK_KJtoIqG-zYWZsQInb2NA_LtxDn-jtLqm_3U,19759
15
+ jaxsim/api/model.py,sha256=HXoqCtQ3KStGoxhgvFm8P_Sc-lbEM4l5No2MoHzNlOk,65558
16
+ jaxsim/api/ode.py,sha256=Vb2sN4zwpXnaJDD9-ziz2qvfmfa4jvIQ0fONbBIRGmU,13368
17
+ jaxsim/api/ode_data.py,sha256=U7F6TL6bENAxpQQl4PupPoDG7d7VfTTFqDAs3xwu6Hs,20003
18
18
  jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
- jaxsim/integrators/common.py,sha256=Hg_HaFDrlDjG1zmtSTRPuw0ds8wcaLkt5PBFyKhkZWg,20143
20
+ jaxsim/integrators/common.py,sha256=ntjflaV3qWaFH_E65pAGZ6QipdnFsgQDasKtIKpxTe4,20432
21
21
  jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
22
22
  jaxsim/integrators/variable_step.py,sha256=5StkFh9oQba34zlkIoXG2fUN78gbxkHePWbrpQ-QZOI,21274
23
23
  jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
@@ -53,7 +53,8 @@ jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
53
53
  jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
54
54
  jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
55
55
  jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
- jaxsim/rbda/contacts/common.py,sha256=iMKLP30Qft9eGTiHo2iY-UoACJjg1JphA9_pW8wRdjc,2410
56
+ jaxsim/rbda/contacts/common.py,sha256=VwAs742futAmLnDgbaOuLzNDBFiKDfYItdEZ4UcFgzE,2467
57
+ jaxsim/rbda/contacts/rigid.py,sha256=8Vbnxng-ERZ5ka_eZGIBuhBDr2PNjc7m-Or255AfEw4,15862
57
58
  jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
58
59
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
59
60
  jaxsim/terrain/terrain.py,sha256=ctyNANIFSM3tZmamprjaEDcWgUSP0oNJbmT1zw9RjPs,4565
@@ -61,8 +62,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
61
62
  jaxsim/utils/jaxsim_dataclass.py,sha256=5xJbY0G8d7C0OTNIW9T4vQxiDak6TGZT9gpNOvRykFI,11373
62
63
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
63
64
  jaxsim/utils/wrappers.py,sha256=JhLUh1g8iU-lhjbuZRfkscPZhYlLCOorVM2Xl3ulRBI,4054
64
- jaxsim-0.4.3.dev17.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
- jaxsim-0.4.3.dev17.dist-info/METADATA,sha256=st7RonsEdp_bRhPTTjJRZVj4GISB2k_gL4T7v3tBqIw,17227
66
- jaxsim-0.4.3.dev17.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
67
- jaxsim-0.4.3.dev17.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
- jaxsim-0.4.3.dev17.dist-info/RECORD,,
65
+ jaxsim-0.4.3.dev18.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
66
+ jaxsim-0.4.3.dev18.dist-info/METADATA,sha256=aLpRkfa9CC7GVzXMKX3LY5DkCHEmOr4CE-u3Vbt5fx8,17247
67
+ jaxsim-0.4.3.dev18.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
68
+ jaxsim-0.4.3.dev18.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
69
+ jaxsim-0.4.3.dev18.dist-info/RECORD,,