jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py CHANGED
@@ -2,64 +2,99 @@ from typing import Any, Protocol
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
- import jaxlie
6
5
 
7
- import jaxsim.physics.algos.soft_contacts
6
+ import jaxsim.api as js
7
+ import jaxsim.rbda
8
8
  import jaxsim.typing as jtp
9
- from jaxsim import VelRepr, integrators
10
- from jaxsim.integrators.common import Time
11
- from jaxsim.math.quaternion import Quaternion
12
- from jaxsim.physics.algos.soft_contacts import SoftContactsState
13
- from jaxsim.physics.model.physics_model_state import PhysicsModelState
14
- from jaxsim.simulation.ode_data import ODEState
9
+ from jaxsim.integrators import Time
10
+ from jaxsim.math import Quaternion
11
+ from jaxsim.rbda import contacts
15
12
 
16
- from . import contact as Contact
17
- from . import data as Data
18
- from . import model as Model
13
+ from .common import VelRepr
14
+ from .ode_data import ODEState
19
15
 
20
16
 
21
17
  class SystemDynamicsFromModelAndData(Protocol):
18
+ """
19
+ Protocol defining the signature of a function computing the system dynamics
20
+ given a model and data object.
21
+ """
22
+
22
23
  def __call__(
23
24
  self,
24
- model: Model.JaxSimModel,
25
- data: Data.JaxSimModelData,
25
+ model: js.model.JaxSimModel,
26
+ data: js.data.JaxSimModelData,
26
27
  **kwargs: dict[str, Any],
27
- ) -> tuple[ODEState, dict[str, Any]]: ...
28
+ ) -> tuple[ODEState, dict[str, Any]]:
29
+ """
30
+ Compute the system dynamics given a model and data object.
31
+
32
+ Args:
33
+ model: The model to consider.
34
+ data: The data of the considered model.
35
+ **kwargs: Additional keyword arguments.
36
+
37
+ Returns:
38
+ A tuple with an `ODEState` object storing in each of its attributes the
39
+ corresponding derivative, and the dictionary of auxiliary data returned
40
+ by the system dynamics evaluation.
41
+ """
42
+
43
+ pass
28
44
 
29
45
 
30
46
  def wrap_system_dynamics_for_integration(
31
- model: Model.JaxSimModel,
32
- data: Data.JaxSimModelData,
33
47
  *,
34
48
  system_dynamics: SystemDynamicsFromModelAndData,
35
- **kwargs,
49
+ **kwargs: dict[str, Any],
36
50
  ) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
37
51
  """
38
- Wrap generic system dynamics operating on `JaxSimModel` and `JaxSimModelData`
39
- for integration with `jaxsim.integrators`.
52
+ Wrap the system dynamics considered by JaxSim integrators in a generic
53
+ `f(x, t, **u, **parameters)` function.
40
54
 
41
55
  Args:
42
- model: The model to consider.
43
- data: The data of the considered model.
44
56
  system_dynamics: The system dynamics to wrap.
45
57
  **kwargs: Additional kwargs to close over the system dynamics.
46
58
 
47
59
  Returns:
48
- The system dynamics closed over the model, the data, and the additional kwargs.
60
+ The system dynamics closed over the additional kwargs to be used by
61
+ JaxSim integrators.
49
62
  """
50
63
 
