jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py ADDED
@@ -0,0 +1,295 @@
1
+ from typing import Any, Protocol
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ import jaxsim.api as js
7
+ import jaxsim.rbda
8
+ import jaxsim.typing as jtp
9
+ from jaxsim.integrators import Time
10
+ from jaxsim.math import Quaternion
11
+
12
+ from .common import VelRepr
13
+ from .ode_data import ODEState
14
+
15
+
16
+ class SystemDynamicsFromModelAndData(Protocol):
17
+ def __call__(
18
+ self,
19
+ model: js.model.JaxSimModel,
20
+ data: js.data.JaxSimModelData,
21
+ **kwargs: dict[str, Any],
22
+ ) -> tuple[ODEState, dict[str, Any]]: ...
23
+
24
+
25
+ def wrap_system_dynamics_for_integration(
26
+ model: js.model.JaxSimModel,
27
+ data: js.data.JaxSimModelData,
28
+ *,
29
+ system_dynamics: SystemDynamicsFromModelAndData,
30
+ **kwargs,
31
+ ) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
32
+ """
33
+ Wrap generic system dynamics operating on `JaxSimModel` and `JaxSimModelData`
34
+ for integration with `jaxsim.integrators`.
35
+
36
+ Args:
37
+ model: The model to consider.
38
+ data: The data of the considered model.
39
+ system_dynamics: The system dynamics to wrap.
40
+ **kwargs: Additional kwargs to close over the system dynamics.
41
+
42
+ Returns:
43
+ The system dynamics closed over the model, the data, and the additional kwargs.
44
+ """
45
+
46
+ # We allow to close `system_dynamics` over additional kwargs.
47
+ kwargs_closed = kwargs.copy()
48
+
49
+ # Create a local copy of model and data.
50
+ # The wrapped dynamics will hold a reference of this object.
51
+ model_closed = model.copy()
52
+ data_closed = data.copy().replace(
53
+ state=js.ode_data.ODEState.zero(model=model_closed)
54
+ )
55
+
56
+ def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
57
+
58
+ # Allow caller to override the closed data and model objects.
59
+ data_f = kwargs_f.pop("data", data_closed)
60
+ model_f = kwargs_f.pop("model", model_closed)
61
+
62
+ # Update the state and time stored inside data.
63
+ with data_f.editable(validate=True) as data_rw:
64
+ data_rw.state = x
65
+ data_rw.time_ns = jnp.array(t * 1e9).astype(data_rw.time_ns.dtype)
66
+
67
+ # Evaluate the system dynamics, allowing to override the kwargs originally
68
+ # passed when the closure was created.
69
+ return system_dynamics(
70
+ model=model_f,
71
+ data=data_rw,
72
+ **(kwargs_closed | kwargs_f),
73
+ )
74
+
75
+ f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
76
+ return f
77
+
78
+
79
+ # ==================================
80
+ # Functions defining system dynamics
81
+ # ==================================
82
+
83
+
84
+ @jax.jit
85
+ def system_velocity_dynamics(
86
+ model: js.model.JaxSimModel,
87
+ data: js.data.JaxSimModelData,
88
+ *,
89
+ joint_forces: jtp.Vector | None = None,
90
+ link_forces: jtp.Vector | None = None,
91
+ ) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
92
+ """
93
+ Compute the dynamics of the system velocity.
94
+
95
+ Args:
96
+ model: The model to consider.
97
+ data: The data of the considered model.
98
+ joint_forces: The joint forces to apply.
99
+ link_forces: The 6D forces to apply to the links.
100
+
101
+ Returns:
102
+ A tuple containing the derivative of the base 6D velocity in inertial-fixed
103
+ representation, the derivative of the joint velocities, the derivative of
104
+ the material deformation, and the dictionary of auxiliary data returned by
105
+ the system dynamics evaluation.
106
+ """
107
+
108
+ # Build joint torques if not provided
109
+ τ = (
110
+ jnp.atleast_1d(joint_forces.squeeze())
111
+ if joint_forces is not None
112
+ else jnp.zeros_like(data.joint_positions())
113
+ ).astype(float)
114
+
115
+ # Build link forces if not provided
116
+ W_f_L = (
117
+ jnp.atleast_2d(link_forces.squeeze())
118
+ if link_forces is not None
119
+ else jnp.zeros((model.number_of_links(), 6))
120
+ ).astype(float)
121
+
122
+ # ======================
123
+ # Compute contact forces
124
+ # ======================
125
+
126
+ # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
127
+ # with the terrain.
128
+ W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float)
129
+
130
+ # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
131
+ # expressed in the world frame.
132
+ W_f_Ci = None
133
+
134
+ # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
135
+ ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
136
+
137
+ if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
138
+ # Compute the 6D forces applied to each collidable point and the
139
+ # corresponding material deformation rates.
140
+ with data.switch_velocity_representation(VelRepr.Inertial):
141
+ W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data)
142
+
143
+ # Construct the vector defining the parent link index of each collidable point.
144
+ # We use this vector to sum the 6D forces of all collidable points rigidly
145
+ # attached to the same link.
146
+ parent_link_index_of_collidable_points = jnp.array(
147
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
148
+ )
149
+
150
+ # Sum the forces of all collidable points rigidly attached to a body.
151
+ # Since the contact forces W_f_Ci are expressed in the world frame,
152
+ # we don't need any coordinate transformation.
153
+ W_f_Li_terrain = jax.vmap(
154
+ lambda nc: (
155
+ jnp.vstack(
156
+ jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
157
+ )
158
+ * W_f_Ci
159
+ ).sum(axis=0)
160
+ )(jnp.arange(model.number_of_links()))
161
+
162
+ # ====================
163
+ # Enforce joint limits
164
+ # ====================
165
+
166
+ # TODO: enforce joint limits
167
+ τ_position_limit = jnp.zeros_like(τ).astype(float)
168
+
169
+ # ====================
170
+ # Joint friction model
171
+ # ====================
172
+
173
+ τ_friction = jnp.zeros_like(τ).astype(float)
174
+
175
+ if model.dofs() > 0:
176
+ # Static and viscous joint friction parameters
177
+ kc = jnp.array(
178
+ model.kin_dyn_parameters.joint_parameters.friction_static
179
+ ).astype(float)
180
+ kv = jnp.array(
181
+ model.kin_dyn_parameters.joint_parameters.friction_viscous
182
+ ).astype(float)
183
+
184
+ # Compute the joint friction torque
185
+ τ_friction = -(
186
+ jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_positions)
187
+ + jnp.diag(kv) @ data.state.physics_model.joint_velocities
188
+ )
189
+
190
+ # ========================
191
+ # Compute forward dynamics
192
+ # ========================
193
+
194
+ # Compute the total joint forces
195
+ τ_total = τ + τ_friction + τ_position_limit
196
+
197
+ # Compute the total external 6D forces applied to the links
198
+ W_f_L_total = W_f_L + W_f_Li_terrain
199
+
200
+ # - Joint accelerations: s̈ ∈ ℝⁿ
201
+ # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
202
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
203
+ W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
204
+ model=model,
205
+ data=data,
206
+ joint_forces=τ_total,
207
+ link_forces=W_f_L_total,
208
+ )
209
+
210
+ return W_v̇_WB, s̈, ṁ, dict()
211
+
212
+
213
+ @jax.jit
214
+ def system_position_dynamics(
215
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
216
+ ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
217
+ """
218
+ Compute the dynamics of the system position.
219
+
220
+ Args:
221
+ model: The model to consider.
222
+ data: The data of the considered model.
223
+
224
+ Returns:
225
+ A tuple containing the derivative of the base position, the derivative of the
226
+ base quaternion, and the derivative of the joint positions.
227
+ """
228
+
229
+ ṡ = data.joint_velocities(model=model)
230
+ W_Q_B = data.base_orientation(dcm=False)
231
+
232
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
233
+ W_ṗ_B = data.base_velocity()[0:3]
234
+
235
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
236
+ W_ω_WB = data.base_velocity()[3:6]
237
+
238
+ W_Q̇_B = Quaternion.derivative(
239
+ quaternion=W_Q_B,
240
+ omega=W_ω_WB,
241
+ omega_in_body_fixed=False,
242
+ ).squeeze()
243
+
244
+ return W_ṗ_B, W_Q̇_B, ṡ
245
+
246
+
247
+ @jax.jit
248
+ def system_dynamics(
249
+ model: js.model.JaxSimModel,
250
+ data: js.data.JaxSimModelData,
251
+ *,
252
+ joint_forces: jtp.Vector | None = None,
253
+ link_forces: jtp.Vector | None = None,
254
+ ) -> tuple[ODEState, dict[str, Any]]:
255
+ """
256
+ Compute the dynamics of the system.
257
+
258
+ Args:
259
+ model: The model to consider.
260
+ data: The data of the considered model.
261
+ joint_forces: The joint forces to apply.
262
+ link_forces: The 6D forces to apply to the links.
263
+
264
+ Returns:
265
+ A tuple with an `ODEState` object storing in each of its attributes the
266
+ corresponding derivative, and the dictionary of auxiliary data returned
267
+ by the system dynamics evaluation.
268
+ """
269
+
270
+ # Compute the accelerations and the material deformation rate.
271
+ W_v̇_WB, s̈, ṁ, aux_dict = system_velocity_dynamics(
272
+ model=model,
273
+ data=data,
274
+ joint_forces=joint_forces,
275
+ link_forces=link_forces,
276
+ )
277
+
278
+ # Extract the velocities.
279
+ W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(model=model, data=data)
280
+
281
+ # Create an ODEState object populated with the derivative of each leaf.
282
+ # Our integrators, operating on generic pytrees, will be able to handle it
283
+ # automatically as state derivative.
284
+ ode_state_derivative = ODEState.build_from_jaxsim_model(
285
+ model=model,
286
+ base_position=W_ṗ_B,
287
+ base_quaternion=W_Q̇_B,
288
+ joint_positions=ṡ,
289
+ base_linear_velocity=W_v̇_WB[0:3],
290
+ base_angular_velocity=W_v̇_WB[3:6],
291
+ joint_velocities=s̈,
292
+ tangential_deformation=ṁ,
293
+ )
294
+
295
+ return ode_state_derivative, aux_dict