jaxsim 0.4.2__py3-none-any.whl → 0.4.2.dev12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.2'
16
- __version_tuple__ = version_tuple = (0, 4, 2)
15
+ __version__ = version = '0.4.2.dev12'
16
+ __version_tuple__ = version_tuple = (0, 4, 2, 'dev12')
jaxsim/api/com.py CHANGED
@@ -137,9 +137,9 @@ def centroidal_momentum_jacobian(
137
137
 
138
138
  match data.velocity_representation:
139
139
  case VelRepr.Inertial | VelRepr.Mixed:
140
- W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
140
+ W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
141
141
  case VelRepr.Body:
142
- W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
142
+ W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
143
143
  case _:
144
144
  raise ValueError(data.velocity_representation)
145
145
 
@@ -172,9 +172,9 @@ def locked_centroidal_spatial_inertia(
172
172
 
173
173
  match data.velocity_representation:
174
174
  case VelRepr.Inertial | VelRepr.Mixed:
175
- W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
175
+ W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
176
176
  case VelRepr.Body:
177
- W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
177
+ W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
178
178
  case _:
179
179
  raise ValueError(data.velocity_representation)
180
180
 
@@ -290,14 +290,14 @@ def bias_acceleration(
290
290
 
291
291
  case VelRepr.Inertial:
292
292
 
293
- C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
294
- C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
293
+ C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL
294
+ C_v_WC = W_v_WW = jnp.zeros(6)
295
295
 
296
- L_H_C = L_H_W = jax.vmap( # noqa: F841
296
+ L_H_C = L_H_W = jax.vmap(
297
297
  lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
298
298
  )(W_H_L)
299
299
 
300
- L_v_LC = L_v_LW = jax.vmap( # noqa: F841
300
+ L_v_LC = L_v_LW = jax.vmap(
301
301
  lambda i: -js.link.velocity(
302
302
  model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
303
303
  )
@@ -314,9 +314,9 @@ def bias_acceleration(
314
314
 
315
315
  case VelRepr.Mixed:
316
316
 
317
- C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841
317
+ C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL
318
318
 
319
- C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841
319
+ C_v_WC = LW_v_W_LW = jax.vmap(
320
320
  lambda i: js.link.velocity(
321
321
  model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
322
322
  )
@@ -324,13 +324,13 @@ def bias_acceleration(
324
324
  .set(jnp.zeros(3))
325
325
  )(jnp.arange(model.number_of_links()))
326
326
 
327
- L_H_C = L_H_LW = jax.vmap( # noqa: F841
327
+ L_H_C = L_H_LW = jax.vmap(
328
328
  lambda W_H_L: jaxsim.math.Transform.inverse(
329
329
  W_H_L.at[0:3, 3].set(jnp.zeros(3))
330
330
  )
331
331
  )(W_H_L)
332
332
 
333
- L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841
333
+ L_v_LC = L_v_L_LW = jax.vmap(
334
334
  lambda i: -js.link.velocity(
335
335
  model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
336
336
  )
jaxsim/api/contact.py CHANGED
@@ -8,7 +8,6 @@ import jax.numpy as jnp
8
8
  import jaxsim.api as js
9
9
  import jaxsim.terrain
10
10
  import jaxsim.typing as jtp
11
- from jaxsim.math import Adjoint, Cross, Transform
12
11
  from jaxsim.rbda.contacts.soft import SoftContactsParams
13
12
 
14
13
  from .common import VelRepr
@@ -412,170 +411,3 @@ def jacobian(
412
411
  raise ValueError(output_vel_repr)
413
412
 
414
413
  return O_J_WC
415
-
416
-
417
- @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
418
- def jacobian_derivative(
419
- model: js.model.JaxSimModel,
420
- data: js.data.JaxSimModelData,
421
- *,
422
- output_vel_repr: VelRepr | None = None,
423
- ) -> jtp.Matrix:
424
- r"""
425
- Compute the derivative of the free-floating jacobian of the contact points.
426
-
427
- Args:
428
- model: The model to consider.
429
- data: The data of the considered model.
430
- output_vel_repr:
431
- The output velocity representation of the free-floating jacobian derivative.
432
-
433
- Returns:
434
- The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the contact points.
435
-
436
- Note:
437
- The input representation of the free-floating jacobian derivative is the active
438
- velocity representation.
439
- """
440
-
441
- output_vel_repr = (
442
- output_vel_repr if output_vel_repr is not None else data.velocity_representation
443
- )
444
-
445
- # Get the index of the parent link and the position of the collidable point.
446
- parent_link_idxs = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
447
- L_p_Ci = jnp.array(model.kin_dyn_parameters.contact_parameters.point)
448
- contact_idxs = jnp.arange(L_p_Ci.shape[0])
449
-
450
- # Get the transforms of all the parent links.
451
- W_H_Li = js.model.forward_kinematics(model=model, data=data)
452
-
453
- # =====================================================
454
- # Compute quantities to adjust the input representation
455
- # =====================================================
456
-
457
- def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
458
- In = jnp.eye(model.dofs())
459
- T = jax.scipy.linalg.block_diag(X, In)
460
- return T
461
-
462
- def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
463
- On = jnp.zeros(shape=(model.dofs(), model.dofs()))
464
- Ṫ = jax.scipy.linalg.block_diag(Ẋ, On)
465
- return Ṫ
466
-
467
- # Compute the operator to change the representation of ν, and its
468
- # time derivative.
469
- match data.velocity_representation:
470
- case VelRepr.Inertial:
471
- W_H_W = jnp.eye(4)
472
- W_X_W = Adjoint.from_transform(transform=W_H_W)
473
- W_Ẋ_W = jnp.zeros((6, 6))
474
-
475
- T = compute_T(model=model, X=W_X_W)
476
- Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
477
-
478
- case VelRepr.Body:
479
- W_H_B = data.base_transform()
480
- W_X_B = Adjoint.from_transform(transform=W_H_B)
481
- B_v_WB = data.base_velocity()
482
- B_vx_WB = Cross.vx(B_v_WB)
483
- W_Ẋ_B = W_X_B @ B_vx_WB
484
-
485
- T = compute_T(model=model, X=W_X_B)
486
- Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
487
-
488
- case VelRepr.Mixed:
489
- W_H_B = data.base_transform()
490
- W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
491
- W_X_BW = Adjoint.from_transform(transform=W_H_BW)
492
- BW_v_WB = data.base_velocity()
493
- BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
494
- BW_vx_W_BW = Cross.vx(BW_v_W_BW)
495
- W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
496
-
497
- T = compute_T(model=model, X=W_X_BW)
498
- Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW)
499
-
500
- case _:
501
- raise ValueError(data.velocity_representation)
502
-
503
- # =====================================================
504
- # Compute quantities to adjust the output representation
505
- # =====================================================
506
-
507
- with data.switch_velocity_representation(VelRepr.Inertial):
508
- # Compute the Jacobian of the parent link in inertial representation.
509
- W_J_WL_W = js.model.generalized_free_floating_jacobian(
510
- model=model,
511
- data=data,
512
- output_vel_repr=VelRepr.Inertial,
513
- )
514
- # Compute the Jacobian derivative of the parent link in inertial representation.
515
- W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
516
- model=model,
517
- data=data,
518
- output_vel_repr=VelRepr.Inertial,
519
- )
520
-
521
- # Get the Jacobian of the collidable points in the mixed representation.
522
- with data.switch_velocity_representation(VelRepr.Mixed):
523
- CW_J_WC_BW = jacobian(
524
- model=model,
525
- data=data,
526
- output_vel_repr=VelRepr.Mixed,
527
- )
528
-
529
- def compute_O_J̇_WC_I(
530
- L_p_C: jtp.Vector,
531
- contact_idx: jtp.Int,
532
- CW_J_WC_BW: jtp.Matrix,
533
- W_H_L: jtp.Matrix,
534
- ) -> jtp.Matrix:
535
-
536
- parent_link_idx = parent_link_idxs[contact_idx]
537
-
538
- match output_vel_repr:
539
- case VelRepr.Inertial:
540
- O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841
541
- transform=jnp.eye(4)
542
- )
543
- O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841
544
-
545
- case VelRepr.Body:
546
- L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
547
- W_H_C = W_H_L[parent_link_idx] @ L_H_C
548
- O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
549
- with data.switch_velocity_representation(VelRepr.Inertial):
550
- W_nu = data.generalized_velocity()
551
- W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu
552
- W_vx_WC = Cross.vx(W_v_WC)
553
- O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841
554
-
555
- case VelRepr.Mixed:
556
- L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
557
- W_H_C = W_H_L[parent_link_idx] @ L_H_C
558
- W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
559
- CW_H_W = Transform.inverse(W_H_CW)
560
- O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
561
- with data.switch_velocity_representation(VelRepr.Mixed):
562
- CW_v_WC = CW_J_WC_BW @ data.generalized_velocity()
563
- W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
564
- W_vx_W_CW = Cross.vx(W_v_W_CW)
565
- O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841
566
-
567
- case _:
568
- raise ValueError(output_vel_repr)
569
-
570
- O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs()))
571
- O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T
572
- O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T
573
- O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ
574
-
575
- return O_J̇_WC_I
576
-
577
- O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, 0, None))(
578
- L_p_Ci, contact_idxs, CW_J_WC_BW, W_H_Li
579
- )
580
-
581
- return O_J̇_WC
jaxsim/api/data.py CHANGED
@@ -12,6 +12,7 @@ import jaxlie
12
12
  import jaxsim.api as js
13
13
  import jaxsim.rbda
14
14
  import jaxsim.typing as jtp
15
+ from jaxsim.math import Quaternion
15
16
  from jaxsim.rbda.contacts.soft import SoftContacts
16
17
  from jaxsim.utils import Mutability
17
18
  from jaxsim.utils.tracing import not_tracing
@@ -190,7 +191,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
190
191
 
191
192
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
192
193
  translation=base_position,
193
- rotation=jaxlie.SO3(wxyz=base_quaternion),
194
+ rotation=jaxlie.SO3.from_quaternion_xyzw(
195
+ base_quaternion[jnp.array([1, 2, 3, 0])]
196
+ ),
194
197
  ).as_matrix()
195
198
 
196
199
  v_WB = JaxSimModelData.other_representation_to_inertial(
@@ -377,7 +380,13 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
377
380
  on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
378
381
  )
379
382
 
380
- return (W_Q_B if not dcm else jaxlie.SO3(wxyz=W_Q_B).as_matrix()).astype(float)
383
+ return (
384
+ W_Q_B
385
+ if not dcm
386
+ else jaxlie.SO3.from_quaternion_xyzw(
387
+ Quaternion.to_xyzw(wxyz=W_Q_B)
388
+ ).as_matrix()
389
+ ).astype(float)
381
390
 
382
391
  @jax.jit
383
392
  def base_transform(self) -> jtp.Matrix:
jaxsim/api/frame.py CHANGED
@@ -384,7 +384,7 @@ def jacobian_derivative(
384
384
  W_nu = data.generalized_velocity()
385
385
  W_v_WF = W_J_WL_W @ W_nu
386
386
  W_vx_WF = Cross.vx(W_v_WF)
387
- O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF # noqa: F841
387
+ O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF
388
388
 
389
389
  case VelRepr.Mixed:
390
390
  W_H_F = transform(model=model, data=data, frame_index=frame_index)
@@ -401,7 +401,7 @@ def jacobian_derivative(
401
401
  FW_v_WF = FW_J_WF_FW @ data.generalized_velocity()
402
402
  W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3])
403
403
  W_vx_W_FW = Cross.vx(W_v_W_FW)
404
- O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW # noqa: F841
404
+ O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW
405
405
 
406
406
  case _:
407
407
  raise ValueError(output_vel_repr)
jaxsim/api/link.py CHANGED
@@ -288,7 +288,7 @@ def jacobian(
288
288
  case VelRepr.Inertial:
289
289
  W_H_B = data.base_transform()
290
290
  B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
291
- B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
291
+ B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(
292
292
  B_X_W, jnp.eye(model.dofs())
293
293
  )
294
294
 
@@ -299,7 +299,7 @@ def jacobian(
299
299
  W_R_B = data.base_orientation(dcm=True)
300
300
  BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
301
301
  B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
302
- B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
302
+ B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(
303
303
  B_X_BW, jnp.eye(model.dofs())
304
304
  )
305
305
 
@@ -313,7 +313,7 @@ def jacobian(
313
313
  case VelRepr.Inertial:
314
314
  W_H_B = data.base_transform()
315
315
  W_X_B = Adjoint.from_transform(transform=W_H_B)
316
- O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841
316
+ O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I
317
317
 
318
318
  case VelRepr.Body:
319
319
  L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True)
@@ -505,7 +505,7 @@ def jacobian_derivative(
505
505
  with data.switch_velocity_representation(VelRepr.Body):
506
506
  B_v_WB = data.base_velocity()
507
507
 
508
- O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
508
+ O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB)
509
509
 
510
510
  case VelRepr.Body:
511
511
 
@@ -519,9 +519,7 @@ def jacobian_derivative(
519
519
  B_v_WB = data.base_velocity()
520
520
  L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index)
521
521
 
522
- O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
523
- B_X_L @ L_v_WL - B_v_WB
524
- )
522
+ O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx(B_X_L @ L_v_WL - B_v_WB)
525
523
 
526
524
  case VelRepr.Mixed:
527
525
 
@@ -546,9 +544,8 @@ def jacobian_derivative(
546
544
  LW_v_LW_L = LW_v_WL - LW_v_W_LW
547
545
  LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L
548
546
 
549
- O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
550
- B_X_LW @ LW_v_B_LW
551
- )
547
+ O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx(B_X_LW @ LW_v_B_LW)
548
+
552
549
  case _:
553
550
  raise ValueError(output_vel_repr)
554
551
 
jaxsim/api/model.py CHANGED
@@ -495,9 +495,8 @@ def generalized_free_floating_jacobian(
495
495
  W_H_B = data.base_transform()
496
496
  B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
497
497
 
498
- B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841
499
- B_J_full_WX_B
500
- @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
498
+ B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
499
+ B_X_W, jnp.eye(model.dofs())
501
500
  )
502
501
 
503
502
  case VelRepr.Body:
@@ -510,7 +509,7 @@ def generalized_free_floating_jacobian(
510
509
  BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
511
510
  B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
512
511
 
513
- B_J_full_WX_I = B_J_full_WX_BW = ( # noqa: F841
512
+ B_J_full_WX_I = B_J_full_WX_BW = (
514
513
  B_J_full_WX_B
515
514
  @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
516
515
  )
@@ -543,13 +542,11 @@ def generalized_free_floating_jacobian(
543
542
  W_H_B = data.base_transform()
544
543
  W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)
545
544
 
546
- O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841
547
- lambda B_J_WL_I: W_X_B @ B_J_WL_I
548
- )(B_J_WL_I)
545
+ O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I)
549
546
 
550
547
  case VelRepr.Body:
551
548
 
552
- O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841
549
+ O_J_WL_I = L_J_WL_I = jax.vmap(
553
550
  lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
554
551
  B_H_L, inverse=True
555
552
  )
@@ -568,7 +565,7 @@ def generalized_free_floating_jacobian(
568
565
  lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
569
566
  )(LW_H_L, B_H_L)
570
567
 
571
- O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841
568
+ O_J_WL_I = LW_J_WL_I = jax.vmap(
572
569
  lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B)
573
570
  @ B_J_WL_I
574
571
  )(LW_H_B, B_J_WL_I)
@@ -579,41 +576,6 @@ def generalized_free_floating_jacobian(
579
576
  return O_J_WL_I
580
577
 
581
578
 
582
- @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
583
- def generalized_free_floating_jacobian_derivative(
584
- model: JaxSimModel,
585
- data: js.data.JaxSimModelData,
586
- *,
587
- output_vel_repr: VelRepr | None = None,
588
- ) -> jtp.Matrix:
589
- """
590
- Compute the free-floating jacobian derivatives of all links.
591
-
592
- Args:
593
- model: The model to consider.
594
- data: The data of the considered model.
595
- output_vel_repr:
596
- The output velocity representation of the free-floating jacobian derivatives.
597
-
598
- Returns:
599
- The `(nL, 6, 6+dofs)` array containing the stacked free-floating
600
- jacobian derivatives of the links. The first axis is the link index.
601
- """
602
-
603
- output_vel_repr = (
604
- output_vel_repr if output_vel_repr is not None else data.velocity_representation
605
- )
606
-
607
- O_J̇_WL_I = jax.vmap(
608
- lambda model, data, link_idxs, output_vel_repr: js.link.jacobian_derivative(
609
- model, data, link_index=link_idxs, output_vel_repr=output_vel_repr
610
- ),
611
- in_axes=(None, None, 0, None),
612
- )(model, data, jnp.arange(model.number_of_links()), output_vel_repr)
613
-
614
- return O_J̇_WL_I
615
-
616
-
617
579
  @functools.partial(jax.jit, static_argnames=["prefer_aba"])
618
580
  def forward_dynamics(
619
581
  model: JaxSimModel,
@@ -759,8 +721,8 @@ def forward_dynamics_aba(
759
721
  match data.velocity_representation:
760
722
  case VelRepr.Inertial:
761
723
  # In this case C=W
762
- W_H_C = W_H_W = jnp.eye(4) # noqa: F841
763
- W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
724
+ W_H_C = W_H_W = jnp.eye(4)
725
+ W_v_WC = W_v_WW = jnp.zeros(6)
764
726
 
765
727
  case VelRepr.Body:
766
728
  # In this case C=B
@@ -770,9 +732,9 @@ def forward_dynamics_aba(
770
732
  case VelRepr.Mixed:
771
733
  # In this case C=B[W]
772
734
  W_H_B = data.base_transform()
773
- W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
735
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
774
736
  W_ṗ_B = data.base_velocity()[0:3]
775
- W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
737
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
776
738
 
777
739
  case _:
778
740
  raise ValueError(data.velocity_representation)
@@ -1127,8 +1089,8 @@ def inverse_dynamics(
1127
1089
 
1128
1090
  match data.velocity_representation:
1129
1091
  case VelRepr.Inertial:
1130
- W_H_C = W_H_W = jnp.eye(4) # noqa: F841
1131
- W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
1092
+ W_H_C = W_H_W = jnp.eye(4)
1093
+ W_v_WC = W_v_WW = jnp.zeros(6)
1132
1094
 
1133
1095
  case VelRepr.Body:
1134
1096
  W_H_C = W_H_B = data.base_transform()
@@ -1137,9 +1099,9 @@ def inverse_dynamics(
1137
1099
 
1138
1100
  case VelRepr.Mixed:
1139
1101
  W_H_B = data.base_transform()
1140
- W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
1102
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
1141
1103
  W_ṗ_B = data.base_velocity()[0:3]
1142
- W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
1104
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
1143
1105
 
1144
1106
  case _:
1145
1107
  raise ValueError(data.velocity_representation)
@@ -1574,15 +1536,15 @@ def link_bias_accelerations(
1574
1536
  # a simple C_X_W 6D transform.
1575
1537
  match data.velocity_representation:
1576
1538
  case VelRepr.Inertial:
1577
- W_H_C = W_H_W = jnp.eye(4) # noqa: F841
1578
- W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
1539
+ W_H_C = W_H_W = jnp.eye(4)
1540
+ W_v_WC = W_v_WW = jnp.zeros(6)
1579
1541
  with data.switch_velocity_representation(VelRepr.Inertial):
1580
1542
  C_v_WB = W_v_WB = data.base_velocity()
1581
1543
 
1582
1544
  case VelRepr.Body:
1583
1545
  W_H_C = W_H_B
1584
1546
  with data.switch_velocity_representation(VelRepr.Inertial):
1585
- W_v_WC = W_v_WB = data.base_velocity() # noqa: F841
1547
+ W_v_WC = W_v_WB = data.base_velocity()
1586
1548
  with data.switch_velocity_representation(VelRepr.Body):
1587
1549
  C_v_WB = B_v_WB = data.base_velocity()
1588
1550
 
@@ -1593,9 +1555,9 @@ def link_bias_accelerations(
1593
1555
  W_ṗ_B = data.base_velocity()[0:3]
1594
1556
  BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
1595
1557
  W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
1596
- W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841
1558
+ W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW
1597
1559
  with data.switch_velocity_representation(VelRepr.Mixed):
1598
- C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841
1560
+ C_v_WB = BW_v_WB = data.base_velocity()
1599
1561
 
1600
1562
  case _:
1601
1563
  raise ValueError(data.velocity_representation)
@@ -1703,12 +1665,8 @@ def link_bias_accelerations(
1703
1665
 
1704
1666
  match data.velocity_representation:
1705
1667
  case VelRepr.Body:
1706
- C_H_L = L_H_L = jnp.stack( # noqa: F841
1707
- [jnp.eye(4)] * model.number_of_links()
1708
- )
1709
- L_v_CL = L_v_LL = jnp.zeros( # noqa: F841
1710
- shape=(model.number_of_links(), 6)
1711
- )
1668
+ C_H_L = L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links())
1669
+ L_v_CL = L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6))
1712
1670
 
1713
1671
  case VelRepr.Inertial:
1714
1672
  C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
@@ -1718,9 +1676,7 @@ def link_bias_accelerations(
1718
1676
  W_H_L = js.model.forward_kinematics(model=model, data=data)
1719
1677
  LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
1720
1678
  C_H_L = LW_H_L
1721
- L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
1722
- lambda v: v.at[0:3].set(jnp.zeros(3))
1723
- )(L_v_WL)
1679
+ L_v_CL = L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL)
1724
1680
 
1725
1681
  case _:
1726
1682
  raise ValueError(data.velocity_representation)
jaxsim/api/references.py CHANGED
@@ -8,7 +8,6 @@ import jax_dataclasses
8
8
 
9
9
  import jaxsim.api as js
10
10
  import jaxsim.typing as jtp
11
- from jaxsim import exceptions
12
11
  from jaxsim.utils.tracing import not_tracing
13
12
 
14
13
  from .common import VelRepr
@@ -31,7 +30,6 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
31
30
  @staticmethod
32
31
  def zero(
33
32
  model: js.model.JaxSimModel,
34
- data: js.data.JaxSimModelData | None = None,
35
33
  velocity_representation: VelRepr = VelRepr.Inertial,
36
34
  ) -> JaxSimModelReferences:
37
35
  """
@@ -39,9 +37,6 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
39
37
 
40
38
  Args:
41
39
  model: The model for which to create the zero references.
42
- data:
43
- The data of the model, only needed if the velocity representation is
44
- not inertial-fixed.
45
40
  velocity_representation: The velocity representation to use.
46
41
 
47
42
  Returns:
@@ -49,7 +44,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
49
44
  """
50
45
 
51
46
  return JaxSimModelReferences.build(
52
- model=model, data=data, velocity_representation=velocity_representation
47
+ model=model, velocity_representation=velocity_representation
53
48
  )
54
49
 
55
50
  @staticmethod
@@ -446,104 +441,3 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
446
441
  return replace(
447
442
  forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
448
443
  )
449
-
450
- def apply_frame_forces(
451
- self,
452
- forces: jtp.MatrixLike,
453
- model: js.model.JaxSimModel,
454
- data: js.data.JaxSimModelData,
455
- frame_names: tuple[str, ...] | str | None = None,
456
- additive: bool = False,
457
- ) -> Self:
458
- """
459
- Apply the frame forces.
460
-
461
- Args:
462
- forces: The frame 6D forces in the active representation.
463
- model:
464
- The model to consider, only needed if a frame serialization different
465
- from the implicit one is used.
466
- data:
467
- The data of the considered model, only needed if the velocity
468
- representation is not inertial-fixed.
469
- frame_names: The names of the frames corresponding to the forces.
470
- additive:
471
- Whether to add the forces to the existing ones instead of replacing them.
472
-
473
- Returns:
474
- A new `JaxSimModelReferences` object with the given frame forces.
475
-
476
- Note:
477
- The frame forces must be expressed in the active representation.
478
- Then, we always convert and store forces in inertial-fixed representation.
479
- """
480
-
481
- f_F = jnp.atleast_2d(forces).astype(float)
482
-
483
- # If we have the model, we can extract the frame names if not provided.
484
- frame_names = frame_names if frame_names is not None else model.frame_names()
485
-
486
- # Make sure that the frame names are a tuple if they are provided by the user.
487
- frame_names = (frame_names,) if isinstance(frame_names, str) else frame_names
488
-
489
- if len(frame_names) != f_F.shape[0]:
490
- msg = "The number of frame names ({}) must match the number of forces ({})"
491
- raise ValueError(msg.format(len(frame_names), f_F.shape[0]))
492
-
493
- # Extract the frame indices.
494
- frame_idxs = js.frame.names_to_idxs(frame_names=frame_names, model=model)
495
- parent_link_idxs = jax.vmap(js.frame.idx_of_parent_link, in_axes=(None,))(
496
- model, frame_index=frame_idxs
497
- )
498
-
499
- exceptions.raise_value_error_if(
500
- condition=jnp.logical_not(data.valid(model=model)),
501
- msg="The provided data is not valid for the model",
502
- )
503
- W_H_Fi = jax.vmap(
504
- lambda frame_idx: js.frame.transform(
505
- model=model, data=data, frame_index=frame_idx
506
- )
507
- )(frame_idxs)
508
-
509
- # Helper function to convert a single 6D force to the inertial representation
510
- # considering as body the frame (i.e. L_f_F and LW_f_F).
511
- def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix:
512
- return JaxSimModelReferences.other_representation_to_inertial(
513
- array=f_F,
514
- other_representation=self.velocity_representation,
515
- transform=W_H_F,
516
- is_force=True,
517
- )
518
-
519
- match self.velocity_representation:
520
- case VelRepr.Inertial:
521
- W_f_F = f_F
522
-
523
- case VelRepr.Body | VelRepr.Mixed:
524
- W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi)
525
-
526
- case _:
527
- raise ValueError("Invalid velocity representation.")
528
-
529
- # Sum the forces on the parent links.
530
- mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links())
531
- W_f_L = mask.T @ W_f_F
532
-
533
- with self.switch_velocity_representation(
534
- velocity_representation=VelRepr.Inertial
535
- ):
536
- references = self.apply_link_forces(
537
- model=model,
538
- data=data,
539
- link_names=js.link.idxs_to_names(
540
- model=model, link_indices=parent_link_idxs
541
- ),
542
- forces=W_f_L,
543
- additive=additive,
544
- )
545
-
546
- with references.switch_velocity_representation(
547
- velocity_representation=self.velocity_representation
548
- ):
549
- return references
@@ -261,10 +261,8 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
261
261
  # Check if the Butcher tableau supports FSAL (first-same-as-last).
262
262
  # If it does, store the index of the intermediate derivative to be used as the
263
263
  # first derivative of the next iteration.
264
- has_fsal, index_of_fsal = ( # noqa: F841
265
- ExplicitRungeKutta.butcher_tableau_supports_fsal(
266
- A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
267
- )
264
+ has_fsal, index_of_fsal = ExplicitRungeKutta.butcher_tableau_supports_fsal(
265
+ A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
268
266
  )
269
267
 
270
268
  # Build the integrator object.
jaxsim/math/adjoint.py CHANGED
@@ -3,6 +3,7 @@ import jaxlie
3
3
 
4
4
  import jaxsim.typing as jtp
5
5
 
6
+ from .quaternion import Quaternion
6
7
  from .skew import Skew
7
8
 
8
9
 
@@ -30,7 +31,7 @@ class Adjoint:
30
31
  assert quaternion.size == 4
31
32
  assert translation.size == 3
32
33
 
33
- Q_sixd = jaxlie.SO3(wxyz=quaternion)
34
+ Q_sixd = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion))
34
35
  Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()
35
36
 
36
37
  return Adjoint.from_rotation_and_translation(
@@ -83,14 +84,14 @@ class Adjoint:
83
84
  A_o_B = translation.squeeze()
84
85
 
85
86
  if not inverse:
86
- X = A_X_B = jnp.vstack( # noqa: F841
87
+ X = A_X_B = jnp.vstack(
87
88
  [
88
89
  jnp.block([A_R_B, Skew.wedge(A_o_B) @ A_R_B]),
89
90
  jnp.block([jnp.zeros(shape=(3, 3)), A_R_B]),
90
91
  ]
91
92
  )
92
93
  else:
93
- X = B_X_A = jnp.vstack( # noqa: F841
94
+ X = B_X_A = jnp.vstack(
94
95
  [
95
96
  jnp.block([A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)]),
96
97
  jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]),
@@ -254,8 +254,8 @@ def supported_joint_motion(
254
254
  # This is a metadata required by only some joint types.
255
255
  axis = jnp.array(joint_axis).astype(float).squeeze()
256
256
 
257
- pre_H_suc = jaxlie.SE3.from_matrix(
258
- matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis))
257
+ pre_H_suc = jaxlie.SE3.from_rotation(
258
+ rotation=jaxlie.SO3.from_matrix(Rotation.from_axis_angle(vector=s * axis))
259
259
  )
260
260
 
261
261
  S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))
jaxsim/math/quaternion.py CHANGED
@@ -43,7 +43,9 @@ class Quaternion:
43
43
  Returns:
44
44
  jtp.Matrix: Direction cosine matrix (DCM).
45
45
  """
46
- return jaxlie.SO3(wxyz=quaternion).as_matrix()
46
+ return jaxlie.SO3.from_quaternion_xyzw(
47
+ xyzw=Quaternion.to_xyzw(quaternion)
48
+ ).as_matrix()
47
49
 
48
50
  @staticmethod
49
51
  def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
@@ -156,7 +158,7 @@ class Quaternion:
156
158
  A_Q_B = jnp.array(quaternion).squeeze().astype(float)
157
159
 
158
160
  # Build the initial SO(3) quaternion.
159
- W_Q_B_t0 = jaxlie.SO3(wxyz=A_Q_B)
161
+ W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=A_Q_B))
160
162
 
161
163
  # Integrate the quaternion on the manifold.
162
164
  W_Q_B_tf = jax.lax.select(
jaxsim/math/transform.py CHANGED
@@ -3,6 +3,8 @@ import jaxlie
3
3
 
4
4
  import jaxsim.typing as jtp
5
5
 
6
+ from .quaternion import Quaternion
7
+
6
8
 
7
9
  class Transform:
8
10
 
@@ -33,7 +35,7 @@ class Transform:
33
35
  assert W_p_B.size == 3
34
36
  assert W_Q_B.size == 4
35
37
 
36
- A_R_B = jaxlie.SO3(wxyz=W_Q_B)
38
+ A_R_B = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(W_Q_B))
37
39
  A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
38
40
 
39
41
  A_H_B = jaxlie.SE3.from_rotation_and_translation(
jaxsim/mujoco/loaders.py CHANGED
@@ -532,12 +532,7 @@ class UrdfToMjcf:
532
532
  model_name: str | None = None,
533
533
  plane_normal: tuple[float, float, float] = (0, 0, 1),
534
534
  heightmap: bool | None = None,
535
- cameras: (
536
- MujocoCamera
537
- | Sequence[MujocoCamera]
538
- | dict[str, str]
539
- | Sequence[dict[str, str]]
540
- ) = (),
535
+ cameras: list[dict[str, str]] | dict[str, str] | None = None,
541
536
  ) -> tuple[str, dict[str, Any]]:
542
537
  """
543
538
  Converts a URDF file to a Mujoco MJCF string.
@@ -579,12 +574,7 @@ class SdfToMjcf:
579
574
  model_name: str | None = None,
580
575
  plane_normal: tuple[float, float, float] = (0, 0, 1),
581
576
  heightmap: bool | None = None,
582
- cameras: (
583
- MujocoCamera
584
- | Sequence[MujocoCamera]
585
- | dict[str, str]
586
- | Sequence[dict[str, str]]
587
- ) = (),
577
+ cameras: list[dict[str, str]] | dict[str, str] | None = None,
588
578
  ) -> tuple[str, dict[str, Any]]:
589
579
  """
590
580
  Converts a SDF file to a Mujoco MJCF string.
jaxsim/mujoco/model.py CHANGED
@@ -238,9 +238,9 @@ class MujocoModelHelper:
238
238
  raise ValueError("The orientation is not a valid element of SO(3)")
239
239
 
240
240
  W_Q_B = (
241
- Rotation.from_matrix(orientation).as_quat(
242
- canonical=True, scalar_first=False
243
- )
241
+ Rotation.from_matrix(orientation).as_quat(canonical=True)[
242
+ np.array([3, 0, 1, 2])
243
+ ]
244
244
  if dcm
245
245
  else orientation
246
246
  )
@@ -394,8 +394,8 @@ class MujocoModelHelper:
394
394
  if dcm:
395
395
  return R
396
396
 
397
- q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True, scalar_first=False)
398
- return q_xyzw
397
+ q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True)
398
+ return q_xyzw[[3, 0, 1, 2]]
399
399
 
400
400
  # ===============
401
401
  # Private methods
jaxsim/rbda/aba.py CHANGED
@@ -4,7 +4,7 @@ import jaxlie
4
4
 
5
5
  import jaxsim.api as js
6
6
  import jaxsim.typing as jtp
7
- from jaxsim.math import Adjoint, Cross, StandardGravity
7
+ from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
8
8
 
9
9
  from . import utils
10
10
 
@@ -77,7 +77,7 @@ def aba(
77
77
 
78
78
  # Compute the base transform.
79
79
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
80
- rotation=jaxlie.SO3(wxyz=W_Q_B),
80
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
81
81
  translation=W_p_B,
82
82
  )
83
83
 
@@ -4,7 +4,7 @@ import jaxlie
4
4
 
5
5
  import jaxsim.api as js
6
6
  import jaxsim.typing as jtp
7
- from jaxsim.math import Adjoint, Skew
7
+ from jaxsim.math import Adjoint, Quaternion, Skew
8
8
 
9
9
  from . import utils
10
10
 
@@ -57,7 +57,7 @@ def collidable_points_pos_vel(
57
57
 
58
58
  # Compute the base transform.
59
59
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
60
- rotation=jaxlie.SO3(wxyz=W_Q_B),
60
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
61
61
  translation=W_p_B,
62
62
  )
63
63
 
@@ -192,7 +192,7 @@ class SoftContacts(ContactModel):
192
192
 
193
193
  # Unpack the position of the collidable point.
194
194
  px, py, pz = W_p_C = position.squeeze()
195
- W_ṗ_C = velocity.squeeze()
195
+ vx, vy, vz = W_ṗ_C = velocity.squeeze()
196
196
 
197
197
  # Compute the terrain normal and the contact depth.
198
198
  n̂ = self.terrain.normal(x=px, y=py).squeeze()
jaxsim/rbda/crba.py CHANGED
@@ -59,14 +59,10 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
59
59
 
60
60
  return (i_X_0,), None
61
61
 
62
- (i_X_0,), _ = (
63
- jax.lax.scan(
64
- f=propagate_kinematics,
65
- init=forward_pass_carry,
66
- xs=jnp.arange(start=1, stop=model.number_of_links()),
67
- )
68
- if model.number_of_links() > 1
69
- else [(i_X_0,), None]
62
+ (i_X_0,), _ = jax.lax.scan(
63
+ f=propagate_kinematics,
64
+ init=forward_pass_carry,
65
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
70
66
  )
71
67
 
72
68
  # ===================
@@ -132,14 +128,10 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
132
128
  operand=carry,
133
129
  )
134
130
 
135
- (j, Fi, M), _ = (
136
- jax.lax.scan(
137
- f=inner_fn,
138
- init=carry_inner_fn,
139
- xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
140
- )
141
- if model.number_of_links() > 1
142
- else [(j, Fi, M), None]
131
+ (j, Fi, M), _ = jax.lax.scan(
132
+ f=inner_fn,
133
+ init=carry_inner_fn,
134
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
143
135
  )
144
136
 
145
137
  Fi = i_X_0[j].T @ Fi
@@ -151,14 +143,10 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
151
143
 
152
144
  # This scan performs the backward pass to compute Mbj, Mjb and Mjj, that
153
145
  # also includes a fake while loop implemented with a scan and two cond.
154
- (Mc, M), _ = (
155
- jax.lax.scan(
156
- f=backward_pass,
157
- init=backward_pass_carry,
158
- xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
159
- )
160
- if model.number_of_links() > 1
161
- else [(Mc, M), None]
146
+ (Mc, M), _ = jax.lax.scan(
147
+ f=backward_pass,
148
+ init=backward_pass_carry,
149
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
162
150
  )
163
151
 
164
152
  # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.
@@ -4,7 +4,7 @@ import jaxlie
4
4
 
5
5
  import jaxsim.api as js
6
6
  import jaxsim.typing as jtp
7
- from jaxsim.math import Adjoint
7
+ from jaxsim.math import Adjoint, Quaternion
8
8
 
9
9
  from . import utils
10
10
 
@@ -42,7 +42,7 @@ def forward_kinematics_model(
42
42
 
43
43
  # Compute the base transform.
44
44
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
45
- rotation=jaxlie.SO3(wxyz=W_Q_B),
45
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
46
46
  translation=W_p_B,
47
47
  )
48
48
 
@@ -75,14 +75,10 @@ def forward_kinematics_model(
75
75
 
76
76
  return (W_X_i,), None
77
77
 
78
- (W_X_i,), _ = (
79
- jax.lax.scan(
80
- f=propagate_kinematics,
81
- init=propagate_kinematics_carry,
82
- xs=jnp.arange(start=1, stop=model.number_of_links()),
83
- )
84
- if model.number_of_links() > 1
85
- else [(W_X_i,), None]
78
+ (W_X_i,), _ = jax.lax.scan(
79
+ f=propagate_kinematics,
80
+ init=propagate_kinematics_carry,
81
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
86
82
  )
87
83
 
88
84
  return jax.vmap(Adjoint.to_transform)(W_X_i)
jaxsim/rbda/jacobian.py CHANGED
@@ -67,14 +67,10 @@ def jacobian(
67
67
 
68
68
  return (i_X_0,), None
69
69
 
70
- (i_X_0,), _ = (
71
- jax.lax.scan(
72
- f=propagate_kinematics,
73
- init=propagate_kinematics_carry,
74
- xs=np.arange(start=1, stop=model.number_of_links()),
75
- )
76
- if model.number_of_links() > 1
77
- else [(i_X_0,), None]
70
+ (i_X_0,), _ = jax.lax.scan(
71
+ f=propagate_kinematics,
72
+ init=propagate_kinematics_carry,
73
+ xs=np.arange(start=1, stop=model.number_of_links()),
78
74
  )
79
75
 
80
76
  # ============================
@@ -109,14 +105,10 @@ def jacobian(
109
105
 
110
106
  return J, None
111
107
 
112
- L_J_WL_B, _ = (
113
- jax.lax.scan(
114
- f=compute_jacobian,
115
- init=J,
116
- xs=np.arange(start=1, stop=model.number_of_links()),
117
- )
118
- if model.number_of_links() > 1
119
- else [J, None]
108
+ L_J_WL_B, _ = jax.lax.scan(
109
+ f=compute_jacobian,
110
+ init=J,
111
+ xs=np.arange(start=1, stop=model.number_of_links()),
120
112
  )
121
113
 
122
114
  return L_J_WL_B
@@ -192,14 +184,10 @@ def jacobian_full_doubly_left(
192
184
 
193
185
  return (B_X_i, J), None
194
186
 
195
- (B_X_i, J), _ = (
196
- jax.lax.scan(
197
- f=compute_full_jacobian,
198
- init=compute_full_jacobian_carry,
199
- xs=np.arange(start=1, stop=model.number_of_links()),
200
- )
201
- if model.number_of_links() > 1
202
- else [(B_X_i, J), None]
187
+ (B_X_i, J), _ = jax.lax.scan(
188
+ f=compute_full_jacobian,
189
+ init=compute_full_jacobian_carry,
190
+ xs=np.arange(start=1, stop=model.number_of_links()),
203
191
  )
204
192
 
205
193
  # Convert adjoints to SE(3) transforms.
jaxsim/rbda/rnea.py CHANGED
@@ -6,7 +6,7 @@ import jaxlie
6
6
 
7
7
  import jaxsim.api as js
8
8
  import jaxsim.typing as jtp
9
- from jaxsim.math import Adjoint, Cross, StandardGravity
9
+ from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
10
10
 
11
11
  from . import utils
12
12
 
@@ -82,7 +82,7 @@ def rnea(
82
82
 
83
83
  # Compute the base transform.
84
84
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
85
- rotation=jaxlie.SO3(wxyz=W_Q_B),
85
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
86
86
  translation=W_p_B,
87
87
  )
88
88
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.2
3
+ Version: 0.4.2.dev12
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -60,29 +60,29 @@ Requires-Python: >=3.10
60
60
  Description-Content-Type: text/markdown
61
61
  License-File: LICENSE
62
62
  Requires-Dist: coloredlogs
63
- Requires-Dist: jax>=0.4.13
64
- Requires-Dist: jaxlib>=0.4.13
65
- Requires-Dist: jaxlie>=1.3.0
66
- Requires-Dist: jax-dataclasses>=1.4.0
63
+ Requires-Dist: jax >=0.4.13
64
+ Requires-Dist: jaxlib >=0.4.13
65
+ Requires-Dist: jaxlie >=1.3.0
66
+ Requires-Dist: jax-dataclasses >=1.4.0
67
67
  Requires-Dist: pptree
68
- Requires-Dist: rod>=0.3.0
69
- Requires-Dist: typing-extensions; python_version < "3.12"
68
+ Requires-Dist: rod >=0.3.0
69
+ Requires-Dist: typing-extensions ; python_version < "3.12"
70
70
  Provides-Extra: all
71
- Requires-Dist: jaxsim[style,testing,viz]; extra == "all"
71
+ Requires-Dist: jaxsim[style,testing,viz] ; extra == 'all'
72
72
  Provides-Extra: style
73
- Requires-Dist: black[jupyter]~=24.0; extra == "style"
74
- Requires-Dist: isort; extra == "style"
75
- Requires-Dist: pre-commit; extra == "style"
73
+ Requires-Dist: black[jupyter] ~=24.0 ; extra == 'style'
74
+ Requires-Dist: isort ; extra == 'style'
75
+ Requires-Dist: pre-commit ; extra == 'style'
76
76
  Provides-Extra: testing
77
- Requires-Dist: idyntree>=12.2.1; extra == "testing"
78
- Requires-Dist: pytest>=6.0; extra == "testing"
79
- Requires-Dist: pytest-icdiff; extra == "testing"
80
- Requires-Dist: robot-descriptions; extra == "testing"
77
+ Requires-Dist: idyntree >=12.2.1 ; extra == 'testing'
78
+ Requires-Dist: pytest >=6.0 ; extra == 'testing'
79
+ Requires-Dist: pytest-icdiff ; extra == 'testing'
80
+ Requires-Dist: robot-descriptions ; extra == 'testing'
81
81
  Provides-Extra: viz
82
- Requires-Dist: lxml; extra == "viz"
83
- Requires-Dist: mediapy; extra == "viz"
84
- Requires-Dist: mujoco>=3.0.0; extra == "viz"
85
- Requires-Dist: scipy>=1.14.0; extra == "viz"
82
+ Requires-Dist: lxml ; extra == 'viz'
83
+ Requires-Dist: mediapy ; extra == 'viz'
84
+ Requires-Dist: mujoco >=3.0.0 ; extra == 'viz'
85
+ Requires-Dist: scipy >=1.14.0 ; extra == 'viz'
86
86
 
87
87
  # JaxSim
88
88
 
@@ -90,18 +90,6 @@ JaxSim is a **differentiable physics engine** and **multibody dynamics library**
90
90
 
91
91
  Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
92
92
 
93
- <div align="center">
94
- <br/>
95
- <table>
96
- <tr>
97
- <th><img src="https://github.com/user-attachments/assets/115b1c1c-6ae5-4c59-92e0-1be13ba954db" width="250"></th>
98
- <th><img src="https://github.com/user-attachments/assets/f9661fae-9a85-41dd-9a58-218758ec8c9c" width="250"></th>
99
- <th><img src="https://github.com/user-attachments/assets/ae8adadf-3bca-47b8-97ca-3a9273633d60" width="250"></th>
100
- </tr>
101
- </table>
102
- <br/>
103
- </div>
104
-
105
93
  ## Features
106
94
 
107
95
  - Physics engine in reduced coordinates supporting fixed-base and floating-base robots.
@@ -1,38 +1,38 @@
1
1
  jaxsim/__init__.py,sha256=ixsS4dYMPex2wOUUp_rkPnwrPhYzkRh1xO_YuMj3Cr4,2626
2
- jaxsim/_version.py,sha256=McNH31cVzymi4jtwoAHwNiyVAdDW8uY0z3IcBioCvQI,411
2
+ jaxsim/_version.py,sha256=RAx3CWTs9fxrkfMFaLp1UMf3oZxbRvQW5m-ab2iNDcg,426
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
5
5
  jaxsim/typing.py,sha256=IbFx3UkEXi-cm7UBqMPi58rJAFV_HbZ9E_K4JwfNvVM,753
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
- jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
7
+ jaxsim/api/com.py,sha256=6TnYCvjmsJ2KLuw3NtZb0pay7ZwGKe9MKphYeQdjpQ0,13474
8
8
  jaxsim/api/common.py,sha256=Ubi6uAw3o6qbdU0TFGzUyHg98EnoMzrnlihrvrs95Sk,6653
9
- jaxsim/api/contact.py,sha256=yBLfjT01BxEZ1lnC0WBSJZwCXK9DnW_DBJxmo9arStE,19855
10
- jaxsim/api/data.py,sha256=T6m7-NteWrm-K3491Yo76xvlWtFCKqTzEE7Ughcasi8,27197
11
- jaxsim/api/frame.py,sha256=Lsx9OIao_UZOQ6ibB_rramCRiYQbQv-M8C1QdoQdgcA,12902
9
+ jaxsim/api/contact.py,sha256=EcOx_T94gZT3igtebmW9FJDpZYPEf-RwKfFN18JjOWM,13364
10
+ jaxsim/api/data.py,sha256=-xx4b11thP8oJEXB4xtgrh3RTY2-BxrT38b7s_GrzjA,27420
11
+ jaxsim/api/frame.py,sha256=yQmhh8fckXnqzs7dQvojOzbuSanNGLwUTWUQDXbVtF4,12874
12
12
  jaxsim/api/joint.py,sha256=Pvg_It2iYA-jAQ2nOlFZxwmITiozO_f46G13BdQtHQ0,5106
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
14
- jaxsim/api/link.py,sha256=GlnY7LMne-siFyg9J49IZGhiPQzS9Uk6rzQ0jI8cD_E,18622
15
- jaxsim/api/model.py,sha256=EdSjpKXd4N72wYjg5o0wGKFxjVMyrXg6LnlPEi3JqnU,63094
14
+ jaxsim/api/link.py,sha256=hn7fbxaebHeXnvwEG9jZiWwzRcfdS8m-18LVsIG3S24,18479
15
+ jaxsim/api/model.py,sha256=PMTSz00AIVopwiJ3zGBoYPTtkLH_beJCcQsX9wBE38I,61502
16
16
  jaxsim/api/ode.py,sha256=NnLTBvpaT4kXnbjAghXIzLv9DTMJ8bele2iOlUQDv3Q,11028
17
17
  jaxsim/api/ode_data.py,sha256=9YZX-SK_KJtoIqG-zYWZsQInb2NA_LtxDn-jtLqm_3U,19759
18
- jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
18
+ jaxsim/api/references.py,sha256=UA6kSQVBoq-bXSo99EOELf-_MD5MTy2zS0GtG3wQ410,16618
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
- jaxsim/integrators/common.py,sha256=GqiyKTrAozuR6RuvVWdPF7locZQAXSEDY2AjTKpFGYM,20149
20
+ jaxsim/integrators/common.py,sha256=iwFykYZxdchqJcmcx8MFWEVijS5Hx9wCNKLKAJdF4gE,20103
21
21
  jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
22
22
  jaxsim/integrators/variable_step.py,sha256=0FCmAZIFnhvQxVbAzNfZgCWN1yMRTGVdBm9UwwaXI1o,21280
23
23
  jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
24
- jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
24
+ jaxsim/math/adjoint.py,sha256=DT21izjVW497GRrgNfx8tv0ZeWW5QncWMGMhI0acUNw,4425
25
25
  jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
26
26
  jaxsim/math/inertia.py,sha256=UAB7ym4gXFanejcs_ovZMpteHCc6poWYmt-mLmd5hhk,1640
27
- jaxsim/math/joint_model.py,sha256=VZ3hRCgb0gsyI1wN1UdHkmaRMBxjAYQK1i_3WIvkdUA,9994
28
- jaxsim/math/quaternion.py,sha256=_WA7W3iv7px83sWO1V1n0-J78hqAlO4SL1-jofE-UZ4,4754
27
+ jaxsim/math/joint_model.py,sha256=cVD9G8tBCsYtXC-r2BkVYO8Jg_km_BhJr7dezh3x6Rw,9995
28
+ jaxsim/math/quaternion.py,sha256=A05m7syBTIpl3SrsB7F76NNbExtUyJAdvHMavjNManI,4863
29
29
  jaxsim/math/rotation.py,sha256=Z90daUjGpuNEVLfWB3SVtM9EtwAIaneVj9A9UpWXqhA,2182
30
30
  jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
31
- jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
31
+ jaxsim/math/transform.py,sha256=_5kSnfkS6_vxvjxdw50KeXMjvW8e1OGaumUlk1iGJgc,2969
32
32
  jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
33
33
  jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
34
- jaxsim/mujoco/loaders.py,sha256=t9kWyvTJoMChBs8WMaZEPH1Y_smp_A9go6XT6rHcPEU,25301
35
- jaxsim/mujoco/model.py,sha256=ZqqHBDnB-y8ueHydD0Ujbg4ALyLUpB_6r_9r0sENQvI,16359
34
+ jaxsim/mujoco/loaders.py,sha256=He55jmkC5wQpMhEIDHOXXbqgWNjJ2fx16wOTStp_3PA,25111
35
+ jaxsim/mujoco/model.py,sha256=EwUPg9BsNv1B7TdDfjZCpC022lDR16AyIAajPJGH7NU,16357
36
36
  jaxsim/mujoco/visualizer.py,sha256=XvMzGSHM-xnOSYl1Vk6bPe6j6ylQmJeLOgxHgL6I1nw,6966
37
37
  jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
38
  jaxsim/parsers/kinematic_graph.py,sha256=88d0EmndVJWdcyFJsW25S78Z8F04cUt08RQMyoil1Xw,34734
@@ -45,24 +45,24 @@ jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrG
45
45
  jaxsim/parsers/rod/parser.py,sha256=B8fnnL3LFNfCNTTFhX_OeQZhTlRgwPFCNKcUVL94-rY,13528
46
46
  jaxsim/parsers/rod/utils.py,sha256=5DsF3OeePZGidOJ5GiFSZx-51uIdnFvMW9EK6SgOW6Q,5698
47
47
  jaxsim/rbda/__init__.py,sha256=H7DhXpxkPOi9lpUvg31IMHFfRafke1UoJLc5GQIdyhA,387
48
- jaxsim/rbda/aba.py,sha256=w7ciyxB0IsmueatT0C7PcBQEl9dyiH9oqJgIi3xeTUE,8983
49
- jaxsim/rbda/collidable_points.py,sha256=Rmf1DhflhOTYh9mDalv0agS0CGSbmfoOybwP2KzKuJ0,4883
50
- jaxsim/rbda/crba.py,sha256=zJSiHKRvNU98z2tT9prrWR4VU9wIZQWFwEut7mua6as,5044
51
- jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdulqStA,3458
52
- jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
53
- jaxsim/rbda/rnea.py,sha256=LGXD6s3NigaVy4-WxoROjnbKLZcUoyFmS9UNu_4ldjo,7568
48
+ jaxsim/rbda/aba.py,sha256=IyeeCOF5nD-WSkRT5nMYtLuC0RWiyJQHlcyWDjQqliQ,9041
49
+ jaxsim/rbda/collidable_points.py,sha256=fQBZonoiLSSgHNpsa4mwe5wsA1j7jb2b_0D9z_oqKWo,4941
50
+ jaxsim/rbda/crba.py,sha256=NhtZO48OUKKor7ddY7mB7h7a6idrmOyf0Vy4p7UCCgI,4724
51
+ jaxsim/rbda/forward_kinematics.py,sha256=OEQYovnLKsWphUKhigmWa_384LwZW3Csp0MKufw4e1M,3415
52
+ jaxsim/rbda/jacobian.py,sha256=I6mrlkk7Cpq3CE7k_tajOHCbT6vf2pW6vMS0TKNCnng,10725
53
+ jaxsim/rbda/rnea.py,sha256=UrhcL93fp3pAKlGxOPS6X47L0ferH50bcSMzG55t4zY,7626
54
54
  jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
55
55
  jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
56
  jaxsim/rbda/contacts/common.py,sha256=iMKLP30Qft9eGTiHo2iY-UoACJjg1JphA9_pW8wRdjc,2410
57
- jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
57
+ jaxsim/rbda/contacts/soft.py,sha256=3cDynim_tIgcbzRuqpHN82v4ELlxxK6lR-PG0haSK7Q,15660
58
58
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
59
59
  jaxsim/terrain/terrain.py,sha256=ctyNANIFSM3tZmamprjaEDcWgUSP0oNJbmT1zw9RjPs,4565
60
60
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
61
61
  jaxsim/utils/jaxsim_dataclass.py,sha256=fLl1tY3DDb3lpIhG6BPqA5W34hM84oFzL-5cuz8k-68,11379
62
62
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
63
63
  jaxsim/utils/wrappers.py,sha256=GOJQCJc5zwzoEGZB62wnWWGvUUQlXvDxz_A2Q-hFv7c,4027
64
- jaxsim-0.4.2.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
- jaxsim-0.4.2.dist-info/METADATA,sha256=7k6l7OO00B6czQi7O1MPCpjcbL1qxeSwy19khRT4FXc,17221
66
- jaxsim-0.4.2.dist-info/WHEEL,sha256=nCVcAvsfA9TDtwGwhYaRrlPhTLV9m-Ga6mdyDtuwK18,91
67
- jaxsim-0.4.2.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
- jaxsim-0.4.2.dist-info/RECORD,,
64
+ jaxsim-0.4.2.dev12.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
+ jaxsim-0.4.2.dev12.dist-info/METADATA,sha256=bxPCWzVduanZPJ4ZuYkXTKkRCyyu-rjDSRd64c4oKpw,16826
66
+ jaxsim-0.4.2.dev12.dist-info/WHEEL,sha256=-oYQCr74JF3a37z2nRlQays_SX2MqOANoqVjBBAP2yE,91
67
+ jaxsim-0.4.2.dev12.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
+ jaxsim-0.4.2.dev12.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (73.0.0)
2
+ Generator: setuptools (71.0.3)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5