jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py CHANGED
@@ -2,64 +2,99 @@ from typing import Any, Protocol
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
- import jaxlie
6
5
 
7
- import jaxsim.physics.algos.soft_contacts
6
+ import jaxsim.api as js
7
+ import jaxsim.rbda
8
8
  import jaxsim.typing as jtp
9
- from jaxsim import VelRepr, integrators
10
- from jaxsim.integrators.common import Time
11
- from jaxsim.math.quaternion import Quaternion
12
- from jaxsim.physics.algos.soft_contacts import SoftContactsState
13
- from jaxsim.physics.model.physics_model_state import PhysicsModelState
14
- from jaxsim.simulation.ode_data import ODEState
9
+ from jaxsim.integrators import Time
10
+ from jaxsim.math import Quaternion
11
+ from jaxsim.rbda import contacts
15
12
 
16
- from . import contact as Contact
17
- from . import data as Data
18
- from . import model as Model
13
+ from .common import VelRepr
14
+ from .ode_data import ODEState
19
15
 
20
16
 
21
17
  class SystemDynamicsFromModelAndData(Protocol):
18
+ """
19
+ Protocol defining the signature of a function computing the system dynamics
20
+ given a model and data object.
21
+ """
22
+
22
23
  def __call__(
23
24
  self,
24
- model: Model.JaxSimModel,
25
- data: Data.JaxSimModelData,
25
+ model: js.model.JaxSimModel,
26
+ data: js.data.JaxSimModelData,
26
27
  **kwargs: dict[str, Any],
27
- ) -> tuple[ODEState, dict[str, Any]]: ...
28
+ ) -> tuple[ODEState, dict[str, Any]]:
29
+ """
30
+ Compute the system dynamics given a model and data object.
31
+
32
+ Args:
33
+ model: The model to consider.
34
+ data: The data of the considered model.
35
+ **kwargs: Additional keyword arguments.
36
+
37
+ Returns:
38
+ A tuple with an `ODEState` object storing in each of its attributes the
39
+ corresponding derivative, and the dictionary of auxiliary data returned
40
+ by the system dynamics evaluation.
41
+ """
42
+
43
+ pass
28
44
 
29
45
 
30
46
  def wrap_system_dynamics_for_integration(
31
- model: Model.JaxSimModel,
32
- data: Data.JaxSimModelData,
33
47
  *,
34
48
  system_dynamics: SystemDynamicsFromModelAndData,
35
- **kwargs,
49
+ **kwargs: dict[str, Any],
36
50
  ) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
37
51
  """
38
- Wrap generic system dynamics operating on `JaxSimModel` and `JaxSimModelData`
39
- for integration with `jaxsim.integrators`.
52
+ Wrap the system dynamics considered by JaxSim integrators in a generic
53
+ `f(x, t, **u, **parameters)` function.
40
54
 
41
55
  Args:
42
- model: The model to consider.
43
- data: The data of the considered model.
44
56
  system_dynamics: The system dynamics to wrap.
45
57
  **kwargs: Additional kwargs to close over the system dynamics.
46
58
 
47
59
  Returns:
48
- The system dynamics closed over the model, the data, and the additional kwargs.
60
+ The system dynamics closed over the additional kwargs to be used by
61
+ JaxSim integrators.
49
62
  """
50
63
 
