jaxsim 0.6.1.dev13__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
- jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.1.dev13.dist-info/RECORD +0 -74
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/integrators/common.py
DELETED
@@ -1,592 +0,0 @@
|
|
1
|
-
import abc
|
2
|
-
import dataclasses
|
3
|
-
from typing import Any, ClassVar, Generic, Protocol, TypeVar
|
4
|
-
|
5
|
-
import jax
|
6
|
-
import jax.numpy as jnp
|
7
|
-
import jax_dataclasses
|
8
|
-
from jax_dataclasses import Static
|
9
|
-
|
10
|
-
import jaxsim.api as js
|
11
|
-
import jaxsim.math
|
12
|
-
import jaxsim.typing as jtp
|
13
|
-
from jaxsim import exceptions, logging
|
14
|
-
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
|
15
|
-
|
16
|
-
try:
|
17
|
-
from typing import override
|
18
|
-
except ImportError:
|
19
|
-
from typing_extensions import override
|
20
|
-
|
21
|
-
try:
|
22
|
-
from typing import Self
|
23
|
-
except ImportError:
|
24
|
-
from typing_extensions import Self
|
25
|
-
|
26
|
-
|
27
|
-
# =============
|
28
|
-
# Generic types
|
29
|
-
# =============
|
30
|
-
|
31
|
-
Time = jtp.FloatLike
|
32
|
-
TimeStep = jtp.FloatLike
|
33
|
-
State = NextState = TypeVar("State")
|
34
|
-
StateDerivative = TypeVar("StateDerivative")
|
35
|
-
PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
|
36
|
-
|
37
|
-
|
38
|
-
class SystemDynamics(Protocol[State, StateDerivative]):
|
39
|
-
"""
|
40
|
-
Protocol defining the system dynamics.
|
41
|
-
"""
|
42
|
-
|
43
|
-
def __call__(
|
44
|
-
self, x: State, t: Time, **kwargs
|
45
|
-
) -> tuple[StateDerivative, dict[str, Any]]:
|
46
|
-
"""
|
47
|
-
Compute the state derivative of the system.
|
48
|
-
|
49
|
-
Args:
|
50
|
-
x: The state of the system.
|
51
|
-
t: The time of the system.
|
52
|
-
**kwargs: Additional keyword arguments.
|
53
|
-
|
54
|
-
Returns:
|
55
|
-
The state derivative of the system and the auxiliary dictionary.
|
56
|
-
"""
|
57
|
-
pass
|
58
|
-
|
59
|
-
|
60
|
-
# =======================
|
61
|
-
# Base integrator classes
|
62
|
-
# =======================
|
63
|
-
|
64
|
-
|
65
|
-
@jax_dataclasses.pytree_dataclass
|
66
|
-
class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
67
|
-
"""
|
68
|
-
Factory class for integrators.
|
69
|
-
"""
|
70
|
-
|
71
|
-
dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
|
72
|
-
repr=False, hash=False, compare=False, kw_only=True
|
73
|
-
)
|
74
|
-
|
75
|
-
@classmethod
|
76
|
-
def build(
|
77
|
-
cls: type[Self],
|
78
|
-
*,
|
79
|
-
dynamics: SystemDynamics[State, StateDerivative],
|
80
|
-
**kwargs,
|
81
|
-
) -> Self:
|
82
|
-
"""
|
83
|
-
Build the integrator object.
|
84
|
-
|
85
|
-
Args:
|
86
|
-
dynamics: The system dynamics.
|
87
|
-
**kwargs: Additional keyword arguments to build the integrator.
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
The integrator object.
|
91
|
-
"""
|
92
|
-
|
93
|
-
return cls(dynamics=dynamics, **kwargs)
|
94
|
-
|
95
|
-
def step(
|
96
|
-
self,
|
97
|
-
x0: State,
|
98
|
-
t0: Time,
|
99
|
-
dt: TimeStep,
|
100
|
-
*,
|
101
|
-
metadata: dict[str, Any] | None = None,
|
102
|
-
**kwargs,
|
103
|
-
) -> tuple[NextState, dict[str, Any]]:
|
104
|
-
"""
|
105
|
-
Perform a single integration step.
|
106
|
-
|
107
|
-
Args:
|
108
|
-
x0: The initial state of the system.
|
109
|
-
t0: The initial time of the system.
|
110
|
-
dt: The time step of the integration.
|
111
|
-
metadata: The state auxiliary dictionary of the integrator.
|
112
|
-
**kwargs: Additional keyword arguments.
|
113
|
-
|
114
|
-
Returns:
|
115
|
-
The final state of the system and the updated auxiliary dictionary.
|
116
|
-
"""
|
117
|
-
|
118
|
-
metadata = metadata if metadata is not None else {}
|
119
|
-
|
120
|
-
with self.mutable_context(mutability=Mutability.MUTABLE) as integrator:
|
121
|
-
xf, metadata_step = integrator(x0, t0, dt, **kwargs)
|
122
|
-
|
123
|
-
return (
|
124
|
-
xf,
|
125
|
-
metadata | metadata_step,
|
126
|
-
)
|
127
|
-
|
128
|
-
@abc.abstractmethod
|
129
|
-
def __call__(
|
130
|
-
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
131
|
-
) -> tuple[NextState, dict[str, Any]]:
|
132
|
-
"""
|
133
|
-
Perform a single integration step.
|
134
|
-
"""
|
135
|
-
pass
|
136
|
-
|
137
|
-
def init(
|
138
|
-
self,
|
139
|
-
x0: State,
|
140
|
-
t0: Time,
|
141
|
-
dt: TimeStep,
|
142
|
-
*,
|
143
|
-
include_dynamics_aux_dict: bool = False,
|
144
|
-
**kwargs,
|
145
|
-
) -> dict[str, Any]:
|
146
|
-
"""
|
147
|
-
Initialize the integrator. This method is deprecated.
|
148
|
-
"""
|
149
|
-
|
150
|
-
logging.warning(
|
151
|
-
"The 'init' method has been deprecated. There is no need to call it."
|
152
|
-
)
|
153
|
-
|
154
|
-
return {}
|
155
|
-
|
156
|
-
|
157
|
-
@jax_dataclasses.pytree_dataclass
|
158
|
-
class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
|
159
|
-
"""
|
160
|
-
Base class for explicit Runge-Kutta integrators.
|
161
|
-
|
162
|
-
Attributes:
|
163
|
-
A: The Runge-Kutta matrix.
|
164
|
-
b: The weights coefficients.
|
165
|
-
c: The nodes coefficients.
|
166
|
-
order_of_bT_rows: The order of the solution.
|
167
|
-
row_index_of_solution: The row of the integration output corresponding to the final solution.
|
168
|
-
fsal_enabled_if_supported: Whether to enable the FSAL property, if supported.
|
169
|
-
index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration.
|
170
|
-
"""
|
171
|
-
|
172
|
-
# The Runge-Kutta matrix.
|
173
|
-
A: jtp.Matrix
|
174
|
-
|
175
|
-
# The weights coefficients.
|
176
|
-
# Note that in practice we typically use its transpose `b.transpose()`.
|
177
|
-
b: jtp.Matrix
|
178
|
-
|
179
|
-
# The nodes coefficients.
|
180
|
-
c: jtp.Vector
|
181
|
-
|
182
|
-
# Define the order of the solution.
|
183
|
-
# It should have as many elements as the number of rows of `b.transpose()`.
|
184
|
-
order_of_bT_rows: ClassVar[tuple[int, ...]]
|
185
|
-
|
186
|
-
# Define the row of the integration output corresponding to the final solution.
|
187
|
-
# This is the row of b.T that produces the final state.
|
188
|
-
row_index_of_solution: ClassVar[int]
|
189
|
-
|
190
|
-
# Attributes of FSAL (first-same-as-last) property.
|
191
|
-
fsal_enabled_if_supported: Static[bool] = dataclasses.field(repr=False)
|
192
|
-
index_of_fsal: Static[jtp.IntLike | None] = dataclasses.field(repr=False)
|
193
|
-
|
194
|
-
@property
|
195
|
-
def has_fsal(self) -> bool:
|
196
|
-
"""
|
197
|
-
Check if the integrator supports the FSAL property.
|
198
|
-
"""
|
199
|
-
return self.fsal_enabled_if_supported and self.index_of_fsal is not None
|
200
|
-
|
201
|
-
@property
|
202
|
-
def order(self) -> int:
|
203
|
-
"""
|
204
|
-
Return the order of the integrator.
|
205
|
-
"""
|
206
|
-
return self.order_of_bT_rows[self.row_index_of_solution]
|
207
|
-
|
208
|
-
@override
|
209
|
-
@classmethod
|
210
|
-
def build(
|
211
|
-
cls: type[Self],
|
212
|
-
*,
|
213
|
-
dynamics: SystemDynamics[State, StateDerivative],
|
214
|
-
fsal_enabled_if_supported: jtp.BoolLike = True,
|
215
|
-
**kwargs,
|
216
|
-
) -> Self:
|
217
|
-
"""
|
218
|
-
Build the integrator object.
|
219
|
-
|
220
|
-
Args:
|
221
|
-
dynamics: The system dynamics.
|
222
|
-
fsal_enabled_if_supported:
|
223
|
-
Whether to enable the FSAL property, if supported.
|
224
|
-
**kwargs: Additional keyword arguments to build the integrator.
|
225
|
-
|
226
|
-
Returns:
|
227
|
-
The integrator object.
|
228
|
-
"""
|
229
|
-
A = cls.__dataclass_fields__["A"].default_factory()
|
230
|
-
b = cls.__dataclass_fields__["b"].default_factory()
|
231
|
-
c = cls.__dataclass_fields__["c"].default_factory()
|
232
|
-
|
233
|
-
# Check validity of the Butcher tableau.
|
234
|
-
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
|
235
|
-
raise ValueError("The Butcher tableau of this class is not valid.")
|
236
|
-
|
237
|
-
# Check that b.T has enough rows based on the configured index of the solution.
|
238
|
-
if cls.row_index_of_solution >= b.T.shape[0]:
|
239
|
-
msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
|
240
|
-
raise ValueError(msg.format(cls.row_index_of_solution, b.T.shape[0]))
|
241
|
-
|
242
|
-
# Check that the tuple containing the order of the b.T rows matches the number
|
243
|
-
# of the b.T rows.
|
244
|
-
if len(cls.order_of_bT_rows) != b.T.shape[0]:
|
245
|
-
msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}."
|
246
|
-
raise ValueError(msg.format(len(cls.order_of_bT_rows), b.T.shape[0]))
|
247
|
-
|
248
|
-
# Check if the Butcher tableau supports FSAL (first-same-as-last).
|
249
|
-
# If it does, store the index of the intermediate derivative to be used as the
|
250
|
-
# first derivative of the next iteration.
|
251
|
-
has_fsal, index_of_fsal = ( # noqa: F841
|
252
|
-
ExplicitRungeKutta.butcher_tableau_supports_fsal(
|
253
|
-
A=A, b=b, c=c, index_of_solution=cls.row_index_of_solution
|
254
|
-
)
|
255
|
-
)
|
256
|
-
|
257
|
-
# Build the integrator object.
|
258
|
-
integrator = super().build(
|
259
|
-
dynamics=dynamics,
|
260
|
-
index_of_fsal=index_of_fsal,
|
261
|
-
fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
|
262
|
-
**kwargs,
|
263
|
-
)
|
264
|
-
|
265
|
-
return integrator
|
266
|
-
|
267
|
-
def __call__(
|
268
|
-
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
269
|
-
) -> tuple[NextState, dict[str, Any]]:
|
270
|
-
"""
|
271
|
-
Perform a single integration step.
|
272
|
-
"""
|
273
|
-
|
274
|
-
# Here z is a batched state with as many batch elements as b.T rows.
|
275
|
-
# Note that z has multiple batches only if b.T has more than one row,
|
276
|
-
# e.g. in Butcher tableau of embedded schemes.
|
277
|
-
z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
|
278
|
-
|
279
|
-
# The next state is the batch element located at the configured index of solution.
|
280
|
-
next_state = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
|
281
|
-
|
282
|
-
return next_state, aux_dict
|
283
|
-
|
284
|
-
@classmethod
|
285
|
-
def integrate_rk_stage(
|
286
|
-
cls, x0: State, t0: Time, dt: TimeStep, k: StateDerivative
|
287
|
-
) -> NextState:
|
288
|
-
"""
|
289
|
-
Integrate a single stage of the Runge-Kutta method.
|
290
|
-
|
291
|
-
Args:
|
292
|
-
x0: The initial state of the system.
|
293
|
-
t0: The initial time of the system.
|
294
|
-
dt:
|
295
|
-
The time step of the RK integration scheme. Note that this is
|
296
|
-
not the stage timestep, as it depends on the `A` matrix used
|
297
|
-
to compute the `k` argument.
|
298
|
-
k:
|
299
|
-
The RK state derivative of the current stage, weighted with
|
300
|
-
the `A` matrix.
|
301
|
-
|
302
|
-
Returns:
|
303
|
-
The state at the next stage of the integration.
|
304
|
-
|
305
|
-
Note:
|
306
|
-
In the most generic case, `k` could be an arbitrary composition
|
307
|
-
of the kᵢ derivatives, depending on the RK matrix A.
|
308
|
-
|
309
|
-
Note:
|
310
|
-
Overriding this method allows users to use different classes
|
311
|
-
defining `State` and `StateDerivative`. Be aware that the
|
312
|
-
timestep `dt` is not the stage timestep, therefore the map
|
313
|
-
used to convert the state derivative must be time-independent.
|
314
|
-
"""
|
315
|
-
|
316
|
-
op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
|
317
|
-
return jax.tree.map(op, x0, k)
|
318
|
-
|
319
|
-
@classmethod
|
320
|
-
def post_process_state(
|
321
|
-
cls, x0: State, t0: Time, xf: NextState, dt: TimeStep
|
322
|
-
) -> NextState:
|
323
|
-
r"""
|
324
|
-
Post-process the integrated state at :math:`t_f = t_0 + \Delta t`.
|
325
|
-
|
326
|
-
Args:
|
327
|
-
x0: The initial state of the system.
|
328
|
-
t0: The initial time of the system.
|
329
|
-
xf: The final state of the system obtain through the integration.
|
330
|
-
dt: The time step used for the integration.
|
331
|
-
|
332
|
-
Returns:
|
333
|
-
The post-processed integrated state.
|
334
|
-
"""
|
335
|
-
|
336
|
-
return xf
|
337
|
-
|
338
|
-
def _compute_next_state(
|
339
|
-
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
340
|
-
) -> tuple[NextState, dict[str, Any]]:
|
341
|
-
"""
|
342
|
-
Compute the next state of the system, returning all the output states.
|
343
|
-
|
344
|
-
Args:
|
345
|
-
x0: The initial state of the system.
|
346
|
-
t0: The initial time of the system.
|
347
|
-
dt: The time step of the integration.
|
348
|
-
**kwargs: Additional keyword arguments.
|
349
|
-
|
350
|
-
Returns:
|
351
|
-
A batched state with as many batch elements as `b.T` rows.
|
352
|
-
"""
|
353
|
-
|
354
|
-
# Call variables with better symbols.
|
355
|
-
Δt = dt
|
356
|
-
c = self.c
|
357
|
-
b = self.b
|
358
|
-
A = self.A
|
359
|
-
|
360
|
-
# Extract metadata from the kwargs.
|
361
|
-
metadata = kwargs.pop("metadata", {})
|
362
|
-
|
363
|
-
# Close f over optional kwargs.
|
364
|
-
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
|
365
|
-
|
366
|
-
# Initialize the carry of the for loop with the stacked kᵢ vectors.
|
367
|
-
carry0 = jax.tree.map(
|
368
|
-
lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
|
369
|
-
)
|
370
|
-
|
371
|
-
# Closure on metadata to either evaluate the dynamics at the initial state
|
372
|
-
# or to use the previous state derivative (only integrators supporting FSAL).
|
373
|
-
def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
|
374
|
-
ẋ0, aux_dict = f(x0, t0)
|
375
|
-
return metadata.get("dxdt0", ẋ0), aux_dict
|
376
|
-
|
377
|
-
# We use a `jax.lax.scan` to compile the `f` function only once.
|
378
|
-
# Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
|
379
|
-
# would include 4 repetitions of the `f` logic, making everything extremely slow.
|
380
|
-
def scan_body(
|
381
|
-
carry: jax.Array, i: int | jax.Array
|
382
|
-
) -> tuple[jax.Array, dict[str, Any]]:
|
383
|
-
"""
|
384
|
-
Compute the kᵢ derivative of the Runge-Kutta stage.
|
385
|
-
"""
|
386
|
-
|
387
|
-
# Unpack the carry, i.e. the stacked kᵢ vectors.
|
388
|
-
K = carry
|
389
|
-
|
390
|
-
# Define the computation of the Runge-Kutta stage.
|
391
|
-
def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
|
392
|
-
|
393
|
-
# Compute ∑ⱼ aᵢⱼ kⱼ.
|
394
|
-
op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
|
395
|
-
sum_ak = jax.tree.map(op_sum_ak, K)
|
396
|
-
|
397
|
-
# Compute the next state for the kᵢ evaluation.
|
398
|
-
# Note that this is not a Δt integration since aᵢⱼ could be fractional.
|
399
|
-
xi = self.integrate_rk_stage(x0, t0, Δt, sum_ak)
|
400
|
-
|
401
|
-
# Compute the next time for the kᵢ evaluation.
|
402
|
-
ti = t0 + c[i] * Δt
|
403
|
-
|
404
|
-
# Evaluate the dynamics.
|
405
|
-
ki, aux_dict = f(xi, ti)
|
406
|
-
return ki, aux_dict
|
407
|
-
|
408
|
-
# This selector enables FSAL property in the first iteration (i=0).
|
409
|
-
ki, aux_dict = jax.lax.cond(
|
410
|
-
pred=jnp.logical_and(i == 0, self.has_fsal),
|
411
|
-
true_fun=get_ẋ0_and_aux_dict,
|
412
|
-
false_fun=compute_ki,
|
413
|
-
)
|
414
|
-
|
415
|
-
# Store the kᵢ derivative in K.
|
416
|
-
op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
|
417
|
-
K = jax.tree.map(op, K, ki)
|
418
|
-
|
419
|
-
carry = K
|
420
|
-
return carry, aux_dict
|
421
|
-
|
422
|
-
# Compute the state derivatives kᵢ.
|
423
|
-
K, aux_dict = jax.lax.scan(
|
424
|
-
f=scan_body,
|
425
|
-
init=carry0,
|
426
|
-
xs=jnp.arange(c.size),
|
427
|
-
)
|
428
|
-
|
429
|
-
# Update the FSAL property for the next iteration.
|
430
|
-
if self.has_fsal:
|
431
|
-
# Store the first derivative of the next step in the metadata.
|
432
|
-
metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
|
433
|
-
|
434
|
-
# Compute the output state.
|
435
|
-
# Note that z contains as many new states as the rows of `b.T`.
|
436
|
-
op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k)
|
437
|
-
z = jax.tree.map(op, x0, K)
|
438
|
-
|
439
|
-
# Transform the final state of the integration.
|
440
|
-
# This allows to inject custom logic, if needed.
|
441
|
-
z_transformed = jax.vmap(
|
442
|
-
lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
|
443
|
-
)(z)
|
444
|
-
|
445
|
-
return z_transformed, aux_dict | {"metadata": metadata}
|
446
|
-
|
447
|
-
@staticmethod
|
448
|
-
def butcher_tableau_is_valid(
|
449
|
-
A: jtp.Matrix, b: jtp.Matrix, c: jtp.Vector
|
450
|
-
) -> jtp.Bool:
|
451
|
-
"""
|
452
|
-
Check if the Butcher tableau is valid.
|
453
|
-
|
454
|
-
Args:
|
455
|
-
A: The Runge-Kutta matrix.
|
456
|
-
b: The weights coefficients.
|
457
|
-
c: The nodes coefficients.
|
458
|
-
|
459
|
-
Returns:
|
460
|
-
`True` if the Butcher tableau is valid, `False` otherwise.
|
461
|
-
"""
|
462
|
-
|
463
|
-
valid = True
|
464
|
-
valid = valid and A.ndim == 2
|
465
|
-
valid = valid and b.ndim == 2
|
466
|
-
valid = valid and c.ndim == 1
|
467
|
-
valid = valid and b.T.shape[0] <= 2
|
468
|
-
valid = valid and A.shape[0] == A.shape[1]
|
469
|
-
valid = valid and A.shape == (c.size, b.T.shape[1])
|
470
|
-
valid = valid and bool(jnp.all(b.T.sum(axis=1) == 1))
|
471
|
-
|
472
|
-
return valid
|
473
|
-
|
474
|
-
@staticmethod
|
475
|
-
def butcher_tableau_is_explicit(A: jtp.Matrix) -> jtp.Bool:
|
476
|
-
"""
|
477
|
-
Check if the Butcher tableau corresponds to an explicit integration scheme.
|
478
|
-
|
479
|
-
Args:
|
480
|
-
A: The Runge-Kutta matrix.
|
481
|
-
|
482
|
-
Returns:
|
483
|
-
`True` if the Butcher tableau is explicit, `False` otherwise.
|
484
|
-
"""
|
485
|
-
|
486
|
-
return jnp.allclose(A, jnp.tril(A, k=-1))
|
487
|
-
|
488
|
-
@staticmethod
|
489
|
-
def butcher_tableau_supports_fsal(
|
490
|
-
A: jtp.Matrix,
|
491
|
-
b: jtp.Matrix,
|
492
|
-
c: jtp.Vector,
|
493
|
-
index_of_solution: jtp.IntLike = 0,
|
494
|
-
) -> tuple[bool, int | None]:
|
495
|
-
"""
|
496
|
-
Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
|
497
|
-
|
498
|
-
Args:
|
499
|
-
A: The Runge-Kutta matrix.
|
500
|
-
b: The weights coefficients.
|
501
|
-
c: The nodes coefficients.
|
502
|
-
index_of_solution:
|
503
|
-
The index of the row of `b.T` corresponding to the solution.
|
504
|
-
|
505
|
-
Returns:
|
506
|
-
A tuple containing a boolean indicating whether the Butcher tableau supports
|
507
|
-
FSAL, and the index i of the intermediate kᵢ derivative corresponding to the
|
508
|
-
initial derivative `f(x0, t0)` of the next step.
|
509
|
-
"""
|
510
|
-
|
511
|
-
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
|
512
|
-
raise ValueError("The Butcher tableau is not valid.")
|
513
|
-
|
514
|
-
if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A):
|
515
|
-
return False, None
|
516
|
-
|
517
|
-
if index_of_solution >= b.T.shape[0]:
|
518
|
-
msg = "The index of the solution (i-th row of `b.T`) is out of range."
|
519
|
-
raise ValueError(msg)
|
520
|
-
|
521
|
-
if c[0] != 0:
|
522
|
-
return False, None
|
523
|
-
|
524
|
-
# Find all the rows of A where c = 1 (therefore at t=tf). The Butcher tableau
|
525
|
-
# supports FSAL if any of these rows (there might be more rows with c=1) matches
|
526
|
-
# the rows of b.T corresponding to the next state (marked by `index_of_solution`).
|
527
|
-
# This last condition means that the last kᵢ derivative is computed at (tf, xf),
|
528
|
-
# that corresponds to the (t0, x0) pair of the next integration call.
|
529
|
-
rows_of_A_with_fsal = (A == b.T[None, index_of_solution]).all(axis=1)
|
530
|
-
rows_of_A_with_fsal = jnp.logical_and(rows_of_A_with_fsal, (c == 1))
|
531
|
-
|
532
|
-
# If there is no match, it means that the Butcher tableau does not support FSAL.
|
533
|
-
if not rows_of_A_with_fsal.any():
|
534
|
-
return False, None
|
535
|
-
|
536
|
-
# Return the index of the row of A providing the fsal derivative (that is the
|
537
|
-
# possibly intermediate kᵢ derivative).
|
538
|
-
# Note that if multiple rows match (it should not), we return the first match.
|
539
|
-
return True, int(jnp.where(rows_of_A_with_fsal)[0].tolist()[0])
|
540
|
-
|
541
|
-
|
542
|
-
class ExplicitRungeKuttaSO3Mixin:
|
543
|
-
"""
|
544
|
-
Mixin class to apply over explicit RK integrators defined on
|
545
|
-
`PyTreeType = ODEState` to integrate the quaternion on SO(3).
|
546
|
-
"""
|
547
|
-
|
548
|
-
@classmethod
|
549
|
-
def post_process_state(
|
550
|
-
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
|
551
|
-
) -> js.ode_data.ODEState:
|
552
|
-
r"""
|
553
|
-
Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the
|
554
|
-
quaternion is normalized.
|
555
|
-
|
556
|
-
Args:
|
557
|
-
x0: The initial state of the system.
|
558
|
-
t0: The initial time of the system.
|
559
|
-
xf: The final state of the system obtain through the integration.
|
560
|
-
dt: The time step used for the integration.
|
561
|
-
"""
|
562
|
-
|
563
|
-
# Extract the initial base quaternion.
|
564
|
-
W_Q_B_t0 = x0.physics_model.base_quaternion
|
565
|
-
|
566
|
-
# We assume that the initial quaternion is already unary.
|
567
|
-
exceptions.raise_runtime_error_if(
|
568
|
-
condition=~jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0),
|
569
|
-
msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
|
570
|
-
)
|
571
|
-
|
572
|
-
# Get the angular velocity ω to integrate the quaternion.
|
573
|
-
# This velocity ω[t0] is computed in the previous timestep by averaging the kᵢ
|
574
|
-
# corresponding to the active RK-based scheme. Therefore, by using the ω[t0],
|
575
|
-
# we obtain an explicit RK scheme operating on the SO(3) manifold.
|
576
|
-
# Note that the current integrator is not a semi-implicit scheme, therefore
|
577
|
-
# using the final ω[tf] would be not correct.
|
578
|
-
W_ω_WB_t0 = x0.physics_model.base_angular_velocity
|
579
|
-
|
580
|
-
# Integrate the quaternion on SO(3).
|
581
|
-
W_Q_B_tf = jaxsim.math.Quaternion.integration(
|
582
|
-
quaternion=W_Q_B_t0,
|
583
|
-
dt=dt,
|
584
|
-
omega=W_ω_WB_t0,
|
585
|
-
omega_in_body_fixed=False,
|
586
|
-
)
|
587
|
-
|
588
|
-
# Replace the quaternion in the final state.
|
589
|
-
return xf.replace(
|
590
|
-
physics_model=xf.physics_model.replace(base_quaternion=W_Q_B_tf),
|
591
|
-
validate=True,
|
592
|
-
)
|