jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +88 -72
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/collision.py +14 -0
  26. jaxsim/parsers/descriptions/link.py +13 -2
  27. jaxsim/parsers/kinematic_graph.py +5 -0
  28. jaxsim/parsers/rod/utils.py +7 -8
  29. jaxsim/rbda/__init__.py +7 -0
  30. jaxsim/rbda/aba.py +295 -0
  31. jaxsim/rbda/collidable_points.py +142 -0
  32. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  33. jaxsim/rbda/forward_kinematics.py +113 -0
  34. jaxsim/rbda/jacobian.py +201 -0
  35. jaxsim/rbda/rnea.py +237 -0
  36. jaxsim/rbda/soft_contacts.py +296 -0
  37. jaxsim/rbda/utils.py +152 -0
  38. jaxsim/terrain/__init__.py +2 -0
  39. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  40. jaxsim/utils/__init__.py +1 -4
  41. jaxsim/utils/hashless.py +18 -0
  42. jaxsim/utils/jaxsim_dataclass.py +281 -30
  43. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  44. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  45. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  46. jaxsim/high_level/__init__.py +0 -2
  47. jaxsim/high_level/common.py +0 -11
  48. jaxsim/high_level/joint.py +0 -148
  49. jaxsim/high_level/link.py +0 -259
  50. jaxsim/high_level/model.py +0 -1686
  51. jaxsim/math/conv.py +0 -114
  52. jaxsim/math/joint.py +0 -102
  53. jaxsim/math/plucker.py +0 -100
  54. jaxsim/physics/__init__.py +0 -12
  55. jaxsim/physics/algos/__init__.py +0 -0
  56. jaxsim/physics/algos/aba.py +0 -254
  57. jaxsim/physics/algos/aba_motors.py +0 -284
  58. jaxsim/physics/algos/forward_kinematics.py +0 -79
  59. jaxsim/physics/algos/jacobian.py +0 -98
  60. jaxsim/physics/algos/rnea.py +0 -180
  61. jaxsim/physics/algos/rnea_motors.py +0 -196
  62. jaxsim/physics/algos/soft_contacts.py +0 -523
  63. jaxsim/physics/algos/utils.py +0 -69
  64. jaxsim/physics/model/__init__.py +0 -0
  65. jaxsim/physics/model/ground_contact.py +0 -55
  66. jaxsim/physics/model/physics_model.py +0 -388
  67. jaxsim/physics/model/physics_model_state.py +0 -283
  68. jaxsim/simulation/__init__.py +0 -4
  69. jaxsim/simulation/integrators.py +0 -393
  70. jaxsim/simulation/ode.py +0 -290
  71. jaxsim/simulation/ode_data.py +0 -96
  72. jaxsim/simulation/ode_integration.py +0 -62
  73. jaxsim/simulation/simulator.py +0 -543
  74. jaxsim/simulation/simulator_callbacks.py +0 -79
  75. jaxsim/simulation/utils.py +0 -15
  76. jaxsim/sixd/__init__.py +0 -2
  77. jaxsim/utils/oop.py +0 -536
  78. jaxsim/utils/vmappable.py +0 -117
  79. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  80. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  81. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/link.py CHANGED
@@ -4,20 +4,21 @@ from typing import Sequence
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import jaxlie
7
+ import numpy as np
7
8
 
8
- import jaxsim.physics.algos.jacobian
9
+ import jaxsim.api as js
10
+ import jaxsim.rbda
9
11
  import jaxsim.typing as jtp
10
- from jaxsim.high_level.common import VelRepr
11
12
 
12
- from . import data as Data
13
- from . import model as Model
13
+ from .common import VelRepr
14
14
 
15
15
  # =======================
16
16
  # Index-related functions
17
17
  # =======================
18
18
 
19
19
 
