jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,335 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+ import jaxlie
9
+ from jax_dataclasses import Static
10
+
11
+ import jaxsim.typing as jtp
12
+ from jaxsim.parsers.descriptions import (
13
+ JointDescriptor,
14
+ JointGenericAxis,
15
+ JointType,
16
+ ModelDescription,
17
+ )
18
+
19
+ from .rotation import Rotation
20
+
21
+
22
+ @jax_dataclasses.pytree_dataclass
23
+ class JointModel:
24
+ """
25
+ Class describing the joint kinematics of a robot model.
26
+
27
+ Attributes:
28
+ λ_H_pre:
29
+ The homogeneous transformation between the parent link and
30
+ the predecessor frame of each joint.
31
+ suc_H_i:
32
+ The homogeneous transformation between the successor frame and
33
+ the child link of each joint.
34
+ joint_dofs: The number of DoFs of each joint.
35
+ joint_names: The names of each joint.
36
+ joint_types: The types of each joint.
37
+
38
+ Note:
39
+ Due to the presence of the static attributes, this class needs to be created
40
+ already in a vectorized form. In other words, it cannot be created using vmap.
41
+ """
42
+
43
+ λ_H_pre: jax.Array
44
+ suc_H_i: jax.Array
45
+
46
+ joint_dofs: Static[tuple[int, ...]]
47
+ joint_names: Static[tuple[str, ...]]
48
+ joint_types: Static[tuple[JointType | JointDescriptor, ...]]
49
+
50
+ @staticmethod
51
+ def build(description: ModelDescription) -> JointModel:
52
+ """
53
+ Build the joint model of a model description.
54
+
55
+ Args:
56
+ description: The model description to consider.
57
+
58
+ Returns:
59
+ The joint model of the considered model description.
60
+ """
61
+
62
+ # The link index is equal to its body index: [0, number_of_bodies - 1].
63
+ ordered_links = sorted(
64
+ list(description.links_dict.values()),
65
+ key=lambda l: l.index,
66
+ )
67
+
68
+ # Note: the joint index is equal to its child link index, therefore it
69
+ # starts from 1.
70
+ ordered_joints = sorted(
71
+ list(description.joints_dict.values()),
72
+ key=lambda j: j.index,
73
+ )
74
+
75
+ # Allocate the parent-to-predecessor and successor-to-child transforms.
76
+ λ_H_pre = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
77
+ suc_H_i = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
78
+
79
+ # Initialize an identical parent-to-predecessor transform for the joint
80
+ # between the world frame W and the base link B.
81
+ λ_H_pre = λ_H_pre.at[0].set(jnp.eye(4))
82
+
83
+ # Initialize the successor-to-child transform of the joint between the
84
+ # world frame W and the base link B.
85
+ # We store here the optional transform between the root frame of the model
86
+ # and the base link frame (this is needed only if the pose of the link frame
87
+ # w.r.t. the implicit __model__ SDF frame is not the identity).
88
+ suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose)
89
+
90
+ # Compute the parent-to-predecessor and successor-to-child transforms for
91
+ # each joint belonging to the model.
92
+ # Note that the joint indices starts from i=1 given our joint model,
93
+ # therefore the entries at index 0 are not updated.
94
+ for joint in ordered_joints:
95
+ λ_H_pre = λ_H_pre.at[joint.index].set(
96
+ description.relative_transform(
97
+ relative_to=joint.parent.name,
98
+ name=joint.name,
99
+ )
100
+ )
101
+ suc_H_i = suc_H_i.at[joint.index].set(
102
+ description.relative_transform(
103
+ relative_to=joint.name, name=joint.child.name
104
+ )
105
+ )
106
+
107
+ # Define the DoFs of the base link.
108
+ base_dofs = 0 if description.fixed_base else 6
109
+
110
+ # We always add a dummy fixed joint between world and base.
111
+ # TODO: Port floating-base support also at this level, not only in RBDAs.
112
+ return JointModel(
113
+ λ_H_pre=λ_H_pre,
114
+ suc_H_i=suc_H_i,
115
+ # Static attributes
116
+ joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
117
+ joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
118
+ joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]),
119
+ )
120
+
121
+ def parent_H_child(
122
+ self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
123
+ ) -> tuple[jtp.Matrix, jtp.Array]:
124
+ r"""
125
+ Compute the homogeneous transformation between the parent link and
126
+ the child link of a joint, and the corresponding motion subspace.
127
+
128
+ Args:
129
+ joint_index: The index of the joint.
130
+ joint_position: The position of the joint.
131
+
132
+ Returns:
133
+ A tuple containing the homogeneous transformation
134
+ :math:`{}^{\lambda(i)} \mathbf{H}_i(s)`
135
+ and the motion subspace :math:`\mathbf{S}(s)`.
136
+ """
137
+
138
+ i = joint_index
139
+ s = joint_position
140
+
141
+ # Get the components of the joint model.
142
+ λ_Hi_pre = self.parent_H_predecessor(joint_index=i)
143
+ pre_Hi_suc, S = self.predecessor_H_successor(joint_index=i, joint_position=s)
144
+ suc_Hi_i = self.successor_H_child(joint_index=i)
145
+
146
+ # Compose all the transforms.
147
+ return λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, S
148
+
149
+ @jax.jit
150
+ def child_H_parent(
151
+ self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
152
+ ) -> tuple[jtp.Matrix, jtp.Array]:
153
+ r"""
154
+ Compute the homogeneous transformation between the child link and
155
+ the parent link of a joint, and the corresponding motion subspace.
156
+
157
+ Args:
158
+ joint_index: The index of the joint.
159
+ joint_position: The position of the joint.
160
+
161
+ Returns:
162
+ A tuple containing the homogeneous transformation
163
+ :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
164
+ and the motion subspace :math:`\mathbf{S}(s)`.
165
+ """
166
+
167
+ λ_Hi_i, S = self.parent_H_child(
168
+ joint_index=joint_index, joint_position=joint_position
169
+ )
170
+
171
+ i_Hi_λ = jaxlie.SE3.from_matrix(λ_Hi_i).inverse().as_matrix()
172
+
173
+ return i_Hi_λ, S
174
+
175
+ def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
176
+ r"""
177
+ Return the homogeneous transformation between the parent link and
178
+ the predecessor frame of a joint.
179
+
180
+ Args:
181
+ joint_index: The index of the joint.
182
+
183
+ Returns:
184
+ The homogeneous transformation
185
+ :math:`{}^{\lambda(i)} \mathbf{H}_{\text{pre}(i)}`.
186
+ """
187
+
188
+ return self.λ_H_pre[joint_index]
189
+
190
+ def predecessor_H_successor(
191
+ self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
192
+ ) -> tuple[jtp.Matrix, jtp.Array]:
193
+ r"""
194
+ Compute the homogeneous transformation between the predecessor and
195
+ the successor frame of a joint, and the corresponding motion subspace.
196
+
197
+ Args:
198
+ joint_index: The index of the joint.
199
+ joint_position: The position of the joint.
200
+
201
+ Returns:
202
+ A tuple containing the homogeneous transformation
203
+ :math:`{}^{\text{pre}(i)} \mathbf{H}_{\text{suc}(i)}(s)`
204
+ and the motion subspace :math:`\mathbf{S}(s)`.
205
+ """
206
+
207
+ pre_H_suc, S = supported_joint_motion(
208
+ joint_type=self.joint_types[joint_index],
209
+ joint_position=joint_position,
210
+ )
211
+
212
+ return pre_H_suc, S
213
+
214
+ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
215
+ r"""
216
+ Return the homogeneous transformation between the successor frame and
217
+ the child link of a joint.
218
+
219
+ Args:
220
+ joint_index: The index of the joint.
221
+
222
+ Returns:
223
+ The homogeneous transformation
224
+ :math:`{}^{\text{suc}(i)} \mathbf{H}_i`.
225
+ """
226
+
227
+ return self.suc_H_i[joint_index]
228
+
229
+
230
+ @functools.partial(jax.jit, static_argnames=["joint_type"])
231
+ def supported_joint_motion(
232
+ joint_type: JointType | JointDescriptor, joint_position: jtp.VectorLike
233
+ ) -> tuple[jtp.Matrix, jtp.Array]:
234
+ """
235
+ Compute the homogeneous transformation and motion subspace of a joint.
236
+
237
+ Args:
238
+ joint_type: The type of the joint.
239
+ joint_position: The position of the joint.
240
+
241
+ Returns:
242
+ A tuple containing the homogeneous transformation and the motion subspace.
243
+ """
244
+
245
+ if isinstance(joint_type, JointType):
246
+ code = joint_type
247
+ elif isinstance(joint_type, JointDescriptor):
248
+ code = joint_type.code
249
+ else:
250
+ raise ValueError(joint_type)
251
+
252
+ # Prepare the joint position
253
+ s = jnp.array(joint_position).astype(float)
254
+
255
+ match code:
256
+
257
+ case JointType.R:
258
+ joint_type: JointGenericAxis
259
+
260
+ pre_H_suc = jaxlie.SE3.from_rotation(
261
+ rotation=jaxlie.SO3.from_matrix(
262
+ Rotation.from_axis_angle(vector=s * joint_type.axis)
263
+ )
264
+ )
265
+
266
+ S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_type.axis.squeeze()]))
267
+
268
+ case JointType.P:
269
+ joint_type: JointGenericAxis
270
+
271
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
272
+ rotation=jaxlie.SO3.identity(),
273
+ translation=jnp.array(s * joint_type.axis),
274
+ )
275
+
276
+ S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))
277
+
278
+ case JointType.F:
279
+ raise ValueError("Fixed joints shouldn't be here")
280
+
281
+ case JointType.Rx:
282
+
283
+ pre_H_suc = jaxlie.SE3.from_rotation(
284
+ rotation=jaxlie.SO3.from_x_radians(theta=s)
285
+ )
286
+
287
+ S = jnp.vstack([0, 0, 0, 1.0, 0, 0])
288
+
289
+ case JointType.Ry:
290
+
291
+ pre_H_suc = jaxlie.SE3.from_rotation(
292
+ rotation=jaxlie.SO3.from_y_radians(theta=s)
293
+ )
294
+
295
+ S = jnp.vstack([0, 0, 0, 0, 1.0, 0])
296
+
297
+ case JointType.Rz:
298
+
299
+ pre_H_suc = jaxlie.SE3.from_rotation(
300
+ rotation=jaxlie.SO3.from_z_radians(theta=s)
301
+ )
302
+
303
+ S = jnp.vstack([0, 0, 0, 0, 0, 1.0])
304
+
305
+ case JointType.Px:
306
+
307
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
308
+ rotation=jaxlie.SO3.identity(),
309
+ translation=jnp.array([s, 0.0, 0.0]),
310
+ )
311
+
312
+ S = jnp.vstack([1.0, 0, 0, 0, 0, 0])
313
+
314
+ case JointType.Py:
315
+
316
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
317
+ rotation=jaxlie.SO3.identity(),
318
+ translation=jnp.array([0.0, s, 0.0]),
319
+ )
320
+
321
+ S = jnp.vstack([0, 1.0, 0, 0, 0, 0])
322
+
323
+ case JointType.Pz:
324
+
325
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
326
+ rotation=jaxlie.SO3.identity(),
327
+ translation=jnp.array([0.0, 0.0, s]),
328
+ )
329
+
330
+ S = jnp.vstack([0, 0, 1.0, 0, 0, 0])
331
+
332
+ case _:
333
+ raise ValueError(joint_type)
334
+
335
+ return pre_H_suc.as_matrix(), S
jaxsim/math/quaternion.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import jax.lax
2
2
  import jax.numpy as jnp
