jaxsim 0.2.dev56__py3-none-any.whl → 0.2.dev77__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
- from typing import Any, Callable, Dict, Tuple, Union
1
+ import enum
2
+ from typing import Any, Callable
2
3
 
3
4
  import jax
4
5
  import jax.numpy as jnp
@@ -19,30 +20,39 @@ State = jtp.PyTree
19
20
  StateDerivative = jtp.PyTree
20
21
 
21
22
  StateDerivativeCallable = Callable[
22
- [State, Time], Tuple[StateDerivative, Dict[str, Any]]
23
+ [State, Time], tuple[StateDerivative, dict[str, Any]]
23
24
  ]
24
25
 
25
26
 
27
+ class IntegratorType(enum.IntEnum):
28
+ RungeKutta4 = enum.auto()
29
+ EulerForward = enum.auto()
30
+ EulerSemiImplicit = enum.auto()
31
+ EulerSemiImplicitManifold = enum.auto()
32
+
33
+
26
34
  # =======================
27
35
  # Single-step integration
28
36
  # =======================
29
37
 
30
38
 
31
- def odeint_euler_one_step(
39
+ def integrator_fixed_single_step(
32
40
  dx_dt: StateDerivativeCallable,
33
- x0: State,
41
+ x0: State | ODEState,
34
42
  t0: Time,
35
43
  tf: Time,
44
+ integrator_type: IntegratorType,
36
45
  num_sub_steps: int = 1,
37
- ) -> Tuple[State, Dict[str, Any]]:
46
+ ) -> tuple[State | ODEState, dict[str, Any]]:
38
47
  """
39
- Forward Euler integrator.
48
+ Advance a state vector by integrating a sytem dynamics with a fixed-step integrator.
40
49
 
41
50
  Args:
42
51
  dx_dt: Callable that computes the state derivative.
43
52
  x0: Initial state.
44
53
  t0: Initial time.
45
54
  tf: Final time.
55
+ integrator_type: Integrator type.
46
56
  num_sub_steps: Number of sub-steps to break the integration into.
47
57
 
48
58
  Returns:
@@ -55,10 +65,14 @@ def odeint_euler_one_step(
55
65
  sub_step_dt = dt / num_sub_steps
56
66
 
57
67
  # Initialize the carry
58
- Carry = Tuple[State, Time]
68
+ Carry = tuple[State | ODEState, Time]
59
69
  carry_init: Carry = (x0, t0)
60
70
 
61
- def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
71
+ def forward_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
72
+ """
73
+ Forward Euler integrator.
74
+ """
75
+
62
76
  # Unpack the carry
63
77
  x_t0, t0 = carry
64
78
 
@@ -78,48 +92,11 @@ def odeint_euler_one_step(
78
92
 
79
93
  return carry, None
80
94
 
81
- # Integrate over the given horizon
82
- (x_tf, _), _ = jax.lax.scan(
83
- f=body_fun, init=carry_init, xs=None, length=num_sub_steps
84
- )
95
+ def rk4_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
96
+ """
97
+ Runge-Kutta 4 integrator.
98
+ """
85
99
 
86
- # Compute the aux dictionary at t0
87
- _, aux_t0 = dx_dt(x0, t0)
88
-
89
- return x_tf, aux_t0
90
-
91
-
92
- def odeint_rk4_one_step(
93
- dx_dt: StateDerivativeCallable,
94
- x0: State,
95
- t0: Time,
96
- tf: Time,
97
- num_sub_steps: int = 1,
98
- ) -> Tuple[State, Dict[str, Any]]:
99
- """
100
- Runge-Kutta 4 integrator.
101
-
102
- Args:
103
- dx_dt: Callable that computes the state derivative.
104
- x0: Initial state.
105
- t0: Initial time.
106
- tf: Final time.
107
- num_sub_steps: Number of sub-steps to break the integration into.
108
-
109
- Returns:
110
- The final state and a dictionary including auxiliary data at t0.
111
- """
112
-
113
- # Compute the sub-step size.
114
- # We break dt in configurable sub-steps.
115
- dt = tf - t0
116
- sub_step_dt = dt / num_sub_steps
117
-
118
- # Initialize the carry
119
- Carry = Tuple[State, Time]
120
- carry_init: Carry = (x0, t0)
121
-
122
- def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
123
100
  # Unpack the carry
124
101
  x_t0, t0 = carry
125
102
 
@@ -148,49 +125,11 @@ def odeint_rk4_one_step(
148
125
 
149
126
  return carry, None
150
127
 
151
- # Integrate over the given horizon
152
- (x_tf, _), _ = jax.lax.scan(
153
- f=body_fun, init=carry_init, xs=None, length=num_sub_steps
154
- )
155
-
156
- # Compute the aux dictionary at t0
157
- _, aux_t0 = dx_dt(x0, t0)
158
-
159
- return x_tf, aux_t0
160
-
161
-
162
- def odeint_euler_semi_implicit_one_step(
163
- dx_dt: StateDerivativeCallable,
164
- x0: ODEState,
165
- t0: Time,
166
- tf: Time,
167
- num_sub_steps: int = 1,
168
- ) -> Tuple[ODEState, Dict[str, Any]]:
169
- """
170
- Semi-implicit Euler integrator.
171
-
172
- Args:
173
- dx_dt: Callable that computes the state derivative.
174
- x0: Initial state as ODEState object.
175
- t0: Initial time.
176
- tf: Final time.
177
- num_sub_steps: Number of sub-steps to break the integration into.
178
-
179
- Returns:
180
- A tuple having as first element the final state as ODEState object,
181
- and as second element a dictionary including auxiliary data at t0.
182
- """
128
+ def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
129
+ """
130
+ Semi-implicit Euler integrator.
131
+ """
183
132
 
184
- # Compute the sub-step size.
185
- # We break dt in configurable sub-steps.
186
- dt = tf - t0
187
- sub_step_dt = dt / num_sub_steps
188
-
189
- # Initialize the carry
190
- Carry = Tuple[ODEState, Time]
191
- carry_init: Carry = (x0, t0)
192
-
193
- def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
194
133
  # Unpack the carry
195
134
  x_t0, t0 = carry
196
135
 
@@ -218,6 +157,7 @@ def odeint_euler_semi_implicit_one_step(
218
157
  # 2. Compute the derivative of the generalized position
219
158
  # 3. Integrate the implicit velocities
220
159
  # 4. Integrate the remaining state
160
+ # 5. Outside the loop: integrate the quaternion on SO(3) manifold
221
161
 
222
162
  # ----------------------------------------------------------------
223
163
  # 1. Integrate the accelerations obtaining the implicit velocities
@@ -254,13 +194,27 @@ def odeint_euler_semi_implicit_one_step(
254
194
  BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3]
255
195
 
256
196
  # Compute the derivative of the generalized position
257
- d_pos_tf = jnp.hstack([BW_vl_WB, W_Qd_B, vel_tf[6:]])
197
+ d_pos_tf = (
198
+ jnp.hstack([BW_vl_WB, vel_tf[6:]])
199
+ if integrator_type is IntegratorType.EulerSemiImplicitManifold
200
+ else jnp.hstack([BW_vl_WB, W_Qd_B, vel_tf[6:]])
201
+ )
258
202
 
259
203
  # ------------------------------------
260
204
  # 3. Integrate the implicit velocities
261
205
  # ------------------------------------
262
206
 
263
207
  pos_tf = pos_t0 + sub_step_dt * d_pos_tf
208
+ joint_positions = (
209
+ pos_tf[3:]
210
+ if integrator_type is IntegratorType.EulerSemiImplicitManifold
211
+ else pos_tf[7:]
212
+ )
213
+ base_quaternion = (
214
+ jnp.zeros_like(x_t0.base_quaternion)
215
+ if integrator_type is IntegratorType.EulerSemiImplicitManifold
216
+ else pos_tf[3:7]
217
+ )
264
218
 
265
219
  # ---------------------------------
266
220
  # 4. Integrate the remaining state
@@ -275,8 +229,8 @@ def odeint_euler_semi_implicit_one_step(
275
229
  x_tf = ODEState(
276
230
  physics_model=PhysicsModelState(
277
231
  base_position=pos_tf[0:3],
278
- base_quaternion=pos_tf[3:7],
279
- joint_positions=pos_tf[7:],
232
+ base_quaternion=base_quaternion,
233
+ joint_positions=joint_positions,
280
234
  base_linear_velocity=vel_tf[0:3],
281
235
  base_angular_velocity=vel_tf[3:6],
282
236
  joint_velocities=vel_tf[6:],
@@ -294,176 +248,43 @@ def odeint_euler_semi_implicit_one_step(
294
248
 
295
249
  return carry, None
296
250
 
297
- # Integrate over the given horizon
298
- (x_tf, _), _ = jax.lax.scan(
299
- f=body_fun, init=carry_init, xs=None, length=num_sub_steps
300
- )
301
-
302
- # Compute the aux dictionary at t0
303
- _, aux_t0 = dx_dt(x0, t0)
304
-
305
- return x_tf, aux_t0
306
-
307
-
308
- def odeint_euler_semi_implicit_manifold_one_step(
309
- dx_dt: StateDerivativeCallable,
310
- x0: ODEState,
311
- t0: Time,
312
- tf: Time,
313
- num_sub_steps: int = 1,
314
- ) -> Tuple[ODEState, Dict[str, Any]]:
315
- """
316
- Semi-implicit Euler integrator with quaternion integration on SO(3).
317
-
318
- Args:
319
- dx_dt: Callable that computes the state derivative.
320
- x0: Initial state as ODEState object.
321
- t0: Initial time.
322
- tf: Final time.
323
- num_sub_steps: Number of sub-steps to break the integration into.
324
-
325
- Returns:
326
- A tuple having as first element the final state as ODEState object,
327
- and as second element a dictionary including auxiliary data at t0.
328
- """
329
-
330
- # Compute the sub-step size.
331
- # We break dt in configurable sub-steps.
332
- dt = tf - t0
333
- sub_step_dt = dt / num_sub_steps
334
-
335
- # Integrate the quaternion on its manifold using the new angular velocity
336
-
337
- # Initialize the carry
338
- Carry = Tuple[ODEState, Time]
339
- carry_init: Carry = (x0, t0)
340
-
341
- def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
342
- # Unpack the carry
343
- x_t0, t0 = carry
344
-
345
- # Compute the state derivative.
346
- # We only keep the quantities related to the acceleration and discard those
347
- # related to the velocity since we are going to use those implicitly integrated
348
- # from the accelerations.
349
- StateDerivative = ODEState
350
- dxdt_t0: StateDerivative = dx_dt(x_t0, t0)[0]
351
-
352
- # Extract the initial position ∈ ℝ⁷⁺ⁿ and initial velocity ∈ ℝ⁶⁺ⁿ.
353
- # This integrator, contrarily to most of the other ones, is not generic.
354
- # It expects to operate on an x object of class ODEState.
355
- pos_t0 = x_t0.physics_model.position()
356
- vel_t0 = x_t0.physics_model.velocity()
357
-
358
- # Extract the velocity derivative
359
- d_vel_dt = dxdt_t0.physics_model.velocity()
360
-
361
- # =============================================
362
- # Perform semi-implicit Euler integration [1-4]
363
- # =============================================
251
+ _integrator_registry = {
252
+ IntegratorType.RungeKutta4: rk4_body_fun,
253
+ IntegratorType.EulerForward: forward_euler_body_fun,
254
+ IntegratorType.EulerSemiImplicit: semi_implicit_euler_body_fun,
255
+ IntegratorType.EulerSemiImplicitManifold: semi_implicit_euler_body_fun,
256
+ }
364
257
 
365
- # 1. Integrate the accelerations obtaining the implicit velocities
366
- # 2. Compute the derivative of the generalized position (w/o quaternion)
367
- # 3. Integrate the implicit velocities (w/o quaternion)
368
- # 4. Integrate the remaining state
369
- # 5. Outside the loop: integrate the quaternion on SO(3) manifold
370
-
371
- # ----------------------------------------------------------------
372
- # 1. Integrate the accelerations obtaining the implicit velocities
373
- # ----------------------------------------------------------------
374
-
375
- vel_tf = vel_t0 + sub_step_dt * d_vel_dt
376
-
377
- # ----------------------------------------------------------------------
378
- # 2. Compute the derivative of the generalized position (w/o quaternion)
379
- # ----------------------------------------------------------------------
380
-
381
- # Compute the transform of the mixed base frame at t0
382
- W_H_BW = jnp.vstack(
383
- [
384
- jnp.block([jnp.eye(3), jnp.vstack(x_t0.physics_model.base_position)]),
385
- jnp.array([0, 0, 0, 1]),
386
- ]
387
- )
388
-
389
- # The derivative W_ṗ_B of the base position is the linear component of the
390
- # mixed velocity B[W]_v_WB. We need to compute it from the velocity in
391
- # inertial-fixed representation W_vl_WB.
392
- W_v_WB = vel_tf[0:6]
393
- BW_Xv_W = se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
394
- BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3]
395
-
396
- # Compute the derivative of the generalized position excluding the quaternion
397
- pos_no_quat_t0 = jnp.hstack([pos_t0[0:3], pos_t0[7:]])
398
- d_pos_no_quat_tf = jnp.hstack([BW_vl_WB, vel_tf[6:]])
399
-
400
- # -----------------------------------------------------
401
- # 3. Integrate the implicit velocities (w/o quaternion)
402
- # -----------------------------------------------------
403
-
404
- pos_no_quat_tf = pos_no_quat_t0 + sub_step_dt * d_pos_no_quat_tf
405
-
406
- # ---------------------------------
407
- # 4. Integrate the remaining state
408
- # ---------------------------------
409
-
410
- # Integrate the derivative of the tangential material deformation
411
- m = x_t0.soft_contacts.tangential_deformation
412
- ṁ = dxdt_t0.soft_contacts.tangential_deformation
413
- tangential_deformation_tf = m + sub_step_dt * ṁ
414
-
415
- # Pack the new state into an ODEState object.
416
- # We store a zero quaternion as placeholder, it will be replaced later.
417
- x_tf = ODEState(
418
- physics_model=PhysicsModelState(
419
- base_position=pos_no_quat_tf[0:3],
420
- base_quaternion=jnp.zeros_like(x_t0.physics_model.base_quaternion),
421
- joint_positions=pos_no_quat_tf[3:],
422
- base_linear_velocity=vel_tf[0:3],
423
- base_angular_velocity=vel_tf[3:6],
424
- joint_velocities=vel_tf[6:],
425
- ),
426
- soft_contacts=SoftContactsState(
427
- tangential_deformation=tangential_deformation_tf
428
- ),
429
- )
430
-
431
- # Update the time
432
- tf = t0 + sub_step_dt
433
-
434
- # Pack the carry
435
- carry = (x_tf, tf)
436
-
437
- return carry, None
258
+ # Get the body function for the selected integrator
259
+ body_fun = _integrator_registry[integrator_type]
438
260
 
439
261
  # Integrate over the given horizon
440
- (x_no_quat_tf, _), _ = jax.lax.scan(
262
+ (x_tf, _), _ = jax.lax.scan(
441
263
  f=body_fun, init=carry_init, xs=None, length=num_sub_steps
442
264
  )
443
265
 
444
- # ---------------------------------------------
445
- # 5. Integrate the quaternion on SO(3) manifold
446
- # ---------------------------------------------
447
-
448
- # Indices to convert quaternions between serializations
449
- to_xyzw = jnp.array([1, 2, 3, 0])
450
- to_wxyz = jnp.array([3, 0, 1, 2])
266
+ if integrator_type is IntegratorType.EulerSemiImplicitManifold:
267
+ # Indices to convert quaternions between serializations
268
+ to_xyzw = jnp.array([1, 2, 3, 0])
269
+ to_wxyz = jnp.array([3, 0, 1, 2])
451
270
 
452
- # Get the initial quaternion and the implicitly integrated angular velocity
453
- W_ω_WB_tf = x_no_quat_tf.physics_model.base_angular_velocity
454
- W_Q_B_t0 = so3.SO3.from_quaternion_xyzw(x0.physics_model.base_quaternion[to_xyzw])
271
+ # Get the initial quaternion and the implicitly integrated angular velocity
272
+ W_ω_WB_tf = x_tf.physics_model.base_angular_velocity
273
+ W_Q_B_t0 = so3.SO3.from_quaternion_xyzw(
274
+ x0.physics_model.base_quaternion[to_xyzw]
275
+ )
455
276
 
456
- # Integrate the quaternion on its manifold using the implicit angular velocity,
457
- # transformed in body-fixed representation since jaxlie uses this convention
458
- B_R_W = W_Q_B_t0.inverse().as_matrix()
459
- W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf)
277
+ # Integrate the quaternion on its manifold using the implicit angular velocity,
278
+ # transformed in body-fixed representation since jaxlie uses this convention
279
+ B_R_W = W_Q_B_t0.inverse().as_matrix()
280
+ W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf)
460
281
 
461
- # Store the quaternion in the final state
462
- x_tf = x_no_quat_tf.replace(
463
- physics_model=x_no_quat_tf.physics_model.replace(
464
- base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
282
+ # Store the quaternion in the final state
283
+ x_tf = x_tf.replace(
284
+ physics_model=x_tf.physics_model.replace(
285
+ base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
286
+ )
465
287
  )
466
- )
467
288
 
468
289
  # Compute the aux dictionary at t0
469
290
  _, aux_t0 = dx_dt(x0, t0)
@@ -477,10 +298,10 @@ def odeint_euler_semi_implicit_manifold_one_step(
477
298
 
478
299
 
479
300
  def integrate_single_step_over_horizon(
480
- integrator_single_step: Callable[[Time, Time, State], Tuple[State, Dict[str, Any]]],
301
+ integrator_single_step: Callable[[Time, Time, State], tuple[State, dict[str, Any]]],
481
302
  t: TimeHorizon,
482
303
  x0: State,
483
- ) -> Tuple[State, Dict[str, Any]]:
304
+ ) -> tuple[State, dict[str, Any]]:
484
305
  """
