jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/simulation/ode.py DELETED
@@ -1,290 +0,0 @@
1
- from typing import Any, Dict, Optional, Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.physics import algos
9
- from jaxsim.physics.algos.soft_contacts import (
10
- SoftContacts,
11
- SoftContactsParams,
12
- collidable_points_pos_vel,
13
- )
14
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
15
- from jaxsim.physics.model.physics_model import PhysicsModel
16
-
17
- from . import ode_data
18
-
19
-
20
- def compute_contact_forces(
21
- physics_model: PhysicsModel,
22
- ode_state: ode_data.ODEState,
23
- soft_contacts_params: SoftContactsParams = SoftContactsParams(),
24
- terrain: Terrain = FlatTerrain(),
25
- ) -> Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]:
26
- """
27
- Compute the contact forces acting on the collidable points of the model.
28
-
29
- Args:
30
- physics_model: The physics model to consider.
31
- ode_state: The state of the ODE corresponding to the physics model.
32
- soft_contacts_params: The parameters of the soft contacts model.
33
- terrain: The terrain model.
34
-
35
- Returns:
36
- A tuple containing:
37
- - The contact forces expressed in the world frame acting on the model's links.
38
- - The derivative of the tangential deformation of the terrain dynamics.
39
- - The contact forces expressed in the world frame acting on the model's collidable points.
40
- """
41
-
42
- # Compute position and linear mixed velocity of all model's collidable points
43
- # collidable_points_kinematics
44
- pos_cp, vel_cp = collidable_points_pos_vel(
45
- model=physics_model,
46
- q=ode_state.physics_model.joint_positions,
47
- qd=ode_state.physics_model.joint_velocities,
48
- xfb=ode_state.physics_model.xfb(),
49
- )
50
-
51
- # Compute the forces acting on the collidable points due to contact with
52
- # the compliant ground surface. Apply vmap to process all points together.
53
- contact_forces_points, tangential_deformation_dot = jax.vmap(
54
- SoftContacts(parameters=soft_contacts_params, terrain=terrain).contact_model
55
- )(pos_cp.T, vel_cp.T, ode_state.soft_contacts.tangential_deformation.T)
56
-
57
- contact_forces_points = contact_forces_points.T
58
- tangential_deformation_dot = tangential_deformation_dot.T
59
-
60
- # Initialize the contact forces, one per body
61
- contact_forces_links = jnp.zeros_like(
62
- ode_data.ODEInput.zero(physics_model).physics_model.f_ext
63
- )
64
-
65
- # Combine the contact forces of all collidable points belonging to the same body
66
- for body_idx in set(physics_model.gc.body):
67
- body_idx = int(body_idx)
68
- contact_forces_links = contact_forces_links.at[body_idx, :].set(
69
- jnp.sum(contact_forces_points[:, physics_model.gc.body == body_idx], axis=1)
70
- )
71
-
72
- return contact_forces_links, tangential_deformation_dot, contact_forces_points.T
73
-
74
-
75
- def dx_dt(
76
- x: ode_data.ODEState,
77
- t: jtp.Float | None,
78
- physics_model: PhysicsModel,
79
- soft_contacts_params: SoftContactsParams = SoftContactsParams(),
80
- ode_input: ode_data.ODEInput | None = None,
81
- terrain: Terrain = FlatTerrain(),
82
- ) -> Tuple[ode_data.ODEState, Dict[str, Any]]:
83
- """
84
- Compute the state derivative of the ODE corresponding to the physics model.
85
-
86
- Args:
87
- x: The state of the ODE.
88
- t: The current time.
89
- physics_model: The physics model to consider.
90
- soft_contacts_params: The parameters of the soft contacts model.
91
- ode_input: The input of the ODE.
92
- terrain: The terrain model.
93
-
94
- Returns:
95
- A tuple containing:
96
- - The state derivative of the ODE.
97
- - A dictionary containing auxiliary information.
98
- """
99
-
100
- if t is not None and isinstance(t, np.ndarray) and t.size != 1:
101
- raise ValueError(t.size)
102
-
103
- # Initialize arguments
104
- ode_state = x
105
- ode_input = (
106
- ode_input
107
- if ode_input is not None
108
- else ode_data.ODEInput.zero(physics_model=physics_model)
109
- )
110
-
111
- # ======================
112
- # Compute contact forces
113
- # ======================
114
-
115
- # Initialize the collidable points contact forces
116
- contact_forces_points = None
117
-
118
- # Initialize the contact forces, one per body
119
- contact_forces_links = jnp.zeros_like(ode_input.physics_model.f_ext)
120
-
121
- # Initialize the derivative of the tangential deformation
122
- tangential_deformation_dot = jnp.zeros_like(
123
- ode_state.soft_contacts.tangential_deformation
124
- )
125
-
126
- if physics_model.gc.body.size > 0:
127
- (
128
- contact_forces_links,
129
- tangential_deformation_dot,
130
- contact_forces_points,
131
- ) = compute_contact_forces(
132
- physics_model=physics_model,
133
- soft_contacts_params=soft_contacts_params,
134
- ode_state=ode_state,
135
- terrain=terrain,
136
- )
137
-
138
- # =====================
139
- # Joint position limits
140
- # =====================
141
-
142
- if physics_model.dofs() > 0:
143
- # Get the joint position limits
144
- s_min, s_max = jnp.array(
145
- [j.position_limit for j in physics_model.description.joints_dict.values()]
146
- ).T
147
-
148
- # Get the spring/damper parameters of joint limits enforcement
149
- k_damper = jnp.array(list(physics_model._joint_limit_damper.values()))
150
-
151
- # Compute the joint torques that enforce joint limits
152
- s = ode_state.physics_model.joint_positions
153
- tau_min = jnp.where(s <= s_min, k_damper * (s_min - s), 0)
154
- tau_max = jnp.where(s >= s_max, k_damper * (s_max - s), 0)
155
- tau_limit = tau_max + tau_min
156
-
157
- else:
158
- tau_limit = jnp.zeros_like(ode_input.physics_model.tau)
159
-
160
- # ==============
161
- # Joint friction
162
- # ==============
163
-
164
- if physics_model.dofs() > 0:
165
- # Static and viscous joint friction parameters
166
- kc = jnp.array(list(physics_model._joint_friction_static.values()))
167
- kv = jnp.array(list(physics_model._joint_friction_viscous.values()))
168
-
169
- # Compute the joint friction torque
170
- tau_friction = -(
171
- jnp.diag(kc) @ jnp.sign(ode_state.physics_model.joint_positions)
172
- + jnp.diag(kv) @ ode_state.physics_model.joint_velocities
173
- )
174
-
175
- else:
176
- tau_friction = jnp.zeros_like(ode_input.physics_model.tau)
177
-
178
- # ========================
179
- # Compute forward dynamics
180
- # ========================
181
-
182
- # Compute the total forces applied to the bodies
183
- total_forces = ode_input.physics_model.f_ext + contact_forces_links
184
-
185
- # Compute the joint torques to actuate
186
- tau = ode_input.physics_model.tau + tau_friction + tau_limit
187
-
188
- # Compute forward dynamics with the ABA algorithm
189
- W_a_WB, qdd = algos.aba.aba(
190
- model=physics_model,
191
- xfb=ode_state.physics_model.xfb(),
192
- q=ode_state.physics_model.joint_positions,
193
- qd=ode_state.physics_model.joint_velocities,
194
- tau=tau,
195
- f_ext=total_forces,
196
- )
197
-
198
- # =========================================
199
- # Compute the state derivative of base link
200
- # =========================================
201
-
202
- if not physics_model.is_floating_base:
203
- W_Qd_B = jnp.zeros(4)
204
- BW_v_WB = jnp.zeros(3)
205
-
206
- else:
207
- from jaxsim.math.conv import Convert
208
- from jaxsim.math.quaternion import Quaternion
209
-
210
- W_Qd_B = Quaternion.derivative(
211
- quaternion=ode_state.physics_model.base_quaternion,
212
- omega=ode_state.physics_model.base_angular_velocity,
213
- omega_in_body_fixed=False,
214
- ).squeeze()
215
-
216
- # Compute linear component of mixed velocity
217
- BW_v_WB = Convert.velocities_threed(
218
- v_6d=jnp.hstack(
219
- [
220
- ode_state.physics_model.base_linear_velocity,
221
- ode_state.physics_model.base_angular_velocity,
222
- ]
223
- ),
224
- p=ode_state.physics_model.base_position,
225
- ).squeeze()
226
-
227
- # Derivative of xfb (floating-base state)
228
- xd_fb = jnp.hstack([W_Qd_B, BW_v_WB, W_a_WB.squeeze()]).squeeze()
229
-
230
- # =====================================
231
- # Build the full derivative of ODEState
232
- # =====================================
233
-
234
- def fix_one_dof(vector: jtp.Vector) -> jtp.Vector | None:
235
- """Fix the shape of computed quantities for models with just 1 DoF."""
236
-
237
- if vector is None:
238
- return None
239
-
240
- return jnp.array([vector]) if vector.shape == () else vector
241
-
242
- # Fill the PhysicsModelState object included in the input ODEState to store the
243
- # returned PhysicsModelState derivative
244
- physics_model_state_derivative = ode_state.physics_model.replace(
245
- joint_positions=fix_one_dof(ode_state.physics_model.joint_velocities.squeeze()),
246
- joint_velocities=fix_one_dof(qdd.squeeze()),
247
- base_quaternion=xd_fb.squeeze()[0:4],
248
- base_position=xd_fb.squeeze()[4:7],
249
- base_angular_velocity=xd_fb.squeeze()[10:13],
250
- base_linear_velocity=xd_fb.squeeze()[7:10],
251
- )
252
-
253
- # Fill the SoftContactsState object included in the input ODEState to store the
254
- # returned SoftContactsState derivative
255
- soft_contacts_state_derivative = ode_state.soft_contacts.replace(
256
- tangential_deformation=tangential_deformation_dot.squeeze(),
257
- )
258
-
259
- # We store the state derivative using the ODEState class so that the pytree
260
- # structure remains consistent, allowing to use our generic pytree integrators
261
- state_derivative = ode_data.ODEState(
262
- physics_model=physics_model_state_derivative,
263
- soft_contacts=soft_contacts_state_derivative,
264
- )
265
-
266
- # ===============================
267
- # Build auxiliary data and return
268
- # ===============================
269
-
270
- # Real ODEInput containing the real joint forces that have been actuated and
271
- # the total external forces (= original external forces + terrain + limits)
272
- ode_input_real = ode_data.ODEInput(
273
- physics_model=ode_data.PhysicsModelInput(tau=tau, f_ext=total_forces)
274
- )
275
-
276
- # Pack the inertial-fixed floating-base acceleration
277
- W_nud_WB = jnp.hstack([W_a_WB.squeeze(), qdd.squeeze()])
278
-
279
- # Build the auxiliary data
280
- aux_dict = dict(
281
- model_acceleration=W_nud_WB,
282
- ode_input=ode_input,
283
- ode_input_real=ode_input_real,
284
- contact_forces_links=contact_forces_links,
285
- contact_forces_points=contact_forces_points,
286
- tangential_deformation_dot=tangential_deformation_dot,
287
- )
288
-
289
- # Return the state derivative as a generic PyTree, and the dict with auxiliary info
290
- return state_derivative, aux_dict
@@ -1,53 +0,0 @@
1
- import jax.flatten_util
2
- import jax_dataclasses
3
-
4
- import jaxsim.typing as jtp
5
- from jaxsim.physics.algos.soft_contacts import SoftContactsState
6
- from jaxsim.physics.model.physics_model import PhysicsModel
7
- from jaxsim.physics.model.physics_model_state import (
8
- PhysicsModelInput,
9
- PhysicsModelState,
10
- )
11
- from jaxsim.utils import JaxsimDataclass
12
-
13
-
14
- @jax_dataclasses.pytree_dataclass
15
- class ODEInput(JaxsimDataclass):
16
- physics_model: PhysicsModelInput
17
-
18
- @staticmethod
19
- def zero(physics_model: PhysicsModel) -> "ODEInput":
20
- return ODEInput(
21
- physics_model=PhysicsModelInput.zero(physics_model=physics_model)
22
- )
23
-
24
- def valid(self, physics_model: PhysicsModel) -> bool:
25
- return self.physics_model.valid(physics_model=physics_model)
26
-
27
-
28
- @jax_dataclasses.pytree_dataclass
29
- class ODEState(JaxsimDataclass):
30
- physics_model: PhysicsModelState
31
- soft_contacts: SoftContactsState
32
-
33
- @staticmethod
34
- def deserialize(data: jtp.VectorJax, physics_model: PhysicsModel) -> "ODEState":
35
- dummy_object = ODEState.zero(physics_model=physics_model)
36
- _, unflatten_data = jax.flatten_util.ravel_pytree(dummy_object)
37
-
38
- return unflatten_data(data)
39
-
40
- @staticmethod
41
- def zero(physics_model: PhysicsModel) -> "ODEState":
42
- model_state = ODEState(
43
- physics_model=PhysicsModelState.zero(physics_model=physics_model),
44
- soft_contacts=SoftContactsState.zero(physics_model=physics_model),
45
- )
46
-
47
- assert model_state.valid(physics_model)
48
- return model_state
49
-
50
- def valid(self, physics_model: PhysicsModel) -> bool:
51
- return self.physics_model.valid(
52
- physics_model=physics_model
53
- ) and self.soft_contacts.valid(physics_model=physics_model)
@@ -1,125 +0,0 @@
1
- import enum
2
- import functools
3
- from typing import Any, Dict, Tuple, Union
4
-
5
- import jax.flatten_util
6
- from jax.experimental.ode import odeint
7
-
8
- import jaxsim.typing as jtp
9
- from jaxsim.physics.algos.soft_contacts import SoftContactsParams
10
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
11
- from jaxsim.physics.model.physics_model import PhysicsModel
12
- from jaxsim.simulation import integrators, ode
13
-
14
-
15
- class IntegratorType(enum.IntEnum):
16
- RungeKutta4 = enum.auto()
17
- EulerForward = enum.auto()
18
- EulerSemiImplicit = enum.auto()
19
-
20
-
21
- @jax.jit
22
- def ode_integration_rk4_adaptive(
23
- x0: jtp.Array,
24
- t: integrators.TimeHorizon,
25
- physics_model: PhysicsModel,
26
- *args,
27
- **kwargs,
28
- ) -> jtp.Array:
29
- # Close function over its inputs and parameters
30
- dx_dt_closure = lambda x, ts: ode.dx_dt(x, ts, physics_model, *args)
31
-
32
- return odeint(dx_dt_closure, x0, t, **kwargs)
33
-
34
-
35
- @functools.partial(jax.jit, static_argnames=["num_sub_steps", "return_aux"])
36
- def ode_integration_euler(
37
- x0: ode.ode_data.ODEState,
38
- t: integrators.TimeHorizon,
39
- physics_model: PhysicsModel,
40
- soft_contacts_params: SoftContactsParams = SoftContactsParams(),
41
- terrain: Terrain = FlatTerrain(),
42
- ode_input: ode.ode_data.ODEInput | None = None,
43
- *args,
44
- num_sub_steps: int = 1,
45
- return_aux: bool = False,
46
- ) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict[str, Any]]]:
47
- # Close func over additional inputs and parameters
48
- dx_dt_closure = lambda x, ts: ode.dx_dt(
49
- x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
50
- )
51
-
52
- # Integrate over the horizon
53
- out = integrators.odeint_euler(
54
- func=dx_dt_closure,
55
- y0=x0,
56
- t=t,
57
- num_sub_steps=num_sub_steps,
58
- return_aux=return_aux,
59
- )
60
-
61
- # Return output pytree and, optionally, the aux dict
62
- state = out if not return_aux else out[0]
63
- return (state, out[1]) if return_aux else state
64
-
65
-
66
- @functools.partial(jax.jit, static_argnames=["num_sub_steps", "return_aux"])
67
- def ode_integration_euler_semi_implicit(
68
- x0: ode.ode_data.ODEState,
69
- t: integrators.TimeHorizon,
70
- physics_model: PhysicsModel,
71
- soft_contacts_params: SoftContactsParams = SoftContactsParams(),
72
- terrain: Terrain = FlatTerrain(),
73
- ode_input: ode.ode_data.ODEInput | None = None,
74
- *args,
75
- num_sub_steps: int = 1,
76
- return_aux: bool = False,
77
- ) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict[str, Any]]]:
78
- # Close func over additional inputs and parameters
79
- dx_dt_closure = lambda x, ts: ode.dx_dt(
80
- x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
81
- )
82
-
83
- # Integrate over the horizon
84
- out = integrators.odeint_euler_semi_implicit(
85
- func=dx_dt_closure,
86
- y0=x0,
87
- t=t,
88
- num_sub_steps=num_sub_steps,
89
- return_aux=return_aux,
90
- )
91
-
92
- # Return output pytree and, optionally, the aux dict
93
- state = out if not return_aux else out[0]
94
- return (state, out[1]) if return_aux else state
95
-
96
-
97
- @functools.partial(jax.jit, static_argnames=["num_sub_steps", "return_aux"])
98
- def ode_integration_rk4(
99
- x0: ode.ode_data.ODEState,
100
- t: integrators.TimeHorizon,
101
- physics_model: PhysicsModel,
102
- soft_contacts_params: SoftContactsParams = SoftContactsParams(),
103
- terrain: Terrain = FlatTerrain(),
104
- ode_input: ode.ode_data.ODEInput | None = None,
105
- *args,
106
- num_sub_steps=1,
107
- return_aux: bool = False,
108
- ) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict]]:
109
- # Close func over additional inputs and parameters
110
- dx_dt_closure = lambda x, ts: ode.dx_dt(
111
- x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
112
- )
113
-
114
- # Integrate over the horizon
115
- out = integrators.odeint_rk4(
116
- func=dx_dt_closure,
117
- y0=x0,
118
- t=t,
119
- num_sub_steps=num_sub_steps,
120
- return_aux=return_aux,
121
- )
122
-
123
- # Return output pytree and, optionally, the aux dict
124
- state = out if not return_aux else out[0]
125
- return (state, out[1]) if return_aux else state