jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/logging.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
import enum
|
2
2
|
import logging
|
3
|
-
from typing import Union
|
4
3
|
|
5
4
|
import coloredlogs
|
6
5
|
|
@@ -20,7 +19,7 @@ def _logger() -> logging.Logger:
|
|
20
19
|
return logging.getLogger(name=LOGGER_NAME)
|
21
20
|
|
22
21
|
|
23
|
-
def set_logging_level(level:
|
22
|
+
def set_logging_level(level: int | LoggingLevel = LoggingLevel.WARNING):
|
24
23
|
if isinstance(level, int):
|
25
24
|
level = LoggingLevel(level)
|
26
25
|
|
jaxsim/math/__init__.py
CHANGED
@@ -0,0 +1,13 @@
|
|
1
|
+
# Define the default standard gravity constant.
|
2
|
+
StandardGravity = 9.81
|
3
|
+
|
4
|
+
from .adjoint import Adjoint
|
5
|
+
from .cross import Cross
|
6
|
+
from .inertia import Inertia
|
7
|
+
from .quaternion import Quaternion
|
8
|
+
from .rotation import Rotation
|
9
|
+
from .skew import Skew
|
10
|
+
from .transform import Transform
|
11
|
+
from .utils import safe_norm
|
12
|
+
|
13
|
+
from .joint_model import JointModel, supported_joint_motion # isort:skip
|
jaxsim/math/adjoint.py
CHANGED
@@ -1,17 +1,20 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
|
+
import jaxlie
|
2
3
|
|
3
4
|
import jaxsim.typing as jtp
|
4
|
-
from jaxsim.sixd import so3
|
5
5
|
|
6
|
-
from .quaternion import Quaternion
|
7
6
|
from .skew import Skew
|
8
7
|
|
9
8
|
|
10
9
|
class Adjoint:
|
10
|
+
"""
|
11
|
+
A utility class for adjoint matrix operations.
|
12
|
+
"""
|
13
|
+
|
11
14
|
@staticmethod
|
12
15
|
def from_quaternion_and_translation(
|
13
|
-
quaternion: jtp.Vector
|
14
|
-
translation: jtp.Vector =
|
16
|
+
quaternion: jtp.Vector | None = None,
|
17
|
+
translation: jtp.Vector | None = None,
|
15
18
|
inverse: bool = False,
|
16
19
|
normalize_quaternion: bool = False,
|
17
20
|
) -> jtp.Matrix:
|
@@ -19,8 +22,8 @@ class Adjoint:
|
|
19
22
|
Create an adjoint matrix from a quaternion and a translation.
|
20
23
|
|
21
24
|
Args:
|
22
|
-
quaternion (jtp.Vector): A quaternion vector (4D) representing orientation.
|
23
|
-
translation (jtp.Vector): A translation vector (3D).
|
25
|
+
quaternion (jtp.Vector): A quaternion vector (4D) representing orientation. Default is [1, 0, 0, 0].
|
26
|
+
translation (jtp.Vector): A translation vector (3D). Default is [0, 0, 0].
|
24
27
|
inverse (bool): Whether to compute the inverse adjoint. Default is False.
|
25
28
|
normalize_quaternion (bool): Whether to normalize the quaternion before creating the adjoint.
|
26
29
|
Default is False.
|
@@ -28,33 +31,59 @@ class Adjoint:
|
|
28
31
|
Returns:
|
29
32
|
jtp.Matrix: The adjoint matrix.
|
30
33
|
"""
|
34
|
+
quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
|
35
|
+
translation = translation if translation is not None else jnp.zeros(3)
|
31
36
|
assert quaternion.size == 4
|
32
37
|
assert translation.size == 3
|
33
38
|
|
34
|
-
Q_sixd =
|
39
|
+
Q_sixd = jaxlie.SO3(wxyz=quaternion)
|
35
40
|
Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()
|
36
41
|
|
37
42
|
return Adjoint.from_rotation_and_translation(
|
38
43
|
rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse
|
39
44
|
)
|
40
45
|
|
46
|
+
@staticmethod
|
47
|
+
def from_transform(transform: jtp.MatrixLike, inverse: bool = False) -> jtp.Matrix:
|
48
|
+
"""
|
49
|
+
Create an adjoint matrix from a transformation matrix.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
transform: A 4x4 transformation matrix.
|
53
|
+
inverse: Whether to compute the inverse adjoint.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The 6x6 adjoint matrix.
|
57
|
+
"""
|
58
|
+
|
59
|
+
A_H_B = jnp.reshape(transform, (-1, 4, 4))
|
60
|
+
|
61
|
+
return (
|
62
|
+
jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint()
|
63
|
+
if not inverse
|
64
|
+
else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint()
|
65
|
+
).reshape(transform.shape[:-2] + (6, 6))
|
66
|
+
|
41
67
|
@staticmethod
|
42
68
|
def from_rotation_and_translation(
|
43
|
-
rotation: jtp.Matrix =
|
44
|
-
translation: jtp.Vector =
|
69
|
+
rotation: jtp.Matrix | None = None,
|
70
|
+
translation: jtp.Vector | None = None,
|
45
71
|
inverse: bool = False,
|
46
72
|
) -> jtp.Matrix:
|
47
73
|
"""
|
48
74
|
Create an adjoint matrix from a rotation matrix and a translation vector.
|
49
75
|
|
50
76
|
Args:
|
51
|
-
rotation (jtp.Matrix): A 3x3 rotation matrix.
|
52
|
-
translation (jtp.Vector): A translation vector (3D).
|
77
|
+
rotation (jtp.Matrix): A 3x3 rotation matrix. Default is identity.
|
78
|
+
translation (jtp.Vector): A translation vector (3D). Default is [0, 0, 0].
|
53
79
|
inverse (bool): Whether to compute the inverse adjoint. Default is False.
|
54
80
|
|
55
81
|
Returns:
|
56
82
|
jtp.Matrix: The adjoint matrix.
|
57
83
|
"""
|
84
|
+
rotation = rotation if rotation is not None else jnp.eye(3)
|
85
|
+
translation = translation if translation is not None else jnp.zeros(3)
|
86
|
+
|
58
87
|
assert rotation.shape == (3, 3)
|
59
88
|
assert translation.size == 3
|
60
89
|
|
@@ -62,14 +91,14 @@ class Adjoint:
|
|
62
91
|
A_o_B = translation.squeeze()
|
63
92
|
|
64
93
|
if not inverse:
|
65
|
-
X = A_X_B = jnp.vstack(
|
94
|
+
X = A_X_B = jnp.vstack( # noqa: F841
|
66
95
|
[
|
67
96
|
jnp.block([A_R_B, Skew.wedge(A_o_B) @ A_R_B]),
|
68
97
|
jnp.block([jnp.zeros(shape=(3, 3)), A_R_B]),
|
69
98
|
]
|
70
99
|
)
|
71
100
|
else:
|
72
|
-
X = B_X_A = jnp.vstack(
|
101
|
+
X = B_X_A = jnp.vstack( # noqa: F841
|
73
102
|
[
|
74
103
|
jnp.block([A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)]),
|
75
104
|
jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]),
|
@@ -84,7 +113,7 @@ class Adjoint:
|
|
84
113
|
Convert an adjoint matrix to a transformation matrix.
|
85
114
|
|
86
115
|
Args:
|
87
|
-
adjoint
|
116
|
+
adjoint: The adjoint matrix (6x6).
|
88
117
|
|
89
118
|
Returns:
|
90
119
|
jtp.Matrix: The transformation matrix (4x4).
|
@@ -110,17 +139,23 @@ class Adjoint:
|
|
110
139
|
Compute the inverse of an adjoint matrix.
|
111
140
|
|
112
141
|
Args:
|
113
|
-
adjoint
|
142
|
+
adjoint: The adjoint matrix.
|
114
143
|
|
115
144
|
Returns:
|
116
145
|
jtp.Matrix: The inverse adjoint matrix.
|
117
146
|
"""
|
118
|
-
A_X_B = adjoint
|
119
|
-
A_H_B = Adjoint.to_transform(adjoint=A_X_B)
|
147
|
+
A_X_B = adjoint.reshape(-1, 6, 6)
|
120
148
|
|
121
|
-
|
122
|
-
|
149
|
+
A_R_B_T = jnp.swapaxes(A_X_B[..., 0:3, 0:3], -2, -1)
|
150
|
+
A_T_B = A_X_B[..., 0:3, 3:6]
|
123
151
|
|
124
|
-
return
|
125
|
-
|
126
|
-
|
152
|
+
return jnp.concatenate(
|
153
|
+
[
|
154
|
+
jnp.concatenate(
|
155
|
+
[A_R_B_T, -A_R_B_T @ A_T_B @ A_R_B_T],
|
156
|
+
axis=-1,
|
157
|
+
),
|
158
|
+
jnp.concatenate([jnp.zeros_like(A_R_B_T), A_R_B_T], axis=-1),
|
159
|
+
],
|
160
|
+
axis=-2,
|
161
|
+
).reshape(adjoint.shape)
|
jaxsim/math/cross.py
CHANGED
@@ -6,13 +6,17 @@ from .skew import Skew
|
|
6
6
|
|
7
7
|
|
8
8
|
class Cross:
|
9
|
+
"""
|
10
|
+
A utility class for cross product matrix operations.
|
11
|
+
"""
|
12
|
+
|
9
13
|
@staticmethod
|
10
14
|
def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
|
11
15
|
"""
|
12
16
|
Compute the cross product matrix for 6D velocities.
|
13
17
|
|
14
18
|
Args:
|
15
|
-
velocity_sixd
|
19
|
+
velocity_sixd: A 6D velocity vector [v, ω].
|
16
20
|
|
17
21
|
Returns:
|
18
22
|
jtp.Matrix: The cross product matrix (6x6).
|
@@ -20,13 +24,18 @@ class Cross:
|
|
20
24
|
Raises:
|
21
25
|
ValueError: If the input vector does not have a size of 6.
|
22
26
|
"""
|
23
|
-
|
27
|
+
velocity_sixd = velocity_sixd.reshape(-1, 6)
|
28
|
+
|
29
|
+
v, ω = jnp.split(velocity_sixd, 2, axis=-1)
|
24
30
|
|
25
|
-
v_cross = jnp.
|
31
|
+
v_cross = jnp.concatenate(
|
26
32
|
[
|
27
|
-
jnp.
|
28
|
-
|
29
|
-
|
33
|
+
jnp.concatenate(
|
34
|
+
[Skew.wedge(ω), jnp.zeros((ω.shape[0], 3, 3)).squeeze()], axis=-2
|
35
|
+
),
|
36
|
+
jnp.concatenate([Skew.wedge(v), Skew.wedge(ω)], axis=-2),
|
37
|
+
],
|
38
|
+
axis=-1,
|
30
39
|
)
|
31
40
|
|
32
41
|
return v_cross
|
@@ -37,7 +46,7 @@ class Cross:
|
|
37
46
|
Compute the negative transpose of the cross product matrix for 6D velocities.
|
38
47
|
|
39
48
|
Args:
|
40
|
-
velocity_sixd
|
49
|
+
velocity_sixd: A 6D velocity vector [v, ω].
|
41
50
|
|
42
51
|
Returns:
|
43
52
|
jtp.Matrix: The negative transpose of the cross product matrix (6x6).
|
jaxsim/math/inertia.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
1
|
import jax.numpy as jnp
|
4
2
|
|
5
3
|
import jaxsim.typing as jtp
|
@@ -8,15 +6,19 @@ from .skew import Skew
|
|
8
6
|
|
9
7
|
|
10
8
|
class Inertia:
|
9
|
+
"""
|
10
|
+
A utility class for inertia matrix operations.
|
11
|
+
"""
|
12
|
+
|
11
13
|
@staticmethod
|
12
14
|
def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:
|
13
15
|
"""
|
14
16
|
Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix.
|
15
17
|
|
16
18
|
Args:
|
17
|
-
mass
|
18
|
-
com
|
19
|
-
I
|
19
|
+
mass: The mass of the body.
|
20
|
+
com: The center of mass position (3D).
|
21
|
+
I: The 3x3 inertia matrix.
|
20
22
|
|
21
23
|
Returns:
|
22
24
|
jtp.Matrix: The 6x6 inertia matrix.
|
@@ -39,15 +41,15 @@ class Inertia:
|
|
39
41
|
return M
|
40
42
|
|
41
43
|
@staticmethod
|
42
|
-
def to_params(M: jtp.Matrix) ->
|
44
|
+
def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
|
43
45
|
"""
|
44
46
|
Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.
|
45
47
|
|
46
48
|
Args:
|
47
|
-
M
|
49
|
+
M: The 6x6 inertia matrix.
|
48
50
|
|
49
51
|
Returns:
|
50
|
-
|
52
|
+
tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
|
51
53
|
|
52
54
|
Raises:
|
53
55
|
ValueError: If the input matrix M has an unexpected shape.
|
@@ -0,0 +1,289 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jax_dataclasses
|
6
|
+
import jaxlie
|
7
|
+
from jax_dataclasses import Static
|
8
|
+
|
9
|
+
import jaxsim.typing as jtp
|
10
|
+
from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
|
11
|
+
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
|
12
|
+
|
13
|
+
from .rotation import Rotation
|
14
|
+
from .transform import Transform
|
15
|
+
|
16
|
+
|
17
|
+
@jax_dataclasses.pytree_dataclass
|
18
|
+
class JointModel:
|
19
|
+
"""
|
20
|
+
Class describing the joint kinematics of a robot model.
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
λ_H_pre:
|
24
|
+
The homogeneous transformation between the parent link and
|
25
|
+
the predecessor frame of each joint.
|
26
|
+
suc_H_i:
|
27
|
+
The homogeneous transformation between the successor frame and
|
28
|
+
the child link of each joint.
|
29
|
+
joint_dofs: The number of DoFs of each joint.
|
30
|
+
joint_names: The names of each joint.
|
31
|
+
joint_types: The types of each joint.
|
32
|
+
|
33
|
+
Note:
|
34
|
+
Due to the presence of the static attributes, this class needs to be created
|
35
|
+
already in a vectorized form. In other words, it cannot be created using vmap.
|
36
|
+
"""
|
37
|
+
|
38
|
+
λ_H_pre: jtp.Array
|
39
|
+
suc_H_i: jtp.Array
|
40
|
+
|
41
|
+
joint_dofs: Static[tuple[int, ...]]
|
42
|
+
joint_names: Static[tuple[str, ...]]
|
43
|
+
joint_types: Static[tuple[int, ...]]
|
44
|
+
joint_axis: Static[tuple[JointGenericAxis, ...]]
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def build(description: ModelDescription) -> JointModel:
|
48
|
+
"""
|
49
|
+
Build the joint model of a model description.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
description: The model description to consider.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The joint model of the considered model description.
|
56
|
+
"""
|
57
|
+
|
58
|
+
# The link index is equal to its body index: [0, number_of_bodies - 1].
|
59
|
+
ordered_links = sorted(
|
60
|
+
list(description.links_dict.values()),
|
61
|
+
key=lambda l: l.index,
|
62
|
+
)
|
63
|
+
|
64
|
+
# Note: the joint index is equal to its child link index, therefore it
|
65
|
+
# starts from 1.
|
66
|
+
ordered_joints = sorted(
|
67
|
+
list(description.joints_dict.values()),
|
68
|
+
key=lambda j: j.index,
|
69
|
+
)
|
70
|
+
|
71
|
+
# Allocate the parent-to-predecessor and successor-to-child transforms.
|
72
|
+
λ_H_pre = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
|
73
|
+
suc_H_i = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
|
74
|
+
|
75
|
+
# Initialize an identical parent-to-predecessor transform for the joint
|
76
|
+
# between the world frame W and the base link B.
|
77
|
+
λ_H_pre = λ_H_pre.at[0].set(jnp.eye(4))
|
78
|
+
|
79
|
+
# Initialize the successor-to-child transform of the joint between the
|
80
|
+
# world frame W and the base link B.
|
81
|
+
# We store here the optional transform between the root frame of the model
|
82
|
+
# and the base link frame (this is needed only if the pose of the link frame
|
83
|
+
# w.r.t. the implicit __model__ SDF frame is not the identity).
|
84
|
+
suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose)
|
85
|
+
|
86
|
+
# Create the object to compute forward kinematics.
|
87
|
+
fk = KinematicGraphTransforms(graph=description)
|
88
|
+
|
89
|
+
# Compute the parent-to-predecessor and successor-to-child transforms for
|
90
|
+
# each joint belonging to the model.
|
91
|
+
# Note that the joint indices starts from i=1 given our joint model,
|
92
|
+
# therefore the entries at index 0 are not updated.
|
93
|
+
for joint in ordered_joints:
|
94
|
+
λ_H_pre = λ_H_pre.at[joint.index].set(
|
95
|
+
fk.relative_transform(relative_to=joint.parent.name, name=joint.name)
|
96
|
+
)
|
97
|
+
suc_H_i = suc_H_i.at[joint.index].set(
|
98
|
+
fk.relative_transform(relative_to=joint.name, name=joint.child.name)
|
99
|
+
)
|
100
|
+
|
101
|
+
# Define the DoFs of the base link.
|
102
|
+
base_dofs = 0 if description.fixed_base else 6
|
103
|
+
|
104
|
+
# We always add a dummy fixed joint between world and base.
|
105
|
+
# TODO: Port floating-base support also at this level, not only in RBDAs.
|
106
|
+
return JointModel(
|
107
|
+
λ_H_pre=λ_H_pre,
|
108
|
+
suc_H_i=suc_H_i,
|
109
|
+
# Static attributes
|
110
|
+
joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]),
|
111
|
+
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
|
112
|
+
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
|
113
|
+
joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
|
114
|
+
)
|
115
|
+
|
116
|
+
def parent_H_child(
|
117
|
+
self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
|
118
|
+
) -> tuple[jtp.Matrix, jtp.Array]:
|
119
|
+
r"""
|
120
|
+
Compute the homogeneous transformation between the parent link and
|
121
|
+
the child link of a joint, and the corresponding motion subspace.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
joint_index: The index of the joint.
|
125
|
+
joint_position: The position of the joint.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
A tuple containing the homogeneous transformation
|
129
|
+
:math:`{}^{\lambda(i)} \mathbf{H}_i(s)`
|
130
|
+
and the motion subspace :math:`\mathbf{S}(s)`.
|
131
|
+
"""
|
132
|
+
|
133
|
+
i = joint_index
|
134
|
+
s = joint_position
|
135
|
+
|
136
|
+
# Get the components of the joint model.
|
137
|
+
λ_Hi_pre = self.parent_H_predecessor(joint_index=i)
|
138
|
+
pre_Hi_suc, S = self.predecessor_H_successor(joint_index=i, joint_position=s)
|
139
|
+
suc_Hi_i = self.successor_H_child(joint_index=i)
|
140
|
+
|
141
|
+
# Compose all the transforms.
|
142
|
+
return λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, S
|
143
|
+
|
144
|
+
@jax.jit
|
145
|
+
def child_H_parent(
|
146
|
+
self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
|
147
|
+
) -> tuple[jtp.Matrix, jtp.Array]:
|
148
|
+
r"""
|
149
|
+
Compute the homogeneous transformation between the child link and
|
150
|
+
the parent link of a joint, and the corresponding motion subspace.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
joint_index: The index of the joint.
|
154
|
+
joint_position: The position of the joint.
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
A tuple containing the homogeneous transformation
|
158
|
+
:math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
|
159
|
+
and the motion subspace :math:`\mathbf{S}(s)`.
|
160
|
+
"""
|
161
|
+
|
162
|
+
λ_Hi_i, S = self.parent_H_child(
|
163
|
+
joint_index=joint_index, joint_position=joint_position
|
164
|
+
)
|
165
|
+
|
166
|
+
i_Hi_λ = Transform.inverse(λ_Hi_i)
|
167
|
+
|
168
|
+
return i_Hi_λ, S
|
169
|
+
|
170
|
+
def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
|
171
|
+
r"""
|
172
|
+
Return the homogeneous transformation between the parent link and
|
173
|
+
the predecessor frame of a joint.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
joint_index: The index of the joint.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
The homogeneous transformation
|
180
|
+
:math:`{}^{\lambda(i)} \mathbf{H}_{\text{pre}(i)}`.
|
181
|
+
"""
|
182
|
+
|
183
|
+
return self.λ_H_pre[joint_index]
|
184
|
+
|
185
|
+
def predecessor_H_successor(
|
186
|
+
self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
|
187
|
+
) -> tuple[jtp.Matrix, jtp.Array]:
|
188
|
+
r"""
|
189
|
+
Compute the homogeneous transformation between the predecessor and
|
190
|
+
the successor frame of a joint, and the corresponding motion subspace.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
joint_index: The index of the joint.
|
194
|
+
joint_position: The position of the joint.
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
A tuple containing the homogeneous transformation
|
198
|
+
:math:`{}^{\text{pre}(i)} \mathbf{H}_{\text{suc}(i)}(s)`
|
199
|
+
and the motion subspace :math:`\mathbf{S}(s)`.
|
200
|
+
"""
|
201
|
+
|
202
|
+
pre_H_suc, S = supported_joint_motion(
|
203
|
+
self.joint_types[joint_index],
|
204
|
+
joint_position,
|
205
|
+
self.joint_axis[joint_index].axis,
|
206
|
+
)
|
207
|
+
|
208
|
+
return pre_H_suc, S
|
209
|
+
|
210
|
+
def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
|
211
|
+
r"""
|
212
|
+
Return the homogeneous transformation between the successor frame and
|
213
|
+
the child link of a joint.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
joint_index: The index of the joint.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
The homogeneous transformation
|
220
|
+
:math:`{}^{\text{suc}(i)} \mathbf{H}_i`.
|
221
|
+
"""
|
222
|
+
|
223
|
+
return self.suc_H_i[joint_index]
|
224
|
+
|
225
|
+
|
226
|
+
@jax.jit
|
227
|
+
def supported_joint_motion(
|
228
|
+
joint_type: jtp.IntLike,
|
229
|
+
joint_position: jtp.VectorLike,
|
230
|
+
joint_axis: jtp.VectorLike | None = None,
|
231
|
+
/,
|
232
|
+
) -> tuple[jtp.Matrix, jtp.Array]:
|
233
|
+
"""
|
234
|
+
Compute the homogeneous transformation and motion subspace of a joint.
|
235
|
+
|
236
|
+
Args:
|
237
|
+
joint_type: The type of the joint.
|
238
|
+
joint_position: The position of the joint.
|
239
|
+
joint_axis: The optional 3D axis of rotation or translation of the joint.
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
A tuple containing the homogeneous transformation and the motion subspace.
|
243
|
+
"""
|
244
|
+
|
245
|
+
# Prepare the joint position
|
246
|
+
s = jnp.array(joint_position).astype(float)
|
247
|
+
|
248
|
+
def compute_F() -> tuple[jtp.Matrix, jtp.Array]:
|
249
|
+
return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))
|
250
|
+
|
251
|
+
def compute_R() -> tuple[jtp.Matrix, jtp.Array]:
|
252
|
+
|
253
|
+
# Get the additional argument specifying the joint axis.
|
254
|
+
# This is a metadata required by only some joint types.
|
255
|
+
axis = jnp.array(joint_axis).astype(float).squeeze()
|
256
|
+
|
257
|
+
pre_H_suc = jaxlie.SE3.from_matrix(
|
258
|
+
matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis))
|
259
|
+
)
|
260
|
+
|
261
|
+
S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))
|
262
|
+
|
263
|
+
return pre_H_suc, S
|
264
|
+
|
265
|
+
def compute_P() -> tuple[jtp.Matrix, jtp.Array]:
|
266
|
+
|
267
|
+
# Get the additional argument specifying the joint axis.
|
268
|
+
# This is a metadata required by only some joint types.
|
269
|
+
axis = jnp.array(joint_axis).astype(float).squeeze()
|
270
|
+
|
271
|
+
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
|
272
|
+
rotation=jaxlie.SO3.identity(),
|
273
|
+
translation=jnp.array(s * axis),
|
274
|
+
)
|
275
|
+
|
276
|
+
S = jnp.vstack(jnp.hstack([axis, jnp.zeros(3)]))
|
277
|
+
|
278
|
+
return pre_H_suc, S
|
279
|
+
|
280
|
+
pre_H_suc, S = jax.lax.switch(
|
281
|
+
index=joint_type,
|
282
|
+
branches=(
|
283
|
+
compute_F, # JointType.Fixed
|
284
|
+
compute_R, # JointType.Revolute
|
285
|
+
compute_P, # JointType.Prismatic
|
286
|
+
),
|
287
|
+
)
|
288
|
+
|
289
|
+
return pre_H_suc.as_matrix(), S
|