jaxsim 0.6.1.dev13__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
- jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.1.dev13.dist-info/RECORD +0 -74
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
@@ -1,706 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
import functools
|
3
|
-
from typing import Any, ClassVar, Generic
|
4
|
-
|
5
|
-
try:
|
6
|
-
from typing import Self
|
7
|
-
except ImportError:
|
8
|
-
from typing_extensions import Self
|
9
|
-
|
10
|
-
import jax
|
11
|
-
import jax.flatten_util
|
12
|
-
import jax.numpy as jnp
|
13
|
-
import jax_dataclasses
|
14
|
-
from jax_dataclasses import Static
|
15
|
-
|
16
|
-
import jaxsim.utils.tracing
|
17
|
-
from jaxsim import typing as jtp
|
18
|
-
|
19
|
-
from .common import (
|
20
|
-
ExplicitRungeKutta,
|
21
|
-
ExplicitRungeKuttaSO3Mixin,
|
22
|
-
NextState,
|
23
|
-
PyTreeType,
|
24
|
-
State,
|
25
|
-
StateDerivative,
|
26
|
-
SystemDynamics,
|
27
|
-
Time,
|
28
|
-
TimeStep,
|
29
|
-
)
|
30
|
-
|
31
|
-
# For robot dynamics, the following default tolerances are already pretty accurate.
|
32
|
-
# Users can either decrease them and pay the price of smaller Δt, or increase
|
33
|
-
# them and pay the price of less accurate dynamics.
|
34
|
-
RTOL_DEFAULT = 0.000_100 # 0.01%
|
35
|
-
ATOL_DEFAULT = 0.000_010 # 10μ
|
36
|
-
|
37
|
-
# Default parameters of Embedded RK schemes.
|
38
|
-
SAFETY_DEFAULT = 0.9
|
39
|
-
BETA_MIN_DEFAULT = 1.0 / 10
|
40
|
-
BETA_MAX_DEFAULT = 2.5
|
41
|
-
MAX_STEP_REJECTIONS_DEFAULT = 5
|
42
|
-
|
43
|
-
|
44
|
-
# =================
|
45
|
-
# Utility functions
|
46
|
-
# =================
|
47
|
-
|
48
|
-
|
49
|
-
@functools.partial(jax.jit, static_argnames=["f"])
|
50
|
-
def estimate_step_size(
|
51
|
-
x0: jtp.PyTree,
|
52
|
-
t0: Time,
|
53
|
-
f: SystemDynamics,
|
54
|
-
order: jtp.IntLike,
|
55
|
-
rtol: jtp.FloatLike = RTOL_DEFAULT,
|
56
|
-
atol: jtp.FloatLike = ATOL_DEFAULT,
|
57
|
-
) -> tuple[jtp.Float, jtp.PyTree]:
|
58
|
-
r"""
|
59
|
-
Compute the initial step size to warm-start variable-step integrators.
|
60
|
-
|
61
|
-
Args:
|
62
|
-
x0: The initial state.
|
63
|
-
t0: The initial time.
|
64
|
-
f: The state derivative function :math:`f(x, t)`.
|
65
|
-
order:
|
66
|
-
The order :math:`p` of an integrator with truncation error
|
67
|
-
:math:`\mathcal{O}(\Delta t^{p+1})`.
|
68
|
-
rtol: The relative tolerance to scale the state.
|
69
|
-
atol: The absolute tolerance to scale the state.
|
70
|
-
|
71
|
-
Returns:
|
72
|
-
A tuple containing the computed initial step size
|
73
|
-
and the state derivative :math:`\dot{x} = f(x_0, t_0)`.
|
74
|
-
|
75
|
-
Note:
|
76
|
-
Interested readers could find implementation details in:
|
77
|
-
|
78
|
-
Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
|
79
|
-
E. Hairer, S. P. Norsett G. Wanner.
|
80
|
-
"""
|
81
|
-
|
82
|
-
# Helper to flatten a pytree to a 1D vector.
|
83
|
-
def flatten(pytree) -> jax.Array:
|
84
|
-
return jax.flatten_util.ravel_pytree(pytree=pytree)[0]
|
85
|
-
|
86
|
-
# Compute the state derivative at the initial state.
|
87
|
-
ẋ0 = f(x0, t0)[0]
|
88
|
-
|
89
|
-
# Compute the scaling factors of the initial state and its derivative.
|
90
|
-
compute_scale = lambda x: atol + jnp.abs(x) * rtol
|
91
|
-
scale0 = jax.tree.map(compute_scale, x0)
|
92
|
-
scale1 = jax.tree.map(compute_scale, ẋ0)
|
93
|
-
|
94
|
-
# Scale the initial state and its derivative.
|
95
|
-
scale_pytree = lambda x, scale: jnp.abs(x) / scale
|
96
|
-
x0_scaled = jax.tree.map(scale_pytree, x0, scale0)
|
97
|
-
ẋ0_scaled = jax.tree.map(scale_pytree, ẋ0, scale1)
|
98
|
-
|
99
|
-
# Get the maximum of the scaled pytrees.
|
100
|
-
d0 = jnp.linalg.norm(flatten(x0_scaled), ord=jnp.inf)
|
101
|
-
d1 = jnp.linalg.norm(flatten(ẋ0_scaled), ord=jnp.inf)
|
102
|
-
|
103
|
-
# Compute the first guess of the initial step size.
|
104
|
-
h0 = jnp.where(jnp.minimum(d0, d1) <= 1e-5, 1e-6, 0.01 * d0 / d1)
|
105
|
-
|
106
|
-
# Compute the next state (explicit Euler step) and its derivative.
|
107
|
-
x1 = jax.tree.map(lambda x0, ẋ0: x0 + h0 * ẋ0, x0, ẋ0)
|
108
|
-
ẋ1 = f(x1, t0 + h0)[0]
|
109
|
-
|
110
|
-
# Compute the scaling factor of the state derivatives.
|
111
|
-
compute_scale_2 = lambda ẋ0, ẋ1: atol + jnp.maximum(jnp.abs(ẋ0), jnp.abs(ẋ1)) * rtol
|
112
|
-
scale2 = jax.tree.map(compute_scale_2, ẋ0, ẋ1)
|
113
|
-
|
114
|
-
# Scale the difference of the state derivatives.
|
115
|
-
scale_ẋ_difference = lambda ẋ0, ẋ1, scale: jnp.abs((ẋ0 - ẋ1) / scale)
|
116
|
-
ẋ_difference_scaled = jax.tree.map(scale_ẋ_difference, ẋ0, ẋ1, scale2)
|
117
|
-
|
118
|
-
# Get the maximum of the scaled derivatives difference.
|
119
|
-
d2 = jnp.linalg.norm(flatten(ẋ_difference_scaled), ord=jnp.inf) / h0
|
120
|
-
|
121
|
-
# Compute the second guess of the initial step size.
|
122
|
-
h1 = jnp.where(
|
123
|
-
jnp.maximum(d1, d2) <= 1e-15,
|
124
|
-
jnp.maximum(1e-6, h0 * 1e-3),
|
125
|
-
(0.01 / jnp.maximum(d1, d2)) ** (1.0 / (order + 1.0)),
|
126
|
-
)
|
127
|
-
|
128
|
-
# Propose the final guess of the initial step size.
|
129
|
-
# Also return the state derivative computed at the initial state since
|
130
|
-
# likely it is a quantity that needs to be computed again later.
|
131
|
-
return jnp.array(jnp.minimum(100.0 * h0, h1), dtype=float), ẋ0
|
132
|
-
|
133
|
-
|
134
|
-
@jax.jit
|
135
|
-
def compute_pytree_scale(
|
136
|
-
x1: jtp.PyTree,
|
137
|
-
x2: jtp.PyTree | None = None,
|
138
|
-
rtol: jtp.FloatLike = RTOL_DEFAULT,
|
139
|
-
atol: jtp.FloatLike = ATOL_DEFAULT,
|
140
|
-
) -> jtp.PyTree:
|
141
|
-
"""
|
142
|
-
Compute the component-wise state scale factors to scale dynamical states.
|
143
|
-
|
144
|
-
Args:
|
145
|
-
x1: The first state (often the initial state).
|
146
|
-
x2: The optional second state (often the final state).
|
147
|
-
rtol: The relative tolerance to scale the state.
|
148
|
-
atol: The absolute tolerance to scale the state.
|
149
|
-
|
150
|
-
Returns:
|
151
|
-
A pytree with the same structure of the state containing the scaling factors.
|
152
|
-
"""
|
153
|
-
|
154
|
-
# Consider a zero second pytree, if not given.
|
155
|
-
x2 = jax.tree.map(jnp.zeros_like, x1) if x2 is None else x2
|
156
|
-
|
157
|
-
# Compute the scaling factors of the initial state and its derivative.
|
158
|
-
compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
|
159
|
-
scale = jax.tree.map(compute_scale, x1, x2)
|
160
|
-
|
161
|
-
return scale
|
162
|
-
|
163
|
-
|
164
|
-
@jax.jit
|
165
|
-
def local_error_estimation(
|
166
|
-
xf: jtp.PyTree,
|
167
|
-
xf_estimate: jtp.PyTree | None = None,
|
168
|
-
x0: jtp.PyTree | None = None,
|
169
|
-
rtol: jtp.FloatLike = RTOL_DEFAULT,
|
170
|
-
atol: jtp.FloatLike = ATOL_DEFAULT,
|
171
|
-
norm_ord: jtp.IntLike | jtp.FloatLike = jnp.inf,
|
172
|
-
) -> jtp.Float:
|
173
|
-
"""
|
174
|
-
Estimate the local integration error, often used in Embedded RK schemes.
|
175
|
-
|
176
|
-
Args:
|
177
|
-
xf: The final state, often computed with the most accurate integrator.
|
178
|
-
xf_estimate:
|
179
|
-
The estimated final state, often computed with the less accurate integrator.
|
180
|
-
If missing, it is initialized to zero.
|
181
|
-
x0:
|
182
|
-
The initial state to compute the scaling factors. If missing, it is
|
183
|
-
initialized to zero.
|
184
|
-
rtol: The relative tolerance to scale the state.
|
185
|
-
atol: The absolute tolerance to scale the state.
|
186
|
-
norm_ord:
|
187
|
-
The norm to use to compute the error. Default is the infinity norm.
|
188
|
-
|
189
|
-
Returns:
|
190
|
-
The estimated local integration error.
|
191
|
-
"""
|
192
|
-
|
193
|
-
# Helper to flatten a pytree to a 1D vector.
|
194
|
-
def flatten(pytree) -> jax.Array:
|
195
|
-
return jax.flatten_util.ravel_pytree(pytree=pytree)[0]
|
196
|
-
|
197
|
-
# Compute the scale considering the initial and final states.
|
198
|
-
scale = compute_pytree_scale(x1=xf, x2=x0, rtol=rtol, atol=atol)
|
199
|
-
|
200
|
-
# Consider a zero estimated final state, if not given.
|
201
|
-
xf_estimate = (
|
202
|
-
jax.tree.map(jnp.zeros_like, xf) if xf_estimate is None else xf_estimate
|
203
|
-
)
|
204
|
-
|
205
|
-
# Estimate the error.
|
206
|
-
estimate_error = lambda l, l̂, sc: jnp.abs(l - l̂) / sc
|
207
|
-
error_estimate = jax.tree.map(estimate_error, xf, xf_estimate, scale)
|
208
|
-
|
209
|
-
# Return the highest element of the error estimate.
|
210
|
-
return jnp.linalg.norm(flatten(error_estimate), ord=norm_ord)
|
211
|
-
|
212
|
-
|
213
|
-
# ================================
|
214
|
-
# Embedded Runge-Kutta integrators
|
215
|
-
# ================================
|
216
|
-
|
217
|
-
|
218
|
-
@jax_dataclasses.pytree_dataclass
|
219
|
-
class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
220
|
-
"""
|
221
|
-
An Embedded Runge-Kutta integrator.
|
222
|
-
|
223
|
-
This class implements a general-purpose Embedded Runge-Kutta integrator
|
224
|
-
that can be used to solve ordinary differential equations with adaptive
|
225
|
-
step sizes.
|
226
|
-
|
227
|
-
The integrator is based on an Explicit Runge-Kutta method, and it uses
|
228
|
-
two different solutions to estimate the local integration error. The
|
229
|
-
error is then used to adapt the step size to reach a desired accuracy.
|
230
|
-
"""
|
231
|
-
|
232
|
-
AfterInitKey: ClassVar[str] = "after_init"
|
233
|
-
InitializingKey: ClassVar[str] = "initializing"
|
234
|
-
|
235
|
-
# Define the row of the integration output corresponding to the solution estimate.
|
236
|
-
# This is the row of b.T that produces the state used e.g. by embedded methods to
|
237
|
-
# implement the adaptive timestep logic.
|
238
|
-
row_index_of_solution_estimate: ClassVar[int | None] = None
|
239
|
-
|
240
|
-
# Bounds of the adaptive Δt.
|
241
|
-
dt_max: Static[jtp.FloatLike] = jnp.inf
|
242
|
-
dt_min: Static[jtp.FloatLike] = -jnp.inf
|
243
|
-
|
244
|
-
# Tolerances used to scale the two states corresponding to the high-order solution
|
245
|
-
# and the low-order estimate during the computation of the local integration error.
|
246
|
-
rtol: Static[jtp.FloatLike] = RTOL_DEFAULT
|
247
|
-
atol: Static[jtp.FloatLike] = ATOL_DEFAULT
|
248
|
-
|
249
|
-
# Parameters of the adaptive timestep logic.
|
250
|
-
# Refer to Eq. (4.13) pag. 168 of Hairer93.
|
251
|
-
safety: Static[jtp.FloatLike] = SAFETY_DEFAULT
|
252
|
-
beta_max: Static[jtp.FloatLike] = BETA_MAX_DEFAULT
|
253
|
-
beta_min: Static[jtp.FloatLike] = BETA_MIN_DEFAULT
|
254
|
-
|
255
|
-
# Maximum number of rejected steps when the Δt needs to be reduced.
|
256
|
-
max_step_rejections: Static[jtp.IntLike] = MAX_STEP_REJECTIONS_DEFAULT
|
257
|
-
|
258
|
-
index_of_fsal: jtp.IntLike | None = None
|
259
|
-
fsal_enabled_if_supported: bool = False
|
260
|
-
|
261
|
-
def init(
|
262
|
-
self,
|
263
|
-
x0: State,
|
264
|
-
t0: Time,
|
265
|
-
dt: TimeStep,
|
266
|
-
**kwargs,
|
267
|
-
) -> dict[str, Any]:
|
268
|
-
"""
|
269
|
-
Initialize the integrator and get the metadata.
|
270
|
-
|
271
|
-
Args:
|
272
|
-
x0: The initial state of the system.
|
273
|
-
t0: The initial time of the system.
|
274
|
-
dt: The time step of the integration.
|
275
|
-
**kwargs: Additional parameters.
|
276
|
-
|
277
|
-
Returns:
|
278
|
-
The metadata of the integrator to be passed to the first step.
|
279
|
-
"""
|
280
|
-
|
281
|
-
if jaxsim.utils.tracing(var=jnp.zeros(0)):
|
282
|
-
raise RuntimeError("This method cannot be used within a JIT context")
|
283
|
-
|
284
|
-
with self.editable(validate=False) as integrator:
|
285
|
-
|
286
|
-
# Inject this key to signal that the integrator is initializing.
|
287
|
-
# This is used to allocate the arrays of the metadata dictionary,
|
288
|
-
# that are then filled with NaNs.
|
289
|
-
metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}
|
290
|
-
|
291
|
-
# Run a dummy call of the integrator.
|
292
|
-
# It is used only to get the metadata so that we know the structure
|
293
|
-
# of the corresponding pytree.
|
294
|
-
_ = integrator(
|
295
|
-
x0,
|
296
|
-
jnp.array(t0, dtype=float),
|
297
|
-
jnp.array(dt, dtype=float),
|
298
|
-
**(kwargs | {"metadata": metadata}),
|
299
|
-
)
|
300
|
-
|
301
|
-
# Remove the injected key.
|
302
|
-
_ = metadata.pop(EmbeddedRungeKutta.InitializingKey)
|
303
|
-
|
304
|
-
# Make sure that all leafs of the dictionary are JAX arrays.
|
305
|
-
# Also, since these are dummy parameters, set them all to NaN.
|
306
|
-
metadata_after_init = jax.tree.map(
|
307
|
-
lambda l: jnp.nan * jnp.zeros_like(l), metadata
|
308
|
-
)
|
309
|
-
|
310
|
-
return metadata_after_init
|
311
|
-
|
312
|
-
def __call__(
|
313
|
-
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
314
|
-
) -> tuple[NextState, dict[str, Any]]:
|
315
|
-
"""
|
316
|
-
Integrate the system for a single step.
|
317
|
-
"""
|
318
|
-
|
319
|
-
# This method is called differently in three stages:
|
320
|
-
#
|
321
|
-
# 1. During initialization, to allocate a dummy metadata dictionary.
|
322
|
-
# The metadata is a dictionary of float JAX arrays, that are initialized
|
323
|
-
# with the right shape and filled with NaNs.
|
324
|
-
# 2. During the first step, this method operates on the Nan-filled
|
325
|
-
# `metadata` argument, and it populates with the actual metadata.
|
326
|
-
# 3. After the first step, this method operates on the actual metadata.
|
327
|
-
#
|
328
|
-
# In particular, we store the following information in the metadata:
|
329
|
-
# - The first attempt of the step size, `dt0`. This is either estimated during
|
330
|
-
# phase 2, or taken from the previous step during phase 3.
|
331
|
-
# - For integrators that support FSAL, the derivative at the initial state
|
332
|
-
# computed during the previous step. This can be done because FSAL integrators
|
333
|
-
# evaluate the dynamics at the final state of the previous step, that matches
|
334
|
-
# the initial state of the current step.
|
335
|
-
#
|
336
|
-
metadata = kwargs.pop("metadata", {})
|
337
|
-
|
338
|
-
integrator_init = jnp.array(
|
339
|
-
metadata.get(self.InitializingKey, False), dtype=bool
|
340
|
-
)
|
341
|
-
|
342
|
-
# Close f over optional kwargs.
|
343
|
-
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
|
344
|
-
|
345
|
-
# Define the final time.
|
346
|
-
tf = t0 + dt
|
347
|
-
|
348
|
-
# Initialize solution orders.
|
349
|
-
p = self.order_of_solution
|
350
|
-
p̂ = self.order_of_solution_estimate
|
351
|
-
q = jnp.minimum(p, p̂)
|
352
|
-
|
353
|
-
# The value of dt0 is NaN (or, at least, it should be) only after initialization
|
354
|
-
# and before the first step.
|
355
|
-
metadata["dt0"], metadata["dxdt0"] = jax.lax.cond(
|
356
|
-
pred=("dt0" in metadata) & ~jnp.isnan(metadata.get("dt0", 0.0)).any(),
|
357
|
-
true_fun=lambda metadata: (
|
358
|
-
metadata.get("dt0", jnp.array(0.0, dtype=float)),
|
359
|
-
metadata.get("dxdt0", f(x0, t0)[0]),
|
360
|
-
),
|
361
|
-
false_fun=lambda aux: estimate_step_size(
|
362
|
-
x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
|
363
|
-
),
|
364
|
-
operand=metadata,
|
365
|
-
)
|
366
|
-
|
367
|
-
# Clip the estimated initial step size to the given bounds, if necessary.
|
368
|
-
metadata["dt0"] = jnp.clip(
|
369
|
-
metadata["dt0"],
|
370
|
-
jnp.minimum(self.dt_min, metadata["dt0"]),
|
371
|
-
jnp.minimum(self.dt_max, metadata["dt0"]),
|
372
|
-
)
|
373
|
-
|
374
|
-
# =========================================================
|
375
|
-
# While loop to reach tf from t0 using an adaptive timestep
|
376
|
-
# =========================================================
|
377
|
-
|
378
|
-
# Initialize the carry of the while loop.
|
379
|
-
Carry = tuple[Any, ...]
|
380
|
-
carry0: Carry = (
|
381
|
-
x0,
|
382
|
-
jnp.array(t0).astype(float),
|
383
|
-
metadata,
|
384
|
-
jnp.array(0, dtype=int),
|
385
|
-
jnp.array(False).astype(bool),
|
386
|
-
)
|
387
|
-
|
388
|
-
def while_loop_cond(carry: Carry) -> jtp.Bool:
|
389
|
-
_, _, _, _, break_loop = carry
|
390
|
-
return jnp.logical_not(break_loop)
|
391
|
-
|
392
|
-
# Each loop is an integration step with variable Δt.
|
393
|
-
# Depending on the integration error, the step could be discarded and the
|
394
|
-
# while body ran again from the same (x0, t0) but with a smaller Δt.
|
395
|
-
# We run these loops until the final time tf is reached.
|
396
|
-
def while_loop_body(carry: Carry) -> Carry:
|
397
|
-
|
398
|
-
# Unpack the carry.
|
399
|
-
x0, t0, metadata, discarded_steps, _ = carry
|
400
|
-
|
401
|
-
# Take care of the final adaptive step.
|
402
|
-
# We want the final Δt to let us reach tf exactly.
|
403
|
-
# Then we can exit the while loop.
|
404
|
-
Δt0 = metadata["dt0"]
|
405
|
-
Δt0 = jnp.where(t0 + Δt0 < tf, Δt0, tf - t0)
|
406
|
-
break_loop = jnp.where(t0 + Δt0 < tf, False, True)
|
407
|
-
|
408
|
-
# Run the underlying explicit RK integrator.
|
409
|
-
# The output z contains multiple solutions (depending on the rows of b.T).
|
410
|
-
with self.editable(validate=True) as integrator:
|
411
|
-
z, aux_dict = integrator._compute_next_state(
|
412
|
-
x0=x0, t0=t0, dt=Δt0, **kwargs
|
413
|
-
)
|
414
|
-
metadata_next = aux_dict["metadata"]
|
415
|
-
|
416
|
-
# Extract the high-order solution xf and the low-order estimate x̂f.
|
417
|
-
xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
|
418
|
-
x̂f = jax.tree.map(lambda l: l[self.row_index_of_solution_estimate], z)
|
419
|
-
|
420
|
-
# Calculate the local integration error.
|
421
|
-
local_error = local_error_estimation(
|
422
|
-
x0=x0, xf=xf, xf_estimate=x̂f, rtol=self.rtol, atol=self.atol
|
423
|
-
)
|
424
|
-
|
425
|
-
# Shrink the Δt every time by the safety factor (even when accepted).
|
426
|
-
# The β parameters define the bounds of the timestep update factor.
|
427
|
-
safety = jnp.clip(self.safety, 0.0, 1.0)
|
428
|
-
β_min = jnp.maximum(0.0, self.beta_min)
|
429
|
-
β_max = jnp.maximum(β_min, self.beta_max)
|
430
|
-
|
431
|
-
# Compute the next Δt from the desired integration error.
|
432
|
-
# The computed integration step is accepted if error <= 1.0,
|
433
|
-
# otherwise it is rejected.
|
434
|
-
#
|
435
|
-
# In case of rejection, Δt_next is always smaller than Δt0.
|
436
|
-
# In case of acceptance, Δt_next could either be larger than Δt0,
|
437
|
-
# or slightly smaller than Δt0 depending on the safety factor.
|
438
|
-
Δt_next = Δt0 * jnp.clip(
|
439
|
-
safety * jnp.power(1 / local_error, 1 / (q + 1)),
|
440
|
-
β_min,
|
441
|
-
β_max,
|
442
|
-
)
|
443
|
-
|
444
|
-
def accept_step():
|
445
|
-
# Use Δt_next in the next while loop.
|
446
|
-
# If it is the last one, and Δt0 was clipped, return the initial Δt0.
|
447
|
-
metadata_next_accepted = metadata_next | dict(
|
448
|
-
dt0=jnp.clip(
|
449
|
-
jax.lax.select(
|
450
|
-
pred=break_loop,
|
451
|
-
on_true=metadata["dt0"],
|
452
|
-
on_false=Δt_next,
|
453
|
-
),
|
454
|
-
self.dt_min,
|
455
|
-
self.dt_max,
|
456
|
-
)
|
457
|
-
)
|
458
|
-
|
459
|
-
# Start the next while loop from the final state.
|
460
|
-
x0_next = xf
|
461
|
-
|
462
|
-
# Advance the starting time of the next adaptive step.
|
463
|
-
t0_next = t0 + Δt0
|
464
|
-
|
465
|
-
# Signal that the final time has been reached.
|
466
|
-
break_loop_next = t0 + Δt0 >= tf
|
467
|
-
|
468
|
-
return (
|
469
|
-
x0_next,
|
470
|
-
t0_next,
|
471
|
-
break_loop_next,
|
472
|
-
metadata_next_accepted,
|
473
|
-
jnp.array(0, dtype=int),
|
474
|
-
)
|
475
|
-
|
476
|
-
def reject_step():
|
477
|
-
# Get back the original metadata.
|
478
|
-
metadata_next_rejected = metadata
|
479
|
-
|
480
|
-
# This time, with a reduced Δt.
|
481
|
-
metadata_next_rejected["dt0"] = jnp.clip(
|
482
|
-
Δt_next, self.dt_min, self.dt_max
|
483
|
-
)
|
484
|
-
|
485
|
-
return (
|
486
|
-
x0,
|
487
|
-
t0,
|
488
|
-
False,
|
489
|
-
metadata_next_rejected,
|
490
|
-
discarded_steps + 1,
|
491
|
-
)
|
492
|
-
|
493
|
-
# Decide whether to accept or reject the step.
|
494
|
-
(
|
495
|
-
x0_next,
|
496
|
-
t0_next,
|
497
|
-
break_loop,
|
498
|
-
metadata_next,
|
499
|
-
discarded_steps,
|
500
|
-
) = jax.lax.cond(
|
501
|
-
pred=(discarded_steps >= self.max_step_rejections)
|
502
|
-
| (local_error <= 1.0)
|
503
|
-
| (Δt_next < self.dt_min)
|
504
|
-
| integrator_init,
|
505
|
-
true_fun=accept_step,
|
506
|
-
false_fun=reject_step,
|
507
|
-
)
|
508
|
-
|
509
|
-
return (
|
510
|
-
x0_next,
|
511
|
-
t0_next,
|
512
|
-
metadata_next,
|
513
|
-
discarded_steps,
|
514
|
-
break_loop,
|
515
|
-
)
|
516
|
-
|
517
|
-
# Integrate with adaptive step until tf is reached.
|
518
|
-
(
|
519
|
-
xf,
|
520
|
-
tf,
|
521
|
-
metadata_tf,
|
522
|
-
_,
|
523
|
-
_,
|
524
|
-
) = jax.lax.while_loop(
|
525
|
-
cond_fun=while_loop_cond,
|
526
|
-
body_fun=while_loop_body,
|
527
|
-
init_val=carry0,
|
528
|
-
)
|
529
|
-
|
530
|
-
return xf, {"metadata": metadata_tf}
|
531
|
-
|
532
|
-
@property
|
533
|
-
def order_of_solution(self) -> int:
|
534
|
-
"""
|
535
|
-
The order of the solution.
|
536
|
-
"""
|
537
|
-
return self.order_of_bT_rows[self.row_index_of_solution]
|
538
|
-
|
539
|
-
@property
|
540
|
-
def order_of_solution_estimate(self) -> int:
|
541
|
-
"""
|
542
|
-
The order of the solution estimate.
|
543
|
-
"""
|
544
|
-
return self.order_of_bT_rows[self.row_index_of_solution_estimate]
|
545
|
-
|
546
|
-
@classmethod
|
547
|
-
def build(
|
548
|
-
cls: type[Self],
|
549
|
-
*,
|
550
|
-
dynamics: SystemDynamics[State, StateDerivative],
|
551
|
-
fsal_enabled_if_supported: jtp.BoolLike = True,
|
552
|
-
dt_max: jtp.FloatLike = jnp.inf,
|
553
|
-
dt_min: jtp.FloatLike = -jnp.inf,
|
554
|
-
rtol: jtp.FloatLike = RTOL_DEFAULT,
|
555
|
-
atol: jtp.FloatLike = ATOL_DEFAULT,
|
556
|
-
safety: jtp.FloatLike = SAFETY_DEFAULT,
|
557
|
-
beta_max: jtp.FloatLike = BETA_MAX_DEFAULT,
|
558
|
-
beta_min: jtp.FloatLike = BETA_MIN_DEFAULT,
|
559
|
-
max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT,
|
560
|
-
**kwargs,
|
561
|
-
) -> Self:
|
562
|
-
"""
|
563
|
-
Build an Embedded Runge-Kutta integrator.
|
564
|
-
|
565
|
-
Args:
|
566
|
-
dynamics: The system dynamics function.
|
567
|
-
fsal_enabled_if_supported:
|
568
|
-
Whether to enable the FSAL property if supported by the integrator.
|
569
|
-
dt_max: The maximum step size.
|
570
|
-
dt_min: The minimum step size.
|
571
|
-
rtol: The relative tolerance.
|
572
|
-
atol: The absolute tolerance.
|
573
|
-
safety: The safety factor to shrink the step size.
|
574
|
-
beta_max: The maximum factor to increase the step size.
|
575
|
-
beta_min: The minimum factor to increase the step size.
|
576
|
-
max_step_rejections: The maximum number of step rejections.
|
577
|
-
**kwargs: Additional parameters.
|
578
|
-
"""
|
579
|
-
|
580
|
-
b = cls.__dataclass_fields__["b"].default_factory()
|
581
|
-
|
582
|
-
# Check that b.T has enough rows based on the configured index of the
|
583
|
-
# solution estimate. This is necessary for embedded methods.
|
584
|
-
if (
|
585
|
-
cls.row_index_of_solution_estimate is not None
|
586
|
-
and cls.row_index_of_solution_estimate >= b.T.shape[0]
|
587
|
-
):
|
588
|
-
msg = "The index of the solution estimate ({}-th row of `b.T`) "
|
589
|
-
msg += "is out of range ({})."
|
590
|
-
raise ValueError(
|
591
|
-
msg.format(cls.row_index_of_solution_estimate, b.T.shape[0])
|
592
|
-
)
|
593
|
-
|
594
|
-
integrator = super().build(
|
595
|
-
# Integrator:
|
596
|
-
dynamics=dynamics,
|
597
|
-
# ExplicitRungeKutta:
|
598
|
-
fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
|
599
|
-
# EmbeddedRungeKutta:
|
600
|
-
dt_max=float(dt_max),
|
601
|
-
dt_min=float(dt_min),
|
602
|
-
rtol=float(rtol),
|
603
|
-
atol=float(atol),
|
604
|
-
safety=float(safety),
|
605
|
-
beta_max=float(beta_max),
|
606
|
-
beta_min=float(beta_min),
|
607
|
-
max_step_rejections=int(max_step_rejections),
|
608
|
-
**kwargs,
|
609
|
-
)
|
610
|
-
|
611
|
-
return integrator
|
612
|
-
|
613
|
-
|
614
|
-
@jax_dataclasses.pytree_dataclass
|
615
|
-
class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
|
616
|
-
"""
|
617
|
-
The Heun-Euler integrator for SO(3) dynamics.
|
618
|
-
"""
|
619
|
-
|
620
|
-
A: jtp.Matrix = dataclasses.field(
|
621
|
-
default_factory=lambda: jnp.array(
|
622
|
-
[
|
623
|
-
[0, 0],
|
624
|
-
[1, 0],
|
625
|
-
]
|
626
|
-
).astype(float),
|
627
|
-
compare=False,
|
628
|
-
)
|
629
|
-
|
630
|
-
b: jtp.Matrix = dataclasses.field(
|
631
|
-
default_factory=lambda: (
|
632
|
-
jnp.atleast_2d(
|
633
|
-
jnp.array(
|
634
|
-
[
|
635
|
-
[1 / 2, 1 / 2],
|
636
|
-
[1, 0],
|
637
|
-
]
|
638
|
-
),
|
639
|
-
)
|
640
|
-
.astype(float)
|
641
|
-
.transpose()
|
642
|
-
),
|
643
|
-
compare=False,
|
644
|
-
)
|
645
|
-
|
646
|
-
c: jtp.Vector = dataclasses.field(
|
647
|
-
default_factory=lambda: jnp.array(
|
648
|
-
[0, 1],
|
649
|
-
).astype(float),
|
650
|
-
compare=False,
|
651
|
-
)
|
652
|
-
|
653
|
-
row_index_of_solution: ClassVar[int] = 0
|
654
|
-
row_index_of_solution_estimate: ClassVar[int | None] = 1
|
655
|
-
|
656
|
-
order_of_bT_rows: ClassVar[tuple[int, ...]] = (2, 1)
|
657
|
-
|
658
|
-
index_of_fsal: jtp.IntLike | None = None
|
659
|
-
fsal_enabled_if_supported: bool = False
|
660
|
-
|
661
|
-
|
662
|
-
@jax_dataclasses.pytree_dataclass
|
663
|
-
class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
|
664
|
-
"""
|
665
|
-
The Bogacki-Shampine integrator for SO(3) dynamics.
|
666
|
-
"""
|
667
|
-
|
668
|
-
A: jtp.Matrix = dataclasses.field(
|
669
|
-
default_factory=lambda: jnp.array(
|
670
|
-
[
|
671
|
-
[0, 0, 0, 0],
|
672
|
-
[1 / 2, 0, 0, 0],
|
673
|
-
[0, 3 / 4, 0, 0],
|
674
|
-
[2 / 9, 1 / 3, 4 / 9, 0],
|
675
|
-
]
|
676
|
-
).astype(float),
|
677
|
-
compare=False,
|
678
|
-
)
|
679
|
-
|
680
|
-
b: jtp.Matrix = dataclasses.field(
|
681
|
-
default_factory=lambda: (
|
682
|
-
jnp.atleast_2d(
|
683
|
-
jnp.array(
|
684
|
-
[
|
685
|
-
[2 / 9, 1 / 3, 4 / 9, 0],
|
686
|
-
[7 / 24, 1 / 4, 1 / 3, 1 / 8],
|
687
|
-
]
|
688
|
-
),
|
689
|
-
)
|
690
|
-
.astype(float)
|
691
|
-
.transpose()
|
692
|
-
),
|
693
|
-
compare=False,
|
694
|
-
)
|
695
|
-
|
696
|
-
c: jtp.Vector = dataclasses.field(
|
697
|
-
default_factory=lambda: jnp.array(
|
698
|
-
[0, 1 / 2, 3 / 4, 1],
|
699
|
-
).astype(float),
|
700
|
-
compare=False,
|
701
|
-
)
|
702
|
-
|
703
|
-
row_index_of_solution: ClassVar[int] = 0
|
704
|
-
row_index_of_solution_estimate: ClassVar[int | None] = 1
|
705
|
-
|
706
|
-
order_of_bT_rows: ClassVar[tuple[int, ...]] = (3, 2)
|