jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/simulation/ode.py DELETED
@@ -1,290 +0,0 @@
1
- from typing import Any, Dict, Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.physics import algos
9
- from jaxsim.physics.algos.soft_contacts import (
10
- SoftContacts,
11
- SoftContactsParams,
12
- collidable_points_pos_vel,
13
- )
14
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
15
- from jaxsim.physics.model.physics_model import PhysicsModel
16
-
17
- from . import ode_data
18
-
19
-
20
- def compute_contact_forces(
21
- physics_model: PhysicsModel,
22
- ode_state: ode_data.ODEState,
23
- soft_contacts_params: SoftContactsParams = SoftContactsParams(),
24
- terrain: Terrain = FlatTerrain(),
25
- ) -> Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]:
26
- """
27
- Compute the contact forces acting on the collidable points of the model.
28
-
29
- Args:
30
- physics_model: The physics model to consider.
31
- ode_state: The state of the ODE corresponding to the physics model.
32
- soft_contacts_params: The parameters of the soft contacts model.
33
- terrain: The terrain model.
34
-
35
- Returns:
36
- A tuple containing:
37
- - The contact forces expressed in the world frame acting on the model's links.
38
- - The derivative of the tangential deformation of the terrain dynamics.
39
- - The contact forces expressed in the world frame acting on the model's collidable points.
40
- """
41
-
42
- # Compute position and linear mixed velocity of all model's collidable points
43
- # collidable_points_kinematics
44
- pos_cp, vel_cp = collidable_points_pos_vel(
45
- model=physics_model,
46
- q=ode_state.physics_model.joint_positions,
47
- qd=ode_state.physics_model.joint_velocities,
48
- xfb=ode_state.physics_model.xfb(),
49
- )
50
-
51
- # Compute the forces acting on the collidable points due to contact with
52
- # the compliant ground surface. Apply vmap to process all points together.
53
- contact_forces_points, tangential_deformation_dot = jax.vmap(
54
- SoftContacts(parameters=soft_contacts_params, terrain=terrain).contact_model
55
- )(pos_cp.T, vel_cp.T, ode_state.soft_contacts.tangential_deformation.T)
56
-
57
- contact_forces_points = contact_forces_points.T
58
- tangential_deformation_dot = tangential_deformation_dot.T
59
-
60
- # Initialize the contact forces, one per body
61
- contact_forces_links = jnp.zeros_like(
62
- ode_data.ODEInput.zero(physics_model).physics_model.f_ext
63
- )
64
-
65
- # Combine the contact forces of all collidable points belonging to the same body
66
- for body_idx in set(physics_model.gc.body):
67
- body_idx = int(body_idx)
68
- contact_forces_links = contact_forces_links.at[body_idx, :].set(
69
- jnp.sum(contact_forces_points[:, physics_model.gc.body == body_idx], axis=1)
70
- )
71
-
72
- return contact_forces_links, tangential_deformation_dot, contact_forces_points.T
73
-
74
-
75
- def dx_dt(
76
- x: ode_data.ODEState,
77
- t: jtp.Float | None,
78
- physics_model: PhysicsModel,
79
- soft_contacts_params: SoftContactsParams = SoftContactsParams(),
80
- ode_input: ode_data.ODEInput | None = None,
81
- terrain: Terrain = FlatTerrain(),
82
- ) -> Tuple[ode_data.ODEState, Dict[str, Any]]:
83
- """
84
- Compute the state derivative of the ODE corresponding to the physics model.
85
-
86
- Args:
87
- x: The state of the ODE.
88
- t: The current time.
89
- physics_model: The physics model to consider.
90
- soft_contacts_params: The parameters of the soft contacts model.
91
- ode_input: The input of the ODE.
92
- terrain: The terrain model.
93
-
94
- Returns:
95
- A tuple containing:
96
- - The state derivative of the ODE.
97
- - A dictionary containing auxiliary information.
98
- """
99
-
100
- if t is not None and isinstance(t, np.ndarray) and t.size != 1:
101
- raise ValueError(t.size)
102
-
103
- # Initialize arguments
104
- ode_state = x
105
- ode_input = (
106
- ode_input
107
- if ode_input is not None
108
- else ode_data.ODEInput.zero(physics_model=physics_model)
109
- )
110
-
111
- # ======================
112
- # Compute contact forces
113
- # ======================
114
-
115
- # Initialize the collidable points contact forces
116
- contact_forces_points = None
117
-
118
- # Initialize the contact forces, one per body
119
- contact_forces_links = jnp.zeros_like(ode_input.physics_model.f_ext)
120
-
121
- # Initialize the derivative of the tangential deformation
122
- tangential_deformation_dot = jnp.zeros_like(
123
- ode_state.soft_contacts.tangential_deformation
124
- )
125
-
126
- if len(physics_model.gc.body) > 0:
127
- (
128
- contact_forces_links,
129
- tangential_deformation_dot,
130
- contact_forces_points,
131
- ) = compute_contact_forces(
132
- physics_model=physics_model,
133
- soft_contacts_params=soft_contacts_params,
134
- ode_state=ode_state,
135
- terrain=terrain,
136
- )
137
-
138
- # =====================
139
- # Joint position limits
140
- # =====================
141
-
142
- if physics_model.dofs() > 0:
143
- # Get the joint position limits
144
- s_min, s_max = jnp.array(
145
- [j.position_limit for j in physics_model.description.joints_dict.values()]
146
- ).T
147
-
148
- # Get the spring/damper parameters of joint limits enforcement
149
- k_damper = jnp.array(list(physics_model._joint_limit_damper.values()))
150
-
151
- # Compute the joint torques that enforce joint limits
152
- s = ode_state.physics_model.joint_positions
153
- tau_min = jnp.where(s <= s_min, k_damper * (s_min - s), 0)
154
- tau_max = jnp.where(s >= s_max, k_damper * (s_max - s), 0)
155
- tau_limit = tau_max + tau_min
156
-
157
- else:
158
- tau_limit = jnp.zeros_like(ode_input.physics_model.tau)
159
-
160
- # ==============
161
- # Joint friction
162
- # ==============
163
-
164
- if physics_model.dofs() > 0:
165
- # Static and viscous joint friction parameters
166
- kc = jnp.array(list(physics_model._joint_friction_static.values()))
167
- kv = jnp.array(list(physics_model._joint_friction_viscous.values()))
168
-
169
- # Compute the joint friction torque
170
- tau_friction = -(
171
- jnp.diag(kc) @ jnp.sign(ode_state.physics_model.joint_positions)
172
- + jnp.diag(kv) @ ode_state.physics_model.joint_velocities
173
- )
174
-
175
- else:
176
- tau_friction = jnp.zeros_like(ode_input.physics_model.tau)
177
-
178
- # ========================
179
- # Compute forward dynamics
180
- # ========================
181
-
182
- # Compute the total forces applied to the bodies
183
- total_forces = ode_input.physics_model.f_ext + contact_forces_links
184
-
185
- # Compute the joint torques to actuate
186
- tau = ode_input.physics_model.tau + tau_friction + tau_limit
187
-
188
- # Compute forward dynamics with the ABA algorithm
189
- W_a_WB, qdd = algos.aba.aba(
190
- model=physics_model,
191
- xfb=ode_state.physics_model.xfb(),
192
- q=ode_state.physics_model.joint_positions,
193
- qd=ode_state.physics_model.joint_velocities,
194
- tau=tau,
195
- f_ext=total_forces,
196
- )
197
-
198
- # =========================================
199
- # Compute the state derivative of base link
200
- # =========================================
201
-
202
- if not physics_model.is_floating_base:
203
- W_Qd_B = jnp.zeros(4)
204
- BW_v_WB = jnp.zeros(3)
205
-
206
- else:
207
- from jaxsim.math.conv import Convert
208
- from jaxsim.math.quaternion import Quaternion
209
-
210
- W_Qd_B = Quaternion.derivative(
211
- quaternion=ode_state.physics_model.base_quaternion,
212
- omega=ode_state.physics_model.base_angular_velocity,
213
- omega_in_body_fixed=False,
214
- ).squeeze()
215
-
216
- # Compute linear component of mixed velocity
217
- BW_v_WB = Convert.velocities_threed(
218
- v_6d=jnp.hstack(
219
- [
220
- ode_state.physics_model.base_linear_velocity,
221
- ode_state.physics_model.base_angular_velocity,
222
- ]
223
- ),
224
- p=ode_state.physics_model.base_position,
225
- ).squeeze()
226
-
227
- # Derivative of xfb (floating-base state)
228
- xd_fb = jnp.hstack([W_Qd_B, BW_v_WB, W_a_WB.squeeze()]).squeeze()
229
-
230
- # =====================================
231
- # Build the full derivative of ODEState
232
- # =====================================
233
-
234
- def fix_one_dof(vector: jtp.Vector) -> jtp.Vector | None:
235
- """Fix the shape of computed quantities for models with just 1 DoF."""
236
-
237
- if vector is None:
238
- return None
239
-
240
- return jnp.array([vector]) if vector.shape == () else vector
241
-
242
- # Fill the PhysicsModelState object included in the input ODEState to store the
243
- # returned PhysicsModelState derivative
244
- physics_model_state_derivative = ode_state.physics_model.replace(
245
- joint_positions=fix_one_dof(ode_state.physics_model.joint_velocities.squeeze()),
246
- joint_velocities=fix_one_dof(qdd.squeeze()),
247
- base_quaternion=xd_fb.squeeze()[0:4],
248
- base_position=xd_fb.squeeze()[4:7],
249
- base_angular_velocity=xd_fb.squeeze()[10:13],
250
- base_linear_velocity=xd_fb.squeeze()[7:10],
251
- )
252
-
253
- # Fill the SoftContactsState object included in the input ODEState to store the
254
- # returned SoftContactsState derivative
255
- soft_contacts_state_derivative = ode_state.soft_contacts.replace(
256
- tangential_deformation=tangential_deformation_dot.squeeze(),
257
- )
258
-
259
- # We store the state derivative using the ODEState class so that the pytree
260
- # structure remains consistent, allowing to use our generic pytree integrators
261
- state_derivative = ode_data.ODEState(
262
- physics_model=physics_model_state_derivative,
263
- soft_contacts=soft_contacts_state_derivative,
264
- )
265
-
266
- # ===============================
267
- # Build auxiliary data and return
268
- # ===============================
269
-
270
- # Real ODEInput containing the real joint forces that have been actuated and
271
- # the total external forces (= original external forces + terrain + limits)
272
- ode_input_real = ode_data.ODEInput(
273
- physics_model=ode_data.PhysicsModelInput(tau=tau, f_ext=total_forces)
274
- )
275
-
276
- # Pack the inertial-fixed floating-base acceleration
277
- W_nud_WB = jnp.hstack([W_a_WB.squeeze(), qdd.squeeze()])
278
-
279
- # Build the auxiliary data
280
- aux_dict = {
281
- "model_acceleration": W_nud_WB,
282
- "ode_input": ode_input,
283
- "ode_input_real": ode_input_real,
284
- "contact_forces_links": contact_forces_links,
285
- "contact_forces_points": contact_forces_points,
286
- "tangential_deformation_dot": tangential_deformation_dot,
287
- }
288
-
289
- # Return the state derivative as a generic PyTree, and the dict with auxiliary info
290
- return state_derivative, aux_dict
@@ -1,96 +0,0 @@
1
- import jax.flatten_util
2
- import jax_dataclasses
3
-
4
- import jaxsim.typing as jtp
5
- from jaxsim.physics.algos.soft_contacts import SoftContactsState
6
- from jaxsim.physics.model.physics_model import PhysicsModel
7
- from jaxsim.physics.model.physics_model_state import (
8
- PhysicsModelInput,
9
- PhysicsModelState,
10
- )
11
- from jaxsim.utils import JaxsimDataclass
12
-
13
-
14
- @jax_dataclasses.pytree_dataclass
15
- class ODEInput(JaxsimDataclass):
16
- """"""
17
-
18
- physics_model: PhysicsModelInput
19
-
20
- @staticmethod
21
- def build(
22
- physics_model_input: PhysicsModelInput | None = None,
23
- physics_model: PhysicsModel | None = None,
24
- ) -> "ODEInput":
25
- """"""
26
-
27
- physics_model_input = (
28
- physics_model_input
29
- if physics_model_input is not None
30
- else PhysicsModelInput.zero(physics_model=physics_model)
31
- )
32
-
33
- return ODEInput(physics_model=physics_model_input)
34
-
35
- @staticmethod
36
- def zero(physics_model: PhysicsModel) -> "ODEInput":
37
- return ODEInput(
38
- physics_model=PhysicsModelInput.zero(physics_model=physics_model)
39
- )
40
-
41
- def valid(self, physics_model: PhysicsModel) -> bool:
42
- return self.physics_model.valid(physics_model=physics_model)
43
-
44
-
45
- @jax_dataclasses.pytree_dataclass
46
- class ODEState(JaxsimDataclass):
47
- """"""
48
-
49
- physics_model: PhysicsModelState
50
- soft_contacts: SoftContactsState
51
-
52
- @staticmethod
53
- def build(
54
- physics_model_state: PhysicsModelState | None = None,
55
- soft_contacts_state: SoftContactsState | None = None,
56
- physics_model: PhysicsModel | None = None,
57
- ) -> "ODEState":
58
- """"""
59
-
60
- physics_model_state = (
61
- physics_model_state
62
- if physics_model_state is not None
63
- else PhysicsModelState.zero(physics_model=physics_model)
64
- )
65
-
66
- soft_contacts_state = (
67
- soft_contacts_state
68
- if soft_contacts_state is not None
69
- else SoftContactsState.zero(physics_model=physics_model)
70
- )
71
-
72
- return ODEState(
73
- physics_model=physics_model_state, soft_contacts=soft_contacts_state
74
- )
75
-
76
- @staticmethod
77
- def deserialize(data: jtp.VectorJax, physics_model: PhysicsModel) -> "ODEState":
78
- dummy_object = ODEState.zero(physics_model=physics_model)
79
- _, unflatten_data = jax.flatten_util.ravel_pytree(dummy_object)
80
-
81
- return unflatten_data(data)
82
-
83
- @staticmethod
84
- def zero(physics_model: PhysicsModel) -> "ODEState":
85
- model_state = ODEState(
86
- physics_model=PhysicsModelState.zero(physics_model=physics_model),
87
- soft_contacts=SoftContactsState.zero(physics_model=physics_model),
88
- )
89
-
90
- assert model_state.valid(physics_model)
91
- return model_state
92
-
93
- def valid(self, physics_model: PhysicsModel) -> bool:
94
- return self.physics_model.valid(
95
- physics_model=physics_model
96
- ) and self.soft_contacts.valid(physics_model=physics_model)
@@ -1,62 +0,0 @@
1
- import enum
2
- import functools
3
- from typing import Any, Dict, Tuple, Union
4
-
5
- import jax.flatten_util
6
- from jax.experimental.ode import odeint
7
-
8
- import jaxsim.typing as jtp
9
- from jaxsim.physics.algos.soft_contacts import SoftContactsParams
10
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
11
- from jaxsim.physics.model.physics_model import PhysicsModel
12
- from jaxsim.simulation import integrators, ode
13
- from jaxsim.simulation.integrators import IntegratorType
14
-
15
-
16
- @jax.jit
17
- def ode_integration_rk4_adaptive(
18
- x0: jtp.Array,
19
- t: integrators.TimeHorizon,
20
- physics_model: PhysicsModel,
21
- *args,
22
- **kwargs,
23
- ) -> jtp.Array:
24
- # Close function over its inputs and parameters
25
- dx_dt_closure = lambda x, ts: ode.dx_dt(x, ts, physics_model, *args)
26
-
27
- return odeint(dx_dt_closure, x0, t, **kwargs)
28
-
29
-
30
- @functools.partial(
31
- jax.jit, static_argnames=["num_sub_steps", "integrator_type", "return_aux"]
32
- )
33
- def ode_integration_fixed_step(
34
- x0: ode.ode_data.ODEState,
35
- t: integrators.TimeHorizon,
36
- physics_model: PhysicsModel,
37
- integrator_type: IntegratorType,
38
- soft_contacts_params: SoftContactsParams = SoftContactsParams(),
39
- terrain: Terrain = FlatTerrain(),
40
- ode_input: ode.ode_data.ODEInput | None = None,
41
- *args,
42
- num_sub_steps: int = 1,
43
- return_aux: bool = False,
44
- ) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict]]:
45
- # Close func over additional inputs and parameters
46
- dx_dt_closure = lambda x, ts: ode.dx_dt(
47
- x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
48
- )
49
-
50
- # Integrate over the horizon
51
- out = integrators.odeint(
52
- func=dx_dt_closure,
53
- y0=x0,
54
- t=t,
55
- num_sub_steps=num_sub_steps,
56
- return_aux=return_aux,
57
- integrator_type=integrator_type,
58
- )
59
-
60
- # Return output pytree and, optionally, the aux dict
61
- state = out if not return_aux else out[0]
62
- return (state, out[1]) if return_aux else state