jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/integrators/common.py
CHANGED
@@ -1,13 +1,16 @@
|
|
1
1
|
import abc
|
2
2
|
import dataclasses
|
3
|
-
from typing import Any, ClassVar, Generic, Protocol,
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
7
|
import jax_dataclasses
|
8
8
|
from jax_dataclasses import Static
|
9
9
|
|
10
|
+
import jaxsim.api as js
|
11
|
+
import jaxsim.math
|
10
12
|
import jaxsim.typing as jtp
|
13
|
+
from jaxsim import exceptions, logging
|
11
14
|
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
|
12
15
|
|
13
16
|
try:
|
@@ -25,17 +28,33 @@ except ImportError:
|
|
25
28
|
# Generic types
|
26
29
|
# =============
|
27
30
|
|
28
|
-
Time =
|
29
|
-
TimeStep =
|
31
|
+
Time = jtp.FloatLike
|
32
|
+
TimeStep = jtp.FloatLike
|
30
33
|
State = NextState = TypeVar("State")
|
31
34
|
StateDerivative = TypeVar("StateDerivative")
|
32
35
|
PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
|
33
36
|
|
34
37
|
|
35
38
|
class SystemDynamics(Protocol[State, StateDerivative]):
|
39
|
+
"""
|
40
|
+
Protocol defining the system dynamics.
|
41
|
+
"""
|
42
|
+
|
36
43
|
def __call__(
|
37
44
|
self, x: State, t: Time, **kwargs
|
38
|
-
) -> tuple[StateDerivative, dict[str, Any]]:
|
45
|
+
) -> tuple[StateDerivative, dict[str, Any]]:
|
46
|
+
"""
|
47
|
+
Compute the state derivative of the system.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
x: The state of the system.
|
51
|
+
t: The time of the system.
|
52
|
+
**kwargs: Additional keyword arguments.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The state derivative of the system and the auxiliary dictionary.
|
56
|
+
"""
|
57
|
+
pass
|
39
58
|
|
40
59
|
|
41
60
|
# =======================
|
@@ -45,20 +64,20 @@ class SystemDynamics(Protocol[State, StateDerivative]):
|
|
45
64
|
|
46
65
|
@jax_dataclasses.pytree_dataclass
|
47
66
|
class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
48
|
-
|
49
|
-
|
67
|
+
"""
|
68
|
+
Factory class for integrators.
|
69
|
+
"""
|
50
70
|
|
51
71
|
dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
|
52
72
|
repr=False, hash=False, compare=False, kw_only=True
|
53
73
|
)
|
54
74
|
|
55
|
-
params: dict[str, Any] = dataclasses.field(
|
56
|
-
default_factory=dict, repr=False, hash=False, compare=False, kw_only=True
|
57
|
-
)
|
58
|
-
|
59
75
|
@classmethod
|
60
76
|
def build(
|
61
|
-
cls:
|
77
|
+
cls: type[Self],
|
78
|
+
*,
|
79
|
+
dynamics: SystemDynamics[State, StateDerivative],
|
80
|
+
**kwargs,
|
62
81
|
) -> Self:
|
63
82
|
"""
|
64
83
|
Build the integrator object.
|
@@ -71,7 +90,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
71
90
|
The integrator object.
|
72
91
|
"""
|
73
92
|
|
74
|
-
return cls(dynamics=dynamics, **kwargs)
|
93
|
+
return cls(dynamics=dynamics, **kwargs)
|
75
94
|
|
76
95
|
def step(
|
77
96
|
self,
|
@@ -79,9 +98,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
79
98
|
t0: Time,
|
80
99
|
dt: TimeStep,
|
81
100
|
*,
|
82
|
-
|
101
|
+
metadata: dict[str, Any] | None = None,
|
83
102
|
**kwargs,
|
84
|
-
) -> tuple[
|
103
|
+
) -> tuple[NextState, dict[str, Any]]:
|
85
104
|
"""
|
86
105
|
Perform a single integration step.
|
87
106
|
|
@@ -89,25 +108,30 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
89
108
|
x0: The initial state of the system.
|
90
109
|
t0: The initial time of the system.
|
91
110
|
dt: The time step of the integration.
|
92
|
-
|
111
|
+
metadata: The state auxiliary dictionary of the integrator.
|
93
112
|
**kwargs: Additional keyword arguments.
|
94
113
|
|
95
114
|
Returns:
|
96
115
|
The final state of the system and the updated auxiliary dictionary.
|
97
116
|
"""
|
98
117
|
|
99
|
-
|
100
|
-
integrator.params = params
|
101
|
-
|
102
|
-
with integrator.mutable_context(mutability=Mutability.MUTABLE):
|
103
|
-
xf = integrator(x0, t0, dt, **kwargs)
|
118
|
+
metadata = metadata if metadata is not None else {}
|
104
119
|
|
105
|
-
|
120
|
+
with self.mutable_context(mutability=Mutability.MUTABLE) as integrator:
|
121
|
+
xf, metadata_step = integrator(x0, t0, dt, **kwargs)
|
106
122
|
|
107
|
-
return
|
123
|
+
return (
|
124
|
+
xf,
|
125
|
+
metadata | metadata_step,
|
126
|
+
)
|
108
127
|
|
109
128
|
@abc.abstractmethod
|
110
|
-
def __call__(
|
129
|
+
def __call__(
|
130
|
+
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
131
|
+
) -> tuple[NextState, dict[str, Any]]:
|
132
|
+
"""
|
133
|
+
Perform a single integration step.
|
134
|
+
"""
|
111
135
|
pass
|
112
136
|
|
113
137
|
def init(
|
@@ -116,56 +140,44 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
116
140
|
t0: Time,
|
117
141
|
dt: TimeStep,
|
118
142
|
*,
|
119
|
-
|
143
|
+
include_dynamics_aux_dict: bool = False,
|
120
144
|
**kwargs,
|
121
145
|
) -> dict[str, Any]:
|
122
146
|
"""
|
123
|
-
Initialize the integrator.
|
124
|
-
|
125
|
-
Args:
|
126
|
-
x0: The initial state of the system.
|
127
|
-
t0: The initial time of the system.
|
128
|
-
dt: The time step of the integration.
|
129
|
-
key: An optional random key to initialize the integrator.
|
130
|
-
|
131
|
-
Returns:
|
132
|
-
The auxiliary dictionary of the integrator.
|
133
|
-
|
134
|
-
Note:
|
135
|
-
This method should have the same signature as the inherited `__call__`
|
136
|
-
method, including additional kwargs.
|
137
|
-
|
138
|
-
Note:
|
139
|
-
If the integrator supports FSAL, the pair `(x0, t0)` must match the real
|
140
|
-
initial state and time of the system, otherwise the initial derivative of
|
141
|
-
the first step will be wrong.
|
147
|
+
Initialize the integrator. This method is deprecated.
|
142
148
|
"""
|
143
149
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
_ = integrator(x0, t0, dt, **kwargs)
|
148
|
-
aux_dict_step = integrator.params
|
149
|
-
|
150
|
-
if Integrator.AuxDictDynamicsKey in aux_dict_dynamics:
|
151
|
-
msg = "You cannot create a key '{}' in the __call__ method."
|
152
|
-
raise KeyError(msg.format(Integrator.AuxDictDynamicsKey))
|
150
|
+
logging.warning(
|
151
|
+
"The 'init' method has been deprecated. There is no need to call it."
|
152
|
+
)
|
153
153
|
|
154
|
-
return {
|
154
|
+
return {}
|
155
155
|
|
156
156
|
|
157
157
|
@jax_dataclasses.pytree_dataclass
|
158
158
|
class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
|
159
|
+
"""
|
160
|
+
Base class for explicit Runge-Kutta integrators.
|
161
|
+
|
162
|
+
Attributes:
|
163
|
+
A: The Runge-Kutta matrix.
|
164
|
+
b: The weights coefficients.
|
165
|
+
c: The nodes coefficients.
|
166
|
+
order_of_bT_rows: The order of the solution.
|
167
|
+
row_index_of_solution: The row of the integration output corresponding to the final solution.
|
168
|
+
fsal_enabled_if_supported: Whether to enable the FSAL property, if supported.
|
169
|
+
index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration.
|
170
|
+
"""
|
159
171
|
|
160
172
|
# The Runge-Kutta matrix.
|
161
|
-
A:
|
173
|
+
A: jtp.Matrix
|
162
174
|
|
163
175
|
# The weights coefficients.
|
164
176
|
# Note that in practice we typically use its transpose `b.transpose()`.
|
165
|
-
b:
|
177
|
+
b: jtp.Matrix
|
166
178
|
|
167
179
|
# The nodes coefficients.
|
168
|
-
c:
|
180
|
+
c: jtp.Vector
|
169
181
|
|
170
182
|
# Define the order of the solution.
|
171
183
|
# It should have as many elements as the number of rows of `b.transpose()`.
|
@@ -181,16 +193,22 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
181
193
|
|
182
194
|
@property
|
183
195
|
def has_fsal(self) -> bool:
|
196
|
+
"""
|
197
|
+
Check if the integrator supports the FSAL property.
|
198
|
+
"""
|
184
199
|
return self.fsal_enabled_if_supported and self.index_of_fsal is not None
|
185
200
|
|
186
201
|
@property
|
187
202
|
def order(self) -> int:
|
203
|
+
"""
|
204
|
+
Return the order of the integrator.
|
205
|
+
"""
|
188
206
|
return self.order_of_bT_rows[self.row_index_of_solution]
|
189
207
|
|
190
208
|
@override
|
191
209
|
@classmethod
|
192
210
|
def build(
|
193
|
-
cls:
|
211
|
+
cls: type[Self],
|
194
212
|
*,
|
195
213
|
dynamics: SystemDynamics[State, StateDerivative],
|
196
214
|
fsal_enabled_if_supported: jtp.BoolLike = True,
|
@@ -208,37 +226,32 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
208
226
|
Returns:
|
209
227
|
The integrator object.
|
210
228
|
"""
|
211
|
-
|
212
|
-
|
213
|
-
c =
|
214
|
-
b = jnp.atleast_2d(jnp.vstack(cls.b.squeeze()))
|
215
|
-
A = jnp.atleast_2d(cls.A.squeeze())
|
229
|
+
A = cls.__dataclass_fields__["A"].default_factory()
|
230
|
+
b = cls.__dataclass_fields__["b"].default_factory()
|
231
|
+
c = cls.__dataclass_fields__["c"].default_factory()
|
216
232
|
|
217
233
|
# Check validity of the Butcher tableau.
|
218
234
|
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
|
219
235
|
raise ValueError("The Butcher tableau of this class is not valid.")
|
220
236
|
|
221
|
-
# Store the adjusted shapes of the tableau coefficients.
|
222
|
-
cls.c = c
|
223
|
-
cls.b = b
|
224
|
-
cls.A = A
|
225
|
-
|
226
237
|
# Check that b.T has enough rows based on the configured index of the solution.
|
227
|
-
if cls.row_index_of_solution >=
|
238
|
+
if cls.row_index_of_solution >= b.T.shape[0]:
|
228
239
|
msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
|
229
|
-
raise ValueError(msg.format(cls.row_index_of_solution,
|
240
|
+
raise ValueError(msg.format(cls.row_index_of_solution, b.T.shape[0]))
|
230
241
|
|
231
242
|
# Check that the tuple containing the order of the b.T rows matches the number
|
232
243
|
# of the b.T rows.
|
233
|
-
if len(cls.order_of_bT_rows) !=
|
244
|
+
if len(cls.order_of_bT_rows) != b.T.shape[0]:
|
234
245
|
msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}."
|
235
|
-
raise ValueError(msg.format(len(cls.order_of_bT_rows),
|
246
|
+
raise ValueError(msg.format(len(cls.order_of_bT_rows), b.T.shape[0]))
|
236
247
|
|
237
248
|
# Check if the Butcher tableau supports FSAL (first-same-as-last).
|
238
249
|
# If it does, store the index of the intermediate derivative to be used as the
|
239
250
|
# first derivative of the next iteration.
|
240
|
-
has_fsal, index_of_fsal =
|
241
|
-
|
251
|
+
has_fsal, index_of_fsal = ( # noqa: F841
|
252
|
+
ExplicitRungeKutta.butcher_tableau_supports_fsal(
|
253
|
+
A=A, b=b, c=c, index_of_solution=cls.row_index_of_solution
|
254
|
+
)
|
242
255
|
)
|
243
256
|
|
244
257
|
# Build the integrator object.
|
@@ -251,15 +264,22 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
251
264
|
|
252
265
|
return integrator
|
253
266
|
|
254
|
-
def __call__(
|
267
|
+
def __call__(
|
268
|
+
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
269
|
+
) -> tuple[NextState, dict[str, Any]]:
|
270
|
+
"""
|
271
|
+
Perform a single integration step.
|
272
|
+
"""
|
255
273
|
|
256
274
|
# Here z is a batched state with as many batch elements as b.T rows.
|
257
275
|
# Note that z has multiple batches only if b.T has more than one row,
|
258
276
|
# e.g. in Butcher tableau of embedded schemes.
|
259
|
-
z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
|
277
|
+
z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
|
260
278
|
|
261
279
|
# The next state is the batch element located at the configured index of solution.
|
262
|
-
|
280
|
+
next_state = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
|
281
|
+
|
282
|
+
return next_state, aux_dict
|
263
283
|
|
264
284
|
@classmethod
|
265
285
|
def integrate_rk_stage(
|
@@ -294,13 +314,13 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
294
314
|
"""
|
295
315
|
|
296
316
|
op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
|
297
|
-
return jax.
|
317
|
+
return jax.tree.map(op, x0, k)
|
298
318
|
|
299
319
|
@classmethod
|
300
320
|
def post_process_state(
|
301
321
|
cls, x0: State, t0: Time, xf: NextState, dt: TimeStep
|
302
322
|
) -> NextState:
|
303
|
-
"""
|
323
|
+
r"""
|
304
324
|
Post-process the integrated state at :math:`t_f = t_0 + \Delta t`.
|
305
325
|
|
306
326
|
Args:
|
@@ -317,7 +337,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
317
337
|
|
318
338
|
def _compute_next_state(
|
319
339
|
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
320
|
-
) -> NextState:
|
340
|
+
) -> tuple[NextState, dict[str, Any]]:
|
321
341
|
"""
|
322
342
|
Compute the next state of the system, returning all the output states.
|
323
343
|
|
@@ -337,33 +357,42 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
337
357
|
b = self.b
|
338
358
|
A = self.A
|
339
359
|
|
360
|
+
# Extract metadata from the kwargs.
|
361
|
+
metadata = kwargs.pop("metadata", {})
|
362
|
+
|
340
363
|
# Close f over optional kwargs.
|
341
364
|
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
|
342
365
|
|
343
366
|
# Initialize the carry of the for loop with the stacked kᵢ vectors.
|
344
|
-
carry0 = jax.
|
345
|
-
lambda l: jnp.
|
346
|
-
x0,
|
367
|
+
carry0 = jax.tree.map(
|
368
|
+
lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
|
347
369
|
)
|
348
370
|
|
349
|
-
#
|
350
|
-
|
371
|
+
# Closure on metadata to either evaluate the dynamics at the initial state
|
372
|
+
# or to use the previous state derivative (only integrators supporting FSAL).
|
373
|
+
def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
|
374
|
+
ẋ0, aux_dict = f(x0, t0)
|
375
|
+
return metadata.get("dxdt0", ẋ0), aux_dict
|
351
376
|
|
352
377
|
# We use a `jax.lax.scan` to compile the `f` function only once.
|
353
378
|
# Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
|
354
379
|
# would include 4 repetitions of the `f` logic, making everything extremely slow.
|
355
|
-
def scan_body(
|
356
|
-
|
380
|
+
def scan_body(
|
381
|
+
carry: jax.Array, i: int | jax.Array
|
382
|
+
) -> tuple[jax.Array, dict[str, Any]]:
|
383
|
+
"""
|
384
|
+
Compute the kᵢ derivative of the Runge-Kutta stage.
|
385
|
+
"""
|
357
386
|
|
358
387
|
# Unpack the carry, i.e. the stacked kᵢ vectors.
|
359
388
|
K = carry
|
360
389
|
|
361
390
|
# Define the computation of the Runge-Kutta stage.
|
362
|
-
def compute_ki() -> jax.Array:
|
391
|
+
def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
|
363
392
|
|
364
|
-
# Compute ∑ⱼ aᵢⱼ k
|
393
|
+
# Compute ∑ⱼ aᵢⱼ kⱼ.
|
365
394
|
op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
|
366
|
-
sum_ak = jax.
|
395
|
+
sum_ak = jax.tree.map(op_sum_ak, K)
|
367
396
|
|
368
397
|
# Compute the next state for the kᵢ evaluation.
|
369
398
|
# Note that this is not a Δt integration since aᵢⱼ could be fractional.
|
@@ -372,25 +401,26 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
372
401
|
# Compute the next time for the kᵢ evaluation.
|
373
402
|
ti = t0 + c[i] * Δt
|
374
403
|
|
375
|
-
#
|
376
|
-
|
404
|
+
# Evaluate the dynamics.
|
405
|
+
ki, aux_dict = f(xi, ti)
|
406
|
+
return ki, aux_dict
|
377
407
|
|
378
408
|
# This selector enables FSAL property in the first iteration (i=0).
|
379
|
-
ki = jax.lax.cond(
|
409
|
+
ki, aux_dict = jax.lax.cond(
|
380
410
|
pred=jnp.logical_and(i == 0, self.has_fsal),
|
381
|
-
true_fun=get_ẋ
|
411
|
+
true_fun=get_ẋ0_and_aux_dict,
|
382
412
|
false_fun=compute_ki,
|
383
413
|
)
|
384
414
|
|
385
415
|
# Store the kᵢ derivative in K.
|
386
416
|
op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
|
387
|
-
K = jax.
|
417
|
+
K = jax.tree.map(op, K, ki)
|
388
418
|
|
389
419
|
carry = K
|
390
|
-
return carry,
|
420
|
+
return carry, aux_dict
|
391
421
|
|
392
422
|
# Compute the state derivatives kᵢ.
|
393
|
-
K,
|
423
|
+
K, aux_dict = jax.lax.scan(
|
394
424
|
f=scan_body,
|
395
425
|
init=carry0,
|
396
426
|
xs=jnp.arange(c.size),
|
@@ -398,12 +428,13 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
398
428
|
|
399
429
|
# Update the FSAL property for the next iteration.
|
400
430
|
if self.has_fsal:
|
401
|
-
|
431
|
+
# Store the first derivative of the next step in the metadata.
|
432
|
+
metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
|
402
433
|
|
403
434
|
# Compute the output state.
|
404
435
|
# Note that z contains as many new states as the rows of `b.T`.
|
405
436
|
op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k)
|
406
|
-
z = jax.
|
437
|
+
z = jax.tree.map(op, x0, K)
|
407
438
|
|
408
439
|
# Transform the final state of the integration.
|
409
440
|
# This allows to inject custom logic, if needed.
|
@@ -411,11 +442,11 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
411
442
|
lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
|
412
443
|
)(z)
|
413
444
|
|
414
|
-
return z_transformed
|
445
|
+
return z_transformed, aux_dict | {"metadata": metadata}
|
415
446
|
|
416
447
|
@staticmethod
|
417
448
|
def butcher_tableau_is_valid(
|
418
|
-
A:
|
449
|
+
A: jtp.Matrix, b: jtp.Matrix, c: jtp.Vector
|
419
450
|
) -> jtp.Bool:
|
420
451
|
"""
|
421
452
|
Check if the Butcher tableau is valid.
|
@@ -441,7 +472,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
441
472
|
return valid
|
442
473
|
|
443
474
|
@staticmethod
|
444
|
-
def butcher_tableau_is_explicit(A:
|
475
|
+
def butcher_tableau_is_explicit(A: jtp.Matrix) -> jtp.Bool:
|
445
476
|
"""
|
446
477
|
Check if the Butcher tableau corresponds to an explicit integration scheme.
|
447
478
|
|
@@ -456,11 +487,11 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
456
487
|
|
457
488
|
@staticmethod
|
458
489
|
def butcher_tableau_supports_fsal(
|
459
|
-
A:
|
460
|
-
b:
|
461
|
-
c:
|
490
|
+
A: jtp.Matrix,
|
491
|
+
b: jtp.Matrix,
|
492
|
+
c: jtp.Vector,
|
462
493
|
index_of_solution: jtp.IntLike = 0,
|
463
|
-
) -> [bool, int | None]:
|
494
|
+
) -> tuple[bool, int | None]:
|
464
495
|
"""
|
465
496
|
Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
|
466
497
|
|
@@ -481,7 +512,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
481
512
|
raise ValueError("The Butcher tableau is not valid.")
|
482
513
|
|
483
514
|
if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A):
|
484
|
-
return False
|
515
|
+
return False, None
|
485
516
|
|
486
517
|
if index_of_solution >= b.T.shape[0]:
|
487
518
|
msg = "The index of the solution (i-th row of `b.T`) is out of range."
|
@@ -505,4 +536,57 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
505
536
|
# Return the index of the row of A providing the fsal derivative (that is the
|
506
537
|
# possibly intermediate kᵢ derivative).
|
507
538
|
# Note that if multiple rows match (it should not), we return the first match.
|
508
|
-
return True, int(jnp.where(rows_of_A_with_fsal
|
539
|
+
return True, int(jnp.where(rows_of_A_with_fsal)[0].tolist()[0])
|
540
|
+
|
541
|
+
|
542
|
+
class ExplicitRungeKuttaSO3Mixin:
|
543
|
+
"""
|
544
|
+
Mixin class to apply over explicit RK integrators defined on
|
545
|
+
`PyTreeType = ODEState` to integrate the quaternion on SO(3).
|
546
|
+
"""
|
547
|
+
|
548
|
+
@classmethod
|
549
|
+
def post_process_state(
|
550
|
+
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
|
551
|
+
) -> js.ode_data.ODEState:
|
552
|
+
r"""
|
553
|
+
Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the
|
554
|
+
quaternion is normalized.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
x0: The initial state of the system.
|
558
|
+
t0: The initial time of the system.
|
559
|
+
xf: The final state of the system obtain through the integration.
|
560
|
+
dt: The time step used for the integration.
|
561
|
+
"""
|
562
|
+
|
563
|
+
# Extract the initial base quaternion.
|
564
|
+
W_Q_B_t0 = x0.physics_model.base_quaternion
|
565
|
+
|
566
|
+
# We assume that the initial quaternion is already unary.
|
567
|
+
exceptions.raise_runtime_error_if(
|
568
|
+
condition=~jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0),
|
569
|
+
msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
|
570
|
+
)
|
571
|
+
|
572
|
+
# Get the angular velocity ω to integrate the quaternion.
|
573
|
+
# This velocity ω[t0] is computed in the previous timestep by averaging the kᵢ
|
574
|
+
# corresponding to the active RK-based scheme. Therefore, by using the ω[t0],
|
575
|
+
# we obtain an explicit RK scheme operating on the SO(3) manifold.
|
576
|
+
# Note that the current integrator is not a semi-implicit scheme, therefore
|
577
|
+
# using the final ω[tf] would be not correct.
|
578
|
+
W_ω_WB_t0 = x0.physics_model.base_angular_velocity
|
579
|
+
|
580
|
+
# Integrate the quaternion on SO(3).
|
581
|
+
W_Q_B_tf = jaxsim.math.Quaternion.integration(
|
582
|
+
quaternion=W_Q_B_t0,
|
583
|
+
dt=dt,
|
584
|
+
omega=W_ω_WB_t0,
|
585
|
+
omega_in_body_fixed=False,
|
586
|
+
)
|
587
|
+
|
588
|
+
# Replace the quaternion in the final state.
|
589
|
+
return xf.replace(
|
590
|
+
physics_model=xf.physics_model.replace(base_quaternion=W_Q_B_tf),
|
591
|
+
validate=True,
|
592
|
+
)
|