jaxsim 0.3.1.dev64__py3-none-any.whl → 0.3.1.dev94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. jaxsim/__init__.py +5 -5
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/com.py +3 -4
  4. jaxsim/api/common.py +11 -11
  5. jaxsim/api/contact.py +11 -3
  6. jaxsim/api/data.py +3 -6
  7. jaxsim/api/frame.py +9 -10
  8. jaxsim/api/kin_dyn_parameters.py +25 -28
  9. jaxsim/api/link.py +12 -12
  10. jaxsim/api/model.py +47 -43
  11. jaxsim/api/ode.py +19 -12
  12. jaxsim/api/ode_data.py +11 -11
  13. jaxsim/integrators/common.py +17 -20
  14. jaxsim/integrators/fixed_step.py +10 -10
  15. jaxsim/integrators/variable_step.py +13 -13
  16. jaxsim/math/__init__.py +2 -1
  17. jaxsim/math/joint_model.py +2 -1
  18. jaxsim/math/quaternion.py +3 -9
  19. jaxsim/math/transform.py +2 -2
  20. jaxsim/mujoco/loaders.py +5 -5
  21. jaxsim/mujoco/model.py +6 -6
  22. jaxsim/mujoco/visualizer.py +3 -0
  23. jaxsim/parsers/__init__.py +0 -1
  24. jaxsim/parsers/descriptions/joint.py +1 -1
  25. jaxsim/parsers/descriptions/link.py +3 -4
  26. jaxsim/parsers/descriptions/model.py +1 -1
  27. jaxsim/parsers/kinematic_graph.py +38 -39
  28. jaxsim/parsers/rod/parser.py +14 -14
  29. jaxsim/parsers/rod/utils.py +9 -11
  30. jaxsim/rbda/aba.py +6 -12
  31. jaxsim/rbda/collidable_points.py +8 -7
  32. jaxsim/rbda/contacts/soft.py +29 -27
  33. jaxsim/rbda/crba.py +3 -3
  34. jaxsim/rbda/forward_kinematics.py +1 -1
  35. jaxsim/rbda/jacobian.py +8 -8
  36. jaxsim/rbda/rnea.py +3 -3
  37. jaxsim/rbda/utils.py +1 -1
  38. jaxsim/terrain/terrain.py +100 -22
  39. jaxsim/typing.py +14 -22
  40. jaxsim/utils/jaxsim_dataclass.py +4 -4
  41. jaxsim/utils/wrappers.py +5 -1
  42. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev94.dist-info}/METADATA +1 -1
  43. jaxsim-0.3.1.dev94.dist-info/RECORD +68 -0
  44. jaxsim-0.3.1.dev64.dist-info/RECORD +0 -68
  45. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev94.dist-info}/LICENSE +0 -0
  46. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev94.dist-info}/WHEEL +0 -0
  47. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev94.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py CHANGED
@@ -33,17 +33,17 @@ def _is_editable() -> bool:
33
33
  import pathlib
34
34
  import site
35
35
 
36
- # Get the ModuleSpec of jaxsim
36
+ # Get the ModuleSpec of jaxsim.
37
37
  jaxsim_spec = importlib.util.find_spec(name="jaxsim")
38
38
 
39
39
  # This can be None. If it's None, assume non-editable installation.
40
40
  if jaxsim_spec.origin is None:
41
41
  return False
42
42
 
43
- # Get the folder containing the jaxsim package
43
+ # Get the folder containing the jaxsim package.
44
44
  jaxsim_package_dir = str(pathlib.Path(jaxsim_spec.origin).parent.parent)
45
45
 
46
- # The installation is editable if the package dir is not in any {site|dist}-packages
46
+ # The installation is editable if the package dir is not in any {site|dist}-packages.
47
47
  return jaxsim_package_dir not in site.getsitepackages()
48
48
 
49
49
 
@@ -82,10 +82,10 @@ def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
82
82
  logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL"))
83
83
 
84
84
 
85
- # Configure JAX
85
+ # Configure JAX.
86
86
  _jnp_options()
87
87
 
88
- # Initialize the numpy print options
88
+ # Initialize the numpy print options.
89
89
  _np_options()
90
90
 
91
91
  del _jnp_options
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.1.dev64'
16
- __version_tuple__ = version_tuple = (0, 3, 1, 'dev64')
15
+ __version__ = version = '0.3.1.dev94'
16
+ __version_tuple__ = version_tuple = (0, 3, 1, 'dev94')
jaxsim/api/com.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
- import jaxlie
4
3
 
