jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,594 @@
1
+ import abc
2
+ import dataclasses
3
+ from typing import Any, ClassVar, Generic, Protocol, Type, TypeVar
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+ import jaxlie
9
+ from jax_dataclasses import Static
10
+
11
+ import jaxsim.api as js
12
+ import jaxsim.typing as jtp
13
+ from jaxsim.math import Quaternion
14
+ from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
15
+
16
+ try:
17
+ from typing import override
18
+ except ImportError:
19
+ from typing_extensions import override
20
+
21
+ try:
22
+ from typing import Self
23
+ except ImportError:
24
+ from typing_extensions import Self
25
+
26
+
27
+ # =============
28
+ # Generic types
29
+ # =============
30
+
31
+ Time = jax.typing.ArrayLike
32
+ TimeStep = jax.typing.ArrayLike
33
+ State = NextState = TypeVar("State")
34
+ StateDerivative = TypeVar("StateDerivative")
35
+ PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
36
+
37
+
38
+ class SystemDynamics(Protocol[State, StateDerivative]):
39
+ def __call__(
40
+ self, x: State, t: Time, **kwargs
41
+ ) -> tuple[StateDerivative, dict[str, Any]]: ...
42
+
43
+
44
+ # =======================
45
+ # Base integrator classes
46
+ # =======================
47
+
48
+
49
+ @jax_dataclasses.pytree_dataclass
50
+ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
51
+
52
+ AfterInitKey: ClassVar[str] = "after_init"
53
+ InitializingKey: ClassVar[str] = "initializing"
54
+
55
+ AuxDictDynamicsKey: ClassVar[str] = "aux_dict_dynamics"
56
+
57
+ dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
58
+ repr=False, hash=False, compare=False, kw_only=True
59
+ )
60
+
61
+ params: dict[str, Any] = dataclasses.field(
62
+ default_factory=dict, repr=False, hash=False, compare=False, kw_only=True
63
+ )
64
+
65
+ @classmethod
66
+ def build(
67
+ cls: Type[Self],
68
+ *,
69
+ dynamics: SystemDynamics[State, StateDerivative],
70
+ **kwargs,
71
+ ) -> Self:
72
+ """
73
+ Build the integrator object.
74
+
75
+ Args:
76
+ dynamics: The system dynamics.
77
+ **kwargs: Additional keyword arguments to build the integrator.
78
+
79
+ Returns:
80
+ The integrator object.
81
+ """
82
+
83
+ return cls(dynamics=dynamics, **kwargs) # noqa
84
+
85
+ def step(
86
+ self,
87
+ x0: State,
88
+ t0: Time,
89
+ dt: TimeStep,
90
+ *,
91
+ params: dict[str, Any],
92
+ **kwargs,
93
+ ) -> tuple[State, dict[str, Any]]:
94
+ """
95
+ Perform a single integration step.
96
+
97
+ Args:
98
+ x0: The initial state of the system.
99
+ t0: The initial time of the system.
100
+ dt: The time step of the integration.
101
+ params: The auxiliary dictionary of the integrator.
102
+ **kwargs: Additional keyword arguments.
103
+
104
+ Returns:
105
+ The final state of the system and the updated auxiliary dictionary.
106
+ """
107
+
108
+ with self.editable(validate=False) as integrator:
109
+ integrator.params = params
110
+
111
+ with integrator.mutable_context(mutability=Mutability.MUTABLE):
112
+ xf = integrator(x0, t0, dt, **kwargs)
113
+
114
+ return xf, integrator.params | {
115
+ Integrator.AfterInitKey: jnp.array(False).astype(bool)
116
+ }
117
+
118
+ @abc.abstractmethod
119
+ def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
120
+ pass
121
+
122
+ def init(
123
+ self,
124
+ x0: State,
125
+ t0: Time,
126
+ dt: TimeStep,
127
+ *,
128
+ include_dynamics_aux_dict: bool = False,
129
+ **kwargs,
130
+ ) -> dict[str, Any]:
131
+ """
132
+ Initialize the integrator.
133
+
134
+ Args:
135
+ x0: The initial state of the system.
136
+ t0: The initial time of the system.
137
+ dt: The time step of the integration.
138
+
139
+ Returns:
140
+ The auxiliary dictionary of the integrator.
141
+
142
+ Note:
143
+ This method should have the same signature as the inherited `__call__`
144
+ method, including additional kwargs.
145
+
146
+ Note:
147
+ If the integrator supports FSAL, the pair `(x0, t0)` must match the real
148
+ initial state and time of the system, otherwise the initial derivative of
149
+ the first step will be wrong.
150
+ """
151
+
152
+ with self.editable(validate=False) as integrator:
153
+
154
+ # Initialize the integrator parameters.
155
+ # For initialization purpose, the integrators can check if the
156
+ # `Integrator.InitializingKey` is present in their parameters.
157
+ # The AfterInitKey is used in the first step after initialization.
158
+ integrator.params = {
159
+ Integrator.InitializingKey: jnp.array(True),
160
+ Integrator.AfterInitKey: jnp.array(False),
161
+ }
162
+
163
+ # Run a dummy call of the integrator.
164
+ # It is used only to get the params so that we know the structure
165
+ # of the corresponding pytree.
166
+ _ = integrator(x0, t0, dt, **kwargs)
167
+
168
+ # Remove the injected key.
169
+ _ = integrator.params.pop(Integrator.InitializingKey)
170
+
171
+ # Make sure that all leafs of the dictionary are JAX arrays.
172
+ # Also, since these are dummy parameters, set them all to zero.
173
+ params_after_init = jax.tree_util.tree_map(
174
+ lambda l: jnp.zeros_like(l), integrator.params
175
+ )
176
+
177
+ # Mark the next step as first step after initialization.
178
+ params_after_init = params_after_init | {
179
+ Integrator.AfterInitKey: jnp.array(True)
180
+ }
181
+
182
+ # Store the zero parameters in the integrator.
183
+ # When the integrator is stepped, this is used to check if the passed
184
+ # parameters are valid.
185
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
186
+ self.params = params_after_init
187
+
188
+ return params_after_init
189
+
190
+
191
+ @jax_dataclasses.pytree_dataclass
192
+ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
193
+
194
+ # The Runge-Kutta matrix.
195
+ A: ClassVar[jax.typing.ArrayLike]
196
+
197
+ # The weights coefficients.
198
+ # Note that in practice we typically use its transpose `b.transpose()`.
199
+ b: ClassVar[jax.typing.ArrayLike]
200
+
201
+ # The nodes coefficients.
202
+ c: ClassVar[jax.typing.ArrayLike]
203
+
204
+ # Define the order of the solution.
205
+ # It should have as many elements as the number of rows of `b.transpose()`.
206
+ order_of_bT_rows: ClassVar[tuple[int, ...]]
207
+
208
+ # Define the row of the integration output corresponding to the final solution.
209
+ # This is the row of b.T that produces the final state.
210
+ row_index_of_solution: ClassVar[int]
211
+
212
+ # Attributes of FSAL (first-same-as-last) property.
213
+ fsal_enabled_if_supported: Static[bool] = dataclasses.field(repr=False)
214
+ index_of_fsal: Static[jtp.IntLike | None] = dataclasses.field(repr=False)
215
+
216
+ @property
217
+ def has_fsal(self) -> bool:
218
+ return self.fsal_enabled_if_supported and self.index_of_fsal is not None
219
+
220
+ @property
221
+ def order(self) -> int:
222
+ return self.order_of_bT_rows[self.row_index_of_solution]
223
+
224
+ @override
225
+ @classmethod
226
+ def build(
227
+ cls: Type[Self],
228
+ *,
229
+ dynamics: SystemDynamics[State, StateDerivative],
230
+ fsal_enabled_if_supported: jtp.BoolLike = True,
231
+ **kwargs,
232
+ ) -> Self:
233
+ """
234
+ Build the integrator object.
235
+
236
+ Args:
237
+ dynamics: The system dynamics.
238
+ fsal_enabled_if_supported:
239
+ Whether to enable the FSAL property, if supported.
240
+ **kwargs: Additional keyword arguments to build the integrator.
241
+
242
+ Returns:
243
+ The integrator object.
244
+ """
245
+
246
+ # Check validity of the Butcher tableau.
247
+ if not ExplicitRungeKutta.butcher_tableau_is_valid(A=cls.A, b=cls.b, c=cls.c):
248
+ raise ValueError("The Butcher tableau of this class is not valid.")
249
+
250
+ # Check that b.T has enough rows based on the configured index of the solution.
251
+ if cls.row_index_of_solution >= cls.b.T.shape[0]:
252
+ msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
253
+ raise ValueError(msg.format(cls.row_index_of_solution, cls.b.T.shape[0]))
254
+
255
+ # Check that the tuple containing the order of the b.T rows matches the number
256
+ # of the b.T rows.
257
+ if len(cls.order_of_bT_rows) != cls.b.T.shape[0]:
258
+ msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}."
259
+ raise ValueError(msg.format(len(cls.order_of_bT_rows), cls.b.T.shape[0]))
260
+
261
+ # Check if the Butcher tableau supports FSAL (first-same-as-last).
262
+ # If it does, store the index of the intermediate derivative to be used as the
263
+ # first derivative of the next iteration.
264
+ has_fsal, index_of_fsal = ExplicitRungeKutta.butcher_tableau_supports_fsal(
265
+ A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
266
+ )
267
+
268
+ # Build the integrator object.
269
+ integrator = super().build(
270
+ dynamics=dynamics,
271
+ index_of_fsal=index_of_fsal,
272
+ fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
273
+ **kwargs,
274
+ )
275
+
276
+ return integrator
277
+
278
+ def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
279
+
280
+ # Here z is a batched state with as many batch elements as b.T rows.
281
+ # Note that z has multiple batches only if b.T has more than one row,
282
+ # e.g. in Butcher tableau of embedded schemes.
283
+ z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
284
+
285
+ # The next state is the batch element located at the configured index of solution.
286
+ return jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
287
+
288
+ @classmethod
289
+ def integrate_rk_stage(
290
+ cls, x0: State, t0: Time, dt: TimeStep, k: StateDerivative
291
+ ) -> NextState:
292
+ """
293
+ Integrate a single stage of the Runge-Kutta method.
294
+
295
+ Args:
296
+ x0: The initial state of the system.
297
+ t0: The initial time of the system.
298
+ dt:
299
+ The time step of the RK integration scheme. Note that this is
300
+ not the stage timestep, as it depends on the `A` matrix used
301
+ to compute the `k` argument.
302
+ k:
303
+ The RK state derivative of the current stage, weighted with
304
+ the `A` matrix.
305
+
306
+ Returns:
307
+ The state at the next stage of the integration.
308
+
309
+ Note:
310
+ In the most generic case, `k` could be an arbitrary composition
311
+ of the kᵢ derivatives, depending on the RK matrix A.
312
+
313
+ Note:
314
+ Overriding this method allows users to use different classes
315
+ defining `State` and `StateDerivative`. Be aware that the
316
+ timestep `dt` is not the stage timestep, therefore the map
317
+ used to convert the state derivative must be time-independent.
318
+ """
319
+
320
+ op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
321
+ return jax.tree_util.tree_map(op, x0, k)
322
+
323
+ @classmethod
324
+ def post_process_state(
325
+ cls, x0: State, t0: Time, xf: NextState, dt: TimeStep
326
+ ) -> NextState:
327
+ r"""
328
+ Post-process the integrated state at :math:`t_f = t_0 + \Delta t`.
329
+
330
+ Args:
331
+ x0: The initial state of the system.
332
+ t0: The initial time of the system.
333
+ xf: The final state of the system obtain through the integration.
334
+ dt: The time step used for the integration.
335
+
336
+ Returns:
337
+ The post-processed integrated state.
338
+ """
339
+
340
+ return xf
341
+
342
+ def _compute_next_state(
343
+ self, x0: State, t0: Time, dt: TimeStep, **kwargs
344
+ ) -> NextState:
345
+ """
346
+ Compute the next state of the system, returning all the output states.
347
+
348
+ Args:
349
+ x0: The initial state of the system.
350
+ t0: The initial time of the system.
351
+ dt: The time step of the integration.
352
+ **kwargs: Additional keyword arguments.
353
+
354
+ Returns:
355
+ A batched state with as many batch elements as `b.T` rows.
356
+ """
357
+
358
+ # Call variables with better symbols.
359
+ Δt = dt
360
+ c = self.c
361
+ b = self.b
362
+ A = self.A
363
+
364
+ # Close f over optional kwargs.
365
+ f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
366
+
367
+ # Initialize the carry of the for loop with the stacked kᵢ vectors.
368
+ carry0 = jax.tree_util.tree_map(
369
+ lambda l: jnp.repeat(jnp.zeros_like(l)[jnp.newaxis, ...], c.size, axis=0),
370
+ x0,
371
+ )
372
+
373
+ # Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
374
+ get_ẋ0 = lambda: self.params.get("dxdt0", f(x0, t0)[0])
375
+
376
+ # We use a `jax.lax.scan` to compile the `f` function only once.
377
+ # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
378
+ # would include 4 repetitions of the `f` logic, making everything extremely slow.
379
+ def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[jax.Array, None]:
380
+ """"""
381
+
382
+ # Unpack the carry, i.e. the stacked kᵢ vectors.
383
+ K = carry
384
+
385
+ # Define the computation of the Runge-Kutta stage.
386
+ def compute_ki() -> jax.Array:
387
+
388
+ # Compute ∑ⱼ aᵢⱼ kⱼ
389
+ op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
390
+ sum_ak = jax.tree_util.tree_map(op_sum_ak, K)
391
+
392
+ # Compute the next state for the kᵢ evaluation.
393
+ # Note that this is not a Δt integration since aᵢⱼ could be fractional.
394
+ xi = self.integrate_rk_stage(x0, t0, Δt, sum_ak)
395
+
396
+ # Compute the next time for the kᵢ evaluation.
397
+ ti = t0 + c[i] * Δt
398
+
399
+ # This is kᵢ = f(xᵢ, tᵢ).
400
+ return f(xi, ti)[0]
401
+
402
+ # This selector enables FSAL property in the first iteration (i=0).
403
+ ki = jax.lax.cond(
404
+ pred=jnp.logical_and(i == 0, self.has_fsal),
405
+ true_fun=get_ẋ0,
406
+ false_fun=compute_ki,
407
+ )
408
+
409
+ # Store the kᵢ derivative in K.
410
+ op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
411
+ K = jax.tree_util.tree_map(op, K, ki)
412
+
413
+ carry = K
414
+ return carry, None
415
+
416
+ # Compute the state derivatives kᵢ.
417
+ K, _ = jax.lax.scan(
418
+ f=scan_body,
419
+ init=carry0,
420
+ xs=jnp.arange(c.size),
421
+ )
422
+
423
+ # Update the FSAL property for the next iteration.
424
+ if self.has_fsal:
425
+ self.params["dxdt0"] = jax.tree_map(lambda l: l[self.index_of_fsal], K)
426
+
427
+ # Compute the output state.
428
+ # Note that z contains as many new states as the rows of `b.T`.
429
+ op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k)
430
+ z = jax.tree_util.tree_map(op, x0, K)
431
+
432
+ # Transform the final state of the integration.
433
+ # This allows to inject custom logic, if needed.
434
+ z_transformed = jax.vmap(
435
+ lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
436
+ )(z)
437
+
438
+ return z_transformed
439
+
440
+ @staticmethod
441
+ def butcher_tableau_is_valid(
442
+ A: jax.typing.ArrayLike, b: jax.typing.ArrayLike, c: jax.typing.ArrayLike
443
+ ) -> jtp.Bool:
444
+ """
445
+ Check if the Butcher tableau is valid.
446
+
447
+ Args:
448
+ A: The Runge-Kutta matrix.
449
+ b: The weights coefficients.
450
+ c: The nodes coefficients.
451
+
452
+ Returns:
453
+ `True` if the Butcher tableau is valid, `False` otherwise.
454
+ """
455
+
456
+ valid = True
457
+ valid = valid and A.ndim == 2
458
+ valid = valid and b.ndim == 2
459
+ valid = valid and c.ndim == 1
460
+ valid = valid and b.T.shape[0] <= 2
461
+ valid = valid and A.shape[0] == A.shape[1]
462
+ valid = valid and A.shape == (c.size, b.T.shape[1])
463
+ valid = valid and bool(jnp.all(b.T.sum(axis=1) == 1))
464
+
465
+ return valid
466
+
467
+ @staticmethod
468
+ def butcher_tableau_is_explicit(A: jax.typing.ArrayLike) -> jtp.Bool:
469
+ """
470
+ Check if the Butcher tableau corresponds to an explicit integration scheme.
471
+
472
+ Args:
473
+ A: The Runge-Kutta matrix.
474
+
475
+ Returns:
476
+ `True` if the Butcher tableau is explicit, `False` otherwise.
477
+ """
478
+
479
+ return jnp.allclose(A, jnp.tril(A, k=-1))
480
+
481
+ @staticmethod
482
+ def butcher_tableau_supports_fsal(
483
+ A: jax.typing.ArrayLike,
484
+ b: jax.typing.ArrayLike,
485
+ c: jax.typing.ArrayLike,
486
+ index_of_solution: jtp.IntLike = 0,
487
+ ) -> [bool, int | None]:
488
+ """
489
+ Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
490
+
491
+ Args:
492
+ A: The Runge-Kutta matrix.
493
+ b: The weights coefficients.
494
+ c: The nodes coefficients.
495
+ index_of_solution:
496
+ The index of the row of `b.T` corresponding to the solution.
497
+
498
+ Returns:
499
+ A tuple containing a boolean indicating whether the Butcher tableau supports
500
+ FSAL, and the index i of the intermediate kᵢ derivative corresponding to the
501
+ initial derivative `f(x0, t0)` of the next step.
502
+ """
503
+
504
+ if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
505
+ raise ValueError("The Butcher tableau is not valid.")
506
+
507
+ if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A):
508
+ return False
509
+
510
+ if index_of_solution >= b.T.shape[0]:
511
+ msg = "The index of the solution (i-th row of `b.T`) is out of range."
512
+ raise ValueError(msg)
513
+
514
+ if c[0] != 0:
515
+ return False, None
516
+
517
+ # Find all the rows of A where c = 1 (therefore at t=tf). The Butcher tableau
518
+ # supports FSAL if any of these rows (there might be more rows with c=1) matches
519
+ # the rows of b.T corresponding to the next state (marked by `index_of_solution`).
520
+ # This last condition means that the last kᵢ derivative is computed at (tf, xf),
521
+ # that corresponds to the (t0, x0) pair of the next integration call.
522
+ rows_of_A_with_fsal = (A == b.T[None, index_of_solution]).all(axis=1)
523
+ rows_of_A_with_fsal = jnp.logical_and(rows_of_A_with_fsal, (c == 1))
524
+
525
+ # If there is no match, it means that the Butcher tableau does not support FSAL.
526
+ if not rows_of_A_with_fsal.any():
527
+ return False, None
528
+
529
+ # Return the index of the row of A providing the fsal derivative (that is the
530
+ # possibly intermediate kᵢ derivative).
531
+ # Note that if multiple rows match (it should not), we return the first match.
532
+ return True, int(jnp.where(rows_of_A_with_fsal)[0].tolist()[0])
533
+
534
+
535
+ class ExplicitRungeKuttaSO3Mixin:
536
+ """
537
+ Mixin class to apply over explicit RK integrators defined on
538
+ `PyTreeType = ODEState` to integrate the quaternion on SO(3).
539
+ """
540
+
541
+ @classmethod
542
+ def integrate_rk_stage(
543
+ cls, x0: js.ode_data.ODEState, t0: Time, dt: TimeStep, k: js.ode_data.ODEState
544
+ ) -> js.ode_data.ODEState:
545
+
546
+ op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
547
+ xf: js.ode_data.ODEState = jax.tree_util.tree_map(op, x0, k)
548
+
549
+ W_Q_B_t0 = x0.physics_model.base_quaternion
550
+ W_ω_WB_t0 = x0.physics_model.base_angular_velocity
551
+
552
+ return xf.replace(
553
+ physics_model=xf.physics_model.replace(
554
+ base_quaternion=Quaternion.integration(
555
+ quaternion=W_Q_B_t0,
556
+ dt=dt,
557
+ omega=W_ω_WB_t0,
558
+ omega_in_body_fixed=False,
559
+ ),
560
+ )
561
+ )
562
+
563
+ @classmethod
564
+ def post_process_state(
565
+ cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
566
+ ) -> js.ode_data.ODEState:
567
+
568
+ # Indices to convert quaternions between serializations.
569
+ to_xyzw = jnp.array([1, 2, 3, 0])
570
+ to_wxyz = jnp.array([3, 0, 1, 2])
571
+
572
+ # Get the initial quaternion.
573
+ W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
574
+ xyzw=x0.physics_model.base_quaternion[to_xyzw]
575
+ )
576
+
577
+ # Get the final angular velocity.
578
+ # This is already computed by averaging the kᵢ in RK-based schemes.
579
+ # Therefore, by using the ω at tf, we obtain a RK scheme operating
580
+ # on the SO(3) manifold.
581
+ W_ω_WB_tf = xf.physics_model.base_angular_velocity
582
+
583
+ # Integrate the quaternion on SO(3).
584
+ # Note that we left-multiply with the exponential map since the angular
585
+ # velocity is expressed in the inertial frame.
586
+ W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
587
+
588
+ # Replace the quaternion in the final state.
589
+ return xf.replace(
590
+ physics_model=xf.physics_model.replace(
591
+ base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
592
+ ),
593
+ validate=True,
594
+ )
@@ -0,0 +1,102 @@
1
+ from typing import ClassVar, Generic
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jax_dataclasses
6
+
7
+ import jaxsim.api as js
8
+
9
+ from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
10
+
11
+ ODEStateDerivative = js.ode_data.ODEState
12
+
13
+ # =====================================================
14
+ # Explicit Runge-Kutta integrators operating on PyTrees
15
+ # =====================================================
16
+
17
+
18
+ @jax_dataclasses.pytree_dataclass
19
+ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
20
+
21
+ A: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(0).astype(float)
22
+
23
+ b: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(1).astype(float).transpose()
24
+
25
+ c: ClassVar[jax.typing.ArrayLike] = jnp.atleast_1d(0).astype(float)
26
+
27
+ row_index_of_solution: ClassVar[int] = 0
28
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
29
+
30
+
31
+ @jax_dataclasses.pytree_dataclass
32
+ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
33
+
34
+ A: ClassVar[jax.typing.ArrayLike] = jnp.array(
35
+ [
36
+ [0, 0],
37
+ [1, 0],
38
+ ]
39
+ ).astype(float)
40
+
41
+ b: ClassVar[jax.typing.ArrayLike] = (
42
+ jnp.atleast_2d(
43
+ jnp.array([1 / 2, 1 / 2]),
44
+ )
45
+ .astype(float)
46
+ .transpose()
47
+ )
48
+
49
+ c: ClassVar[jax.typing.ArrayLike] = jnp.array(
50
+ [0, 1],
51
+ ).astype(float)
52
+
53
+ row_index_of_solution: ClassVar[int] = 0
54
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (2,)
55
+
56
+
57
+ @jax_dataclasses.pytree_dataclass
58
+ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
59
+
60
+ A: ClassVar[jax.typing.ArrayLike] = jnp.array(
61
+ [
62
+ [0, 0, 0, 0],
63
+ [1 / 2, 0, 0, 0],
64
+ [0, 1 / 2, 0, 0],
65
+ [0, 0, 1, 0],
66
+ ]
67
+ ).astype(float)
68
+
69
+ b: ClassVar[jax.typing.ArrayLike] = (
70
+ jnp.atleast_2d(
71
+ jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
72
+ )
73
+ .astype(float)
74
+ .transpose()
75
+ )
76
+
77
+ c: ClassVar[jax.typing.ArrayLike] = jnp.array(
78
+ [0, 1 / 2, 1 / 2, 1],
79
+ ).astype(float)
80
+
81
+ row_index_of_solution: ClassVar[int] = 0
82
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (4,)
83
+
84
+
85
+ # ===============================================================================
86
+ # Explicit Runge-Kutta integrators operating on ODEState and integrating on SO(3)
87
+ # ===============================================================================
88
+
89
+
90
+ @jax_dataclasses.pytree_dataclass
91
+ class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
92
+ pass
93
+
94
+
95
+ @jax_dataclasses.pytree_dataclass
96
+ class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
97
+ pass
98
+
99
+
100
+ @jax_dataclasses.pytree_dataclass
101
+ class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
102
+ pass