3
+ import jaxlie
3
4
 
4
5
  import jaxsim.typing as jtp
5
- from jaxsim.sixd import so3
6
6
 
7
7
 
8
8
  class Quaternion:
@@ -43,7 +43,7 @@ class Quaternion:
43
43
  Returns:
44
44
  jtp.Matrix: Direction cosine matrix (DCM).
45
45
  """
46
- return so3.SO3.from_quaternion_xyzw(
46
+ return jaxlie.SO3.from_quaternion_xyzw(
47
47
  xyzw=Quaternion.to_xyzw(quaternion)
48
48
  ).as_matrix()
49
49
 
@@ -59,7 +59,7 @@ class Quaternion:
59
59
  jtp.Vector: Quaternion in XYZW representation.
60
60
  """
61
61
  return Quaternion.to_wxyz(
62
- xyzw=so3.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
62
+ xyzw=jaxlie.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
63
63
  )
64
64
 
65
65
  @staticmethod
@@ -133,3 +133,44 @@ class Quaternion:
133
133
  )
134
134
 
135
135
  return jnp.vstack(qd)
136
+
137
+ @staticmethod
138
+ def integration(
139
+ quaternion: jtp.VectorLike,
140
+ dt: jtp.FloatLike,
141
+ omega: jtp.VectorLike,
142
+ omega_in_body_fixed: jtp.BoolLike = False,
143
+ ) -> jtp.Vector:
144
+ """
145
+ Integrate a quaternion in SO(3) given an angular velocity.
146
+
147
+ Args:
148
+ quaternion: The quaternion to integrate.
149
+ dt: The time step.
150
+ omega: The angular velocity vector.
151
+ omega_in_body_fixed:
152
+ Whether the angular velocity is in body-fixed representation
153
+ as opposed to the default inertial-fixed representation.
154
+
155
+ Returns:
156
+ The integrated quaternion.
157
+ """
158
+
159
+ ω_AB = jnp.array(omega).squeeze().astype(float)
160
+ A_Q_B = jnp.array(quaternion).squeeze().astype(float)
161
+
162
+ # Build the initial SO(3) quaternion.
163
+ W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=A_Q_B))
164
+
165
+ # Integrate the quaternion on the manifold.
166
+ W_Q_B_tf = jax.lax.select(
167
+ pred=omega_in_body_fixed,
168
+ on_true=Quaternion.to_wxyz(
169
+ xyzw=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).as_quaternion_xyzw()
170
+ ),
171
+ on_false=Quaternion.to_wxyz(
172
+ xyzw=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).as_quaternion_xyzw()
173
+ ),
174
+ )
175
+
176
+ return W_Q_B_tf
jaxsim/math/rotation.py CHANGED
@@ -2,9 +2,9 @@ from typing import Tuple
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
+ import jaxlie
5
6
 
