jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,335 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+ import jaxlie
9
+ from jax_dataclasses import Static
10
+
11
+ import jaxsim.typing as jtp
12
+ from jaxsim.parsers.descriptions import (
13
+ JointDescriptor,
14
+ JointGenericAxis,
15
+ JointType,
16
+ ModelDescription,
17
+ )
18
+
19
+ from .rotation import Rotation
20
+
21
+
22
+ @jax_dataclasses.pytree_dataclass
23
+ class JointModel:
24
+ """
25
+ Class describing the joint kinematics of a robot model.
26
+
27
+ Attributes:
28
+ λ_H_pre:
29
+ The homogeneous transformation between the parent link and
30
+ the predecessor frame of each joint.
31
+ suc_H_i:
32
+ The homogeneous transformation between the successor frame and
33
+ the child link of each joint.
34
+ joint_dofs: The number of DoFs of each joint.
35
+ joint_names: The names of each joint.
36
+ joint_types: The types of each joint.
37
+
38
+ Note:
39
+ Due to the presence of the static attributes, this class needs to be created
40
+ already in a vectorized form. In other words, it cannot be created using vmap.
41
+ """
42
+
43
+ λ_H_pre: jax.Array
44
+ suc_H_i: jax.Array
45
+
46
+ joint_dofs: Static[tuple[int, ...]]
47
+ joint_names: Static[tuple[str, ...]]
48
+ joint_types: Static[tuple[JointType | JointDescriptor, ...]]
49
+
50
+ @staticmethod
51
+ def build(description: ModelDescription) -> JointModel:
52
+ """
53
+ Build the joint model of a model description.
54
+
55
+ Args:
56
+ description: The model description to consider.
57
+
58
+ Returns:
59
+ The joint model of the considered model description.
60
+ """
61
+
62
+ # The link index is equal to its body index: [0, number_of_bodies - 1].
63
+ ordered_links = sorted(
64
+ list(description.links_dict.values()),
65
+ key=lambda l: l.index,
66
+ )
67
+
68
+ # Note: the joint index is equal to its child link index, therefore it
69
+ # starts from 1.
70
+ ordered_joints = sorted(
71
+ list(description.joints_dict.values()),
72
+ key=lambda j: j.index,
73
+ )
74
+
75
+ # Allocate the parent-to-predecessor and successor-to-child transforms.
76
+ λ_H_pre = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
77
+ suc_H_i = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
78
+
79
+ # Initialize an identical parent-to-predecessor transform for the joint
80
+ # between the world frame W and the base link B.
81
+ λ_H_pre = λ_H_pre.at[0].set(jnp.eye(4))
82
+
83
+ # Initialize the successor-to-child transform of the joint between the
84
+ # world frame W and the base link B.
85
+ # We store here the optional transform between the root frame of the model
86
+ # and the base link frame (this is needed only if the pose of the link frame
87
+ # w.r.t. the implicit __model__ SDF frame is not the identity).
88
+ suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose)
89
+
90
+ # Compute the parent-to-predecessor and successor-to-child transforms for
91
+ # each joint belonging to the model.
92
+ # Note that the joint indices starts from i=1 given our joint model,
93
+ # therefore the entries at index 0 are not updated.
94
+ for joint in ordered_joints:
95
+ λ_H_pre = λ_H_pre.at[joint.index].set(
96
+ description.relative_transform(
97
+ relative_to=joint.parent.name,
98
+ name=joint.name,
99
+ )
100
+ )
101
+ suc_H_i = suc_H_i.at[joint.index].set(
102
+ description.relative_transform(
103
+ relative_to=joint.name, name=joint.child.name
104
+ )
105
+ )
106
+
107
+ # Define the DoFs of the base link.
108
+ base_dofs = 0 if description.fixed_base else 6
109
+
110
+ # We always add a dummy fixed joint between world and base.
111
+ # TODO: Port floating-base support also at this level, not only in RBDAs.
112
+ return JointModel(
113
+ λ_H_pre=λ_H_pre,
114
+ suc_H_i=suc_H_i,
115
+ # Static attributes
116
+ joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
117
+ joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
118
+ joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]),
119
+ )
120
+
121
+ def parent_H_child(
122
+ self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
123
+ ) -> tuple[jtp.Matrix, jtp.Array]:
124
+ r"""
125
+ Compute the homogeneous transformation between the parent link and
126
+ the child link of a joint, and the corresponding motion subspace.
127
+
128
+ Args:
129
+ joint_index: The index of the joint.
130
+ joint_position: The position of the joint.
131
+
132
+ Returns:
133
+ A tuple containing the homogeneous transformation
134
+ :math:`{}^{\lambda(i)} \mathbf{H}_i(s)`
135
+ and the motion subspace :math:`\mathbf{S}(s)`.
136
+ """
137
+
138
+ i = joint_index
139
+ s = joint_position
140
+
141
+ # Get the components of the joint model.
142
+ λ_Hi_pre = self.parent_H_predecessor(joint_index=i)
143
+ pre_Hi_suc, S = self.predecessor_H_successor(joint_index=i, joint_position=s)
144
+ suc_Hi_i = self.successor_H_child(joint_index=i)
145
+
146
+ # Compose all the transforms.
147
+ return λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, S
148
+
149
+ @jax.jit
150
+ def child_H_parent(
151
+ self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
152
+ ) -> tuple[jtp.Matrix, jtp.Array]:
153
+ r"""
154
+ Compute the homogeneous transformation between the child link and
155
+ the parent link of a joint, and the corresponding motion subspace.
156
+
157
+ Args:
158
+ joint_index: The index of the joint.
159
+ joint_position: The position of the joint.
160
+
161
+ Returns:
162
+ A tuple containing the homogeneous transformation
163
+ :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
164
+ and the motion subspace :math:`\mathbf{S}(s)`.
165
+ """
166
+
167
+ λ_Hi_i, S = self.parent_H_child(
168
+ joint_index=joint_index, joint_position=joint_position
169
+ )
170
+
171
+ i_Hi_λ = jaxlie.SE3.from_matrix(λ_Hi_i).inverse().as_matrix()
172
+
173
+ return i_Hi_λ, S
174
+
175
+ def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
176
+ r"""
177
+ Return the homogeneous transformation between the parent link and
178
+ the predecessor frame of a joint.
179
+
180
+ Args:
181
+ joint_index: The index of the joint.
182
+
183
+ Returns:
184
+ The homogeneous transformation
185
+ :math:`{}^{\lambda(i)} \mathbf{H}_{\text{pre}(i)}`.
186
+ """
187
+
188
+ return self.λ_H_pre[joint_index]
189
+
190
+ def predecessor_H_successor(
191
+ self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
192
+ ) -> tuple[jtp.Matrix, jtp.Array]:
193
+ r"""
194
+ Compute the homogeneous transformation between the predecessor and
195
+ the successor frame of a joint, and the corresponding motion subspace.
196
+
197
+ Args:
198
+ joint_index: The index of the joint.
199
+ joint_position: The position of the joint.
200
+
201
+ Returns:
202
+ A tuple containing the homogeneous transformation
203
+ :math:`{}^{\text{pre}(i)} \mathbf{H}_{\text{suc}(i)}(s)`
204
+ and the motion subspace :math:`\mathbf{S}(s)`.
205
+ """
206
+
207
+ pre_H_suc, S = supported_joint_motion(
208
+ joint_type=self.joint_types[joint_index],
209
+ joint_position=joint_position,
210
+ )
211
+
212
+ return pre_H_suc, S
213
+
214
+ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
215
+ r"""
216
+ Return the homogeneous transformation between the successor frame and
217
+ the child link of a joint.
218
+
219
+ Args:
220
+ joint_index: The index of the joint.
221
+
222
+ Returns:
223
+ The homogeneous transformation
224
+ :math:`{}^{\text{suc}(i)} \mathbf{H}_i`.
225
+ """
226
+
227
+ return self.suc_H_i[joint_index]
228
+
229
+
230
+ @functools.partial(jax.jit, static_argnames=["joint_type"])
231
+ def supported_joint_motion(
232
+ joint_type: JointType | JointDescriptor, joint_position: jtp.VectorLike
233
+ ) -> tuple[jtp.Matrix, jtp.Array]:
234
+ """
235
+ Compute the homogeneous transformation and motion subspace of a joint.
236
+
237
+ Args:
238
+ joint_type: The type of the joint.
239
+ joint_position: The position of the joint.
240
+
241
+ Returns:
242
+ A tuple containing the homogeneous transformation and the motion subspace.
243
+ """
244
+
245
+ if isinstance(joint_type, JointType):
246
+ code = joint_type
247
+ elif isinstance(joint_type, JointDescriptor):
248
+ code = joint_type.code
249
+ else:
250
+ raise ValueError(joint_type)
251
+
252
+ # Prepare the joint position
253
+ s = jnp.array(joint_position).astype(float)
254
+
255
+ match code:
256
+
257
+ case JointType.R:
258
+ joint_type: JointGenericAxis
259
+
260
+ pre_H_suc = jaxlie.SE3.from_rotation(
261
+ rotation=jaxlie.SO3.from_matrix(
262
+ Rotation.from_axis_angle(vector=s * joint_type.axis)
263
+ )
264
+ )
265
+
266
+ S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_type.axis.squeeze()]))
267
+
268
+ case JointType.P:
269
+ joint_type: JointGenericAxis
270
+
271
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
272
+ rotation=jaxlie.SO3.identity(),
273
+ translation=jnp.array(s * joint_type.axis),
274
+ )
275
+
276
+ S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))
277
+
278
+ case JointType.F:
279
+ raise ValueError("Fixed joints shouldn't be here")
280
+
281
+ case JointType.Rx:
282
+
283
+ pre_H_suc = jaxlie.SE3.from_rotation(
284
+ rotation=jaxlie.SO3.from_x_radians(theta=s)
285
+ )
286
+
287
+ S = jnp.vstack([0, 0, 0, 1.0, 0, 0])
288
+
289
+ case JointType.Ry:
290
+
291
+ pre_H_suc = jaxlie.SE3.from_rotation(
292
+ rotation=jaxlie.SO3.from_y_radians(theta=s)
293
+ )
294
+
295
+ S = jnp.vstack([0, 0, 0, 0, 1.0, 0])
296
+
297
+ case JointType.Rz:
298
+
299
+ pre_H_suc = jaxlie.SE3.from_rotation(
300
+ rotation=jaxlie.SO3.from_z_radians(theta=s)
301
+ )
302
+
303
+ S = jnp.vstack([0, 0, 0, 0, 0, 1.0])
304
+
305
+ case JointType.Px:
306
+
307
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
308
+ rotation=jaxlie.SO3.identity(),
309
+ translation=jnp.array([s, 0.0, 0.0]),
310
+ )
311
+
312
+ S = jnp.vstack([1.0, 0, 0, 0, 0, 0])
313
+
314
+ case JointType.Py:
315
+
316
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
317
+ rotation=jaxlie.SO3.identity(),
318
+ translation=jnp.array([0.0, s, 0.0]),
319
+ )
320
+
321
+ S = jnp.vstack([0, 1.0, 0, 0, 0, 0])
322
+
323
+ case JointType.Pz:
324
+
325
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
326
+ rotation=jaxlie.SO3.identity(),
327
+ translation=jnp.array([0.0, 0.0, s]),
328
+ )
329
+
330
+ S = jnp.vstack([0, 0, 1.0, 0, 0, 0])
331
+
332
+ case _:
333
+ raise ValueError(joint_type)
334
+
335
+ return pre_H_suc.as_matrix(), S
jaxsim/math/quaternion.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import jax.lax
2
2
  import jax.numpy as jnp
