jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,283 +0,0 @@
|
|
1
|
-
from typing import Union
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
import jax_dataclasses
|
5
|
-
|
6
|
-
import jaxsim.physics.model.physics_model
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.utils import JaxsimDataclass
|
9
|
-
|
10
|
-
|
11
|
-
@jax_dataclasses.pytree_dataclass
|
12
|
-
class PhysicsModelState(JaxsimDataclass):
|
13
|
-
"""
|
14
|
-
A class representing the state of a physics model.
|
15
|
-
|
16
|
-
This class stores the joint positions, joint velocities, and the base state (position, orientation, linear velocity,
|
17
|
-
and angular velocity) of a physics model.
|
18
|
-
|
19
|
-
Attributes:
|
20
|
-
joint_positions (jtp.Vector): An array representing the joint positions.
|
21
|
-
joint_velocities (jtp.Vector): An array representing the joint velocities.
|
22
|
-
base_position (jtp.Vector): An array representing the base position (default: zeros).
|
23
|
-
base_quaternion (jtp.Vector): An array representing the base quaternion (default: [1.0, 0, 0, 0]).
|
24
|
-
base_linear_velocity (jtp.Vector): An array representing the base linear velocity (default: zeros).
|
25
|
-
base_angular_velocity (jtp.Vector): An array representing the base angular velocity (default: zeros).
|
26
|
-
"""
|
27
|
-
|
28
|
-
# Joint state
|
29
|
-
joint_positions: jtp.Vector
|
30
|
-
joint_velocities: jtp.Vector
|
31
|
-
|
32
|
-
# Base state
|
33
|
-
base_position: jtp.Vector = jax_dataclasses.field(
|
34
|
-
default_factory=lambda: jnp.zeros(3)
|
35
|
-
)
|
36
|
-
base_quaternion: jtp.Vector = jax_dataclasses.field(
|
37
|
-
default_factory=lambda: jnp.array([1.0, 0, 0, 0])
|
38
|
-
)
|
39
|
-
base_linear_velocity: jtp.Vector = jax_dataclasses.field(
|
40
|
-
default_factory=lambda: jnp.zeros(3)
|
41
|
-
)
|
42
|
-
base_angular_velocity: jtp.Vector = jax_dataclasses.field(
|
43
|
-
default_factory=lambda: jnp.zeros(3)
|
44
|
-
)
|
45
|
-
|
46
|
-
@staticmethod
|
47
|
-
def build(
|
48
|
-
joint_positions: jtp.Vector | None = None,
|
49
|
-
joint_velocities: jtp.Vector | None = None,
|
50
|
-
base_position: jtp.Vector | None = None,
|
51
|
-
base_quaternion: jtp.Vector | None = None,
|
52
|
-
base_linear_velocity: jtp.Vector | None = None,
|
53
|
-
base_angular_velocity: jtp.Vector | None = None,
|
54
|
-
number_of_dofs: jtp.Int | None = None,
|
55
|
-
) -> "PhysicsModelState":
|
56
|
-
""""""
|
57
|
-
|
58
|
-
joint_positions = (
|
59
|
-
joint_positions
|
60
|
-
if joint_positions is not None
|
61
|
-
else jnp.zeros(number_of_dofs)
|
62
|
-
)
|
63
|
-
|
64
|
-
joint_velocities = (
|
65
|
-
joint_velocities
|
66
|
-
if joint_velocities is not None
|
67
|
-
else jnp.zeros(number_of_dofs)
|
68
|
-
)
|
69
|
-
|
70
|
-
base_position = base_position if base_position is not None else jnp.zeros(3)
|
71
|
-
|
72
|
-
base_quaternion = (
|
73
|
-
base_quaternion
|
74
|
-
if base_quaternion is not None
|
75
|
-
else jnp.array([1.0, 0, 0, 0])
|
76
|
-
)
|
77
|
-
|
78
|
-
base_linear_velocity = (
|
79
|
-
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
|
80
|
-
)
|
81
|
-
|
82
|
-
base_angular_velocity = (
|
83
|
-
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
|
84
|
-
)
|
85
|
-
|
86
|
-
physics_model_state = PhysicsModelState(
|
87
|
-
joint_positions=jnp.array(joint_positions, dtype=float),
|
88
|
-
joint_velocities=jnp.array(joint_velocities, dtype=float),
|
89
|
-
base_position=jnp.array(base_position, dtype=float),
|
90
|
-
base_quaternion=jnp.array(base_quaternion, dtype=float),
|
91
|
-
base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
|
92
|
-
base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
|
93
|
-
)
|
94
|
-
|
95
|
-
return physics_model_state
|
96
|
-
|
97
|
-
@staticmethod
|
98
|
-
def build_from_physics_model(
|
99
|
-
joint_positions: jtp.Vector | None = None,
|
100
|
-
joint_velocities: jtp.Vector | None = None,
|
101
|
-
base_position: jtp.Vector | None = None,
|
102
|
-
base_quaternion: jtp.Vector | None = None,
|
103
|
-
base_linear_velocity: jtp.Vector | None = None,
|
104
|
-
base_angular_velocity: jtp.Vector | None = None,
|
105
|
-
physics_model: Union[
|
106
|
-
"jaxsim.physics.model.physics_model.PhysicsModel", None
|
107
|
-
] = None,
|
108
|
-
) -> "PhysicsModelState":
|
109
|
-
""""""
|
110
|
-
|
111
|
-
return PhysicsModelState.build(
|
112
|
-
joint_positions=joint_positions,
|
113
|
-
joint_velocities=joint_velocities,
|
114
|
-
base_position=base_position,
|
115
|
-
base_quaternion=base_quaternion,
|
116
|
-
base_linear_velocity=base_linear_velocity,
|
117
|
-
base_angular_velocity=base_angular_velocity,
|
118
|
-
number_of_dofs=physics_model.dofs(),
|
119
|
-
)
|
120
|
-
|
121
|
-
@staticmethod
|
122
|
-
def zero(
|
123
|
-
physics_model: "jaxsim.physics.model.physics_model.PhysicsModel",
|
124
|
-
) -> "PhysicsModelState":
|
125
|
-
return PhysicsModelState.build_from_physics_model(physics_model=physics_model)
|
126
|
-
|
127
|
-
def position(self) -> jtp.Vector:
|
128
|
-
return jnp.hstack(
|
129
|
-
[self.base_position, self.base_quaternion, self.joint_positions]
|
130
|
-
)
|
131
|
-
|
132
|
-
def velocity(self) -> jtp.Vector:
|
133
|
-
# W_v_WB: inertial-fixed representation of the base velocity
|
134
|
-
return jnp.hstack(
|
135
|
-
[
|
136
|
-
self.base_linear_velocity,
|
137
|
-
self.base_angular_velocity,
|
138
|
-
self.joint_velocities,
|
139
|
-
]
|
140
|
-
)
|
141
|
-
|
142
|
-
def xfb(self) -> jtp.Vector:
|
143
|
-
return jnp.hstack(
|
144
|
-
[
|
145
|
-
self.base_quaternion,
|
146
|
-
self.base_position,
|
147
|
-
self.base_angular_velocity,
|
148
|
-
self.base_linear_velocity,
|
149
|
-
]
|
150
|
-
)
|
151
|
-
|
152
|
-
def valid(
|
153
|
-
self, physics_model: "jaxsim.physics.model.physics_model.PhysicsModel"
|
154
|
-
) -> bool:
|
155
|
-
from jaxsim.simulation.utils import check_valid_shape
|
156
|
-
|
157
|
-
valid = True
|
158
|
-
|
159
|
-
valid = check_valid_shape(
|
160
|
-
what="joint_positions",
|
161
|
-
shape=self.joint_positions.shape,
|
162
|
-
expected_shape=(physics_model.dofs(),),
|
163
|
-
valid=valid,
|
164
|
-
)
|
165
|
-
|
166
|
-
valid = check_valid_shape(
|
167
|
-
what="joint_velocities",
|
168
|
-
shape=self.joint_velocities.shape,
|
169
|
-
expected_shape=(physics_model.dofs(),),
|
170
|
-
valid=valid,
|
171
|
-
)
|
172
|
-
|
173
|
-
valid = check_valid_shape(
|
174
|
-
what="base_position",
|
175
|
-
shape=self.base_position.shape,
|
176
|
-
expected_shape=(3,),
|
177
|
-
valid=valid,
|
178
|
-
)
|
179
|
-
|
180
|
-
valid = check_valid_shape(
|
181
|
-
what="base_quaternion",
|
182
|
-
shape=self.base_quaternion.shape,
|
183
|
-
expected_shape=(4,),
|
184
|
-
valid=valid,
|
185
|
-
)
|
186
|
-
|
187
|
-
valid = check_valid_shape(
|
188
|
-
what="base_linear_velocity",
|
189
|
-
shape=self.base_linear_velocity.shape,
|
190
|
-
expected_shape=(3,),
|
191
|
-
valid=valid,
|
192
|
-
)
|
193
|
-
|
194
|
-
valid = check_valid_shape(
|
195
|
-
what="base_angular_velocity",
|
196
|
-
shape=self.base_angular_velocity.shape,
|
197
|
-
expected_shape=(3,),
|
198
|
-
valid=valid,
|
199
|
-
)
|
200
|
-
|
201
|
-
return valid
|
202
|
-
|
203
|
-
|
204
|
-
@jax_dataclasses.pytree_dataclass
|
205
|
-
class PhysicsModelInput(JaxsimDataclass):
|
206
|
-
"""
|
207
|
-
A class representing the input to a physics model.
|
208
|
-
|
209
|
-
This class stores the joint torques and external forces acting on the bodies of a physics model.
|
210
|
-
|
211
|
-
Attributes:
|
212
|
-
tau: An array representing the joint torques.
|
213
|
-
f_ext: A matrix representing the external forces acting on the bodies of the physics model.
|
214
|
-
"""
|
215
|
-
|
216
|
-
tau: jtp.VectorJax
|
217
|
-
f_ext: jtp.MatrixJax
|
218
|
-
|
219
|
-
@staticmethod
|
220
|
-
def build(
|
221
|
-
tau: jtp.VectorJax | None = None,
|
222
|
-
f_ext: jtp.MatrixJax | None = None,
|
223
|
-
number_of_dofs: jtp.Int | None = None,
|
224
|
-
number_of_links: jtp.Int | None = None,
|
225
|
-
) -> "PhysicsModelInput":
|
226
|
-
""""""
|
227
|
-
|
228
|
-
tau = tau if tau is not None else jnp.zeros(number_of_dofs)
|
229
|
-
f_ext = f_ext if f_ext is not None else jnp.zeros(shape=(number_of_links, 6))
|
230
|
-
|
231
|
-
return PhysicsModelInput(
|
232
|
-
tau=jnp.array(tau, dtype=float), f_ext=jnp.array(f_ext, dtype=float)
|
233
|
-
)
|
234
|
-
|
235
|
-
@staticmethod
|
236
|
-
def build_from_physics_model(
|
237
|
-
tau: jtp.VectorJax | None = None,
|
238
|
-
f_ext: jtp.MatrixJax | None = None,
|
239
|
-
physics_model: Union[
|
240
|
-
"jaxsim.physics.model.physics_model.PhysicsModel", None
|
241
|
-
] = None,
|
242
|
-
) -> "PhysicsModelInput":
|
243
|
-
return PhysicsModelInput.build(
|
244
|
-
tau=tau,
|
245
|
-
f_ext=f_ext,
|
246
|
-
number_of_dofs=physics_model.dofs(),
|
247
|
-
number_of_links=physics_model.NB,
|
248
|
-
)
|
249
|
-
|
250
|
-
@staticmethod
|
251
|
-
def zero(
|
252
|
-
physics_model: "jaxsim.physics.model.physics_model.PhysicsModel",
|
253
|
-
) -> "PhysicsModelInput":
|
254
|
-
return PhysicsModelInput.build_from_physics_model(physics_model=physics_model)
|
255
|
-
|
256
|
-
def replace(self, validate: bool = True, **kwargs) -> "PhysicsModelInput":
|
257
|
-
with jax_dataclasses.copy_and_mutate(self, validate=validate) as updated_input:
|
258
|
-
_ = [updated_input.__setattr__(k, v) for k, v in kwargs.items()]
|
259
|
-
|
260
|
-
return updated_input
|
261
|
-
|
262
|
-
def valid(
|
263
|
-
self, physics_model: "jaxsim.physics.model.physics_model.PhysicsModel"
|
264
|
-
) -> bool:
|
265
|
-
from jaxsim.simulation.utils import check_valid_shape
|
266
|
-
|
267
|
-
valid = True
|
268
|
-
|
269
|
-
valid = check_valid_shape(
|
270
|
-
what="tau",
|
271
|
-
shape=self.tau.shape,
|
272
|
-
expected_shape=(physics_model.dofs(),),
|
273
|
-
valid=valid,
|
274
|
-
)
|
275
|
-
|
276
|
-
valid = check_valid_shape(
|
277
|
-
what="f_ext",
|
278
|
-
shape=self.f_ext.shape,
|
279
|
-
expected_shape=(physics_model.NB, 6),
|
280
|
-
valid=valid,
|
281
|
-
)
|
282
|
-
|
283
|
-
return valid
|
jaxsim/simulation/__init__.py
DELETED
jaxsim/simulation/integrators.py
DELETED
@@ -1,393 +0,0 @@
|
|
1
|
-
import enum
|
2
|
-
from typing import Any, Callable
|
3
|
-
|
4
|
-
import jax
|
5
|
-
import jax.numpy as jnp
|
6
|
-
from jax.tree_util import tree_map
|
7
|
-
|
8
|
-
import jaxsim.typing as jtp
|
9
|
-
from jaxsim.math.quaternion import Quaternion
|
10
|
-
from jaxsim.physics.algos.soft_contacts import SoftContactsState
|
11
|
-
from jaxsim.physics.model.physics_model_state import PhysicsModelState
|
12
|
-
from jaxsim.simulation.ode_data import ODEState
|
13
|
-
from jaxsim.sixd import se3, so3
|
14
|
-
|
15
|
-
Time = jtp.FloatLike
|
16
|
-
TimeStep = jtp.FloatLike
|
17
|
-
TimeHorizon = jtp.VectorLike
|
18
|
-
|
19
|
-
State = jtp.PyTree
|
20
|
-
StateDerivative = jtp.PyTree
|
21
|
-
|
22
|
-
StateDerivativeCallable = Callable[
|
23
|
-
[State, Time], tuple[StateDerivative, dict[str, Any]]
|
24
|
-
]
|
25
|
-
|
26
|
-
|
27
|
-
class IntegratorType(enum.IntEnum):
|
28
|
-
RungeKutta4 = enum.auto()
|
29
|
-
EulerForward = enum.auto()
|
30
|
-
EulerSemiImplicit = enum.auto()
|
31
|
-
EulerSemiImplicitManifold = enum.auto()
|
32
|
-
|
33
|
-
|
34
|
-
# =======================
|
35
|
-
# Single-step integration
|
36
|
-
# =======================
|
37
|
-
|
38
|
-
|
39
|
-
def integrator_fixed_single_step(
|
40
|
-
dx_dt: StateDerivativeCallable,
|
41
|
-
x0: State | ODEState,
|
42
|
-
t0: Time,
|
43
|
-
tf: Time,
|
44
|
-
integrator_type: IntegratorType,
|
45
|
-
num_sub_steps: int = 1,
|
46
|
-
) -> tuple[State | ODEState, dict[str, Any]]:
|
47
|
-
"""
|
48
|
-
Advance a state vector by integrating a sytem dynamics with a fixed-step integrator.
|
49
|
-
|
50
|
-
Args:
|
51
|
-
dx_dt: Callable that computes the state derivative.
|
52
|
-
x0: Initial state.
|
53
|
-
t0: Initial time.
|
54
|
-
tf: Final time.
|
55
|
-
integrator_type: Integrator type.
|
56
|
-
num_sub_steps: Number of sub-steps to break the integration into.
|
57
|
-
|
58
|
-
Returns:
|
59
|
-
The final state and a dictionary including auxiliary data at t0.
|
60
|
-
"""
|
61
|
-
|
62
|
-
# Compute the sub-step size.
|
63
|
-
# We break dt in configurable sub-steps.
|
64
|
-
dt = tf - t0
|
65
|
-
sub_step_dt = dt / num_sub_steps
|
66
|
-
|
67
|
-
# Initialize the carry
|
68
|
-
Carry = tuple[State | ODEState, Time]
|
69
|
-
carry_init: Carry = (x0, t0)
|
70
|
-
|
71
|
-
def forward_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
|
72
|
-
"""
|
73
|
-
Forward Euler integrator.
|
74
|
-
"""
|
75
|
-
|
76
|
-
# Unpack the carry
|
77
|
-
x_t0, t0 = carry
|
78
|
-
|
79
|
-
# Compute the state derivative
|
80
|
-
dxdt_t0, _ = dx_dt(x_t0, t0)
|
81
|
-
|
82
|
-
# Integrate the dynamics
|
83
|
-
x_tf = jax.tree_util.tree_map(
|
84
|
-
lambda x, dxdt: x + sub_step_dt * dxdt, x_t0, dxdt_t0
|
85
|
-
)
|
86
|
-
|
87
|
-
# Update the time
|
88
|
-
tf = t0 + sub_step_dt
|
89
|
-
|
90
|
-
# Pack the carry
|
91
|
-
carry = (x_tf, tf)
|
92
|
-
|
93
|
-
return carry, None
|
94
|
-
|
95
|
-
def rk4_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
|
96
|
-
"""
|
97
|
-
Runge-Kutta 4 integrator.
|
98
|
-
"""
|
99
|
-
|
100
|
-
# Unpack the carry
|
101
|
-
x_t0, t0 = carry
|
102
|
-
|
103
|
-
# Helper to forward the state to compute k2 and k3 at midpoint and k4 at final
|
104
|
-
euler_mid = lambda x, dxdt: x + (0.5 * sub_step_dt) * dxdt
|
105
|
-
euler_fin = lambda x, dxdt: x + sub_step_dt * dxdt
|
106
|
-
|
107
|
-
# Compute the RK4 slopes
|
108
|
-
k1, _ = dx_dt(x_t0, t0)
|
109
|
-
k2, _ = dx_dt(tree_map(euler_mid, x_t0, k1), t0 + 0.5 * sub_step_dt)
|
110
|
-
k3, _ = dx_dt(tree_map(euler_mid, x_t0, k2), t0 + 0.5 * sub_step_dt)
|
111
|
-
k4, _ = dx_dt(tree_map(euler_fin, x_t0, k3), t0 + sub_step_dt)
|
112
|
-
|
113
|
-
# Average the slopes and compute the RK4 state derivative
|
114
|
-
average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6
|
115
|
-
dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4)
|
116
|
-
|
117
|
-
# Integrate the dynamics
|
118
|
-
x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt)
|
119
|
-
|
120
|
-
# Update the time
|
121
|
-
tf = t0 + sub_step_dt
|
122
|
-
|
123
|
-
# Pack the carry
|
124
|
-
carry = (x_tf, tf)
|
125
|
-
|
126
|
-
return carry, None
|
127
|
-
|
128
|
-
def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
|
129
|
-
"""
|
130
|
-
Semi-implicit Euler integrator.
|
131
|
-
"""
|
132
|
-
|
133
|
-
# Unpack the carry
|
134
|
-
x_t0, t0 = carry
|
135
|
-
|
136
|
-
# Compute the state derivative.
|
137
|
-
# We only keep the quantities related to the acceleration and discard those
|
138
|
-
# related to the velocity since we are going to use those implicitly integrated
|
139
|
-
# from the accelerations.
|
140
|
-
StateDerivative = ODEState
|
141
|
-
dxdt_t0: StateDerivative = dx_dt(x_t0, t0)[0]
|
142
|
-
|
143
|
-
# Extract the initial position ∈ ℝ⁷⁺ⁿ and initial velocity ∈ ℝ⁶⁺ⁿ.
|
144
|
-
# This integrator, contrarily to most of the other ones, is not generic.
|
145
|
-
# It expects to operate on an x object of class ODEState.
|
146
|
-
pos_t0 = x_t0.physics_model.position()
|
147
|
-
vel_t0 = x_t0.physics_model.velocity()
|
148
|
-
|
149
|
-
# Extract the velocity derivative
|
150
|
-
d_vel_dt = dxdt_t0.physics_model.velocity()
|
151
|
-
|
152
|
-
# =============================================
|
153
|
-
# Perform semi-implicit Euler integration [1-4]
|
154
|
-
# =============================================
|
155
|
-
|
156
|
-
# 1. Integrate the accelerations obtaining the implicit velocities
|
157
|
-
# 2. Compute the derivative of the generalized position
|
158
|
-
# 3. Integrate the implicit velocities
|
159
|
-
# 4. Integrate the remaining state
|
160
|
-
# 5. Outside the loop: integrate the quaternion on SO(3) manifold
|
161
|
-
|
162
|
-
# ----------------------------------------------------------------
|
163
|
-
# 1. Integrate the accelerations obtaining the implicit velocities
|
164
|
-
# ----------------------------------------------------------------
|
165
|
-
|
166
|
-
vel_tf = vel_t0 + sub_step_dt * d_vel_dt
|
167
|
-
|
168
|
-
# -----------------------------------------------------
|
169
|
-
# 2. Compute the derivative of the generalized position
|
170
|
-
# -----------------------------------------------------
|
171
|
-
|
172
|
-
# Extract the implicit angular velocity and the initial base quaternion
|
173
|
-
W_ω_WB = vel_tf[3:6]
|
174
|
-
W_Q_B = x_t0.physics_model.base_quaternion
|
175
|
-
|
176
|
-
# Compute the quaternion derivative and the base position derivative
|
177
|
-
W_Qd_B = Quaternion.derivative(
|
178
|
-
quaternion=W_Q_B, omega=W_ω_WB, omega_in_body_fixed=False
|
179
|
-
).squeeze()
|
180
|
-
|
181
|
-
# Compute the transform of the mixed base frame at t0
|
182
|
-
W_H_BW = jnp.vstack(
|
183
|
-
[
|
184
|
-
jnp.block([jnp.eye(3), jnp.vstack(x_t0.physics_model.base_position)]),
|
185
|
-
jnp.array([0, 0, 0, 1]),
|
186
|
-
]
|
187
|
-
)
|
188
|
-
|
189
|
-
# The derivative W_ṗ_B of the base position is the linear component of the
|
190
|
-
# mixed velocity B[W]_v_WB. We need to compute it from the velocity in
|
191
|
-
# inertial-fixed representation W_vl_WB.
|
192
|
-
W_v_WB = vel_tf[0:6]
|
193
|
-
BW_Xv_W = se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
|
194
|
-
BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3]
|
195
|
-
|
196
|
-
# Compute the derivative of the generalized position
|
197
|
-
d_pos_tf = (
|
198
|
-
jnp.hstack([BW_vl_WB, vel_tf[6:]])
|
199
|
-
if integrator_type is IntegratorType.EulerSemiImplicitManifold
|
200
|
-
else jnp.hstack([BW_vl_WB, W_Qd_B, vel_tf[6:]])
|
201
|
-
)
|
202
|
-
|
203
|
-
# ------------------------------------
|
204
|
-
# 3. Integrate the implicit velocities
|
205
|
-
# ------------------------------------
|
206
|
-
|
207
|
-
pos_tf = pos_t0 + sub_step_dt * d_pos_tf
|
208
|
-
joint_positions = (
|
209
|
-
pos_tf[3:]
|
210
|
-
if integrator_type is IntegratorType.EulerSemiImplicitManifold
|
211
|
-
else pos_tf[7:]
|
212
|
-
)
|
213
|
-
base_quaternion = (
|
214
|
-
jnp.zeros_like(x_t0.base_quaternion)
|
215
|
-
if integrator_type is IntegratorType.EulerSemiImplicitManifold
|
216
|
-
else pos_tf[3:7]
|
217
|
-
)
|
218
|
-
|
219
|
-
# ---------------------------------
|
220
|
-
# 4. Integrate the remaining state
|
221
|
-
# ---------------------------------
|
222
|
-
|
223
|
-
# Integrate the derivative of the tangential material deformation
|
224
|
-
m = x_t0.soft_contacts.tangential_deformation
|
225
|
-
ṁ = dxdt_t0.soft_contacts.tangential_deformation
|
226
|
-
tangential_deformation_tf = m + sub_step_dt * ṁ
|
227
|
-
|
228
|
-
# Pack the new state into an ODEState object
|
229
|
-
x_tf = ODEState(
|
230
|
-
physics_model=PhysicsModelState(
|
231
|
-
base_position=pos_tf[0:3],
|
232
|
-
base_quaternion=base_quaternion,
|
233
|
-
joint_positions=joint_positions,
|
234
|
-
base_linear_velocity=vel_tf[0:3],
|
235
|
-
base_angular_velocity=vel_tf[3:6],
|
236
|
-
joint_velocities=vel_tf[6:],
|
237
|
-
),
|
238
|
-
soft_contacts=SoftContactsState(
|
239
|
-
tangential_deformation=tangential_deformation_tf
|
240
|
-
),
|
241
|
-
)
|
242
|
-
|
243
|
-
# Update the time
|
244
|
-
tf = t0 + sub_step_dt
|
245
|
-
|
246
|
-
# Pack the carry
|
247
|
-
carry = (x_tf, tf)
|
248
|
-
|
249
|
-
return carry, None
|
250
|
-
|
251
|
-
_integrator_registry = {
|
252
|
-
IntegratorType.RungeKutta4: rk4_body_fun,
|
253
|
-
IntegratorType.EulerForward: forward_euler_body_fun,
|
254
|
-
IntegratorType.EulerSemiImplicit: semi_implicit_euler_body_fun,
|
255
|
-
IntegratorType.EulerSemiImplicitManifold: semi_implicit_euler_body_fun,
|
256
|
-
}
|
257
|
-
|
258
|
-
# Get the body function for the selected integrator
|
259
|
-
body_fun = _integrator_registry[integrator_type]
|
260
|
-
|
261
|
-
# Integrate over the given horizon
|
262
|
-
(x_tf, _), _ = jax.lax.scan(
|
263
|
-
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
|
264
|
-
)
|
265
|
-
|
266
|
-
if integrator_type is IntegratorType.EulerSemiImplicitManifold:
|
267
|
-
# Indices to convert quaternions between serializations
|
268
|
-
to_xyzw = jnp.array([1, 2, 3, 0])
|
269
|
-
to_wxyz = jnp.array([3, 0, 1, 2])
|
270
|
-
|
271
|
-
# Get the initial quaternion and the implicitly integrated angular velocity
|
272
|
-
W_ω_WB_tf = x_tf.physics_model.base_angular_velocity
|
273
|
-
W_Q_B_t0 = so3.SO3.from_quaternion_xyzw(
|
274
|
-
x0.physics_model.base_quaternion[to_xyzw]
|
275
|
-
)
|
276
|
-
|
277
|
-
# Integrate the quaternion on its manifold using the implicit angular velocity,
|
278
|
-
# transformed in body-fixed representation since jaxlie uses this convention
|
279
|
-
B_R_W = W_Q_B_t0.inverse().as_matrix()
|
280
|
-
W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf)
|
281
|
-
|
282
|
-
# Store the quaternion in the final state
|
283
|
-
x_tf = x_tf.replace(
|
284
|
-
physics_model=x_tf.physics_model.replace(
|
285
|
-
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
|
286
|
-
)
|
287
|
-
)
|
288
|
-
|
289
|
-
# Compute the aux dictionary at t0
|
290
|
-
_, aux_t0 = dx_dt(x0, t0)
|
291
|
-
|
292
|
-
return x_tf, aux_t0
|
293
|
-
|
294
|
-
|
295
|
-
# ===============================
|
296
|
-
# Adapter: single step -> horizon
|
297
|
-
# ===============================
|
298
|
-
|
299
|
-
|
300
|
-
def integrate_single_step_over_horizon(
|
301
|
-
integrator_single_step: Callable[[Time, Time, State], tuple[State, dict[str, Any]]],
|
302
|
-
t: TimeHorizon,
|
303
|
-
x0: State,
|
304
|
-
) -> tuple[State, dict[str, Any]]:
|
305
|
-
"""
|
306
|
-
Integrate a single-step integrator over a given horizon.
|
307
|
-
|
308
|
-
Args:
|
309
|
-
integrator_single_step: A single-step integrator.
|
310
|
-
t: The vector of time instants of the integration horizon.
|
311
|
-
x0: The initial state of the integration horizon.
|
312
|
-
|
313
|
-
Returns:
|
314
|
-
The final state and auxiliary data produced by the integrator.
|
315
|
-
"""
|
316
|
-
|
317
|
-
# Initialize the carry
|
318
|
-
carry_init = (x0, t)
|
319
|
-
|
320
|
-
def body_fun(carry: tuple, idx: int) -> tuple[tuple, jtp.PyTree]:
|
321
|
-
# Unpack the carry
|
322
|
-
x_t0, horizon = carry
|
323
|
-
|
324
|
-
# Get the integration interval
|
325
|
-
t0 = horizon[idx]
|
326
|
-
tf = horizon[idx + 1]
|
327
|
-
|
328
|
-
# Perform a single-step integration of the ODE
|
329
|
-
x_tf, aux_t0 = integrator_single_step(t0, tf, x_t0)
|
330
|
-
|
331
|
-
# Prepare returned data
|
332
|
-
out = (x_t0, aux_t0)
|
333
|
-
carry = (x_tf, horizon)
|
334
|
-
|
335
|
-
return carry, out
|
336
|
-
|
337
|
-
# Integrate over the given horizon
|
338
|
-
_, (x_horizon, aux_horizon) = jax.lax.scan(
|
339
|
-
f=body_fun, init=carry_init, xs=jnp.arange(start=0, stop=len(t), dtype=int)
|
340
|
-
)
|
341
|
-
|
342
|
-
return x_horizon, aux_horizon
|
343
|
-
|
344
|
-
|
345
|
-
# ===================================================================
|
346
|
-
# Integration over horizon (same APIs of jax.experimental.ode.odeint)
|
347
|
-
# ===================================================================
|
348
|
-
|
349
|
-
|
350
|
-
def odeint(
|
351
|
-
func,
|
352
|
-
y0: State,
|
353
|
-
t: TimeHorizon,
|
354
|
-
*args,
|
355
|
-
num_sub_steps: int = 1,
|
356
|
-
return_aux: bool = False,
|
357
|
-
integrator_type: IntegratorType = None,
|
358
|
-
):
|
359
|
-
"""
|
360
|
-
Integrate a system of ODEs with a fixed-step integrator.
|
361
|
-
|
362
|
-
Args:
|
363
|
-
func: A function that computes the time-derivative of the state.
|
364
|
-
y0: The initial state.
|
365
|
-
t: The vector of time instants of the integration horizon.
|
366
|
-
*args: Additional arguments to be passed to the function func.
|
367
|
-
num_sub_steps: The number of sub-steps to be performed within each integration step.
|
368
|
-
return_aux: Whether to return the auxiliary data produced by the integrator.
|
369
|
-
|
370
|
-
Returns:
|
371
|
-
The state of the system at the end of the integration horizon, and optionally
|
372
|
-
the auxiliary data produced by the integrator.
|
373
|
-
"""
|
374
|
-
|
375
|
-
# Close func over additional inputs and parameters
|
376
|
-
dx_dt_closure = lambda x, ts: func(x, ts, *args)
|
377
|
-
|
378
|
-
# Close one-step integration over its arguments
|
379
|
-
integrator_single_step = lambda t0, tf, x0: integrator_fixed_single_step(
|
380
|
-
dx_dt=dx_dt_closure,
|
381
|
-
x0=x0,
|
382
|
-
t0=t0,
|
383
|
-
tf=tf,
|
384
|
-
num_sub_steps=num_sub_steps,
|
385
|
-
integrator_type=integrator_type,
|
386
|
-
)
|
387
|
-
|
388
|
-
# Integrate the state and compute optional auxiliary data over the horizon
|
389
|
-
out, aux = integrate_single_step_over_horizon(
|
390
|
-
integrator_single_step=integrator_single_step, t=t, x0=y0
|
391
|
-
)
|
392
|
-
|
393
|
-
return (out, aux) if return_aux else out
|