6
7
  import jaxsim.typing as jtp
7
- from jaxsim.sixd import so3
8
8
 
9
9
  from .skew import Skew
10
10
 
@@ -21,7 +21,7 @@ class Rotation:
21
21
  Returns:
22
22
  jtp.Matrix: 3D rotation matrix.
23
23
  """
24
- return so3.SO3.from_x_radians(theta=theta).as_matrix()
24
+ return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()
25
25
 
26
26
  @staticmethod
27
27
  def y(theta: jtp.Float) -> jtp.Matrix:
@@ -34,7 +34,7 @@ class Rotation:
34
34
  Returns:
35
35
  jtp.Matrix: 3D rotation matrix.
36
36
  """
37
- return so3.SO3.from_y_radians(theta=theta).as_matrix()
37
+ return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()
38
38
 
39
39
  @staticmethod
40
40
  def z(theta: jtp.Float) -> jtp.Matrix:
@@ -47,7 +47,7 @@ class Rotation:
47
47
  Returns:
48
48
  jtp.Matrix: 3D rotation matrix.
49
49
  """
50
- return so3.SO3.from_z_radians(theta=theta).as_matrix()
50
+ return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()
51
51
 
52
52
  @staticmethod
53
53
  def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
@@ -0,0 +1,92 @@
1
+ import jax.numpy as jnp
2
+ import jaxlie
3
+
4
+ import jaxsim.typing as jtp
5
+
6
+ from .quaternion import Quaternion
7
+
8
+
9
+ class Transform:
10
+
11
+ @staticmethod
12
+ def from_quaternion_and_translation(
13
+ quaternion: jtp.VectorLike = jnp.array([1.0, 0, 0, 0]),
14
+ translation: jtp.VectorLike = jnp.zeros(3),
15
+ inverse: jtp.BoolLike = False,
16
+ normalize_quaternion: jtp.BoolLike = False,
17
+ ) -> jtp.Matrix:
18
+ """
19
+ Create a transformation matrix from a quaternion and a translation.
20
+
21
+ Args:
22
+ quaternion: A 4D vector representing a SO(3) orientation.
23
+ translation: A 3D vector representing a translation.
24
+ inverse: Whether to compute the inverse transformation.
25
+ normalize_quaternion:
26
+ Whether to normalize the quaternion before creating the transformation.
27
+
28
+ Returns:
29
+ The 4x4 transformation matrix representing the SE(3) transformation.
30
+ """
31
+
32
+ W_Q_B = jnp.array(quaternion).astype(float)
33
+ W_p_B = jnp.array(translation).astype(float)
34
+
35
+ assert W_p_B.size == 3
36
+ assert W_Q_B.size == 4
37
+
38
+ A_R_B = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(W_Q_B))
39
+ A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
40
+
41
+ A_H_B = jaxlie.SE3.from_rotation_and_translation(
42
+ rotation=A_R_B, translation=W_p_B
43
+ )
44
+
45
+ return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
46
+
47
+ @staticmethod
48
+ def from_rotation_and_translation(
49
+ rotation: jtp.MatrixLike,
50
+ translation: jtp.VectorLike,
51
+ inverse: jtp.BoolLike = False,
52
+ ) -> jtp.Matrix:
53
+ """
54
+ Create a transformation matrix from a rotation matrix and a translation vector.
55
+
56
+ Args:
57
+ rotation: A 3x3 rotation matrix representing a SO(3) orientation.
58
+ translation: A 3D vector representing a translation.
59
+ inverse: Whether to compute the inverse transformation.
60
+
61
+ Returns:
62
+ The 4x4 transformation matrix representing the SE(3) transformation.
63
+ """
64
+
65
+ A_R_B = jnp.array(rotation).astype(float)
66
+ W_p_B = jnp.array(translation).astype(float)
67
+
68
+ assert W_p_B.size == 3
69
+ assert A_R_B.shape == (3, 3)
70
+
71
+ A_H_B = jaxlie.SE3.from_rotation_and_translation(
72
+ rotation=A_R_B, translation=W_p_B
73
+ )
74
+
75
+ return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
76
+
77
+ @staticmethod
78
+ def inverse(transform: jtp.MatrixLike) -> jtp.Matrix:
79
+ """
80
+ Compute the inverse transformation matrix.
81
+
82
+ Args:
83
+ transform: A 4x4 transformation matrix.
84
+
85
+ Returns:
86
+ The 4x4 inverse transformation matrix.
87
+ """
88
+
89
+ A_H_B = jnp.array(transform).astype(float)
90
+ assert A_H_B.shape == (4, 4)
91
+
92
+ return jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().as_matrix()
@@ -0,0 +1,3 @@
1
+ from .loaders import RodModelToMjcf, SdfToMjcf, UrdfToMjcf
2
+ from .model import MujocoModelHelper
3
+ from .visualizer import MujocoVideoRecorder, MujocoVisualizer