20
- def name_to_idx(model: Model.JaxSimModel, *, link_name: str) -> jtp.Int:
20
+ @functools.partial(jax.jit, static_argnames="link_name")
21
+ def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:
21
22
  """
22
23
  Convert the name of a link to its index.
23
24
 
@@ -29,12 +30,18 @@ def name_to_idx(model: Model.JaxSimModel, *, link_name: str) -> jtp.Int:
29
30
  The index of the link.
30
31
  """
31
32
 
32
- return jnp.array(
33
- model.physics_model.description.links_dict[link_name].index, dtype=int
34
- )
33
+ if link_name in model.kin_dyn_parameters.link_names:
34
+ return (
35
+ jnp.array(
36
+ np.argwhere(np.array(model.kin_dyn_parameters.link_names) == link_name)
37
+ )
38
+ .squeeze()
39
+ .astype(int)
40
+ )
41
+ return jnp.array(-1).astype(int)
35
42
 
36
43
 
37
- def idx_to_name(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
44
+ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
38
45
  """
39
46
  Convert the index of a link to its name.
40
47
 
@@ -46,11 +53,13 @@ def idx_to_name(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
46
53
  The name of the link.
47
54
  """
48
55
 
49
- d = {l.index: l.name for l in model.physics_model.description.links_dict.values()}
50
- return d[link_index]
56
+ return model.kin_dyn_parameters.link_names[link_index]
51
57
 
52
58
 
53
- def names_to_idxs(model: Model.JaxSimModel, *, link_names: Sequence[str]) -> jax.Array:
59
+ @functools.partial(jax.jit, static_argnames="link_names")
60
+ def names_to_idxs(
61
+ model: js.model.JaxSimModel, *, link_names: Sequence[str]
62
+ ) -> jax.Array:
54
63
  """
55
64
  Convert a sequence of link names to their corresponding indices.
56
65
 
@@ -63,13 +72,12 @@ def names_to_idxs(model: Model.JaxSimModel, *, link_names: Sequence[str]) -> jax
63
72
  """
64
73
 
65
74
  return jnp.array(
66
- [model.physics_model.description.links_dict[name].index for name in link_names],
67
- dtype=int,
68
- )
75
+ [name_to_idx(model=model, link_name=name) for name in link_names],
76
+ ).astype(int)
69
77
 
70
78
 
71
79
  def idxs_to_names(
72
- model: Model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike
80
+ model: js.model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike
73
81
  ) -> tuple[str, ...]:
74
82
  """
75
83
  Convert a sequence of link indices to their corresponding names.
@@ -82,8 +90,7 @@ def idxs_to_names(
82
90
  The names of the links.
83
91
  """
84
92
 
85
- d = {l.index: l.name for l in model.physics_model.description.links_dict.values()}
86
- return tuple(d[i] for i in link_indices)
93
+ return tuple(idx_to_name(model=model, link_index=idx) for idx in link_indices)
87
94
 
88
95
 
89
96
  # =========
