jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/math/conv.py
DELETED
@@ -1,114 +0,0 @@
|
|
1
|
-
import jax.numpy as jnp
|
2
|
-
|
3
|
-
import jaxsim.typing as jtp
|
4
|
-
|
5
|
-
from .skew import Skew
|
6
|
-
|
7
|
-
|
8
|
-
class Convert:
|
9
|
-
@staticmethod
|
10
|
-
def coordinates_tf(X: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
|
11
|
-
"""
|
12
|
-
Transform coordinates from one frame to another using a transformation matrix.
|
13
|
-
|
14
|
-
Args:
|
15
|
-
X (jtp.Matrix): The transformation matrix (4x4 or 6x6).
|
16
|
-
p (jtp.Matrix): The coordinates to be transformed (3xN).
|
17
|
-
|
18
|
-
Returns:
|
19
|
-
jtp.Matrix: Transformed coordinates (3xN).
|
20
|
-
|
21
|
-
Raises:
|
22
|
-
ValueError: If the input matrix p does not have shape (3, N).
|
23
|
-
"""
|
24
|
-
X = X.squeeze()
|
25
|
-
p = p.squeeze()
|
26
|
-
|
27
|
-
# If p has shape (X,), transform it to a column vector
|
28
|
-
p = jnp.vstack(p) if len(p.shape) == 1 else p
|
29
|
-
rows_p, cols_p = p.shape
|
30
|
-
|
31
|
-
if rows_p != 3:
|
32
|
-
raise ValueError(p.shape)
|
33
|
-
|
34
|
-
R = X[0:3, 0:3]
|
35
|
-
r = -Skew.vee(R.T @ X[0:3, 3:6])
|
36
|
-
|
37
|
-
if cols_p > 1:
|
38
|
-
r = jnp.tile(r, (1, cols_p))
|
39
|
-
|
40
|
-
assert r.shape == p.shape, (r.shape, p.shape)
|
41
|
-
|
42
|
-
xp = R @ (p - r)
|
43
|
-
return jnp.vstack(xp)
|
44
|
-
|
45
|
-
@staticmethod
|
46
|
-
def velocities_threed(v_6d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
|
47
|
-
"""
|
48
|
-
Compute 3D velocities based on 6D velocities and positions.
|
49
|
-
|
50
|
-
Args:
|
51
|
-
v_6d (jtp.Matrix): The 6D velocities (6xN).
|
52
|
-
p (jtp.Matrix): The positions (3xN).
|
53
|
-
|
54
|
-
Returns:
|
55
|
-
jtp.Matrix: 3D velocities (3xN).
|
56
|
-
|
57
|
-
Raises:
|
58
|
-
ValueError: If the input matrices have incompatible shapes.
|
59
|
-
"""
|
60
|
-
v = v_6d.squeeze()
|
61
|
-
p = p.squeeze()
|
62
|
-
|
63
|
-
# If the arrays have shape (X,), transform them to column vectors
|
64
|
-
v = jnp.vstack(v) if len(v.shape) == 1 else v
|
65
|
-
p = jnp.vstack(p) if len(p.shape) == 1 else p
|
66
|
-
|
67
|
-
rows_v, cols_v = v.shape
|
68
|
-
_, cols_p = p.shape
|
69
|
-
|
70
|
-
if cols_v == 1 and cols_p > 1:
|
71
|
-
v = jnp.repeat(v, cols_p, axis=1)
|
72
|
-
|
73
|
-
if rows_v == 6:
|
74
|
-
vp = v[0:3, :] + jnp.cross(v[3:6, :], p, axis=0)
|
75
|
-
else:
|
76
|
-
raise ValueError(v.shape)
|
77
|
-
|
78
|
-
return jnp.vstack(vp)
|
79
|
-
|
80
|
-
@staticmethod
|
81
|
-
def forces_sixd(f_3d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
|
82
|
-
"""
|
83
|
-
Compute 6D forces based on 3D forces and positions.
|
84
|
-
|
85
|
-
Args:
|
86
|
-
f_3d (jtp.Matrix): The 3D forces (3xN).
|
87
|
-
p (jtp.Matrix): The positions (3xN).
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
jtp.Matrix: 6D forces (6xN).
|
91
|
-
|
92
|
-
Raises:
|
93
|
-
ValueError: If the input matrices have incompatible shapes.
|
94
|
-
"""
|
95
|
-
f = f_3d.squeeze()
|
96
|
-
p = p.squeeze()
|
97
|
-
|
98
|
-
# If the arrays have shape (X,), transform them to column vectors
|
99
|
-
fp = jnp.vstack(f) if len(f.shape) == 1 else f
|
100
|
-
p = jnp.vstack(p) if len(p.shape) == 1 else p
|
101
|
-
|
102
|
-
_, cols_p = p.shape
|
103
|
-
rows_fp, cols_fp = fp.shape
|
104
|
-
|
105
|
-
# Number of columns must match
|
106
|
-
if cols_p != cols_fp:
|
107
|
-
raise ValueError(cols_p, cols_fp)
|
108
|
-
|
109
|
-
if rows_fp == 3:
|
110
|
-
f = jnp.vstack([fp, jnp.cross(p, fp, axis=0)])
|
111
|
-
else:
|
112
|
-
raise ValueError(fp.shape)
|
113
|
-
|
114
|
-
return jnp.vstack(f)
|
jaxsim/math/joint.py
DELETED
@@ -1,102 +0,0 @@
|
|
1
|
-
from typing import Tuple, Union
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
|
5
|
-
import jaxsim.typing as jtp
|
6
|
-
from jaxsim.parsers.descriptions import JointDescriptor, JointGenericAxis, JointType
|
7
|
-
|
8
|
-
from .adjoint import Adjoint
|
9
|
-
from .rotation import Rotation
|
10
|
-
|
11
|
-
|
12
|
-
def jcalc(
|
13
|
-
jtyp: Union[JointType, JointDescriptor], q: jtp.Float
|
14
|
-
) -> Tuple[jtp.Matrix, jtp.Vector]:
|
15
|
-
"""
|
16
|
-
Compute the spatial transformation matrix and motion subspace vector for a joint.
|
17
|
-
|
18
|
-
Args:
|
19
|
-
jtyp (Union[JointType, JointDescriptor]): The type or descriptor of the joint.
|
20
|
-
q (jtp.Float): The joint configuration parameter.
|
21
|
-
|
22
|
-
Returns:
|
23
|
-
Tuple[jtp.Matrix, jtp.Vector]: A tuple containing the spatial transformation matrix (6x6) and the motion subspace vector (6x1).
|
24
|
-
|
25
|
-
Raises:
|
26
|
-
ValueError: If the joint type or descriptor is not recognized.
|
27
|
-
"""
|
28
|
-
if isinstance(jtyp, JointType):
|
29
|
-
code = jtyp
|
30
|
-
elif isinstance(jtyp, JointDescriptor):
|
31
|
-
code = jtyp.code
|
32
|
-
else:
|
33
|
-
raise ValueError(jtyp)
|
34
|
-
|
35
|
-
match code:
|
36
|
-
case JointType.F:
|
37
|
-
raise ValueError("Fixed joints shouldn't be here")
|
38
|
-
|
39
|
-
case JointType.R:
|
40
|
-
jtyp: JointGenericAxis
|
41
|
-
|
42
|
-
Xj = Adjoint.from_rotation_and_translation(
|
43
|
-
rotation=Rotation.from_axis_angle(vector=q * jtyp.axis), inverse=True
|
44
|
-
)
|
45
|
-
|
46
|
-
S = jnp.vstack(jnp.hstack([jnp.zeros(3), jtyp.axis.squeeze()]))
|
47
|
-
|
48
|
-
case JointType.P:
|
49
|
-
jtyp: JointGenericAxis
|
50
|
-
|
51
|
-
Xj = Adjoint.from_rotation_and_translation(
|
52
|
-
translation=jnp.array(q * jtyp.axis), inverse=True
|
53
|
-
)
|
54
|
-
|
55
|
-
S = jnp.vstack(jnp.hstack([jtyp.axis.squeeze(), jnp.zeros(3)]))
|
56
|
-
|
57
|
-
case JointType.Rx:
|
58
|
-
Xj = Adjoint.from_rotation_and_translation(
|
59
|
-
rotation=Rotation.x(theta=q), inverse=True
|
60
|
-
)
|
61
|
-
|
62
|
-
S = jnp.vstack([0, 0, 0, 1.0, 0, 0])
|
63
|
-
|
64
|
-
case JointType.Ry:
|
65
|
-
Xj = Adjoint.from_rotation_and_translation(
|
66
|
-
rotation=Rotation.y(theta=q), inverse=True
|
67
|
-
)
|
68
|
-
|
69
|
-
S = jnp.vstack([0, 0, 0, 0, 1.0, 0])
|
70
|
-
|
71
|
-
case JointType.Rz:
|
72
|
-
Xj = Adjoint.from_rotation_and_translation(
|
73
|
-
rotation=Rotation.z(theta=q), inverse=True
|
74
|
-
)
|
75
|
-
|
76
|
-
S = jnp.vstack([0, 0, 0, 0, 0, 1.0])
|
77
|
-
|
78
|
-
case JointType.Px:
|
79
|
-
Xj = Adjoint.from_rotation_and_translation(
|
80
|
-
translation=jnp.array([q, 0.0, 0.0]), inverse=True
|
81
|
-
)
|
82
|
-
|
83
|
-
S = jnp.vstack([1.0, 0, 0, 0, 0, 0])
|
84
|
-
|
85
|
-
case JointType.Py:
|
86
|
-
Xj = Adjoint.from_rotation_and_translation(
|
87
|
-
translation=jnp.array([0.0, q, 0.0]), inverse=True
|
88
|
-
)
|
89
|
-
|
90
|
-
S = jnp.vstack([0, 1.0, 0, 0, 0, 0])
|
91
|
-
|
92
|
-
case JointType.Pz:
|
93
|
-
Xj = Adjoint.from_rotation_and_translation(
|
94
|
-
translation=jnp.array([0.0, 0.0, q]), inverse=True
|
95
|
-
)
|
96
|
-
|
97
|
-
S = jnp.vstack([0, 0, 1.0, 0, 0, 0])
|
98
|
-
|
99
|
-
case _:
|
100
|
-
raise ValueError(code)
|
101
|
-
|
102
|
-
return Xj, S
|
jaxsim/math/plucker.py
DELETED
@@ -1,100 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
|
5
|
-
import jaxsim.typing as jtp
|
6
|
-
|
7
|
-
from .skew import Skew
|
8
|
-
|
9
|
-
|
10
|
-
class Plucker:
|
11
|
-
@staticmethod
|
12
|
-
def from_rot_and_trans(dcm: jtp.Matrix, translation: jtp.Vector) -> jtp.Matrix:
|
13
|
-
"""
|
14
|
-
Computes the Plücker matrix from a rotation matrix and a translation vector.
|
15
|
-
|
16
|
-
Args:
|
17
|
-
dcm: A 3x3 rotation matrix.
|
18
|
-
translation: A 3x1 translation vector.
|
19
|
-
|
20
|
-
Returns:
|
21
|
-
A 6x6 Plücker matrix.
|
22
|
-
"""
|
23
|
-
R = dcm
|
24
|
-
|
25
|
-
X = jnp.block(
|
26
|
-
[
|
27
|
-
[R, -R @ Skew.wedge(vector=translation)],
|
28
|
-
[jnp.zeros(shape=(3, 3)), R],
|
29
|
-
]
|
30
|
-
)
|
31
|
-
|
32
|
-
return X
|
33
|
-
|
34
|
-
@staticmethod
|
35
|
-
def to_rot_and_trans(adjoint: jtp.Matrix) -> Tuple[jtp.Matrix, jtp.Vector]:
|
36
|
-
"""
|
37
|
-
Computes the rotation matrix and translation vector from a Plücker matrix.
|
38
|
-
|
39
|
-
Args:
|
40
|
-
adjoint: A 6x6 Plücker matrix.
|
41
|
-
|
42
|
-
Returns:
|
43
|
-
A tuple containing the 3x3 rotation matrix and the 3x1 translation vector.
|
44
|
-
"""
|
45
|
-
X = adjoint
|
46
|
-
|
47
|
-
R = X[0:3, 0:3]
|
48
|
-
p = -Skew.vee(R.T @ X[0:3, 3:6])
|
49
|
-
|
50
|
-
return R, p
|
51
|
-
|
52
|
-
@staticmethod
|
53
|
-
def from_transform(transform: jtp.Matrix) -> jtp.Matrix:
|
54
|
-
"""
|
55
|
-
Computes the Plücker matrix from a homogeneous transformation matrix.
|
56
|
-
|
57
|
-
Args:
|
58
|
-
transform: A 4x4 homogeneous transformation matrix.
|
59
|
-
|
60
|
-
Returns:
|
61
|
-
A 6x6 Plücker matrix.
|
62
|
-
"""
|
63
|
-
H = transform
|
64
|
-
|
65
|
-
R = H[0:3, 0:3]
|
66
|
-
p = H[0:3, 3]
|
67
|
-
|
68
|
-
X = jnp.block(
|
69
|
-
[
|
70
|
-
[R, Skew.wedge(vector=p) @ R],
|
71
|
-
[jnp.zeros(shape=(3, 3)), R],
|
72
|
-
]
|
73
|
-
)
|
74
|
-
|
75
|
-
return X
|
76
|
-
|
77
|
-
@staticmethod
|
78
|
-
def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix:
|
79
|
-
"""
|
80
|
-
Computes the homogeneous transformation matrix from a Plücker matrix.
|
81
|
-
|
82
|
-
Args:
|
83
|
-
adjoint: A 6x6 Plücker matrix.
|
84
|
-
|
85
|
-
Returns:
|
86
|
-
A 4x4 homogeneous transformation matrix.
|
87
|
-
"""
|
88
|
-
X = adjoint
|
89
|
-
|
90
|
-
R = X[0:3, 0:3]
|
91
|
-
o_x_R = X[0:3, 3:6]
|
92
|
-
|
93
|
-
H = jnp.vstack(
|
94
|
-
[
|
95
|
-
jnp.hstack([R, Skew.vee(matrix=o_x_R @ R.T)]),
|
96
|
-
[0, 0, 0, 1],
|
97
|
-
]
|
98
|
-
)
|
99
|
-
|
100
|
-
return H
|
jaxsim/physics/__init__.py
DELETED
jaxsim/physics/algos/__init__.py
DELETED
File without changes
|
jaxsim/physics/algos/aba.py
DELETED
@@ -1,254 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.math.adjoint import Adjoint
|
9
|
-
from jaxsim.math.cross import Cross
|
10
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
11
|
-
|
12
|
-
from . import utils
|
13
|
-
|
14
|
-
|
15
|
-
def aba(
|
16
|
-
model: PhysicsModel,
|
17
|
-
xfb: jtp.Vector,
|
18
|
-
q: jtp.Vector,
|
19
|
-
qd: jtp.Vector,
|
20
|
-
tau: jtp.Vector,
|
21
|
-
f_ext: jtp.Matrix | None = None,
|
22
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
23
|
-
"""
|
24
|
-
Articulated Body Algorithm (ABA) algorithm for forward dynamics.
|
25
|
-
|
26
|
-
Args:
|
27
|
-
model: The physics model of the articulated body or robot.
|
28
|
-
xfb: The floating base state vector containing quaternion (4D) and position (3D).
|
29
|
-
q: Joint positions (Generalized coordinates).
|
30
|
-
qd: Joint velocities.
|
31
|
-
tau: Joint torques or forces.
|
32
|
-
f_ext: External forces and torques acting on each link. Defaults to None.
|
33
|
-
|
34
|
-
Returns:
|
35
|
-
A tuple containing the resulting base acceleration (in inertial-fixed representation)
|
36
|
-
and joint accelerations.
|
37
|
-
|
38
|
-
Note:
|
39
|
-
The ABA algorithm is used to compute the accelerations of the links in an articulated body or robot system given
|
40
|
-
inputs such as joint positions, velocities, torques, and external forces. The algorithm involves multiple passes
|
41
|
-
to calculate intermediate quantities required for simulating the motion of the robot.
|
42
|
-
"""
|
43
|
-
|
44
|
-
x_fb, q, qd, _, tau, f_ext = utils.process_inputs(
|
45
|
-
physics_model=model, xfb=xfb, q=q, qd=qd, tau=tau, f_ext=f_ext
|
46
|
-
)
|
47
|
-
|
48
|
-
# Extract data from the physics model
|
49
|
-
pre_X_λi = model.tree_transforms
|
50
|
-
M = model.spatial_inertias
|
51
|
-
i_X_pre = model.joint_transforms(q=q)
|
52
|
-
S = model.motion_subspaces(q=q)
|
53
|
-
λ = model.parent_array()
|
54
|
-
|
55
|
-
# Initialize buffers
|
56
|
-
v = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
57
|
-
MA = jnp.array([jnp.zeros([6, 6])] * model.NB)
|
58
|
-
pA = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
59
|
-
c = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
60
|
-
i_X_λi = jnp.zeros_like(i_X_pre)
|
61
|
-
|
62
|
-
# Base pose B_X_W and velocity
|
63
|
-
base_quat = jnp.vstack(x_fb[0:4])
|
64
|
-
base_pos = jnp.vstack(x_fb[4:7])
|
65
|
-
base_vel = jnp.vstack(jnp.hstack([x_fb[10:13], x_fb[7:10]]))
|
66
|
-
|
67
|
-
# 6D transform of base velocity
|
68
|
-
B_X_W = Adjoint.from_quaternion_and_translation(
|
69
|
-
quaternion=base_quat,
|
70
|
-
translation=base_pos,
|
71
|
-
inverse=True,
|
72
|
-
normalize_quaternion=True,
|
73
|
-
)
|
74
|
-
i_X_λi = i_X_λi.at[0].set(B_X_W)
|
75
|
-
|
76
|
-
# Transforms link -> base
|
77
|
-
i_X_0 = jnp.zeros_like(pre_X_λi)
|
78
|
-
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
79
|
-
|
80
|
-
# Initialize base quantities
|
81
|
-
if model.is_floating_base:
|
82
|
-
# Base velocity v₀
|
83
|
-
v_0 = B_X_W @ base_vel
|
84
|
-
v = v.at[0].set(v_0)
|
85
|
-
|
86
|
-
# AB inertia (Mᴬ) and AB bias forces (pᴬ)
|
87
|
-
MA_0 = M[0]
|
88
|
-
MA = MA.at[0].set(MA_0)
|
89
|
-
pA_0 = Cross.vx_star(v[0]) @ MA_0 @ v[0] - Adjoint.inverse(
|
90
|
-
B_X_W
|
91
|
-
).T @ jnp.vstack(f_ext[0])
|
92
|
-
pA = pA.at[0].set(pA_0)
|
93
|
-
|
94
|
-
Pass1Carry = Tuple[
|
95
|
-
jtp.MatrixJax,
|
96
|
-
jtp.MatrixJax,
|
97
|
-
jtp.MatrixJax,
|
98
|
-
jtp.MatrixJax,
|
99
|
-
jtp.MatrixJax,
|
100
|
-
jtp.MatrixJax,
|
101
|
-
]
|
102
|
-
|
103
|
-
pass_1_carry = (i_X_λi, v, c, MA, pA, i_X_0)
|
104
|
-
|
105
|
-
# Pass 1
|
106
|
-
def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
|
107
|
-
ii = i - 1
|
108
|
-
i_X_λi, v, c, MA, pA, i_X_0 = carry
|
109
|
-
|
110
|
-
# Compute parent-to-child transform
|
111
|
-
i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
|
112
|
-
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
|
113
|
-
|
114
|
-
# Propagate link velocity
|
115
|
-
vJ = S[i] * qd[ii] if qd.size != 0 else S[i] * 0
|
116
|
-
|
117
|
-
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
118
|
-
v = v.at[i].set(v_i)
|
119
|
-
|
120
|
-
c_i = Cross.vx(v[i]) @ vJ
|
121
|
-
c = c.at[i].set(c_i)
|
122
|
-
|
123
|
-
# Initialize articulated-body inertia
|
124
|
-
MA_i = jnp.array(M[i])
|
125
|
-
MA = MA.at[i].set(MA_i)
|
126
|
-
|
127
|
-
# Initialize articulated-body bias forces
|
128
|
-
i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]]
|
129
|
-
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
130
|
-
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
|
131
|
-
|
132
|
-
pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i])
|
133
|
-
pA = pA.at[i].set(pA_i)
|
134
|
-
|
135
|
-
return (i_X_λi, v, c, MA, pA, i_X_0), None
|
136
|
-
|
137
|
-
(i_X_λi, v, c, MA, pA, i_X_0), _ = jax.lax.scan(
|
138
|
-
f=loop_body_pass1,
|
139
|
-
init=pass_1_carry,
|
140
|
-
xs=np.arange(start=1, stop=model.NB),
|
141
|
-
)
|
142
|
-
|
143
|
-
U = jnp.zeros_like(S)
|
144
|
-
d = jnp.zeros(shape=(model.NB, 1))
|
145
|
-
u = jnp.zeros(shape=(model.NB, 1))
|
146
|
-
|
147
|
-
Pass2Carry = Tuple[
|
148
|
-
jtp.MatrixJax,
|
149
|
-
jtp.MatrixJax,
|
150
|
-
jtp.MatrixJax,
|
151
|
-
jtp.MatrixJax,
|
152
|
-
jtp.MatrixJax,
|
153
|
-
]
|
154
|
-
|
155
|
-
pass_2_carry = (U, d, u, MA, pA)
|
156
|
-
|
157
|
-
# Pass 2
|
158
|
-
def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
|
159
|
-
ii = i - 1
|
160
|
-
U, d, u, MA, pA = carry
|
161
|
-
|
162
|
-
# Compute intermediate results
|
163
|
-
U_i = MA[i] @ S[i]
|
164
|
-
U = U.at[i].set(U_i)
|
165
|
-
|
166
|
-
d_i = S[i].T @ U[i]
|
167
|
-
d = d.at[i].set(d_i.squeeze())
|
168
|
-
|
169
|
-
u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
|
170
|
-
u = u.at[i].set(u_i.squeeze())
|
171
|
-
|
172
|
-
# Compute the articulated-body inertia and bias forces of this link
|
173
|
-
Ma = MA[i] - U[i] / d[i] @ U[i].T
|
174
|
-
pa = pA[i] + Ma @ c[i] + U[i] * u[i] / d[i]
|
175
|
-
|
176
|
-
# Propagate them to the parent, handling the base link
|
177
|
-
def propagate(
|
178
|
-
MA_pA: Tuple[jtp.MatrixJax, jtp.MatrixJax]
|
179
|
-
) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]:
|
180
|
-
MA, pA = MA_pA
|
181
|
-
|
182
|
-
MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
|
183
|
-
MA = MA.at[λ[i]].set(MA_λi)
|
184
|
-
|
185
|
-
pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
|
186
|
-
pA = pA.at[λ[i]].set(pA_λi)
|
187
|
-
|
188
|
-
return MA, pA
|
189
|
-
|
190
|
-
MA, pA = jax.lax.cond(
|
191
|
-
pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
|
192
|
-
true_fun=propagate,
|
193
|
-
false_fun=lambda MA_pA: MA_pA,
|
194
|
-
operand=(MA, pA),
|
195
|
-
)
|
196
|
-
|
197
|
-
return (U, d, u, MA, pA), None
|
198
|
-
|
199
|
-
(U, d, u, MA, pA), _ = jax.lax.scan(
|
200
|
-
f=loop_body_pass2,
|
201
|
-
init=pass_2_carry,
|
202
|
-
xs=np.flip(np.arange(start=1, stop=model.NB)),
|
203
|
-
)
|
204
|
-
|
205
|
-
if model.is_floating_base:
|
206
|
-
a0 = jnp.linalg.solve(-MA[0], pA[0])
|
207
|
-
else:
|
208
|
-
a0 = -B_X_W @ jnp.vstack(model.gravity)
|
209
|
-
|
210
|
-
a = jnp.zeros_like(S)
|
211
|
-
a = a.at[0].set(a0)
|
212
|
-
qdd = jnp.zeros_like(q)
|
213
|
-
|
214
|
-
Pass3Carry = Tuple[jtp.MatrixJax, jtp.VectorJax]
|
215
|
-
pass_3_carry = (a, qdd)
|
216
|
-
|
217
|
-
# Pass 3
|
218
|
-
def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
|
219
|
-
ii = i - 1
|
220
|
-
a, qdd = carry
|
221
|
-
|
222
|
-
# Propagate link accelerations
|
223
|
-
a_i = i_X_λi[i] @ a[λ[i]] + c[i]
|
224
|
-
|
225
|
-
# Compute joint accelerations
|
226
|
-
qdd_ii = (u[i] - U[i].T @ a_i) / d[i]
|
227
|
-
qdd = qdd.at[i - 1].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd
|
228
|
-
|
229
|
-
a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
|
230
|
-
a = a.at[i].set(a_i)
|
231
|
-
|
232
|
-
return (a, qdd), None
|
233
|
-
|
234
|
-
(a, qdd), _ = jax.lax.scan(
|
235
|
-
f=loop_body_pass3,
|
236
|
-
init=pass_3_carry,
|
237
|
-
xs=np.arange(1, model.NB),
|
238
|
-
)
|
239
|
-
|
240
|
-
# Handle 1 DoF models
|
241
|
-
qdd = jnp.atleast_1d(qdd.squeeze())
|
242
|
-
qdd = jnp.vstack(qdd) if qdd.size > 0 else jnp.empty(shape=(0, 1))
|
243
|
-
|
244
|
-
# Get the resulting base acceleration (w/o gravity) in body-fixed representation
|
245
|
-
B_a_WB = a[0]
|
246
|
-
|
247
|
-
# Convert the base acceleration to inertial-fixed representation, and add gravity
|
248
|
-
W_a_WB = jnp.vstack(
|
249
|
-
jnp.linalg.solve(B_X_W, B_a_WB) + jnp.vstack(model.gravity)
|
250
|
-
if model.is_floating_base
|
251
|
-
else jnp.zeros(6)
|
252
|
-
)
|
253
|
-
|
254
|
-
return W_a_WB, qdd
|