51
- # We allow to close `system_dynamics` over additional kwargs.
52
- kwargs_closed = kwargs
53
-
54
- def f(x: ODEState, t: Time, **kwargs) -> tuple[ODEState, dict[str, Any]]:
55
-
56
- # Close f over the `data` parameter.
57
- with data.editable(validate=True) as data_rw:
64
+ # Close `system_dynamics` over additional kwargs.
65
+ # Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
66
+ # following logic to allow the caller to close over arguments having the same name
67
+ # of the ones used in the `wrap_system_dynamics_for_integration` function.
68
+ kwargs = kwargs.copy() if kwargs is not None else {}
69
+ colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
70
+ system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
71
+
72
+ # Remove `model` and `data` for backward compatibility.
73
+ # It's no longer necessary to close over them at this stage, as this is always
74
+ # done in `jaxsim.api.model.step`.
75
+ # We can remove the following lines in a few releases.
76
+ _ = system_dynamics_kwargs.pop("data", None)
77
+ _ = system_dynamics_kwargs.pop("model", None)
78
+
79
+ # Create the function with the signature expected by our generic integrators.
80
+ # Note that our system dynamics is time independent.
81
+ def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
82
+
83
+ # Get the data and model objects from the kwargs.
84
+ data_f = kwargs_f.pop("data")
85
+ model_f = kwargs_f.pop("model")
86
+
87
+ # Update the state and time stored inside data.
88
+ with data_f.editable(validate=True) as data_rw:
58
89
  data_rw.state = x
59
- data_rw.time_ns = jnp.array(t * 1e9).astype(jnp.uint64)
60
90
 
61
- # Close f over the `model` parameter.
62
- return system_dynamics(model=model, data=data_rw, **kwargs_closed | kwargs)
91
+ # Evaluate the system dynamics, allowing to override the kwargs originally
92
+ # passed when the closure was created.
93
+ return system_dynamics(
94
+ model=model_f,
95
+ data=data_rw,
96
+ **(system_dynamics_kwargs | kwargs_f),
97
+ )
63
98
 
64
99
  f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
65
100
  return f
@@ -71,105 +106,217 @@ def wrap_system_dynamics_for_integration(
71
106
 
72
107
 
73
108
  @jax.jit
109
+ @js.common.named_scope
74
110
  def system_velocity_dynamics(
75
- model: Model.JaxSimModel,
76
- data: Data.JaxSimModelData,
111
+ model: js.model.JaxSimModel,
112
+ data: js.data.JaxSimModelData,
77
113
  *,
78
- joint_forces: jtp.Vector | None = None,
79
- external_forces: jtp.Vector | None = None,
80
- ) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
114
+ link_forces: jtp.Vector | None = None,
115
+ joint_force_references: jtp.Vector | None = None,
116
+ ) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
81
117
  """
82
118
  Compute the dynamics of the system velocity.
83
119
 
84
120
  Args:
85
121
  model: The model to consider.
86
122
  data: The data of the considered model.
87
- joint_forces: The joint forces to apply.
88
- external_forces: The external forces to apply to the links.
123
+ link_forces:
124
+ The 6D forces to apply to the links expressed in the frame corresponding to
125
+ the velocity representation of `data`.
126
+ joint_force_references: The joint force references to apply.
89
127
 
90
128
  Returns:
91
129
  A tuple containing the derivative of the base 6D velocity in inertial-fixed
92
- representation, the derivative of the joint velocities, the derivative of
93
- the material deformation, and the dictionary of auxiliary data returned by
94
- the system dynamics evalutation.
130
+ representation, the derivative of the joint velocities, and auxiliary data
131
+ returned by the system dynamics evaluation.
95
132
  """
96
133
 
