jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/link.py CHANGED
@@ -1,23 +1,26 @@
1
1
  import functools
2
- from typing import Sequence
2
+ from collections.abc import Sequence
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
- import jaxlie
6
+ import jax.scipy.linalg
7
+ import numpy as np
7
8
 
8
- import jaxsim.physics.algos.jacobian
9
+ import jaxsim.api as js
10
+ import jaxsim.rbda
9
11
  import jaxsim.typing as jtp
10
- from jaxsim.high_level.common import VelRepr
12
+ from jaxsim import exceptions
13
+ from jaxsim.math import Adjoint
11
14
 
12
- from . import data as Data
13
- from . import model as Model
15
+ from .common import VelRepr
14
16
 
15
17
  # =======================
16
18
  # Index-related functions
17
19
  # =======================
18
20
 
19
21
 
20
- def name_to_idx(model: Model.JaxSimModel, *, link_name: str) -> jtp.Int:
22
+ @functools.partial(jax.jit, static_argnames="link_name")
23
+ def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:
21
24
  """
22
25
  Convert the name of a link to its index.
23
26
 
@@ -29,12 +32,17 @@ def name_to_idx(model: Model.JaxSimModel, *, link_name: str) -> jtp.Int:
29
32
  The index of the link.
30
33
  """
31
34
 
