jaxsim 0.4.1.dev24__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/api/model.py CHANGED
@@ -301,10 +301,10 @@ class JaxSimModel(JaxsimDataclass):
301
301
 
302
302
  def frame_names(self) -> tuple[str, ...]:
303
303
  """
304
- Return the names of the links in the model.
304
+ Return the names of the frames in the model.
305
305
 
306
306
  Returns:
307
- The names of the links in the model.
307
+ The names of the frames in the model.
308
308
  """
309
309
 
310
310
  return self.kin_dyn_parameters.frame_parameters.name
@@ -495,8 +495,9 @@ def generalized_free_floating_jacobian(
495
495
  W_H_B = data.base_transform()
496
496
  B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
497
497
 
498
- B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
499
- B_X_W, jnp.eye(model.dofs())
498
+ B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841
499
+ B_J_full_WX_B
500
+ @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
500
501
  )
501
502
 
502
503
  case VelRepr.Body:
@@ -509,7 +510,7 @@ def generalized_free_floating_jacobian(
509
510
  BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
510
511
  B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
511
512
 
512
- B_J_full_WX_I = B_J_full_WX_BW = (
513
+ B_J_full_WX_I = B_J_full_WX_BW = ( # noqa: F841
513
514
  B_J_full_WX_B
514
515
  @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
515
516
  )
@@ -542,11 +543,13 @@ def generalized_free_floating_jacobian(
542
543
  W_H_B = data.base_transform()
543
544
  W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)
544
545
 
545
- O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I)
546
+ O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841
547
+ lambda B_J_WL_I: W_X_B @ B_J_WL_I
548
+ )(B_J_WL_I)
546
549
 
547
550
  case VelRepr.Body:
548
551
 
549
- O_J_WL_I = L_J_WL_I = jax.vmap(
552
+ O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841
550
553
  lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
551
554
  B_H_L, inverse=True
552
555
  )
@@ -565,7 +568,7 @@ def generalized_free_floating_jacobian(
565
568
  lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
566
569
  )(LW_H_L, B_H_L)
567
570
 
568
- O_J_WL_I = LW_J_WL_I = jax.vmap(
571
+ O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841
569
572
  lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B)
570
573
  @ B_J_WL_I
571
574
  )(LW_H_B, B_J_WL_I)