3
+ import jaxlie
3
4
 
4
5
  import jaxsim.typing as jtp
5
- from jaxsim.sixd import so3
6
6
 
7
7
 
8
8
  class Quaternion:
@@ -43,7 +43,7 @@ class Quaternion:
43
43
  Returns:
44
44
  jtp.Matrix: Direction cosine matrix (DCM).
45
45
  """
46
- return so3.SO3.from_quaternion_xyzw(
46
+ return jaxlie.SO3.from_quaternion_xyzw(
47
47
  xyzw=Quaternion.to_xyzw(quaternion)
48
48
  ).as_matrix()
49
49
 
@@ -59,7 +59,7 @@ class Quaternion:
59
59
  jtp.Vector: Quaternion in XYZW representation.
60
60
  """
61
61
  return Quaternion.to_wxyz(
62
- xyzw=so3.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
62
+ xyzw=jaxlie.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
63
63
  )
64
64
 
65
65
  @staticmethod
@@ -133,3 +133,44 @@ class Quaternion:
133
133
  )
134
134
 
135
135
  return jnp.vstack(qd)
136
+
137
+ @staticmethod
138
+ def integration(
139
+ quaternion: jtp.VectorLike,
140
+ dt: jtp.FloatLike,
141
+ omega: jtp.VectorLike,
142
+ omega_in_body_fixed: jtp.BoolLike = False,
143
+ ) -> jtp.Vector:
144
+ """
145
+ Integrate a quaternion in SO(3) given an angular velocity.
146
+
147
+ Args:
148
+ quaternion: The quaternion to integrate.
149
+ dt: The time step.
150
+ omega: The angular velocity vector.
151
+ omega_in_body_fixed:
152
+ Whether the angular velocity is in body-fixed representation
153
+ as opposed to the default inertial-fixed representation.
154
+
155
+ Returns:
156
+ The integrated quaternion.
157
+ """
158
+
159
+ ω_AB = jnp.array(omega).squeeze().astype(float)
160
+ A_Q_B = jnp.array(quaternion).squeeze().astype(float)
161
+
162
+ # Build the initial SO(3) quaternion.
163
+ W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=A_Q_B))
164
+
165
+ # Integrate the quaternion on the manifold.
166
+ W_Q_B_tf = jax.lax.select(
167
+ pred=omega_in_body_fixed,
168
+ on_true=Quaternion.to_wxyz(
169
+ xyzw=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).as_quaternion_xyzw()
170
+ ),
171
+ on_false=Quaternion.to_wxyz(
172
+ xyzw=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).as_quaternion_xyzw()
173
+ ),
174
+ )
175
+
176
+ return W_Q_B_tf
jaxsim/math/rotation.py CHANGED
@@ -2,9 +2,9 @@ from typing import Tuple
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
+ import jaxlie
5
6
 
