jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,16 @@
1
1
  import abc
2
2
  import dataclasses
3
- from typing import Any, ClassVar, Generic, Protocol, Type, TypeVar
3
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import jax_dataclasses
8
8
  from jax_dataclasses import Static
9
9
 
10
+ import jaxsim.api as js
11
+ import jaxsim.math
10
12
  import jaxsim.typing as jtp
13
+ from jaxsim import exceptions, logging
11
14
  from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
12
15
 
13
16
  try:
@@ -25,17 +28,33 @@ except ImportError:
25
28
  # Generic types
26
29
  # =============
27
30
 
28
- Time = jax.typing.ArrayLike
29
- TimeStep = jax.typing.ArrayLike
31
+ Time = jtp.FloatLike
32
+ TimeStep = jtp.FloatLike
30
33
  State = NextState = TypeVar("State")
31
34
  StateDerivative = TypeVar("StateDerivative")
32
35
  PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
33
36
 
34
37
 
35
38
  class SystemDynamics(Protocol[State, StateDerivative]):
39
+ """
40
+ Protocol defining the system dynamics.
41
+ """
42
+
36
43
  def __call__(
37
44
  self, x: State, t: Time, **kwargs
38
- ) -> tuple[StateDerivative, dict[str, Any]]: ...
45
+ ) -> tuple[StateDerivative, dict[str, Any]]:
46
+ """
47
+ Compute the state derivative of the system.
48
+
49
+ Args:
50
+ x: The state of the system.
51
+ t: The time of the system.
52
+ **kwargs: Additional keyword arguments.
53
+
54
+ Returns:
55
+ The state derivative of the system and the auxiliary dictionary.
56
+ """
57
+ pass
39
58
 
40
59
 
41
60
  # =======================
@@ -45,20 +64,20 @@ class SystemDynamics(Protocol[State, StateDerivative]):
45
64
 
46
65
  @jax_dataclasses.pytree_dataclass
47
66
  class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
48
-
49
- AuxDictDynamicsKey: ClassVar[str] = "aux_dict_dynamics"
67
+ """
68
+ Factory class for integrators.
69
+ """
50
70
 
51
71
  dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
52
72
  repr=False, hash=False, compare=False, kw_only=True
53
73
  )
54
74
 
55
- params: dict[str, Any] = dataclasses.field(
56
- default_factory=dict, repr=False, hash=False, compare=False, kw_only=True
57
- )
58
-
59
75
  @classmethod
60
76
  def build(
61
- cls: Type[Self], *, dynamics: SystemDynamics[State, StateDerivative], **kwargs
77
+ cls: type[Self],
78
+ *,
79
+ dynamics: SystemDynamics[State, StateDerivative],
80
+ **kwargs,
62
81
  ) -> Self:
63
82
  """
64
83
  Build the integrator object.
@@ -71,7 +90,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
71
90
  The integrator object.
72
91
  """
73
92
 
74
- return cls(dynamics=dynamics, **kwargs) # noqa
93
+ return cls(dynamics=dynamics, **kwargs)
75
94
 
76
95
  def step(
77
96
  self,
@@ -79,9 +98,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
79
98
  t0: Time,
80
99
  dt: TimeStep,
81
100
  *,
82
- params: dict[str, Any],
101
+ metadata: dict[str, Any] | None = None,
83
102
  **kwargs,
84
- ) -> tuple[State, dict[str, Any]]:
103
+ ) -> tuple[NextState, dict[str, Any]]:
85
104
  """
86
105
  Perform a single integration step.
87
106
 
@@ -89,25 +108,30 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
89
108
  x0: The initial state of the system.
90
109
  t0: The initial time of the system.
91
110
  dt: The time step of the integration.
92
- params: The auxiliary dictionary of the integrator.
111
+ metadata: The state auxiliary dictionary of the integrator.
93
112
  **kwargs: Additional keyword arguments.
94
113
 
95
114
  Returns:
96
115
  The final state of the system and the updated auxiliary dictionary.
97
116
  """
98
117
 
99
- with self.editable(validate=False) as integrator:
100
- integrator.params = params
101
-
102
- with integrator.mutable_context(mutability=Mutability.MUTABLE):
103
- xf = integrator(x0, t0, dt, **kwargs)
118
+ metadata = metadata if metadata is not None else {}
104
119
 
105
- assert Integrator.AuxDictDynamicsKey in integrator.params
120
+ with self.mutable_context(mutability=Mutability.MUTABLE) as integrator:
121
+ xf, metadata_step = integrator(x0, t0, dt, **kwargs)
106
122
 