5
4
  import jaxsim.api as js
6
5
  import jaxsim.math
@@ -28,7 +27,7 @@ def com_position(
28
27
 
29
28
  W_H_L = js.model.forward_kinematics(model=model, data=data)
30
29
  W_H_B = data.base_transform()
31
- B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
30
+ B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)
32
31
 
33
32
  def B_p̃_LCoM(i) -> jtp.Vector:
34
33
  m = js.link.mass(model=model, link_index=i)
@@ -179,9 +178,9 @@ def locked_centroidal_spatial_inertia(
179
178
  case _:
180
179
  raise ValueError(data.velocity_representation)
181
180
 
182
- B_H_G = jaxlie.SE3.from_matrix(jaxsim.math.Transform.inverse(W_H_B) @ W_H_G)
181
+ B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G
183
182
 
184
- B_Xv_G = B_H_G.adjoint()
183
+ B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G)
185
184
  G_Xf_B = B_Xv_G.transpose()
186
185
 
187
186
  return G_Xf_B @ B_Mbb_B @ B_Xv_G
jaxsim/api/common.py CHANGED
@@ -8,10 +8,10 @@ from typing import ContextManager
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  import jax_dataclasses
11
- import jaxlie
12
11
  from jax_dataclasses import Static
13
12
 
14
13
  import jaxsim.typing as jtp
14
+ from jaxsim.math import Adjoint
15
15
  from jaxsim.utils import JaxsimDataclass, Mutability
16
16
 
17
17
  try:
@@ -59,7 +59,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
59
59
 
60
60
  try:
61
61
 
62
- # First, we replace the velocity representation
62
+ # First, we replace the velocity representation.
63
63
  with self.mutable_context(
64
64
  mutability=Mutability.MUTABLE_NO_VALIDATION,
65
65
  restore_after_exception=True,
@@ -97,7 +97,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
97
97
  array: The 6D quantity to convert.
98
98
  other_representation: The representation to convert to.
99
99
  transform:
100
- The `math:W \mathbf{H}_O` transform, where `math:O` is the
100
+ The :math:`W \mathbf{H}_O` transform, where :math:`O` is the
101
101
  reference frame of the other representation.
102
102
  is_force: Whether the quantity is a 6D force or a 6D velocity.
103
103
 
@@ -122,11 +122,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
122
122
  case VelRepr.Body:
123
123
 
124
124
  if not is_force:
125
- O_Xv_W = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint()
125
+ O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
126
126
  O_array = O_Xv_W @ W_array
127
127
 
128
128
  else:
129
- O_Xf_W = jaxlie.SE3.from_matrix(W_H_O).adjoint().T
129
+ O_Xf_W = Adjoint.from_transform(transform=W_H_O).T
130
130
  O_array = O_Xf_W @ W_array
131
131
 
132
132
  return O_array
@@ -136,11 +136,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
136
136
  W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
137
137
 
138
138
  if not is_force:
139
- OW_Xv_W = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint()
139
+ OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
140
140
  OW_array = OW_Xv_W @ W_array
141
141
 
142
142
  else:
143
- OW_Xf_W = jaxlie.SE3.from_matrix(W_H_OW).adjoint().transpose()
143
+ OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T
144
144
  OW_array = OW_Xf_W @ W_array
145
145
 
146
146
  return OW_array
@@ -190,11 +190,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
190
190
  O_array = array
191
191
 
192
192
  if not is_force:
193
- W_Xv_O: jtp.Array = jaxlie.SE3.from_matrix(W_H_O).adjoint()
193
+ W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O)
194
194
  W_array = W_Xv_O @ O_array
195
195
 
196
196
  else:
197
- W_Xf_O = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint().T
197
+ W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T
198
198
  W_array = W_Xf_O @ O_array
199
199
 
200
200
  return W_array
@@ -205,11 +205,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
205
205
  W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
206
206
 
207
207
  if not is_force:
208
- W_Xv_BW: jtp.Array = jaxlie.SE3.from_matrix(W_H_OW).adjoint()
208
+ W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW)
209
209
  W_array = W_Xv_BW @ BW_array
210
210
 
211
211
  else:
212
- W_Xf_BW = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint().T
212
+ W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T
213
213
  W_array = W_Xf_BW @ BW_array
214
214
 
215
215
  return W_array
jaxsim/api/contact.py CHANGED
@@ -8,7 +8,7 @@ import jax.numpy as jnp
8
8
  import jaxsim.api as js
9
9
  import jaxsim.terrain
