jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,452 +0,0 @@
1
- from typing import Any, Callable, Dict, Tuple, Union
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- from jax.tree_util import tree_map
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.physics.algos.soft_contacts import SoftContactsState
9
- from jaxsim.physics.model.physics_model_state import PhysicsModelState
10
- from jaxsim.simulation.ode_data import ODEState
11
-
12
- Time = float
13
- TimeHorizon = jtp.Vector
14
-
15
- State = jtp.PyTree
16
- StateDerivative = jtp.PyTree
17
-
18
- StateDerivativeCallable = Callable[
19
- [State, Time], Tuple[StateDerivative, Dict[str, Any]]
20
- ]
21
-
22
-
23
- # =======================
24
- # Single-step integration
25
- # =======================
26
-
27
-
28
- def odeint_euler_one_step(
29
- dx_dt: StateDerivativeCallable,
30
- x0: State,
31
- t0: Time,
32
- tf: Time,
33
- num_sub_steps: int = 1,
34
- ) -> Tuple[State, Dict[str, Any]]:
35
- """
36
- Forward Euler integrator.
37
-
38
- Args:
39
- dx_dt: Callable that computes the state derivative.
40
- x0: Initial state.
41
- t0: Initial time.
42
- tf: Final time.
43
- num_sub_steps: Number of sub-steps to break the integration into.
44
-
45
- Returns:
46
- The final state and a dictionary including auxiliary data at t0.
47
- """
48
-
49
- # Compute the sub-step size.
50
- # We break dt in configurable sub-steps.
51
- dt = tf - t0
52
- sub_step_dt = dt / num_sub_steps
53
-
54
- # Initialize the carry
55
- Carry = Tuple[State, Time]
56
- carry_init: Carry = (x0, t0)
57
-
58
- def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
59
- # Unpack the carry
60
- x_t0, t0 = carry
61
-
62
- # Compute the state derivative
63
- dxdt_t0, _ = dx_dt(x_t0, t0)
64
-
65
- # Integrate the dynamics
66
- x_tf = jax.tree_util.tree_map(
67
- lambda x, dxdt: x + sub_step_dt * dxdt, x_t0, dxdt_t0
68
- )
69
-
70
- # Update the time
71
- tf = t0 + sub_step_dt
72
-
73
- # Pack the carry
74
- carry = (x_tf, tf)
75
-
76
- return carry, None
77
-
78
- # Integrate over the given horizon
79
- (x_tf, _), _ = jax.lax.scan(
80
- f=body_fun, init=carry_init, xs=None, length=num_sub_steps
81
- )
82
-
83
- # Compute the aux dictionary at t0
84
- _, aux_t0 = dx_dt(x0, t0)
85
-
86
- return x_tf, aux_t0
87
-
88
-
89
- def odeint_rk4_one_step(
90
- dx_dt: StateDerivativeCallable,
91
- x0: State,
92
- t0: Time,
93
- tf: Time,
94
- num_sub_steps: int = 1,
95
- ) -> Tuple[State, Dict[str, Any]]:
96
- """
97
- Runge-Kutta 4 integrator.
98
-
99
- Args:
100
- dx_dt: Callable that computes the state derivative.
101
- x0: Initial state.
102
- t0: Initial time.
103
- tf: Final time.
104
- num_sub_steps: Number of sub-steps to break the integration into.
105
-
106
- Returns:
107
- The final state and a dictionary including auxiliary data at t0.
108
- """
109
-
110
- # Compute the sub-step size.
111
- # We break dt in configurable sub-steps.
112
- dt = tf - t0
113
- sub_step_dt = dt / num_sub_steps
114
-
115
- # Initialize the carry
116
- Carry = Tuple[State, Time]
117
- carry_init: Carry = (x0, t0)
118
-
119
- def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
120
- # Unpack the carry
121
- x_t0, t0 = carry
122
-
123
- # Helper to forward the state to compute k2 and k3 at midpoint and k4 at final
124
- euler_mid = lambda x, dxdt: x + (0.5 * sub_step_dt) * dxdt
125
- euler_fin = lambda x, dxdt: x + sub_step_dt * dxdt
126
-
127
- # Compute the RK4 slopes
128
- k1, _ = dx_dt(x_t0, t0)
129
- k2, _ = dx_dt(tree_map(euler_mid, x_t0, k1), t0 + 0.5 * sub_step_dt)
130
- k3, _ = dx_dt(tree_map(euler_mid, x_t0, k2), t0 + 0.5 * sub_step_dt)
131
- k4, _ = dx_dt(tree_map(euler_fin, x_t0, k3), t0 + sub_step_dt)
132
-
133
- # Average the slopes and compute the RK4 state derivative
134
- average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6
135
- dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4)
136
-
137
- # Integrate the dynamics
138
- x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt)
139
-
140
- # Update the time
141
- tf = t0 + sub_step_dt
142
-
143
- # Pack the carry
144
- carry = (x_tf, tf)
145
-
146
- return carry, None
147
-
148
- # Integrate over the given horizon
149
- (x_tf, _), _ = jax.lax.scan(
150
- f=body_fun, init=carry_init, xs=None, length=num_sub_steps
151
- )
152
-
153
- # Compute the aux dictionary at t0
154
- _, aux_t0 = dx_dt(x0, t0)
155
-
156
- return x_tf, aux_t0
157
-
158
-
159
- def odeint_euler_semi_implicit_one_step(
160
- dx_dt: StateDerivativeCallable,
161
- x0: ODEState,
162
- t0: Time,
163
- tf: Time,
164
- num_sub_steps: int = 1,
165
- ) -> Tuple[ODEState, Dict[str, Any]]:
166
- """
167
- Semi-implicit Euler integrator.
168
-
169
- Args:
170
- dx_dt: Callable that computes the state derivative.
171
- x0: Initial state.
172
- t0: Initial time.
173
- tf: Final time.
174
- num_sub_steps: Number of sub-steps to break the integration into.
175
-
176
- Returns:
177
- The final state and a dictionary including auxiliary data at t0.
178
- """
179
-
180
- # Compute the sub-step size.
181
- # We break dt in configurable sub-steps.
182
- dt = tf - t0
183
- sub_step_dt = dt / num_sub_steps
184
-
185
- # Initialize the carry
186
- Carry = Tuple[ODEState, Time]
187
- carry_init: Carry = (x0, t0)
188
-
189
- def quaternion_derivative(W_Q_B: jtp.Vector, W_omega_WB: jtp.Vector) -> jtp.Vector:
190
- from jaxsim.math.quaternion import Quaternion
191
-
192
- return Quaternion.derivative(
193
- quaternion=W_Q_B, omega=W_omega_WB, omega_in_body_fixed=False
194
- ).squeeze()
195
-
196
- def inertial_to_3d_mixed(
197
- W_v_lin_WB: jtp.Vector, W_v_ang_WB: jtp.Vector, W_pos_B: jtp.Vector
198
- ) -> jtp.Vector:
199
- from jaxsim.math.conv import Convert
200
-
201
- # Compute linear component of mixed velocity BW_v_WB
202
- return Convert.velocities_threed(
203
- v_6d=jnp.hstack([W_v_lin_WB, W_v_ang_WB]), p=W_pos_B.squeeze()
204
- ).squeeze()
205
-
206
- def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
207
- # Unpack the carry
208
- x_t0, t0 = carry
209
-
210
- # Extract the initial position and velocity
211
- pos_t0 = x_t0.physics_model.position()
212
- vel_t0 = x_t0.physics_model.velocity()
213
-
214
- # Compute the state derivative
215
- StateDerivative = ODEState
216
- dxdt_t0: StateDerivative = dx_dt(x_t0, t0)[0]
217
-
218
- # Extract the velocity derivative
219
- d_vel_dt = dxdt_t0.physics_model.velocity()
220
-
221
- # Perform semi-implicit Euler integration [1-4].
222
-
223
- # 1. Integrate the velocities
224
- vel_tf = vel_t0 + sub_step_dt * d_vel_dt
225
-
226
- # 2. Compute the quaternion derivative and the base position derivative
227
- W_Qd_B = quaternion_derivative(
228
- W_Q_B=x_t0.physics_model.base_quaternion, W_omega_WB=vel_tf[3:6]
229
- )
230
- BW_v_WB = inertial_to_3d_mixed(
231
- W_pos_B=x_t0.physics_model.base_position,
232
- W_v_lin_WB=x_t0.physics_model.base_linear_velocity,
233
- W_v_ang_WB=x_t0.physics_model.base_angular_velocity,
234
- )
235
-
236
- # 3. Compute the derivative of the position
237
- posd_tf = jnp.hstack([BW_v_WB, W_Qd_B, vel_tf[6:]])
238
-
239
- # 4. Integrate the positions
240
- pos_tf = pos_t0 + sub_step_dt * posd_tf
241
-
242
- # Integrate the remaining state
243
- u = x_t0.soft_contacts.tangential_deformation
244
- ud = dxdt_t0.soft_contacts.tangential_deformation
245
- tangential_deformation_tf = u + sub_step_dt * ud
246
-
247
- x_tf = ODEState(
248
- physics_model=PhysicsModelState(
249
- base_position=pos_tf[0:3],
250
- base_quaternion=pos_tf[3:7],
251
- joint_positions=pos_tf[7:],
252
- base_linear_velocity=vel_tf[0:3],
253
- base_angular_velocity=vel_tf[3:6],
254
- joint_velocities=vel_tf[6:],
255
- ),
256
- soft_contacts=SoftContactsState(
257
- tangential_deformation=tangential_deformation_tf
258
- ),
259
- )
260
-
261
- # Update the time
262
- tf = t0 + sub_step_dt
263
-
264
- # Pack the carry
265
- carry = (x_tf, tf)
266
-
267
- return carry, None
268
-
269
- # Integrate over the given horizon
270
- (x_tf, _), _ = jax.lax.scan(
271
- f=body_fun, init=carry_init, xs=None, length=num_sub_steps
272
- )
273
-
274
- # Compute the aux dictionary at t0
275
- _, aux_t0 = dx_dt(x0, t0)
276
-
277
- return x_tf, aux_t0
278
-
279
-
280
- # ===============================
281
- # Adapter: single step -> horizon
282
- # ===============================
283
-
284
-
285
- def integrate_single_step_over_horizon(
286
- integrator_single_step: Callable[[Time, Time, State], Tuple[State, Dict[str, Any]]],
287
- t: TimeHorizon,
288
- x0: State,
289
- ) -> Tuple[State, Dict[str, Any]]:
290
- """
291
- Integrate a single-step integrator over a given horizon.
292
-
293
- Args:
294
- integrator_single_step: A single-step integrator.
295
- t: The vector of time instants of the integration horizon.
296
- x0: The initial state of the integration horizon.
297
-
298
- Returns:
299
- The final state and auxiliary data produced by the integrator.
300
- """
301
-
302
- # Initialize the carry
303
- carry_init = (x0, t)
304
-
305
- def body_fun(carry: Tuple, idx: int) -> Tuple[Tuple, jtp.PyTree]:
306
- # Unpack the carry
307
- x_t0, horizon = carry
308
-
309
- # Get the integration interval
310
- t0 = horizon[idx]
311
- tf = horizon[idx + 1]
312
-
313
- # Perform a single-step integration of the ODE
314
- x_tf, aux_t0 = integrator_single_step(t0, tf, x_t0)
315
-
316
- # Prepare returned data
317
- out = (x_t0, aux_t0)
318
- carry = (x_tf, horizon)
319
-
320
- return carry, out
321
-
322
- # Integrate over the given horizon
323
- _, (x_horizon, aux_horizon) = jax.lax.scan(
324
- f=body_fun, init=carry_init, xs=jnp.arange(start=0, stop=len(t))
325
- )
326
-
327
- return x_horizon, aux_horizon
328
-
329
-
330
- # ===================================================================
331
- # Integration over horizon (same APIs of jax.experimental.ode.odeint)
332
- # ===================================================================
333
-
334
-
335
- def odeint_euler(
336
- func,
337
- y0: State,
338
- t: TimeHorizon,
339
- *args,
340
- num_sub_steps: int = 1,
341
- return_aux: bool = False
342
- ) -> Union[State, Tuple[State, Dict[str, Any]]]:
343
- """
344
- Integrate a system of ODEs using the Euler method.
345
-
346
- Args:
347
- func: A function that computes the time-derivative of the state.
348
- y0: The initial state.
349
- t: The vector of time instants of the integration horizon.
350
- *args: Additional arguments to be passed to the function func.
351
- num_sub_steps: The number of sub-steps to be performed within each integration step.
352
- return_aux: Whether to return the auxiliary data produced by the integrator.
353
-
354
- Returns:
355
- The state of the system at the end of the integration horizon, and optionally
356
- the auxiliary data produced by the integrator.
357
- """
358
-
359
- # Close func over additional inputs and parameters
360
- dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)
361
-
362
- # Close one-step integration over its arguments
363
- integrator_single_step = lambda t0, tf, x0: odeint_euler_one_step(
364
- dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
365
- )
366
-
367
- # Integrate the state and compute optional auxiliary data over the horizon
368
- out, aux = integrate_single_step_over_horizon(
369
- integrator_single_step=integrator_single_step, t=t, x0=y0
370
- )
371
-
372
- return (out, aux) if return_aux else out
373
-
374
-
375
- def odeint_euler_semi_implicit(
376
- func,
377
- y0: State,
378
- t: TimeHorizon,
379
- *args,
380
- num_sub_steps: int = 1,
381
- return_aux: bool = False
382
- ) -> Union[State, Tuple[State, Dict[str, Any]]]:
383
- """
384
- Integrate a system of ODEs using the Semi-Implicit Euler method.
385
-
386
- Args:
387
- func: A function that computes the time-derivative of the state.
388
- y0: The initial state.
389
- t: The vector of time instants of the integration horizon.
390
- *args: Additional arguments to be passed to the function func.
391
- num_sub_steps: The number of sub-steps to be performed within each integration step.
392
- return_aux: Whether to return the auxiliary data produced by the integrator.
393
-
394
- Returns:
395
- The state of the system at the end of the integration horizon, and optionally
396
- the auxiliary data produced by the integrator.
397
- """
398
-
399
- # Close func over additional inputs and parameters
400
- dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)
401
-
402
- # Close one-step integration over its arguments
403
- integrator_single_step = lambda t0, tf, x0: odeint_euler_semi_implicit_one_step(
404
- dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
405
- )
406
-
407
- # Integrate the state and compute optional auxiliary data over the horizon
408
- out, aux = integrate_single_step_over_horizon(
409
- integrator_single_step=integrator_single_step, t=t, x0=y0
410
- )
411
-
412
- return (out, aux) if return_aux else out
413
-
414
-
415
- def odeint_rk4(
416
- func,
417
- y0: State,
418
- t: TimeHorizon,
419
- *args,
420
- num_sub_steps: int = 1,
421
- return_aux: bool = False
422
- ) -> Union[State, Tuple[State, Dict[str, Any]]]:
423
- """
424
- Integrate a system of ODEs using the Runge-Kutta 4 method.
425
-
426
- Args:
427
- func: A function that computes the time-derivative of the state.
428
- y0: The initial state.
429
- t: The vector of time instants of the integration horizon.
430
- *args: Additional arguments to be passed to the function func.
431
- num_sub_steps: The number of sub-steps to be performed within each integration step.
432
- return_aux: Whether to return the auxiliary data produced by the integrator.
433
-
434
- Returns:
435
- The state of the system at the end of the integration horizon, and optionally
436
- the auxiliary data produced by the integrator.
437
- """
438
-
439
- # Close func over additional inputs and parameters
440
- dx_dt_closure = lambda x, ts: func(x, ts, *args)
441
-
442
- # Close one-step integration over its arguments
443
- integrator_single_step = lambda t0, tf, x0: odeint_rk4_one_step(
444
- dx_dt=dx_dt_closure, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
445
- )
446
-
447
- # Integrate the state and compute optional auxiliary data over the horizon
448
- out, aux = integrate_single_step_over_horizon(
449
- integrator_single_step=integrator_single_step, t=t, x0=y0
450
- )
451
-
452
- return (out, aux) if return_aux else out