jaxsim 0.2.dev65__py3-none-any.whl → 0.2.dev77__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/simulation/integrators.py +90 -343
- jaxsim/simulation/ode_integration.py +3 -16
- {jaxsim-0.2.dev65.dist-info → jaxsim-0.2.dev77.dist-info}/METADATA +1 -1
- {jaxsim-0.2.dev65.dist-info → jaxsim-0.2.dev77.dist-info}/RECORD +8 -8
- {jaxsim-0.2.dev65.dist-info → jaxsim-0.2.dev77.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev65.dist-info → jaxsim-0.2.dev77.dist-info}/WHEEL +0 -0
- {jaxsim-0.2.dev65.dist-info → jaxsim-0.2.dev77.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.2.
|
16
|
-
__version_tuple__ = version_tuple = (0, 2, '
|
15
|
+
__version__ = version = '0.2.dev77'
|
16
|
+
__version_tuple__ = version_tuple = (0, 2, 'dev77')
|
jaxsim/simulation/integrators.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
|
1
|
+
import enum
|
2
|
+
from typing import Any, Callable
|
2
3
|
|
3
4
|
import jax
|
4
5
|
import jax.numpy as jnp
|
@@ -19,30 +20,39 @@ State = jtp.PyTree
|
|
19
20
|
StateDerivative = jtp.PyTree
|
20
21
|
|
21
22
|
StateDerivativeCallable = Callable[
|
22
|
-
[State, Time],
|
23
|
+
[State, Time], tuple[StateDerivative, dict[str, Any]]
|
23
24
|
]
|
24
25
|
|
25
26
|
|
27
|
+
class IntegratorType(enum.IntEnum):
|
28
|
+
RungeKutta4 = enum.auto()
|
29
|
+
EulerForward = enum.auto()
|
30
|
+
EulerSemiImplicit = enum.auto()
|
31
|
+
EulerSemiImplicitManifold = enum.auto()
|
32
|
+
|
33
|
+
|
26
34
|
# =======================
|
27
35
|
# Single-step integration
|
28
36
|
# =======================
|
29
37
|
|
30
38
|
|
31
|
-
def
|
39
|
+
def integrator_fixed_single_step(
|
32
40
|
dx_dt: StateDerivativeCallable,
|
33
|
-
x0: State,
|
41
|
+
x0: State | ODEState,
|
34
42
|
t0: Time,
|
35
43
|
tf: Time,
|
44
|
+
integrator_type: IntegratorType,
|
36
45
|
num_sub_steps: int = 1,
|
37
|
-
) ->
|
46
|
+
) -> tuple[State | ODEState, dict[str, Any]]:
|
38
47
|
"""
|
39
|
-
|
48
|
+
Advance a state vector by integrating a sytem dynamics with a fixed-step integrator.
|
40
49
|
|
41
50
|
Args:
|
42
51
|
dx_dt: Callable that computes the state derivative.
|
43
52
|
x0: Initial state.
|
44
53
|
t0: Initial time.
|
45
54
|
tf: Final time.
|
55
|
+
integrator_type: Integrator type.
|
46
56
|
num_sub_steps: Number of sub-steps to break the integration into.
|
47
57
|
|
48
58
|
Returns:
|
@@ -55,10 +65,14 @@ def odeint_euler_one_step(
|
|
55
65
|
sub_step_dt = dt / num_sub_steps
|
56
66
|
|
57
67
|
# Initialize the carry
|
58
|
-
Carry =
|
68
|
+
Carry = tuple[State | ODEState, Time]
|
59
69
|
carry_init: Carry = (x0, t0)
|
60
70
|
|
61
|
-
def
|
71
|
+
def forward_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
|
72
|
+
"""
|
73
|
+
Forward Euler integrator.
|
74
|
+
"""
|
75
|
+
|
62
76
|
# Unpack the carry
|
63
77
|
x_t0, t0 = carry
|
64
78
|
|
@@ -78,48 +92,11 @@ def odeint_euler_one_step(
|
|
78
92
|
|
79
93
|
return carry, None
|
80
94
|
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
95
|
+
def rk4_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
|
96
|
+
"""
|
97
|
+
Runge-Kutta 4 integrator.
|
98
|
+
"""
|
85
99
|
|
86
|
-
# Compute the aux dictionary at t0
|
87
|
-
_, aux_t0 = dx_dt(x0, t0)
|
88
|
-
|
89
|
-
return x_tf, aux_t0
|
90
|
-
|
91
|
-
|
92
|
-
def odeint_rk4_one_step(
|
93
|
-
dx_dt: StateDerivativeCallable,
|
94
|
-
x0: State,
|
95
|
-
t0: Time,
|
96
|
-
tf: Time,
|
97
|
-
num_sub_steps: int = 1,
|
98
|
-
) -> Tuple[State, Dict[str, Any]]:
|
99
|
-
"""
|
100
|
-
Runge-Kutta 4 integrator.
|
101
|
-
|
102
|
-
Args:
|
103
|
-
dx_dt: Callable that computes the state derivative.
|
104
|
-
x0: Initial state.
|
105
|
-
t0: Initial time.
|
106
|
-
tf: Final time.
|
107
|
-
num_sub_steps: Number of sub-steps to break the integration into.
|
108
|
-
|
109
|
-
Returns:
|
110
|
-
The final state and a dictionary including auxiliary data at t0.
|
111
|
-
"""
|
112
|
-
|
113
|
-
# Compute the sub-step size.
|
114
|
-
# We break dt in configurable sub-steps.
|
115
|
-
dt = tf - t0
|
116
|
-
sub_step_dt = dt / num_sub_steps
|
117
|
-
|
118
|
-
# Initialize the carry
|
119
|
-
Carry = Tuple[State, Time]
|
120
|
-
carry_init: Carry = (x0, t0)
|
121
|
-
|
122
|
-
def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
|
123
100
|
# Unpack the carry
|
124
101
|
x_t0, t0 = carry
|
125
102
|
|
@@ -148,49 +125,11 @@ def odeint_rk4_one_step(
|
|
148
125
|
|
149
126
|
return carry, None
|
150
127
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
# Compute the aux dictionary at t0
|
157
|
-
_, aux_t0 = dx_dt(x0, t0)
|
158
|
-
|
159
|
-
return x_tf, aux_t0
|
160
|
-
|
161
|
-
|
162
|
-
def odeint_euler_semi_implicit_one_step(
|
163
|
-
dx_dt: StateDerivativeCallable,
|
164
|
-
x0: ODEState,
|
165
|
-
t0: Time,
|
166
|
-
tf: Time,
|
167
|
-
num_sub_steps: int = 1,
|
168
|
-
) -> Tuple[ODEState, Dict[str, Any]]:
|
169
|
-
"""
|
170
|
-
Semi-implicit Euler integrator.
|
171
|
-
|
172
|
-
Args:
|
173
|
-
dx_dt: Callable that computes the state derivative.
|
174
|
-
x0: Initial state as ODEState object.
|
175
|
-
t0: Initial time.
|
176
|
-
tf: Final time.
|
177
|
-
num_sub_steps: Number of sub-steps to break the integration into.
|
178
|
-
|
179
|
-
Returns:
|
180
|
-
A tuple having as first element the final state as ODEState object,
|
181
|
-
and as second element a dictionary including auxiliary data at t0.
|
182
|
-
"""
|
128
|
+
def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
|
129
|
+
"""
|
130
|
+
Semi-implicit Euler integrator.
|
131
|
+
"""
|
183
132
|
|
184
|
-
# Compute the sub-step size.
|
185
|
-
# We break dt in configurable sub-steps.
|
186
|
-
dt = tf - t0
|
187
|
-
sub_step_dt = dt / num_sub_steps
|
188
|
-
|
189
|
-
# Initialize the carry
|
190
|
-
Carry = Tuple[ODEState, Time]
|
191
|
-
carry_init: Carry = (x0, t0)
|
192
|
-
|
193
|
-
def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
|
194
133
|
# Unpack the carry
|
195
134
|
x_t0, t0 = carry
|
196
135
|
|
@@ -218,6 +157,7 @@ def odeint_euler_semi_implicit_one_step(
|
|
218
157
|
# 2. Compute the derivative of the generalized position
|
219
158
|
# 3. Integrate the implicit velocities
|
220
159
|
# 4. Integrate the remaining state
|
160
|
+
# 5. Outside the loop: integrate the quaternion on SO(3) manifold
|
221
161
|
|
222
162
|
# ----------------------------------------------------------------
|
223
163
|
# 1. Integrate the accelerations obtaining the implicit velocities
|
@@ -254,13 +194,27 @@ def odeint_euler_semi_implicit_one_step(
|
|
254
194
|
BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3]
|
255
195
|
|
256
196
|
# Compute the derivative of the generalized position
|
257
|
-
d_pos_tf =
|
197
|
+
d_pos_tf = (
|
198
|
+
jnp.hstack([BW_vl_WB, vel_tf[6:]])
|
199
|
+
if integrator_type is IntegratorType.EulerSemiImplicitManifold
|
200
|
+
else jnp.hstack([BW_vl_WB, W_Qd_B, vel_tf[6:]])
|
201
|
+
)
|
258
202
|
|
259
203
|
# ------------------------------------
|
260
204
|
# 3. Integrate the implicit velocities
|
261
205
|
# ------------------------------------
|
262
206
|
|
263
207
|
pos_tf = pos_t0 + sub_step_dt * d_pos_tf
|
208
|
+
joint_positions = (
|
209
|
+
pos_tf[3:]
|
210
|
+
if integrator_type is IntegratorType.EulerSemiImplicitManifold
|
211
|
+
else pos_tf[7:]
|
212
|
+
)
|
213
|
+
base_quaternion = (
|
214
|
+
jnp.zeros_like(x_t0.base_quaternion)
|
215
|
+
if integrator_type is IntegratorType.EulerSemiImplicitManifold
|
216
|
+
else pos_tf[3:7]
|
217
|
+
)
|
264
218
|
|
265
219
|
# ---------------------------------
|
266
220
|
# 4. Integrate the remaining state
|
@@ -275,8 +229,8 @@ def odeint_euler_semi_implicit_one_step(
|
|
275
229
|
x_tf = ODEState(
|
276
230
|
physics_model=PhysicsModelState(
|
277
231
|
base_position=pos_tf[0:3],
|
278
|
-
base_quaternion=
|
279
|
-
joint_positions=
|
232
|
+
base_quaternion=base_quaternion,
|
233
|
+
joint_positions=joint_positions,
|
280
234
|
base_linear_velocity=vel_tf[0:3],
|
281
235
|
base_angular_velocity=vel_tf[3:6],
|
282
236
|
joint_velocities=vel_tf[6:],
|
@@ -294,176 +248,43 @@ def odeint_euler_semi_implicit_one_step(
|
|
294
248
|
|
295
249
|
return carry, None
|
296
250
|
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
_, aux_t0 = dx_dt(x0, t0)
|
304
|
-
|
305
|
-
return x_tf, aux_t0
|
306
|
-
|
307
|
-
|
308
|
-
def odeint_euler_semi_implicit_manifold_one_step(
|
309
|
-
dx_dt: StateDerivativeCallable,
|
310
|
-
x0: ODEState,
|
311
|
-
t0: Time,
|
312
|
-
tf: Time,
|
313
|
-
num_sub_steps: int = 1,
|
314
|
-
) -> Tuple[ODEState, Dict[str, Any]]:
|
315
|
-
"""
|
316
|
-
Semi-implicit Euler integrator with quaternion integration on SO(3).
|
317
|
-
|
318
|
-
Args:
|
319
|
-
dx_dt: Callable that computes the state derivative.
|
320
|
-
x0: Initial state as ODEState object.
|
321
|
-
t0: Initial time.
|
322
|
-
tf: Final time.
|
323
|
-
num_sub_steps: Number of sub-steps to break the integration into.
|
324
|
-
|
325
|
-
Returns:
|
326
|
-
A tuple having as first element the final state as ODEState object,
|
327
|
-
and as second element a dictionary including auxiliary data at t0.
|
328
|
-
"""
|
329
|
-
|
330
|
-
# Compute the sub-step size.
|
331
|
-
# We break dt in configurable sub-steps.
|
332
|
-
dt = tf - t0
|
333
|
-
sub_step_dt = dt / num_sub_steps
|
334
|
-
|
335
|
-
# Integrate the quaternion on its manifold using the new angular velocity
|
336
|
-
|
337
|
-
# Initialize the carry
|
338
|
-
Carry = Tuple[ODEState, Time]
|
339
|
-
carry_init: Carry = (x0, t0)
|
340
|
-
|
341
|
-
def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
|
342
|
-
# Unpack the carry
|
343
|
-
x_t0, t0 = carry
|
344
|
-
|
345
|
-
# Compute the state derivative.
|
346
|
-
# We only keep the quantities related to the acceleration and discard those
|
347
|
-
# related to the velocity since we are going to use those implicitly integrated
|
348
|
-
# from the accelerations.
|
349
|
-
StateDerivative = ODEState
|
350
|
-
dxdt_t0: StateDerivative = dx_dt(x_t0, t0)[0]
|
351
|
-
|
352
|
-
# Extract the initial position ∈ ℝ⁷⁺ⁿ and initial velocity ∈ ℝ⁶⁺ⁿ.
|
353
|
-
# This integrator, contrarily to most of the other ones, is not generic.
|
354
|
-
# It expects to operate on an x object of class ODEState.
|
355
|
-
pos_t0 = x_t0.physics_model.position()
|
356
|
-
vel_t0 = x_t0.physics_model.velocity()
|
357
|
-
|
358
|
-
# Extract the velocity derivative
|
359
|
-
d_vel_dt = dxdt_t0.physics_model.velocity()
|
360
|
-
|
361
|
-
# =============================================
|
362
|
-
# Perform semi-implicit Euler integration [1-4]
|
363
|
-
# =============================================
|
251
|
+
_integrator_registry = {
|
252
|
+
IntegratorType.RungeKutta4: rk4_body_fun,
|
253
|
+
IntegratorType.EulerForward: forward_euler_body_fun,
|
254
|
+
IntegratorType.EulerSemiImplicit: semi_implicit_euler_body_fun,
|
255
|
+
IntegratorType.EulerSemiImplicitManifold: semi_implicit_euler_body_fun,
|
256
|
+
}
|
364
257
|
|
365
|
-
|
366
|
-
|
367
|
-
# 3. Integrate the implicit velocities (w/o quaternion)
|
368
|
-
# 4. Integrate the remaining state
|
369
|
-
# 5. Outside the loop: integrate the quaternion on SO(3) manifold
|
370
|
-
|
371
|
-
# ----------------------------------------------------------------
|
372
|
-
# 1. Integrate the accelerations obtaining the implicit velocities
|
373
|
-
# ----------------------------------------------------------------
|
374
|
-
|
375
|
-
vel_tf = vel_t0 + sub_step_dt * d_vel_dt
|
376
|
-
|
377
|
-
# ----------------------------------------------------------------------
|
378
|
-
# 2. Compute the derivative of the generalized position (w/o quaternion)
|
379
|
-
# ----------------------------------------------------------------------
|
380
|
-
|
381
|
-
# Compute the transform of the mixed base frame at t0
|
382
|
-
W_H_BW = jnp.vstack(
|
383
|
-
[
|
384
|
-
jnp.block([jnp.eye(3), jnp.vstack(x_t0.physics_model.base_position)]),
|
385
|
-
jnp.array([0, 0, 0, 1]),
|
386
|
-
]
|
387
|
-
)
|
388
|
-
|
389
|
-
# The derivative W_ṗ_B of the base position is the linear component of the
|
390
|
-
# mixed velocity B[W]_v_WB. We need to compute it from the velocity in
|
391
|
-
# inertial-fixed representation W_vl_WB.
|
392
|
-
W_v_WB = vel_tf[0:6]
|
393
|
-
BW_Xv_W = se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
|
394
|
-
BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3]
|
395
|
-
|
396
|
-
# Compute the derivative of the generalized position excluding the quaternion
|
397
|
-
pos_no_quat_t0 = jnp.hstack([pos_t0[0:3], pos_t0[7:]])
|
398
|
-
d_pos_no_quat_tf = jnp.hstack([BW_vl_WB, vel_tf[6:]])
|
399
|
-
|
400
|
-
# -----------------------------------------------------
|
401
|
-
# 3. Integrate the implicit velocities (w/o quaternion)
|
402
|
-
# -----------------------------------------------------
|
403
|
-
|
404
|
-
pos_no_quat_tf = pos_no_quat_t0 + sub_step_dt * d_pos_no_quat_tf
|
405
|
-
|
406
|
-
# ---------------------------------
|
407
|
-
# 4. Integrate the remaining state
|
408
|
-
# ---------------------------------
|
409
|
-
|
410
|
-
# Integrate the derivative of the tangential material deformation
|
411
|
-
m = x_t0.soft_contacts.tangential_deformation
|
412
|
-
ṁ = dxdt_t0.soft_contacts.tangential_deformation
|
413
|
-
tangential_deformation_tf = m + sub_step_dt * ṁ
|
414
|
-
|
415
|
-
# Pack the new state into an ODEState object.
|
416
|
-
# We store a zero quaternion as placeholder, it will be replaced later.
|
417
|
-
x_tf = ODEState(
|
418
|
-
physics_model=PhysicsModelState(
|
419
|
-
base_position=pos_no_quat_tf[0:3],
|
420
|
-
base_quaternion=jnp.zeros_like(x_t0.physics_model.base_quaternion),
|
421
|
-
joint_positions=pos_no_quat_tf[3:],
|
422
|
-
base_linear_velocity=vel_tf[0:3],
|
423
|
-
base_angular_velocity=vel_tf[3:6],
|
424
|
-
joint_velocities=vel_tf[6:],
|
425
|
-
),
|
426
|
-
soft_contacts=SoftContactsState(
|
427
|
-
tangential_deformation=tangential_deformation_tf
|
428
|
-
),
|
429
|
-
)
|
430
|
-
|
431
|
-
# Update the time
|
432
|
-
tf = t0 + sub_step_dt
|
433
|
-
|
434
|
-
# Pack the carry
|
435
|
-
carry = (x_tf, tf)
|
436
|
-
|
437
|
-
return carry, None
|
258
|
+
# Get the body function for the selected integrator
|
259
|
+
body_fun = _integrator_registry[integrator_type]
|
438
260
|
|
439
261
|
# Integrate over the given horizon
|
440
|
-
(
|
262
|
+
(x_tf, _), _ = jax.lax.scan(
|
441
263
|
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
|
442
264
|
)
|
443
265
|
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
# Indices to convert quaternions between serializations
|
449
|
-
to_xyzw = jnp.array([1, 2, 3, 0])
|
450
|
-
to_wxyz = jnp.array([3, 0, 1, 2])
|
266
|
+
if integrator_type is IntegratorType.EulerSemiImplicitManifold:
|
267
|
+
# Indices to convert quaternions between serializations
|
268
|
+
to_xyzw = jnp.array([1, 2, 3, 0])
|
269
|
+
to_wxyz = jnp.array([3, 0, 1, 2])
|
451
270
|
|
452
|
-
|
453
|
-
|
454
|
-
|
271
|
+
# Get the initial quaternion and the implicitly integrated angular velocity
|
272
|
+
W_ω_WB_tf = x_tf.physics_model.base_angular_velocity
|
273
|
+
W_Q_B_t0 = so3.SO3.from_quaternion_xyzw(
|
274
|
+
x0.physics_model.base_quaternion[to_xyzw]
|
275
|
+
)
|
455
276
|
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
277
|
+
# Integrate the quaternion on its manifold using the implicit angular velocity,
|
278
|
+
# transformed in body-fixed representation since jaxlie uses this convention
|
279
|
+
B_R_W = W_Q_B_t0.inverse().as_matrix()
|
280
|
+
W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf)
|
460
281
|
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
282
|
+
# Store the quaternion in the final state
|
283
|
+
x_tf = x_tf.replace(
|
284
|
+
physics_model=x_tf.physics_model.replace(
|
285
|
+
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
|
286
|
+
)
|
465
287
|
)
|
466
|
-
)
|
467
288
|
|
468
289
|
# Compute the aux dictionary at t0
|
469
290
|
_, aux_t0 = dx_dt(x0, t0)
|
@@ -477,10 +298,10 @@ def odeint_euler_semi_implicit_manifold_one_step(
|
|
477
298
|
|
478
299
|
|
479
300
|
def integrate_single_step_over_horizon(
|
480
|
-
integrator_single_step: Callable[[Time, Time, State],
|
301
|
+
integrator_single_step: Callable[[Time, Time, State], tuple[State, dict[str, Any]]],
|
481
302
|
t: TimeHorizon,
|
482
303
|
x0: State,
|
483
|
-
) ->
|
304
|
+
) -> tuple[State, dict[str, Any]]:
|
484
305
|
"""
|
485
306
|
Integrate a single-step integrator over a given horizon.
|
486
307
|
|
@@ -496,7 +317,7 @@ def integrate_single_step_over_horizon(
|
|
496
317
|
# Initialize the carry
|
497
318
|
carry_init = (x0, t)
|
498
319
|
|
499
|
-
def body_fun(carry:
|
320
|
+
def body_fun(carry: tuple, idx: int) -> tuple[tuple, jtp.PyTree]:
|
500
321
|
# Unpack the carry
|
501
322
|
x_t0, horizon = carry
|
502
323
|
|
@@ -526,96 +347,17 @@ def integrate_single_step_over_horizon(
|
|
526
347
|
# ===================================================================
|
527
348
|
|
528
349
|
|
529
|
-
def
|
530
|
-
func,
|
531
|
-
y0: State,
|
532
|
-
t: TimeHorizon,
|
533
|
-
*args,
|
534
|
-
num_sub_steps: int = 1,
|
535
|
-
return_aux: bool = False
|
536
|
-
) -> Union[State, Tuple[State, Dict[str, Any]]]:
|
537
|
-
"""
|
538
|
-
Integrate a system of ODEs using the Euler method.
|
539
|
-
|
540
|
-
Args:
|
541
|
-
func: A function that computes the time-derivative of the state.
|
542
|
-
y0: The initial state.
|
543
|
-
t: The vector of time instants of the integration horizon.
|
544
|
-
*args: Additional arguments to be passed to the function func.
|
545
|
-
num_sub_steps: The number of sub-steps to be performed within each integration step.
|
546
|
-
return_aux: Whether to return the auxiliary data produced by the integrator.
|
547
|
-
|
548
|
-
Returns:
|
549
|
-
The state of the system at the end of the integration horizon, and optionally
|
550
|
-
the auxiliary data produced by the integrator.
|
551
|
-
"""
|
552
|
-
|
553
|
-
# Close func over additional inputs and parameters
|
554
|
-
dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)
|
555
|
-
|
556
|
-
# Close one-step integration over its arguments
|
557
|
-
integrator_single_step = lambda t0, tf, x0: odeint_euler_one_step(
|
558
|
-
dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
|
559
|
-
)
|
560
|
-
|
561
|
-
# Integrate the state and compute optional auxiliary data over the horizon
|
562
|
-
out, aux = integrate_single_step_over_horizon(
|
563
|
-
integrator_single_step=integrator_single_step, t=t, x0=y0
|
564
|
-
)
|
565
|
-
|
566
|
-
return (out, aux) if return_aux else out
|
567
|
-
|
568
|
-
|
569
|
-
def odeint_euler_semi_implicit(
|
570
|
-
func,
|
571
|
-
y0: ODEState,
|
572
|
-
t: TimeHorizon,
|
573
|
-
*args,
|
574
|
-
num_sub_steps: int = 1,
|
575
|
-
return_aux: bool = False
|
576
|
-
) -> Union[ODEState, Tuple[ODEState, Dict[str, Any]]]:
|
577
|
-
"""
|
578
|
-
Integrate a system of ODEs using the Semi-Implicit Euler method.
|
579
|
-
|
580
|
-
Args:
|
581
|
-
func: A function that computes the time-derivative of the state.
|
582
|
-
y0: The initial state as ODEState object.
|
583
|
-
t: The vector of time instants of the integration horizon.
|
584
|
-
*args: Additional arguments to be passed to the function func.
|
585
|
-
num_sub_steps: The number of sub-steps to be performed within each integration step.
|
586
|
-
return_aux: Whether to return the auxiliary data produced by the integrator.
|
587
|
-
|
588
|
-
Returns:
|
589
|
-
The state of the system at the end of the integration horizon as ODEState object,
|
590
|
-
and optionally the auxiliary data produced by the integrator.
|
591
|
-
"""
|
592
|
-
|
593
|
-
# Close func over additional inputs and parameters
|
594
|
-
dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)
|
595
|
-
|
596
|
-
# Close one-step integration over its arguments
|
597
|
-
integrator_single_step = lambda t0, tf, x0: odeint_euler_semi_implicit_one_step(
|
598
|
-
dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
|
599
|
-
)
|
600
|
-
|
601
|
-
# Integrate the state and compute optional auxiliary data over the horizon
|
602
|
-
out, aux = integrate_single_step_over_horizon(
|
603
|
-
integrator_single_step=integrator_single_step, t=t, x0=y0
|
604
|
-
)
|
605
|
-
|
606
|
-
return (out, aux) if return_aux else out
|
607
|
-
|
608
|
-
|
609
|
-
def odeint_rk4(
|
350
|
+
def odeint(
|
610
351
|
func,
|
611
352
|
y0: State,
|
612
353
|
t: TimeHorizon,
|
613
354
|
*args,
|
614
355
|
num_sub_steps: int = 1,
|
615
|
-
return_aux: bool = False
|
616
|
-
|
356
|
+
return_aux: bool = False,
|
357
|
+
integrator_type: IntegratorType = None,
|
358
|
+
):
|
617
359
|
"""
|
618
|
-
Integrate a system of ODEs
|
360
|
+
Integrate a system of ODEs with a fixed-step integrator.
|
619
361
|
|
620
362
|
Args:
|
621
363
|
func: A function that computes the time-derivative of the state.
|
@@ -634,8 +376,13 @@ def odeint_rk4(
|
|
634
376
|
dx_dt_closure = lambda x, ts: func(x, ts, *args)
|
635
377
|
|
636
378
|
# Close one-step integration over its arguments
|
637
|
-
integrator_single_step = lambda t0, tf, x0:
|
638
|
-
dx_dt=dx_dt_closure,
|
379
|
+
integrator_single_step = lambda t0, tf, x0: integrator_fixed_single_step(
|
380
|
+
dx_dt=dx_dt_closure,
|
381
|
+
x0=x0,
|
382
|
+
t0=t0,
|
383
|
+
tf=tf,
|
384
|
+
num_sub_steps=num_sub_steps,
|
385
|
+
integrator_type=integrator_type,
|
639
386
|
)
|
640
387
|
|
641
388
|
# Integrate the state and compute optional auxiliary data over the horizon
|
@@ -10,21 +10,7 @@ from jaxsim.physics.algos.soft_contacts import SoftContactsParams
|
|
10
10
|
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
|
11
11
|
from jaxsim.physics.model.physics_model import PhysicsModel
|
12
12
|
from jaxsim.simulation import integrators, ode
|
13
|
-
|
14
|
-
|
15
|
-
class IntegratorType(enum.IntEnum):
|
16
|
-
RungeKutta4 = enum.auto()
|
17
|
-
EulerForward = enum.auto()
|
18
|
-
EulerSemiImplicit = enum.auto()
|
19
|
-
EulerSemiImplicitManifold = enum.auto()
|
20
|
-
|
21
|
-
|
22
|
-
_integrator_registry = {
|
23
|
-
IntegratorType.RungeKutta4: integrators.odeint_rk4,
|
24
|
-
IntegratorType.EulerForward: integrators.odeint_euler,
|
25
|
-
IntegratorType.EulerSemiImplicit: integrators.odeint_euler_semi_implicit,
|
26
|
-
IntegratorType.EulerSemiImplicitManifold: integrators.odeint_euler_semi_implicit_manifold_one_step,
|
27
|
-
}
|
13
|
+
from jaxsim.simulation.integrators import IntegratorType
|
28
14
|
|
29
15
|
|
30
16
|
@jax.jit
|
@@ -62,12 +48,13 @@ def ode_integration_fixed_step(
|
|
62
48
|
)
|
63
49
|
|
64
50
|
# Integrate over the horizon
|
65
|
-
out =
|
51
|
+
out = integrators.odeint(
|
66
52
|
func=dx_dt_closure,
|
67
53
|
y0=x0,
|
68
54
|
t=t,
|
69
55
|
num_sub_steps=num_sub_steps,
|
70
56
|
return_aux=return_aux,
|
57
|
+
integrator_type=integrator_type,
|
71
58
|
)
|
72
59
|
|
73
60
|
# Return output pytree and, optionally, the aux dict
|
@@ -1,5 +1,5 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=LJhCG4rsmCrTKTocwRIvllPQeYTxDn-VFn6NjPngn4s,1877
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=ea7HN3P3BTsbrEcL4jflFM4NRajTns0RZu9M7F_bdGw,421
|
3
3
|
jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
|
4
4
|
jaxsim/typing.py,sha256=ErTscpEljFyrhPCisZnLEUt6FWLAuEAh-72Teb8Nz98,626
|
5
5
|
jaxsim/high_level/__init__.py,sha256=aWYBCsYmEO76Qt4GEi91Hye_ifGFLvc_bpy9OQplz2o,69
|
@@ -49,10 +49,10 @@ jaxsim/physics/model/ground_contact.py,sha256=mva-yDzYHREmgUu8jGJmIAsf66_SF6ZISm
|
|
49
49
|
jaxsim/physics/model/physics_model.py,sha256=kVTIaJQrxALzyWjWrDLnwDOcxmzaPGSpUOS8BCq-g6M,13249
|
50
50
|
jaxsim/physics/model/physics_model_state.py,sha256=LTC-uqUCP1-7-mLHMa6aY4xfBYWuHIextxDH0EEqEmE,5729
|
51
51
|
jaxsim/simulation/__init__.py,sha256=WOWkzq7rMGa4xWvjNqTYtD0Nl4yLQtULGW1xU7hD9m0,182
|
52
|
-
jaxsim/simulation/integrators.py,sha256=
|
52
|
+
jaxsim/simulation/integrators.py,sha256=WIlL7xi4UocSlWg4Qms8-6puqRYnK5A4r7TJUNPg5g0,13022
|
53
53
|
jaxsim/simulation/ode.py,sha256=ntq_iQPIw3SHj64CZWD2mHAKmt05ZgRpw2UwyTxHDOQ,10380
|
54
54
|
jaxsim/simulation/ode_data.py,sha256=spzHU5LnOL6mJPuuhho-J61koT-bcTRonqMMkiPo3M4,1750
|
55
|
-
jaxsim/simulation/ode_integration.py,sha256=
|
55
|
+
jaxsim/simulation/ode_integration.py,sha256=VDprQYoHEE_iI7ia1Mm3RyYl-LRvHU8dJEvRoGA4TFA,1947
|
56
56
|
jaxsim/simulation/simulator.py,sha256=qCI5QG0WKkBC5GNqauSvI7rSlGD7CLttTzCgLED7iJM,18123
|
57
57
|
jaxsim/simulation/simulator_callbacks.py,sha256=QWdY7dilmjrxeieWCB6RQ-cWpwLuUOK8fYWXpnnBcyU,2217
|
58
58
|
jaxsim/simulation/utils.py,sha256=YdNA1mYGBAE7xVA-Dw7_OoBEuh0J8RS2X0RPQZf4c5E,329
|
@@ -62,8 +62,8 @@ jaxsim/utils/jaxsim_dataclass.py,sha256=FbjfEoCoYC_F-M3wUggXiEhQ7MMS-V_ciYQca-uS
|
|
62
62
|
jaxsim/utils/oop.py,sha256=LQhBXkSOD0zgYNJLO7Bl0FPRg-LvtvPzxyQa1WFP0rM,22616
|
63
63
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
64
64
|
jaxsim/utils/vmappable.py,sha256=NqGL9nGFRI5OorCfnjXsjR_yXigzDxL0lW1YhQ_nMTY,3655
|
65
|
-
jaxsim-0.2.
|
66
|
-
jaxsim-0.2.
|
67
|
-
jaxsim-0.2.
|
68
|
-
jaxsim-0.2.
|
69
|
-
jaxsim-0.2.
|
65
|
+
jaxsim-0.2.dev77.dist-info/LICENSE,sha256=EsU2z6_sWW4Zduzq3goVWjZoCZVKQsM4H_y0o7oRA7Q,1547
|
66
|
+
jaxsim-0.2.dev77.dist-info/METADATA,sha256=zOGN2gEgqKAs7iRkr9rjGcSsZWUS90wnoekgcZCgaKs,7486
|
67
|
+
jaxsim-0.2.dev77.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
68
|
+
jaxsim-0.2.dev77.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
69
|
+
jaxsim-0.2.dev77.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|