jaxsim 0.2.dev101__py3-none-any.whl → 0.2.dev166__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +1 -0
- jaxsim/api/contact.py +194 -0
- jaxsim/api/data.py +951 -0
- jaxsim/api/joint.py +148 -0
- jaxsim/api/link.py +262 -0
- jaxsim/api/model.py +1099 -0
- jaxsim/api/ode.py +280 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +508 -0
- jaxsim/integrators/fixed_step.py +158 -0
- jaxsim/mujoco/__init__.py +1 -1
- jaxsim/mujoco/loaders.py +30 -18
- jaxsim/mujoco/visualizer.py +3 -1
- jaxsim/physics/algos/soft_contacts.py +97 -28
- jaxsim/physics/model/physics_model.py +30 -0
- jaxsim/physics/model/physics_model_state.py +110 -11
- jaxsim/simulation/ode_data.py +43 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/METADATA +2 -1
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/RECORD +23 -13
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/WHEEL +0 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,508 @@
|
|
1
|
+
import abc
|
2
|
+
import dataclasses
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol, Type, TypeVar
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax_dataclasses
|
8
|
+
from jax_dataclasses import Static
|
9
|
+
|
10
|
+
import jaxsim.typing as jtp
|
11
|
+
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
|
12
|
+
|
13
|
+
try:
|
14
|
+
from typing import override
|
15
|
+
except ImportError:
|
16
|
+
from typing_extensions import override
|
17
|
+
|
18
|
+
try:
|
19
|
+
from typing import Self
|
20
|
+
except ImportError:
|
21
|
+
from typing_extensions import Self
|
22
|
+
|
23
|
+
|
24
|
+
# =============
|
25
|
+
# Generic types
|
26
|
+
# =============
|
27
|
+
|
28
|
+
Time = jax.typing.ArrayLike
|
29
|
+
TimeStep = jax.typing.ArrayLike
|
30
|
+
State = NextState = TypeVar("State")
|
31
|
+
StateDerivative = TypeVar("StateDerivative")
|
32
|
+
PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
|
33
|
+
|
34
|
+
|
35
|
+
class SystemDynamics(Protocol[State, StateDerivative]):
|
36
|
+
def __call__(
|
37
|
+
self, x: State, t: Time, **kwargs
|
38
|
+
) -> tuple[StateDerivative, dict[str, Any]]: ...
|
39
|
+
|
40
|
+
|
41
|
+
# =======================
|
42
|
+
# Base integrator classes
|
43
|
+
# =======================
|
44
|
+
|
45
|
+
|
46
|
+
@jax_dataclasses.pytree_dataclass
|
47
|
+
class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
48
|
+
|
49
|
+
AuxDictDynamicsKey: ClassVar[str] = "aux_dict_dynamics"
|
50
|
+
|
51
|
+
dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
|
52
|
+
repr=False, hash=False, compare=False, kw_only=True
|
53
|
+
)
|
54
|
+
|
55
|
+
params: dict[str, Any] = dataclasses.field(
|
56
|
+
default_factory=dict, repr=False, hash=False, compare=False, kw_only=True
|
57
|
+
)
|
58
|
+
|
59
|
+
@classmethod
|
60
|
+
def build(
|
61
|
+
cls: Type[Self], *, dynamics: SystemDynamics[State, StateDerivative], **kwargs
|
62
|
+
) -> Self:
|
63
|
+
"""
|
64
|
+
Build the integrator object.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
dynamics: The system dynamics.
|
68
|
+
**kwargs: Additional keyword arguments to build the integrator.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
The integrator object.
|
72
|
+
"""
|
73
|
+
|
74
|
+
return cls(dynamics=dynamics, **kwargs) # noqa
|
75
|
+
|
76
|
+
def step(
|
77
|
+
self,
|
78
|
+
x0: State,
|
79
|
+
t0: Time,
|
80
|
+
dt: TimeStep,
|
81
|
+
*,
|
82
|
+
params: dict[str, Any],
|
83
|
+
**kwargs,
|
84
|
+
) -> tuple[State, dict[str, Any]]:
|
85
|
+
"""
|
86
|
+
Perform a single integration step.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
x0: The initial state of the system.
|
90
|
+
t0: The initial time of the system.
|
91
|
+
dt: The time step of the integration.
|
92
|
+
params: The auxiliary dictionary of the integrator.
|
93
|
+
**kwargs: Additional keyword arguments.
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
The final state of the system and the updated auxiliary dictionary.
|
97
|
+
"""
|
98
|
+
|
99
|
+
with self.editable(validate=False) as integrator:
|
100
|
+
integrator.params = params
|
101
|
+
|
102
|
+
with integrator.mutable_context(mutability=Mutability.MUTABLE):
|
103
|
+
xf = integrator(x0, t0, dt, **kwargs)
|
104
|
+
|
105
|
+
assert Integrator.AuxDictDynamicsKey in integrator.params
|
106
|
+
|
107
|
+
return xf, integrator.params
|
108
|
+
|
109
|
+
@abc.abstractmethod
|
110
|
+
def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
|
111
|
+
pass
|
112
|
+
|
113
|
+
def init(
|
114
|
+
self,
|
115
|
+
x0: State,
|
116
|
+
t0: Time,
|
117
|
+
dt: TimeStep,
|
118
|
+
*,
|
119
|
+
key: jax.Array | None = None,
|
120
|
+
**kwargs,
|
121
|
+
) -> dict[str, Any]:
|
122
|
+
"""
|
123
|
+
Initialize the integrator.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
x0: The initial state of the system.
|
127
|
+
t0: The initial time of the system.
|
128
|
+
dt: The time step of the integration.
|
129
|
+
key: An optional random key to initialize the integrator.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
The auxiliary dictionary of the integrator.
|
133
|
+
|
134
|
+
Note:
|
135
|
+
This method should have the same signature as the inherited `__call__`
|
136
|
+
method, including additional kwargs.
|
137
|
+
|
138
|
+
Note:
|
139
|
+
If the integrator supports FSAL, the pair `(x0, t0)` must match the real
|
140
|
+
initial state and time of the system, otherwise the initial derivative of
|
141
|
+
the first step will be wrong.
|
142
|
+
"""
|
143
|
+
|
144
|
+
_, aux_dict_dynamics = self.dynamics(x0, t0)
|
145
|
+
|
146
|
+
with self.editable(validate=False) as integrator:
|
147
|
+
_ = integrator(x0, t0, dt, **kwargs)
|
148
|
+
aux_dict_step = integrator.params
|
149
|
+
|
150
|
+
if Integrator.AuxDictDynamicsKey in aux_dict_dynamics:
|
151
|
+
msg = "You cannot create a key '{}' in the __call__ method."
|
152
|
+
raise KeyError(msg.format(Integrator.AuxDictDynamicsKey))
|
153
|
+
|
154
|
+
return {Integrator.AuxDictDynamicsKey: aux_dict_dynamics} | aux_dict_step
|
155
|
+
|
156
|
+
|
157
|
+
@jax_dataclasses.pytree_dataclass
|
158
|
+
class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
|
159
|
+
|
160
|
+
# The Runge-Kutta matrix.
|
161
|
+
A: ClassVar[jax.typing.ArrayLike]
|
162
|
+
|
163
|
+
# The weights coefficients.
|
164
|
+
# Note that in practice we typically use its transpose `b.transpose()`.
|
165
|
+
b: ClassVar[jax.typing.ArrayLike]
|
166
|
+
|
167
|
+
# The nodes coefficients.
|
168
|
+
c: ClassVar[jax.typing.ArrayLike]
|
169
|
+
|
170
|
+
# Define the order of the solution.
|
171
|
+
# It should have as many elements as the number of rows of `b.transpose()`.
|
172
|
+
order_of_bT_rows: ClassVar[tuple[int, ...]]
|
173
|
+
|
174
|
+
# Define the row of the integration output corresponding to the final solution.
|
175
|
+
# This is the row of b.T that produces the final state.
|
176
|
+
row_index_of_solution: ClassVar[int]
|
177
|
+
|
178
|
+
# Attributes of FSAL (first-same-as-last) property.
|
179
|
+
fsal_enabled_if_supported: Static[bool] = dataclasses.field(repr=False)
|
180
|
+
index_of_fsal: Static[jtp.IntLike | None] = dataclasses.field(repr=False)
|
181
|
+
|
182
|
+
@property
|
183
|
+
def has_fsal(self) -> bool:
|
184
|
+
return self.fsal_enabled_if_supported and self.index_of_fsal is not None
|
185
|
+
|
186
|
+
@property
|
187
|
+
def order(self) -> int:
|
188
|
+
return self.order_of_bT_rows[self.row_index_of_solution]
|
189
|
+
|
190
|
+
@override
|
191
|
+
@classmethod
|
192
|
+
def build(
|
193
|
+
cls: Type[Self],
|
194
|
+
*,
|
195
|
+
dynamics: SystemDynamics[State, StateDerivative],
|
196
|
+
fsal_enabled_if_supported: jtp.BoolLike = True,
|
197
|
+
**kwargs,
|
198
|
+
) -> Self:
|
199
|
+
"""
|
200
|
+
Build the integrator object.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
dynamics: The system dynamics.
|
204
|
+
fsal_enabled_if_supported:
|
205
|
+
Whether to enable the FSAL property, if supported.
|
206
|
+
**kwargs: Additional keyword arguments to build the integrator.
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
The integrator object.
|
210
|
+
"""
|
211
|
+
|
212
|
+
# Adjust the shape of the tableau coefficients.
|
213
|
+
c = jnp.atleast_1d(cls.c.squeeze())
|
214
|
+
b = jnp.atleast_2d(jnp.vstack(cls.b.squeeze()))
|
215
|
+
A = jnp.atleast_2d(cls.A.squeeze())
|
216
|
+
|
217
|
+
# Check validity of the Butcher tableau.
|
218
|
+
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
|
219
|
+
raise ValueError("The Butcher tableau of this class is not valid.")
|
220
|
+
|
221
|
+
# Store the adjusted shapes of the tableau coefficients.
|
222
|
+
cls.c = c
|
223
|
+
cls.b = b
|
224
|
+
cls.A = A
|
225
|
+
|
226
|
+
# Check that b.T has enough rows based on the configured index of the solution.
|
227
|
+
if cls.row_index_of_solution >= cls.b.T.shape[0]:
|
228
|
+
msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
|
229
|
+
raise ValueError(msg.format(cls.row_index_of_solution, cls.b.T.shape[0]))
|
230
|
+
|
231
|
+
# Check that the tuple containing the order of the b.T rows matches the number
|
232
|
+
# of the b.T rows.
|
233
|
+
if len(cls.order_of_bT_rows) != cls.b.T.shape[0]:
|
234
|
+
msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}."
|
235
|
+
raise ValueError(msg.format(len(cls.order_of_bT_rows), cls.b.T.shape[0]))
|
236
|
+
|
237
|
+
# Check if the Butcher tableau supports FSAL (first-same-as-last).
|
238
|
+
# If it does, store the index of the intermediate derivative to be used as the
|
239
|
+
# first derivative of the next iteration.
|
240
|
+
has_fsal, index_of_fsal = ExplicitRungeKutta.butcher_tableau_supports_fsal(
|
241
|
+
A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
|
242
|
+
)
|
243
|
+
|
244
|
+
# Build the integrator object.
|
245
|
+
integrator = super().build(
|
246
|
+
dynamics=dynamics,
|
247
|
+
index_of_fsal=index_of_fsal,
|
248
|
+
fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
|
249
|
+
**kwargs,
|
250
|
+
)
|
251
|
+
|
252
|
+
return integrator
|
253
|
+
|
254
|
+
def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
|
255
|
+
|
256
|
+
# Here z is a batched state with as many batch elements as b.T rows.
|
257
|
+
# Note that z has multiple batches only if b.T has more than one row,
|
258
|
+
# e.g. in Butcher tableau of embedded schemes.
|
259
|
+
z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
|
260
|
+
|
261
|
+
# The next state is the batch element located at the configured index of solution.
|
262
|
+
return jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
|
263
|
+
|
264
|
+
@classmethod
|
265
|
+
def integrate_rk_stage(
|
266
|
+
cls, x0: State, t0: Time, dt: TimeStep, k: StateDerivative
|
267
|
+
) -> NextState:
|
268
|
+
"""
|
269
|
+
Integrate a single stage of the Runge-Kutta method.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
x0: The initial state of the system.
|
273
|
+
t0: The initial time of the system.
|
274
|
+
dt:
|
275
|
+
The time step of the RK integration scheme. Note that this is
|
276
|
+
not the stage timestep, as it depends on the `A` matrix used
|
277
|
+
to compute the `k` argument.
|
278
|
+
k:
|
279
|
+
The RK state derivative of the current stage, weighted with
|
280
|
+
the `A` matrix.
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
The state at the next stage of the integration.
|
284
|
+
|
285
|
+
Note:
|
286
|
+
In the most generic case, `k` could be an arbitrary composition
|
287
|
+
of the kᵢ derivatives, depending on the RK matrix A.
|
288
|
+
|
289
|
+
Note:
|
290
|
+
Overriding this method allows users to use different classes
|
291
|
+
defining `State` and `StateDerivative`. Be aware that the
|
292
|
+
timestep `dt` is not the stage timestep, therefore the map
|
293
|
+
used to convert the state derivative must be time-independent.
|
294
|
+
"""
|
295
|
+
|
296
|
+
op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
|
297
|
+
return jax.tree_util.tree_map(op, x0, k)
|
298
|
+
|
299
|
+
@classmethod
|
300
|
+
def post_process_state(
|
301
|
+
cls, x0: State, t0: Time, xf: NextState, dt: TimeStep
|
302
|
+
) -> NextState:
|
303
|
+
"""
|
304
|
+
Post-process the integrated state at :math:`t_f = t_0 + \Delta t`.
|
305
|
+
|
306
|
+
Args:
|
307
|
+
x0: The initial state of the system.
|
308
|
+
t0: The initial time of the system.
|
309
|
+
xf: The final state of the system obtain through the integration.
|
310
|
+
dt: The time step used for the integration.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
The post-processed integrated state.
|
314
|
+
"""
|
315
|
+
|
316
|
+
return xf
|
317
|
+
|
318
|
+
def _compute_next_state(
|
319
|
+
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
320
|
+
) -> NextState:
|
321
|
+
"""
|
322
|
+
Compute the next state of the system, returning all the output states.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
x0: The initial state of the system.
|
326
|
+
t0: The initial time of the system.
|
327
|
+
dt: The time step of the integration.
|
328
|
+
**kwargs: Additional keyword arguments.
|
329
|
+
|
330
|
+
Returns:
|
331
|
+
A batched state with as many batch elements as `b.T` rows.
|
332
|
+
"""
|
333
|
+
|
334
|
+
# Call variables with better symbols.
|
335
|
+
Δt = dt
|
336
|
+
c = self.c
|
337
|
+
b = self.b
|
338
|
+
A = self.A
|
339
|
+
|
340
|
+
# Close f over optional kwargs.
|
341
|
+
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
|
342
|
+
|
343
|
+
# Initialize the carry of the for loop with the stacked kᵢ vectors.
|
344
|
+
carry0 = jax.tree_util.tree_map(
|
345
|
+
lambda l: jnp.repeat(jnp.zeros_like(l)[jnp.newaxis, ...], c.size, axis=0),
|
346
|
+
x0,
|
347
|
+
)
|
348
|
+
|
349
|
+
# Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
|
350
|
+
get_ẋ0 = lambda: self.params.get("dxdt0", f(x0, t0)[0])
|
351
|
+
|
352
|
+
# We use a `jax.lax.scan` to compile the `f` function only once.
|
353
|
+
# Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
|
354
|
+
# would include 4 repetitions of the `f` logic, making everything extremely slow.
|
355
|
+
def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[jax.Array, None]:
|
356
|
+
""""""
|
357
|
+
|
358
|
+
# Unpack the carry, i.e. the stacked kᵢ vectors.
|
359
|
+
K = carry
|
360
|
+
|
361
|
+
# Define the computation of the Runge-Kutta stage.
|
362
|
+
def compute_ki() -> jax.Array:
|
363
|
+
|
364
|
+
# Compute ∑ⱼ aᵢⱼ kⱼ
|
365
|
+
op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
|
366
|
+
sum_ak = jax.tree_util.tree_map(op_sum_ak, K)
|
367
|
+
|
368
|
+
# Compute the next state for the kᵢ evaluation.
|
369
|
+
# Note that this is not a Δt integration since aᵢⱼ could be fractional.
|
370
|
+
xi = self.integrate_rk_stage(x0, t0, Δt, sum_ak)
|
371
|
+
|
372
|
+
# Compute the next time for the kᵢ evaluation.
|
373
|
+
ti = t0 + c[i] * Δt
|
374
|
+
|
375
|
+
# This is kᵢ = f(xᵢ, tᵢ).
|
376
|
+
return f(xi, ti)[0]
|
377
|
+
|
378
|
+
# This selector enables FSAL property in the first iteration (i=0).
|
379
|
+
ki = jax.lax.cond(
|
380
|
+
pred=jnp.logical_and(i == 0, self.has_fsal),
|
381
|
+
true_fun=get_ẋ0,
|
382
|
+
false_fun=compute_ki,
|
383
|
+
)
|
384
|
+
|
385
|
+
# Store the kᵢ derivative in K.
|
386
|
+
op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
|
387
|
+
K = jax.tree_util.tree_map(op, K, ki)
|
388
|
+
|
389
|
+
carry = K
|
390
|
+
return carry, None
|
391
|
+
|
392
|
+
# Compute the state derivatives kᵢ.
|
393
|
+
K, _ = jax.lax.scan(
|
394
|
+
f=scan_body,
|
395
|
+
init=carry0,
|
396
|
+
xs=jnp.arange(c.size),
|
397
|
+
)
|
398
|
+
|
399
|
+
# Update the FSAL property for the next iteration.
|
400
|
+
if self.has_fsal:
|
401
|
+
self.params["dxdt0"] = jax.tree_map(lambda l: l[self.index_of_fsal], K)
|
402
|
+
|
403
|
+
# Compute the output state.
|
404
|
+
# Note that z contains as many new states as the rows of `b.T`.
|
405
|
+
op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k)
|
406
|
+
z = jax.tree_util.tree_map(op, x0, K)
|
407
|
+
|
408
|
+
# Transform the final state of the integration.
|
409
|
+
# This allows to inject custom logic, if needed.
|
410
|
+
z_transformed = jax.vmap(
|
411
|
+
lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
|
412
|
+
)(z)
|
413
|
+
|
414
|
+
return z_transformed
|
415
|
+
|
416
|
+
@staticmethod
|
417
|
+
def butcher_tableau_is_valid(
|
418
|
+
A: jax.typing.ArrayLike, b: jax.typing.ArrayLike, c: jax.typing.ArrayLike
|
419
|
+
) -> jtp.Bool:
|
420
|
+
"""
|
421
|
+
Check if the Butcher tableau is valid.
|
422
|
+
|
423
|
+
Args:
|
424
|
+
A: The Runge-Kutta matrix.
|
425
|
+
b: The weights coefficients.
|
426
|
+
c: The nodes coefficients.
|
427
|
+
|
428
|
+
Returns:
|
429
|
+
`True` if the Butcher tableau is valid, `False` otherwise.
|
430
|
+
"""
|
431
|
+
|
432
|
+
valid = True
|
433
|
+
valid = valid and A.ndim == 2
|
434
|
+
valid = valid and b.ndim == 2
|
435
|
+
valid = valid and c.ndim == 1
|
436
|
+
valid = valid and b.T.shape[0] <= 2
|
437
|
+
valid = valid and A.shape[0] == A.shape[1]
|
438
|
+
valid = valid and A.shape == (c.size, b.T.shape[1])
|
439
|
+
valid = valid and bool(jnp.all(b.T.sum(axis=1) == 1))
|
440
|
+
|
441
|
+
return valid
|
442
|
+
|
443
|
+
@staticmethod
|
444
|
+
def butcher_tableau_is_explicit(A: jax.typing.ArrayLike) -> jtp.Bool:
|
445
|
+
"""
|
446
|
+
Check if the Butcher tableau corresponds to an explicit integration scheme.
|
447
|
+
|
448
|
+
Args:
|
449
|
+
A: The Runge-Kutta matrix.
|
450
|
+
|
451
|
+
Returns:
|
452
|
+
`True` if the Butcher tableau is explicit, `False` otherwise.
|
453
|
+
"""
|
454
|
+
|
455
|
+
return jnp.allclose(A, jnp.tril(A, k=-1))
|
456
|
+
|
457
|
+
@staticmethod
|
458
|
+
def butcher_tableau_supports_fsal(
|
459
|
+
A: jax.typing.ArrayLike,
|
460
|
+
b: jax.typing.ArrayLike,
|
461
|
+
c: jax.typing.ArrayLike,
|
462
|
+
index_of_solution: jtp.IntLike = 0,
|
463
|
+
) -> [bool, int | None]:
|
464
|
+
"""
|
465
|
+
Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
|
466
|
+
|
467
|
+
Args:
|
468
|
+
A: The Runge-Kutta matrix.
|
469
|
+
b: The weights coefficients.
|
470
|
+
c: The nodes coefficients.
|
471
|
+
index_of_solution:
|
472
|
+
The index of the row of `b.T` corresponding to the solution.
|
473
|
+
|
474
|
+
Returns:
|
475
|
+
A tuple containing a boolean indicating whether the Butcher tableau supports
|
476
|
+
FSAL, and the index i of the intermediate kᵢ derivative corresponding to the
|
477
|
+
initial derivative `f(x0, t0)` of the next step.
|
478
|
+
"""
|
479
|
+
|
480
|
+
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
|
481
|
+
raise ValueError("The Butcher tableau is not valid.")
|
482
|
+
|
483
|
+
if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A):
|
484
|
+
return False
|
485
|
+
|
486
|
+
if index_of_solution >= b.T.shape[0]:
|
487
|
+
msg = "The index of the solution (i-th row of `b.T`) is out of range."
|
488
|
+
raise ValueError(msg)
|
489
|
+
|
490
|
+
if c[0] != 0:
|
491
|
+
return False, None
|
492
|
+
|
493
|
+
# Find all the rows of A where c = 1 (therefore at t=tf). The Butcher tableau
|
494
|
+
# supports FSAL if any of these rows (there might be more rows with c=1) matches
|
495
|
+
# the rows of b.T corresponding to the next state (marked by `index_of_solution`).
|
496
|
+
# This last condition means that the last kᵢ derivative is computed at (tf, xf),
|
497
|
+
# that corresponds to the (t0, x0) pair of the next integration call.
|
498
|
+
rows_of_A_with_fsal = (A == b.T[None, index_of_solution]).all(axis=1)
|
499
|
+
rows_of_A_with_fsal = jnp.logical_and(rows_of_A_with_fsal, (c == 1))
|
500
|
+
|
501
|
+
# If there is no match, it means that the Butcher tableau does not support FSAL.
|
502
|
+
if not rows_of_A_with_fsal.any():
|
503
|
+
return False, None
|
504
|
+
|
505
|
+
# Return the index of the row of A providing the fsal derivative (that is the
|
506
|
+
# possibly intermediate kᵢ derivative).
|
507
|
+
# Note that if multiple rows match (it should not), we return the first match.
|
508
|
+
return True, int(jnp.where(rows_of_A_with_fsal == True)[0].tolist()[0])
|
@@ -0,0 +1,158 @@
|
|
1
|
+
from typing import ClassVar, Generic
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jax_dataclasses
|
6
|
+
import jaxlie
|
7
|
+
|
8
|
+
from jaxsim.simulation.ode_data import ODEState
|
9
|
+
|
10
|
+
from .common import ExplicitRungeKutta, PyTreeType, Time, TimeStep
|
11
|
+
|
12
|
+
ODEStateDerivative = ODEState
|
13
|
+
|
14
|
+
|
15
|
+
# =====================================================
|
16
|
+
# Explicit Runge-Kutta integrators operating on PyTrees
|
17
|
+
# =====================================================
|
18
|
+
|
19
|
+
|
20
|
+
@jax_dataclasses.pytree_dataclass
|
21
|
+
class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
22
|
+
|
23
|
+
A: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
24
|
+
[
|
25
|
+
[0],
|
26
|
+
]
|
27
|
+
).astype(float)
|
28
|
+
|
29
|
+
b: ClassVar[jax.typing.ArrayLike] = (
|
30
|
+
jnp.array(
|
31
|
+
[
|
32
|
+
[1],
|
33
|
+
]
|
34
|
+
)
|
35
|
+
.astype(float)
|
36
|
+
.transpose()
|
37
|
+
)
|
38
|
+
|
39
|
+
c: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
40
|
+
[0],
|
41
|
+
).astype(float)
|
42
|
+
|
43
|
+
row_index_of_solution: ClassVar[int] = 0
|
44
|
+
order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
|
45
|
+
|
46
|
+
|
47
|
+
@jax_dataclasses.pytree_dataclass
|
48
|
+
class Heun(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
49
|
+
|
50
|
+
A: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
51
|
+
[
|
52
|
+
[0, 0],
|
53
|
+
[1 / 2, 0],
|
54
|
+
]
|
55
|
+
).astype(float)
|
56
|
+
|
57
|
+
b: ClassVar[jax.typing.ArrayLike] = (
|
58
|
+
jnp.atleast_2d(
|
59
|
+
jnp.array([1 / 2, 1 / 2]),
|
60
|
+
)
|
61
|
+
.astype(float)
|
62
|
+
.transpose()
|
63
|
+
)
|
64
|
+
|
65
|
+
c: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
66
|
+
[0, 1],
|
67
|
+
).astype(float)
|
68
|
+
|
69
|
+
row_index_of_solution: ClassVar[int] = 0
|
70
|
+
order_of_bT_rows: ClassVar[tuple[int, ...]] = (2,)
|
71
|
+
|
72
|
+
|
73
|
+
@jax_dataclasses.pytree_dataclass
|
74
|
+
class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
75
|
+
|
76
|
+
A: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
77
|
+
[
|
78
|
+
[0, 0, 0, 0],
|
79
|
+
[1 / 2, 0, 0, 0],
|
80
|
+
[0, 1 / 2, 0, 0],
|
81
|
+
[0, 0, 1, 0],
|
82
|
+
]
|
83
|
+
).astype(float)
|
84
|
+
|
85
|
+
b: ClassVar[jax.typing.ArrayLike] = (
|
86
|
+
jnp.atleast_2d(
|
87
|
+
jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
|
88
|
+
)
|
89
|
+
.astype(float)
|
90
|
+
.transpose()
|
91
|
+
)
|
92
|
+
|
93
|
+
c: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
94
|
+
[0, 1 / 2, 1 / 2, 1],
|
95
|
+
).astype(float)
|
96
|
+
|
97
|
+
row_index_of_solution: ClassVar[int] = 0
|
98
|
+
order_of_bT_rows: ClassVar[tuple[int, ...]] = (4,)
|
99
|
+
|
100
|
+
|
101
|
+
# ===============================================================================
|
102
|
+
# Explicit Runge-Kutta integrators operating on ODEState and integrating on SO(3)
|
103
|
+
# ===============================================================================
|
104
|
+
|
105
|
+
|
106
|
+
class ExplicitRungeKuttaSO3Mixin:
|
107
|
+
"""
|
108
|
+
Mixin class to apply over explicit RK integrators defined on
|
109
|
+
`PyTreeType = ODEState` to integrate the quaternion on SO(3).
|
110
|
+
"""
|
111
|
+
|
112
|
+
@classmethod
|
113
|
+
def post_process_state(
|
114
|
+
cls, x0: ODEState, t0: Time, xf: ODEState, dt: TimeStep
|
115
|
+
) -> ODEState:
|
116
|
+
|
117
|
+
# Indices to convert quaternions between serializations.
|
118
|
+
to_xyzw = jnp.array([1, 2, 3, 0])
|
119
|
+
to_wxyz = jnp.array([3, 0, 1, 2])
|
120
|
+
|
121
|
+
# Get the initial quaternion.
|
122
|
+
W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
|
123
|
+
xyzw=x0.physics_model.base_quaternion[to_xyzw]
|
124
|
+
)
|
125
|
+
|
126
|
+
# Get the final angular velocity.
|
127
|
+
# This is already computed by averaging the kᵢ in RK-based schemes.
|
128
|
+
# Therefore, by using the ω at tf, we obtain a RK scheme operating
|
129
|
+
# on the SO(3) manifold.
|
130
|
+
W_ω_WB_tf = xf.physics_model.base_angular_velocity
|
131
|
+
|
132
|
+
# Integrate the quaternion on SO(3).
|
133
|
+
# Note that we left-multiply with the exponential map since the angular
|
134
|
+
# velocity is expressed in the inertial frame.
|
135
|
+
W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
|
136
|
+
|
137
|
+
# Replace the quaternion in the final state.
|
138
|
+
return xf.replace(
|
139
|
+
physics_model=xf.physics_model.replace(
|
140
|
+
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
|
141
|
+
),
|
142
|
+
validate=True,
|
143
|
+
)
|
144
|
+
|
145
|
+
|
146
|
+
@jax_dataclasses.pytree_dataclass
|
147
|
+
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
|
148
|
+
pass
|
149
|
+
|
150
|
+
|
151
|
+
@jax_dataclasses.pytree_dataclass
|
152
|
+
class HeunSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
|
153
|
+
pass
|
154
|
+
|
155
|
+
|
156
|
+
@jax_dataclasses.pytree_dataclass
|
157
|
+
class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[ODEState]):
|
158
|
+
pass
|
jaxsim/mujoco/__init__.py
CHANGED