jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/logging.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import enum
2
2
  import logging
3
- from typing import Union
4
3
 
5
4
  import coloredlogs
6
5
 
@@ -20,7 +19,7 @@ def _logger() -> logging.Logger:
20
19
  return logging.getLogger(name=LOGGER_NAME)
21
20
 
22
21
 
23
- def set_logging_level(level: Union[int, LoggingLevel] = LoggingLevel.WARNING):
22
+ def set_logging_level(level: int | LoggingLevel = LoggingLevel.WARNING):
24
23
  if isinstance(level, int):
25
24
  level = LoggingLevel(level)
26
25
 
jaxsim/math/__init__.py CHANGED
@@ -0,0 +1,13 @@
1
+ # Define the default standard gravity constant.
2
+ StandardGravity = 9.81
3
+
4
+ from .adjoint import Adjoint
5
+ from .cross import Cross
6
+ from .inertia import Inertia
7
+ from .quaternion import Quaternion
8
+ from .rotation import Rotation
9
+ from .skew import Skew
10
+ from .transform import Transform
11
+ from .utils import safe_norm
12
+
13
+ from .joint_model import JointModel, supported_joint_motion # isort:skip
jaxsim/math/adjoint.py CHANGED
@@ -1,17 +1,20 @@
1
1
  import jax.numpy as jnp
2
+ import jaxlie
2
3
 
3
4
  import jaxsim.typing as jtp
4
- from jaxsim.sixd import so3
5
5
 
6
- from .quaternion import Quaternion
7
6
  from .skew import Skew
8
7
 
9
8
 
10
9
  class Adjoint:
10
+ """
11
+ A utility class for adjoint matrix operations.
12
+ """
13
+
11
14
  @staticmethod
12
15
  def from_quaternion_and_translation(
13
- quaternion: jtp.Vector = jnp.array([1.0, 0, 0, 0]),
14
- translation: jtp.Vector = jnp.zeros(3),
16
+ quaternion: jtp.Vector | None = None,
17
+ translation: jtp.Vector | None = None,
15
18
  inverse: bool = False,
16
19
  normalize_quaternion: bool = False,
17
20
  ) -> jtp.Matrix:
@@ -19,42 +22,67 @@ class Adjoint:
19
22
  Create an adjoint matrix from a quaternion and a translation.
20
23
 
21
24
  Args:
22
- quaternion (jtp.Vector): A quaternion vector (4D) representing orientation.
23
- translation (jtp.Vector): A translation vector (3D).
24
- inverse (bool): Whether to compute the inverse adjoint. Default is False.
25
- normalize_quaternion (bool): Whether to normalize the quaternion before creating the adjoint.
26
- Default is False.
25
+ quaternion: A quaternion vector (4D) representing orientation.
26
+ translation: A translation vector (3D).
27
+ inverse: Whether to compute the inverse adjoint.
28
+ normalize_quaternion: Whether to normalize the quaternion before creating the adjoint.
27
29
 
28
30
  Returns:
29
- jtp.Matrix: The adjoint matrix.
31
+ The adjoint matrix.
30
32
  """
33
+ quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
34
+ translation = translation if translation is not None else jnp.zeros(3)
31
35
  assert quaternion.size == 4
32
36
  assert translation.size == 3
33
37
 
34
- Q_sixd = so3.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion))
38
+ Q_sixd = jaxlie.SO3(wxyz=quaternion)
35
39
  Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()
36
40
 
37
41
  return Adjoint.from_rotation_and_translation(
38
42
  rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse
39
43
  )
40
44
 
45
+ @staticmethod
46
+ def from_transform(transform: jtp.MatrixLike, inverse: bool = False) -> jtp.Matrix:
47
+ """
48
+ Create an adjoint matrix from a transformation matrix.
49
+
50
+ Args:
51
+ transform: A 4x4 transformation matrix.
52
+ inverse: Whether to compute the inverse adjoint.
53
+
54
+ Returns:
55
+ The 6x6 adjoint matrix.
56
+ """
57
+
58
+ A_H_B = jnp.reshape(transform, (-1, 4, 4))
59
+
60
+ return (
61
+ jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint()
62
+ if not inverse
63
+ else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint()
64
+ ).reshape(transform.shape[:-2] + (6, 6))
65
+
41
66
  @staticmethod
42
67
  def from_rotation_and_translation(
43
- rotation: jtp.Matrix = jnp.eye(3),
44
- translation: jtp.Vector = jnp.zeros(3),
68
+ rotation: jtp.Matrix | None = None,
69
+ translation: jtp.Vector | None = None,
45
70
  inverse: bool = False,
46
71
  ) -> jtp.Matrix:
47
72
  """
