jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/api/frame.py ADDED
@@ -0,0 +1,471 @@
1
+ import functools
2
+ from collections.abc import Sequence
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ import jaxsim.api as js
8
+ import jaxsim.typing as jtp
9
+ from jaxsim import exceptions
10
+ from jaxsim.math import Adjoint, Cross, Transform
11
+
12
+ from .common import VelRepr
13
+
14
+ # =======================
15
+ # Index-related functions
16
+ # =======================
17
+
18
+
19
+ @jax.jit
20
+ @js.common.named_scope
21
+ def idx_of_parent_link(
22
+ model: js.model.JaxSimModel, *, frame_index: jtp.IntLike
23
+ ) -> jtp.Int:
24
+ """
25
+ Get the index of the link to which the frame is rigidly attached.
26
+
27
+ Args:
28
+ model: The model to consider.
29
+ frame_index: The index of the frame.
30
+
31
+ Returns:
32
+ The index of the frame's parent link.
33
+ """
34
+
35
+ n_l = model.number_of_links()
36
+ n_f = len(model.frame_names())
37
+
38
+ exceptions.raise_value_error_if(
39
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
40
+ msg="Invalid frame index '{idx}'",
41
+ idx=frame_index,
42
+ )
43
+
44
+ return jnp.array(model.kin_dyn_parameters.frame_parameters.body)[
45
+ frame_index - model.number_of_links()
46
+ ]
47
+
48
+
49
+ @functools.partial(jax.jit, static_argnames="frame_name")
50
+ @js.common.named_scope
51
+ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
52
+ """
53
+ Convert the name of a frame to its index.
54
+
55
+ Args:
56
+ model: The model to consider.
57
+ frame_name: The name of the frame.
58
+
59
+ Returns:
60
+ The index of the frame.
61
+ """
62
+
63
+ if frame_name not in model.frame_names():
64
+ raise ValueError(f"Frame '{frame_name}' not found in the model.")
65
+
66
+ return (
67
+ jnp.array(
68
+ model.number_of_links()
69
+ + model.kin_dyn_parameters.frame_parameters.name.index(frame_name)
70
+ )
71
+ .astype(int)
72
+ .squeeze()
73
+ )
74
+
75
+
76
+ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
77
+ """
78
+ Convert the index of a frame to its name.
79
+
80
+ Args:
81
+ model: The model to consider.
82
+ frame_index: The index of the frame.
83
+
84
+ Returns:
85
+ The name of the frame.
86
+ """
87
+
88
+ n_l = model.number_of_links()
89
+ n_f = len(model.frame_names())
90
+
91
+ exceptions.raise_value_error_if(
92
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
93
+ msg="Invalid frame index '{idx}'",
94
+ idx=frame_index,
95
+ )
96
+
97
+ return model.kin_dyn_parameters.frame_parameters.name[
98
+ frame_index - model.number_of_links()
99
+ ]
100
+
101
+
102
+ @functools.partial(jax.jit, static_argnames=["frame_names"])
103
+ @js.common.named_scope
104
+ def names_to_idxs(
105
+ model: js.model.JaxSimModel, *, frame_names: Sequence[str]
106
+ ) -> jax.Array:
107
+ """
108
+ Convert a sequence of frame names to their corresponding indices.
109
+
110
+ Args:
111
+ model: The model to consider.
112
+ frame_names: The names of the frames.
113
+
114
+ Returns:
115
+ The indices of the frames.
116
+ """
117
+
118
+ return jnp.array(
119
+ [name_to_idx(model=model, frame_name=name) for name in frame_names]
120
+ ).astype(int)
121
+
122
+
123
+ def idxs_to_names(
124
+ model: js.model.JaxSimModel, *, frame_indices: Sequence[jtp.IntLike]
125
+ ) -> tuple[str, ...]:
126
+ """
127
+ Convert a sequence of frame indices to their corresponding names.
128
+
129
+ Args:
130
+ model: The model to consider.
131
+ frame_indices: The indices of the frames.
132
+
133
+ Returns:
134
+ The names of the frames.
135
+ """
136
+
137
+ return tuple(idx_to_name(model=model, frame_index=idx) for idx in frame_indices)
138
+
139
+
140
+ # ==========
141
+ # Frame APIs
142
+ # ==========
143
+
144
+
145
+ @jax.jit
146
+ @js.common.named_scope
147
+ def transform(
148
+ model: js.model.JaxSimModel,
149
+ data: js.data.JaxSimModelData,
150
+ *,
151
+ frame_index: jtp.IntLike,
152
+ ) -> jtp.Matrix:
153
+ """
154
+ Compute the SE(3) transform from the world frame to the specified frame.
155
+
156
+ Args:
157
+ model: The model to consider.
158
+ data: The data of the considered model.
159
+ frame_index: The index of the frame for which the transform is requested.
160
+
161
+ Returns:
162
+ The 4x4 matrix representing the transform.
163
+ """
164
+
165
+ n_l = model.number_of_links()
166
+ n_f = len(model.frame_names())
167
+
168
+ exceptions.raise_value_error_if(
169
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
170
+ msg="Invalid frame index '{idx}'",
171
+ idx=frame_index,
172
+ )
173
+
174
+ # Compute the necessary transforms.
175
+ L = idx_of_parent_link(model=model, frame_index=frame_index)
176
+ W_H_L = js.link.transform(model=model, data=data, link_index=L)
177
+
178
+ # Get the static frame pose wrt the parent link.
179
+ L_H_F = model.kin_dyn_parameters.frame_parameters.transform[
180
+ frame_index - model.number_of_links()
181
+ ]
182
+
183
+ # Combine the transforms computing the frame pose.
184
+ return W_H_L @ L_H_F
185
+
186
+
187
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
188
+ @js.common.named_scope
189
+ def velocity(
190
+ model: js.model.JaxSimModel,
191
+ data: js.data.JaxSimModelData,
192
+ *,
193
+ frame_index: jtp.IntLike,
194
+ output_vel_repr: VelRepr | None = None,
195
+ ) -> jtp.Vector:
196
+ """
197
+ Compute the 6D velocity of the frame.
198
+
199
+ Args:
200
+ model: The model to consider.
201
+ data: The data of the considered model.
202
+ frame_index: The index of the frame.
203
+ output_vel_repr:
204
+ The output velocity representation of the frame velocity.
205
+
206
+ Returns:
207
+ The 6D velocity of the frame in the specified velocity representation.
208
+ """
209
+ n_l = model.number_of_links()
210
+ n_f = model.number_of_frames()
211
+
212
+ exceptions.raise_value_error_if(
213
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
214
+ msg="Invalid frame index '{idx}'",
215
+ idx=frame_index,
216
+ )
217
+
218
+ output_vel_repr = (
219
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
220
+ )
221
+
222
+ # Get the frame jacobian having I as input representation (taken from data)
223
+ # and O as output representation, specified by the user (or taken from data).
224
+ O_J_WF_I = jacobian(
225
+ model=model,
226
+ data=data,
227
+ frame_index=frame_index,
228
+ output_vel_repr=output_vel_repr,
229
+ )
230
+
231
+ # Get the generalized velocity in the input velocity representation.
232
+ I_ν = data.generalized_velocity()
233
+
234
+ # Compute the frame velocity in the output velocity representation.
235
+ return O_J_WF_I @ I_ν
236
+
237
+
238
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
239
+ @js.common.named_scope
240
+ def jacobian(
241
+ model: js.model.JaxSimModel,
242
+ data: js.data.JaxSimModelData,
243
+ *,
244
+ frame_index: jtp.IntLike,
245
+ output_vel_repr: VelRepr | None = None,
246
+ ) -> jtp.Matrix:
247
+ r"""
248
+ Compute the free-floating jacobian of the frame.
249
+
250
+ Args:
251
+ model: The model to consider.
252
+ data: The data of the considered model.
253
+ frame_index: The index of the frame.
254
+ output_vel_repr:
255
+ The output velocity representation of the free-floating jacobian.
256
+
257
+ Returns:
258
+ The :math:`6 \times (6+n)` free-floating jacobian of the frame.
259
+
260
+ Note:
261
+ The input representation of the free-floating jacobian is the active
262
+ velocity representation.
263
+ """
264
+
265
+ n_l = model.number_of_links()
266
+ n_f = model.number_of_frames()
267
+
268
+ exceptions.raise_value_error_if(
269
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
270
+ msg="Invalid frame index '{idx}'",
271
+ idx=frame_index,
272
+ )
273
+
274
+ output_vel_repr = (
275
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
276
+ )
277
+
278
+ # Get the index of the parent link.
279
+ L = idx_of_parent_link(model=model, frame_index=frame_index)
280
+
281
+ # Compute the Jacobian of the parent link using body-fixed output representation.
282
+ L_J_WL = js.link.jacobian(
283
+ model=model, data=data, link_index=L, output_vel_repr=VelRepr.Body
284
+ )
285
+
286
+ # Adjust the output representation.
287
+ match output_vel_repr:
288
+ case VelRepr.Inertial:
289
+ W_H_L = js.link.transform(model=model, data=data, link_index=L)
290
+ W_X_L = Adjoint.from_transform(transform=W_H_L)
291
+ W_J_WL = W_X_L @ L_J_WL
292
+ O_J_WL_I = W_J_WL
293
+
294
+ case VelRepr.Body:
295
+ W_H_L = js.link.transform(model=model, data=data, link_index=L)
296
+ W_H_F = transform(model=model, data=data, frame_index=frame_index)
297
+ F_H_L = Transform.inverse(W_H_F) @ W_H_L
298
+ F_X_L = Adjoint.from_transform(transform=F_H_L)
299
+ F_J_WL = F_X_L @ L_J_WL
300
+ O_J_WL_I = F_J_WL
301
+
302
+ case VelRepr.Mixed:
303
+ W_H_L = js.link.transform(model=model, data=data, link_index=L)
304
+ W_H_F = transform(model=model, data=data, frame_index=frame_index)
305
+ F_H_L = Transform.inverse(W_H_F) @ W_H_L
306
+ FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3))
307
+ FW_H_L = FW_H_F @ F_H_L
308
+ FW_X_L = Adjoint.from_transform(transform=FW_H_L)
309
+ FW_J_WL = FW_X_L @ L_J_WL
310
+ O_J_WL_I = FW_J_WL
311
+
312
+ case _:
313
+ raise ValueError(output_vel_repr)
314
+
315
+ return O_J_WL_I
316
+
317
+
318
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
319
+ @js.common.named_scope
320
+ def jacobian_derivative(
321
+ model: js.model.JaxSimModel,
322
+ data: js.data.JaxSimModelData,
323
+ *,
324
+ frame_index: jtp.IntLike,
325
+ output_vel_repr: VelRepr | None = None,
326
+ ) -> jtp.Matrix:
327
+ r"""
328
+ Compute the derivative of the free-floating jacobian of the frame.
329
+
330
+ Args:
331
+ model: The model to consider.
332
+ data: The data of the considered model.
333
+ frame_index: The index of the frame.
334
+ output_vel_repr:
335
+ The output velocity representation of the free-floating jacobian derivative.
336
+
337
+ Returns:
338
+ The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the frame.
339
+
340
+ Note:
341
+ The input representation of the free-floating jacobian derivative is the active
342
+ velocity representation.
343
+ """
344
+
345
+ n_l = model.number_of_links()
346
+ n_f = len(model.frame_names())
347
+
348
+ exceptions.raise_value_error_if(
349
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
350
+ msg="Invalid frame index '{idx}'",
351
+ idx=frame_index,
352
+ )
353
+
354
+ output_vel_repr = (
355
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
356
+ )
357
+
358
+ # Get the index of the parent link.
359
+ L = idx_of_parent_link(model=model, frame_index=frame_index)
360
+
361
+ with data.switch_velocity_representation(VelRepr.Inertial):
362
+ # Compute the Jacobian of the parent link in inertial representation.
363
+ W_J_WL_W = js.link.jacobian(
364
+ model=model,
365
+ data=data,
366
+ link_index=L,
367
+ output_vel_repr=VelRepr.Inertial,
368
+ )
369
+
370
+ # Compute the Jacobian derivative of the parent link in inertial representation.
371
+ W_J̇_WL_W = js.link.jacobian_derivative(
372
+ model=model,
373
+ data=data,
374
+ link_index=L,
375
+ output_vel_repr=VelRepr.Inertial,
376
+ )
377
+
378
+ # =====================================================
379
+ # Compute quantities to adjust the input representation
380
+ # =====================================================
381
+
382
+ def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
383
+ In = jnp.eye(model.dofs())
384
+ T = jax.scipy.linalg.block_diag(X, In)
385
+ return T
386
+
387
+ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
388
+ On = jnp.zeros(shape=(model.dofs(), model.dofs()))
389
+ Ṫ = jax.scipy.linalg.block_diag(Ẋ, On)
390
+ return Ṫ
391
+
392
+ # Compute the operator to change the representation of ν, and its
393
+ # time derivative.
394
+ match data.velocity_representation:
395
+ case VelRepr.Inertial:
396
+ W_H_W = jnp.eye(4)
397
+ W_X_W = Adjoint.from_transform(transform=W_H_W)
398
+ W_Ẋ_W = jnp.zeros((6, 6))
399
+
400
+ T = compute_T(model=model, X=W_X_W)
401
+ Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
402
+
403
+ case VelRepr.Body:
404
+ W_H_B = data.base_transform()
405
+ W_X_B = Adjoint.from_transform(transform=W_H_B)
406
+ B_v_WB = data.base_velocity()
407
+ B_vx_WB = Cross.vx(B_v_WB)
408
+ W_Ẋ_B = W_X_B @ B_vx_WB
409
+
410
+ T = compute_T(model=model, X=W_X_B)
411
+ Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
412
+
413
+ case VelRepr.Mixed:
414
+ W_H_B = data.base_transform()
415
+ W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
416
+ W_X_BW = Adjoint.from_transform(transform=W_H_BW)
417
+ BW_v_WB = data.base_velocity()
418
+ BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
419
+ BW_vx_W_BW = Cross.vx(BW_v_W_BW)
420
+ W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
421
+
422
+ T = compute_T(model=model, X=W_X_BW)
423
+ Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW)
424
+
425
+ case _:
426
+ raise ValueError(data.velocity_representation)
427
+
428
+ # =====================================================
429
+ # Compute quantities to adjust the output representation
430
+ # =====================================================
431
+
432
+ match output_vel_repr:
433
+ case VelRepr.Inertial:
434
+ O_X_W = W_X_W = Adjoint.from_transform(transform=jnp.eye(4))
435
+ O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6))
436
+
437
+ case VelRepr.Body:
438
+ W_H_F = transform(model=model, data=data, frame_index=frame_index)
439
+ O_X_W = F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True)
440
+ with data.switch_velocity_representation(VelRepr.Inertial):
441
+ W_nu = data.generalized_velocity()
442
+ W_v_WF = W_J_WL_W @ W_nu
443
+ W_vx_WF = Cross.vx(W_v_WF)
444
+ O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF # noqa: F841
445
+
446
+ case VelRepr.Mixed:
447
+ W_H_F = transform(model=model, data=data, frame_index=frame_index)
448
+ W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.eye(3))
449
+ FW_H_W = Transform.inverse(W_H_FW)
450
+ O_X_W = FW_X_W = Adjoint.from_transform(transform=FW_H_W)
451
+ with data.switch_velocity_representation(VelRepr.Mixed):
452
+ FW_J_WF_FW = jacobian(
453
+ model=model,
454
+ data=data,
455
+ frame_index=frame_index,
456
+ output_vel_repr=VelRepr.Mixed,
457
+ )
458
+ FW_v_WF = FW_J_WF_FW @ data.generalized_velocity()
459
+ W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3])
460
+ W_vx_W_FW = Cross.vx(W_v_W_FW)
461
+ O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW # noqa: F841
462
+
463
+ case _:
464
+ raise ValueError(output_vel_repr)
465
+
466
+ O_J̇_WF_I = jnp.zeros(shape=(6, 6 + model.dofs()))
467
+ O_J̇_WF_I += O_Ẋ_W @ W_J_WL_W @ T
468
+ O_J̇_WF_I += O_X_W @ W_J̇_WL_W @ T
469
+ O_J̇_WF_I += O_X_W @ W_J_WL_W @ Ṫ
470
+
471
+ return O_J̇_WF_I