107
- return xf, integrator.params
123
+ return (
124
+ xf,
125
+ metadata | metadata_step,
126
+ )
108
127
 
109
128
  @abc.abstractmethod
110
- def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
129
+ def __call__(
130
+ self, x0: State, t0: Time, dt: TimeStep, **kwargs
131
+ ) -> tuple[NextState, dict[str, Any]]:
132
+ """
133
+ Perform a single integration step.
134
+ """
111
135
  pass
112
136
 
113
137
  def init(
@@ -116,56 +140,44 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
116
140
  t0: Time,
117
141
  dt: TimeStep,
118
142
  *,
119
- key: jax.Array | None = None,
143
+ include_dynamics_aux_dict: bool = False,
120
144
  **kwargs,
121
145
  ) -> dict[str, Any]:
122
146
  """
123
- Initialize the integrator.
124
-
125
- Args:
126
- x0: The initial state of the system.
127
- t0: The initial time of the system.
128
- dt: The time step of the integration.
129
- key: An optional random key to initialize the integrator.
130
-
131
- Returns:
132
- The auxiliary dictionary of the integrator.
133
-
134
- Note:
135
- This method should have the same signature as the inherited `__call__`
136
- method, including additional kwargs.
137
-
138
- Note:
139
- If the integrator supports FSAL, the pair `(x0, t0)` must match the real
140
- initial state and time of the system, otherwise the initial derivative of
141
- the first step will be wrong.
147
+ Initialize the integrator. This method is deprecated.
142
148
  """
143
149
 
144
- _, aux_dict_dynamics = self.dynamics(x0, t0)
145
-
146
- with self.editable(validate=False) as integrator:
147
- _ = integrator(x0, t0, dt, **kwargs)
148
- aux_dict_step = integrator.params
149
-
150
- if Integrator.AuxDictDynamicsKey in aux_dict_dynamics:
151
- msg = "You cannot create a key '{}' in the __call__ method."
152
- raise KeyError(msg.format(Integrator.AuxDictDynamicsKey))
150
+ logging.warning(
151
+ "The 'init' method has been deprecated. There is no need to call it."
152
+ )
153
153
 
154
- return {Integrator.AuxDictDynamicsKey: aux_dict_dynamics} | aux_dict_step
154
+ return {}
155
155
 
156
156
 
157
157
  @jax_dataclasses.pytree_dataclass
158
158
  class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
159
+ """
160
+ Base class for explicit Runge-Kutta integrators.
161
+
162
+ Attributes:
163
+ A: The Runge-Kutta matrix.
164
+ b: The weights coefficients.
165
+ c: The nodes coefficients.
166
+ order_of_bT_rows: The order of the solution.
167
+ row_index_of_solution: The row of the integration output corresponding to the final solution.
168
+ fsal_enabled_if_supported: Whether to enable the FSAL property, if supported.
169
+ index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration.
170
+ """
159
171
 
160
172
  # The Runge-Kutta matrix.
161
- A: ClassVar[jax.typing.ArrayLike]
173
+ A: jtp.Matrix
162
174
 
163
175
  # The weights coefficients.
164
176
  # Note that in practice we typically use its transpose `b.transpose()`.
165
- b: ClassVar[jax.typing.ArrayLike]
177
+ b: jtp.Matrix
166
178
 
167
179
  # The nodes coefficients.
168
- c: ClassVar[jax.typing.ArrayLike]
180
+ c: jtp.Vector
169
181
 
170
182
  # Define the order of the solution.
171
183
  # It should have as many elements as the number of rows of `b.transpose()`.
@@ -181,16 +193,22 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
181
193
 
182
194
  @property
183
195
  def has_fsal(self) -> bool:
196
+ """
197
+ Check if the integrator supports the FSAL property.
198
+ """
184
199
  return self.fsal_enabled_if_supported and self.index_of_fsal is not None
185
200
 
186
201
  @property
187
202
  def order(self) -> int:
203
+ """
204
+ Return the order of the integrator.
205
+ """
188
206
  return self.order_of_bT_rows[self.row_index_of_solution]
189
207
 
190
208
  @override
191
209
  @classmethod