97
- # Build joint torques if not provided
98
- τ = (
99
- jnp.atleast_1d(joint_forces.squeeze())
100
- if joint_forces is not None
101
- else jnp.zeros_like(data.joint_positions())
102
- ).astype(float)
103
-
104
- # Build external forces if not provided
105
- f_ext = (
106
- jnp.atleast_2d(external_forces.squeeze())
107
- if external_forces is not None
134
+ # Build link forces if not provided.
135
+ # These forces are expressed in the frame corresponding to the velocity
136
+ # representation of data.
137
+ O_f_L = (
138
+ jnp.atleast_2d(link_forces.squeeze())
139
+ if link_forces is not None
108
140
  else jnp.zeros((model.number_of_links(), 6))
109
141
  ).astype(float)
110
142
 
143
+ # We expect that the 6D forces included in the `link_forces` argument are expressed
144
+ # in the frame corresponding to the velocity representation of `data`.
145
+ references = js.references.JaxSimModelReferences.build(
146
+ model=model,
147
+ link_forces=O_f_L,
148
+ joint_force_references=joint_force_references,
149
+ data=data,
150
+ velocity_representation=data.velocity_representation,
151
+ )
152
+
111
153
  # ======================
112
154
  # Compute contact forces
113
155
  # ======================
114
156
 
115
157
  # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
116
158
  # with the terrain.
117
- W_f_Li_terrain = jnp.zeros_like(f_ext).astype(float)
118
-
119
- # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 3} applied to collidable points,
120
- # expressed in the world frame.
121
- W_f_Ci = None
122
-
123
- # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
124
- ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
125
-
126
- if len(model.physics_model.gc.body) > 0:
127
- # Compute the position and linear velocities (mixed representation) of
128
- # all collidable points belonging to the robot.
129
- W_p_Ci, W_ṗ_Ci = Contact.collidable_point_kinematics(model=model, data=data)
130
-
131
- # Compute the 3D forces applied to each collidable point.
132
- W_f_Ci, = jax.vmap(
133
- lambda p, ṗ, m: jaxsim.physics.algos.soft_contacts.SoftContacts(
134
- parameters=data.soft_contacts_params, terrain=model.terrain
135
- ).contact_model(position=p, velocity=ṗ, tangential_deformation=m)
136
- )(W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation.T)
137
-
138
- # Sum the forces of all collidable points rigidly attached to a body.
139
- # Since the contact forces W_f_Ci are expressed in the world frame,
140
- # we don't need any coordinate transformation.
141
- W_f_Li_terrain = jax.vmap(
142
- lambda nc: (
143
- jnp.vstack(
144
- jnp.equal(
145
- np.array(model.physics_model.gc.body, dtype=int), nc
146
- ).astype(int)
147
- )
148
- * W_f_Ci
149
- ).sum(axis=0)
150
- )(jnp.arange(model.number_of_links()))
159
+ W_f_L_terrain = jnp.zeros_like(O_f_L).astype(float)
160
+
161
+ # Initialize a dictionary of auxiliary data.
162
+ # This dictionary is used to store additional data computed by the contact model.
163
+ aux_data = {}
164
+
165
+ if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
166
+
167
+ with (
168
+ data.switch_velocity_representation(VelRepr.Inertial),
169
+ references.switch_velocity_representation(VelRepr.Inertial),
170
+ ):
171
+
172
+ # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
173
+ # along with contact-specific auxiliary states.
174
+ W_f_C, aux_data = js.contact.collidable_point_dynamics(
175
+ model=model,
176
+ data=data,
177
+ link_forces=references.link_forces(model=model, data=data),
178
+ joint_force_references=references.joint_force_references(model=model),
179
+ )
180
+
181
+ # Compute the 6D forces applied to the links equivalent to the forces applied
182
+ # to the frames associated to the collidable points.
183
+ W_f_L_terrain = model.contact_model.link_forces_from_contact_forces(
184
+ model=model,
185
+ data=data,
186
+ contact_forces=W_f_C,
187
+ )
188
+
189
+ # ===========================
190
+ # Compute system acceleration
191
+ # ===========================
192
+
193
+ # Compute the total link forces.
194
+ with (
195
+ data.switch_velocity_representation(VelRepr.Inertial),
196
+ references.switch_velocity_representation(VelRepr.Inertial),
197
+ ):
198
+
199
+ # Sum the contact forces just computed with the link forces applied by the user.
200
+ references = references.apply_link_forces(
201
+ model=model,
202
+ data=data,
203
+ forces=W_f_L_terrain,
204
+ additive=True,
205
+ )
206
+
207
+ # Get the link forces in inertial-fixed representation.
208
+ f_L_total = references.link_forces(model=model, data=data)
209
+
210
+ # Compute the system acceleration in inertial-fixed representation.
211
+ # This representation is useful for integration purpose.
212
+ W_v̇_WB, s̈ = system_acceleration(
213
+ model=model,
214
+ data=data,
215
+ joint_force_references=joint_force_references,
216
+ link_forces=f_L_total,
217
+ )
218
+
219
+ return W_v̇_WB, s̈, aux_data
220
+
221
+
222
+ def system_acceleration(
223
+ model: js.model.JaxSimModel,
224
+ data: js.data.JaxSimModelData,
225
+ *,
226
+ link_forces: jtp.MatrixLike | None = None,
227
+ joint_force_references: jtp.VectorLike | None = None,
228
+ ) -> tuple[jtp.Vector, jtp.Vector]:
229
+ """
230
+ Compute the system acceleration in the active representation.
231
+
232
+ Args:
233
+ model: The model to consider.
234
+ data: The data of the considered model.
235
+ link_forces:
236
+ The 6D forces to apply to the links expressed in the same
237
+ velocity representation of data.
238
+ joint_force_references: The joint force references to apply.
239
+
240
+ Returns:
241
+ A tuple containing the base 6D acceleration in the active representation
242
+ and the joint accelerations.
243
+ """
244
+
245
+ # ====================
246
+ # Validate input data
247
+ # ====================
248
+
249
+ # Build link forces if not provided.
250
+ f_L = (
251
+ jnp.atleast_2d(link_forces.squeeze())
252
+ if link_forces is not None
253
+ else jnp.zeros((model.number_of_links(), 6))
254
+ ).astype(float)
255
+
256
+ # Build joint torques if not provided.
257
+ τ_references = (
258
+ jnp.atleast_1d(joint_force_references.squeeze())
259
+ if joint_force_references is not None
260
+ else jnp.zeros_like(data.joint_positions())
261
+ ).astype(float)
151
262
 
152
263
  # ====================
153
264
  # Enforce joint limits
154
265
  # ====================
155
266
 
156
- # TODO: enforce joint limits
157
- τ_position_limit = jnp.zeros_like(τ).astype(float)
267
+ τ_position_limit = jnp.zeros_like(τ_references).astype(float)
268
+
269
+ if model.dofs() > 0:
270
+
271
+ # Stiffness and damper parameters for the joint position limits.
272
+ k_j = jnp.array(
273
+ model.kin_dyn_parameters.joint_parameters.position_limit_spring
274
+ ).astype(float)
275
+ d_j = jnp.array(
276
+ model.kin_dyn_parameters.joint_parameters.position_limit_damper
277
+ ).astype(float)
278
+
279
+ # Compute the joint position limit violations.
280
+ lower_violation = jnp.clip(
281
+ data.state.physics_model.joint_positions
282
+ - model.kin_dyn_parameters.joint_parameters.position_limits_min,
283
+ max=0.0,
284
+ )
285
+
286
+ upper_violation = jnp.clip(
287
+ data.state.physics_model.joint_positions
288
+ - model.kin_dyn_parameters.joint_parameters.position_limits_max,
289
+ min=0.0,
290
+ )
291
+
292
+ # Compute the joint position limit torque.
293
+ τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
294
+
295
+ τ_position_limit -= (
296
+ jnp.positive(τ_position_limit)
297
+ * jnp.diag(d_j)
298
+ @ data.state.physics_model.joint_velocities
299
+ )
158
300
 
159
301
  # ====================
160
302
  # Joint friction model
161
303
  # ====================
162
304
 
163
- τ_friction = jnp.zeros_like(τ).astype(float)
305
+ τ_friction = jnp.zeros_like(τ_references).astype(float)
164
306
 
165
307
  if model.dofs() > 0:
166
- # Static and viscous joint friction parameters
167
- kc = jnp.array(list(model.physics_model._joint_friction_static.values()))
168
- kv = jnp.array(list(model.physics_model._joint_friction_viscous.values()))
169
308
 
170
- # Compute the joint friction torque
309
+ # Static and viscous joint friction parameters
310
+ kc = jnp.array(
311
+ model.kin_dyn_parameters.joint_parameters.friction_static
312
+ ).astype(float)
313
+ kv = jnp.array(
314
+ model.kin_dyn_parameters.joint_parameters.friction_viscous
315
+ ).astype(float)
316
+
317
+ # Compute the joint friction torque.
171
318
  τ_friction = -(
172
- jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_positions)
319
+ jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
173
320
  + jnp.diag(kv) @ data.state.physics_model.joint_velocities
174
321
  )
