jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
@@ -1,592 +0,0 @@
1
- import abc
2
- import dataclasses
3
- from typing import Any, ClassVar, Generic, Protocol, TypeVar
4
-
5
- import jax
6
- import jax.numpy as jnp
7
- import jax_dataclasses
8
- from jax_dataclasses import Static
9
-
10
- import jaxsim.api as js
11
- import jaxsim.math
12
- import jaxsim.typing as jtp
13
- from jaxsim import exceptions, logging
14
- from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
15
-
16
- try:
17
- from typing import override
18
- except ImportError:
19
- from typing_extensions import override
20
-
21
- try:
22
- from typing import Self
23
- except ImportError:
24
- from typing_extensions import Self
25
-
26
-
27
- # =============
28
- # Generic types
29
- # =============
30
-
31
- Time = jtp.FloatLike
32
- TimeStep = jtp.FloatLike
33
- State = NextState = TypeVar("State")
34
- StateDerivative = TypeVar("StateDerivative")
35
- PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
36
-
37
-
38
- class SystemDynamics(Protocol[State, StateDerivative]):
39
- """
40
- Protocol defining the system dynamics.
41
- """
42
-
43
- def __call__(
44
- self, x: State, t: Time, **kwargs
45
- ) -> tuple[StateDerivative, dict[str, Any]]:
46
- """
47
- Compute the state derivative of the system.
48
-
49
- Args:
50
- x: The state of the system.
51
- t: The time of the system.
52
- **kwargs: Additional keyword arguments.
53
-
54
- Returns:
55
- The state derivative of the system and the auxiliary dictionary.
56
- """
57
- pass
58
-
59
-
60
- # =======================
61
- # Base integrator classes
62
- # =======================
63
-
64
-
65
- @jax_dataclasses.pytree_dataclass
66
- class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
67
- """
68
- Factory class for integrators.
69
- """
70
-
71
- dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
72
- repr=False, hash=False, compare=False, kw_only=True
73
- )
74
-
75
- @classmethod
76
- def build(
77
- cls: type[Self],
78
- *,
79
- dynamics: SystemDynamics[State, StateDerivative],
80
- **kwargs,
81
- ) -> Self:
82
- """
83
- Build the integrator object.
84
-
85
- Args:
86
- dynamics: The system dynamics.
87
- **kwargs: Additional keyword arguments to build the integrator.
88
-
89
- Returns:
90
- The integrator object.
91
- """
92
-
93
- return cls(dynamics=dynamics, **kwargs)
94
-
95
- def step(
96
- self,
97
- x0: State,
98
- t0: Time,
99
- dt: TimeStep,
100
- *,
101
- metadata: dict[str, Any] | None = None,
102
- **kwargs,
103
- ) -> tuple[NextState, dict[str, Any]]:
104
- """
105
- Perform a single integration step.
106
-
107
- Args:
108
- x0: The initial state of the system.
109
- t0: The initial time of the system.
110
- dt: The time step of the integration.
111
- metadata: The state auxiliary dictionary of the integrator.
112
- **kwargs: Additional keyword arguments.
113
-
114
- Returns:
115
- The final state of the system and the updated auxiliary dictionary.
116
- """
117
-
118
- metadata = metadata if metadata is not None else {}
119
-
120
- with self.mutable_context(mutability=Mutability.MUTABLE) as integrator:
121
- xf, metadata_step = integrator(x0, t0, dt, **kwargs)
122
-
123
- return (
124
- xf,
125
- metadata | metadata_step,
126
- )
127
-
128
- @abc.abstractmethod
129
- def __call__(
130
- self, x0: State, t0: Time, dt: TimeStep, **kwargs
131
- ) -> tuple[NextState, dict[str, Any]]:
132
- """
133
- Perform a single integration step.
134
- """
135
- pass
136
-
137
- def init(
138
- self,
139
- x0: State,
140
- t0: Time,
141
- dt: TimeStep,
142
- *,
143
- include_dynamics_aux_dict: bool = False,
144
- **kwargs,
145
- ) -> dict[str, Any]:
146
- """
147
- Initialize the integrator. This method is deprecated.
148
- """
149
-
150
- logging.warning(
151
- "The 'init' method has been deprecated. There is no need to call it."
152
- )
153
-
154
- return {}
155
-
156
-
157
- @jax_dataclasses.pytree_dataclass
158
- class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
159
- """
160
- Base class for explicit Runge-Kutta integrators.
161
-
162
- Attributes:
163
- A: The Runge-Kutta matrix.
164
- b: The weights coefficients.
165
- c: The nodes coefficients.
166
- order_of_bT_rows: The order of the solution.
167
- row_index_of_solution: The row of the integration output corresponding to the final solution.
168
- fsal_enabled_if_supported: Whether to enable the FSAL property, if supported.
169
- index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration.
170
- """
171
-
172
- # The Runge-Kutta matrix.
173
- A: jtp.Matrix
174
-
175
- # The weights coefficients.
176
- # Note that in practice we typically use its transpose `b.transpose()`.
177
- b: jtp.Matrix
178
-
179
- # The nodes coefficients.
180
- c: jtp.Vector
181
-
182
- # Define the order of the solution.
183
- # It should have as many elements as the number of rows of `b.transpose()`.
184
- order_of_bT_rows: ClassVar[tuple[int, ...]]
185
-
186
- # Define the row of the integration output corresponding to the final solution.
187
- # This is the row of b.T that produces the final state.
188
- row_index_of_solution: ClassVar[int]
189
-
190
- # Attributes of FSAL (first-same-as-last) property.
191
- fsal_enabled_if_supported: Static[bool] = dataclasses.field(repr=False)
192
- index_of_fsal: Static[jtp.IntLike | None] = dataclasses.field(repr=False)
193
-
194
- @property
195
- def has_fsal(self) -> bool:
196
- """
197
- Check if the integrator supports the FSAL property.
198
- """
199
- return self.fsal_enabled_if_supported and self.index_of_fsal is not None
200
-
201
- @property
202
- def order(self) -> int:
203
- """
204
- Return the order of the integrator.
205
- """
206
- return self.order_of_bT_rows[self.row_index_of_solution]
207
-
208
- @override
209
- @classmethod
210
- def build(
211
- cls: type[Self],
212
- *,
213
- dynamics: SystemDynamics[State, StateDerivative],
214
- fsal_enabled_if_supported: jtp.BoolLike = True,
215
- **kwargs,
216
- ) -> Self:
217
- """
218
- Build the integrator object.
219
-
220
- Args:
221
- dynamics: The system dynamics.
222
- fsal_enabled_if_supported:
223
- Whether to enable the FSAL property, if supported.
224
- **kwargs: Additional keyword arguments to build the integrator.
225
-
226
- Returns:
227
- The integrator object.
228
- """
229
- A = cls.__dataclass_fields__["A"].default_factory()
230
- b = cls.__dataclass_fields__["b"].default_factory()
231
- c = cls.__dataclass_fields__["c"].default_factory()
232
-
233
- # Check validity of the Butcher tableau.
234
- if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
235
- raise ValueError("The Butcher tableau of this class is not valid.")
236
-
237
- # Check that b.T has enough rows based on the configured index of the solution.
238
- if cls.row_index_of_solution >= b.T.shape[0]:
239
- msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
240
- raise ValueError(msg.format(cls.row_index_of_solution, b.T.shape[0]))
241
-
242
- # Check that the tuple containing the order of the b.T rows matches the number
243
- # of the b.T rows.
244
- if len(cls.order_of_bT_rows) != b.T.shape[0]:
245
- msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}."
246
- raise ValueError(msg.format(len(cls.order_of_bT_rows), b.T.shape[0]))
247
-
248
- # Check if the Butcher tableau supports FSAL (first-same-as-last).
249
- # If it does, store the index of the intermediate derivative to be used as the
250
- # first derivative of the next iteration.
251
- has_fsal, index_of_fsal = ( # noqa: F841
252
- ExplicitRungeKutta.butcher_tableau_supports_fsal(
253
- A=A, b=b, c=c, index_of_solution=cls.row_index_of_solution
254
- )
255
- )
256
-
257
- # Build the integrator object.
258
- integrator = super().build(
259
- dynamics=dynamics,
260
- index_of_fsal=index_of_fsal,
261
- fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
262
- **kwargs,
263
- )
264
-
265
- return integrator
266
-
267
- def __call__(
268
- self, x0: State, t0: Time, dt: TimeStep, **kwargs
269
- ) -> tuple[NextState, dict[str, Any]]:
270
- """
271
- Perform a single integration step.
272
- """
273
-
274
- # Here z is a batched state with as many batch elements as b.T rows.
275
- # Note that z has multiple batches only if b.T has more than one row,
276
- # e.g. in Butcher tableau of embedded schemes.
277
- z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
278
-
279
- # The next state is the batch element located at the configured index of solution.
280
- next_state = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
281
-
282
- return next_state, aux_dict
283
-
284
- @classmethod
285
- def integrate_rk_stage(
286
- cls, x0: State, t0: Time, dt: TimeStep, k: StateDerivative
287
- ) -> NextState:
288
- """
289
- Integrate a single stage of the Runge-Kutta method.
290
-
291
- Args:
292
- x0: The initial state of the system.
293
- t0: The initial time of the system.
294
- dt:
295
- The time step of the RK integration scheme. Note that this is
296
- not the stage timestep, as it depends on the `A` matrix used
297
- to compute the `k` argument.
298
- k:
299
- The RK state derivative of the current stage, weighted with
300
- the `A` matrix.
301
-
302
- Returns:
303
- The state at the next stage of the integration.
304
-
305
- Note:
306
- In the most generic case, `k` could be an arbitrary composition
307
- of the kᵢ derivatives, depending on the RK matrix A.
308
-
309
- Note:
310
- Overriding this method allows users to use different classes
311
- defining `State` and `StateDerivative`. Be aware that the
312
- timestep `dt` is not the stage timestep, therefore the map
313
- used to convert the state derivative must be time-independent.
314
- """
315
-
316
- op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
317
- return jax.tree.map(op, x0, k)
318
-
319
- @classmethod
320
- def post_process_state(
321
- cls, x0: State, t0: Time, xf: NextState, dt: TimeStep
322
- ) -> NextState:
323
- r"""
324
- Post-process the integrated state at :math:`t_f = t_0 + \Delta t`.
325
-
326
- Args:
327
- x0: The initial state of the system.
328
- t0: The initial time of the system.
329
- xf: The final state of the system obtain through the integration.
330
- dt: The time step used for the integration.
331
-
332
- Returns:
333
- The post-processed integrated state.
334
- """
335
-
336
- return xf
337
-
338
- def _compute_next_state(
339
- self, x0: State, t0: Time, dt: TimeStep, **kwargs
340
- ) -> tuple[NextState, dict[str, Any]]:
341
- """
342
- Compute the next state of the system, returning all the output states.
343
-
344
- Args:
345
- x0: The initial state of the system.
346
- t0: The initial time of the system.
347
- dt: The time step of the integration.
348
- **kwargs: Additional keyword arguments.
349
-
350
- Returns:
351
- A batched state with as many batch elements as `b.T` rows.
352
- """
353
-
354
- # Call variables with better symbols.
355
- Δt = dt
356
- c = self.c
357
- b = self.b
358
- A = self.A
359
-
360
- # Extract metadata from the kwargs.
361
- metadata = kwargs.pop("metadata", {})
362
-
363
- # Close f over optional kwargs.
364
- f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
365
-
366
- # Initialize the carry of the for loop with the stacked kᵢ vectors.
367
- carry0 = jax.tree.map(
368
- lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
369
- )
370
-
371
- # Closure on metadata to either evaluate the dynamics at the initial state
372
- # or to use the previous state derivative (only integrators supporting FSAL).
373
- def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
374
- ẋ0, aux_dict = f(x0, t0)
375
- return metadata.get("dxdt0", ẋ0), aux_dict
376
-
377
- # We use a `jax.lax.scan` to compile the `f` function only once.
378
- # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
379
- # would include 4 repetitions of the `f` logic, making everything extremely slow.
380
- def scan_body(
381
- carry: jax.Array, i: int | jax.Array
382
- ) -> tuple[jax.Array, dict[str, Any]]:
383
- """
384
- Compute the kᵢ derivative of the Runge-Kutta stage.
385
- """
386
-
387
- # Unpack the carry, i.e. the stacked kᵢ vectors.
388
- K = carry
389
-
390
- # Define the computation of the Runge-Kutta stage.
391
- def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
392
-
393
- # Compute ∑ⱼ aᵢⱼ kⱼ.
394
- op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
395
- sum_ak = jax.tree.map(op_sum_ak, K)
396
-
397
- # Compute the next state for the kᵢ evaluation.
398
- # Note that this is not a Δt integration since aᵢⱼ could be fractional.
399
- xi = self.integrate_rk_stage(x0, t0, Δt, sum_ak)
400
-
401
- # Compute the next time for the kᵢ evaluation.
402
- ti = t0 + c[i] * Δt
403
-
404
- # Evaluate the dynamics.
405
- ki, aux_dict = f(xi, ti)
406
- return ki, aux_dict
407
-
408
- # This selector enables FSAL property in the first iteration (i=0).
409
- ki, aux_dict = jax.lax.cond(
410
- pred=jnp.logical_and(i == 0, self.has_fsal),
411
- true_fun=get_ẋ0_and_aux_dict,
412
- false_fun=compute_ki,
413
- )
414
-
415
- # Store the kᵢ derivative in K.
416
- op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
417
- K = jax.tree.map(op, K, ki)
418
-
419
- carry = K
420
- return carry, aux_dict
421
-
422
- # Compute the state derivatives kᵢ.
423
- K, aux_dict = jax.lax.scan(
424
- f=scan_body,
425
- init=carry0,
426
- xs=jnp.arange(c.size),
427
- )
428
-
429
- # Update the FSAL property for the next iteration.
430
- if self.has_fsal:
431
- # Store the first derivative of the next step in the metadata.
432
- metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
433
-
434
- # Compute the output state.
435
- # Note that z contains as many new states as the rows of `b.T`.
436
- op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k)
437
- z = jax.tree.map(op, x0, K)
438
-
439
- # Transform the final state of the integration.
440
- # This allows to inject custom logic, if needed.
441
- z_transformed = jax.vmap(
442
- lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
443
- )(z)
444
-
445
- return z_transformed, aux_dict | {"metadata": metadata}
446
-
447
- @staticmethod
448
- def butcher_tableau_is_valid(
449
- A: jtp.Matrix, b: jtp.Matrix, c: jtp.Vector
450
- ) -> jtp.Bool:
451
- """
452
- Check if the Butcher tableau is valid.
453
-
454
- Args:
455
- A: The Runge-Kutta matrix.
456
- b: The weights coefficients.
457
- c: The nodes coefficients.
458
-
459
- Returns:
460
- `True` if the Butcher tableau is valid, `False` otherwise.
461
- """
462
-
463
- valid = True
464
- valid = valid and A.ndim == 2
465
- valid = valid and b.ndim == 2
466
- valid = valid and c.ndim == 1
467
- valid = valid and b.T.shape[0] <= 2
468
- valid = valid and A.shape[0] == A.shape[1]
469
- valid = valid and A.shape == (c.size, b.T.shape[1])
470
- valid = valid and bool(jnp.all(b.T.sum(axis=1) == 1))
471
-
472
- return valid
473
-
474
- @staticmethod
475
- def butcher_tableau_is_explicit(A: jtp.Matrix) -> jtp.Bool:
476
- """
477
- Check if the Butcher tableau corresponds to an explicit integration scheme.
478
-
479
- Args:
480
- A: The Runge-Kutta matrix.
481
-
482
- Returns:
483
- `True` if the Butcher tableau is explicit, `False` otherwise.
484
- """
485
-
486
- return jnp.allclose(A, jnp.tril(A, k=-1))
487
-
488
- @staticmethod
489
- def butcher_tableau_supports_fsal(
490
- A: jtp.Matrix,
491
- b: jtp.Matrix,
492
- c: jtp.Vector,
493
- index_of_solution: jtp.IntLike = 0,
494
- ) -> tuple[bool, int | None]:
495
- """
496
- Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
497
-
498
- Args:
499
- A: The Runge-Kutta matrix.
500
- b: The weights coefficients.
501
- c: The nodes coefficients.
502
- index_of_solution:
503
- The index of the row of `b.T` corresponding to the solution.
504
-
505
- Returns:
506
- A tuple containing a boolean indicating whether the Butcher tableau supports
507
- FSAL, and the index i of the intermediate kᵢ derivative corresponding to the
508
- initial derivative `f(x0, t0)` of the next step.
509
- """
510
-
511
- if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
512
- raise ValueError("The Butcher tableau is not valid.")
513
-
514
- if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A):
515
- return False, None
516
-
517
- if index_of_solution >= b.T.shape[0]:
518
- msg = "The index of the solution (i-th row of `b.T`) is out of range."
519
- raise ValueError(msg)
520
-
521
- if c[0] != 0:
522
- return False, None
523
-
524
- # Find all the rows of A where c = 1 (therefore at t=tf). The Butcher tableau
525
- # supports FSAL if any of these rows (there might be more rows with c=1) matches
526
- # the rows of b.T corresponding to the next state (marked by `index_of_solution`).
527
- # This last condition means that the last kᵢ derivative is computed at (tf, xf),
528
- # that corresponds to the (t0, x0) pair of the next integration call.
529
- rows_of_A_with_fsal = (A == b.T[None, index_of_solution]).all(axis=1)
530
- rows_of_A_with_fsal = jnp.logical_and(rows_of_A_with_fsal, (c == 1))
531
-
532
- # If there is no match, it means that the Butcher tableau does not support FSAL.
533
- if not rows_of_A_with_fsal.any():
534
- return False, None
535
-
536
- # Return the index of the row of A providing the fsal derivative (that is the
537
- # possibly intermediate kᵢ derivative).
538
- # Note that if multiple rows match (it should not), we return the first match.
539
- return True, int(jnp.where(rows_of_A_with_fsal)[0].tolist()[0])
540
-
541
-
542
- class ExplicitRungeKuttaSO3Mixin:
543
- """
544
- Mixin class to apply over explicit RK integrators defined on
545
- `PyTreeType = ODEState` to integrate the quaternion on SO(3).
546
- """
547
-
548
- @classmethod
549
- def post_process_state(
550
- cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
551
- ) -> js.ode_data.ODEState:
552
- r"""
553
- Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the
554
- quaternion is normalized.
555
-
556
- Args:
557
- x0: The initial state of the system.
558
- t0: The initial time of the system.
559
- xf: The final state of the system obtain through the integration.
560
- dt: The time step used for the integration.
561
- """
562
-
563
- # Extract the initial base quaternion.
564
- W_Q_B_t0 = x0.physics_model.base_quaternion
565
-
566
- # We assume that the initial quaternion is already unary.
567
- exceptions.raise_runtime_error_if(
568
- condition=~jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0),
569
- msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
570
- )
571
-
572
- # Get the angular velocity ω to integrate the quaternion.
573
- # This velocity ω[t0] is computed in the previous timestep by averaging the kᵢ
574
- # corresponding to the active RK-based scheme. Therefore, by using the ω[t0],
575
- # we obtain an explicit RK scheme operating on the SO(3) manifold.
576
- # Note that the current integrator is not a semi-implicit scheme, therefore
577
- # using the final ω[tf] would be not correct.
578
- W_ω_WB_t0 = x0.physics_model.base_angular_velocity
579
-
580
- # Integrate the quaternion on SO(3).
581
- W_Q_B_tf = jaxsim.math.Quaternion.integration(
582
- quaternion=W_Q_B_t0,
583
- dt=dt,
584
- omega=W_ω_WB_t0,
585
- omega_in_body_fixed=False,
586
- )
587
-
588
- # Replace the quaternion in the final state.
589
- return xf.replace(
590
- physics_model=xf.physics_model.replace(base_quaternion=W_Q_B_tf),
591
- validate=True,
592
- )