10
10
  import jaxsim.typing as jtp
11
- from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsParams
11
+ from jaxsim.rbda.contacts.soft import SoftContactsParams
12
12
 
13
13
  from .common import VelRepr
14
14
 
@@ -137,9 +137,17 @@ def collidable_point_dynamics(
137
137
  # all collidable points belonging to the robot.
138
138
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
139
139
 
140
+ # Import privately the soft contacts classes.
141
+ from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
142
+
140
143
  # Build the soft contact model.
141
144
  match model.contact_model:
142
- case s if isinstance(s, SoftContacts):
145
+
146
+ case SoftContacts():
147
+
148
+ assert isinstance(model.contact_model, SoftContacts)
149
+ assert isinstance(data.state.contact, SoftContactsState)
150
+
143
151
  # Build the contact model.
144
152
  soft_contacts = SoftContacts(
145
153
  parameters=data.contacts_params, terrain=model.terrain
@@ -337,7 +345,7 @@ def jacobian(
337
345
  The output velocity representation of the free-floating jacobian.
338
346
 
339
347
  Returns:
340
- The stacked 6×(6+n) free-floating jacobians of the frames associated to the
348
+ The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the
341
349
  collidable points.
342
350
 
343
351
  Note:
jaxsim/api/data.py CHANGED
@@ -8,7 +8,6 @@ import jax
8
8
  import jax.numpy as jnp
9
9
  import jax_dataclasses
10
10
  import jaxlie
11
- import numpy as np
12
11
 
13
12
  import jaxsim.api as js
14
13
  import jaxsim.rbda
@@ -390,7 +389,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
390
389
  ).astype(float)
391
390
 
392
391
  @jax.jit
393
- def base_transform(self) -> jtp.MatrixJax:
392
+ def base_transform(self) -> jtp.Matrix:
394
393
  """
395
394
  Get the base transform.
396
395
 
@@ -625,9 +624,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
625
624
 
626
625
  W_p_B = base_pose[0:3, 3]
627
626
 
628
- to_wxyz = np.array([3, 0, 1, 2])
629
- W_R_B: jaxlie.SO3 = jaxlie.SO3.from_matrix(base_pose[0:3, 0:3]) # noqa
630
- W_Q_B = W_R_B.as_quaternion_xyzw()[to_wxyz]
627
+ W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
631
628
 
632
629
  return self.reset_base_position(base_position=W_p_B).reset_base_quaternion(
633
630
  base_quaternion=W_Q_B
@@ -815,7 +812,7 @@ def random_model_data(
815
812
 
816
813
  physics_model_state.base_quaternion = jaxlie.SO3.from_rpy_radians(
817
814
  *jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
818
- ).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]
815
+ ).wxyz
819
816
 
820
817
  if model.number_of_joints() > 0:
821
818
  physics_model_state.joint_positions = js.joint.random_joint_positions(
jaxsim/api/frame.py CHANGED
@@ -3,12 +3,11 @@ from typing import Sequence
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
- import jaxlie
7
6
 
8
7
  import jaxsim.api as js
9
- import jaxsim.math
10
8
  import jaxsim.typing as jtp
11
9
  from jaxsim import exceptions
10
+ from jaxsim.math import Adjoint, Transform
12
11
 
13
12
  from .common import VelRepr
14
13
 
@@ -189,7 +188,7 @@ def jacobian(
189
188
  frame_index: jtp.IntLike,
190
189
  output_vel_repr: VelRepr | None = None,
191
190
  ) -> jtp.Matrix:
192
- """
191
+ r"""
193
192
  Compute the free-floating jacobian of the frame.
194
193
 
195
194
  Args:
@@ -200,7 +199,7 @@ def jacobian(
200
199
  The output velocity representation of the free-floating jacobian.
201
200
 
202
201
  Returns:
203
- The 6×(6+n) free-floating jacobian of the frame.
202
+ The :math:`6 \times (6+n)` free-floating jacobian of the frame.
204
203
 
205
204
  Note:
206
205
  The input representation of the free-floating jacobian is the active
@@ -228,29 +227,29 @@ def jacobian(
228
227
  model=model, data=data, link_index=L, output_vel_repr=VelRepr.Body
229
228
  )
230
229
 
231
- # Adjust the output representation
230
+ # Adjust the output representation.
232
231
  match output_vel_repr:
233
232
  case VelRepr.Inertial:
234
233
  W_H_L = js.link.transform(model=model, data=data, link_index=L)
235
- W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint()
234
+ W_X_L = Adjoint.from_transform(transform=W_H_L)
236
235
  W_J_WL = W_X_L @ L_J_WL
237
236
  O_J_WL_I = W_J_WL
238
237
 
239
238
  case VelRepr.Body:
240
239
  W_H_L = js.link.transform(model=model, data=data, link_index=L)
241
240
  W_H_F = transform(model=model, data=data, frame_index=frame_index)
242
- F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L
243
- F_X_L = jaxlie.SE3.from_matrix(F_H_L).adjoint()
241
+ F_H_L = Transform.inverse(W_H_F) @ W_H_L
242
+ F_X_L = Adjoint.from_transform(transform=F_H_L)
244
243
  F_J_WL = F_X_L @ L_J_WL
245
244
  O_J_WL_I = F_J_WL
246
245
 
247
246
  case VelRepr.Mixed:
248
247
  W_H_L = js.link.transform(model=model, data=data, link_index=L)
249
248
  W_H_F = transform(model=model, data=data, frame_index=frame_index)
250
- F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L
249
+ F_H_L = Transform.inverse(W_H_F) @ W_H_L
251
250
  FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3))
252
251
  FW_H_L = FW_H_F @ F_H_L
253
- FW_X_L = jaxlie.SE3.from_matrix(FW_H_L).adjoint()
252
+ FW_X_L = Adjoint.from_transform(transform=FW_H_L)
254
253
  FW_J_WL = FW_X_L @ L_J_WL
255
254
  O_J_WL_I = FW_J_WL
256
255
 
@@ -5,11 +5,10 @@ import dataclasses
5
5
  import jax.lax
6
6
  import jax.numpy as jnp
7
7
  import jax_dataclasses
8
- import jaxlie
9
8
  from jax_dataclasses import Static
10
9
 
11
10
  import jaxsim.typing as jtp
12
- from jaxsim.math import Inertia, JointModel, supported_joint_motion
11
+ from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
13
12
  from jaxsim.parsers.descriptions import JointDescription, ModelDescription
14
13
  from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
15
14
 
@@ -168,7 +167,7 @@ class KynDynParameters(JaxsimDataclass):
168
167
  for link in ordered_links
169
168
  if link.parent is not None
170
169
  }
171
- parent_array = jnp.array([-1] + list(parent_array_dict.values()), dtype=int)
170
+ parent_array = jnp.array([-1, *list(parent_array_dict.values())], dtype=int)
172
171
 
173
172
  # Instead of building the support parent array κ(i) for each link of the model,
174
173
  # that has a variable length depending on the number of links connecting the
@@ -432,11 +431,9 @@ class KynDynParameters(JaxsimDataclass):
432
431
  # Compute the overall transforms from the parent to the child of each joint by
433
432
  # composing all the components of our joint model.
434
433
  i_X_λ = jax.vmap(
435
- lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: jaxlie.SE3.from_matrix(
436
- λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i
434
+ lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: Adjoint.from_transform(
435
+ transform=λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, inverse=True
437
436
  )
438
- .inverse()
439
- .adjoint()
440
437
  )(λ_H_pre, pre_H_suc, suc_H_i)
441
438
 
442
439
  return i_X_λ, S
@@ -466,12 +463,12 @@ class KynDynParameters(JaxsimDataclass):
466
463
  def set_link_inertia(
467
464
  self, link_index: int, inertia: jtp.MatrixLike
468
465
  ) -> KynDynParameters:
469
- """
466
+ r"""
470
467
  Set the inertia tensor of a link.
471
468
 
472
469
  Args:
473
470
  link_index: The index of the link.
474
- inertia: The 3×3 inertia tensor of the link.
471
+ inertia: The :math:`3 \times 3` inertia tensor of the link.
475
472
 
476
473
  Returns:
477
474
  The updated kinematic and dynamic parameters of the model.
@@ -569,7 +566,7 @@ class LinkParameters(JaxsimDataclass):
569
566
  index: The index of the link.
570
567
  mass: The mass of the link.
571
568
  inertia_elements:
572
- The unique elements of the 3×3 inertia tensor of the link.
569
+ The unique elements of the :math:`3 \times 3` inertia tensor of the link.
573
570
  center_of_mass:
574
571
  The translation :math:`{}^L \mathbf{p}_{\text{CoM}}` between the origin
575
572
  of the link frame and the link's center of mass, expressed in the
@@ -588,12 +585,12 @@ class LinkParameters(JaxsimDataclass):
588
585
 
589
586
  @staticmethod
590
587
  def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParameters:
591
- """
592
- Build a LinkParameters object from a 6×6 spatial inertia matrix.
588
+ r"""
589
+ Build a LinkParameters object from a :math:`6 \times 6` spatial inertia matrix.
593
590
 
594
591
  Args:
595
592
  index: The index of the link.
596
- M: The 6×6 spatial inertia matrix of the link.
593
+ M: The :math:`6 \times 6` spatial inertia matrix of the link.
597
594
 
598
595
  Returns:
599
596
  The LinkParameters object.
@@ -616,13 +613,13 @@ class LinkParameters(JaxsimDataclass):
616
613
  def build_from_inertial_parameters(
617
614
  index: jtp.IntLike, m: jtp.FloatLike, I: jtp.MatrixLike, c: jtp.VectorLike
618
615
  ) -> LinkParameters:
619
- """
616
+ r"""
620
617
  Build a LinkParameters object from the inertial parameters of a link.
621
618
 
622
619
  Args:
623
620
  index: The index of the link.
624
621
  m: The mass of the link.
625
- I: The 3×3 inertia tensor of the link.
622
+ I: The :math:`3 \times 3` inertia tensor of the link.
626
623
  c: The translation between the link frame and the link's center of mass.
627
624
 
628
625
  Returns:
@@ -676,14 +673,14 @@ class LinkParameters(JaxsimDataclass):
676
673
 
677
674
  @staticmethod
678
675
  def inertia_tensor(params: LinkParameters) -> jtp.Matrix:
679
- """
680
- Return the 3×3 inertia tensor of a link.
676
+ r"""
677
+ Return the :math:`3 \times 3` inertia tensor of a link.
681
678
 
682
679
  Args:
683
680
  params: The link parameters.
684
681
 
685
682
  Returns:
686
- The 3×3 inertia tensor of the link.
683
+ The :math:`3 \times 3` inertia tensor of the link.
687
684
  """
688
685
 
689
686
  return LinkParameters.unflatten_inertia_tensor(
@@ -692,14 +689,14 @@ class LinkParameters(JaxsimDataclass):
692
689
 
693
690
  @staticmethod
694
691
  def spatial_inertia(params: LinkParameters) -> jtp.Matrix:
695
- """
696
- Return the 6×6 spatial inertia matrix of a link.
692
+ r"""
693
+ Return the :math:`6 \times 6` spatial inertia matrix of a link.
697
694
 
698
695
  Args:
699
696
  params: The link parameters.
700
697
 
701
698
  Returns:
702
- The 6×6 spatial inertia matrix of the link.
699
+ The :math:`6 \times 6` spatial inertia matrix of the link.
703
700
  """
704
701
 
705
702
  return Inertia.to_sixd(
@@ -710,11 +707,11 @@ class LinkParameters(JaxsimDataclass):
710
707
 
711
708
  @staticmethod
712
709
  def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector:
713
- """
714
- Flatten a 3×3 inertia tensor into a vector of unique elements.
710
+ r"""
711
+ Flatten a :math:`3 \times 3` inertia tensor into a vector of unique elements.
715
712
 
716
713
  Args:
717
- I: The 3×3 inertia tensor.
714
+ I: The :math:`3 \times 3` inertia tensor.
718
715
 
719
716
  Returns:
720
717
  The vector of unique elements of the inertia tensor.
@@ -724,14 +721,14 @@ class LinkParameters(JaxsimDataclass):
724
721
 
725
722
  @staticmethod
726
723
  def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix:
727
- """
728
- Unflatten a vector of unique elements into a 3×3 inertia tensor.
724
+ r"""
725
+ Unflatten a vector of unique elements into a :math:`3 \times 3` inertia tensor.
729
726
 
730
727
  Args:
731
728
  inertia_elements: The vector of unique elements of the inertia tensor.
732
729
 
733
730
  Returns:
734
- The 3×3 inertia tensor.
731
+ The :math:`3 \times 3` inertia tensor.
735
732
  """
736
733
 
737
734
  I = jnp.zeros([3, 3]).at[jnp.triu_indices(3)].set(inertia_elements.squeeze())
@@ -792,7 +789,7 @@ class ContactParameters(JaxsimDataclass):
792
789
  )
793
790
 
794
791
  # Build the ContactParameters object.
795
- cp = ContactParameters(point=points, body=link_index_of_points) # noqa
792
+ cp = ContactParameters(point=points, body=link_index_of_points)
796
793
 
797
794
  assert cp.point.shape[1] == 3, cp.point.shape[1]
798
795
  assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]
jaxsim/api/link.py CHANGED
@@ -4,12 +4,12 @@ from typing import Sequence
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import jax.scipy.linalg
7
- import jaxlie
8
7
 
9
8
  import jaxsim.api as js
10
9
  import jaxsim.rbda
11
10
  import jaxsim.typing as jtp
12
11
  from jaxsim import exceptions
12
+ from jaxsim.math import Adjoint
13
13
 
14
14
  from .common import VelRepr
15
15
 
@@ -134,7 +134,7 @@ def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
134
134
  def spatial_inertia(
135
135
  model: js.model.JaxSimModel, *, link_index: jtp.IntLike
136
136
  ) -> jtp.Matrix:
137
- """
137
+ r"""
138
138
  Compute the 6D spatial inertial of the link.
139
139
 
140
140
  Args:
@@ -142,7 +142,7 @@ def spatial_inertia(
142
142
  link_index: The index of the link.
143
143
 
144
144
  Returns:
145
- The 6×6 matrix representing the spatial inertia of the link expressed in
145
+ The :math:`6 \times 6` matrix representing the spatial inertia of the link expressed in
146
146
  the link frame (body-fixed representation).
147
147
  """
148
148
 
@@ -243,7 +243,7 @@ def jacobian(
243
243
  link_index: jtp.IntLike,
244
244
  output_vel_repr: VelRepr | None = None,
245
245
  ) -> jtp.Matrix:
246
- """
246
+ r"""
247
247
  Compute the free-floating jacobian of the link.
248
248
 
249
249
  Args:
@@ -254,7 +254,7 @@ def jacobian(
254
254
  The output velocity representation of the free-floating jacobian.
255
255
 
256
256
  Returns:
257
- The 6×(6+n) free-floating jacobian of the link.
257
+ The :math:`6 \times (6+n)` free-floating jacobian of the link.
258
258
 
259
259
  Note:
260
260
  The input representation of the free-floating jacobian is the active
@@ -287,7 +287,7 @@ def jacobian(
287
287
  match data.velocity_representation:
288
288
  case VelRepr.Inertial:
289
289
  W_H_B = data.base_transform()
290
- B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
290
+ B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
291
291
  B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(
292
292
  B_X_W, jnp.eye(model.dofs())
293
293
  )
@@ -298,7 +298,7 @@ def jacobian(
298
298
  case VelRepr.Mixed:
299
299
  W_R_B = data.base_orientation(dcm=True)
300
300
  BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
301
- B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
301
+ B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
302
302
  B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(
303
303
  B_X_BW, jnp.eye(model.dofs())
304
304
  )
@@ -312,11 +312,11 @@ def jacobian(
312
312
  match output_vel_repr:
313
313
  case VelRepr.Inertial:
314
314
  W_H_B = data.base_transform()
315
- W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
315
+ W_X_B = Adjoint.from_transform(transform=W_H_B)
316
316
  O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I
317
317
 
318
318
  case VelRepr.Body:
319
- L_X_B = jaxlie.SE3.from_matrix(B_H_L).inverse().adjoint()
319
+ L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True)
320
320
  L_J_WL_I = L_X_B @ B_J_WL_I
321
321
  O_J_WL_I = L_J_WL_I
322
322
 
@@ -325,7 +325,7 @@ def jacobian(
325
325
  W_H_L = W_H_B @ B_H_L
326
326
  LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
327
327
  LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
328
- LW_X_B = jaxlie.SE3.from_matrix(LW_H_B).adjoint()
328
+ LW_X_B = Adjoint.from_transform(transform=LW_H_B)
329
329
  LW_J_WL_I = LW_X_B @ B_J_WL_I
330
330
  O_J_WL_I = LW_J_WL_I
331
331
 
@@ -393,7 +393,7 @@ def jacobian_derivative(
393
393
  link_index: jtp.IntLike,
394
394
  output_vel_repr: VelRepr | None = None,
395
395
  ) -> jtp.Matrix:
396
- """
396
+ r"""
397
397
  Compute the derivative of the free-floating jacobian of the link.
398
398
 
399
399
  Args:
@@ -404,7 +404,7 @@ def jacobian_derivative(
404
404
  The output velocity representation of the free-floating jacobian derivative.
405
405
 
406
406
  Returns:
407
- The derivative of the 6×(6+n) free-floating jacobian of the link.
407
+ The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the link.
408
408
 
409
409
  Note:
410
410
  The input representation of the free-floating jacobian derivative is the active