jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
@@ -1,706 +0,0 @@
1
- import dataclasses
2
- import functools
3
- from typing import Any, ClassVar, Generic
4
-
5
- try:
6
- from typing import Self
7
- except ImportError:
8
- from typing_extensions import Self
9
-
10
- import jax
11
- import jax.flatten_util
12
- import jax.numpy as jnp
13
- import jax_dataclasses
14
- from jax_dataclasses import Static
15
-
16
- import jaxsim.utils.tracing
17
- from jaxsim import typing as jtp
18
-
19
- from .common import (
20
- ExplicitRungeKutta,
21
- ExplicitRungeKuttaSO3Mixin,
22
- NextState,
23
- PyTreeType,
24
- State,
25
- StateDerivative,
26
- SystemDynamics,
27
- Time,
28
- TimeStep,
29
- )
30
-
31
- # For robot dynamics, the following default tolerances are already pretty accurate.
32
- # Users can either decrease them and pay the price of smaller Δt, or increase
33
- # them and pay the price of less accurate dynamics.
34
- RTOL_DEFAULT = 0.000_100 # 0.01%
35
- ATOL_DEFAULT = 0.000_010 # 10μ
36
-
37
- # Default parameters of Embedded RK schemes.
38
- SAFETY_DEFAULT = 0.9
39
- BETA_MIN_DEFAULT = 1.0 / 10
40
- BETA_MAX_DEFAULT = 2.5
41
- MAX_STEP_REJECTIONS_DEFAULT = 5
42
-
43
-
44
- # =================
45
- # Utility functions
46
- # =================
47
-
48
-
49
- @functools.partial(jax.jit, static_argnames=["f"])
50
- def estimate_step_size(
51
- x0: jtp.PyTree,
52
- t0: Time,
53
- f: SystemDynamics,
54
- order: jtp.IntLike,
55
- rtol: jtp.FloatLike = RTOL_DEFAULT,
56
- atol: jtp.FloatLike = ATOL_DEFAULT,
57
- ) -> tuple[jtp.Float, jtp.PyTree]:
58
- r"""
59
- Compute the initial step size to warm-start variable-step integrators.
60
-
61
- Args:
62
- x0: The initial state.
63
- t0: The initial time.
64
- f: The state derivative function :math:`f(x, t)`.
65
- order:
66
- The order :math:`p` of an integrator with truncation error
67
- :math:`\mathcal{O}(\Delta t^{p+1})`.
68
- rtol: The relative tolerance to scale the state.
69
- atol: The absolute tolerance to scale the state.
70
-
71
- Returns:
72
- A tuple containing the computed initial step size
73
- and the state derivative :math:`\dot{x} = f(x_0, t_0)`.
74
-
75
- Note:
76
- Interested readers could find implementation details in:
77
-
78
- Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
79
- E. Hairer, S. P. Norsett G. Wanner.
80
- """
81
-
82
- # Helper to flatten a pytree to a 1D vector.
83
- def flatten(pytree) -> jax.Array:
84
- return jax.flatten_util.ravel_pytree(pytree=pytree)[0]
85
-
86
- # Compute the state derivative at the initial state.
87
- ẋ0 = f(x0, t0)[0]
88
-
89
- # Compute the scaling factors of the initial state and its derivative.
90
- compute_scale = lambda x: atol + jnp.abs(x) * rtol
91
- scale0 = jax.tree.map(compute_scale, x0)
92
- scale1 = jax.tree.map(compute_scale, ẋ0)
93
-
94
- # Scale the initial state and its derivative.
95
- scale_pytree = lambda x, scale: jnp.abs(x) / scale
96
- x0_scaled = jax.tree.map(scale_pytree, x0, scale0)
97
- ẋ0_scaled = jax.tree.map(scale_pytree, ẋ0, scale1)
98
-
99
- # Get the maximum of the scaled pytrees.
100
- d0 = jnp.linalg.norm(flatten(x0_scaled), ord=jnp.inf)
101
- d1 = jnp.linalg.norm(flatten(ẋ0_scaled), ord=jnp.inf)
102
-
103
- # Compute the first guess of the initial step size.
104
- h0 = jnp.where(jnp.minimum(d0, d1) <= 1e-5, 1e-6, 0.01 * d0 / d1)
105
-
106
- # Compute the next state (explicit Euler step) and its derivative.
107
- x1 = jax.tree.map(lambda x0, ẋ0: x0 + h0 * ẋ0, x0, ẋ0)
108
- ẋ1 = f(x1, t0 + h0)[0]
109
-
110
- # Compute the scaling factor of the state derivatives.
111
- compute_scale_2 = lambda ẋ0, ẋ1: atol + jnp.maximum(jnp.abs(ẋ0), jnp.abs(ẋ1)) * rtol
112
- scale2 = jax.tree.map(compute_scale_2, ẋ0, ẋ1)
113
-
114
- # Scale the difference of the state derivatives.
115
- scale_ẋ_difference = lambda ẋ0, ẋ1, scale: jnp.abs((ẋ0 - ẋ1) / scale)
116
- ẋ_difference_scaled = jax.tree.map(scale_ẋ_difference, ẋ0, ẋ1, scale2)
117
-
118
- # Get the maximum of the scaled derivatives difference.
119
- d2 = jnp.linalg.norm(flatten(ẋ_difference_scaled), ord=jnp.inf) / h0
120
-
121
- # Compute the second guess of the initial step size.
122
- h1 = jnp.where(
123
- jnp.maximum(d1, d2) <= 1e-15,
124
- jnp.maximum(1e-6, h0 * 1e-3),
125
- (0.01 / jnp.maximum(d1, d2)) ** (1.0 / (order + 1.0)),
126
- )
127
-
128
- # Propose the final guess of the initial step size.
129
- # Also return the state derivative computed at the initial state since
130
- # likely it is a quantity that needs to be computed again later.
131
- return jnp.array(jnp.minimum(100.0 * h0, h1), dtype=float), ẋ0
132
-
133
-
134
- @jax.jit
135
- def compute_pytree_scale(
136
- x1: jtp.PyTree,
137
- x2: jtp.PyTree | None = None,
138
- rtol: jtp.FloatLike = RTOL_DEFAULT,
139
- atol: jtp.FloatLike = ATOL_DEFAULT,
140
- ) -> jtp.PyTree:
141
- """
142
- Compute the component-wise state scale factors to scale dynamical states.
143
-
144
- Args:
145
- x1: The first state (often the initial state).
146
- x2: The optional second state (often the final state).
147
- rtol: The relative tolerance to scale the state.
148
- atol: The absolute tolerance to scale the state.
149
-
150
- Returns:
151
- A pytree with the same structure of the state containing the scaling factors.
152
- """
153
-
154
- # Consider a zero second pytree, if not given.
155
- x2 = jax.tree.map(jnp.zeros_like, x1) if x2 is None else x2
156
-
157
- # Compute the scaling factors of the initial state and its derivative.
158
- compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
159
- scale = jax.tree.map(compute_scale, x1, x2)
160
-
161
- return scale
162
-
163
-
164
- @jax.jit
165
- def local_error_estimation(
166
- xf: jtp.PyTree,
167
- xf_estimate: jtp.PyTree | None = None,
168
- x0: jtp.PyTree | None = None,
169
- rtol: jtp.FloatLike = RTOL_DEFAULT,
170
- atol: jtp.FloatLike = ATOL_DEFAULT,
171
- norm_ord: jtp.IntLike | jtp.FloatLike = jnp.inf,
172
- ) -> jtp.Float:
173
- """
174
- Estimate the local integration error, often used in Embedded RK schemes.
175
-
176
- Args:
177
- xf: The final state, often computed with the most accurate integrator.
178
- xf_estimate:
179
- The estimated final state, often computed with the less accurate integrator.
180
- If missing, it is initialized to zero.
181
- x0:
182
- The initial state to compute the scaling factors. If missing, it is
183
- initialized to zero.
184
- rtol: The relative tolerance to scale the state.
185
- atol: The absolute tolerance to scale the state.
186
- norm_ord:
187
- The norm to use to compute the error. Default is the infinity norm.
188
-
189
- Returns:
190
- The estimated local integration error.
191
- """
192
-
193
- # Helper to flatten a pytree to a 1D vector.
194
- def flatten(pytree) -> jax.Array:
195
- return jax.flatten_util.ravel_pytree(pytree=pytree)[0]
196
-
197
- # Compute the scale considering the initial and final states.
198
- scale = compute_pytree_scale(x1=xf, x2=x0, rtol=rtol, atol=atol)
199
-
200
- # Consider a zero estimated final state, if not given.
201
- xf_estimate = (
202
- jax.tree.map(jnp.zeros_like, xf) if xf_estimate is None else xf_estimate
203
- )
204
-
205
- # Estimate the error.
206
- estimate_error = lambda l, l̂, sc: jnp.abs(l - l̂) / sc
207
- error_estimate = jax.tree.map(estimate_error, xf, xf_estimate, scale)
208
-
209
- # Return the highest element of the error estimate.
210
- return jnp.linalg.norm(flatten(error_estimate), ord=norm_ord)
211
-
212
-
213
- # ================================
214
- # Embedded Runge-Kutta integrators
215
- # ================================
216
-
217
-
218
- @jax_dataclasses.pytree_dataclass
219
- class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
220
- """
221
- An Embedded Runge-Kutta integrator.
222
-
223
- This class implements a general-purpose Embedded Runge-Kutta integrator
224
- that can be used to solve ordinary differential equations with adaptive
225
- step sizes.
226
-
227
- The integrator is based on an Explicit Runge-Kutta method, and it uses
228
- two different solutions to estimate the local integration error. The
229
- error is then used to adapt the step size to reach a desired accuracy.
230
- """
231
-
232
- AfterInitKey: ClassVar[str] = "after_init"
233
- InitializingKey: ClassVar[str] = "initializing"
234
-
235
- # Define the row of the integration output corresponding to the solution estimate.
236
- # This is the row of b.T that produces the state used e.g. by embedded methods to
237
- # implement the adaptive timestep logic.
238
- row_index_of_solution_estimate: ClassVar[int | None] = None
239
-
240
- # Bounds of the adaptive Δt.
241
- dt_max: Static[jtp.FloatLike] = jnp.inf
242
- dt_min: Static[jtp.FloatLike] = -jnp.inf
243
-
244
- # Tolerances used to scale the two states corresponding to the high-order solution
245
- # and the low-order estimate during the computation of the local integration error.
246
- rtol: Static[jtp.FloatLike] = RTOL_DEFAULT
247
- atol: Static[jtp.FloatLike] = ATOL_DEFAULT
248
-
249
- # Parameters of the adaptive timestep logic.
250
- # Refer to Eq. (4.13) pag. 168 of Hairer93.
251
- safety: Static[jtp.FloatLike] = SAFETY_DEFAULT
252
- beta_max: Static[jtp.FloatLike] = BETA_MAX_DEFAULT
253
- beta_min: Static[jtp.FloatLike] = BETA_MIN_DEFAULT
254
-
255
- # Maximum number of rejected steps when the Δt needs to be reduced.
256
- max_step_rejections: Static[jtp.IntLike] = MAX_STEP_REJECTIONS_DEFAULT
257
-
258
- index_of_fsal: jtp.IntLike | None = None
259
- fsal_enabled_if_supported: bool = False
260
-
261
- def init(
262
- self,
263
- x0: State,
264
- t0: Time,
265
- dt: TimeStep,
266
- **kwargs,
267
- ) -> dict[str, Any]:
268
- """
269
- Initialize the integrator and get the metadata.
270
-
271
- Args:
272
- x0: The initial state of the system.
273
- t0: The initial time of the system.
274
- dt: The time step of the integration.
275
- **kwargs: Additional parameters.
276
-
277
- Returns:
278
- The metadata of the integrator to be passed to the first step.
279
- """
280
-
281
- if jaxsim.utils.tracing(var=jnp.zeros(0)):
282
- raise RuntimeError("This method cannot be used within a JIT context")
283
-
284
- with self.editable(validate=False) as integrator:
285
-
286
- # Inject this key to signal that the integrator is initializing.
287
- # This is used to allocate the arrays of the metadata dictionary,
288
- # that are then filled with NaNs.
289
- metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}
290
-
291
- # Run a dummy call of the integrator.
292
- # It is used only to get the metadata so that we know the structure
293
- # of the corresponding pytree.
294
- _ = integrator(
295
- x0,
296
- jnp.array(t0, dtype=float),
297
- jnp.array(dt, dtype=float),
298
- **(kwargs | {"metadata": metadata}),
299
- )
300
-
301
- # Remove the injected key.
302
- _ = metadata.pop(EmbeddedRungeKutta.InitializingKey)
303
-
304
- # Make sure that all leafs of the dictionary are JAX arrays.
305
- # Also, since these are dummy parameters, set them all to NaN.
306
- metadata_after_init = jax.tree.map(
307
- lambda l: jnp.nan * jnp.zeros_like(l), metadata
308
- )
309
-
310
- return metadata_after_init
311
-
312
- def __call__(
313
- self, x0: State, t0: Time, dt: TimeStep, **kwargs
314
- ) -> tuple[NextState, dict[str, Any]]:
315
- """
316
- Integrate the system for a single step.
317
- """
318
-
319
- # This method is called differently in three stages:
320
- #
321
- # 1. During initialization, to allocate a dummy metadata dictionary.
322
- # The metadata is a dictionary of float JAX arrays, that are initialized
323
- # with the right shape and filled with NaNs.
324
- # 2. During the first step, this method operates on the Nan-filled
325
- # `metadata` argument, and it populates with the actual metadata.
326
- # 3. After the first step, this method operates on the actual metadata.
327
- #
328
- # In particular, we store the following information in the metadata:
329
- # - The first attempt of the step size, `dt0`. This is either estimated during
330
- # phase 2, or taken from the previous step during phase 3.
331
- # - For integrators that support FSAL, the derivative at the initial state
332
- # computed during the previous step. This can be done because FSAL integrators
333
- # evaluate the dynamics at the final state of the previous step, that matches
334
- # the initial state of the current step.
335
- #
336
- metadata = kwargs.pop("metadata", {})
337
-
338
- integrator_init = jnp.array(
339
- metadata.get(self.InitializingKey, False), dtype=bool
340
- )
341
-
342
- # Close f over optional kwargs.
343
- f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
344
-
345
- # Define the final time.
346
- tf = t0 + dt
347
-
348
- # Initialize solution orders.
349
- p = self.order_of_solution
350
- p̂ = self.order_of_solution_estimate
351
- q = jnp.minimum(p, p̂)
352
-
353
- # The value of dt0 is NaN (or, at least, it should be) only after initialization
354
- # and before the first step.
355
- metadata["dt0"], metadata["dxdt0"] = jax.lax.cond(
356
- pred=("dt0" in metadata) & ~jnp.isnan(metadata.get("dt0", 0.0)).any(),
357
- true_fun=lambda metadata: (
358
- metadata.get("dt0", jnp.array(0.0, dtype=float)),
359
- metadata.get("dxdt0", f(x0, t0)[0]),
360
- ),
361
- false_fun=lambda aux: estimate_step_size(
362
- x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
363
- ),
364
- operand=metadata,
365
- )
366
-
367
- # Clip the estimated initial step size to the given bounds, if necessary.
368
- metadata["dt0"] = jnp.clip(
369
- metadata["dt0"],
370
- jnp.minimum(self.dt_min, metadata["dt0"]),
371
- jnp.minimum(self.dt_max, metadata["dt0"]),
372
- )
373
-
374
- # =========================================================
375
- # While loop to reach tf from t0 using an adaptive timestep
376
- # =========================================================
377
-
378
- # Initialize the carry of the while loop.
379
- Carry = tuple[Any, ...]
380
- carry0: Carry = (
381
- x0,
382
- jnp.array(t0).astype(float),
383
- metadata,
384
- jnp.array(0, dtype=int),
385
- jnp.array(False).astype(bool),
386
- )
387
-
388
- def while_loop_cond(carry: Carry) -> jtp.Bool:
389
- _, _, _, _, break_loop = carry
390
- return jnp.logical_not(break_loop)
391
-
392
- # Each loop is an integration step with variable Δt.
393
- # Depending on the integration error, the step could be discarded and the
394
- # while body ran again from the same (x0, t0) but with a smaller Δt.
395
- # We run these loops until the final time tf is reached.
396
- def while_loop_body(carry: Carry) -> Carry:
397
-
398
- # Unpack the carry.
399
- x0, t0, metadata, discarded_steps, _ = carry
400
-
401
- # Take care of the final adaptive step.
402
- # We want the final Δt to let us reach tf exactly.
403
- # Then we can exit the while loop.
404
- Δt0 = metadata["dt0"]
405
- Δt0 = jnp.where(t0 + Δt0 < tf, Δt0, tf - t0)
406
- break_loop = jnp.where(t0 + Δt0 < tf, False, True)
407
-
408
- # Run the underlying explicit RK integrator.
409
- # The output z contains multiple solutions (depending on the rows of b.T).
410
- with self.editable(validate=True) as integrator:
411
- z, aux_dict = integrator._compute_next_state(
412
- x0=x0, t0=t0, dt=Δt0, **kwargs
413
- )
414
- metadata_next = aux_dict["metadata"]
415
-
416
- # Extract the high-order solution xf and the low-order estimate x̂f.
417
- xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
418
- x̂f = jax.tree.map(lambda l: l[self.row_index_of_solution_estimate], z)
419
-
420
- # Calculate the local integration error.
421
- local_error = local_error_estimation(
422
- x0=x0, xf=xf, xf_estimate=x̂f, rtol=self.rtol, atol=self.atol
423
- )
424
-
425
- # Shrink the Δt every time by the safety factor (even when accepted).
426
- # The β parameters define the bounds of the timestep update factor.
427
- safety = jnp.clip(self.safety, 0.0, 1.0)
428
- β_min = jnp.maximum(0.0, self.beta_min)
429
- β_max = jnp.maximum(β_min, self.beta_max)
430
-
431
- # Compute the next Δt from the desired integration error.
432
- # The computed integration step is accepted if error <= 1.0,
433
- # otherwise it is rejected.
434
- #
435
- # In case of rejection, Δt_next is always smaller than Δt0.
436
- # In case of acceptance, Δt_next could either be larger than Δt0,
437
- # or slightly smaller than Δt0 depending on the safety factor.
438
- Δt_next = Δt0 * jnp.clip(
439
- safety * jnp.power(1 / local_error, 1 / (q + 1)),
440
- β_min,
441
- β_max,
442
- )
443
-
444
- def accept_step():
445
- # Use Δt_next in the next while loop.
446
- # If it is the last one, and Δt0 was clipped, return the initial Δt0.
447
- metadata_next_accepted = metadata_next | dict(
448
- dt0=jnp.clip(
449
- jax.lax.select(
450
- pred=break_loop,
451
- on_true=metadata["dt0"],
452
- on_false=Δt_next,
453
- ),
454
- self.dt_min,
455
- self.dt_max,
456
- )
457
- )
458
-
459
- # Start the next while loop from the final state.
460
- x0_next = xf
461
-
462
- # Advance the starting time of the next adaptive step.
463
- t0_next = t0 + Δt0
464
-
465
- # Signal that the final time has been reached.
466
- break_loop_next = t0 + Δt0 >= tf
467
-
468
- return (
469
- x0_next,
470
- t0_next,
471
- break_loop_next,
472
- metadata_next_accepted,
473
- jnp.array(0, dtype=int),
474
- )
475
-
476
- def reject_step():
477
- # Get back the original metadata.
478
- metadata_next_rejected = metadata
479
-
480
- # This time, with a reduced Δt.
481
- metadata_next_rejected["dt0"] = jnp.clip(
482
- Δt_next, self.dt_min, self.dt_max
483
- )
484
-
485
- return (
486
- x0,
487
- t0,
488
- False,
489
- metadata_next_rejected,
490
- discarded_steps + 1,
491
- )
492
-
493
- # Decide whether to accept or reject the step.
494
- (
495
- x0_next,
496
- t0_next,
497
- break_loop,
498
- metadata_next,
499
- discarded_steps,
500
- ) = jax.lax.cond(
501
- pred=(discarded_steps >= self.max_step_rejections)
502
- | (local_error <= 1.0)
503
- | (Δt_next < self.dt_min)
504
- | integrator_init,
505
- true_fun=accept_step,
506
- false_fun=reject_step,
507
- )
508
-
509
- return (
510
- x0_next,
511
- t0_next,
512
- metadata_next,
513
- discarded_steps,
514
- break_loop,
515
- )
516
-
517
- # Integrate with adaptive step until tf is reached.
518
- (
519
- xf,
520
- tf,
521
- metadata_tf,
522
- _,
523
- _,
524
- ) = jax.lax.while_loop(
525
- cond_fun=while_loop_cond,
526
- body_fun=while_loop_body,
527
- init_val=carry0,
528
- )
529
-
530
- return xf, {"metadata": metadata_tf}
531
-
532
- @property
533
- def order_of_solution(self) -> int:
534
- """
535
- The order of the solution.
536
- """
537
- return self.order_of_bT_rows[self.row_index_of_solution]
538
-
539
- @property
540
- def order_of_solution_estimate(self) -> int:
541
- """
542
- The order of the solution estimate.
543
- """
544
- return self.order_of_bT_rows[self.row_index_of_solution_estimate]
545
-
546
- @classmethod
547
- def build(
548
- cls: type[Self],
549
- *,
550
- dynamics: SystemDynamics[State, StateDerivative],
551
- fsal_enabled_if_supported: jtp.BoolLike = True,
552
- dt_max: jtp.FloatLike = jnp.inf,
553
- dt_min: jtp.FloatLike = -jnp.inf,
554
- rtol: jtp.FloatLike = RTOL_DEFAULT,
555
- atol: jtp.FloatLike = ATOL_DEFAULT,
556
- safety: jtp.FloatLike = SAFETY_DEFAULT,
557
- beta_max: jtp.FloatLike = BETA_MAX_DEFAULT,
558
- beta_min: jtp.FloatLike = BETA_MIN_DEFAULT,
559
- max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT,
560
- **kwargs,
561
- ) -> Self:
562
- """
563
- Build an Embedded Runge-Kutta integrator.
564
-
565
- Args:
566
- dynamics: The system dynamics function.
567
- fsal_enabled_if_supported:
568
- Whether to enable the FSAL property if supported by the integrator.
569
- dt_max: The maximum step size.
570
- dt_min: The minimum step size.
571
- rtol: The relative tolerance.
572
- atol: The absolute tolerance.
573
- safety: The safety factor to shrink the step size.
574
- beta_max: The maximum factor to increase the step size.
575
- beta_min: The minimum factor to increase the step size.
576
- max_step_rejections: The maximum number of step rejections.
577
- **kwargs: Additional parameters.
578
- """
579
-
580
- b = cls.__dataclass_fields__["b"].default_factory()
581
-
582
- # Check that b.T has enough rows based on the configured index of the
583
- # solution estimate. This is necessary for embedded methods.
584
- if (
585
- cls.row_index_of_solution_estimate is not None
586
- and cls.row_index_of_solution_estimate >= b.T.shape[0]
587
- ):
588
- msg = "The index of the solution estimate ({}-th row of `b.T`) "
589
- msg += "is out of range ({})."
590
- raise ValueError(
591
- msg.format(cls.row_index_of_solution_estimate, b.T.shape[0])
592
- )
593
-
594
- integrator = super().build(
595
- # Integrator:
596
- dynamics=dynamics,
597
- # ExplicitRungeKutta:
598
- fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
599
- # EmbeddedRungeKutta:
600
- dt_max=float(dt_max),
601
- dt_min=float(dt_min),
602
- rtol=float(rtol),
603
- atol=float(atol),
604
- safety=float(safety),
605
- beta_max=float(beta_max),
606
- beta_min=float(beta_min),
607
- max_step_rejections=int(max_step_rejections),
608
- **kwargs,
609
- )
610
-
611
- return integrator
612
-
613
-
614
- @jax_dataclasses.pytree_dataclass
615
- class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
616
- """
617
- The Heun-Euler integrator for SO(3) dynamics.
618
- """
619
-
620
- A: jtp.Matrix = dataclasses.field(
621
- default_factory=lambda: jnp.array(
622
- [
623
- [0, 0],
624
- [1, 0],
625
- ]
626
- ).astype(float),
627
- compare=False,
628
- )
629
-
630
- b: jtp.Matrix = dataclasses.field(
631
- default_factory=lambda: (
632
- jnp.atleast_2d(
633
- jnp.array(
634
- [
635
- [1 / 2, 1 / 2],
636
- [1, 0],
637
- ]
638
- ),
639
- )
640
- .astype(float)
641
- .transpose()
642
- ),
643
- compare=False,
644
- )
645
-
646
- c: jtp.Vector = dataclasses.field(
647
- default_factory=lambda: jnp.array(
648
- [0, 1],
649
- ).astype(float),
650
- compare=False,
651
- )
652
-
653
- row_index_of_solution: ClassVar[int] = 0
654
- row_index_of_solution_estimate: ClassVar[int | None] = 1
655
-
656
- order_of_bT_rows: ClassVar[tuple[int, ...]] = (2, 1)
657
-
658
- index_of_fsal: jtp.IntLike | None = None
659
- fsal_enabled_if_supported: bool = False
660
-
661
-
662
- @jax_dataclasses.pytree_dataclass
663
- class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
664
- """
665
- The Bogacki-Shampine integrator for SO(3) dynamics.
666
- """
667
-
668
- A: jtp.Matrix = dataclasses.field(
669
- default_factory=lambda: jnp.array(
670
- [
671
- [0, 0, 0, 0],
672
- [1 / 2, 0, 0, 0],
673
- [0, 3 / 4, 0, 0],
674
- [2 / 9, 1 / 3, 4 / 9, 0],
675
- ]
676
- ).astype(float),
677
- compare=False,
678
- )
679
-
680
- b: jtp.Matrix = dataclasses.field(
681
- default_factory=lambda: (
682
- jnp.atleast_2d(
683
- jnp.array(
684
- [
685
- [2 / 9, 1 / 3, 4 / 9, 0],
686
- [7 / 24, 1 / 4, 1 / 3, 1 / 8],
687
- ]
688
- ),
689
- )
690
- .astype(float)
691
- .transpose()
692
- ),
693
- compare=False,
694
- )
695
-
696
- c: jtp.Vector = dataclasses.field(
697
- default_factory=lambda: jnp.array(
698
- [0, 1 / 2, 3 / 4, 1],
699
- ).astype(float),
700
- compare=False,
701
- )
702
-
703
- row_index_of_solution: ClassVar[int] = 0
704
- row_index_of_solution_estimate: ClassVar[int | None] = 1
705
-
706
- order_of_bT_rows: ClassVar[tuple[int, ...]] = (3, 2)