48
73
  Create an adjoint matrix from a rotation matrix and a translation vector.
49
74
 
50
75
  Args:
51
- rotation (jtp.Matrix): A 3x3 rotation matrix.
52
- translation (jtp.Vector): A translation vector (3D).
53
- inverse (bool): Whether to compute the inverse adjoint. Default is False.
76
+ rotation: A 3x3 rotation matrix.
77
+ translation: A translation vector (3D).
78
+ inverse: Whether to compute the inverse adjoint. Default is False.
54
79
 
55
80
  Returns:
56
- jtp.Matrix: The adjoint matrix.
81
+ The adjoint matrix.
57
82
  """
83
+ rotation = rotation if rotation is not None else jnp.eye(3)
84
+ translation = translation if translation is not None else jnp.zeros(3)
85
+
58
86
  assert rotation.shape == (3, 3)
59
87
  assert translation.size == 3
60
88
 
@@ -62,14 +90,14 @@ class Adjoint:
62
90
  A_o_B = translation.squeeze()
63
91
 
64
92
  if not inverse:
65
- X = A_X_B = jnp.vstack(
93
+ X = A_X_B = jnp.vstack( # noqa: F841
66
94
  [
67
95
  jnp.block([A_R_B, Skew.wedge(A_o_B) @ A_R_B]),
68
96
  jnp.block([jnp.zeros(shape=(3, 3)), A_R_B]),
69
97
  ]
70
98
  )
71
99
  else:
72
- X = B_X_A = jnp.vstack(
100
+ X = B_X_A = jnp.vstack( # noqa: F841
73
101
  [
74
102
  jnp.block([A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)]),
75
103
  jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]),
@@ -84,10 +112,10 @@ class Adjoint:
84
112
  Convert an adjoint matrix to a transformation matrix.
85
113
 
86
114
  Args:
87
- adjoint (jtp.Matrix): The adjoint matrix (6x6).
115
+ adjoint: The adjoint matrix (6x6).
88
116
 
89
117
  Returns:
90
- jtp.Matrix: The transformation matrix (4x4).
118
+ The transformation matrix (4x4).
91
119
  """
92
120
  X = adjoint.squeeze()
93
121
  assert X.shape == (6, 6)
@@ -110,17 +138,23 @@ class Adjoint:
110
138
  Compute the inverse of an adjoint matrix.
111
139
 
112
140
  Args:
113
- adjoint (jtp.Matrix): The adjoint matrix.
141
+ adjoint: The adjoint matrix.
114
142
 
115
143
  Returns:
116
- jtp.Matrix: The inverse adjoint matrix.
144
+ The inverse adjoint matrix.
117
145
  """
118
- A_X_B = adjoint
119
- A_H_B = Adjoint.to_transform(adjoint=A_X_B)
146
+ A_X_B = adjoint.reshape(-1, 6, 6)
120
147
 
121
- A_R_B = A_H_B[0:3, 0:3]
122
- A_o_B = A_H_B[0:3, 3]
148
+ A_R_B_T = jnp.swapaxes(A_X_B[..., 0:3, 0:3], -2, -1)
149
+ A_T_B = A_X_B[..., 0:3, 3:6]
123
150
 
124
- return Adjoint.from_rotation_and_translation(
125
- rotation=A_R_B, translation=A_o_B, inverse=True
126
- )
151
+ return jnp.concatenate(
152
+ [
153
+ jnp.concatenate(
154
+ [A_R_B_T, -A_R_B_T @ A_T_B @ A_R_B_T],
155
+ axis=-1,
156
+ ),
157
+ jnp.concatenate([jnp.zeros_like(A_R_B_T), A_R_B_T], axis=-1),
158
+ ],
159
+ axis=-2,
160
+ ).reshape(adjoint.shape)
jaxsim/math/cross.py CHANGED
@@ -6,27 +6,36 @@ from .skew import Skew
6
6
 
7
7
 
8
8
  class Cross:
9
+ """
10
+ A utility class for cross product matrix operations.
11
+ """
12
+
9
13
  @staticmethod
10
14
  def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
11
15
  """
12
16
  Compute the cross product matrix for 6D velocities.
13
17
 
14
18
  Args:
15
- velocity_sixd (jtp.Vector): A 6D velocity vector [v, ω].
19
+ velocity_sixd: A 6D velocity vector [v, ω].
16
20
 
17
21
  Returns:
18
- jtp.Matrix: The cross product matrix (6x6).
22
+ The cross product matrix (6x6).
19
23
 
20
24
  Raises:
21
25
  ValueError: If the input vector does not have a size of 6.
22
26
  """
23
- v, ω = jnp.split(velocity_sixd.squeeze(), 2)
27
+ velocity_sixd = velocity_sixd.reshape(-1, 6)
28
+
29
+ v, ω = jnp.split(velocity_sixd, 2, axis=-1)
24
30
 
25
- v_cross = jnp.vstack(
31
+ v_cross = jnp.concatenate(
26
32
  [
27
- jnp.block([Skew.wedge(vector=ω), Skew.wedge(vector=v)]),
28
- jnp.block([jnp.zeros(shape=(3, 3)), Skew.wedge(vector=ω)]),
29
- ]
33
+ jnp.concatenate(
34
+ [Skew.wedge(ω), jnp.zeros((ω.shape[0], 3, 3)).squeeze()], axis=-2
35
+ ),
36
+ jnp.concatenate([Skew.wedge(v), Skew.wedge(ω)], axis=-2),
37
+ ],
38
+ axis=-1,
30
39
  )
31
40
 
32
41
  return v_cross
@@ -37,10 +46,10 @@ class Cross:
37
46
  Compute the negative transpose of the cross product matrix for 6D velocities.
38
47
 
39
48
  Args:
40
- velocity_sixd (jtp.Vector): A 6D velocity vector [v, ω].
49
+ velocity_sixd: A 6D velocity vector [v, ω].
41
50
 
42
51
  Returns:
43
- jtp.Matrix: The negative transpose of the cross product matrix (6x6).
52
+ The negative transpose of the cross product matrix (6x6).
44
53
 
45
54
  Raises:
46
55
  ValueError: If the input vector does not have a size of 6.
jaxsim/math/inertia.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import Tuple
2
-
3
1
  import jax.numpy as jnp
4
2
 
5
3
  import jaxsim.typing as jtp
@@ -8,18 +6,22 @@ from .skew import Skew
8
6
 
9
7
 
10
8
  class Inertia:
9
+ """
10
+ A utility class for inertia matrix operations.
11
+ """
12
+
11
13
  @staticmethod
12
14
  def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:
13
15
  """
14
16
  Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix.
15
17
 
16
18
  Args:
17
- mass (jtp.Float): The mass of the body.
18
- com (jtp.Vector): The center of mass position (3D).
19
- I (jtp.Matrix): The 3x3 inertia matrix.
19
+ mass: The mass of the body.
20
+ com: The center of mass position (3D).
21
+ I: The 3x3 inertia matrix.
20
22
 
21
23
  Returns:
22
- jtp.Matrix: The 6x6 inertia matrix.
24
+ The 6x6 inertia matrix.
23
25
 
24
26
  Raises:
25
27
  ValueError: If the shape of the inertia matrix I is not (3, 3).
@@ -39,15 +41,15 @@ class Inertia:
39
41
  return M
40
42
 
41
43
  @staticmethod