@@ -576,6 +579,41 @@ def generalized_free_floating_jacobian(
576
579
  return O_J_WL_I
577
580
 
578
581
 
582
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
583
+ def generalized_free_floating_jacobian_derivative(
584
+ model: JaxSimModel,
585
+ data: js.data.JaxSimModelData,
586
+ *,
587
+ output_vel_repr: VelRepr | None = None,
588
+ ) -> jtp.Matrix:
589
+ """
590
+ Compute the free-floating jacobian derivatives of all links.
591
+
592
+ Args:
593
+ model: The model to consider.
594
+ data: The data of the considered model.
595
+ output_vel_repr:
596
+ The output velocity representation of the free-floating jacobian derivatives.
597
+
598
+ Returns:
599
+ The `(nL, 6, 6+dofs)` array containing the stacked free-floating
600
+ jacobian derivatives of the links. The first axis is the link index.
601
+ """
602
+
603
+ output_vel_repr = (
604
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
605
+ )
606
+
607
+ O_J̇_WL_I = jax.vmap(
608
+ lambda model, data, link_idxs, output_vel_repr: js.link.jacobian_derivative(
609
+ model, data, link_index=link_idxs, output_vel_repr=output_vel_repr
610
+ ),
611
+ in_axes=(None, None, 0, None),
612
+ )(model, data, jnp.arange(model.number_of_links()), output_vel_repr)
613
+
614
+ return O_J̇_WL_I
615
+
616
+
579
617
  @functools.partial(jax.jit, static_argnames=["prefer_aba"])
580
618
  def forward_dynamics(
581
619
  model: JaxSimModel,
@@ -721,8 +759,8 @@ def forward_dynamics_aba(
721
759
  match data.velocity_representation:
722
760
  case VelRepr.Inertial:
723
761
  # In this case C=W
724
- W_H_C = W_H_W = jnp.eye(4)
725
- W_v_WC = W_v_WW = jnp.zeros(6)
762
+ W_H_C = W_H_W = jnp.eye(4) # noqa: F841
763
+ W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
726
764
 
727
765
  case VelRepr.Body:
728
766
  # In this case C=B
@@ -732,9 +770,9 @@ def forward_dynamics_aba(
732
770
  case VelRepr.Mixed:
733
771
  # In this case C=B[W]
734
772
  W_H_B = data.base_transform()
735
- W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
773
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
736
774
  W_ṗ_B = data.base_velocity()[0:3]
737
- W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
775
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
738
776
 
739
777
  case _:
740
778
  raise ValueError(data.velocity_representation)
@@ -1089,8 +1127,8 @@ def inverse_dynamics(
1089
1127
 
1090
1128
  match data.velocity_representation:
1091
1129
  case VelRepr.Inertial:
1092
- W_H_C = W_H_W = jnp.eye(4)
1093
- W_v_WC = W_v_WW = jnp.zeros(6)
1130
+ W_H_C = W_H_W = jnp.eye(4) # noqa: F841
1131
+ W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
1094
1132
 
1095
1133
  case VelRepr.Body:
1096
1134
  W_H_C = W_H_B = data.base_transform()
@@ -1099,9 +1137,9 @@ def inverse_dynamics(
1099
1137
 
1100
1138
  case VelRepr.Mixed:
1101
1139
  W_H_B = data.base_transform()
1102
- W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
1140
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
1103
1141
  W_ṗ_B = data.base_velocity()[0:3]
1104
- W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
1142
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
1105
1143
 
1106
1144
  case _:
1107
1145
  raise ValueError(data.velocity_representation)
@@ -1536,15 +1574,15 @@ def link_bias_accelerations(
1536
1574
  # a simple C_X_W 6D transform.
1537
1575
  match data.velocity_representation:
1538
1576
  case VelRepr.Inertial:
1539
- W_H_C = W_H_W = jnp.eye(4)
1540
- W_v_WC = W_v_WW = jnp.zeros(6)
1577
+ W_H_C = W_H_W = jnp.eye(4) # noqa: F841
1578
+ W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
1541
1579
  with data.switch_velocity_representation(VelRepr.Inertial):
1542
1580
  C_v_WB = W_v_WB = data.base_velocity()
1543
1581
 
1544
1582
  case VelRepr.Body:
1545
1583
  W_H_C = W_H_B
1546
1584
  with data.switch_velocity_representation(VelRepr.Inertial):
1547
- W_v_WC = W_v_WB = data.base_velocity()
1585
+ W_v_WC = W_v_WB = data.base_velocity() # noqa: F841
1548
1586
  with data.switch_velocity_representation(VelRepr.Body):
1549
1587
  C_v_WB = B_v_WB = data.base_velocity()
1550
1588
 
@@ -1555,9 +1593,9 @@ def link_bias_accelerations(
1555
1593
  W_ṗ_B = data.base_velocity()[0:3]
1556
1594
  BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
1557
1595
  W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
1558
- W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW
1596
+ W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841
1559
1597
  with data.switch_velocity_representation(VelRepr.Mixed):
1560
- C_v_WB = BW_v_WB = data.base_velocity()
1598
+ C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841
1561
1599
 
1562
1600
  case _:
1563
1601
  raise ValueError(data.velocity_representation)
@@ -1665,8 +1703,12 @@ def link_bias_accelerations(
1665
1703
 
1666
1704
  match data.velocity_representation:
1667
1705
  case VelRepr.Body:
1668
- C_H_L = L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links())
1669
- L_v_CL = L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6))
1706
+ C_H_L = L_H_L = jnp.stack( # noqa: F841
1707
+ [jnp.eye(4)] * model.number_of_links()
1708
+ )
1709
+ L_v_CL = L_v_LL = jnp.zeros( # noqa: F841
1710
+ shape=(model.number_of_links(), 6)
1711
+ )
1670
1712
 
1671
1713
  case VelRepr.Inertial:
1672
1714
  C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
@@ -1676,7 +1718,9 @@ def link_bias_accelerations(
1676
1718
  W_H_L = js.model.forward_kinematics(model=model, data=data)
1677
1719
  LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
1678
1720
  C_H_L = LW_H_L
1679
- L_v_CL = L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL)
1721
+ L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
1722
+ lambda v: v.at[0:3].set(jnp.zeros(3))
1723
+ )(L_v_WL)
1680
1724
 
1681
1725
  case _:
1682
1726
  raise ValueError(data.velocity_representation)
jaxsim/api/references.py CHANGED
@@ -8,6 +8,7 @@ import jax_dataclasses
8
8
 
9
9
  import jaxsim.api as js
10
10
  import jaxsim.typing as jtp
11
+ from jaxsim import exceptions
11
12
  from jaxsim.utils.tracing import not_tracing
12
13
 
13
14
  from .common import VelRepr
@@ -30,6 +31,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
30
31
  @staticmethod
31
32
  def zero(
32
33
  model: js.model.JaxSimModel,
34
+ data: js.data.JaxSimModelData | None = None,
33
35
  velocity_representation: VelRepr = VelRepr.Inertial,
34
36
  ) -> JaxSimModelReferences:
35
37
  """
@@ -37,6 +39,9 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
37
39
 
38
40
  Args:
39
41
  model: The model for which to create the zero references.
42
+ data:
43
+ The data of the model, only needed if the velocity representation is
44
+ not inertial-fixed.
40
45
  velocity_representation: The velocity representation to use.
41
46
 
42
47
  Returns:
@@ -44,7 +49,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
44
49
  """
45
50
 
46
51
  return JaxSimModelReferences.build(
47
- model=model, velocity_representation=velocity_representation
52
+ model=model, data=data, velocity_representation=velocity_representation
48
53
  )
49
54
 
50
55
  @staticmethod
@@ -441,3 +446,104 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
441
446
  return replace(
442
447
  forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
443
448
  )
449
+
450
+ def apply_frame_forces(
451
+ self,
452
+ forces: jtp.MatrixLike,
453
+ model: js.model.JaxSimModel,
454
+ data: js.data.JaxSimModelData,
455
+ frame_names: tuple[str, ...] | str | None = None,
456
+ additive: bool = False,
457
+ ) -> Self:
458
+ """
459
+ Apply the frame forces.
460
+
461
+ Args:
462
+ forces: The frame 6D forces in the active representation.
463
+ model:
464
+ The model to consider, only needed if a frame serialization different
465
+ from the implicit one is used.
466
+ data:
467
+ The data of the considered model, only needed if the velocity
468
+ representation is not inertial-fixed.
469
+ frame_names: The names of the frames corresponding to the forces.
470
+ additive:
471
+ Whether to add the forces to the existing ones instead of replacing them.
472
+
473
+ Returns:
474
+ A new `JaxSimModelReferences` object with the given frame forces.
475
+
476
+ Note:
477
+ The frame forces must be expressed in the active representation.
478
+ Then, we always convert and store forces in inertial-fixed representation.
479
+ """
480
+
481
+ f_F = jnp.atleast_2d(forces).astype(float)
482
+
483
+ # If we have the model, we can extract the frame names if not provided.
484
+ frame_names = frame_names if frame_names is not None else model.frame_names()
485
+
486
+ # Make sure that the frame names are a tuple if they are provided by the user.
487
+ frame_names = (frame_names,) if isinstance(frame_names, str) else frame_names
488
+
489
+ if len(frame_names) != f_F.shape[0]:
490
+ msg = "The number of frame names ({}) must match the number of forces ({})"
491
+ raise ValueError(msg.format(len(frame_names), f_F.shape[0]))
492
+
493
+ # Extract the frame indices.
494
+ frame_idxs = js.frame.names_to_idxs(frame_names=frame_names, model=model)
495
+ parent_link_idxs = jax.vmap(js.frame.idx_of_parent_link, in_axes=(None,))(
496
+ model, frame_index=frame_idxs
497
+ )
498
+
499
+ exceptions.raise_value_error_if(
500
+ condition=jnp.logical_not(data.valid(model=model)),
501
+ msg="The provided data is not valid for the model",
502
+ )
503
+ W_H_Fi = jax.vmap(
504
+ lambda frame_idx: js.frame.transform(
505
+ model=model, data=data, frame_index=frame_idx
506
+ )
507
+ )(frame_idxs)
508
+
509
+ # Helper function to convert a single 6D force to the inertial representation
510
+ # considering as body the frame (i.e. L_f_F and LW_f_F).
511
+ def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix:
512
+ return JaxSimModelReferences.other_representation_to_inertial(
513
+ array=f_F,
514
+ other_representation=self.velocity_representation,
515
+ transform=W_H_F,
516
+ is_force=True,
517
+ )
518
+
519
+ match self.velocity_representation:
520
+ case VelRepr.Inertial:
521
+ W_f_F = f_F
522
+
523
+ case VelRepr.Body | VelRepr.Mixed:
524
+ W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi)
525
+
526
+ case _:
527
+ raise ValueError("Invalid velocity representation.")
528
+
529
+ # Sum the forces on the parent links.
530
+ mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links())
531
+ W_f_L = mask.T @ W_f_F
532
+
533
+ with self.switch_velocity_representation(
534
+ velocity_representation=VelRepr.Inertial
535
+ ):
536
+ references = self.apply_link_forces(
537
+ model=model,
538
+ data=data,
539
+ link_names=js.link.idxs_to_names(
540
+ model=model, link_indices=parent_link_idxs
541
+ ),
542
+ forces=W_f_L,
543
+ additive=additive,
544
+ )
545
+
546
+ with references.switch_velocity_representation(
547
+ velocity_representation=self.velocity_representation
548
+ ):
549
+ return references
@@ -261,8 +261,10 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
261
261
  # Check if the Butcher tableau supports FSAL (first-same-as-last).
262
262
  # If it does, store the index of the intermediate derivative to be used as the
263
263
  # first derivative of the next iteration.
264
- has_fsal, index_of_fsal = ExplicitRungeKutta.butcher_tableau_supports_fsal(
265
- A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
264
+ has_fsal, index_of_fsal = ( # noqa: F841
265
+ ExplicitRungeKutta.butcher_tableau_supports_fsal(
266
+ A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
267
+ )
266
268
  )
267
269
 
268
270
  # Build the integrator object.
jaxsim/math/adjoint.py CHANGED
@@ -3,7 +3,6 @@ import jaxlie
3
3
 
4
4
  import jaxsim.typing as jtp
5
5
 
6
- from .quaternion import Quaternion
7
6
  from .skew import Skew
8
7
 
9
8
 
@@ -31,7 +30,7 @@ class Adjoint:
31
30
  assert quaternion.size == 4
32
31
  assert translation.size == 3
33
32
 
34
- Q_sixd = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion))
33
+ Q_sixd = jaxlie.SO3(wxyz=quaternion)
35
34
  Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()
36
35
 
37
36
  return Adjoint.from_rotation_and_translation(
@@ -84,14 +83,14 @@ class Adjoint:
84
83
  A_o_B = translation.squeeze()
85
84
 
86
85
  if not inverse:
87
- X = A_X_B = jnp.vstack(
86
+ X = A_X_B = jnp.vstack( # noqa: F841
88
87
  [
89
88
  jnp.block([A_R_B, Skew.wedge(A_o_B) @ A_R_B]),
90
89
  jnp.block([jnp.zeros(shape=(3, 3)), A_R_B]),
91
90
  ]
92
91
  )
93
92
  else:
94
- X = B_X_A = jnp.vstack(
93
+ X = B_X_A = jnp.vstack( # noqa: F841
95
94
  [
96
95
  jnp.block([A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)]),
97
96
  jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]),
@@ -254,8 +254,8 @@ def supported_joint_motion(
254
254
  # This is a metadata required by only some joint types.
255
255
  axis = jnp.array(joint_axis).astype(float).squeeze()
256
256
 
257
- pre_H_suc = jaxlie.SE3.from_rotation(
258
- rotation=jaxlie.SO3.from_matrix(Rotation.from_axis_angle(vector=s * axis))
257
+ pre_H_suc = jaxlie.SE3.from_matrix(
258
+ matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis))
259
259
  )
260
260
 
261
261
  S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))
jaxsim/math/quaternion.py CHANGED
@@ -43,9 +43,7 @@ class Quaternion:
43
43
  Returns:
44
44
  jtp.Matrix: Direction cosine matrix (DCM).
45
45
  """
46
- return jaxlie.SO3.from_quaternion_xyzw(
47
- xyzw=Quaternion.to_xyzw(quaternion)
48
- ).as_matrix()
46
+ return jaxlie.SO3(wxyz=quaternion).as_matrix()
49
47
 
50
48
  @staticmethod
51
49
  def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
@@ -158,7 +156,7 @@ class Quaternion:
158
156
  A_Q_B = jnp.array(quaternion).squeeze().astype(float)
159
157
 
160
158
  # Build the initial SO(3) quaternion.
161
- W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=A_Q_B))
159
+ W_Q_B_t0 = jaxlie.SO3(wxyz=A_Q_B)
162
160
 
163
161
  # Integrate the quaternion on the manifold.
164
162
  W_Q_B_tf = jax.lax.select(
jaxsim/math/transform.py CHANGED
@@ -3,8 +3,6 @@ import jaxlie
3
3
 
4
4
  import jaxsim.typing as jtp
5
5
 
6
- from .quaternion import Quaternion
7
-
8
6
 
9
7
  class Transform:
10
8
 
@@ -35,7 +33,7 @@ class Transform:
35
33
  assert W_p_B.size == 3
36
34
  assert W_Q_B.size == 4
37
35
 
38
- A_R_B = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(W_Q_B))
36
+ A_R_B = jaxlie.SO3(wxyz=W_Q_B)
39
37
  A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
40
38
 
41
39
  A_H_B = jaxlie.SE3.from_rotation_and_translation(
jaxsim/mujoco/loaders.py CHANGED
@@ -532,7 +532,12 @@ class UrdfToMjcf:
532
532
  model_name: str | None = None,
533
533
  plane_normal: tuple[float, float, float] = (0, 0, 1),
534
534
  heightmap: bool | None = None,
535
- cameras: list[dict[str, str]] | dict[str, str] | None = None,
535
+ cameras: (
536
+ MujocoCamera
537
+ | Sequence[MujocoCamera]
538
+ | dict[str, str]
539
+ | Sequence[dict[str, str]]
540
+ ) = (),
536
541
  ) -> tuple[str, dict[str, Any]]:
537
542
  """
538
543
  Converts a URDF file to a Mujoco MJCF string.
@@ -574,7 +579,12 @@ class SdfToMjcf:
574
579
  model_name: str | None = None,
575
580
  plane_normal: tuple[float, float, float] = (0, 0, 1),
576
581
  heightmap: bool | None = None,
577
- cameras: list[dict[str, str]] | dict[str, str] | None = None,
582
+ cameras: (
583
+ MujocoCamera
584
+ | Sequence[MujocoCamera]
585
+ | dict[str, str]
586
+ | Sequence[dict[str, str]]
587
+ ) = (),
578
588
  ) -> tuple[str, dict[str, Any]]:
579
589
  """
580
590
  Converts a SDF file to a Mujoco MJCF string.
jaxsim/mujoco/model.py CHANGED
@@ -238,9 +238,9 @@ class MujocoModelHelper:
238
238
  raise ValueError("The orientation is not a valid element of SO(3)")
239
239
 
240
240
  W_Q_B = (
241
- Rotation.from_matrix(orientation).as_quat(canonical=True)[
242
- np.array([3, 0, 1, 2])
243
- ]
241
+ Rotation.from_matrix(orientation).as_quat(
242
+ canonical=True, scalar_first=False
243
+ )
244
244
  if dcm
245
245
  else orientation
246
246
  )
@@ -394,8 +394,8 @@ class MujocoModelHelper:
394
394
  if dcm:
395
395
  return R
396
396
 
397
- q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True)
398
- return q_xyzw[[3, 0, 1, 2]]
397
+ q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True, scalar_first=False)
398
+ return q_xyzw
399
399
 
400
400
  # ===============
401
401
  # Private methods
jaxsim/rbda/aba.py CHANGED
@@ -4,7 +4,7 @@ import jaxlie
4
4
 
5
5
  import jaxsim.api as js
6
6
  import jaxsim.typing as jtp
7
- from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
7
+ from jaxsim.math import Adjoint, Cross, StandardGravity
8
8
 
9
9
  from . import utils
10
10
 
@@ -77,7 +77,7 @@ def aba(
77
77
 
78
78
  # Compute the base transform.
79
79
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
80
- rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
80
+ rotation=jaxlie.SO3(wxyz=W_Q_B),
81
81
  translation=W_p_B,
82
82
  )
83
83
 
@@ -4,7 +4,7 @@ import jaxlie
4
4
 
5
5
  import jaxsim.api as js
6
6
  import jaxsim.typing as jtp
7
- from jaxsim.math import Adjoint, Quaternion, Skew
7
+ from jaxsim.math import Adjoint, Skew
8
8
 
9
9
  from . import utils
10
10
 
@@ -57,7 +57,7 @@ def collidable_points_pos_vel(
57
57
 
58
58
  # Compute the base transform.
59
59
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
60
- rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
60
+ rotation=jaxlie.SO3(wxyz=W_Q_B),
61
61
  translation=W_p_B,
62
62
  )
63
63
 
@@ -192,7 +192,7 @@ class SoftContacts(ContactModel):
192
192
 
193
193
  # Unpack the position of the collidable point.
194
194
  px, py, pz = W_p_C = position.squeeze()
195
- vx, vy, vz = W_ṗ_C = velocity.squeeze()
195
+ W_ṗ_C = velocity.squeeze()
196
196
 
197
197
  # Compute the terrain normal and the contact depth.
198
198
  n̂ = self.terrain.normal(x=px, y=py).squeeze()
jaxsim/rbda/crba.py CHANGED
@@ -59,10 +59,14 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
59
59
 
60
60
  return (i_X_0,), None
61
61
 
62
- (i_X_0,), _ = jax.lax.scan(
63
- f=propagate_kinematics,
64
- init=forward_pass_carry,
65
- xs=jnp.arange(start=1, stop=model.number_of_links()),
62
+ (i_X_0,), _ = (
63
+ jax.lax.scan(
64
+ f=propagate_kinematics,
65
+ init=forward_pass_carry,
66
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
67
+ )
68
+ if model.number_of_links() > 1
69
+ else [(i_X_0,), None]
66
70
  )
67
71
 
68
72
  # ===================
@@ -128,10 +132,14 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
128
132
  operand=carry,
129
133
  )
130
134
 
131
- (j, Fi, M), _ = jax.lax.scan(
132
- f=inner_fn,
133
- init=carry_inner_fn,
134
- xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
135
+ (j, Fi, M), _ = (
136
+ jax.lax.scan(
137
+ f=inner_fn,
138
+ init=carry_inner_fn,
139
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
140
+ )
141
+ if model.number_of_links() > 1
142
+ else [(j, Fi, M), None]
135
143
  )
136
144
 
137
145
  Fi = i_X_0[j].T @ Fi
@@ -143,10 +151,14 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
143
151
 
144
152
  # This scan performs the backward pass to compute Mbj, Mjb and Mjj, that
145
153
  # also includes a fake while loop implemented with a scan and two cond.
146
- (Mc, M), _ = jax.lax.scan(
147
- f=backward_pass,
148
- init=backward_pass_carry,
149
- xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
154
+ (Mc, M), _ = (
155
+ jax.lax.scan(
156
+ f=backward_pass,
157
+ init=backward_pass_carry,
158
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
159
+ )
160
+ if model.number_of_links() > 1
161
+ else [(Mc, M), None]
150
162
  )
151
163
 
152
164
  # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.
@@ -4,7 +4,7 @@ import jaxlie
4
4
 
5
5
  import jaxsim.api as js
6
6
  import jaxsim.typing as jtp
7
- from jaxsim.math import Adjoint, Quaternion
7
+ from jaxsim.math import Adjoint
8
8
 
9
9
  from . import utils
10
10
 
@@ -42,7 +42,7 @@ def forward_kinematics_model(
42
42
 
43
43
  # Compute the base transform.
44
44
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
45
- rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
45
+ rotation=jaxlie.SO3(wxyz=W_Q_B),
46
46
  translation=W_p_B,
47
47
  )
48
48
 
@@ -75,10 +75,14 @@ def forward_kinematics_model(
75
75
 
76
76
  return (W_X_i,), None
77
77
 
78
- (W_X_i,), _ = jax.lax.scan(
79
- f=propagate_kinematics,
80
- init=propagate_kinematics_carry,
81
- xs=jnp.arange(start=1, stop=model.number_of_links()),
78
+ (W_X_i,), _ = (
79
+ jax.lax.scan(
80
+ f=propagate_kinematics,
81
+ init=propagate_kinematics_carry,
82
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
83
+ )
84
+ if model.number_of_links() > 1
85
+ else [(W_X_i,), None]
82
86
  )
83
87
 
84
88
  return jax.vmap(Adjoint.to_transform)(W_X_i)