jaxsim 0.2.dev65__tar.gz → 0.2.dev77__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/PKG-INFO +1 -1
  2. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/_version.py +2 -2
  3. jaxsim-0.2.dev77/src/jaxsim/simulation/integrators.py +393 -0
  4. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/simulation/ode_integration.py +3 -16
  5. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim.egg-info/PKG-INFO +1 -1
  6. jaxsim-0.2.dev65/src/jaxsim/simulation/integrators.py +0 -646
  7. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/.devcontainer/Dockerfile +0 -0
  8. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/.devcontainer/devcontainer.json +0 -0
  9. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/.github/workflows/ci_cd.yml +0 -0
  10. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/.github/workflows/read_the_docs.yml +0 -0
  11. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/.github/workflows/style.yml +0 -0
  12. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/.gitignore +0 -0
  13. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/.readthedocs.yaml +0 -0
  14. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/LICENSE +0 -0
  15. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/README.md +0 -0
  16. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/Makefile +0 -0
  17. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/conf.py +0 -0
  18. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/guide/install.rst +0 -0
  19. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/index.rst +0 -0
  20. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/make.bat +0 -0
  21. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/modules/high_level.rst +0 -0
  22. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/modules/math.rst +0 -0
  23. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/modules/parsers.rst +0 -0
  24. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/modules/physics.rst +0 -0
  25. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/modules/simulation.rst +0 -0
  26. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/modules/typing.rst +0 -0
  27. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/docs/modules/utils.rst +0 -0
  28. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/environment.yml +0 -0
  29. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/examples/.gitattributes +0 -0
  30. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/examples/.gitignore +0 -0
  31. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/examples/PD_controller.ipynb +0 -0
  32. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/examples/Parallel_computing.ipynb +0 -0
  33. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/examples/README.md +0 -0
  34. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/examples/assets/cartpole.urdf +0 -0
  35. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/examples/pixi.lock +0 -0
  36. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/examples/pixi.toml +0 -0
  37. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/pyproject.toml +0 -0
  38. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/setup.cfg +0 -0
  39. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/setup.py +0 -0
  40. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/__init__.py +0 -0
  41. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/high_level/__init__.py +0 -0
  42. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/high_level/common.py +0 -0
  43. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/high_level/joint.py +0 -0
  44. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/high_level/link.py +0 -0
  45. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/high_level/model.py +0 -0
  46. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/logging.py +0 -0
  47. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/__init__.py +0 -0
  48. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/adjoint.py +0 -0
  49. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/conv.py +0 -0
  50. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/cross.py +0 -0
  51. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/inertia.py +0 -0
  52. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/joint.py +0 -0
  53. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/plucker.py +0 -0
  54. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/quaternion.py +0 -0
  55. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/rotation.py +0 -0
  56. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/math/skew.py +0 -0
  57. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/mujoco/__init__.py +0 -0
  58. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/mujoco/__main__.py +0 -0
  59. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/mujoco/loaders.py +0 -0
  60. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/mujoco/model.py +0 -0
  61. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/mujoco/visualizer.py +0 -0
  62. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/__init__.py +0 -0
  63. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  64. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  65. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  66. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/descriptions/link.py +0 -0
  67. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/descriptions/model.py +0 -0
  68. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  69. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/rod/__init__.py +0 -0
  70. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/rod/parser.py +0 -0
  71. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/parsers/rod/utils.py +0 -0
  72. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/__init__.py +0 -0
  73. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/__init__.py +0 -0
  74. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/aba.py +0 -0
  75. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/aba_motors.py +0 -0
  76. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/crba.py +0 -0
  77. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/forward_kinematics.py +0 -0
  78. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/jacobian.py +0 -0
  79. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/rnea.py +0 -0
  80. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/rnea_motors.py +0 -0
  81. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/soft_contacts.py +0 -0
  82. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/terrain.py +0 -0
  83. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/algos/utils.py +0 -0
  84. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/model/__init__.py +0 -0
  85. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/model/ground_contact.py +0 -0
  86. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/model/physics_model.py +0 -0
  87. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/physics/model/physics_model_state.py +0 -0
  88. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/simulation/__init__.py +0 -0
  89. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/simulation/ode.py +0 -0
  90. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/simulation/ode_data.py +0 -0
  91. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/simulation/simulator.py +0 -0
  92. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/simulation/simulator_callbacks.py +0 -0
  93. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/simulation/utils.py +0 -0
  94. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/sixd/__init__.py +0 -0
  95. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/typing.py +0 -0
  96. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/utils/__init__.py +0 -0
  97. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  98. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/utils/oop.py +0 -0
  99. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/utils/tracing.py +0 -0
  100. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim/utils/vmappable.py +0 -0
  101. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  102. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  103. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim.egg-info/not-zip-safe +0 -0
  104. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim.egg-info/requires.txt +0 -0
  105. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/src/jaxsim.egg-info/top_level.txt +0 -0
  106. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/tests/__init__.py +0 -0
  107. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/tests/test_ad_physics.py +0 -0
  108. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/tests/test_eom.py +0 -0
  109. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/tests/test_forward_dynamics.py +0 -0
  110. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/tests/test_jax_oop.py +0 -0
  111. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/tests/utils_idyntree.py +0 -0
  112. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/tests/utils_models.py +0 -0
  113. {jaxsim-0.2.dev65 → jaxsim-0.2.dev77}/tests/utils_rng.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.dev65