485
306
  Integrate a single-step integrator over a given horizon.
486
307
 
@@ -496,7 +317,7 @@ def integrate_single_step_over_horizon(
496
317
  # Initialize the carry
497
318
  carry_init = (x0, t)
498
319
 
499
- def body_fun(carry: Tuple, idx: int) -> Tuple[Tuple, jtp.PyTree]:
320
+ def body_fun(carry: tuple, idx: int) -> tuple[tuple, jtp.PyTree]:
500
321
  # Unpack the carry
501
322
  x_t0, horizon = carry
502
323
 
@@ -526,96 +347,17 @@ def integrate_single_step_over_horizon(
526
347
  # ===================================================================
527
348
 
528
349
 
529
- def odeint_euler(
530
- func,
531
- y0: State,
532
- t: TimeHorizon,
533
- *args,
534
- num_sub_steps: int = 1,
535
- return_aux: bool = False
536
- ) -> Union[State, Tuple[State, Dict[str, Any]]]:
537
- """
538
- Integrate a system of ODEs using the Euler method.
539
-
540
- Args:
541
- func: A function that computes the time-derivative of the state.
542
- y0: The initial state.
543
- t: The vector of time instants of the integration horizon.
544
- *args: Additional arguments to be passed to the function func.
545
- num_sub_steps: The number of sub-steps to be performed within each integration step.
546
- return_aux: Whether to return the auxiliary data produced by the integrator.
547
-
548
- Returns:
549
- The state of the system at the end of the integration horizon, and optionally
550
- the auxiliary data produced by the integrator.
551
- """
552
-
553
- # Close func over additional inputs and parameters
554
- dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)
555
-
556
- # Close one-step integration over its arguments
557
- integrator_single_step = lambda t0, tf, x0: odeint_euler_one_step(
558
- dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
559
- )
560
-
561
- # Integrate the state and compute optional auxiliary data over the horizon
562
- out, aux = integrate_single_step_over_horizon(
563
- integrator_single_step=integrator_single_step, t=t, x0=y0
564
- )
565
-
566
- return (out, aux) if return_aux else out
567
-
568
-
569
- def odeint_euler_semi_implicit(
570
- func,
571
- y0: ODEState,
572
- t: TimeHorizon,
573
- *args,
574
- num_sub_steps: int = 1,
575
- return_aux: bool = False
576
- ) -> Union[ODEState, Tuple[ODEState, Dict[str, Any]]]:
577
- """
578
- Integrate a system of ODEs using the Semi-Implicit Euler method.
579
-
580
- Args:
581
- func: A function that computes the time-derivative of the state.
582
- y0: The initial state as ODEState object.
583
- t: The vector of time instants of the integration horizon.
584
- *args: Additional arguments to be passed to the function func.
585
- num_sub_steps: The number of sub-steps to be performed within each integration step.
586
- return_aux: Whether to return the auxiliary data produced by the integrator.
587
-
588
- Returns:
589
- The state of the system at the end of the integration horizon as ODEState object,
590
- and optionally the auxiliary data produced by the integrator.
591
- """
592
-
593
- # Close func over additional inputs and parameters
594
- dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)
595
-
596
- # Close one-step integration over its arguments
597
- integrator_single_step = lambda t0, tf, x0: odeint_euler_semi_implicit_one_step(
598
- dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
599
- )
600
-
601
- # Integrate the state and compute optional auxiliary data over the horizon
602
- out, aux = integrate_single_step_over_horizon(
603
- integrator_single_step=integrator_single_step, t=t, x0=y0
604
- )
605
-
606
- return (out, aux) if return_aux else out
607
-
608
-
609
- def odeint_rk4(
350
+ def odeint(
610
351
  func,
611
352
  y0: State,
612
353
  t: TimeHorizon,
613
354
  *args,
614
355
  num_sub_steps: int = 1,
615
- return_aux: bool = False
616
- ) -> Union[State, Tuple[State, Dict[str, Any]]]:
356
+ return_aux: bool = False,
357
+ integrator_type: IntegratorType = None,
358
+ ):
617
359
  """
618
- Integrate a system of ODEs using the Runge-Kutta 4 method.
360
+ Integrate a system of ODEs with a fixed-step integrator.
619
361
 
620
362
  Args:
621
363
  func: A function that computes the time-derivative of the state.
@@ -634,8 +376,13 @@ def odeint_rk4(
634
376
  dx_dt_closure = lambda x, ts: func(x, ts, *args)
635
377
 
636
378
  # Close one-step integration over its arguments
637
- integrator_single_step = lambda t0, tf, x0: odeint_rk4_one_step(
638
- dx_dt=dx_dt_closure, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
379
+ integrator_single_step = lambda t0, tf, x0: integrator_fixed_single_step(
380
+ dx_dt=dx_dt_closure,
381
+ x0=x0,
382
+ t0=t0,
383
+ tf=tf,
384
+ num_sub_steps=num_sub_steps,
385
+ integrator_type=integrator_type,
639
386
  )
640
387
 
641
388
  # Integrate the state and compute optional auxiliary data over the horizon
@@ -10,21 +10,7 @@ from jaxsim.physics.algos.soft_contacts import SoftContactsParams
10
10
  from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
11
11
  from jaxsim.physics.model.physics_model import PhysicsModel
12
12
  from jaxsim.simulation import integrators, ode
13
-
14
-
15
- class IntegratorType(enum.IntEnum):
16
- RungeKutta4 = enum.auto()
17
- EulerForward = enum.auto()
18
- EulerSemiImplicit = enum.auto()
19
- EulerSemiImplicitManifold = enum.auto()
20
-
21
-
22
- _integrator_registry = {
23
- IntegratorType.RungeKutta4: integrators.odeint_rk4,
24
- IntegratorType.EulerForward: integrators.odeint_euler,
25
- IntegratorType.EulerSemiImplicit: integrators.odeint_euler_semi_implicit,
26
- IntegratorType.EulerSemiImplicitManifold: integrators.odeint_euler_semi_implicit_manifold_one_step,
27
- }
13
+ from jaxsim.simulation.integrators import IntegratorType
28
14
 
29
15
 
30
16
  @jax.jit
@@ -62,12 +48,13 @@ def ode_integration_fixed_step(
62
48
  )
63
49
 
64
50
  # Integrate over the horizon
65
- out = _integrator_registry[integrator_type](
51
+ out = integrators.odeint(
66
52
  func=dx_dt_closure,
67
53
  y0=x0,
68
54
  t=t,
69
55
  num_sub_steps=num_sub_steps,
70
56
  return_aux=return_aux,
57
+ integrator_type=integrator_type,
71
58
  )
72
59
 
73
60
  # Return output pytree and, optionally, the aux dict
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.dev56
3
+ Version: 0.2.dev77
4
4
  Summary: A physics engine in reduced coordinates implemented with JAX.
5
5
  Home-page: https://github.com/ami-iit/jaxsim
6
6
  Author: Diego Ferigo
@@ -45,6 +45,8 @@ Requires-Dist: pytest >=6.0 ; extra == 'all'
45
45
  Requires-Dist: pytest-forked ; extra == 'all'
46
46
  Requires-Dist: pytest-icdiff ; extra == 'all'
47
47
  Requires-Dist: robot-descriptions ; extra == 'all'
48
+ Requires-Dist: mediapy ; extra == 'all'
49
+ Requires-Dist: mujoco >=3.0.0 ; extra == 'all'
48
50
  Provides-Extra: style
49
51
  Requires-Dist: black[jupyter] ; extra == 'style'
50
52
  Requires-Dist: isort ; extra == 'style'
@@ -54,6 +56,9 @@ Requires-Dist: pytest >=6.0 ; extra == 'testing'
54
56
  Requires-Dist: pytest-forked ; extra == 'testing'
55
57
  Requires-Dist: pytest-icdiff ; extra == 'testing'
56
58
  Requires-Dist: robot-descriptions ; extra == 'testing'
59
+ Provides-Extra: viz
60
+ Requires-Dist: mediapy ; extra == 'viz'
61
+ Requires-Dist: mujoco >=3.0.0 ; extra == 'viz'
57
62
 
58
63
  # JAXsim
59
64
 
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=LJhCG4rsmCrTKTocwRIvllPQeYTxDn-VFn6NjPngn4s,1877
2
- jaxsim/_version.py,sha256=oVeXz3qbT5dJihvLTHmporVSLumZxaoSdyqkd-ndrOc,421
2
+ jaxsim/_version.py,sha256=ea7HN3P3BTsbrEcL4jflFM4NRajTns0RZu9M7F_bdGw,421
3
3
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
4
4
  jaxsim/typing.py,sha256=ErTscpEljFyrhPCisZnLEUt6FWLAuEAh-72Teb8Nz98,626
5
5
  jaxsim/high_level/__init__.py,sha256=aWYBCsYmEO76Qt4GEi91Hye_ifGFLvc_bpy9OQplz2o,69
@@ -17,6 +17,11 @@ jaxsim/math/plucker.py,sha256=44NvKVbcZoG8ivFN1BeXxDpuSFdEre1Q6ZXvhnmIiPY,2282
17
17
  jaxsim/math/quaternion.py,sha256=ToyRnAWU0JvKSSSX2vaJeSw2lMa5BGU72DjtonOUw0k,3685
18
18
  jaxsim/math/rotation.py,sha256=MHOnrpS5Sf4rszhOpZ8w7qXFkEl7UMltYimqqsuYuuU,2187
19
19
  jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
20
+ jaxsim/mujoco/__init__.py,sha256=1g0HCsI_h1sxVPY41JNItdTUynnUskHQZ3SpUaD_N5k,137
21
+ jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
22
+ jaxsim/mujoco/loaders.py,sha256=Cd9ee_hjKMHoVaHiJj8JADqpEVmkyYkqSTP1gBeXvFs,16047
23
+ jaxsim/mujoco/model.py,sha256=0kG2GERxjVFqWZ1K3352rgUNfchB4kRtIrsvv4pS4oc,10766
24
+ jaxsim/mujoco/visualizer.py,sha256=Z80qXqIxn5EDPbCk8pPc0Q1CdxV6fUHpA6fF8-rJFtg,4138
20
25
  jaxsim/parsers/__init__.py,sha256=sonYi-bBWAoB04kp1mxT4uIORxjb7SdZ0ukGPmVx98Y,44
21
26
  jaxsim/parsers/kinematic_graph.py,sha256=5wQnbzu8JE0bbnLRxK4ZsD_gQ9kbBpYbhUSzCMiNWko,23610
22
27
  jaxsim/parsers/descriptions/__init__.py,sha256=EbTfnrK3oCxA3pNv--YUwllJ6uICENvFgAdRbYtS9ts,238
@@ -44,10 +49,10 @@ jaxsim/physics/model/ground_contact.py,sha256=mva-yDzYHREmgUu8jGJmIAsf66_SF6ZISm
44
49
  jaxsim/physics/model/physics_model.py,sha256=kVTIaJQrxALzyWjWrDLnwDOcxmzaPGSpUOS8BCq-g6M,13249
45
50
  jaxsim/physics/model/physics_model_state.py,sha256=LTC-uqUCP1-7-mLHMa6aY4xfBYWuHIextxDH0EEqEmE,5729
46
51
  jaxsim/simulation/__init__.py,sha256=WOWkzq7rMGa4xWvjNqTYtD0Nl4yLQtULGW1xU7hD9m0,182
47
- jaxsim/simulation/integrators.py,sha256=w9hhQULUFyyyw3p5sFn1YPax1TmLJBbzfQw1iEAvq60,21884
52
+ jaxsim/simulation/integrators.py,sha256=WIlL7xi4UocSlWg4Qms8-6puqRYnK5A4r7TJUNPg5g0,13022
48
53
  jaxsim/simulation/ode.py,sha256=ntq_iQPIw3SHj64CZWD2mHAKmt05ZgRpw2UwyTxHDOQ,10380
49
54
  jaxsim/simulation/ode_data.py,sha256=spzHU5LnOL6mJPuuhho-J61koT-bcTRonqMMkiPo3M4,1750
50
- jaxsim/simulation/ode_integration.py,sha256=DdA09pQeP53AwYssWON4v9xcMU3LxlVdmW5_76e7Kv4,2373
55
+ jaxsim/simulation/ode_integration.py,sha256=VDprQYoHEE_iI7ia1Mm3RyYl-LRvHU8dJEvRoGA4TFA,1947
51
56
  jaxsim/simulation/simulator.py,sha256=qCI5QG0WKkBC5GNqauSvI7rSlGD7CLttTzCgLED7iJM,18123
52
57
  jaxsim/simulation/simulator_callbacks.py,sha256=QWdY7dilmjrxeieWCB6RQ-cWpwLuUOK8fYWXpnnBcyU,2217
53
58
  jaxsim/simulation/utils.py,sha256=YdNA1mYGBAE7xVA-Dw7_OoBEuh0J8RS2X0RPQZf4c5E,329
@@ -57,8 +62,8 @@ jaxsim/utils/jaxsim_dataclass.py,sha256=FbjfEoCoYC_F-M3wUggXiEhQ7MMS-V_ciYQca-uS
57
62
  jaxsim/utils/oop.py,sha256=LQhBXkSOD0zgYNJLO7Bl0FPRg-LvtvPzxyQa1WFP0rM,22616
58
63
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
59
64
  jaxsim/utils/vmappable.py,sha256=NqGL9nGFRI5OorCfnjXsjR_yXigzDxL0lW1YhQ_nMTY,3655
60
- jaxsim-0.2.dev56.dist-info/LICENSE,sha256=EsU2z6_sWW4Zduzq3goVWjZoCZVKQsM4H_y0o7oRA7Q,1547
61
- jaxsim-0.2.dev56.dist-info/METADATA,sha256=e8-yDBqWNndop9IWGzRDb54bTbhtx1HldRGisKtZdl0,7292
62
- jaxsim-0.2.dev56.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
63
- jaxsim-0.2.dev56.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
64
- jaxsim-0.2.dev56.dist-info/RECORD,,
65
+ jaxsim-0.2.dev77.dist-info/LICENSE,sha256=EsU2z6_sWW4Zduzq3goVWjZoCZVKQsM4H_y0o7oRA7Q,1547
66
+ jaxsim-0.2.dev77.dist-info/METADATA,sha256=zOGN2gEgqKAs7iRkr9rjGcSsZWUS90wnoekgcZCgaKs,7486
67
+ jaxsim-0.2.dev77.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
68
+ jaxsim-0.2.dev77.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
69
+ jaxsim-0.2.dev77.dist-info/RECORD,,