42
- def to_params(M: jtp.Matrix) -> Tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
44
+ def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
43
45
  """
44
46
  Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.
45
47
 
46
48
  Args:
47
- M (jtp.Matrix): The 6x6 inertia matrix.
49
+ M: The 6x6 inertia matrix.
48
50
 
49
51
  Returns:
50
- Tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
52
+ A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
51
53
 
52
54
  Raises:
53
55
  ValueError: If the input matrix M has an unexpected shape.
@@ -0,0 +1,289 @@
1
+ from __future__ import annotations
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jax_dataclasses
6
+ import jaxlie
7
+ from jax_dataclasses import Static
8
+
9
+ import jaxsim.typing as jtp
10
+ from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
11
+ from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
12
+
13
+ from .rotation import Rotation
14
+ from .transform import Transform
15
+
16
+
17
+ @jax_dataclasses.pytree_dataclass
18
+ class JointModel:
19
+ """
20
+ Class describing the joint kinematics of a robot model.
21
+
22
+ Attributes:
23
+ λ_H_pre:
24
+ The homogeneous transformation between the parent link and
25
+ the predecessor frame of each joint.
26
+ suc_H_i:
27
+ The homogeneous transformation between the successor frame and
28
+ the child link of each joint.
29
+ joint_dofs: The number of DoFs of each joint.
30
+ joint_names: The names of each joint.
31
+ joint_types: The types of each joint.
32
+
33
+ Note:
34
+ Due to the presence of the static attributes, this class needs to be created
35
+ already in a vectorized form. In other words, it cannot be created using vmap.
36
+ """
37
+
38
+ λ_H_pre: jtp.Array
39
+ suc_H_i: jtp.Array
40
+
41
+ joint_dofs: Static[tuple[int, ...]]
42
+ joint_names: Static[tuple[str, ...]]
43
+ joint_types: Static[tuple[int, ...]]
44
+ joint_axis: Static[tuple[JointGenericAxis, ...]]
45
+
46
+ @staticmethod
47
+ def build(description: ModelDescription) -> JointModel:
48
+ """
49
+ Build the joint model of a model description.
50
+
51
+ Args:
52
+ description: The model description to consider.
53
+
54
+ Returns:
55
+ The joint model of the considered model description.
56
+ """
57
+
58
+ # The link index is equal to its body index: [0, number_of_bodies - 1].
59
+ ordered_links = sorted(
60
+ list(description.links_dict.values()),
61
+ key=lambda l: l.index,
62
+ )
63
+
64
+ # Note: the joint index is equal to its child link index, therefore it
65
+ # starts from 1.
66
+ ordered_joints = sorted(
67
+ list(description.joints_dict.values()),
68
+ key=lambda j: j.index,
69
+ )
70
+
71
+ # Allocate the parent-to-predecessor and successor-to-child transforms.
72
+ λ_H_pre = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
73
+ suc_H_i = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
74
+
75
+ # Initialize an identical parent-to-predecessor transform for the joint
76
+ # between the world frame W and the base link B.
77
+ λ_H_pre = λ_H_pre.at[0].set(jnp.eye(4))
78
+
79
+ # Initialize the successor-to-child transform of the joint between the
80
+ # world frame W and the base link B.
81
+ # We store here the optional transform between the root frame of the model
82
+ # and the base link frame (this is needed only if the pose of the link frame
83
+ # w.r.t. the implicit __model__ SDF frame is not the identity).
84
+ suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose)
85
+
86
+ # Create the object to compute forward kinematics.
87
+ fk = KinematicGraphTransforms(graph=description)
88
+
89
+ # Compute the parent-to-predecessor and successor-to-child transforms for
90
+ # each joint belonging to the model.
91
+ # Note that the joint indices starts from i=1 given our joint model,
92
+ # therefore the entries at index 0 are not updated.
93
+ for joint in ordered_joints:
94
+ λ_H_pre = λ_H_pre.at[joint.index].set(
95
+ fk.relative_transform(relative_to=joint.parent.name, name=joint.name)
96
+ )
97
+ suc_H_i = suc_H_i.at[joint.index].set(
98
+ fk.relative_transform(relative_to=joint.name, name=joint.child.name)
99
+ )
100
+
101
+ # Define the DoFs of the base link.
102
+ base_dofs = 0 if description.fixed_base else 6
103
+
104
+ # We always add a dummy fixed joint between world and base.
105
+ # TODO: Port floating-base support also at this level, not only in RBDAs.
106
+ return JointModel(
107
+ λ_H_pre=λ_H_pre,
108
+ suc_H_i=suc_H_i,
109
+ # Static attributes
110
+ joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]),
111
+ joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
112
+ joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
113
+ joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
114
+ )
115
+
116
+ def parent_H_child(
117
+ self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
118
+ ) -> tuple[jtp.Matrix, jtp.Array]:
119
+ r"""
120
+ Compute the homogeneous transformation between the parent link and
121
+ the child link of a joint, and the corresponding motion subspace.
122
+
123
+ Args:
124
+ joint_index: The index of the joint.
125
+ joint_position: The position of the joint.
126
+
127
+ Returns:
128
+ A tuple containing the homogeneous transformation
129
+ :math:`{}^{\lambda(i)} \mathbf{H}_i(s)`
130
+ and the motion subspace :math:`\mathbf{S}(s)`.
131
+ """
132
+
133
+ i = joint_index
134
+ s = joint_position
135
+
136
+ # Get the components of the joint model.
137
+ λ_Hi_pre = self.parent_H_predecessor(joint_index=i)
138
+ pre_Hi_suc, S = self.predecessor_H_successor(joint_index=i, joint_position=s)
139
+ suc_Hi_i = self.successor_H_child(joint_index=i)
140
+
141
+ # Compose all the transforms.
142
+ return λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, S
143
+
144
+ @jax.jit
145
+ def child_H_parent(
146
+ self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
147
+ ) -> tuple[jtp.Matrix, jtp.Array]:
148
+ r"""
149
+ Compute the homogeneous transformation between the child link and
150
+ the parent link of a joint, and the corresponding motion subspace.
151
+
152
+ Args:
153
+ joint_index: The index of the joint.
154
+ joint_position: The position of the joint.
155
+
156
+ Returns:
157
+ A tuple containing the homogeneous transformation
158
+ :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
159
+ and the motion subspace :math:`\mathbf{S}(s)`.
160
+ """
161
+
162
+ λ_Hi_i, S = self.parent_H_child(
163
+ joint_index=joint_index, joint_position=joint_position
164
+ )
165
+
166
+ i_Hi_λ = Transform.inverse(λ_Hi_i)
167
+
168
+ return i_Hi_λ, S
169
+
170
+ def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
171
+ r"""
172
+ Return the homogeneous transformation between the parent link and
173
+ the predecessor frame of a joint.
174
+
175
+ Args:
176
+ joint_index: The index of the joint.
177
+
178
+ Returns:
179
+ The homogeneous transformation
180
+ :math:`{}^{\lambda(i)} \mathbf{H}_{\text{pre}(i)}`.
181
+ """
182
+
183
+ return self.λ_H_pre[joint_index]
184
+
185
+ def predecessor_H_successor(
186
+ self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
187
+ ) -> tuple[jtp.Matrix, jtp.Array]:
188
+ r"""
189
+ Compute the homogeneous transformation between the predecessor and
190
+ the successor frame of a joint, and the corresponding motion subspace.
191
+
192
+ Args:
193
+ joint_index: The index of the joint.
194
+ joint_position: The position of the joint.
195
+
196
+ Returns:
197
+ A tuple containing the homogeneous transformation
198
+ :math:`{}^{\text{pre}(i)} \mathbf{H}_{\text{suc}(i)}(s)`
199
+ and the motion subspace :math:`\mathbf{S}(s)`.
200
+ """
201
+
202
+ pre_H_suc, S = supported_joint_motion(
203
+ self.joint_types[joint_index],
204
+ joint_position,
205
+ self.joint_axis[joint_index].axis,
206
+ )
207
+
208
+ return pre_H_suc, S
209
+
210
+ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
211
+ r"""
212
+ Return the homogeneous transformation between the successor frame and
213
+ the child link of a joint.
214
+
215
+ Args:
216
+ joint_index: The index of the joint.
217
+
218
+ Returns:
219
+ The homogeneous transformation
220
+ :math:`{}^{\text{suc}(i)} \mathbf{H}_i`.
221
+ """
222
+
223
+ return self.suc_H_i[joint_index]
224
+
225
+
226
+ @jax.jit
227
+ def supported_joint_motion(
228
+ joint_type: jtp.IntLike,
229
+ joint_position: jtp.VectorLike,
230
+ joint_axis: jtp.VectorLike | None = None,
231
+ /,
232
+ ) -> tuple[jtp.Matrix, jtp.Array]:
233
+ """
234
+ Compute the homogeneous transformation and motion subspace of a joint.
235
+
236
+ Args:
237
+ joint_type: The type of the joint.
238
+ joint_position: The position of the joint.
239
+ joint_axis: The optional 3D axis of rotation or translation of the joint.
240
+
241
+ Returns:
242
+ A tuple containing the homogeneous transformation and the motion subspace.
243
+ """
244
+
245
+ # Prepare the joint position
246
+ s = jnp.array(joint_position).astype(float)
247
+
248
+ def compute_F() -> tuple[jtp.Matrix, jtp.Array]:
249
+ return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))
250
+
251
+ def compute_R() -> tuple[jtp.Matrix, jtp.Array]:
252
+
253
+ # Get the additional argument specifying the joint axis.
254
+ # This is a metadata required by only some joint types.
255
+ axis = jnp.array(joint_axis).astype(float).squeeze()
256
+
257
+ pre_H_suc = jaxlie.SE3.from_matrix(
258
+ matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis))
259
+ )
260
+
261
+ S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))
262
+
263
+ return pre_H_suc, S
264
+
265
+ def compute_P() -> tuple[jtp.Matrix, jtp.Array]:
266
+
267
+ # Get the additional argument specifying the joint axis.
268
+ # This is a metadata required by only some joint types.
269
+ axis = jnp.array(joint_axis).astype(float).squeeze()
270
+
271
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
272
+ rotation=jaxlie.SO3.identity(),
273
+ translation=jnp.array(s * axis),
274
+ )
275
+
276
+ S = jnp.vstack(jnp.hstack([axis, jnp.zeros(3)]))
277
+
278
+ return pre_H_suc, S
279
+
280
+ pre_H_suc, S = jax.lax.switch(
281
+ index=joint_type,
282
+ branches=(
283
+ compute_F, # JointType.Fixed
284
+ compute_R, # JointType.Revolute
285
+ compute_P, # JointType.Prismatic
286
+ ),
287
+ )
288
+
289
+ return pre_H_suc.as_matrix(), S