jaxsim 0.2.dev108__py3-none-any.whl → 0.2.dev166__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,508 @@
1
+ import abc
2
+ import dataclasses
3
+ from typing import Any, ClassVar, Generic, Protocol, Type, TypeVar
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+ from jax_dataclasses import Static
9
+
10
+ import jaxsim.typing as jtp
11
+ from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
12
+
13
+ try:
14
+ from typing import override
15
+ except ImportError:
16
+ from typing_extensions import override
17
+
18
+ try:
19
+ from typing import Self
20
+ except ImportError:
21
+ from typing_extensions import Self
22
+
23
+
24
+ # =============
25
+ # Generic types
26
+ # =============
27
+
28
+ Time = jax.typing.ArrayLike
29
+ TimeStep = jax.typing.ArrayLike
30
+ State = NextState = TypeVar("State")
31
+ StateDerivative = TypeVar("StateDerivative")
32
+ PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
33
+
34
+
35
+ class SystemDynamics(Protocol[State, StateDerivative]):
36
+ def __call__(
37
+ self, x: State, t: Time, **kwargs
38
+ ) -> tuple[StateDerivative, dict[str, Any]]: ...
39
+
40
+
41
+ # =======================
42
+ # Base integrator classes
43
+ # =======================
44
+
45
+
46
+ @jax_dataclasses.pytree_dataclass
47
+ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
48
+
49
+ AuxDictDynamicsKey: ClassVar[str] = "aux_dict_dynamics"
50
+
51
+ dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
52
+ repr=False, hash=False, compare=False, kw_only=True
53
+ )
54
+
55
+ params: dict[str, Any] = dataclasses.field(
56
+ default_factory=dict, repr=False, hash=False, compare=False, kw_only=True
57
+ )
58
+
59
+ @classmethod
60
+ def build(
61
+ cls: Type[Self], *, dynamics: SystemDynamics[State, StateDerivative], **kwargs
62
+ ) -> Self:
63
+ """
64
+ Build the integrator object.
65
+
66
+ Args:
67
+ dynamics: The system dynamics.
68
+ **kwargs: Additional keyword arguments to build the integrator.
69
+
70
+ Returns:
71
+ The integrator object.
72
+ """
73
+
74
+ return cls(dynamics=dynamics, **kwargs) # noqa
75
+
76
+ def step(
77
+ self,
78
+ x0: State,
79
+ t0: Time,
80
+ dt: TimeStep,
81
+ *,
82
+ params: dict[str, Any],
83
+ **kwargs,
84
+ ) -> tuple[State, dict[str, Any]]:
85
+ """
86
+ Perform a single integration step.
87
+
88
+ Args:
89
+ x0: The initial state of the system.
90
+ t0: The initial time of the system.
91
+ dt: The time step of the integration.
92
+ params: The auxiliary dictionary of the integrator.
93
+ **kwargs: Additional keyword arguments.
94
+
95
+ Returns:
96
+ The final state of the system and the updated auxiliary dictionary.
97
+ """
98
+
99
+ with self.editable(validate=False) as integrator:
100
+ integrator.params = params
101
+
102
+ with integrator.mutable_context(mutability=Mutability.MUTABLE):
103
+ xf = integrator(x0, t0, dt, **kwargs)
104
+
105
+ assert Integrator.AuxDictDynamicsKey in integrator.params
106
+
107
+ return xf, integrator.params
108
+
109
+ @abc.abstractmethod
110
+ def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
111
+ pass
112
+
113
+ def init(
114
+ self,
115
+ x0: State,
116
+ t0: Time,
117
+ dt: TimeStep,
118
+ *,
119
+ key: jax.Array | None = None,
120
+ **kwargs,
121
+ ) -> dict[str, Any]:
122
+ """
123
+ Initialize the integrator.
124
+
125
+ Args:
126
+ x0: The initial state of the system.
127
+ t0: The initial time of the system.
128
+ dt: The time step of the integration.
129
+ key: An optional random key to initialize the integrator.
130
+
131
+ Returns:
132
+ The auxiliary dictionary of the integrator.
133
+
134
+ Note:
135
+ This method should have the same signature as the inherited `__call__`
136
+ method, including additional kwargs.
137
+
138
+ Note:
139
+ If the integrator supports FSAL, the pair `(x0, t0)` must match the real
140
+ initial state and time of the system, otherwise the initial derivative of
141
+ the first step will be wrong.
142
+ """
143
+
144
+ _, aux_dict_dynamics = self.dynamics(x0, t0)
145
+
146
+ with self.editable(validate=False) as integrator:
147
+ _ = integrator(x0, t0, dt, **kwargs)
148
+ aux_dict_step = integrator.params
149
+
150
+ if Integrator.AuxDictDynamicsKey in aux_dict_dynamics:
151
+ msg = "You cannot create a key '{}' in the __call__ method."
152
+ raise KeyError(msg.format(Integrator.AuxDictDynamicsKey))
153
+
154
+ return {Integrator.AuxDictDynamicsKey: aux_dict_dynamics} | aux_dict_step
155
+
156
+
157
+ @jax_dataclasses.pytree_dataclass
158
+ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
159
+
160
+ # The Runge-Kutta matrix.
161
+ A: ClassVar[jax.typing.ArrayLike]
162
+
163
+ # The weights coefficients.
164
+ # Note that in practice we typically use its transpose `b.transpose()`.
165
+ b: ClassVar[jax.typing.ArrayLike]
166
+
167
+ # The nodes coefficients.
168
+ c: ClassVar[jax.typing.ArrayLike]
169
+
170
+ # Define the order of the solution.
171
+ # It should have as many elements as the number of rows of `b.transpose()`.
172
+ order_of_bT_rows: ClassVar[tuple[int, ...]]
173
+
174
+ # Define the row of the integration output corresponding to the final solution.
175
+ # This is the row of b.T that produces the final state.
176
+ row_index_of_solution: ClassVar[int]
177
+
178
+ # Attributes of FSAL (first-same-as-last) property.
179
+ fsal_enabled_if_supported: Static[bool] = dataclasses.field(repr=False)
180
+ index_of_fsal: Static[jtp.IntLike | None] = dataclasses.field(repr=False)
181
+
182
+ @property
183
+ def has_fsal(self) -> bool:
184
+ return self.fsal_enabled_if_supported and self.index_of_fsal is not None
185
+
186
+ @property
187
+ def order(self) -> int:
188
+ return self.order_of_bT_rows[self.row_index_of_solution]
189
+
190
+ @override
191
+ @classmethod
192
+ def build(
193
+ cls: Type[Self],
194
+ *,
195
+ dynamics: SystemDynamics[State, StateDerivative],
196
+ fsal_enabled_if_supported: jtp.BoolLike = True,
197
+ **kwargs,
198
+ ) -> Self:
199
+ """
200
+ Build the integrator object.
201
+
202
+ Args:
203
+ dynamics: The system dynamics.
204
+ fsal_enabled_if_supported:
205
+ Whether to enable the FSAL property, if supported.
206
+ **kwargs: Additional keyword arguments to build the integrator.
207
+
208
+ Returns:
209
+ The integrator object.
210
+ """
211
+
212
+ # Adjust the shape of the tableau coefficients.
213
+ c = jnp.atleast_1d(cls.c.squeeze())
214
+ b = jnp.atleast_2d(jnp.vstack(cls.b.squeeze()))
215
+ A = jnp.atleast_2d(cls.A.squeeze())
216
+
217
+ # Check validity of the Butcher tableau.
218
+ if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
219
+ raise ValueError("The Butcher tableau of this class is not valid.")
220
+
221
+ # Store the adjusted shapes of the tableau coefficients.
222
+ cls.c = c
223
+ cls.b = b
224
+ cls.A = A
225
+
226
+ # Check that b.T has enough rows based on the configured index of the solution.
227
+ if cls.row_index_of_solution >= cls.b.T.shape[0]:
228
+ msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
229
+ raise ValueError(msg.format(cls.row_index_of_solution, cls.b.T.shape[0]))
230
+
231
+ # Check that the tuple containing the order of the b.T rows matches the number
232
+ # of the b.T rows.
233
+ if len(cls.order_of_bT_rows) != cls.b.T.shape[0]:
234
+ msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}."
235
+ raise ValueError(msg.format(len(cls.order_of_bT_rows), cls.b.T.shape[0]))
236
+
237
+ # Check if the Butcher tableau supports FSAL (first-same-as-last).
238
+ # If it does, store the index of the intermediate derivative to be used as the
239
+ # first derivative of the next iteration.
240
+ has_fsal, index_of_fsal = ExplicitRungeKutta.butcher_tableau_supports_fsal(
241
+ A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
242
+ )
243
+
244
+ # Build the integrator object.
245
+ integrator = super().build(
246
+ dynamics=dynamics,
247
+ index_of_fsal=index_of_fsal,
248
+ fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
249
+ **kwargs,
250
+ )
251
+
252
+ return integrator
253
+
254
+ def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
255
+
256
+ # Here z is a batched state with as many batch elements as b.T rows.
257
+ # Note that z has multiple batches only if b.T has more than one row,
258
+ # e.g. in Butcher tableau of embedded schemes.
259
+ z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
260
+
261
+ # The next state is the batch element located at the configured index of solution.
262
+ return jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
263
+
264
+ @classmethod
265
+ def integrate_rk_stage(
266
+ cls, x0: State, t0: Time, dt: TimeStep, k: StateDerivative
267
+ ) -> NextState:
268
+ """
269
+ Integrate a single stage of the Runge-Kutta method.
270
+
271
+ Args:
272
+ x0: The initial state of the system.
273
+ t0: The initial time of the system.
274
+ dt:
275
+ The time step of the RK integration scheme. Note that this is
276
+ not the stage timestep, as it depends on the `A` matrix used
277
+ to compute the `k` argument.
278
+ k:
279
+ The RK state derivative of the current stage, weighted with
280
+ the `A` matrix.
281
+
282
+ Returns:
283
+ The state at the next stage of the integration.
284
+
285
+ Note:
286
+ In the most generic case, `k` could be an arbitrary composition
287
+ of the kᵢ derivatives, depending on the RK matrix A.
288
+
289
+ Note:
290
+ Overriding this method allows users to use different classes
291
+ defining `State` and `StateDerivative`. Be aware that the
292
+ timestep `dt` is not the stage timestep, therefore the map
293
+ used to convert the state derivative must be time-independent.
294
+ """
295
+
296
+ op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
297
+ return jax.tree_util.tree_map(op, x0, k)
298
+
299
+ @classmethod
300
+ def post_process_state(
301
+ cls, x0: State, t0: Time, xf: NextState, dt: TimeStep
302
+ ) -> NextState:
303
+ """
304
+ Post-process the integrated state at :math:`t_f = t_0 + \Delta t`.
305
+
306
+ Args:
307
+ x0: The initial state of the system.
308
+ t0: The initial time of the system.
309
+ xf: The final state of the system obtain through the integration.
310
+ dt: The time step used for the integration.
311
+
312
+ Returns:
313
+ The post-processed integrated state.
314
+ """
315
+
316
+ return xf
317
+
318
+ def _compute_next_state(
319
+ self, x0: State, t0: Time, dt: TimeStep, **kwargs
320
+ ) -> NextState:
321
+ """
322
+ Compute the next state of the system, returning all the output states.
323
+
324
+ Args:
325
+ x0: The initial state of the system.
326
+ t0: The initial time of the system.
327
+ dt: The time step of the integration.
328
+ **kwargs: Additional keyword arguments.
329
+
330
+ Returns:
331
+ A batched state with as many batch elements as `b.T` rows.
332
+ """
333
+
334
+ # Call variables with better symbols.
335
+ Δt = dt
336
+ c = self.c
337
+ b = self.b
338
+ A = self.A
339
+
340
+ # Close f over optional kwargs.
341
+ f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
342
+
343
+ # Initialize the carry of the for loop with the stacked kᵢ vectors.
344
+ carry0 = jax.tree_util.tree_map(
345
+ lambda l: jnp.repeat(jnp.zeros_like(l)[jnp.newaxis, ...], c.size, axis=0),
346
+ x0,
347
+ )
348
+
349
+ # Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
350
+ get_ẋ0 = lambda: self.params.get("dxdt0", f(x0, t0)[0])
351
+
352
+ # We use a `jax.lax.scan` to compile the `f` function only once.
353
+ # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
354
+ # would include 4 repetitions of the `f` logic, making everything extremely slow.
355
+ def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[jax.Array, None]:
356
+ """"""
357
+
358
+ # Unpack the carry, i.e. the stacked kᵢ vectors.
359
+ K = carry
360
+
361
+ # Define the computation of the Runge-Kutta stage.
362
+ def compute_ki() -> jax.Array:
363
+
364
+ # Compute ∑ⱼ aᵢⱼ kⱼ
365
+ op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
366
+ sum_ak = jax.tree_util.tree_map(op_sum_ak, K)
367
+
368
+ # Compute the next state for the kᵢ evaluation.
369
+ # Note that this is not a Δt integration since aᵢⱼ could be fractional.
370
+ xi = self.integrate_rk_stage(x0, t0, Δt, sum_ak)
371
+
372
+ # Compute the next time for the kᵢ evaluation.
373
+ ti = t0 + c[i] * Δt
374
+
375
+ # This is kᵢ = f(xᵢ, tᵢ).
376
+ return f(xi, ti)[0]
377
+
378
+ # This selector enables FSAL property in the first iteration (i=0).
379
+ ki = jax.lax.cond(
380
+ pred=jnp.logical_and(i == 0, self.has_fsal),
381
+ true_fun=get_ẋ0,
382
+ false_fun=compute_ki,
383
+ )
384
+
385
+ # Store the kᵢ derivative in K.
386
+ op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
387
+ K = jax.tree_util.tree_map(op, K, ki)
388
+
389
+ carry = K
390
+ return carry, None
391
+
392
+ # Compute the state derivatives kᵢ.
393
+ K, _ = jax.lax.scan(
394
+ f=scan_body,
395
+ init=carry0,
396
+ xs=jnp.arange(c.size),
397
+ )
398
+
399
+ # Update the FSAL property for the next iteration.
400
+ if self.has_fsal:
401
+ self.params["dxdt0"] = jax.tree_map(lambda l: l[self.index_of_fsal], K)
402
+
403
+ # Compute the output state.
404
+ # Note that z contains as many new states as the rows of `b.T`.
405
+ op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k)
406
+ z = jax.tree_util.tree_map(op, x0, K)
407
+
408
+ # Transform the final state of the integration.
409
+ # This allows to inject custom logic, if needed.
410
+ z_transformed = jax.vmap(
411
+ lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
412
+ )(z)
413
+
414
+ return z_transformed
415
+
416
+ @staticmethod
417
+ def butcher_tableau_is_valid(
418
+ A: jax.typing.ArrayLike, b: jax.typing.ArrayLike, c: jax.typing.ArrayLike
419
+ ) -> jtp.Bool:
420
+ """
421
+ Check if the Butcher tableau is valid.
422
+
423
+ Args:
424
+ A: The Runge-Kutta matrix.
425
+ b: The weights coefficients.
426
+ c: The nodes coefficients.
427
+
428
+ Returns:
429
+ `True` if the Butcher tableau is valid, `False` otherwise.
430
+ """
431
+
432
+ valid = True
433
+ valid = valid and A.ndim == 2
434
+ valid = valid and b.ndim == 2
435
+ valid = valid and c.ndim == 1
436
+ valid = valid and b.T.shape[0] <= 2
437
+ valid = valid and A.shape[0] == A.shape[1]
438
+ valid = valid and A.shape == (c.size, b.T.shape[1])
439
+ valid = valid and bool(jnp.all(b.T.sum(axis=1) == 1))
440
+
441
+ return valid
442
+
443
+ @staticmethod
444
+ def butcher_tableau_is_explicit(A: jax.typing.ArrayLike) -> jtp.Bool:
445
+ """
446
+ Check if the Butcher tableau corresponds to an explicit integration scheme.
447
+
448
+ Args:
449
+ A: The Runge-Kutta matrix.
450
+
451
+ Returns:
452
+ `True` if the Butcher tableau is explicit, `False` otherwise.
453
+ """
454
+
455
+ return jnp.allclose(A, jnp.tril(A, k=-1))
456
+
457
+ @staticmethod
458
+ def butcher_tableau_supports_fsal(
459
+ A: jax.typing.ArrayLike,
460
+ b: jax.typing.ArrayLike,
461
+ c: jax.typing.ArrayLike,
462
+ index_of_solution: jtp.IntLike = 0,
463
+ ) -> [bool, int | None]:
464
+ """
465
+ Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
466
+
467
+ Args:
468
+ A: The Runge-Kutta matrix.
469
+ b: The weights coefficients.
470
+ c: The nodes coefficients.
471
+ index_of_solution:
472
+ The index of the row of `b.T` corresponding to the solution.
473
+
474
+ Returns:
475
+ A tuple containing a boolean indicating whether the Butcher tableau supports
476
+ FSAL, and the index i of the intermediate kᵢ derivative corresponding to the
477
+ initial derivative `f(x0, t0)` of the next step.
478
+ """
479
+
480
+ if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
481
+ raise ValueError("The Butcher tableau is not valid.")
482
+
483
+ if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A):
484
+ return False
485
+
486
+ if index_of_solution >= b.T.shape[0]:
487
+ msg = "The index of the solution (i-th row of `b.T`) is out of range."
488
+ raise ValueError(msg)
489
+
490
+ if c[0] != 0:
491
+ return False, None
492
+
493
+ # Find all the rows of A where c = 1 (therefore at t=tf). The Butcher tableau
494
+ # supports FSAL if any of these rows (there might be more rows with c=1) matches
495
+ # the rows of b.T corresponding to the next state (marked by `index_of_solution`).
496
+ # This last condition means that the last kᵢ derivative is computed at (tf, xf),
497
+ # that corresponds to the (t0, x0) pair of the next integration call.
498
+ rows_of_A_with_fsal = (A == b.T[None, index_of_solution]).all(axis=1)
499
+ rows_of_A_with_fsal = jnp.logical_and(rows_of_A_with_fsal, (c == 1))
500
+
501
+ # If there is no match, it means that the Butcher tableau does not support FSAL.
502
+ if not rows_of_A_with_fsal.any():
503
+ return False, None
504
+
505
+ # Return the index of the row of A providing the fsal derivative (that is the
506
+ # possibly intermediate kᵢ derivative).
507
+ # Note that if multiple rows match (it should not), we return the first match.
508
+ return True, int(jnp.where(rows_of_A_with_fsal == True)[0].tolist()[0])
@@ -0,0 +1,158 @@
1
+ from typing import ClassVar, Generic
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jax_dataclasses
6
+ import jaxlie
7
+
8
+ from jaxsim.simulation.ode_data import ODEState
9
+
10
+ from .common import ExplicitRungeKutta, PyTreeType, Time, TimeStep
11
+
12
+ ODEStateDerivative = ODEState
13
+
14
+
15
+ # =====================================================
16
+ # Explicit Runge-Kutta integrators operating on PyTrees
17
+ # =====================================================
18
+
19
+
20
+ @jax_dataclasses.pytree_dataclass
21
+ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
22
+
23
+ A: ClassVar[jax.typing.ArrayLike] = jnp.array(
24
+ [
25
+ [0],
26
+ ]
27
+ ).astype(float)
28
+
29
+ b: ClassVar[jax.typing.ArrayLike] = (
30
+ jnp.array(
31
+ [
32
+ [1],
33
+ ]
34
+ )
35
+ .astype(float)
36
+ .transpose()
37
+ )
38
+
39
+ c: ClassVar[jax.typing.ArrayLike] = jnp.array(
40
+ [0],
41
+ ).astype(float)
42
+
43
+ row_index_of_solution: ClassVar[int] = 0
44
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
45
+
46
+
47
+ @jax_dataclasses.pytree_dataclass
48
+ class Heun(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
49
+
50
+ A: ClassVar[jax.typing.ArrayLike] = jnp.array(
51
+ [
52
+ [0, 0],
53
+ [1 / 2, 0],
54
+ ]
55
+ ).astype(float)
56
+
57
+ b: ClassVar[jax.typing.ArrayLike] = (
58
+ jnp.atleast_2d(
59
+ jnp.array([1 / 2, 1 / 2]),
60
+ )
61
+ .astype(float)
62
+ .transpose()
63
+ )
64
+
65
+ c: ClassVar[jax.typing.ArrayLike] = jnp.array(
66
+ [0, 1],
67
+ ).astype(float)
68
+
69
+ row_index_of_solution: ClassVar[int] = 0
70
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (2,)
71
+
72
+
73
+ @jax_dataclasses.pytree_dataclass
74
+ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
75
+
76
+ A: ClassVar[jax.typing.ArrayLike] = jnp.array(
77
+ [
78
+ [0, 0, 0, 0],
79
+ [1 / 2, 0, 0, 0],
80
+ [0, 1 / 2, 0, 0],
81
+ [0, 0, 1, 0],
82
+ ]
83
+ ).astype(float)
84
+
85
+ b: ClassVar[jax.typing.ArrayLike] = (
86
+ jnp.atleast_2d(
87
+ jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
88
+ )
89
+ .astype(float)
90
+ .transpose()
91
+ )
92
+
93
+ c: ClassVar[jax.typing.ArrayLike] = jnp.array(
94
+ [0, 1 / 2, 1 / 2, 1],
95
+ ).astype(float)
96
+
97
+ row_index_of_solution: ClassVar[int] = 0
98
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (4,)
99
+
100
+
101
+ # ===============================================================================
102
+ # Explicit Runge-Kutta integrators operating on ODEState and integrating on SO(3)
103
+ # ===============================================================================
104
+
105
+
106
+ class ExplicitRungeKuttaSO3Mixin:
107
+ """
108
+ Mixin class to apply over explicit RK integrators defined on
109
+ `PyTreeType = ODEState` to integrate the quaternion on SO(3).
110
+ """
111
+
112
+ @classmethod
113
+ def post_process_state(
114
+ cls, x0: ODEState, t0: Time, xf: ODEState, dt: TimeStep
115
+ ) -> ODEState:
116
+
117
+ # Indices to convert quaternions between serializations.
118
+ to_xyzw = jnp.array([1, 2, 3, 0])
119
+ to_wxyz = jnp.array([3, 0, 1, 2])
120
+
121
+ # Get the initial quaternion.
122
+ W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
123
+ xyzw=x0.physics_model.base_quaternion[to_xyzw]
124
+ )
125
+
126
+ # Get the final angular velocity.
127
+ # This is already computed by averaging the kᵢ in RK-based schemes.
128
+ # Therefore, by using the ω at tf, we obtain a RK scheme operating
129
+ # on the SO(3) manifold.
130
+ W_ω_WB_tf = xf.physics_model.base_angular_velocity
131
+
132
+ # Integrate the quaternion on SO(3).
133
+ # Note that we left-multiply with the exponential map since the angular
134
+ # velocity is expressed in the inertial frame.
135
+ W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
136
+
137
+ # Replace the quaternion in the final state.
138
+ return xf.replace(
139
+ physics_model=xf.physics_model.replace(
140
+ base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
141
+ ),
142
+ validate=True,
143
+ )
144
+
145
+
146
+ @jax_dataclasses.pytree_dataclass
147
+ class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
148
+ pass
149
+
150
+
151
+ @jax_dataclasses.pytree_dataclass
152
+ class HeunSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
153
+ pass
154
+
155
+
156
+ @jax_dataclasses.pytree_dataclass
157
+ class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[ODEState]):
158
+ pass