@@ -91,21 +98,51 @@ def idxs_to_names(
91
98
  # =========
92
99
 
93
100
 
94
- def mass(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
95
- """"""
101
+ @jax.jit
102
+ def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
103
+ """
104
+ Return the mass of the link.
105
+
106
+ Args:
107
+ model: The model to consider.
108
+ link_index: The index of the link.
96
109
 
97
- return model.physics_model._link_masses[link_index].astype(float)
110
+ Returns:
111
+ The mass of the link.
112
+ """
98
113
 
114
+ return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float)
99
115
 
100
- def spatial_inertia(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Matrix:
101
- """"""
102
116
 
103
- return model.physics_model._link_spatial_inertias[link_index]
117
+ @jax.jit
118
+ def spatial_inertia(
119
+ model: js.model.JaxSimModel, *, link_index: jtp.IntLike
120
+ ) -> jtp.Matrix:
121
+ """
122
+ Compute the 6D spatial inertial of the link.
123
+
124
+ Args:
125
+ model: The model to consider.
126
+ link_index: The index of the link.
127
+
128
+ Returns:
129
+ The 6×6 matrix representing the spatial inertia of the link expressed in
130
+ the link frame (body-fixed representation).
131
+ """
132
+
133
+ link_parameters = jax.tree_util.tree_map(
134
+ lambda l: l[link_index], model.kin_dyn_parameters.link_parameters
135
+ )
136
+
137
+ return js.kin_dyn_parameters.LinkParameters.spatial_inertia(link_parameters)
104
138
 
105
139
 
106
140
  @jax.jit
107
141
  def transform(
108
- model: Model.JaxSimModel, data: Data.JaxSimModelData, *, link_index: jtp.IntLike
142
+ model: js.model.JaxSimModel,
143
+ data: js.data.JaxSimModelData,
144
+ *,
145
+ link_index: jtp.IntLike,
109
146
  ) -> jtp.Matrix:
110
147
  """
111
148
  Compute the SE(3) transform from the world frame to the link frame.
@@ -119,13 +156,13 @@ def transform(
119
156
  The 4x4 matrix representing the transform.
120
157
  """
121
158
 
122
- return Model.forward_kinematics(model=model, data=data)[link_index]
159
+ return js.model.forward_kinematics(model=model, data=data)[link_index]
123
160
 
124
161
 
125
162
  @jax.jit
126
163
  def com_position(
127
- model: Model.JaxSimModel,
128
- data: Data.JaxSimModelData,
164
+ model: js.model.JaxSimModel,
165
+ data: js.data.JaxSimModelData,
129
166
  *,
130
167
  link_index: jtp.IntLike,
131
168
  in_link_frame: jtp.BoolLike = True,
@@ -168,8 +205,8 @@ def com_position(
168
205
 
169
206
  @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
170
207
  def jacobian(
171
- model: Model.JaxSimModel,
172
- data: Data.JaxSimModelData,
208
+ model: js.model.JaxSimModel,
209
+ data: js.data.JaxSimModelData,
173
210
  *,
174
211
  link_index: jtp.IntLike,
175
212
  output_vel_repr: VelRepr | None = None,
@@ -185,78 +222,116 @@ def jacobian(
185
222
  The output velocity representation of the free-floating jacobian.
186
223
 
187
224
  Returns:
188
- The 6x(6+dofs) free-floating jacobian of the link.
225
+ The (6+n) free-floating jacobian of the link.
189
226
 
190
227
  Note:
191
228
  The input representation of the free-floating jacobian is the active
192
229
  velocity representation.
193
230
  """
194
231
 
195
- if output_vel_repr is None:
196
- output_vel_repr = data.velocity_representation
197
-
198
- # Compute the doubly left-trivialized free-floating jacobian
199
- L_J_WL_B = jaxsim.physics.algos.jacobian.jacobian(
200
- model=model.physics_model,
201
- body_index=link_index,
202
- q=data.joint_positions(),
232
+ output_vel_repr = (
233
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
203
234
  )
204
235
 
205
- match data.velocity_representation:
236
+ # Compute the doubly-left free-floating full jacobian.
237
+ B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
238
+ model=model,
239
+ joint_positions=data.joint_positions(),
240
+ )
206
241
 
207
- case VelRepr.Body:
208
- L_J_WL_target = L_J_WL_B
242
+ # Compute the actual doubly-left free-floating jacobian of the link.
243
+ κ = model.kin_dyn_parameters.support_body_array_bool[link_index]
244
+ B_J_WL_B = jnp.hstack([jnp.ones(5), κ]) * B_J_full_WX_B
209
245
 
246
+ # Adjust the input representation such that `J_WL_I @ I_ν`.
247
+ match data.velocity_representation:
210
248
  case VelRepr.Inertial:
211
- dofs = model.dofs()
212
249
  W_H_B = data.base_transform()
213
-
214
250
  B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
215
- zero_6n = jnp.zeros(shape=(6, dofs))
216
-
217
- B_T_W = jnp.vstack(
218
- [
219
- jnp.block([B_X_W, zero_6n]),
220
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
221
- ]
251
+ B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(
252
+ B_X_W, jnp.eye(model.dofs())
222
253
  )
223
254
 
224
- L_J_WL_target = L_J_WL_B @ B_T_W
255
+ case VelRepr.Body:
256
+ B_J_WL_I = B_J_WL_B
225
257
 
226
258
  case VelRepr.Mixed:
227
- dofs = model.dofs()
228
- W_H_B = data.base_transform()
229
- BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))
230
-
259
+ W_R_B = data.base_orientation(dcm=True)
260
+ BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
231
261
  B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
232
- zero_6n = jnp.zeros(shape=(6, dofs))
233
-
234
- B_T_BW = jnp.vstack(
235
- [
236
- jnp.block([B_X_BW, zero_6n]),
237
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
238
- ]
262
+ B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(
263
+ B_X_BW, jnp.eye(model.dofs())
239
264
  )
240
265
 
241
- L_J_WL_target = L_J_WL_B @ B_T_BW
242
-
243
266
  case _:
244
267
  raise ValueError(data.velocity_representation)
245
268
 
246
- match output_vel_repr:
247
- case VelRepr.Body:
248
- return L_J_WL_target
269
+ B_H_L = B_H_Li[link_index]
249
270
 
271
+ # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
272
+ match output_vel_repr:
250
273
  case VelRepr.Inertial:
251
- W_H_L = transform(model=model, data=data, link_index=link_index)
252
- W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint()
253
- return W_X_L @ L_J_WL_target
274
+ W_H_B = data.base_transform()
275
+ W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
276
+ O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I
277
+
278
+ case VelRepr.Body:
279
+ L_X_B = jaxlie.SE3.from_matrix(B_H_L).inverse().adjoint()
280
+ L_J_WL_I = L_X_B @ B_J_WL_I
281
+ O_J_WL_I = L_J_WL_I
254
282
 
255
283
  case VelRepr.Mixed:
256
- W_H_L = transform(model=model, data=data, link_index=link_index)
257
- LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3))
258
- LW_X_L = jaxlie.SE3.from_matrix(LW_H_L).adjoint()
259
- return LW_X_L @ L_J_WL_target
284
+ W_H_B = data.base_transform()
285
+ W_H_L = W_H_B @ B_H_L
286
+ LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
287
+ LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
288
+ LW_X_B = jaxlie.SE3.from_matrix(LW_H_B).adjoint()
289
+ LW_J_WL_I = LW_X_B @ B_J_WL_I
290
+ O_J_WL_I = LW_J_WL_I
260
291
 
261
292
  case _:
262
293
  raise ValueError(output_vel_repr)
294
+
295
+ return O_J_WL_I
296
+
297
+
298
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
299
+ def velocity(
300
+ model: js.model.JaxSimModel,
301
+ data: js.data.JaxSimModelData,
302
+ *,
303
+ link_index: jtp.IntLike,
304
+ output_vel_repr: VelRepr | None = None,
305
+ ) -> jtp.Vector:
306
+ """
307
+ Compute the 6D velocity of the link.
308
+
309
+ Args:
310
+ model: The model to consider.
311
+ data: The data of the considered model.
312
+ link_index: The index of the link.
313
+ output_vel_repr:
314
+ The output velocity representation of the link velocity.
315
+
316
+ Returns:
317
+ The 6D velocity of the link in the specified velocity representation.
318
+ """
319
+
320
+ output_vel_repr = (
321
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
322
+ )
323
+
324
+ # Get the link jacobian having I as input representation (taken from data)
325
+ # and O as output representation, specified by the user (or taken from data).
326
+ O_J_WL_I = jacobian(
327
+ model=model,
328
+ data=data,
329
+ link_index=link_index,
330
+ output_vel_repr=output_vel_repr,
331
+ )
332
+
333
+ # Get the generalized velocity in the input velocity representation.
334
+ I_ν = data.generalized_velocity()
335
+
336
+ # Compute the link velocity in the output velocity representation.
337
+ return O_J_WL_I @ I_ν