175
322
 
@@ -177,28 +324,40 @@ def system_velocity_dynamics(
177
324
  # Compute forward dynamics
178
325
  # ========================
179
326
 
180
- # Compute the total joint forces
181
- τ_total = τ + τ_friction + τ_position_limit
327
+ # Compute the total joint forces.
328
+ τ_total = τ_references + τ_friction + τ_position_limit
182
329
 
183
- # Compute the total external 6D forces applied to the links
184
- W_f_L_total = f_ext + W_f_Li_terrain
330
+ # Store the link forces in a references object.
331
+ references = js.references.JaxSimModelReferences.build(
332
+ model=model,
333
+ data=data,
334
+ velocity_representation=data.velocity_representation,
335
+ link_forces=f_L,
336
+ )
185
337
 
338
+ # Compute forward dynamics.
339
+ #
186
340
  # - Joint accelerations: s̈ ∈ ℝⁿ
187
- # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
188
- with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
189
- W_v̇_WB, = Model.forward_dynamics_aba(
190
- model=model,
191
- data=data,
192
- joint_forces=τ_total,
193
- external_forces=W_f_L_total,
194
- )
341
+ # - Base acceleration: v̇_WB ∈ ℝ⁶
342
+ #
343
+ # Note that ABA returns the base acceleration in the velocity representation
344
+ # stored in the `data` object.
345
+ v̇_WB, s̈ = js.model.forward_dynamics_aba(
346
+ model=model,
347
+ data=data,
348
+ joint_forces=τ_total,
349
+ link_forces=references.link_forces(model=model, data=data),
350
+ )
195
351
 
196
- return W_v̇_WB, s̈, ṁ.T, dict()
352
+ return v̇_WB, s̈
197
353
 
198
354
 
199
355
  @jax.jit
356
+ @js.common.named_scope
200
357
  def system_position_dynamics(
201
- model: Model.JaxSimModel, data: Data.JaxSimModelData
358
+ model: js.model.JaxSimModel,
359
+ data: js.data.JaxSimModelData,
360
+ baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
202
361
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
203
362
  """
204
363
  Compute the dynamics of the system position.
@@ -206,14 +365,16 @@ def system_position_dynamics(
206
365
  Args:
207
366
  model: The model to consider.
208
367
  data: The data of the considered model.
368
+ baumgarte_quaternion_regularization:
369
+ The Baumgarte regularization coefficient for adjusting the quaternion norm.
209
370
 
210
371
  Returns:
211
372
  A tuple containing the derivative of the base position, the derivative of the
212
373
  base quaternion, and the derivative of the joint positions.
213
374
  """
214
375
 
215
- ṡ = data.state.physics_model.joint_velocities
216
- W_Q_B = data.state.physics_model.base_quaternion
376
+ ṡ = data.joint_velocities(model=model)
377
+ W_Q_B = data.base_orientation(dcm=False)
217
378
 
218
379
  with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
219
380
  W_ṗ_B = data.base_velocity()[0:3]
@@ -225,18 +386,21 @@ def system_position_dynamics(
225
386
  quaternion=W_Q_B,
226
387
  omega=W_ω_WB,
227
388
  omega_in_body_fixed=False,
389
+ K=baumgarte_quaternion_regularization,
228
390
  ).squeeze()
229
391
 
230
392
  return W_ṗ_B, W_Q̇_B, ṡ
231
393
 
232
394
 
233
395
  @jax.jit
396
+ @js.common.named_scope
234
397
  def system_dynamics(
235
- model: Model.JaxSimModel,
236
- data: Data.JaxSimModelData,
398
+ model: js.model.JaxSimModel,
399
+ data: js.data.JaxSimModelData,
237
400
  *,
238
- joint_forces: jtp.Vector | None = None,
239
- external_forces: jtp.Vector | None = None,
401
+ link_forces: jtp.Vector | None = None,
402
+ joint_force_references: jtp.Vector | None = None,
403
+ baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
240
404
  ) -> tuple[ODEState, dict[str, Any]]:
241
405
  """
242
406
  Compute the dynamics of the system.
@@ -244,8 +408,13 @@ def system_dynamics(
244
408
  Args:
245
409
  model: The model to consider.
246
410
  data: The data of the considered model.
247
- joint_forces: The joint forces to apply.
248
- external_forces: The external forces to apply to the links.
411
+ link_forces:
412
+ The 6D forces to apply to the links expressed in the frame corresponding to
413
+ the velocity representation of `data`.
414
+ joint_force_references: The joint force references to apply.
415
+ baumgarte_quaternion_regularization:
416
+ The Baumgarte regularization coefficient used to adjust the norm of the
417
+ quaternion (only used in integrators not operating on the SO(3) manifold).
249
418
 
250
419
  Returns:
251
420
  A tuple with an `ODEState` object storing in each of its attributes the
@@ -254,31 +423,53 @@ def system_dynamics(
254
423
  """
255
424
 
256
425
  # Compute the accelerations and the material deformation rate.
257
- W_v̇_WB, s̈, ṁ, aux_dict = system_velocity_dynamics(
426
+ W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
258
427
  model=model,
259
428
  data=data,
260
- joint_forces=joint_forces,
261
- external_forces=external_forces,
429
+ joint_force_references=joint_force_references,
430
+ link_forces=link_forces,
262
431
  )
263
432
 
433
+ # Initialize the dictionary storing the derivative of the additional state variables
434
+ # that extend the state vector of the integrated ODE system.
435
+ extended_ode_state = {}
436
+
437
+ match model.contact_model:
438
+
439
+ case contacts.SoftContacts():
440
+ extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]
441
+
442
+ case contacts.ViscoElasticContacts():
443
+
444
+ extended_ode_state["tangential_deformation"] = jnp.zeros_like(
445
+ data.state.extended["tangential_deformation"]
446
+ )
447
+
448
+ case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
449
+ pass
450
+
451
+ case _:
452
+ raise ValueError(f"Invalid contact model: {model.contact_model}")
453
+
264
454
  # Extract the velocities.
265
- W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(model=model, data=data)
455
+ W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
456
+ model=model,
457
+ data=data,
458
+ baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
459
+ )
266
460
 
267
461
  # Create an ODEState object populated with the derivative of each leaf.
268
462
  # Our integrators, operating on generic pytrees, will be able to handle it
269
463
  # automatically as state derivative.
270
- ode_state_derivative = ODEState.build(
271
- physics_model_state=PhysicsModelState.build(
272
- base_position=W_ṗ_B,
273
- base_quaternion=W_Q̇_B,
274
- joint_positions=ṡ,
275
- base_linear_velocity=W_v̇_WB[0:3],
276
- base_angular_velocity=W_v̇_WB[3:6],
277
- joint_velocities=s̈,
278
- ),
279
- soft_contacts_state=SoftContactsState.build(
280
- tangential_deformation=ṁ,
281
- ),
464
+ ode_state_derivative = ODEState.build_from_jaxsim_model(
465
+ model=model,
466
+ base_position=W_ṗ_B,
467
+ base_quaternion=W_Q̇_B,
468
+ joint_positions=ṡ,
469
+ base_linear_velocity=W_v̇_WB[0:3],
470
+ base_angular_velocity=W_v̇_WB[3:6],
471
+ joint_velocities=s̈,
472
+ **extended_ode_state,
282
473
  )
283
474
 
284
475
  return ode_state_derivative, aux_dict