51
- # We allow to close `system_dynamics` over additional kwargs.
52
- kwargs_closed = kwargs
53
-
54
- def f(x: ODEState, t: Time, **kwargs) -> tuple[ODEState, dict[str, Any]]:
55
-
56
- # Close f over the `data` parameter.
57
- with data.editable(validate=True) as data_rw:
64
+ # Close `system_dynamics` over additional kwargs.
65
+ # Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
66
+ # following logic to allow the caller to close over arguments having the same name
67
+ # of the ones used in the `wrap_system_dynamics_for_integration` function.
68
+ kwargs = kwargs.copy() if kwargs is not None else {}
69
+ colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
70
+ system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
71
+
72
+ # Remove `model` and `data` for backward compatibility.
73
+ # It's no longer necessary to close over them at this stage, as this is always
74
+ # done in `jaxsim.api.model.step`.
75
+ # We can remove the following lines in a few releases.
76
+ _ = system_dynamics_kwargs.pop("data", None)
77
+ _ = system_dynamics_kwargs.pop("model", None)
78
+
79
+ # Create the function with the signature expected by our generic integrators.
80
+ # Note that our system dynamics is time independent.
81
+ def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
82
+
83
+ # Get the data and model objects from the kwargs.
84
+ data_f = kwargs_f.pop("data")
85
+ model_f = kwargs_f.pop("model")
86
+
87
+ # Update the state and time stored inside data.
88
+ with data_f.editable(validate=True) as data_rw:
58
89
  data_rw.state = x
59
- data_rw.time_ns = jnp.array(t * 1e9).astype(jnp.uint64)
60
90
 
61
- # Close f over the `model` parameter.
62
- return system_dynamics(model=model, data=data_rw, **kwargs_closed | kwargs)
91
+ # Evaluate the system dynamics, allowing to override the kwargs originally
92
+ # passed when the closure was created.
93
+ return system_dynamics(
94
+ model=model_f,
95
+ data=data_rw,
96
+ **(system_dynamics_kwargs | kwargs_f),
97
+ )
63
98
 
64
99
  f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
65
100
  return f
@@ -71,101 +106,217 @@ def wrap_system_dynamics_for_integration(
71
106
 
72
107
 
73
108
  @jax.jit
109
+ @js.common.named_scope
74
110
  def system_velocity_dynamics(
75
- model: Model.JaxSimModel,
76
- data: Data.JaxSimModelData,
111
+ model: js.model.JaxSimModel,
112
+ data: js.data.JaxSimModelData,
77
113
  *,
78
- joint_forces: jtp.Vector | None = None,
79
- external_forces: jtp.Vector | None = None,
80
- ) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
114
+ link_forces: jtp.Vector | None = None,
115
+ joint_force_references: jtp.Vector | None = None,
116
+ ) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
81
117
  """
82
118
  Compute the dynamics of the system velocity.
83
119
 
84
120
  Args:
85
121
  model: The model to consider.
86
122
  data: The data of the considered model.
87
- joint_forces: The joint forces to apply.
88
- external_forces: The external forces to apply to the links.
123
+ link_forces:
124
+ The 6D forces to apply to the links expressed in the frame corresponding to
125
+ the velocity representation of `data`.
126
+ joint_force_references: The joint force references to apply.
89
127
 
90
128
  Returns:
91
129
  A tuple containing the derivative of the base 6D velocity in inertial-fixed
92
- representation, the derivative of the joint velocities, the derivative of
93
- the material deformation, and the dictionary of auxiliary data returned by
94
- the system dynamics evalutation.
130
+ representation, the derivative of the joint velocities, and auxiliary data
131
+ returned by the system dynamics evaluation.
95
132
  """
96
133
 
