jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev364__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +86 -74
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/link.py +2 -2
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev364.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,335 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax_dataclasses
|
8
|
+
import jaxlie
|
9
|
+
from jax_dataclasses import Static
|
10
|
+
|
11
|
+
import jaxsim.typing as jtp
|
12
|
+
from jaxsim.parsers.descriptions import (
|
13
|
+
JointDescriptor,
|
14
|
+
JointGenericAxis,
|
15
|
+
JointType,
|
16
|
+
ModelDescription,
|
17
|
+
)
|
18
|
+
|
19
|
+
from .rotation import Rotation
|
20
|
+
|
21
|
+
|
22
|
+
@jax_dataclasses.pytree_dataclass
|
23
|
+
class JointModel:
|
24
|
+
"""
|
25
|
+
Class describing the joint kinematics of a robot model.
|
26
|
+
|
27
|
+
Attributes:
|
28
|
+
λ_H_pre:
|
29
|
+
The homogeneous transformation between the parent link and
|
30
|
+
the predecessor frame of each joint.
|
31
|
+
suc_H_i:
|
32
|
+
The homogeneous transformation between the successor frame and
|
33
|
+
the child link of each joint.
|
34
|
+
joint_dofs: The number of DoFs of each joint.
|
35
|
+
joint_names: The names of each joint.
|
36
|
+
joint_types: The types of each joint.
|
37
|
+
|
38
|
+
Note:
|
39
|
+
Due to the presence of the static attributes, this class needs to be created
|
40
|
+
already in a vectorized form. In other words, it cannot be created using vmap.
|
41
|
+
"""
|
42
|
+
|
43
|
+
λ_H_pre: jax.Array
|
44
|
+
suc_H_i: jax.Array
|
45
|
+
|
46
|
+
joint_dofs: Static[tuple[int, ...]]
|
47
|
+
joint_names: Static[tuple[str, ...]]
|
48
|
+
joint_types: Static[tuple[JointType | JointDescriptor, ...]]
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def build(description: ModelDescription) -> JointModel:
|
52
|
+
"""
|
53
|
+
Build the joint model of a model description.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
description: The model description to consider.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
The joint model of the considered model description.
|
60
|
+
"""
|
61
|
+
|
62
|
+
# The link index is equal to its body index: [0, number_of_bodies - 1].
|
63
|
+
ordered_links = sorted(
|
64
|
+
list(description.links_dict.values()),
|
65
|
+
key=lambda l: l.index,
|
66
|
+
)
|
67
|
+
|
68
|
+
# Note: the joint index is equal to its child link index, therefore it
|
69
|
+
# starts from 1.
|
70
|
+
ordered_joints = sorted(
|
71
|
+
list(description.joints_dict.values()),
|
72
|
+
key=lambda j: j.index,
|
73
|
+
)
|
74
|
+
|
75
|
+
# Allocate the parent-to-predecessor and successor-to-child transforms.
|
76
|
+
λ_H_pre = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
|
77
|
+
suc_H_i = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
|
78
|
+
|
79
|
+
# Initialize an identical parent-to-predecessor transform for the joint
|
80
|
+
# between the world frame W and the base link B.
|
81
|
+
λ_H_pre = λ_H_pre.at[0].set(jnp.eye(4))
|
82
|
+
|
83
|
+
# Initialize the successor-to-child transform of the joint between the
|
84
|
+
# world frame W and the base link B.
|
85
|
+
# We store here the optional transform between the root frame of the model
|
86
|
+
# and the base link frame (this is needed only if the pose of the link frame
|
87
|
+
# w.r.t. the implicit __model__ SDF frame is not the identity).
|
88
|
+
suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose)
|
89
|
+
|
90
|
+
# Compute the parent-to-predecessor and successor-to-child transforms for
|
91
|
+
# each joint belonging to the model.
|
92
|
+
# Note that the joint indices starts from i=1 given our joint model,
|
93
|
+
# therefore the entries at index 0 are not updated.
|
94
|
+
for joint in ordered_joints:
|
95
|
+
λ_H_pre = λ_H_pre.at[joint.index].set(
|
96
|
+
description.relative_transform(
|
97
|
+
relative_to=joint.parent.name,
|
98
|
+
name=joint.name,
|
99
|
+
)
|
100
|
+
)
|
101
|
+
suc_H_i = suc_H_i.at[joint.index].set(
|
102
|
+
description.relative_transform(
|
103
|
+
relative_to=joint.name, name=joint.child.name
|
104
|
+
)
|
105
|
+
)
|
106
|
+
|
107
|
+
# Define the DoFs of the base link.
|
108
|
+
base_dofs = 0 if description.fixed_base else 6
|
109
|
+
|
110
|
+
# We always add a dummy fixed joint between world and base.
|
111
|
+
# TODO: Port floating-base support also at this level, not only in RBDAs.
|
112
|
+
return JointModel(
|
113
|
+
λ_H_pre=λ_H_pre,
|
114
|
+
suc_H_i=suc_H_i,
|
115
|
+
# Static attributes
|
116
|
+
joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
|
117
|
+
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
|
118
|
+
joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]),
|
119
|
+
)
|
120
|
+
|
121
|
+
def parent_H_child(
|
122
|
+
self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
|
123
|
+
) -> tuple[jtp.Matrix, jtp.Array]:
|
124
|
+
r"""
|
125
|
+
Compute the homogeneous transformation between the parent link and
|
126
|
+
the child link of a joint, and the corresponding motion subspace.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
joint_index: The index of the joint.
|
130
|
+
joint_position: The position of the joint.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
A tuple containing the homogeneous transformation
|
134
|
+
:math:`{}^{\lambda(i)} \mathbf{H}_i(s)`
|
135
|
+
and the motion subspace :math:`\mathbf{S}(s)`.
|
136
|
+
"""
|
137
|
+
|
138
|
+
i = joint_index
|
139
|
+
s = joint_position
|
140
|
+
|
141
|
+
# Get the components of the joint model.
|
142
|
+
λ_Hi_pre = self.parent_H_predecessor(joint_index=i)
|
143
|
+
pre_Hi_suc, S = self.predecessor_H_successor(joint_index=i, joint_position=s)
|
144
|
+
suc_Hi_i = self.successor_H_child(joint_index=i)
|
145
|
+
|
146
|
+
# Compose all the transforms.
|
147
|
+
return λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, S
|
148
|
+
|
149
|
+
@jax.jit
|
150
|
+
def child_H_parent(
|
151
|
+
self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
|
152
|
+
) -> tuple[jtp.Matrix, jtp.Array]:
|
153
|
+
r"""
|
154
|
+
Compute the homogeneous transformation between the child link and
|
155
|
+
the parent link of a joint, and the corresponding motion subspace.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
joint_index: The index of the joint.
|
159
|
+
joint_position: The position of the joint.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
A tuple containing the homogeneous transformation
|
163
|
+
:math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
|
164
|
+
and the motion subspace :math:`\mathbf{S}(s)`.
|
165
|
+
"""
|
166
|
+
|
167
|
+
λ_Hi_i, S = self.parent_H_child(
|
168
|
+
joint_index=joint_index, joint_position=joint_position
|
169
|
+
)
|
170
|
+
|
171
|
+
i_Hi_λ = jaxlie.SE3.from_matrix(λ_Hi_i).inverse().as_matrix()
|
172
|
+
|
173
|
+
return i_Hi_λ, S
|
174
|
+
|
175
|
+
def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
|
176
|
+
r"""
|
177
|
+
Return the homogeneous transformation between the parent link and
|
178
|
+
the predecessor frame of a joint.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
joint_index: The index of the joint.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
The homogeneous transformation
|
185
|
+
:math:`{}^{\lambda(i)} \mathbf{H}_{\text{pre}(i)}`.
|
186
|
+
"""
|
187
|
+
|
188
|
+
return self.λ_H_pre[joint_index]
|
189
|
+
|
190
|
+
def predecessor_H_successor(
|
191
|
+
self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
|
192
|
+
) -> tuple[jtp.Matrix, jtp.Array]:
|
193
|
+
r"""
|
194
|
+
Compute the homogeneous transformation between the predecessor and
|
195
|
+
the successor frame of a joint, and the corresponding motion subspace.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
joint_index: The index of the joint.
|
199
|
+
joint_position: The position of the joint.
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
A tuple containing the homogeneous transformation
|
203
|
+
:math:`{}^{\text{pre}(i)} \mathbf{H}_{\text{suc}(i)}(s)`
|
204
|
+
and the motion subspace :math:`\mathbf{S}(s)`.
|
205
|
+
"""
|
206
|
+
|
207
|
+
pre_H_suc, S = supported_joint_motion(
|
208
|
+
joint_type=self.joint_types[joint_index],
|
209
|
+
joint_position=joint_position,
|
210
|
+
)
|
211
|
+
|
212
|
+
return pre_H_suc, S
|
213
|
+
|
214
|
+
def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
|
215
|
+
r"""
|
216
|
+
Return the homogeneous transformation between the successor frame and
|
217
|
+
the child link of a joint.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
joint_index: The index of the joint.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
The homogeneous transformation
|
224
|
+
:math:`{}^{\text{suc}(i)} \mathbf{H}_i`.
|
225
|
+
"""
|
226
|
+
|
227
|
+
return self.suc_H_i[joint_index]
|
228
|
+
|
229
|
+
|
230
|
+
@functools.partial(jax.jit, static_argnames=["joint_type"])
|
231
|
+
def supported_joint_motion(
|
232
|
+
joint_type: JointType | JointDescriptor, joint_position: jtp.VectorLike
|
233
|
+
) -> tuple[jtp.Matrix, jtp.Array]:
|
234
|
+
"""
|
235
|
+
Compute the homogeneous transformation and motion subspace of a joint.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
joint_type: The type of the joint.
|
239
|
+
joint_position: The position of the joint.
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
A tuple containing the homogeneous transformation and the motion subspace.
|
243
|
+
"""
|
244
|
+
|
245
|
+
if isinstance(joint_type, JointType):
|
246
|
+
code = joint_type
|
247
|
+
elif isinstance(joint_type, JointDescriptor):
|
248
|
+
code = joint_type.code
|
249
|
+
else:
|
250
|
+
raise ValueError(joint_type)
|
251
|
+
|
252
|
+
# Prepare the joint position
|
253
|
+
s = jnp.array(joint_position).astype(float)
|
254
|
+
|
255
|
+
match code:
|
256
|
+
|
257
|
+
case JointType.R:
|
258
|
+
joint_type: JointGenericAxis
|
259
|
+
|
260
|
+
pre_H_suc = jaxlie.SE3.from_rotation(
|
261
|
+
rotation=jaxlie.SO3.from_matrix(
|
262
|
+
Rotation.from_axis_angle(vector=s * joint_type.axis)
|
263
|
+
)
|
264
|
+
)
|
265
|
+
|
266
|
+
S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_type.axis.squeeze()]))
|
267
|
+
|
268
|
+
case JointType.P:
|
269
|
+
joint_type: JointGenericAxis
|
270
|
+
|
271
|
+
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
|
272
|
+
rotation=jaxlie.SO3.identity(),
|
273
|
+
translation=jnp.array(s * joint_type.axis),
|
274
|
+
)
|
275
|
+
|
276
|
+
S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))
|
277
|
+
|
278
|
+
case JointType.F:
|
279
|
+
raise ValueError("Fixed joints shouldn't be here")
|
280
|
+
|
281
|
+
case JointType.Rx:
|
282
|
+
|
283
|
+
pre_H_suc = jaxlie.SE3.from_rotation(
|
284
|
+
rotation=jaxlie.SO3.from_x_radians(theta=s)
|
285
|
+
)
|
286
|
+
|
287
|
+
S = jnp.vstack([0, 0, 0, 1.0, 0, 0])
|
288
|
+
|
289
|
+
case JointType.Ry:
|
290
|
+
|
291
|
+
pre_H_suc = jaxlie.SE3.from_rotation(
|
292
|
+
rotation=jaxlie.SO3.from_y_radians(theta=s)
|
293
|
+
)
|
294
|
+
|
295
|
+
S = jnp.vstack([0, 0, 0, 0, 1.0, 0])
|
296
|
+
|
297
|
+
case JointType.Rz:
|
298
|
+
|
299
|
+
pre_H_suc = jaxlie.SE3.from_rotation(
|
300
|
+
rotation=jaxlie.SO3.from_z_radians(theta=s)
|
301
|
+
)
|
302
|
+
|
303
|
+
S = jnp.vstack([0, 0, 0, 0, 0, 1.0])
|
304
|
+
|
305
|
+
case JointType.Px:
|
306
|
+
|
307
|
+
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
|
308
|
+
rotation=jaxlie.SO3.identity(),
|
309
|
+
translation=jnp.array([s, 0.0, 0.0]),
|
310
|
+
)
|
311
|
+
|
312
|
+
S = jnp.vstack([1.0, 0, 0, 0, 0, 0])
|
313
|
+
|
314
|
+
case JointType.Py:
|
315
|
+
|
316
|
+
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
|
317
|
+
rotation=jaxlie.SO3.identity(),
|
318
|
+
translation=jnp.array([0.0, s, 0.0]),
|
319
|
+
)
|
320
|
+
|
321
|
+
S = jnp.vstack([0, 1.0, 0, 0, 0, 0])
|
322
|
+
|
323
|
+
case JointType.Pz:
|
324
|
+
|
325
|
+
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
|
326
|
+
rotation=jaxlie.SO3.identity(),
|
327
|
+
translation=jnp.array([0.0, 0.0, s]),
|
328
|
+
)
|
329
|
+
|
330
|
+
S = jnp.vstack([0, 0, 1.0, 0, 0, 0])
|
331
|
+
|
332
|
+
case _:
|
333
|
+
raise ValueError(joint_type)
|
334
|
+
|
335
|
+
return pre_H_suc.as_matrix(), S
|
jaxsim/math/quaternion.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
import jax.lax
|
2
2
|
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
3
4
|
|
4
5
|
import jaxsim.typing as jtp
|
5
|
-
from jaxsim.sixd import so3
|
6
6
|
|
7
7
|
|
8
8
|
class Quaternion:
|
@@ -43,7 +43,7 @@ class Quaternion:
|
|
43
43
|
Returns:
|
44
44
|
jtp.Matrix: Direction cosine matrix (DCM).
|
45
45
|
"""
|
46
|
-
return
|
46
|
+
return jaxlie.SO3.from_quaternion_xyzw(
|
47
47
|
xyzw=Quaternion.to_xyzw(quaternion)
|
48
48
|
).as_matrix()
|
49
49
|
|
@@ -59,7 +59,7 @@ class Quaternion:
|
|
59
59
|
jtp.Vector: Quaternion in XYZW representation.
|
60
60
|
"""
|
61
61
|
return Quaternion.to_wxyz(
|
62
|
-
xyzw=
|
62
|
+
xyzw=jaxlie.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
|
63
63
|
)
|
64
64
|
|
65
65
|
@staticmethod
|
@@ -133,3 +133,44 @@ class Quaternion:
|
|
133
133
|
)
|
134
134
|
|
135
135
|
return jnp.vstack(qd)
|
136
|
+
|
137
|
+
@staticmethod
|
138
|
+
def integration(
|
139
|
+
quaternion: jtp.VectorLike,
|
140
|
+
dt: jtp.FloatLike,
|
141
|
+
omega: jtp.VectorLike,
|
142
|
+
omega_in_body_fixed: jtp.BoolLike = False,
|
143
|
+
) -> jtp.Vector:
|
144
|
+
"""
|
145
|
+
Integrate a quaternion in SO(3) given an angular velocity.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
quaternion: The quaternion to integrate.
|
149
|
+
dt: The time step.
|
150
|
+
omega: The angular velocity vector.
|
151
|
+
omega_in_body_fixed:
|
152
|
+
Whether the angular velocity is in body-fixed representation
|
153
|
+
as opposed to the default inertial-fixed representation.
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
The integrated quaternion.
|
157
|
+
"""
|
158
|
+
|
159
|
+
ω_AB = jnp.array(omega).squeeze().astype(float)
|
160
|
+
A_Q_B = jnp.array(quaternion).squeeze().astype(float)
|
161
|
+
|
162
|
+
# Build the initial SO(3) quaternion.
|
163
|
+
W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=A_Q_B))
|
164
|
+
|
165
|
+
# Integrate the quaternion on the manifold.
|
166
|
+
W_Q_B_tf = jax.lax.select(
|
167
|
+
pred=omega_in_body_fixed,
|
168
|
+
on_true=Quaternion.to_wxyz(
|
169
|
+
xyzw=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).as_quaternion_xyzw()
|
170
|
+
),
|
171
|
+
on_false=Quaternion.to_wxyz(
|
172
|
+
xyzw=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).as_quaternion_xyzw()
|
173
|
+
),
|
174
|
+
)
|
175
|
+
|
176
|
+
return W_Q_B_tf
|
jaxsim/math/rotation.py
CHANGED
@@ -2,9 +2,9 @@ from typing import Tuple
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
+
import jaxlie
|
5
6
|
|
6
7
|
import jaxsim.typing as jtp
|
7
|
-
from jaxsim.sixd import so3
|
8
8
|
|
9
9
|
from .skew import Skew
|
10
10
|
|
@@ -21,7 +21,7 @@ class Rotation:
|
|
21
21
|
Returns:
|
22
22
|
jtp.Matrix: 3D rotation matrix.
|
23
23
|
"""
|
24
|
-
return
|
24
|
+
return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()
|
25
25
|
|
26
26
|
@staticmethod
|
27
27
|
def y(theta: jtp.Float) -> jtp.Matrix:
|
@@ -34,7 +34,7 @@ class Rotation:
|
|
34
34
|
Returns:
|
35
35
|
jtp.Matrix: 3D rotation matrix.
|
36
36
|
"""
|
37
|
-
return
|
37
|
+
return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()
|
38
38
|
|
39
39
|
@staticmethod
|
40
40
|
def z(theta: jtp.Float) -> jtp.Matrix:
|
@@ -47,7 +47,7 @@ class Rotation:
|
|
47
47
|
Returns:
|
48
48
|
jtp.Matrix: 3D rotation matrix.
|
49
49
|
"""
|
50
|
-
return
|
50
|
+
return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()
|
51
51
|
|
52
52
|
@staticmethod
|
53
53
|
def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
|
jaxsim/math/transform.py
ADDED
@@ -0,0 +1,93 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
4
|
+
|
5
|
+
import jaxsim.typing as jtp
|
6
|
+
|
7
|
+
from .quaternion import Quaternion
|
8
|
+
|
9
|
+
|
10
|
+
class Transform:
|
11
|
+
|
12
|
+
@staticmethod
|
13
|
+
def from_quaternion_and_translation(
|
14
|
+
quaternion: jtp.VectorLike = jnp.array([1.0, 0, 0, 0]),
|
15
|
+
translation: jtp.VectorLike = jnp.zeros(3),
|
16
|
+
inverse: jtp.BoolLike = False,
|
17
|
+
normalize_quaternion: jtp.BoolLike = False,
|
18
|
+
) -> jtp.Matrix:
|
19
|
+
"""
|
20
|
+
Create a transformation matrix from a quaternion and a translation.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
quaternion: A 4D vector representing a SO(3) orientation.
|
24
|
+
translation: A 3D vector representing a translation.
|
25
|
+
inverse: Whether to compute the inverse transformation.
|
26
|
+
normalize_quaternion:
|
27
|
+
Whether to normalize the quaternion before creating the transformation.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
The 4x4 transformation matrix representing the SE(3) transformation.
|
31
|
+
"""
|
32
|
+
|
33
|
+
W_Q_B = jnp.array(quaternion).astype(float)
|
34
|
+
W_p_B = jnp.array(translation).astype(float)
|
35
|
+
|
36
|
+
assert W_p_B.size == 3
|
37
|
+
assert W_Q_B.size == 4
|
38
|
+
|
39
|
+
A_R_B = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(W_Q_B))
|
40
|
+
A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
|
41
|
+
|
42
|
+
A_H_B = jaxlie.SE3.from_rotation_and_translation(
|
43
|
+
rotation=A_R_B, translation=W_p_B
|
44
|
+
)
|
45
|
+
|
46
|
+
return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def from_rotation_and_translation(
|
50
|
+
rotation: jtp.MatrixLike,
|
51
|
+
translation: jtp.VectorLike,
|
52
|
+
inverse: jtp.BoolLike = False,
|
53
|
+
) -> jtp.Matrix:
|
54
|
+
"""
|
55
|
+
Create a transformation matrix from a rotation matrix and a translation vector.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
rotation: A 3x3 rotation matrix representing a SO(3) orientation.
|
59
|
+
translation: A 3D vector representing a translation.
|
60
|
+
inverse: Whether to compute the inverse transformation.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
The 4x4 transformation matrix representing the SE(3) transformation.
|
64
|
+
"""
|
65
|
+
|
66
|
+
A_R_B = jnp.array(rotation).astype(float)
|
67
|
+
W_p_B = jnp.array(translation).astype(float)
|
68
|
+
|
69
|
+
assert W_p_B.size == 3
|
70
|
+
assert A_R_B.shape == (3, 3)
|
71
|
+
|
72
|
+
A_H_B = jaxlie.SE3.from_rotation_and_translation(
|
73
|
+
rotation=A_R_B, translation=W_p_B
|
74
|
+
)
|
75
|
+
|
76
|
+
return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
|
77
|
+
|
78
|
+
@staticmethod
|
79
|
+
def inverse(transform: jtp.MatrixLike) -> jtp.Matrix:
|
80
|
+
"""
|
81
|
+
Compute the inverse transformation matrix.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
transform: A 4x4 transformation matrix.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
The 4x4 inverse transformation matrix.
|
88
|
+
"""
|
89
|
+
|
90
|
+
A_H_B = jnp.array(transform).astype(float)
|
91
|
+
assert A_H_B.shape == (4, 4)
|
92
|
+
|
93
|
+
return jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().as_matrix()
|
@@ -3,10 +3,10 @@ from typing import List
|
|
3
3
|
|
4
4
|
import jax.numpy as jnp
|
5
5
|
import jax_dataclasses
|
6
|
+
import jaxlie
|
6
7
|
from jax_dataclasses import Static
|
7
8
|
|
8
9
|
import jaxsim.typing as jtp
|
9
|
-
from jaxsim.sixd import se3
|
10
10
|
from jaxsim.utils import JaxsimDataclass
|
11
11
|
|
12
12
|
|
@@ -78,7 +78,7 @@ class LinkDescription(JaxsimDataclass):
|
|
78
78
|
I_removed = link.inertia
|
79
79
|
|
80
80
|
# Create the SE3 object. Note the inverse.
|
81
|
-
r_H_l =
|
81
|
+
r_H_l = jaxlie.SE3.from_matrix(lumped_H_removed).inverse()
|
82
82
|
r_X_l = r_H_l.adjoint()
|
83
83
|
|
84
84
|
# Move the inertia
|
jaxsim/parsers/rod/utils.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1
1
|
import os
|
2
2
|
from typing import Union
|
3
3
|
|
4
|
-
import
|
4
|
+
import jaxlie
|
5
5
|
import numpy as np
|
6
6
|
import numpy.typing as npt
|
7
7
|
import rod
|
8
8
|
|
9
|
+
import jaxsim.typing as jtp
|
10
|
+
from jaxsim.math.inertia import Inertia
|
9
11
|
from jaxsim.parsers import descriptions
|
10
12
|
|
11
13
|
|
12
|
-
def from_sdf_inertial(inertial: rod.Inertial) ->
|
14
|
+
def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
|
13
15
|
"""
|
14
16
|
Extract the 6D inertia matrix from an SDF inertial element.
|
15
17
|
|
@@ -20,9 +22,6 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
|
|
20
22
|
The 6D inertia matrix of the link expressed in the link frame.
|
21
23
|
"""
|
22
24
|
|
23
|
-
from jaxsim.math.inertia import Inertia
|
24
|
-
from jaxsim.sixd import se3
|
25
|
-
|
26
25
|
# Extract the "mass" element
|
27
26
|
m = inertial.mass
|
28
27
|
|
@@ -52,13 +51,13 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
|
|
52
51
|
L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)
|
53
52
|
|
54
53
|
# We need its inverse
|
55
|
-
CoM_H_L =
|
56
|
-
CoM_X_L
|
54
|
+
CoM_H_L = jaxlie.SE3.from_matrix(matrix=L_H_CoM).inverse()
|
55
|
+
CoM_X_L = CoM_H_L.adjoint()
|
57
56
|
|
58
57
|
# Express the CoM inertia matrix in the link frame L
|
59
58
|
M_L = CoM_X_L.T @ M_CoM @ CoM_X_L
|
60
59
|
|
61
|
-
return
|
60
|
+
return M_L.astype(dtype=float)
|
62
61
|
|
63
62
|
|
64
63
|
def axis_to_jtype(
|
jaxsim/rbda/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
from .aba import aba
|
2
|
+
from .collidable_points import collidable_points_pos_vel
|
3
|
+
from .crba import crba
|
4
|
+
from .forward_kinematics import forward_kinematics, forward_kinematics_model
|
5
|
+
from .jacobian import jacobian, jacobian_full_doubly_left
|
6
|
+
from .rnea import rnea
|
7
|
+
from .soft_contacts import SoftContacts, SoftContactsParams
|