6
7
  import jaxsim.typing as jtp
7
- from jaxsim.sixd import so3
8
8
 
9
9
  from .skew import Skew
10
10
 
@@ -21,7 +21,7 @@ class Rotation:
21
21
  Returns:
22
22
  jtp.Matrix: 3D rotation matrix.
23
23
  """
24
- return so3.SO3.from_x_radians(theta=theta).as_matrix()
24
+ return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()
25
25
 
26
26
  @staticmethod
27
27
  def y(theta: jtp.Float) -> jtp.Matrix:
@@ -34,7 +34,7 @@ class Rotation:
34
34
  Returns:
35
35
  jtp.Matrix: 3D rotation matrix.
36
36
  """
37
- return so3.SO3.from_y_radians(theta=theta).as_matrix()
37
+ return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()
38
38
 
39
39
  @staticmethod
40
40
  def z(theta: jtp.Float) -> jtp.Matrix:
@@ -47,7 +47,7 @@ class Rotation:
47
47
  Returns:
48
48
  jtp.Matrix: 3D rotation matrix.
49
49
  """
50
- return so3.SO3.from_z_radians(theta=theta).as_matrix()
50
+ return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()
51
51
 
52
52
  @staticmethod
53
53
  def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
@@ -0,0 +1,93 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.typing as jtp
6
+
7
+ from .quaternion import Quaternion
8
+
9
+
10
+ class Transform:
11
+
12
+ @staticmethod
13
+ def from_quaternion_and_translation(
14
+ quaternion: jtp.VectorLike = jnp.array([1.0, 0, 0, 0]),
15
+ translation: jtp.VectorLike = jnp.zeros(3),
16
+ inverse: jtp.BoolLike = False,
17
+ normalize_quaternion: jtp.BoolLike = False,
18
+ ) -> jtp.Matrix:
19
+ """
20
+ Create a transformation matrix from a quaternion and a translation.
21
+
22
+ Args:
23
+ quaternion: A 4D vector representing a SO(3) orientation.
24
+ translation: A 3D vector representing a translation.
25
+ inverse: Whether to compute the inverse transformation.
26
+ normalize_quaternion:
27
+ Whether to normalize the quaternion before creating the transformation.
28
+
29
+ Returns:
30
+ The 4x4 transformation matrix representing the SE(3) transformation.
31
+ """
32
+
33
+ W_Q_B = jnp.array(quaternion).astype(float)
34
+ W_p_B = jnp.array(translation).astype(float)
35
+
36
+ assert W_p_B.size == 3
37
+ assert W_Q_B.size == 4
38
+
39
+ A_R_B = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(W_Q_B))
40
+ A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
41
+
42
+ A_H_B = jaxlie.SE3.from_rotation_and_translation(
43
+ rotation=A_R_B, translation=W_p_B
44
+ )
45
+
46
+ return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
47
+
48
+ @staticmethod
49
+ def from_rotation_and_translation(
50
+ rotation: jtp.MatrixLike,
51
+ translation: jtp.VectorLike,
52
+ inverse: jtp.BoolLike = False,
53
+ ) -> jtp.Matrix:
54
+ """
55
+ Create a transformation matrix from a rotation matrix and a translation vector.
56
+
57
+ Args:
58
+ rotation: A 3x3 rotation matrix representing a SO(3) orientation.
59
+ translation: A 3D vector representing a translation.
60
+ inverse: Whether to compute the inverse transformation.
61
+
62
+ Returns:
63
+ The 4x4 transformation matrix representing the SE(3) transformation.
64
+ """
65
+
66
+ A_R_B = jnp.array(rotation).astype(float)
67
+ W_p_B = jnp.array(translation).astype(float)
68
+
69
+ assert W_p_B.size == 3
70
+ assert A_R_B.shape == (3, 3)
71
+
72
+ A_H_B = jaxlie.SE3.from_rotation_and_translation(
73
+ rotation=A_R_B, translation=W_p_B
74
+ )
75
+
76
+ return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
77
+
78
+ @staticmethod
79
+ def inverse(transform: jtp.MatrixLike) -> jtp.Matrix:
80
+ """
81
+ Compute the inverse transformation matrix.
82
+
83
+ Args:
84
+ transform: A 4x4 transformation matrix.
85
+
86
+ Returns:
87
+ The 4x4 inverse transformation matrix.
88
+ """
89
+
90
+ A_H_B = jnp.array(transform).astype(float)
91
+ assert A_H_B.shape == (4, 4)
92
+
93
+ return jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().as_matrix()
@@ -3,10 +3,10 @@ from typing import List
3
3
 
4
4
  import jax.numpy as jnp
5
5
  import jax_dataclasses
6
+ import jaxlie
6
7
  from jax_dataclasses import Static
7
8
 
8
9
  import jaxsim.typing as jtp
9
- from jaxsim.sixd import se3
10
10
  from jaxsim.utils import JaxsimDataclass
11
11
 
12
12
 
@@ -78,7 +78,7 @@ class LinkDescription(JaxsimDataclass):
78
78
  I_removed = link.inertia
79
79
 
80
80
  # Create the SE3 object. Note the inverse.
81
- r_H_l = se3.SE3.from_matrix(lumped_H_removed).inverse()
81
+ r_H_l = jaxlie.SE3.from_matrix(lumped_H_removed).inverse()
82
82
  r_X_l = r_H_l.adjoint()
83
83
 
84
84
  # Move the inertia
@@ -1,15 +1,17 @@
1
1
  import os
2
2
  from typing import Union
3
3
 
4
- import jax.numpy as jnp
4
+ import jaxlie
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
7
  import rod
8
8
 
9
+ import jaxsim.typing as jtp
10
+ from jaxsim.math.inertia import Inertia
9
11
  from jaxsim.parsers import descriptions
10
12
 
11
13
 
12
- def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
14
+ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
13
15
  """