97
- # Build joint torques if not provided
98
- τ = (
99
- jnp.atleast_1d(joint_forces.squeeze())
100
- if joint_forces is not None
101
- else jnp.zeros_like(data.joint_positions())
102
- ).astype(float)
103
-
104
- # Build external forces if not provided
105
- f_ext = (
106
- jnp.atleast_2d(external_forces.squeeze())
107
- if external_forces is not None
134
+ # Build link forces if not provided.
135
+ # These forces are expressed in the frame corresponding to the velocity
136
+ # representation of data.
137
+ O_f_L = (
138
+ jnp.atleast_2d(link_forces.squeeze())
139
+ if link_forces is not None
108
140
  else jnp.zeros((model.number_of_links(), 6))
109
141
  ).astype(float)
110
142
 
143
+ # We expect that the 6D forces included in the `link_forces` argument are expressed
144
+ # in the frame corresponding to the velocity representation of `data`.
145
+ references = js.references.JaxSimModelReferences.build(
146
+ model=model,
147
+ link_forces=O_f_L,
148
+ joint_force_references=joint_force_references,
149
+ data=data,
150
+ velocity_representation=data.velocity_representation,
151
+ )
152
+
111
153
  # ======================
112
154
  # Compute contact forces
113
155
  # ======================
114
156
 
115
157
  # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
116
158
  # with the terrain.
117
- W_f_Li_terrain = jnp.zeros_like(f_ext).astype(float)
118
-
119
- # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 3} applied to collidable points,
120
- # expressed in the world frame.
121
- W_f_Ci = None
122
-
123
- # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
124
- ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
125
-
126
- if model.physics_model.gc.body.size > 0:
127
- # Compute the position and linear velocities (mixed representation) of
128
- # all collidable points belonging to the robot.
129
- W_p_Ci, W_ṗ_Ci = Contact.collidable_point_kinematics(model=model, data=data)
130
-
131
- # Compute the 3D forces applied to each collidable point.
132
- W_f_Ci, = jax.vmap(
133
- lambda p, ṗ, m: jaxsim.physics.algos.soft_contacts.SoftContacts(
134
- parameters=data.soft_contacts_params, terrain=model.terrain
135
- ).contact_model(position=p, velocity=ṗ, tangential_deformation=m)
136
- )(W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation.T)
137
-
138
- # Sum the forces of all collidable points rigidly attached to a body.
139
- # Since the contact forces W_f_Ci are expressed in the world frame,
140
- # we don't need any coordinate transformation.
141
- W_f_Li_terrain = jax.vmap(
142
- lambda nc: (
143
- jnp.vstack(jnp.equal(model.physics_model.gc.body, nc).astype(int))
144
- * W_f_Ci
145
- ).sum(axis=0)
146
- )(jnp.arange(model.number_of_links()))
159
+ W_f_L_terrain = jnp.zeros_like(O_f_L).astype(float)
160
+
161
+ # Initialize a dictionary of auxiliary data.
162
+ # This dictionary is used to store additional data computed by the contact model.
163
+ aux_data = {}
164
+
165
+ if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
166
+
167
+ with (
168
+ data.switch_velocity_representation(VelRepr.Inertial),
169
+ references.switch_velocity_representation(VelRepr.Inertial),
170
+ ):
171
+
172
+ # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
173
+ # along with contact-specific auxiliary states.
174
+ W_f_C, aux_data = js.contact.collidable_point_dynamics(
175
+ model=model,
176
+ data=data,
177
+ link_forces=references.link_forces(model=model, data=data),
178
+ joint_force_references=references.joint_force_references(model=model),
179
+ )
180
+
181
+ # Compute the 6D forces applied to the links equivalent to the forces applied
182
+ # to the frames associated to the collidable points.
183
+ W_f_L_terrain = model.contact_model.link_forces_from_contact_forces(
184
+ model=model,
185
+ data=data,
186
+ contact_forces=W_f_C,
187
+ )
188
+
189
+ # ===========================
190
+ # Compute system acceleration
191
+ # ===========================
192
+
193
+ # Compute the total link forces.
194
+ with (
195
+ data.switch_velocity_representation(VelRepr.Inertial),
196
+ references.switch_velocity_representation(VelRepr.Inertial),
197
+ ):
198
+
199
+ # Sum the contact forces just computed with the link forces applied by the user.
200
+ references = references.apply_link_forces(
201
+ model=model,
202
+ data=data,
203
+ forces=W_f_L_terrain,
204
+ additive=True,
205
+ )
206
+
207
+ # Get the link forces in inertial-fixed representation.
208
+ f_L_total = references.link_forces(model=model, data=data)
209
+
210
+ # Compute the system acceleration in inertial-fixed representation.
211
+ # This representation is useful for integration purpose.
212
+ W_v̇_WB, s̈ = system_acceleration(
213
+ model=model,
214
+ data=data,
215
+ joint_force_references=joint_force_references,
216
+ link_forces=f_L_total,
217
+ )
218
+
219
+ return W_v̇_WB, s̈, aux_data
220
+
221
+
222
+ def system_acceleration(
223
+ model: js.model.JaxSimModel,
224
+ data: js.data.JaxSimModelData,
225
+ *,
226
+ link_forces: jtp.MatrixLike | None = None,
227
+ joint_force_references: jtp.VectorLike | None = None,
228
+ ) -> tuple[jtp.Vector, jtp.Vector]:
229
+ """
230
+ Compute the system acceleration in the active representation.
231
+
232
+ Args:
233
+ model: The model to consider.
234
+ data: The data of the considered model.
235
+ link_forces:
236
+ The 6D forces to apply to the links expressed in the same
237
+ velocity representation of data.
238
+ joint_force_references: The joint force references to apply.
239
+
240
+ Returns:
241
+ A tuple containing the base 6D acceleration in the active representation
242
+ and the joint accelerations.
243
+ """
244
+
245
+ # ====================
246
+ # Validate input data
247
+ # ====================
248
+
249
+ # Build link forces if not provided.
250
+ f_L = (
251
+ jnp.atleast_2d(link_forces.squeeze())
252
+ if link_forces is not None
253
+ else jnp.zeros((model.number_of_links(), 6))
254
+ ).astype(float)
255
+
256
+ # Build joint torques if not provided.
257
+ τ_references = (
258
+ jnp.atleast_1d(joint_force_references.squeeze())
259
+ if joint_force_references is not None
260
+ else jnp.zeros_like(data.joint_positions())
261
+ ).astype(float)
147
262
 
148
263
  # ====================
149
264
  # Enforce joint limits
150
265
  # ====================
151
266
 
152
- # TODO: enforce joint limits
153
- τ_position_limit = jnp.zeros_like(τ).astype(float)
267
+ τ_position_limit = jnp.zeros_like(τ_references).astype(float)
268
+
269
+ if model.dofs() > 0:
270
+
271
+ # Stiffness and damper parameters for the joint position limits.
272
+ k_j = jnp.array(
273
+ model.kin_dyn_parameters.joint_parameters.position_limit_spring
274
+ ).astype(float)
275
+ d_j = jnp.array(
276
+ model.kin_dyn_parameters.joint_parameters.position_limit_damper
277
+ ).astype(float)
278
+
279
+ # Compute the joint position limit violations.
280
+ lower_violation = jnp.clip(
281
+ data.state.physics_model.joint_positions
282
+ - model.kin_dyn_parameters.joint_parameters.position_limits_min,
283
+ max=0.0,
284
+ )
285
+
286
+ upper_violation = jnp.clip(
287
+ data.state.physics_model.joint_positions
288
+ - model.kin_dyn_parameters.joint_parameters.position_limits_max,
289
+ min=0.0,
290
+ )
291
+
292
+ # Compute the joint position limit torque.
293
+ τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
294
+
295
+ τ_position_limit -= (
296
+ jnp.positive(τ_position_limit)
297
+ * jnp.diag(d_j)
298
+ @ data.state.physics_model.joint_velocities
299
+ )
154
300
 
155
301
  # ====================
156
302
  # Joint friction model
157
303
  # ====================
158
304
 
159
- τ_friction = jnp.zeros_like(τ).astype(float)
305
+ τ_friction = jnp.zeros_like(τ_references).astype(float)
160
306
 
161
307
  if model.dofs() > 0:
162
- # Static and viscous joint friction parameters
163
- kc = jnp.array(list(model.physics_model._joint_friction_static.values()))
164
- kv = jnp.array(list(model.physics_model._joint_friction_viscous.values()))
165
308
 
166
- # Compute the joint friction torque
309
+ # Static and viscous joint friction parameters
310
+ kc = jnp.array(
311
+ model.kin_dyn_parameters.joint_parameters.friction_static
312
+ ).astype(float)
313
+ kv = jnp.array(
314
+ model.kin_dyn_parameters.joint_parameters.friction_viscous
315
+ ).astype(float)
316
+
317
+ # Compute the joint friction torque.
167
318
  τ_friction = -(
168
- jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_positions)
319
+ jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
169
320
  + jnp.diag(kv) @ data.state.physics_model.joint_velocities
170
321
  )