3
+ Version: 0.2.dev77
4
4
  Summary: A physics engine in reduced coordinates implemented with JAX.
5
5
  Home-page: https://github.com/ami-iit/jaxsim
6
6
  Author: Diego Ferigo
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.dev65'
16
- __version_tuple__ = version_tuple = (0, 2, 'dev65')
15
+ __version__ = version = '0.2.dev77'
16
+ __version_tuple__ = version_tuple = (0, 2, 'dev77')
@@ -0,0 +1,393 @@
1
+ import enum
2
+ from typing import Any, Callable
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from jax.tree_util import tree_map
7
+
8
+ import jaxsim.typing as jtp
9
+ from jaxsim.math.quaternion import Quaternion
10
+ from jaxsim.physics.algos.soft_contacts import SoftContactsState
11
+ from jaxsim.physics.model.physics_model_state import PhysicsModelState
12
+ from jaxsim.simulation.ode_data import ODEState
13
+ from jaxsim.sixd import se3, so3
14
+
15
+ Time = jtp.FloatLike
16
+ TimeStep = jtp.FloatLike
17
+ TimeHorizon = jtp.VectorLike
18
+
19
+ State = jtp.PyTree
20
+ StateDerivative = jtp.PyTree
21
+
22
+ StateDerivativeCallable = Callable[
23
+ [State, Time], tuple[StateDerivative, dict[str, Any]]
24
+ ]
25
+
26
+
27
+ class IntegratorType(enum.IntEnum):
28
+ RungeKutta4 = enum.auto()
29
+ EulerForward = enum.auto()
30
+ EulerSemiImplicit = enum.auto()
31
+ EulerSemiImplicitManifold = enum.auto()
32
+
33
+
34
+ # =======================
35
+ # Single-step integration
36
+ # =======================
37
+
38
+
39
+ def integrator_fixed_single_step(
40
+ dx_dt: StateDerivativeCallable,
41
+ x0: State | ODEState,
42
+ t0: Time,
43
+ tf: Time,
44
+ integrator_type: IntegratorType,
45
+ num_sub_steps: int = 1,
46
+ ) -> tuple[State | ODEState, dict[str, Any]]:
47
+ """
48
+ Advance a state vector by integrating a sytem dynamics with a fixed-step integrator.
49
+
50
+ Args:
51
+ dx_dt: Callable that computes the state derivative.
52
+ x0: Initial state.
53
+ t0: Initial time.
54
+ tf: Final time.
55
+ integrator_type: Integrator type.
56
+ num_sub_steps: Number of sub-steps to break the integration into.
57
+
58
+ Returns:
59
+ The final state and a dictionary including auxiliary data at t0.
60
+ """
61
+
62
+ # Compute the sub-step size.
63
+ # We break dt in configurable sub-steps.
64
+ dt = tf - t0
65
+ sub_step_dt = dt / num_sub_steps
66
+
67
+ # Initialize the carry
68
+ Carry = tuple[State | ODEState, Time]
69
+ carry_init: Carry = (x0, t0)
70
+
71
+ def forward_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
72
+ """
73
+ Forward Euler integrator.
74
+ """
75
+
76
+ # Unpack the carry
77
+ x_t0, t0 = carry
78
+
79
+ # Compute the state derivative
80
+ dxdt_t0, _ = dx_dt(x_t0, t0)
81
+
82
+ # Integrate the dynamics
83
+ x_tf = jax.tree_util.tree_map(
84
+ lambda x, dxdt: x + sub_step_dt * dxdt, x_t0, dxdt_t0
85
+ )
86
+
87
+ # Update the time
88
+ tf = t0 + sub_step_dt
89
+
90
+ # Pack the carry
91
+ carry = (x_tf, tf)
92
+
93
+ return carry, None
94
+
95
+ def rk4_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
96
+ """
97
+ Runge-Kutta 4 integrator.
98
+ """
99
+
100
+ # Unpack the carry
101
+ x_t0, t0 = carry
102
+
103
+ # Helper to forward the state to compute k2 and k3 at midpoint and k4 at final
104
+ euler_mid = lambda x, dxdt: x + (0.5 * sub_step_dt) * dxdt
105
+ euler_fin = lambda x, dxdt: x + sub_step_dt * dxdt
106
+
107
+ # Compute the RK4 slopes
108
+ k1, _ = dx_dt(x_t0, t0)
109
+ k2, _ = dx_dt(tree_map(euler_mid, x_t0, k1), t0 + 0.5 * sub_step_dt)
110
+ k3, _ = dx_dt(tree_map(euler_mid, x_t0, k2), t0 + 0.5 * sub_step_dt)
111
+ k4, _ = dx_dt(tree_map(euler_fin, x_t0, k3), t0 + sub_step_dt)
112
+
113
+ # Average the slopes and compute the RK4 state derivative
114
+ average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6
115
+ dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4)
116
+
117
+ # Integrate the dynamics
118
+ x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt)
119
+
120
+ # Update the time
121
+ tf = t0 + sub_step_dt
122
+
123
+ # Pack the carry
124
+ carry = (x_tf, tf)
125
+
126
+ return carry, None
127
+
128
+ def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
129
+ """
130
+ Semi-implicit Euler integrator.
131
+ """
132
+
133
+ # Unpack the carry
134
+ x_t0, t0 = carry
135
+
136
+ # Compute the state derivative.
137
+ # We only keep the quantities related to the acceleration and discard those
138
+ # related to the velocity since we are going to use those implicitly integrated
139
+ # from the accelerations.
140
+ StateDerivative = ODEState
141
+ dxdt_t0: StateDerivative = dx_dt(x_t0, t0)[0]
142
+
143
+ # Extract the initial position ∈ ℝ⁷⁺ⁿ and initial velocity ∈ ℝ⁶⁺ⁿ.
144
+ # This integrator, contrarily to most of the other ones, is not generic.
145
+ # It expects to operate on an x object of class ODEState.
146
+ pos_t0 = x_t0.physics_model.position()
147
+ vel_t0 = x_t0.physics_model.velocity()
148
+
149
+ # Extract the velocity derivative
150
+ d_vel_dt = dxdt_t0.physics_model.velocity()
151
+
152
+ # =============================================
153
+ # Perform semi-implicit Euler integration [1-4]
154
+ # =============================================
155
+
156
+ # 1. Integrate the accelerations obtaining the implicit velocities
157
+ # 2. Compute the derivative of the generalized position
158
+ # 3. Integrate the implicit velocities
159
+ # 4. Integrate the remaining state
160
+ # 5. Outside the loop: integrate the quaternion on SO(3) manifold
161
+
162
+ # ----------------------------------------------------------------
163
+ # 1. Integrate the accelerations obtaining the implicit velocities
164
+ # ----------------------------------------------------------------
165
+
166
+ vel_tf = vel_t0 + sub_step_dt * d_vel_dt
167
+
168
+ # -----------------------------------------------------
169
+ # 2. Compute the derivative of the generalized position
170
+ # -----------------------------------------------------
171
+
172
+ # Extract the implicit angular velocity and the initial base quaternion
173
+ W_ω_WB = vel_tf[3:6]
174
+ W_Q_B = x_t0.physics_model.base_quaternion
175
+
176
+ # Compute the quaternion derivative and the base position derivative
177
+ W_Qd_B = Quaternion.derivative(
178
+ quaternion=W_Q_B, omega=W_ω_WB, omega_in_body_fixed=False
179
+ ).squeeze()
180
+
181
+ # Compute the transform of the mixed base frame at t0
182
+ W_H_BW = jnp.vstack(
183
+ [
184
+ jnp.block([jnp.eye(3), jnp.vstack(x_t0.physics_model.base_position)]),
185
+ jnp.array([0, 0, 0, 1]),
186
+ ]
187
+ )
188
+
189
+ # The derivative W_ṗ_B of the base position is the linear component of the
190
+ # mixed velocity B[W]_v_WB. We need to compute it from the velocity in
191
+ # inertial-fixed representation W_vl_WB.
192
+ W_v_WB = vel_tf[0:6]
193
+ BW_Xv_W = se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
194
+ BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3]
195
+
196
+ # Compute the derivative of the generalized position
197
+ d_pos_tf = (
198
+ jnp.hstack([BW_vl_WB, vel_tf[6:]])
199
+ if integrator_type is IntegratorType.EulerSemiImplicitManifold
200
+ else jnp.hstack([BW_vl_WB, W_Qd_B, vel_tf[6:]])
201
+ )
202
+
203
+ # ------------------------------------
204
+ # 3. Integrate the implicit velocities
205
+ # ------------------------------------
206
+
207
+ pos_tf = pos_t0 + sub_step_dt * d_pos_tf
208
+ joint_positions = (
209
+ pos_tf[3:]
210
+ if integrator_type is IntegratorType.EulerSemiImplicitManifold
211
+ else pos_tf[7:]
212
+ )
213
+ base_quaternion = (
214
+ jnp.zeros_like(x_t0.base_quaternion)
215
+ if integrator_type is IntegratorType.EulerSemiImplicitManifold
216
+ else pos_tf[3:7]
217
+ )
218
+
219
+ # ---------------------------------
220
+ # 4. Integrate the remaining state
221
+ # ---------------------------------
222
+
223
+ # Integrate the derivative of the tangential material deformation
224
+ m = x_t0.soft_contacts.tangential_deformation
225
+ ṁ = dxdt_t0.soft_contacts.tangential_deformation
226
+ tangential_deformation_tf = m + sub_step_dt * ṁ
227
+
228
+ # Pack the new state into an ODEState object
229
+ x_tf = ODEState(
230
+ physics_model=PhysicsModelState(
231
+ base_position=pos_tf[0:3],
232
+ base_quaternion=base_quaternion,
233
+ joint_positions=joint_positions,
234
+ base_linear_velocity=vel_tf[0:3],
235
+ base_angular_velocity=vel_tf[3:6],
236
+ joint_velocities=vel_tf[6:],
237
+ ),
238
+ soft_contacts=SoftContactsState(
239
+ tangential_deformation=tangential_deformation_tf
240
+ ),
241
+ )
242
+
243
+ # Update the time
244
+ tf = t0 + sub_step_dt
245
+
246
+ # Pack the carry
247
+ carry = (x_tf, tf)
248
+
249
+ return carry, None
250
+
251
+ _integrator_registry = {
252
+ IntegratorType.RungeKutta4: rk4_body_fun,
253
+ IntegratorType.EulerForward: forward_euler_body_fun,
254
+ IntegratorType.EulerSemiImplicit: semi_implicit_euler_body_fun,
255
+ IntegratorType.EulerSemiImplicitManifold: semi_implicit_euler_body_fun,
256
+ }
257
+
258
+ # Get the body function for the selected integrator
259
+ body_fun = _integrator_registry[integrator_type]
260
+
261
+ # Integrate over the given horizon
262
+ (x_tf, _), _ = jax.lax.scan(
263
+ f=body_fun, init=carry_init, xs=None, length=num_sub_steps
264
+ )
265
+
266
+ if integrator_type is IntegratorType.EulerSemiImplicitManifold:
267
+ # Indices to convert quaternions between serializations
268
+ to_xyzw = jnp.array([1, 2, 3, 0])
269
+ to_wxyz = jnp.array([3, 0, 1, 2])
270
+
271
+ # Get the initial quaternion and the implicitly integrated angular velocity
272
+ W_ω_WB_tf = x_tf.physics_model.base_angular_velocity
273
+ W_Q_B_t0 = so3.SO3.from_quaternion_xyzw(
274
+ x0.physics_model.base_quaternion[to_xyzw]
275
+ )
276
+
277
+ # Integrate the quaternion on its manifold using the implicit angular velocity,
278
+ # transformed in body-fixed representation since jaxlie uses this convention
279
+ B_R_W = W_Q_B_t0.inverse().as_matrix()
280
+ W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf)
281
+
282
+ # Store the quaternion in the final state
283
+ x_tf = x_tf.replace(
284
+ physics_model=x_tf.physics_model.replace(
285
+ base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
286
+ )
287
+ )
288
+
289
+ # Compute the aux dictionary at t0
290
+ _, aux_t0 = dx_dt(x0, t0)
291
+
292
+ return x_tf, aux_t0
293
+
294
+
295
+ # ===============================
296
+ # Adapter: single step -> horizon
297
+ # ===============================
298
+
299
+
300
+ def integrate_single_step_over_horizon(
301
+ integrator_single_step: Callable[[Time, Time, State], tuple[State, dict[str, Any]]],
302
+ t: TimeHorizon,
303
+ x0: State,
304
+ ) -> tuple[State, dict[str, Any]]:
305
+ """
306
+ Integrate a single-step integrator over a given horizon.
307
+
308
+ Args:
309
+ integrator_single_step: A single-step integrator.
310
+ t: The vector of time instants of the integration horizon.
311
+ x0: The initial state of the integration horizon.
312
+
313
+ Returns:
314
+ The final state and auxiliary data produced by the integrator.
315
+ """
316
+
317
+ # Initialize the carry
318
+ carry_init = (x0, t)
319
+
320
+ def body_fun(carry: tuple, idx: int) -> tuple[tuple, jtp.PyTree]:
321
+ # Unpack the carry
322
+ x_t0, horizon = carry
323
+
324
+ # Get the integration interval
325
+ t0 = horizon[idx]
326
+ tf = horizon[idx + 1]
327
+
328
+ # Perform a single-step integration of the ODE
329
+ x_tf, aux_t0 = integrator_single_step(t0, tf, x_t0)
330
+
331
+ # Prepare returned data
332
+ out = (x_t0, aux_t0)
333
+ carry = (x_tf, horizon)
334
+
335
+ return carry, out
336
+
337
+ # Integrate over the given horizon
338
+ _, (x_horizon, aux_horizon) = jax.lax.scan(
339
+ f=body_fun, init=carry_init, xs=jnp.arange(start=0, stop=len(t), dtype=int)
340
+ )
341
+
342
+ return x_horizon, aux_horizon
343
+
344
+
345
+ # ===================================================================
346
+ # Integration over horizon (same APIs of jax.experimental.ode.odeint)
347
+ # ===================================================================
348
+
349
+
350
+ def odeint(
351
+ func,
352
+ y0: State,
353
+ t: TimeHorizon,
354
+ *args,
355
+ num_sub_steps: int = 1,
356
+ return_aux: bool = False,
357
+ integrator_type: IntegratorType = None,
358
+ ):
359
+ """
360
+ Integrate a system of ODEs with a fixed-step integrator.
361
+
362
+ Args:
363
+ func: A function that computes the time-derivative of the state.
364
+ y0: The initial state.
365
+ t: The vector of time instants of the integration horizon.
366
+ *args: Additional arguments to be passed to the function func.
367
+ num_sub_steps: The number of sub-steps to be performed within each integration step.
368
+ return_aux: Whether to return the auxiliary data produced by the integrator.
369
+
370
+ Returns:
371
+ The state of the system at the end of the integration horizon, and optionally
372
+ the auxiliary data produced by the integrator.
373
+ """
374
+
375
+ # Close func over additional inputs and parameters
376
+ dx_dt_closure = lambda x, ts: func(x, ts, *args)
377
+
378
+ # Close one-step integration over its arguments
379
+ integrator_single_step = lambda t0, tf, x0: integrator_fixed_single_step(
380
+ dx_dt=dx_dt_closure,
381
+ x0=x0,
382
+ t0=t0,
383
+ tf=tf,
384
+ num_sub_steps=num_sub_steps,
385
+ integrator_type=integrator_type,
386
+ )
387
+
388
+ # Integrate the state and compute optional auxiliary data over the horizon
389
+ out, aux = integrate_single_step_over_horizon(
390
+ integrator_single_step=integrator_single_step, t=t, x0=y0
391
+ )
392
+
393
+ return (out, aux) if return_aux else out
@@ -10,21 +10,7 @@ from jaxsim.physics.algos.soft_contacts import SoftContactsParams
10
10
  from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
