jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev364__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +86 -74
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/link.py +2 -2
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev364.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/references.py
CHANGED
@@ -7,10 +7,11 @@ import jax.numpy as jnp
|
|
7
7
|
import jax_dataclasses
|
8
8
|
|
9
9
|
import jaxsim.api as js
|
10
|
-
import jaxsim.physics.model.physics_model_state
|
11
10
|
import jaxsim.typing as jtp
|
12
|
-
from jaxsim import
|
13
|
-
|
11
|
+
from jaxsim.utils.tracing import not_tracing
|
12
|
+
|
13
|
+
from .common import VelRepr
|
14
|
+
from .ode_data import ODEInput
|
14
15
|
|
15
16
|
try:
|
16
17
|
from typing import Self
|
@@ -95,7 +96,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
95
96
|
|
96
97
|
# Create a zero references object.
|
97
98
|
references = JaxSimModelReferences(
|
98
|
-
input=ODEInput.zero(
|
99
|
+
input=ODEInput.zero(model=model),
|
99
100
|
velocity_representation=velocity_representation,
|
100
101
|
)
|
101
102
|
|
@@ -132,7 +133,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
132
133
|
valid = True
|
133
134
|
|
134
135
|
if model is not None:
|
135
|
-
valid = valid and self.input.valid(
|
136
|
+
valid = valid and self.input.valid(model=model)
|
136
137
|
|
137
138
|
return valid
|
138
139
|
|
@@ -188,7 +189,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
188
189
|
|
189
190
|
# If we have the model, we can extract the link names, if not provided.
|
190
191
|
link_names = link_names if link_names is not None else model.link_names()
|
191
|
-
link_idxs =
|
192
|
+
link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)
|
192
193
|
|
193
194
|
# In inertial-fixed representation, we already have the link forces.
|
194
195
|
if self.velocity_representation is VelRepr.Inertial:
|
@@ -198,7 +199,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
198
199
|
msg = "Missing model data to use a representation different from {}"
|
199
200
|
raise ValueError(msg.format(VelRepr.Inertial.name))
|
200
201
|
|
201
|
-
if not data.valid(model=model):
|
202
|
+
if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model):
|
202
203
|
raise ValueError("The provided data is not valid for the model")
|
203
204
|
|
204
205
|
# Helper function to convert a single 6D force to the active representation.
|
@@ -252,7 +253,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
252
253
|
|
253
254
|
return self.input.physics_model.tau
|
254
255
|
|
255
|
-
if not self.valid(model=model):
|
256
|
+
if not_tracing(self.input.physics_model.tau) and not self.valid(model=model):
|
256
257
|
msg = "The actuation object is not compatible with the provided model"
|
257
258
|
raise ValueError(msg)
|
258
259
|
|
@@ -303,7 +304,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
303
304
|
if model is None:
|
304
305
|
return replace(forces=forces)
|
305
306
|
|
306
|
-
if not self.valid(model=model):
|
307
|
+
if not_tracing(forces) and not self.valid(model=model):
|
307
308
|
msg = "The references object is not compatible with the provided model"
|
308
309
|
raise ValueError(msg)
|
309
310
|
|
@@ -379,7 +380,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
379
380
|
|
380
381
|
# If we have the model, we can extract the link names if not provided.
|
381
382
|
link_names = link_names if link_names is not None else model.link_names()
|
382
|
-
link_idxs =
|
383
|
+
link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)
|
383
384
|
|
384
385
|
# Compute the bias depending on whether we either set or add the link forces.
|
385
386
|
W_f0_L = (
|
@@ -401,7 +402,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
401
402
|
msg = "Missing model data to use a representation different from {}"
|
402
403
|
raise ValueError(msg.format(VelRepr.Inertial.name))
|
403
404
|
|
404
|
-
if not data.valid(model=model):
|
405
|
+
if not_tracing(forces) and not data.valid(model=model):
|
405
406
|
raise ValueError("The provided data is not valid for the model")
|
406
407
|
|
407
408
|
# Helper function to convert a single 6D force to the inertial representation.
|
jaxsim/integrators/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
from . import fixed_step
|
2
|
-
from .common import Integrator, Time, TimeStep
|
1
|
+
from . import fixed_step, variable_step
|
2
|
+
from .common import Integrator, SystemDynamics, Time, TimeStep
|
jaxsim/integrators/common.py
CHANGED
@@ -5,9 +5,12 @@ from typing import Any, ClassVar, Generic, Protocol, Type, TypeVar
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
7
|
import jax_dataclasses
|
8
|
+
import jaxlie
|
8
9
|
from jax_dataclasses import Static
|
9
10
|
|
11
|
+
import jaxsim.api as js
|
10
12
|
import jaxsim.typing as jtp
|
13
|
+
from jaxsim.math import Quaternion
|
11
14
|
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
|
12
15
|
|
13
16
|
try:
|
@@ -46,6 +49,9 @@ class SystemDynamics(Protocol[State, StateDerivative]):
|
|
46
49
|
@jax_dataclasses.pytree_dataclass
|
47
50
|
class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
48
51
|
|
52
|
+
AfterInitKey: ClassVar[str] = "after_init"
|
53
|
+
InitializingKey: ClassVar[str] = "initializing"
|
54
|
+
|
49
55
|
AuxDictDynamicsKey: ClassVar[str] = "aux_dict_dynamics"
|
50
56
|
|
51
57
|
dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
|
@@ -58,7 +64,10 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
58
64
|
|
59
65
|
@classmethod
|
60
66
|
def build(
|
61
|
-
cls: Type[Self],
|
67
|
+
cls: Type[Self],
|
68
|
+
*,
|
69
|
+
dynamics: SystemDynamics[State, StateDerivative],
|
70
|
+
**kwargs,
|
62
71
|
) -> Self:
|
63
72
|
"""
|
64
73
|
Build the integrator object.
|
@@ -102,9 +111,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
102
111
|
with integrator.mutable_context(mutability=Mutability.MUTABLE):
|
103
112
|
xf = integrator(x0, t0, dt, **kwargs)
|
104
113
|
|
105
|
-
|
106
|
-
|
107
|
-
|
114
|
+
return xf, integrator.params | {
|
115
|
+
Integrator.AfterInitKey: jnp.array(False).astype(bool)
|
116
|
+
}
|
108
117
|
|
109
118
|
@abc.abstractmethod
|
110
119
|
def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
|
@@ -116,7 +125,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
116
125
|
t0: Time,
|
117
126
|
dt: TimeStep,
|
118
127
|
*,
|
119
|
-
|
128
|
+
include_dynamics_aux_dict: bool = False,
|
120
129
|
**kwargs,
|
121
130
|
) -> dict[str, Any]:
|
122
131
|
"""
|
@@ -126,7 +135,6 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
126
135
|
x0: The initial state of the system.
|
127
136
|
t0: The initial time of the system.
|
128
137
|
dt: The time step of the integration.
|
129
|
-
key: An optional random key to initialize the integrator.
|
130
138
|
|
131
139
|
Returns:
|
132
140
|
The auxiliary dictionary of the integrator.
|
@@ -141,17 +149,43 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
141
149
|
the first step will be wrong.
|
142
150
|
"""
|
143
151
|
|
144
|
-
_, aux_dict_dynamics = self.dynamics(x0, t0)
|
145
|
-
|
146
152
|
with self.editable(validate=False) as integrator:
|
153
|
+
|
154
|
+
# Initialize the integrator parameters.
|
155
|
+
# For initialization purpose, the integrators can check if the
|
156
|
+
# `Integrator.InitializingKey` is present in their parameters.
|
157
|
+
# The AfterInitKey is used in the first step after initialization.
|
158
|
+
integrator.params = {
|
159
|
+
Integrator.InitializingKey: jnp.array(True),
|
160
|
+
Integrator.AfterInitKey: jnp.array(False),
|
161
|
+
}
|
162
|
+
|
163
|
+
# Run a dummy call of the integrator.
|
164
|
+
# It is used only to get the params so that we know the structure
|
165
|
+
# of the corresponding pytree.
|
147
166
|
_ = integrator(x0, t0, dt, **kwargs)
|
148
|
-
aux_dict_step = integrator.params
|
149
167
|
|
150
|
-
|
151
|
-
|
152
|
-
raise KeyError(msg.format(Integrator.AuxDictDynamicsKey))
|
168
|
+
# Remove the injected key.
|
169
|
+
_ = integrator.params.pop(Integrator.InitializingKey)
|
153
170
|
|
154
|
-
|
171
|
+
# Make sure that all leafs of the dictionary are JAX arrays.
|
172
|
+
# Also, since these are dummy parameters, set them all to zero.
|
173
|
+
params_after_init = jax.tree_util.tree_map(
|
174
|
+
lambda l: jnp.zeros_like(l), integrator.params
|
175
|
+
)
|
176
|
+
|
177
|
+
# Mark the next step as first step after initialization.
|
178
|
+
params_after_init = params_after_init | {
|
179
|
+
Integrator.AfterInitKey: jnp.array(True)
|
180
|
+
}
|
181
|
+
|
182
|
+
# Store the zero parameters in the integrator.
|
183
|
+
# When the integrator is stepped, this is used to check if the passed
|
184
|
+
# parameters are valid.
|
185
|
+
with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
186
|
+
self.params = params_after_init
|
187
|
+
|
188
|
+
return params_after_init
|
155
189
|
|
156
190
|
|
157
191
|
@jax_dataclasses.pytree_dataclass
|
@@ -209,20 +243,10 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
209
243
|
The integrator object.
|
210
244
|
"""
|
211
245
|
|
212
|
-
# Adjust the shape of the tableau coefficients.
|
213
|
-
c = jnp.atleast_1d(cls.c.squeeze())
|
214
|
-
b = jnp.atleast_2d(jnp.vstack(cls.b.squeeze()))
|
215
|
-
A = jnp.atleast_2d(cls.A.squeeze())
|
216
|
-
|
217
246
|
# Check validity of the Butcher tableau.
|
218
|
-
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
|
247
|
+
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=cls.A, b=cls.b, c=cls.c):
|
219
248
|
raise ValueError("The Butcher tableau of this class is not valid.")
|
220
249
|
|
221
|
-
# Store the adjusted shapes of the tableau coefficients.
|
222
|
-
cls.c = c
|
223
|
-
cls.b = b
|
224
|
-
cls.A = A
|
225
|
-
|
226
250
|
# Check that b.T has enough rows based on the configured index of the solution.
|
227
251
|
if cls.row_index_of_solution >= cls.b.T.shape[0]:
|
228
252
|
msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
|
@@ -506,3 +530,65 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
506
530
|
# possibly intermediate kᵢ derivative).
|
507
531
|
# Note that if multiple rows match (it should not), we return the first match.
|
508
532
|
return True, int(jnp.where(rows_of_A_with_fsal == True)[0].tolist()[0])
|
533
|
+
|
534
|
+
|
535
|
+
class ExplicitRungeKuttaSO3Mixin:
|
536
|
+
"""
|
537
|
+
Mixin class to apply over explicit RK integrators defined on
|
538
|
+
`PyTreeType = ODEState` to integrate the quaternion on SO(3).
|
539
|
+
"""
|
540
|
+
|
541
|
+
@classmethod
|
542
|
+
def integrate_rk_stage(
|
543
|
+
cls, x0: js.ode_data.ODEState, t0: Time, dt: TimeStep, k: js.ode_data.ODEState
|
544
|
+
) -> js.ode_data.ODEState:
|
545
|
+
|
546
|
+
op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
|
547
|
+
xf: js.ode_data.ODEState = jax.tree_util.tree_map(op, x0, k)
|
548
|
+
|
549
|
+
W_Q_B_t0 = x0.physics_model.base_quaternion
|
550
|
+
W_ω_WB_t0 = x0.physics_model.base_angular_velocity
|
551
|
+
|
552
|
+
return xf.replace(
|
553
|
+
physics_model=xf.physics_model.replace(
|
554
|
+
base_quaternion=Quaternion.integration(
|
555
|
+
quaternion=W_Q_B_t0,
|
556
|
+
dt=dt,
|
557
|
+
omega=W_ω_WB_t0,
|
558
|
+
omega_in_body_fixed=False,
|
559
|
+
),
|
560
|
+
)
|
561
|
+
)
|
562
|
+
|
563
|
+
@classmethod
|
564
|
+
def post_process_state(
|
565
|
+
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
|
566
|
+
) -> js.ode_data.ODEState:
|
567
|
+
|
568
|
+
# Indices to convert quaternions between serializations.
|
569
|
+
to_xyzw = jnp.array([1, 2, 3, 0])
|
570
|
+
to_wxyz = jnp.array([3, 0, 1, 2])
|
571
|
+
|
572
|
+
# Get the initial quaternion.
|
573
|
+
W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
|
574
|
+
xyzw=x0.physics_model.base_quaternion[to_xyzw]
|
575
|
+
)
|
576
|
+
|
577
|
+
# Get the final angular velocity.
|
578
|
+
# This is already computed by averaging the kᵢ in RK-based schemes.
|
579
|
+
# Therefore, by using the ω at tf, we obtain a RK scheme operating
|
580
|
+
# on the SO(3) manifold.
|
581
|
+
W_ω_WB_tf = xf.physics_model.base_angular_velocity
|
582
|
+
|
583
|
+
# Integrate the quaternion on SO(3).
|
584
|
+
# Note that we left-multiply with the exponential map since the angular
|
585
|
+
# velocity is expressed in the inertial frame.
|
586
|
+
W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
|
587
|
+
|
588
|
+
# Replace the quaternion in the final state.
|
589
|
+
return xf.replace(
|
590
|
+
physics_model=xf.physics_model.replace(
|
591
|
+
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
|
592
|
+
),
|
593
|
+
validate=True,
|
594
|
+
)
|
jaxsim/integrators/fixed_step.py
CHANGED
@@ -3,14 +3,12 @@ from typing import ClassVar, Generic
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
5
|
import jax_dataclasses
|
6
|
-
import jaxlie
|
7
6
|
|
8
|
-
|
7
|
+
import jaxsim.api as js
|
9
8
|
|
10
|
-
from .common import ExplicitRungeKutta,
|
11
|
-
|
12
|
-
ODEStateDerivative = ODEState
|
9
|
+
from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
|
13
10
|
|
11
|
+
ODEStateDerivative = js.ode_data.ODEState
|
14
12
|
|
15
13
|
# =====================================================
|
16
14
|
# Explicit Runge-Kutta integrators operating on PyTrees
|
@@ -20,37 +18,23 @@ ODEStateDerivative = ODEState
|
|
20
18
|
@jax_dataclasses.pytree_dataclass
|
21
19
|
class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
22
20
|
|
23
|
-
A: ClassVar[jax.typing.ArrayLike] = jnp.
|
24
|
-
[
|
25
|
-
[0],
|
26
|
-
]
|
27
|
-
).astype(float)
|
21
|
+
A: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(0).astype(float)
|
28
22
|
|
29
|
-
b: ClassVar[jax.typing.ArrayLike] = (
|
30
|
-
jnp.array(
|
31
|
-
[
|
32
|
-
[1],
|
33
|
-
]
|
34
|
-
)
|
35
|
-
.astype(float)
|
36
|
-
.transpose()
|
37
|
-
)
|
23
|
+
b: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(1).astype(float).transpose()
|
38
24
|
|
39
|
-
c: ClassVar[jax.typing.ArrayLike] = jnp.
|
40
|
-
[0],
|
41
|
-
).astype(float)
|
25
|
+
c: ClassVar[jax.typing.ArrayLike] = jnp.atleast_1d(0).astype(float)
|
42
26
|
|
43
27
|
row_index_of_solution: ClassVar[int] = 0
|
44
28
|
order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
|
45
29
|
|
46
30
|
|
47
31
|
@jax_dataclasses.pytree_dataclass
|
48
|
-
class
|
32
|
+
class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
49
33
|
|
50
34
|
A: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
51
35
|
[
|
52
36
|
[0, 0],
|
53
|
-
[1
|
37
|
+
[1, 0],
|
54
38
|
]
|
55
39
|
).astype(float)
|
56
40
|
|
@@ -103,56 +87,16 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
103
87
|
# ===============================================================================
|
104
88
|
|
105
89
|
|
106
|
-
class ExplicitRungeKuttaSO3Mixin:
|
107
|
-
"""
|
108
|
-
Mixin class to apply over explicit RK integrators defined on
|
109
|
-
`PyTreeType = ODEState` to integrate the quaternion on SO(3).
|
110
|
-
"""
|
111
|
-
|
112
|
-
@classmethod
|
113
|
-
def post_process_state(
|
114
|
-
cls, x0: ODEState, t0: Time, xf: ODEState, dt: TimeStep
|
115
|
-
) -> ODEState:
|
116
|
-
|
117
|
-
# Indices to convert quaternions between serializations.
|
118
|
-
to_xyzw = jnp.array([1, 2, 3, 0])
|
119
|
-
to_wxyz = jnp.array([3, 0, 1, 2])
|
120
|
-
|
121
|
-
# Get the initial quaternion.
|
122
|
-
W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
|
123
|
-
xyzw=x0.physics_model.base_quaternion[to_xyzw]
|
124
|
-
)
|
125
|
-
|
126
|
-
# Get the final angular velocity.
|
127
|
-
# This is already computed by averaging the kᵢ in RK-based schemes.
|
128
|
-
# Therefore, by using the ω at tf, we obtain a RK scheme operating
|
129
|
-
# on the SO(3) manifold.
|
130
|
-
W_ω_WB_tf = xf.physics_model.base_angular_velocity
|
131
|
-
|
132
|
-
# Integrate the quaternion on SO(3).
|
133
|
-
# Note that we left-multiply with the exponential map since the angular
|
134
|
-
# velocity is expressed in the inertial frame.
|
135
|
-
W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
|
136
|
-
|
137
|
-
# Replace the quaternion in the final state.
|
138
|
-
return xf.replace(
|
139
|
-
physics_model=xf.physics_model.replace(
|
140
|
-
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
|
141
|
-
),
|
142
|
-
validate=True,
|
143
|
-
)
|
144
|
-
|
145
|
-
|
146
90
|
@jax_dataclasses.pytree_dataclass
|
147
|
-
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin,
|
91
|
+
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
|
148
92
|
pass
|
149
93
|
|
150
94
|
|
151
95
|
@jax_dataclasses.pytree_dataclass
|
152
|
-
class
|
96
|
+
class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
|
153
97
|
pass
|
154
98
|
|
155
99
|
|
156
100
|
@jax_dataclasses.pytree_dataclass
|
157
|
-
class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[ODEState]):
|
101
|
+
class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
|
158
102
|
pass
|