192
210
  def build(
193
- cls: Type[Self],
211
+ cls: type[Self],
194
212
  *,
195
213
  dynamics: SystemDynamics[State, StateDerivative],
196
214
  fsal_enabled_if_supported: jtp.BoolLike = True,
@@ -208,37 +226,32 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
208
226
  Returns:
209
227
  The integrator object.
210
228
  """
211
-
212
- # Adjust the shape of the tableau coefficients.
213
- c = jnp.atleast_1d(cls.c.squeeze())
214
- b = jnp.atleast_2d(jnp.vstack(cls.b.squeeze()))
215
- A = jnp.atleast_2d(cls.A.squeeze())
229
+ A = cls.__dataclass_fields__["A"].default_factory()
230
+ b = cls.__dataclass_fields__["b"].default_factory()
231
+ c = cls.__dataclass_fields__["c"].default_factory()
216
232
 
217
233
  # Check validity of the Butcher tableau.
218
234
  if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
219
235
  raise ValueError("The Butcher tableau of this class is not valid.")
220
236
 
221
- # Store the adjusted shapes of the tableau coefficients.
222
- cls.c = c
223
- cls.b = b
224
- cls.A = A
225
-
226
237
  # Check that b.T has enough rows based on the configured index of the solution.
227
- if cls.row_index_of_solution >= cls.b.T.shape[0]:
238
+ if cls.row_index_of_solution >= b.T.shape[0]:
228
239
  msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
229
- raise ValueError(msg.format(cls.row_index_of_solution, cls.b.T.shape[0]))
240
+ raise ValueError(msg.format(cls.row_index_of_solution, b.T.shape[0]))
230
241
 
231
242
  # Check that the tuple containing the order of the b.T rows matches the number
232
243
  # of the b.T rows.
233
- if len(cls.order_of_bT_rows) != cls.b.T.shape[0]:
244
+ if len(cls.order_of_bT_rows) != b.T.shape[0]:
234
245
  msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}."
235
- raise ValueError(msg.format(len(cls.order_of_bT_rows), cls.b.T.shape[0]))
246
+ raise ValueError(msg.format(len(cls.order_of_bT_rows), b.T.shape[0]))
236
247
 
237
248
  # Check if the Butcher tableau supports FSAL (first-same-as-last).
238
249
  # If it does, store the index of the intermediate derivative to be used as the
239
250
  # first derivative of the next iteration.
240
- has_fsal, index_of_fsal = ExplicitRungeKutta.butcher_tableau_supports_fsal(
241
- A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
251
+ has_fsal, index_of_fsal = ( # noqa: F841
252
+ ExplicitRungeKutta.butcher_tableau_supports_fsal(
253
+ A=A, b=b, c=c, index_of_solution=cls.row_index_of_solution
254
+ )
242
255
  )
243
256
 
244
257
  # Build the integrator object.
@@ -251,15 +264,22 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
251
264
 
252
265
  return integrator
253
266
 
254
- def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
267
+ def __call__(
268
+ self, x0: State, t0: Time, dt: TimeStep, **kwargs
269
+ ) -> tuple[NextState, dict[str, Any]]:
270
+ """
271
+ Perform a single integration step.
272
+ """
255
273
 
256
274
  # Here z is a batched state with as many batch elements as b.T rows.
257
275
  # Note that z has multiple batches only if b.T has more than one row,
258
276
  # e.g. in Butcher tableau of embedded schemes.
259
- z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
277
+ z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
260
278
 
261
279
  # The next state is the batch element located at the configured index of solution.
262
- return jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
280
+ next_state = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
281
+
282
+ return next_state, aux_dict
263
283
 
264
284
  @classmethod
265
285
  def integrate_rk_stage(
@@ -294,13 +314,13 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
294
314
  """
295
315
 
296
316
  op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
297
- return jax.tree_util.tree_map(op, x0, k)
317
+ return jax.tree.map(op, x0, k)
298
318
 
299
319
  @classmethod
300
320
  def post_process_state(
301
321
  cls, x0: State, t0: Time, xf: NextState, dt: TimeStep
302
322
  ) -> NextState:
303
- """
323
+ r"""
304
324
  Post-process the integrated state at :math:`t_f = t_0 + \Delta t`.
305
325
 
306
326
  Args:
@@ -317,7 +337,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
317
337
 
318
338
  def _compute_next_state(
319
339
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
320
- ) -> NextState:
340
+ ) -> tuple[NextState, dict[str, Any]]:
321
341
  """
322
342
  Compute the next state of the system, returning all the output states.
323
343
 
@@ -337,33 +357,42 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
337
357
  b = self.b
338
358
  A = self.A
339
359
 
360
+ # Extract metadata from the kwargs.
361
+ metadata = kwargs.pop("metadata", {})
362
+
340
363
  # Close f over optional kwargs.
341
364
  f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
342
365
 
343
366
  # Initialize the carry of the for loop with the stacked kᵢ vectors.
344
- carry0 = jax.tree_util.tree_map(
345
- lambda l: jnp.repeat(jnp.zeros_like(l)[jnp.newaxis, ...], c.size, axis=0),
346
- x0,
367
+ carry0 = jax.tree.map(
368
+ lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
347
369
  )
348
370
 
349
- # Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
350
- get_ẋ0 = lambda: self.params.get("dxdt0", f(x0, t0)[0])
371
+ # Closure on metadata to either evaluate the dynamics at the initial state
372
+ # or to use the previous state derivative (only integrators supporting FSAL).
373
+ def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
374
+ ẋ0, aux_dict = f(x0, t0)
375
+ return metadata.get("dxdt0", ẋ0), aux_dict
351
376
 
352
377
  # We use a `jax.lax.scan` to compile the `f` function only once.
353
378
  # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
354
379
  # would include 4 repetitions of the `f` logic, making everything extremely slow.
355
- def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[jax.Array, None]:
356
- """"""
380
+ def scan_body(
381
+ carry: jax.Array, i: int | jax.Array
382
+ ) -> tuple[jax.Array, dict[str, Any]]:
383
+ """
384
+ Compute the kᵢ derivative of the Runge-Kutta stage.
385
+ """
357
386
 
358
387
  # Unpack the carry, i.e. the stacked kᵢ vectors.
359
388
  K = carry
360
389
 
361
390
  # Define the computation of the Runge-Kutta stage.
362
- def compute_ki() -> jax.Array:
391
+ def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
363
392
 
364
- # Compute ∑ⱼ aᵢⱼ k
393
+ # Compute ∑ⱼ aᵢⱼ kⱼ.
365
394
  op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
366
- sum_ak = jax.tree_util.tree_map(op_sum_ak, K)
395
+ sum_ak = jax.tree.map(op_sum_ak, K)
367
396
 
368
397
  # Compute the next state for the kᵢ evaluation.
369
398
  # Note that this is not a Δt integration since aᵢⱼ could be fractional.
@@ -372,25 +401,26 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
372
401
  # Compute the next time for the kᵢ evaluation.
373
402
  ti = t0 + c[i] * Δt
374
403
 
375
- # This is kᵢ = f(xᵢ, tᵢ).
376
- return f(xi, ti)[0]
404
+ # Evaluate the dynamics.
405
+ ki, aux_dict = f(xi, ti)
406
+ return ki, aux_dict
377
407
 
378
408
  # This selector enables FSAL property in the first iteration (i=0).
379
- ki = jax.lax.cond(
409
+ ki, aux_dict = jax.lax.cond(
380
410
  pred=jnp.logical_and(i == 0, self.has_fsal),
381
- true_fun=get_ẋ0,
411
+ true_fun=get_ẋ0_and_aux_dict,
382
412
  false_fun=compute_ki,
383
413
  )
384
414
 
385
415
  # Store the kᵢ derivative in K.
386
416
  op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
387
- K = jax.tree_util.tree_map(op, K, ki)
417
+ K = jax.tree.map(op, K, ki)
388
418
 
389
419
  carry = K
390
- return carry, None
420
+ return carry, aux_dict
391
421
 
392
422
  # Compute the state derivatives kᵢ.
393
- K, _ = jax.lax.scan(
423
+ K, aux_dict = jax.lax.scan(
394
424
  f=scan_body,
395
425
  init=carry0,
396
426
  xs=jnp.arange(c.size),
@@ -398,12 +428,13 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
398
428
 
399
429
  # Update the FSAL property for the next iteration.
400
430
  if self.has_fsal:
401
- self.params["dxdt0"] = jax.tree_map(lambda l: l[self.index_of_fsal], K)
431
+ # Store the first derivative of the next step in the metadata.
432
+ metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
402
433
 
403
434
  # Compute the output state.
404
435
  # Note that z contains as many new states as the rows of `b.T`.
405
436
  op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k)
406
- z = jax.tree_util.tree_map(op, x0, K)
437
+ z = jax.tree.map(op, x0, K)
407
438
 
408
439
  # Transform the final state of the integration.
409
440
  # This allows to inject custom logic, if needed.
@@ -411,11 +442,11 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
411
442
  lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
412
443
  )(z)
413
444
 
414
- return z_transformed
445
+ return z_transformed, aux_dict | {"metadata": metadata}
415
446
 
416
447
  @staticmethod
417
448
  def butcher_tableau_is_valid(
418
- A: jax.typing.ArrayLike, b: jax.typing.ArrayLike, c: jax.typing.ArrayLike
449
+ A: jtp.Matrix, b: jtp.Matrix, c: jtp.Vector
419
450
  ) -> jtp.Bool:
420
451
  """