11
11
  from jaxsim.physics.model.physics_model import PhysicsModel
12
12
  from jaxsim.simulation import integrators, ode
13
-
14
-
15
- class IntegratorType(enum.IntEnum):
16
- RungeKutta4 = enum.auto()
17
- EulerForward = enum.auto()
18
- EulerSemiImplicit = enum.auto()
19
- EulerSemiImplicitManifold = enum.auto()
20
-
21
-
22
- _integrator_registry = {
23
- IntegratorType.RungeKutta4: integrators.odeint_rk4,
24
- IntegratorType.EulerForward: integrators.odeint_euler,
25
- IntegratorType.EulerSemiImplicit: integrators.odeint_euler_semi_implicit,
26
- IntegratorType.EulerSemiImplicitManifold: integrators.odeint_euler_semi_implicit_manifold_one_step,
27
- }
13
+ from jaxsim.simulation.integrators import IntegratorType
28
14
 
29
15
 
30
16
  @jax.jit
@@ -62,12 +48,13 @@ def ode_integration_fixed_step(
62
48
  )
63
49
 
64
50
  # Integrate over the horizon
65
- out = _integrator_registry[integrator_type](
51
+ out = integrators.odeint(
66
52
  func=dx_dt_closure,
67
53
  y0=x0,
68
54
  t=t,
69
55
  num_sub_steps=num_sub_steps,
70
56
  return_aux=return_aux,
57
+ integrator_type=integrator_type,
71
58
  )
72
59
 
73
60
  # Return output pytree and, optionally, the aux dict
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.dev65
3
+ Version: 0.2.dev77
4
4
  Summary: A physics engine in reduced coordinates implemented with JAX.
5
5
  Home-page: https://github.com/ami-iit/jaxsim
6
6
  Author: Diego Ferigo