32
- return jnp.array(
33
- model.physics_model.description.links_dict[link_name].index, dtype=int
35
+ if link_name not in model.link_names():
36
+ raise ValueError(f"Link '{link_name}' not found in the model.")
37
+
38
+ return (
39
+ jnp.array(model.kin_dyn_parameters.link_names.index(link_name))
40
+ .astype(int)
41
+ .squeeze()
34
42
  )
35
43
 
36
44
 
37
- def idx_to_name(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
45
+ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
38
46
  """
39
47
  Convert the index of a link to its name.
40
48
 
@@ -46,11 +54,19 @@ def idx_to_name(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
46
54
  The name of the link.
47
55
  """
48
56
 
49
- d = {l.index: l.name for l in model.physics_model.description.links_dict.values()}
50
- return d[link_index]
57
+ exceptions.raise_value_error_if(
58
+ condition=link_index < 0,
59
+ msg="Invalid link index '{idx}'",
60
+ idx=link_index,
61
+ )
62
+
63
+ return model.kin_dyn_parameters.link_names[link_index]
51
64
 
52
65
 
53
- def names_to_idxs(model: Model.JaxSimModel, *, link_names: Sequence[str]) -> jax.Array:
66
+ @functools.partial(jax.jit, static_argnames="link_names")
67
+ def names_to_idxs(
68
+ model: js.model.JaxSimModel, *, link_names: Sequence[str]
69
+ ) -> jax.Array:
54
70
  """
55
71
  Convert a sequence of link names to their corresponding indices.
56
72
 
@@ -63,13 +79,12 @@ def names_to_idxs(model: Model.JaxSimModel, *, link_names: Sequence[str]) -> jax
63
79
  """
64
80
 
65
81
  return jnp.array(
66
- [model.physics_model.description.links_dict[name].index for name in link_names],
67
- dtype=int,
68
- )
82
+ [name_to_idx(model=model, link_name=name) for name in link_names],
83
+ ).astype(int)
69
84
 
70
85
 
71
86
  def idxs_to_names(
72
- model: Model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike
87
+ model: js.model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike
73
88
  ) -> tuple[str, ...]:
74
89
  """
75
90
  Convert a sequence of link indices to their corresponding names.
@@ -82,8 +97,7 @@ def idxs_to_names(
82
97
  The names of the links.
83
98
  """
84
99
 
85
- d = {l.index: l.name for l in model.physics_model.description.links_dict.values()}
86
- return tuple(d[i] for i in link_indices)
100
+ return tuple(np.array(model.kin_dyn_parameters.link_names)[list(link_indices)])
87
101
 
88
102
 
89
103
  # =========
@@ -91,21 +105,67 @@ def idxs_to_names(
91
105
  # =========
92
106
 
93
107
 
94
- def mass(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
95
- """"""
108
+ @jax.jit
109
+ def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
110
+ """
111
+ Return the mass of the link.
96
112
 
97
- return model.physics_model._link_masses[link_index].astype(float)
113
+ Args:
114
+ model: The model to consider.
115
+ link_index: The index of the link.
98
116
 
117
+ Returns:
118
+ The mass of the link.
119
+ """
99
120
 
100
- def spatial_inertia(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Matrix:
101
- """"""
121
+ exceptions.raise_value_error_if(
122
+ condition=jnp.array(
123
+ [link_index < 0, link_index >= model.number_of_links()]
124
+ ).any(),
125
+ msg="Invalid link index '{idx}'",
126
+ idx=link_index,
127
+ )
102
128
 
103
- return model.physics_model._link_spatial_inertias[link_index]
129
+ return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float)
130
+
131
+
132
+ @jax.jit
133
+ def spatial_inertia(
134
+ model: js.model.JaxSimModel, *, link_index: jtp.IntLike
135
+ ) -> jtp.Matrix:
136
+ r"""
137
+ Compute the 6D spatial inertial of the link.
138
+
139
+ Args:
140
+ model: The model to consider.
141
+ link_index: The index of the link.
142
+
143
+ Returns:
144
+ The :math:`6 \times 6` matrix representing the spatial inertia of the link expressed in
145
+ the link frame (body-fixed representation).
146
+ """
147
+
148
+ exceptions.raise_value_error_if(
149
+ condition=jnp.array(
150
+ [link_index < 0, link_index >= model.number_of_links()]
151
+ ).any(),
152
+ msg="Invalid link index '{idx}'",
153
+ idx=link_index,
154
+ )
155
+
156
+ link_parameters = jax.tree.map(
157
+ lambda l: l[link_index], model.kin_dyn_parameters.link_parameters
158
+ )
159
+
160
+ return js.kin_dyn_parameters.LinkParameters.spatial_inertia(link_parameters)
104
161
 
105
162
 
106
163
  @jax.jit
107
164
  def transform(
108
- model: Model.JaxSimModel, data: Data.JaxSimModelData, *, link_index: jtp.IntLike
165
+ model: js.model.JaxSimModel,
166
+ data: js.data.JaxSimModelData,
167
+ *,
168
+ link_index: jtp.IntLike,
109
169
  ) -> jtp.Matrix:
110
170
  """
111
171
  Compute the SE(3) transform from the world frame to the link frame.
@@ -119,13 +179,21 @@ def transform(
119
179
  The 4x4 matrix representing the transform.
120
180
  """
121
181
 
122
- return Model.forward_kinematics(model=model, data=data)[link_index]
182
+ exceptions.raise_value_error_if(
183
+ condition=jnp.array(
184
+ [link_index < 0, link_index >= model.number_of_links()]
185
+ ).any(),
186
+ msg="Invalid link index '{idx}'",
187
+ idx=link_index,
188
+ )
189
+
190
+ return js.model.forward_kinematics(model=model, data=data)[link_index]
123
191
 
124
192
 
125
193
  @jax.jit
126
194
  def com_position(
127
- model: Model.JaxSimModel,
128
- data: Data.JaxSimModelData,
195
+ model: js.model.JaxSimModel,
196
+ data: js.data.JaxSimModelData,
129
197
  *,
130
198
  link_index: jtp.IntLike,
131
199
  in_link_frame: jtp.BoolLike = True,
@@ -168,13 +236,13 @@ def com_position(
168
236
 
169
237
  @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
170
238
  def jacobian(
171
- model: Model.JaxSimModel,
172
- data: Data.JaxSimModelData,
239
+ model: js.model.JaxSimModel,
240
+ data: js.data.JaxSimModelData,
173
241
  *,
174
242
  link_index: jtp.IntLike,
175
243
  output_vel_repr: VelRepr | None = None,
176
244
  ) -> jtp.Matrix:
177
- """
245
+ r"""
178
246
  Compute the free-floating jacobian of the link.
179
247
 
180
248
  Args:
@@ -185,78 +253,209 @@ def jacobian(
185
253
  The output velocity representation of the free-floating jacobian.
186
254
 
187
255
  Returns:
188
- The 6x(6+dofs) free-floating jacobian of the link.
256
+ The :math:`6 \times (6+n)` free-floating jacobian of the link.
189
257
 
190
258
  Note:
191
259
  The input representation of the free-floating jacobian is the active
192
260
  velocity representation.
193
261
  """
194
262
 
195
- if output_vel_repr is None:
196
- output_vel_repr = data.velocity_representation
263
+ exceptions.raise_value_error_if(
264
+ condition=jnp.array(
265
+ [link_index < 0, link_index >= model.number_of_links()]
266
+ ).any(),
267
+ msg="Invalid link index '{idx}'",
268
+ idx=link_index,
269
+ )
197
270
 
198
- # Compute the doubly left-trivialized free-floating jacobian
199
- L_J_WL_B = jaxsim.physics.algos.jacobian.jacobian(
200
- model=model.physics_model,
201
- body_index=link_index,
202
- q=data.joint_positions(),
271
+ output_vel_repr = (
272
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
203
273
  )
204
274
 
205
- match data.velocity_representation:
275
+ # Compute the doubly-left free-floating full jacobian.
276
+ B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
277
+ model=model,
278
+ joint_positions=data.joint_positions(),
279
+ )
206
280
 
207
- case VelRepr.Body:
208
- L_J_WL_target = L_J_WL_B
281
+ # Compute the actual doubly-left free-floating jacobian of the link.
282
+ κb = model.kin_dyn_parameters.support_body_array_bool[link_index]
283
+ B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B
209
284
 
285
+ # Adjust the input representation such that `J_WL_I @ I_ν`.
286
+ match data.velocity_representation:
210
287
  case VelRepr.Inertial:
211
- dofs = model.dofs()
212
288
  W_H_B = data.base_transform()
213
-
214
- B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
215
- zero_6n = jnp.zeros(shape=(6, dofs))
216
-
217
- B_T_W = jnp.vstack(
218
- [
219
- jnp.block([B_X_W, zero_6n]),
220
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
221
- ]
289
+ B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
290
+ B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
291
+ B_X_W, jnp.eye(model.dofs())
222
292
  )
223
293
 
224
- L_J_WL_target = L_J_WL_B @ B_T_W
294
+ case VelRepr.Body:
295
+ B_J_WL_I = B_J_WL_B
225
296
 
226
297
  case VelRepr.Mixed:
227
- dofs = model.dofs()
228
- W_H_B = data.base_transform()
229
- BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))
230
-
231
- B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
232
- zero_6n = jnp.zeros(shape=(6, dofs))
233
-
234
- B_T_BW = jnp.vstack(
235
- [
236
- jnp.block([B_X_BW, zero_6n]),
237
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
238
- ]
298
+ W_R_B = data.base_orientation(dcm=True)
299
+ BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
300
+ B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
301
+ B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
302
+ B_X_BW, jnp.eye(model.dofs())
239
303
  )
240
304
 
241
- L_J_WL_target = L_J_WL_B @ B_T_BW
242
-
243
305
  case _:
244
306
  raise ValueError(data.velocity_representation)
245
307
 
246
- match output_vel_repr:
247
- case VelRepr.Body:
248
- return L_J_WL_target
308
+ B_H_L = B_H_Li[link_index]
249
309
 
310
+ # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
311
+ match output_vel_repr:
250
312
  case VelRepr.Inertial:
251
- W_H_L = transform(model=model, data=data, link_index=link_index)
252
- W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint()
253
- return W_X_L @ L_J_WL_target
313
+ W_H_B = data.base_transform()
314
+ W_X_B = Adjoint.from_transform(transform=W_H_B)
315
+ O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841
316
+
317
+ case VelRepr.Body:
318
+ L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True)
319
+ L_J_WL_I = L_X_B @ B_J_WL_I
320
+ O_J_WL_I = L_J_WL_I
254
321
 
255
322
  case VelRepr.Mixed:
256
- W_H_L = transform(model=model, data=data, link_index=link_index)
257
- LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3))
258
- LW_X_L = jaxlie.SE3.from_matrix(LW_H_L).adjoint()
259
- return LW_X_L @ L_J_WL_target
323
+ W_H_B = data.base_transform()
324
+ W_H_L = W_H_B @ B_H_L
325
+ LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
326
+ LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
327
+ LW_X_B = Adjoint.from_transform(transform=LW_H_B)
328
+ LW_J_WL_I = LW_X_B @ B_J_WL_I
329
+ O_J_WL_I = LW_J_WL_I
260
330
 
261
331
  case _:
262
332
  raise ValueError(output_vel_repr)
333
+
334
+ return O_J_WL_I
335
+
336
+
337
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
338
+ def velocity(
339
+ model: js.model.JaxSimModel,
340
+ data: js.data.JaxSimModelData,
341
+ *,
342
+ link_index: jtp.IntLike,
343
+ output_vel_repr: VelRepr | None = None,
344
+ ) -> jtp.Vector:
345
+ """
346
+ Compute the 6D velocity of the link.
347
+
348
+ Args:
349
+ model: The model to consider.
350
+ data: The data of the considered model.
351
+ link_index: The index of the link.
352
+ output_vel_repr:
353
+ The output velocity representation of the link velocity.
354
+
355
+ Returns:
356
+ The 6D velocity of the link in the specified velocity representation.
357
+ """
358
+
359
+ exceptions.raise_value_error_if(
360
+ condition=jnp.array(
361
+ [link_index < 0, link_index >= model.number_of_links()]
362
+ ).any(),
363
+ msg="Invalid link index '{idx}'",
364
+ idx=link_index,
365
+ )
366
+
367
+ output_vel_repr = (
368
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
369
+ )
370
+
371
+ # Get the link jacobian having I as input representation (taken from data)
372
+ # and O as output representation, specified by the user (or taken from data).
373
+ O_J_WL_I = jacobian(
374
+ model=model,
375
+ data=data,
376
+ link_index=link_index,
377
+ output_vel_repr=output_vel_repr,
378
+ )
379
+
380
+ # Get the generalized velocity in the input velocity representation.
381
+ I_ν = data.generalized_velocity()
382
+
383
+ # Compute the link velocity in the output velocity representation.
384
+ return O_J_WL_I @ I_ν
385
+
386
+
387
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
388
+ def jacobian_derivative(
389
+ model: js.model.JaxSimModel,
390
+ data: js.data.JaxSimModelData,
391
+ *,
392
+ link_index: jtp.IntLike,
393
+ output_vel_repr: VelRepr | None = None,
394
+ ) -> jtp.Matrix:
395
+ r"""
396
+ Compute the derivative of the free-floating jacobian of the link.
397
+
398
+ Args:
399
+ model: The model to consider.
400
+ data: The data of the considered model.
401
+ link_index: The index of the link.
402
+ output_vel_repr:
403
+ The output velocity representation of the free-floating jacobian derivative.
404
+
405
+ Returns:
406
+ The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the link.
407
+
408
+ Note:
409
+ The input representation of the free-floating jacobian derivative is the active
410
+ velocity representation.
411
+ """
412
+
413
+ exceptions.raise_value_error_if(
414
+ condition=jnp.array(
415
+ [link_index < 0, link_index >= model.number_of_links()]
416
+ ).any(),
417
+ msg="Invalid link index '{idx}'",
418
+ idx=link_index,
419
+ )
420
+
421
+ output_vel_repr = (
422
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
423
+ )
424
+
425
+ O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative(
426
+ model=model, data=data, output_vel_repr=output_vel_repr
427
+ )[link_index]
428
+
429
+ return O_J̇_WL_I
430
+
431
+
432
+ @jax.jit
433
+ def bias_acceleration(
434
+ model: js.model.JaxSimModel,
435
+ data: js.data.JaxSimModelData,
436
+ *,
437
+ link_index: jtp.IntLike,
438
+ ) -> jtp.Vector:
439
+ """
440
+ Compute the bias acceleration of the link.
441
+
442
+ Args:
443
+ model: The model to consider.
444
+ data: The data of the considered model.
445
+ link_index: The index of the link.
446
+
447
+ Returns:
448
+ The 6D bias acceleration of the link.
449
+ """
450
+
451
+ exceptions.raise_value_error_if(
452
+ condition=jnp.array(
453
+ [link_index < 0, link_index >= model.number_of_links()]
454
+ ).any(),
455
+ msg="Invalid link index '{idx}'",
456
+ idx=link_index,
457
+ )
458
+
459
+ # Compute the bias acceleration of all links in the active representation.
460
+ O_v̇_WL = js.model.link_bias_accelerations(model=model, data=data)[link_index]
461
+ return O_v̇_WL