14
16
  Extract the 6D inertia matrix from an SDF inertial element.
15
17
 
@@ -20,9 +22,6 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
20
22
  The 6D inertia matrix of the link expressed in the link frame.
21
23
  """
22
24
 
23
- from jaxsim.math.inertia import Inertia
24
- from jaxsim.sixd import se3
25
-
26
25
  # Extract the "mass" element
27
26
  m = inertial.mass
28
27
 
@@ -52,13 +51,13 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
52
51
  L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)
53
52
 
54
53
  # We need its inverse
55
- CoM_H_L = se3.SE3.from_matrix(matrix=L_H_CoM).inverse()
56
- CoM_X_L: npt.NDArray = CoM_H_L.adjoint()
54
+ CoM_H_L = jaxlie.SE3.from_matrix(matrix=L_H_CoM).inverse()
55
+ CoM_X_L = CoM_H_L.adjoint()
57
56
 
58
57
  # Express the CoM inertia matrix in the link frame L
59
58
  M_L = CoM_X_L.T @ M_CoM @ CoM_X_L
60
59
 
61
- return jnp.array(M_L)
60
+ return M_L.astype(dtype=float)
62
61
 
63
62
 
64
63
  def axis_to_jtype(
@@ -0,0 +1,7 @@
1
+ from .aba import aba
2
+ from .collidable_points import collidable_points_pos_vel
3
+ from .crba import crba
4
+ from .forward_kinematics import forward_kinematics, forward_kinematics_model
5
+ from .jacobian import jacobian, jacobian_full_doubly_left
6
+ from .rnea import rnea
7
+ from .soft_contacts import SoftContacts, SoftContactsParams