jaxsim 0.4.3.dev139__py3-none-any.whl → 0.4.3.dev155__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev139'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev139')
15
+ __version__ = version = '0.4.3.dev155'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev155')
jaxsim/api/contact.py CHANGED
@@ -144,7 +144,8 @@ def collidable_point_dynamics(
144
144
  The joint force references to apply to the joints.
145
145
 
146
146
  Returns:
147
- The 6D force applied to each collidable point and additional data based on the contact model configured:
147
+ The 6D force applied to each collidable point and additional data based
148
+ on the contact model configured:
148
149
  - Soft: the material deformation rate.
149
150
  - Rigid: no additional data.
150
151
  - QuasiRigid: no additional data.
@@ -156,21 +157,13 @@ def collidable_point_dynamics(
156
157
  """
157
158
 
158
159
  # Import privately the contacts classes.
159
- from jaxsim.rbda.contacts import (
160
- RelaxedRigidContacts,
161
- RelaxedRigidContactsState,
162
- RigidContacts,
163
- RigidContactsState,
164
- SoftContacts,
165
- SoftContactsState,
166
- )
160
+ from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts
167
161
 
168
162
  # Build the soft contact model.
169
163
  match model.contact_model:
170
164
 
171
165
  case SoftContacts():
172
166
  assert isinstance(model.contact_model, SoftContacts)
173
- assert isinstance(data.state.contact, SoftContactsState)
174
167
 
175
168
  # Compute the 6D force expressed in the inertial frame and applied to each
176
169
  # collidable point, and the corresponding material deformation rate.
@@ -187,7 +180,6 @@ def collidable_point_dynamics(
187
180
 
188
181
  case RigidContacts():
189
182
  assert isinstance(model.contact_model, RigidContacts)
190
- assert isinstance(data.state.contact, RigidContactsState)
191
183
 
192
184
  # Compute the 6D force expressed in the inertial frame and applied to each
193
185
  # collidable point.
@@ -203,7 +195,6 @@ def collidable_point_dynamics(
203
195
 
204
196
  case RelaxedRigidContacts():
205
197
  assert isinstance(model.contact_model, RelaxedRigidContacts)
206
- assert isinstance(data.state.contact, RelaxedRigidContactsState)
207
198
 
208
199
  # Compute the 6D force expressed in the inertial frame and applied to each
209
200
  # collidable point.
jaxsim/api/data.py CHANGED
@@ -13,7 +13,6 @@ import jaxsim.api as js
13
13
  import jaxsim.math
14
14
  import jaxsim.rbda
15
15
  import jaxsim.typing as jtp
16
- from jaxsim.rbda.contacts import SoftContacts
17
16
  from jaxsim.utils import Mutability
18
17
  from jaxsim.utils.tracing import not_tracing
19
18
 
@@ -107,17 +106,17 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
107
106
  @staticmethod
108
107
  def build(
109
108
  model: js.model.JaxSimModel,
110
- base_position: jtp.Vector | None = None,
111
- base_quaternion: jtp.Vector | None = None,
112
- joint_positions: jtp.Vector | None = None,
113
- base_linear_velocity: jtp.Vector | None = None,
114
- base_angular_velocity: jtp.Vector | None = None,
115
- joint_velocities: jtp.Vector | None = None,
109
+ base_position: jtp.VectorLike | None = None,
110
+ base_quaternion: jtp.VectorLike | None = None,
111
+ joint_positions: jtp.VectorLike | None = None,
112
+ base_linear_velocity: jtp.VectorLike | None = None,
113
+ base_angular_velocity: jtp.VectorLike | None = None,
114
+ joint_velocities: jtp.VectorLike | None = None,
116
115
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
117
- contact: jaxsim.rbda.contacts.ContactsState | None = None,
118
116
  contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
119
117
  velocity_representation: VelRepr = VelRepr.Inertial,
120
118
  time: jtp.FloatLike | None = None,
119
+ extended_ode_state: dict[str, jtp.PyTree] | None = None,
121
120
  ) -> JaxSimModelData:
122
121
  """
123
122
  Create a `JaxSimModelData` object with the given state.
@@ -133,56 +132,73 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
133
132
  The base angular velocity in the selected representation.
134
133
  joint_velocities: The joint velocities.
135
134
  standard_gravity: The standard gravity constant.
136
- contact: The state of the soft contacts.
137
135
  contacts_params: The parameters of the soft contacts.
138
136
  velocity_representation: The velocity representation to use.
139
137
  time: The time at which the state is created.
138
+ extended_ode_state:
139
+ Additional user-defined state variables that are not part of the
140
+ standard `ODEState` object. Useful to extend the system dynamics
141
+ considered by default in JaxSim.
140
142
 
141
143
  Returns:
142
- A `JaxSimModelData` object with the given state.
144
+ A `JaxSimModelData` initialized with the given state.
143
145
  """
144
146
 
145
147
  base_position = jnp.array(
146
- base_position if base_position is not None else jnp.zeros(3)
148
+ base_position if base_position is not None else jnp.zeros(3),
149
+ dtype=float,
147
150
  ).squeeze()
148
151
 
149
152
  base_quaternion = jnp.array(
150
- base_quaternion
151
- if base_quaternion is not None
152
- else jnp.array([1.0, 0, 0, 0])
153
+ (
154
+ base_quaternion
155
+ if base_quaternion is not None
156
+ else jnp.array([1.0, 0, 0, 0])
157
+ ),
158
+ dtype=float,
153
159
  ).squeeze()
154
160
 
155
161
  base_linear_velocity = jnp.array(
156
- base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
162
+ base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3),
163
+ dtype=float,
157
164
  ).squeeze()
158
165
 
159
166
  base_angular_velocity = jnp.array(
160
- base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
167
+ (
168
+ base_angular_velocity
169
+ if base_angular_velocity is not None
170
+ else jnp.zeros(3)
171
+ ),
172
+ dtype=float,
161
173
  ).squeeze()
162
174
 
163
175
  gravity = jnp.zeros(3).at[2].set(-standard_gravity)
164
176
 
165
177
  joint_positions = jnp.atleast_1d(
166
- joint_positions.squeeze()
167
- if joint_positions is not None
168
- else jnp.zeros(model.dofs())
178
+ jnp.array(
179
+ (
180
+ joint_positions
181
+ if joint_positions is not None
182
+ else jnp.zeros(model.dofs())
183
+ ),
184
+ dtype=float,
185
+ ).squeeze()
169
186
  )
170
187
 
171
188
  joint_velocities = jnp.atleast_1d(
172
- joint_velocities.squeeze()
173
- if joint_velocities is not None
174
- else jnp.zeros(model.dofs())
189
+ jnp.array(
190
+ (
191
+ joint_velocities
192
+ if joint_velocities is not None
193
+ else jnp.zeros(model.dofs())
194
+ ),
195
+ dtype=float,
196
+ ).squeeze()
175
197
  )
176
198
 
177
- time_ns = (
178
- jnp.array(
179
- time * 1e9,
180
- dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
181
- )
182
- if time is not None
183
- else jnp.array(
184
- 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
185
- )
199
+ time_ns = jnp.array(
200
+ time * 1e9 if time is not None else 0.0,
201
+ dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
186
202
  )
187
203
 
188
204
  W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
@@ -194,21 +210,22 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
194
210
  other_representation=velocity_representation,
195
211
  transform=W_H_B,
196
212
  is_force=False,
197
- )
213
+ ).astype(float)
198
214
 
199
215
  ode_state = ODEState.build_from_jaxsim_model(
200
216
  model=model,
201
- base_position=base_position.astype(float),
202
- base_quaternion=base_quaternion.astype(float),
203
- joint_positions=joint_positions.astype(float),
204
- base_linear_velocity=v_WB[0:3].astype(float),
205
- base_angular_velocity=v_WB[3:6].astype(float),
206
- joint_velocities=joint_velocities.astype(float),
207
- tangential_deformation=(
208
- contact.tangential_deformation
209
- if contact is not None and isinstance(model.contact_model, SoftContacts)
210
- else None
211
- ),
217
+ base_position=base_position,
218
+ base_quaternion=base_quaternion,
219
+ joint_positions=joint_positions,
220
+ base_linear_velocity=v_WB[0:3],
221
+ base_angular_velocity=v_WB[3:6],
222
+ joint_velocities=joint_velocities,
223
+ # Unpack all the additional ODE states. If the contact model requires an
224
+ # additional state that is not explicitly passed to this builder, ODEState
225
+ # automatically populates that state with zeroed variables.
226
+ # This is not true for any other custom state that the user might want to
227
+ # pass to the integrator.
228
+ **(extended_ode_state if extended_ode_state else {}),
212
229
  )
213
230
 
214
231
  if not ode_state.valid(model=model):
@@ -220,13 +237,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
220
237
  contacts_params = js.contact.estimate_good_soft_contacts_parameters(
221
238
  model=model, standard_gravity=standard_gravity
222
239
  )
240
+
223
241
  else:
224
242
  contacts_params = model.contact_model.parameters
225
243
 
226
244
  return JaxSimModelData(
227
245
  time_ns=time_ns,
228
246
  state=ode_state,
229
- gravity=gravity.astype(float),
247
+ gravity=gravity,
230
248
  contacts_params=contacts_params,
231
249
  velocity_representation=velocity_representation,
232
250
  )
jaxsim/api/model.py CHANGED
@@ -33,7 +33,7 @@ class JaxSimModel(JaxsimDataclass):
33
33
  model_name: Static[str]
34
34
 
35
35
  terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
36
- default=jaxsim.terrain.FlatTerrain(), repr=False
36
+ default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
37
37
  )
38
38
 
39
39
  contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
@@ -101,13 +101,14 @@ class JaxSimModel(JaxsimDataclass):
101
101
  A path to an SDF/URDF file, a string containing
102
102
  its content, or a pre-parsed/pre-built rod model.
103
103
  model_name:
104
- The optional name of the model that overrides the one in
105
- the description.
106
- terrain:
107
- The optional terrain to consider.
104
+ The name of the model. If not specified, it is read from the description.
105
+ terrain: The terrain to consider (the default is a flat infinite plane).
106
+ contact_model:
107
+ The contact model to consider.
108
+ If not specified, a soft contacts model is used.
108
109
  is_urdf:
109
- The optional flag to force the model description to be parsed as a
110
- URDF or a SDF. This is otherwise automatically inferred.
110
+ The optional flag to force the model description to be parsed as a URDF.
111
+ This is usually automatically inferred.
111
112
  considered_joints:
112
113
  The list of joints to consider. If None, all joints are considered.
113
114
 
@@ -120,7 +121,7 @@ class JaxSimModel(JaxsimDataclass):
120
121
  # Parse the input resource (either a path to file or a string with the URDF/SDF)
121
122
  # and build the -intermediate- model description.
122
123
  intermediate_description = jaxsim.parsers.rod.build_model_description(
123
- model_description=model_description
124
+ model_description=model_description, is_urdf=is_urdf
124
125
  )
125
126
 
126
127
  # Lump links together if not all joints are considered.
@@ -160,11 +161,11 @@ class JaxSimModel(JaxsimDataclass):
160
161
  The intermediate model description defining the kinematics and dynamics
161
162
  of the model.
162
163
  model_name:
163
- The optional name of the model overriding the physics model name.
164
- terrain:
165
- The optional terrain to consider.
164
+ The name of the model. If not specified, it is read from the description.
165
+ terrain: The terrain to consider (the default is a flat infinite plane).
166
166
  contact_model:
167
- The optional contact model to consider. If None, the soft contact model is used.
167
+ The contact model to consider.
168
+ If not specified, a soft contacts model is used.
168
169
 
169
170
  Returns:
170
171
  The built Model object.
@@ -173,21 +174,31 @@ class JaxSimModel(JaxsimDataclass):
173
174
  # Set the model name (if not provided, use the one from the model description).
174
175
  model_name = model_name if model_name is not None else model_description.name
175
176
 
176
- # Set the terrain (if not provided, use the default flat terrain).
177
- terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
178
- contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts(
179
- terrain=terrain
177
+ # Consider the default terrain (a flat infinite plane) if not specified.
178
+ terrain = (
179
+ terrain or JaxSimModel.__dataclass_fields__["terrain"].default_factory()
180
+ )
181
+
182
+ # Create the default contact model.
183
+ # It will be populated with an initial estimation of good parameters.
184
+ # While these might not be the best, they are a good starting point.
185
+ contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts.build(
186
+ terrain=terrain, parameters=None
180
187
  )
181
188
 
182
189
  # Build the model.
183
190
  model = JaxSimModel(
184
191
  model_name=model_name,
185
- _description=wrappers.HashlessObject(obj=model_description),
186
192
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
187
193
  model_description=model_description
188
194
  ),
189
195
  terrain=terrain,
190
196
  contact_model=contact_model,
197
+ # The following is wrapped as hashless since it's a static argument, and we
198
+ # don't want to trigger recompilation if it changes. All relevant parameters
199
+ # needed to compute kinematics and dynamics quantities are stored in the
200
+ # kin_dyn_parameters attribute.
201
+ _description=wrappers.HashlessObject(obj=model_description),
191
202
  )
192
203
 
193
204
  return model
@@ -1907,8 +1918,8 @@ def step(
1907
1918
  dt: jtp.FloatLike,
1908
1919
  integrator: jaxsim.integrators.Integrator,
1909
1920
  integrator_state: dict[str, Any] | None = None,
1910
- joint_forces: jtp.VectorLike | None = None,
1911
1921
  link_forces: jtp.MatrixLike | None = None,
1922
+ joint_force_references: jtp.VectorLike | None = None,
1912
1923
  **kwargs,
1913
1924
  ) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
1914
1925
  """
@@ -1920,10 +1931,10 @@ def step(
1920
1931
  dt: The time step to consider.
1921
1932
  integrator: The integrator to use.
1922
1933
  integrator_state: The state of the integrator.
1923
- joint_forces: The joint forces to consider.
1924
1934
  link_forces:
1925
1935
  The 6D forces to apply to the links expressed in the frame corresponding to
1926
1936
  the velocity representation of `data`.
1937
+ joint_force_references: The joint force references to consider.
1927
1938
  kwargs: Additional kwargs to pass to the integrator.
1928
1939
 
1929
1940
  Returns:
@@ -1953,7 +1964,7 @@ def step(
1953
1964
  params=integrator_state_x0,
1954
1965
  # Always inject the current (model, data) pair into the system dynamics
1955
1966
  # considered by the integrator, and include the input variables represented
1956
- # by the pair (joint_forces, link_forces).
1967
+ # by the pair (joint_force_references, link_forces).
1957
1968
  # Note that the wrapper of the system dynamics will override (state_x0, t0)
1958
1969
  # inside the passed data even if it is not strictly needed. This logic is
1959
1970
  # necessary to re-use the jit-compiled step function of compatible pytrees
@@ -1962,7 +1973,7 @@ def step(
1962
1973
  dict(
1963
1974
  model=model,
1964
1975
  data=data,
1965
- joint_forces=joint_forces,
1976
+ joint_force_references=joint_force_references,
1966
1977
  link_forces=link_forces,
1967
1978
  )
1968
1979
  | integrator_kwargs
jaxsim/api/ode.py CHANGED
@@ -86,8 +86,8 @@ def system_velocity_dynamics(
86
86
  model: js.model.JaxSimModel,
87
87
  data: js.data.JaxSimModelData,
88
88
  *,
89
- joint_forces: jtp.Vector | None = None,
90
89
  link_forces: jtp.Vector | None = None,
90
+ joint_force_references: jtp.Vector | None = None,
91
91
  ) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
92
92
  """
93
93
  Compute the dynamics of the system velocity.
@@ -95,10 +95,10 @@ def system_velocity_dynamics(
95
95
  Args:
96
96
  model: The model to consider.
97
97
  data: The data of the considered model.
98
- joint_forces: The joint force references to apply.
99
98
  link_forces:
100
99
  The 6D forces to apply to the links expressed in the frame corresponding to
101
100
  the velocity representation of `data`.
101
+ joint_force_references: The joint force references to apply.
102
102
 
103
103
  Returns:
104
104
  A tuple containing the derivative of the base 6D velocity in inertial-fixed
@@ -120,7 +120,7 @@ def system_velocity_dynamics(
120
120
  references = js.references.JaxSimModelReferences.build(
121
121
  model=model,
122
122
  link_forces=O_f_L,
123
- joint_force_references=joint_forces,
123
+ joint_force_references=joint_force_references,
124
124
  data=data,
125
125
  velocity_representation=data.velocity_representation,
126
126
  )
@@ -192,7 +192,10 @@ def system_velocity_dynamics(
192
192
  f_L_total = references.link_forces(model=model, data=data)
193
193
 
194
194
  v̇_WB, s̈ = system_acceleration(
195
- model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
195
+ model=model,
196
+ data=data,
197
+ joint_force_references=joint_force_references,
198
+ link_forces=f_L_total,
196
199
  )
197
200
 
198
201
  return v̇_WB, s̈, aux_data
@@ -202,8 +205,8 @@ def system_acceleration(
202
205
  model: js.model.JaxSimModel,
203
206
  data: js.data.JaxSimModelData,
204
207
  *,
205
- joint_forces: jtp.VectorLike | None = None,
206
208
  link_forces: jtp.MatrixLike | None = None,
209
+ joint_force_references: jtp.VectorLike | None = None,
207
210
  ) -> tuple[jtp.Vector, jtp.Vector]:
208
211
  """
209
212
  Compute the system acceleration in the active representation.
@@ -211,12 +214,13 @@ def system_acceleration(
211
214
  Args:
212
215
  model: The model to consider.
213
216
  data: The data of the considered model.
214
- joint_forces: The joint forces to apply.
215
217
  link_forces:
216
- The 6D forces to apply to the links expressed in the same representation of data.
218
+ The 6D forces to apply to the links expressed in the same
219
+ velocity representation of data.
220
+ joint_force_references: The joint force references to apply.
217
221
 
218
222
  Returns:
219
- A tuple containing the base 6D acceleration in in the active representation
223
+ A tuple containing the base 6D acceleration in the active representation
220
224
  and the joint accelerations.
221
225
  """
222
226
 
@@ -232,9 +236,9 @@ def system_acceleration(
232
236
  ).astype(float)
233
237
 
234
238
  # Build joint torques if not provided.
235
- τ = (
236
- jnp.atleast_1d(joint_forces.squeeze())
237
- if joint_forces is not None
239
+ τ_references = (
240
+ jnp.atleast_1d(joint_force_references.squeeze())
241
+ if joint_force_references is not None
238
242
  else jnp.zeros_like(data.joint_positions())
239
243
  ).astype(float)
240
244
 
@@ -243,15 +247,16 @@ def system_acceleration(
243
247
  # ====================
244
248
 
245
249
  # TODO: enforce joint limits
246
- τ_position_limit = jnp.zeros_like(τ).astype(float)
250
+ τ_position_limit = jnp.zeros_like(τ_references).astype(float)
247
251
 
248
252
  # ====================
249
253
  # Joint friction model
250
254
  # ====================
251
255
 
252
- τ_friction = jnp.zeros_like(τ).astype(float)
256
+ τ_friction = jnp.zeros_like(τ_references).astype(float)
253
257
 
254
258
  if model.dofs() > 0:
259
+
255
260
  # Static and viscous joint friction parameters
256
261
  kc = jnp.array(
257
262
  model.kin_dyn_parameters.joint_parameters.friction_static
@@ -271,22 +276,27 @@ def system_acceleration(
271
276
  # ========================
272
277
 
273
278
  # Compute the total joint forces.
274
- τ_total = τ + τ_friction + τ_position_limit
279
+ τ_total = τ_references + τ_friction + τ_position_limit
275
280
 
281
+ # Store the link forces in a references object.
276
282
  references = js.references.JaxSimModelReferences.build(
277
283
  model=model,
278
284
  data=data,
279
285
  velocity_representation=data.velocity_representation,
280
- joint_force_references=τ_total,
281
286
  link_forces=f_L,
282
287
  )
283
288
 
289
+ # Compute forward dynamics.
290
+ #
284
291
  # - Joint accelerations: s̈ ∈ ℝⁿ
285
292
  # - Base acceleration: v̇_WB ∈ ℝ⁶
293
+ #
294
+ # Note that ABA returns the base acceleration in the velocity representation
295
+ # stored in the `data` object.
286
296
  v̇_WB, s̈ = js.model.forward_dynamics_aba(
287
297
  model=model,
288
298
  data=data,
289
- joint_forces=references.joint_force_references(model=model),
299
+ joint_forces=τ_total,
290
300
  link_forces=references.link_forces(model=model, data=data),
291
301
  )
292
302
 
@@ -337,8 +347,8 @@ def system_dynamics(
337
347
  model: js.model.JaxSimModel,
338
348
  data: js.data.JaxSimModelData,
339
349
  *,
340
- joint_forces: jtp.Vector | None = None,
341
350
  link_forces: jtp.Vector | None = None,
351
+ joint_force_references: jtp.Vector | None = None,
342
352
  baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
343
353
  ) -> tuple[ODEState, dict[str, Any]]:
344
354
  """
@@ -347,10 +357,10 @@ def system_dynamics(
347
357
  Args:
348
358
  model: The model to consider.
349
359
  data: The data of the considered model.
350
- joint_forces: The joint forces to apply.
351
360
  link_forces:
352
361
  The 6D forces to apply to the links expressed in the frame corresponding to
353
362
  the velocity representation of `data`.
363
+ joint_force_references: The joint force references to apply.
354
364
  baumgarte_quaternion_regularization:
355
365
  The Baumgarte regularization coefficient used to adjust the norm of the
356
366
  quaternion (only used in integrators not operating on the SO(3) manifold).
@@ -360,29 +370,31 @@ def system_dynamics(
360
370
  corresponding derivative, and the dictionary of auxiliary data returned
361
371
  by the system dynamics evaluation.
362
372
  """
363
- from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
364
- from jaxsim.rbda.contacts.rigid import RigidContacts
365
- from jaxsim.rbda.contacts.soft import SoftContacts
373
+
374
+ from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts
366
375
 
367
376
  # Compute the accelerations and the material deformation rate.
368
377
  W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
369
378
  model=model,
370
379
  data=data,
371
- joint_forces=joint_forces,
380
+ joint_force_references=joint_force_references,
372
381
  link_forces=link_forces,
373
382
  )
374
383
 
375
- ode_state_kwargs = {}
384
+ # Initialize the dictionary storing the derivative of the additional state variables
385
+ # that extend the state vector of the integrated ODE system.
386
+ extended_ode_state = {}
376
387
 
377
388
  match model.contact_model:
389
+
378
390
  case SoftContacts():
379
- ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
391
+ extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]
380
392
 
381
393
  case RigidContacts() | RelaxedRigidContacts():
382
394
  pass
383
395
 
384
396
  case _:
385
- raise ValueError("Unable to determine contact state class prefix.")
397
+ raise ValueError(f"Invalid contact model {model.contact_model}")
386
398
 
387
399
  # Extract the velocities.
388
400
  W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
@@ -402,7 +414,7 @@ def system_dynamics(
402
414
  base_linear_velocity=W_v̇_WB[0:3],
403
415
  base_angular_velocity=W_v̇_WB[3:6],
404
416
  joint_velocities=s̈,
405
- **ode_state_kwargs,
417
+ **extended_ode_state,
406
418
  )
407
419
 
408
420
  return ode_state_derivative, aux_dict