421
452
  Check if the Butcher tableau is valid.
@@ -441,7 +472,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
441
472
  return valid
442
473
 
443
474
  @staticmethod
444
- def butcher_tableau_is_explicit(A: jax.typing.ArrayLike) -> jtp.Bool:
475
+ def butcher_tableau_is_explicit(A: jtp.Matrix) -> jtp.Bool:
445
476
  """
446
477
  Check if the Butcher tableau corresponds to an explicit integration scheme.
447
478
 
@@ -456,11 +487,11 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
456
487
 
457
488
  @staticmethod
458
489
  def butcher_tableau_supports_fsal(
459
- A: jax.typing.ArrayLike,
460
- b: jax.typing.ArrayLike,
461
- c: jax.typing.ArrayLike,
490
+ A: jtp.Matrix,
491
+ b: jtp.Matrix,
492
+ c: jtp.Vector,
462
493
  index_of_solution: jtp.IntLike = 0,
463
- ) -> [bool, int | None]:
494
+ ) -> tuple[bool, int | None]:
464
495
  """
465
496
  Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
466
497
 
@@ -481,7 +512,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
481
512
  raise ValueError("The Butcher tableau is not valid.")
482
513
 
483
514
  if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A):
484
- return False
515
+ return False, None
485
516
 
486
517
  if index_of_solution >= b.T.shape[0]:
487
518
  msg = "The index of the solution (i-th row of `b.T`) is out of range."
@@ -505,4 +536,57 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
505
536
  # Return the index of the row of A providing the fsal derivative (that is the
506
537
  # possibly intermediate kᵢ derivative).
507
538
  # Note that if multiple rows match (it should not), we return the first match.
508
- return True, int(jnp.where(rows_of_A_with_fsal == True)[0].tolist()[0])
539
+ return True, int(jnp.where(rows_of_A_with_fsal)[0].tolist()[0])
540
+
541
+
542
+ class ExplicitRungeKuttaSO3Mixin:
543
+ """
544
+ Mixin class to apply over explicit RK integrators defined on
545
+ `PyTreeType = ODEState` to integrate the quaternion on SO(3).
546
+ """
547
+
548
+ @classmethod
549
+ def post_process_state(
550
+ cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
551
+ ) -> js.ode_data.ODEState:
552
+ r"""
553
+ Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the
554
+ quaternion is normalized.
555
+
556
+ Args:
557
+ x0: The initial state of the system.
558
+ t0: The initial time of the system.
559
+ xf: The final state of the system obtain through the integration.
560
+ dt: The time step used for the integration.
561
+ """
562
+
563
+ # Extract the initial base quaternion.
564
+ W_Q_B_t0 = x0.physics_model.base_quaternion
565
+
566
+ # We assume that the initial quaternion is already unary.
567
+ exceptions.raise_runtime_error_if(
568
+ condition=~jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0),
569
+ msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
570
+ )
571
+
572
+ # Get the angular velocity ω to integrate the quaternion.
573
+ # This velocity ω[t0] is computed in the previous timestep by averaging the kᵢ
574
+ # corresponding to the active RK-based scheme. Therefore, by using the ω[t0],
575
+ # we obtain an explicit RK scheme operating on the SO(3) manifold.
576
+ # Note that the current integrator is not a semi-implicit scheme, therefore
577
+ # using the final ω[tf] would be not correct.
578
+ W_ω_WB_t0 = x0.physics_model.base_angular_velocity
579
+
580
+ # Integrate the quaternion on SO(3).
581
+ W_Q_B_tf = jaxsim.math.Quaternion.integration(
582
+ quaternion=W_Q_B_t0,
583
+ dt=dt,
584
+ omega=W_ω_WB_t0,
585
+ omega_in_body_fixed=False,
586
+ )
587
+
588
+ # Replace the quaternion in the final state.
589
+ return xf.replace(
590
+ physics_model=xf.physics_model.replace(base_quaternion=W_Q_B_tf),
591
+ validate=True,
592
+ )