jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/link.py ADDED
@@ -0,0 +1,361 @@
1
+ import functools
2
+ from typing import Sequence
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jaxlie
7
+ import numpy as np
8
+
9
+ import jaxsim.api as js
10
+ import jaxsim.rbda
11
+ import jaxsim.typing as jtp
12
+
13
+ from .common import VelRepr
14
+
15
+ # =======================
16
+ # Index-related functions
17
+ # =======================
18
+
19
+
20
+ @functools.partial(jax.jit, static_argnames="link_name")
21
+ def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:
22
+ """
23
+ Convert the name of a link to its index.
24
+
25
+ Args:
26
+ model: The model to consider.
27
+ link_name: The name of the link.
28
+
29
+ Returns:
30
+ The index of the link.
31
+ """
32
+
33
+ if link_name in model.kin_dyn_parameters.link_names:
34
+ return (
35
+ jnp.array(
36
+ np.argwhere(np.array(model.kin_dyn_parameters.link_names) == link_name)
37
+ )
38
+ .squeeze()
39
+ .astype(int)
40
+ )
41
+ return jnp.array(-1).astype(int)
42
+
43
+
44
+ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
45
+ """
46
+ Convert the index of a link to its name.
47
+
48
+ Args:
49
+ model: The model to consider.
50
+ link_index: The index of the link.
51
+
52
+ Returns:
53
+ The name of the link.
54
+ """
55
+
56
+ return model.kin_dyn_parameters.link_names[link_index]
57
+
58
+
59
+ @functools.partial(jax.jit, static_argnames="link_names")
60
+ def names_to_idxs(
61
+ model: js.model.JaxSimModel, *, link_names: Sequence[str]
62
+ ) -> jax.Array:
63
+ """
64
+ Convert a sequence of link names to their corresponding indices.
65
+
66
+ Args:
67
+ model: The model to consider.
68
+ link_names: The names of the links.
69
+
70
+ Returns:
71
+ The indices of the links.
72
+ """
73
+
74
+ return jnp.array(
75
+ [name_to_idx(model=model, link_name=name) for name in link_names],
76
+ ).astype(int)
77
+
78
+
79
+ def idxs_to_names(
80
+ model: js.model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike
81
+ ) -> tuple[str, ...]:
82
+ """
83
+ Convert a sequence of link indices to their corresponding names.
84
+
85
+ Args:
86
+ model: The model to consider.
87
+ link_indices: The indices of the links.
88
+
89
+ Returns:
90
+ The names of the links.
91
+ """
92
+
93
+ return tuple(idx_to_name(model=model, link_index=idx) for idx in link_indices)
94
+
95
+
96
+ # =========
97
+ # Link APIs
98
+ # =========
99
+
100
+
101
+ @jax.jit
102
+ def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
103
+ """
104
+ Return the mass of the link.
105
+
106
+ Args:
107
+ model: The model to consider.
108
+ link_index: The index of the link.
109
+
110
+ Returns:
111
+ The mass of the link.
112
+ """
113
+
114
+ return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float)
115
+
116
+
117
+ @jax.jit
118
+ def spatial_inertia(
119
+ model: js.model.JaxSimModel, *, link_index: jtp.IntLike
120
+ ) -> jtp.Matrix:
121
+ """
122
+ Compute the 6D spatial inertial of the link.
123
+
124
+ Args:
125
+ model: The model to consider.
126
+ link_index: The index of the link.
127
+
128
+ Returns:
129
+ The 6×6 matrix representing the spatial inertia of the link expressed in
130
+ the link frame (body-fixed representation).
131
+ """
132
+
133
+ link_parameters = jax.tree_util.tree_map(
134
+ lambda l: l[link_index], model.kin_dyn_parameters.link_parameters
135
+ )
136
+
137
+ return js.kin_dyn_parameters.LinkParameters.spatial_inertia(link_parameters)
138
+
139
+
140
+ @jax.jit
141
+ def transform(
142
+ model: js.model.JaxSimModel,
143
+ data: js.data.JaxSimModelData,
144
+ *,
145
+ link_index: jtp.IntLike,
146
+ ) -> jtp.Matrix:
147
+ """
148
+ Compute the SE(3) transform from the world frame to the link frame.
149
+
150
+ Args:
151
+ model: The model to consider.
152
+ data: The data of the considered model.
153
+ link_index: The index of the link.
154
+
155
+ Returns:
156
+ The 4x4 matrix representing the transform.
157
+ """
158
+
159
+ return js.model.forward_kinematics(model=model, data=data)[link_index]
160
+
161
+
162
+ @jax.jit
163
+ def com_position(
164
+ model: js.model.JaxSimModel,
165
+ data: js.data.JaxSimModelData,
166
+ *,
167
+ link_index: jtp.IntLike,
168
+ in_link_frame: jtp.BoolLike = True,
169
+ ) -> jtp.Vector:
170
+ """
171
+ Compute the position of the center of mass of the link.
172
+
173
+ Args:
174
+ model: The model to consider.
175
+ data: The data of the considered model.
176
+ link_index: The index of the link.
177
+ in_link_frame:
178
+ Whether to return the position in the link frame or in the world frame.
179
+
180
+ Returns:
181
+ The 3D position of the center of mass of the link.
182
+ """
183
+
184
+ from jaxsim.math.inertia import Inertia
185
+
186
+ _, L_p_CoM, _ = Inertia.to_params(
187
+ M=spatial_inertia(model=model, link_index=link_index)
188
+ )
189
+
190
+ def com_in_link_frame():
191
+ return L_p_CoM.squeeze()
192
+
193
+ def com_in_inertial_frame():
194
+ W_H_L = transform(link_index=link_index, model=model, data=data)
195
+ W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1])
196
+
197
+ return W_p̃_CoM[0:3].squeeze()
198
+
199
+ return jax.lax.select(
200
+ pred=in_link_frame,
201
+ on_true=com_in_link_frame(),
202
+ on_false=com_in_inertial_frame(),
203
+ )
204
+
205
+
206
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
207
+ def jacobian(
208
+ model: js.model.JaxSimModel,
209
+ data: js.data.JaxSimModelData,
210
+ *,
211
+ link_index: jtp.IntLike,
212
+ output_vel_repr: VelRepr | None = None,
213
+ ) -> jtp.Matrix:
214
+ """
215
+ Compute the free-floating jacobian of the link.
216
+
217
+ Args:
218
+ model: The model to consider.
219
+ data: The data of the considered model.
220
+ link_index: The index of the link.
221
+ output_vel_repr:
222
+ The output velocity representation of the free-floating jacobian.
223
+
224
+ Returns:
225
+ The 6×(6+n) free-floating jacobian of the link.
226
+
227
+ Note:
228
+ The input representation of the free-floating jacobian is the active
229
+ velocity representation.
230
+ """
231
+
232
+ output_vel_repr = (
233
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
234
+ )
235
+
236
+ # Compute the doubly-left free-floating full jacobian.
237
+ B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
238
+ model=model,
239
+ joint_positions=data.joint_positions(),
240
+ )
241
+
242
+ # Compute the actual doubly-left free-floating jacobian of the link.
243
+ κ = model.kin_dyn_parameters.support_body_array_bool[link_index]
244
+ B_J_WL_B = jnp.hstack([jnp.ones(5), κ]) * B_J_full_WX_B
245
+
246
+ # Adjust the input representation such that `J_WL_I @ I_ν`.
247
+ match data.velocity_representation:
248
+ case VelRepr.Inertial:
249
+ W_H_B = data.base_transform()
250
+ B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
251
+ B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(
252
+ B_X_W, jnp.eye(model.dofs())
253
+ )
254
+
255
+ case VelRepr.Body:
256
+ B_J_WL_I = B_J_WL_B
257
+
258
+ case VelRepr.Mixed:
259
+ W_R_B = data.base_orientation(dcm=True)
260
+ BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
261
+ B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
262
+ B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(
263
+ B_X_BW, jnp.eye(model.dofs())
264
+ )
265
+
266
+ case _:
267
+ raise ValueError(data.velocity_representation)
268
+
269
+ B_H_L = B_H_Li[link_index]
270
+
271
+ # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
272
+ match output_vel_repr:
273
+ case VelRepr.Inertial:
274
+ W_H_B = data.base_transform()
275
+ W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
276
+ O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I
277
+
278
+ case VelRepr.Body:
279
+ L_X_B = jaxlie.SE3.from_matrix(B_H_L).inverse().adjoint()
280
+ L_J_WL_I = L_X_B @ B_J_WL_I
281
+ O_J_WL_I = L_J_WL_I
282
+
283
+ case VelRepr.Mixed:
284
+ W_H_B = data.base_transform()
285
+ W_H_L = W_H_B @ B_H_L
286
+ LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
287
+ LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
288
+ LW_X_B = jaxlie.SE3.from_matrix(LW_H_B).adjoint()
289
+ LW_J_WL_I = LW_X_B @ B_J_WL_I
290
+ O_J_WL_I = LW_J_WL_I
291
+
292
+ case _:
293
+ raise ValueError(output_vel_repr)
294
+
295
+ return O_J_WL_I
296
+
297
+
298
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
299
+ def velocity(
300
+ model: js.model.JaxSimModel,
301
+ data: js.data.JaxSimModelData,
302
+ *,
303
+ link_index: jtp.IntLike,
304
+ output_vel_repr: VelRepr | None = None,
305
+ ) -> jtp.Vector:
306
+ """
307
+ Compute the 6D velocity of the link.
308
+
309
+ Args:
310
+ model: The model to consider.
311
+ data: The data of the considered model.
312
+ link_index: The index of the link.
313
+ output_vel_repr:
314
+ The output velocity representation of the link velocity.
315
+
316
+ Returns:
317
+ The 6D velocity of the link in the specified velocity representation.
318
+ """
319
+
320
+ output_vel_repr = (
321
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
322
+ )
323
+
324
+ # Get the link jacobian having I as input representation (taken from data)
325
+ # and O as output representation, specified by the user (or taken from data).
326
+ O_J_WL_I = jacobian(
327
+ model=model,
328
+ data=data,
329
+ link_index=link_index,
330
+ output_vel_repr=output_vel_repr,
331
+ )
332
+
333
+ # Get the generalized velocity in the input velocity representation.
334
+ I_ν = data.generalized_velocity()
335
+
336
+ # Compute the link velocity in the output velocity representation.
337
+ return O_J_WL_I @ I_ν
338
+
339
+
340
+ @jax.jit
341
+ def bias_acceleration(
342
+ model: js.model.JaxSimModel,
343
+ data: js.data.JaxSimModelData,
344
+ *,
345
+ link_index: jtp.IntLike,
346
+ ) -> jtp.Vector:
347
+ """
348
+ Compute the bias acceleration of the link.
349
+
350
+ Args:
351
+ model: The model to consider.
352
+ data: The data of the considered model.
353
+ link_index: The index of the link.
354
+
355
+ Returns:
356
+ The 6D bias acceleration of the link.
357
+ """
358
+
359
+ # Compute the bias acceleration of all links in the active representation.
360
+ O_v̇_WL = js.model.link_bias_accelerations(model=model, data=data)[link_index]
361
+ return O_v̇_WL