171
322
 
@@ -173,28 +324,40 @@ def system_velocity_dynamics(
173
324
  # Compute forward dynamics
174
325
  # ========================
175
326
 
176
- # Compute the total joint forces
177
- τ_total = τ + τ_friction + τ_position_limit
327
+ # Compute the total joint forces.
328
+ τ_total = τ_references + τ_friction + τ_position_limit
178
329
 
179
- # Compute the total external 6D forces applied to the links
180
- W_f_L_total = f_ext + W_f_Li_terrain
330
+ # Store the link forces in a references object.
331
+ references = js.references.JaxSimModelReferences.build(
332
+ model=model,
333
+ data=data,
334
+ velocity_representation=data.velocity_representation,
335
+ link_forces=f_L,
336
+ )
181
337
 
338
+ # Compute forward dynamics.
339
+ #
182
340
  # - Joint accelerations: s̈ ∈ ℝⁿ
183
- # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
184
- with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
185
- W_v̇_WB, = Model.forward_dynamics_aba(
186
- model=model,
187
- data=data,
188
- joint_forces=τ_total,
189
- external_forces=W_f_L_total,
190
- )
341
+ # - Base acceleration: v̇_WB ∈ ℝ⁶
342
+ #
343
+ # Note that ABA returns the base acceleration in the velocity representation
344
+ # stored in the `data` object.
345
+ v̇_WB, s̈ = js.model.forward_dynamics_aba(
346
+ model=model,
347
+ data=data,
348
+ joint_forces=τ_total,
349
+ link_forces=references.link_forces(model=model, data=data),
350
+ )
191
351
 
192
- return W_v̇_WB, s̈, ṁ.T, dict()
352
+ return v̇_WB, s̈
193
353
 
194
354
 
195
355
  @jax.jit
356
+ @js.common.named_scope
196
357
  def system_position_dynamics(
197
- model: Model.JaxSimModel, data: Data.JaxSimModelData
358
+ model: js.model.JaxSimModel,
359
+ data: js.data.JaxSimModelData,
360
+ baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
198
361
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
199
362
  """
200
363
  Compute the dynamics of the system position.
@@ -202,14 +365,16 @@ def system_position_dynamics(
202
365
  Args:
203
366
  model: The model to consider.
204
367
  data: The data of the considered model.
368
+ baumgarte_quaternion_regularization:
369
+ The Baumgarte regularization coefficient for adjusting the quaternion norm.
205
370
 
206
371
  Returns:
207
372
  A tuple containing the derivative of the base position, the derivative of the
208
373
  base quaternion, and the derivative of the joint positions.
209
374
  """
210
375
 
211
- ṡ = data.state.physics_model.joint_velocities
212
- W_Q_B = data.state.physics_model.base_quaternion
376
+ ṡ = data.joint_velocities(model=model)
377
+ W_Q_B = data.base_orientation(dcm=False)
213
378
 
214
379
  with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
215
380
  W_ṗ_B = data.base_velocity()[0:3]
@@ -221,18 +386,21 @@ def system_position_dynamics(
221
386
  quaternion=W_Q_B,
222
387
  omega=W_ω_WB,
223
388
  omega_in_body_fixed=False,
389
+ K=baumgarte_quaternion_regularization,
224
390
  ).squeeze()
225
391
 
226
392
  return W_ṗ_B, W_Q̇_B, ṡ
227
393
 
228
394
 
229
395
  @jax.jit
396
+ @js.common.named_scope
230
397
  def system_dynamics(
231
- model: Model.JaxSimModel,
232
- data: Data.JaxSimModelData,
398
+ model: js.model.JaxSimModel,
399
+ data: js.data.JaxSimModelData,
233
400
  *,
234
- joint_forces: jtp.Vector | None = None,
235
- external_forces: jtp.Vector | None = None,
401
+ link_forces: jtp.Vector | None = None,
402
+ joint_force_references: jtp.Vector | None = None,
403
+ baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
236
404
  ) -> tuple[ODEState, dict[str, Any]]:
237
405
  """
238
406
  Compute the dynamics of the system.
@@ -240,8 +408,13 @@ def system_dynamics(
240
408
  Args:
241
409
  model: The model to consider.
242
410
  data: The data of the considered model.
243
- joint_forces: The joint forces to apply.
244
- external_forces: The external forces to apply to the links.
411
+ link_forces:
412
+ The 6D forces to apply to the links expressed in the frame corresponding to
413
+ the velocity representation of `data`.
414
+ joint_force_references: The joint force references to apply.
415
+ baumgarte_quaternion_regularization:
416
+ The Baumgarte regularization coefficient used to adjust the norm of the
417
+ quaternion (only used in integrators not operating on the SO(3) manifold).
245
418
 
246
419
  Returns:
247
420
  A tuple with an `ODEState` object storing in each of its attributes the
@@ -250,31 +423,53 @@ def system_dynamics(
250
423
  """
251
424
 
252
425
  # Compute the accelerations and the material deformation rate.
253
- W_v̇_WB, s̈, ṁ, aux_dict = system_velocity_dynamics(
426
+ W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
254
427
  model=model,
255
428
  data=data,
256
- joint_forces=joint_forces,
257
- external_forces=external_forces,
429
+ joint_force_references=joint_force_references,
430
+ link_forces=link_forces,
258
431
  )
259
432
 
433
+ # Initialize the dictionary storing the derivative of the additional state variables
434
+ # that extend the state vector of the integrated ODE system.
435
+ extended_ode_state = {}
436
+
437
+ match model.contact_model:
438
+
439
+ case contacts.SoftContacts():
440
+ extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]
441
+
442
+ case contacts.ViscoElasticContacts():
443
+
444
+ extended_ode_state["tangential_deformation"] = jnp.zeros_like(
445
+ data.state.extended["tangential_deformation"]
446
+ )
447
+
448
+ case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
449
+ pass
450
+
451
+ case _:
452
+ raise ValueError(f"Invalid contact model: {model.contact_model}")
453
+
260
454
  # Extract the velocities.
261
- W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(model=model, data=data)
455
+ W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
456
+ model=model,
457
+ data=data,
458
+ baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
459
+ )
262
460
 
263
461
  # Create an ODEState object populated with the derivative of each leaf.
264
462
  # Our integrators, operating on generic pytrees, will be able to handle it
265
463
  # automatically as state derivative.
266
- ode_state_derivative = ODEState.build(
267
- physics_model_state=PhysicsModelState.build(
268
- base_position=W_ṗ_B,
269
- base_quaternion=W_Q̇_B,
270
- joint_positions=ṡ,
271
- base_linear_velocity=W_v̇_WB[0:3],
272
- base_angular_velocity=W_v̇_WB[3:6],
273
- joint_velocities=s̈,
274
- ),
275
- soft_contacts_state=SoftContactsState.build(
276
- tangential_deformation=ṁ,
277
- ),
464
+ ode_state_derivative = ODEState.build_from_jaxsim_model(
465
+ model=model,
466
+ base_position=W_ṗ_B,
467
+ base_quaternion=W_Q̇_B,
468
+ joint_positions=ṡ,
469
+ base_linear_velocity=W_v̇_WB[0:3],
470
+ base_angular_velocity=W_v̇_WB[3:6],
471
+ joint_velocities=s̈,
472
+ **extended_ode_state,
278
473
  )
279
474
 
280
475
  return ode_state_derivative, aux_dict