jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/parsers/rod/utils.py
CHANGED
@@ -1,15 +1,24 @@
|
|
1
1
|
import os
|
2
|
-
|
2
|
+
import pathlib
|
3
|
+
from collections.abc import Callable
|
4
|
+
from typing import TypeVar
|
3
5
|
|
4
|
-
import jax.numpy as jnp
|
5
6
|
import numpy as np
|
6
7
|
import numpy.typing as npt
|
7
8
|
import rod
|
9
|
+
import trimesh
|
10
|
+
from rod.utils.resolve_uris import resolve_local_uri
|
8
11
|
|
12
|
+
import jaxsim.typing as jtp
|
13
|
+
from jaxsim import logging
|
14
|
+
from jaxsim.math import Adjoint, Inertia
|
9
15
|
from jaxsim.parsers import descriptions
|
16
|
+
from jaxsim.parsers.rod import meshes
|
10
17
|
|
18
|
+
MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray])
|
11
19
|
|
12
|
-
|
20
|
+
|
21
|
+
def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
|
13
22
|
"""
|
14
23
|
Extract the 6D inertia matrix from an SDF inertial element.
|
15
24
|
|
@@ -20,13 +29,10 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
|
|
20
29
|
The 6D inertia matrix of the link expressed in the link frame.
|
21
30
|
"""
|
22
31
|
|
23
|
-
|
24
|
-
from jaxsim.sixd import se3
|
25
|
-
|
26
|
-
# Extract the "mass" element
|
32
|
+
# Extract the "mass" element.
|
27
33
|
m = inertial.mass
|
28
34
|
|
29
|
-
# Extract the "inertia" element
|
35
|
+
# Extract the "inertia" element.
|
30
36
|
inertia_element = inertial.inertia
|
31
37
|
|
32
38
|
ixx = inertia_element.ixx
|
@@ -36,7 +42,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
|
|
36
42
|
ixz = inertia_element.ixz if inertia_element.ixz is not None else 0.0
|
37
43
|
iyz = inertia_element.iyz if inertia_element.iyz is not None else 0.0
|
38
44
|
|
39
|
-
# Build the 3x3 inertia matrix expressed in the CoM
|
45
|
+
# Build the 3x3 inertia matrix expressed in the CoM.
|
40
46
|
I_CoM = np.array(
|
41
47
|
[
|
42
48
|
[ixx, ixy, ixz],
|
@@ -45,73 +51,52 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
|
|
45
51
|
]
|
46
52
|
)
|
47
53
|
|
48
|
-
# Build the 6x6 generalized inertia at the CoM
|
54
|
+
# Build the 6x6 generalized inertia at the CoM.
|
49
55
|
M_CoM = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_CoM)
|
50
56
|
|
51
|
-
# Compute the transform from the inertial frame (CoM) to the link frame
|
57
|
+
# Compute the transform from the inertial frame (CoM) to the link frame.
|
52
58
|
L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)
|
53
59
|
|
54
|
-
# We need its inverse
|
55
|
-
|
56
|
-
CoM_X_L: npt.NDArray = CoM_H_L.adjoint()
|
60
|
+
# We need its inverse.
|
61
|
+
CoM_X_L = Adjoint.from_transform(transform=L_H_CoM, inverse=True)
|
57
62
|
|
58
|
-
# Express the CoM inertia matrix in the link frame L
|
63
|
+
# Express the CoM inertia matrix in the link frame L.
|
59
64
|
M_L = CoM_X_L.T @ M_CoM @ CoM_X_L
|
60
65
|
|
61
|
-
return
|
66
|
+
return M_L.astype(dtype=float)
|
62
67
|
|
63
68
|
|
64
|
-
def
|
65
|
-
axis: rod.Axis, type: str
|
66
|
-
) -> Union[descriptions.JointType, descriptions.JointDescriptor]:
|
69
|
+
def joint_to_joint_type(joint: rod.Joint) -> int:
|
67
70
|
"""
|
68
|
-
|
71
|
+
Extract the joint type from an SDF joint.
|
69
72
|
|
70
73
|
Args:
|
71
|
-
|
72
|
-
type: The SDF joint type.
|
74
|
+
joint: The parsed SDF joint.
|
73
75
|
|
74
76
|
Returns:
|
75
|
-
The corresponding joint type
|
77
|
+
The integer corresponding to the joint type.
|
76
78
|
"""
|
77
79
|
|
78
|
-
|
79
|
-
|
80
|
+
axis = joint.axis
|
81
|
+
joint_type = joint.type
|
82
|
+
|
83
|
+
if joint_type == "fixed":
|
84
|
+
return descriptions.JointType.Fixed
|
80
85
|
|
81
86
|
if not (axis.xyz is not None and axis.xyz.xyz is not None):
|
82
87
|
raise ValueError("Failed to read axis xyz data")
|
83
88
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
return descriptions.JointType.Rx
|
88
|
-
|
89
|
-
if np.allclose(axis_xyz, [0, 1, 0]) and type in {"revolute", "continuous"}:
|
90
|
-
return descriptions.JointType.Ry
|
91
|
-
|
92
|
-
if np.allclose(axis_xyz, [0, 0, 1]) and type in {"revolute", "continuous"}:
|
93
|
-
return descriptions.JointType.Rz
|
89
|
+
# Make sure that the axis is a unary vector.
|
90
|
+
axis_xyz = np.array(axis.xyz.xyz).astype(float)
|
91
|
+
axis_xyz = axis_xyz / np.linalg.norm(axis_xyz)
|
94
92
|
|
95
|
-
if
|
96
|
-
return descriptions.JointType.
|
93
|
+
if joint_type in {"revolute", "continuous"}:
|
94
|
+
return descriptions.JointType.Revolute
|
97
95
|
|
98
|
-
if
|
99
|
-
return descriptions.JointType.
|
96
|
+
if joint_type == "prismatic":
|
97
|
+
return descriptions.JointType.Prismatic
|
100
98
|
|
101
|
-
|
102
|
-
return descriptions.JointType.Pz
|
103
|
-
|
104
|
-
if type == "revolute":
|
105
|
-
return descriptions.JointGenericAxis(
|
106
|
-
code=descriptions.JointType.R, axis=np.array(axis_xyz, dtype=float)
|
107
|
-
)
|
108
|
-
|
109
|
-
if type == "prismatic":
|
110
|
-
return descriptions.JointGenericAxis(
|
111
|
-
code=descriptions.JointType.P, axis=np.array(axis_xyz, dtype=float)
|
112
|
-
)
|
113
|
-
|
114
|
-
raise ValueError("Joint not supported", axis_xyz, type)
|
99
|
+
raise ValueError("Joint not supported", axis_xyz, joint_type)
|
115
100
|
|
116
101
|
|
117
102
|
def create_box_collision(
|
@@ -132,22 +117,19 @@ def create_box_collision(
|
|
132
117
|
|
133
118
|
center = np.array([x / 2, y / 2, z / 2])
|
134
119
|
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
np.array([x, 0, z]),
|
144
|
-
np.array([x, y, z]),
|
145
|
-
np.array([0, y, z]),
|
146
|
-
]
|
147
|
-
)
|
148
|
-
- center
|
120
|
+
# Define the bottom corners.
|
121
|
+
bottom_corners = np.array([[0, 0, 0], [x, 0, 0], [x, y, 0], [0, y, 0]])
|
122
|
+
|
123
|
+
# Conditionally add the top corners based on the environment variable.
|
124
|
+
top_corners = (
|
125
|
+
np.array([[0, 0, z], [x, 0, z], [x, y, z], [0, y, z]])
|
126
|
+
if not os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0")
|
127
|
+
else []
|
149
128
|
)
|
150
129
|
|
130
|
+
# Combine and shift by the center
|
131
|
+
box_corners = np.vstack([bottom_corners, *top_corners]) - center
|
132
|
+
|
151
133
|
H = collision.pose.transform() if collision.pose is not None else np.eye(4)
|
152
134
|
|
153
135
|
center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1]
|
@@ -158,7 +140,7 @@ def create_box_collision(
|
|
158
140
|
collidable_points = [
|
159
141
|
descriptions.CollidablePoint(
|
160
142
|
parent_link=link_description,
|
161
|
-
position=corner,
|
143
|
+
position=np.array(corner),
|
162
144
|
enabled=True,
|
163
145
|
)
|
164
146
|
for corner in box_corners_wrt_link.T
|
@@ -185,25 +167,33 @@ def create_sphere_collision(
|
|
185
167
|
|
186
168
|
# From https://stackoverflow.com/a/26127012
|
187
169
|
def fibonacci_sphere(samples: int) -> npt.NDArray:
|
188
|
-
|
189
|
-
phi = np.pi * (3.0 - np.sqrt(5.0))
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
170
|
+
# Get the golden ratio in radians.
|
171
|
+
phi = np.pi * (3.0 - np.sqrt(5.0))
|
172
|
+
|
173
|
+
# Generate the points.
|
174
|
+
points = [
|
175
|
+
np.array(
|
176
|
+
[
|
177
|
+
np.cos(phi * i)
|
178
|
+
* np.sqrt(1 - (y := 1 - 2 * i / (samples - 1)) ** 2),
|
179
|
+
y,
|
180
|
+
np.sin(phi * i) * np.sqrt(1 - y**2),
|
181
|
+
]
|
182
|
+
)
|
183
|
+
for i in range(samples)
|
184
|
+
]
|
199
185
|
|
200
|
-
|
186
|
+
# Filter to keep only the bottom half if required.
|
187
|
+
if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0"):
|
188
|
+
# Keep only the points with z <= 0.
|
189
|
+
points = [point for point in points if point[2] <= 0]
|
201
190
|
|
202
191
|
return np.vstack(points)
|
203
192
|
|
204
193
|
r = collision.geometry.sphere.radius
|
194
|
+
|
205
195
|
sphere_points = r * fibonacci_sphere(
|
206
|
-
samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="
|
196
|
+
samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="50"))
|
207
197
|
)
|
208
198
|
|
209
199
|
H = collision.pose.transform() if collision.pose is not None else np.eye(4)
|
@@ -217,7 +207,7 @@ def create_sphere_collision(
|
|
217
207
|
collidable_points = [
|
218
208
|
descriptions.CollidablePoint(
|
219
209
|
parent_link=link_description,
|
220
|
-
position=point,
|
210
|
+
position=np.array(point),
|
221
211
|
enabled=True,
|
222
212
|
)
|
223
213
|
for point in sphere_points_wrt_link.T
|
@@ -226,3 +216,58 @@ def create_sphere_collision(
|
|
226
216
|
return descriptions.SphereCollision(
|
227
217
|
collidable_points=collidable_points, center=center_wrt_link
|
228
218
|
)
|
219
|
+
|
220
|
+
|
221
|
+
def create_mesh_collision(
|
222
|
+
collision: rod.Collision,
|
223
|
+
link_description: descriptions.LinkDescription,
|
224
|
+
method: MeshMappingMethod = None,
|
225
|
+
) -> descriptions.MeshCollision:
|
226
|
+
"""
|
227
|
+
Create a mesh collision from an SDF collision element.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
collision: The SDF collision element.
|
231
|
+
link_description: The link description.
|
232
|
+
method: The method to use for mesh wrapping.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
The mesh collision description.
|
236
|
+
"""
|
237
|
+
|
238
|
+
file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
|
239
|
+
file_type = file.suffix.replace(".", "")
|
240
|
+
mesh = trimesh.load_mesh(file, file_type=file_type)
|
241
|
+
|
242
|
+
if mesh.is_empty:
|
243
|
+
raise RuntimeError(f"Failed to process '{file}' with trimesh")
|
244
|
+
|
245
|
+
mesh.apply_scale(collision.geometry.mesh.scale)
|
246
|
+
logging.info(
|
247
|
+
msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{file_type}'"
|
248
|
+
)
|
249
|
+
|
250
|
+
if method is None:
|
251
|
+
method = meshes.VertexExtraction()
|
252
|
+
logging.debug("Using default Vertex Extraction method for mesh wrapping")
|
253
|
+
else:
|
254
|
+
logging.debug(f"Using method {method} for mesh wrapping")
|
255
|
+
|
256
|
+
points = method(mesh=mesh)
|
257
|
+
logging.debug(f"Extracted {len(points)} points from mesh")
|
258
|
+
|
259
|
+
W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4)
|
260
|
+
|
261
|
+
# Extract translation from transformation matrix
|
262
|
+
W_p_L = W_H_L[:3, 3]
|
263
|
+
mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L
|
264
|
+
collidable_points = [
|
265
|
+
descriptions.CollidablePoint(
|
266
|
+
parent_link=link_description,
|
267
|
+
position=point,
|
268
|
+
enabled=True,
|
269
|
+
)
|
270
|
+
for point in mesh_points_wrt_link
|
271
|
+
]
|
272
|
+
|
273
|
+
return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L)
|
jaxsim/rbda/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
1
|
+
from . import contacts
|
2
|
+
from .aba import aba
|
3
|
+
from .collidable_points import collidable_points_pos_vel
|
4
|
+
from .crba import crba
|
5
|
+
from .forward_kinematics import forward_kinematics, forward_kinematics_model
|
6
|
+
from .jacobian import (
|
7
|
+
jacobian,
|
8
|
+
jacobian_derivative_full_doubly_left,
|
9
|
+
jacobian_full_doubly_left,
|
10
|
+
)
|
11
|
+
from .rnea import rnea
|
jaxsim/rbda/aba.py
ADDED
@@ -0,0 +1,289 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
4
|
+
|
5
|
+
import jaxsim.api as js
|
6
|
+
import jaxsim.typing as jtp
|
7
|
+
from jaxsim.math import Adjoint, Cross, StandardGravity
|
8
|
+
|
9
|
+
from . import utils
|
10
|
+
|
11
|
+
|
12
|
+
def aba(
|
13
|
+
model: js.model.JaxSimModel,
|
14
|
+
*,
|
15
|
+
base_position: jtp.VectorLike,
|
16
|
+
base_quaternion: jtp.VectorLike,
|
17
|
+
joint_positions: jtp.VectorLike,
|
18
|
+
base_linear_velocity: jtp.VectorLike,
|
19
|
+
base_angular_velocity: jtp.VectorLike,
|
20
|
+
joint_velocities: jtp.VectorLike,
|
21
|
+
joint_forces: jtp.VectorLike | None = None,
|
22
|
+
link_forces: jtp.MatrixLike | None = None,
|
23
|
+
standard_gravity: jtp.FloatLike = StandardGravity,
|
24
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
25
|
+
"""
|
26
|
+
Compute forward dynamics using the Articulated Body Algorithm (ABA).
|
27
|
+
|
28
|
+
Args:
|
29
|
+
model: The model to consider.
|
30
|
+
base_position: The position of the base link.
|
31
|
+
base_quaternion: The quaternion of the base link.
|
32
|
+
joint_positions: The positions of the joints.
|
33
|
+
base_linear_velocity:
|
34
|
+
The linear velocity of the base link in inertial-fixed representation.
|
35
|
+
base_angular_velocity:
|
36
|
+
The angular velocity of the base link in inertial-fixed representation.
|
37
|
+
joint_velocities: The velocities of the joints.
|
38
|
+
joint_forces: The forces applied to the joints.
|
39
|
+
link_forces:
|
40
|
+
The forces applied to the links expressed in the world frame.
|
41
|
+
standard_gravity: The standard gravity constant.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
A tuple containing the base acceleration in inertial-fixed representation
|
45
|
+
and the joint accelerations that result from the applications of the given
|
46
|
+
joint and link forces.
|
47
|
+
|
48
|
+
Note:
|
49
|
+
The algorithm expects a quaternion with unit norm.
|
50
|
+
"""
|
51
|
+
|
52
|
+
W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs(
|
53
|
+
model=model,
|
54
|
+
base_position=base_position,
|
55
|
+
base_quaternion=base_quaternion,
|
56
|
+
joint_positions=joint_positions,
|
57
|
+
base_linear_velocity=base_linear_velocity,
|
58
|
+
base_angular_velocity=base_angular_velocity,
|
59
|
+
joint_velocities=joint_velocities,
|
60
|
+
base_linear_acceleration=None,
|
61
|
+
base_angular_acceleration=None,
|
62
|
+
joint_accelerations=None,
|
63
|
+
joint_forces=joint_forces,
|
64
|
+
link_forces=link_forces,
|
65
|
+
standard_gravity=standard_gravity,
|
66
|
+
)
|
67
|
+
|
68
|
+
W_g = jnp.atleast_2d(W_g).T
|
69
|
+
W_v_WB = jnp.atleast_2d(W_v_WB).T
|
70
|
+
|
71
|
+
# Get the 6D spatial inertia matrices of all links.
|
72
|
+
M = js.model.link_spatial_inertia_matrices(model=model)
|
73
|
+
|
74
|
+
# Get the parent array λ(i).
|
75
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
76
|
+
λ = model.kin_dyn_parameters.parent_array
|
77
|
+
|
78
|
+
# Compute the base transform.
|
79
|
+
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
80
|
+
rotation=jaxlie.SO3(wxyz=W_Q_B),
|
81
|
+
translation=W_p_B,
|
82
|
+
)
|
83
|
+
|
84
|
+
# Compute 6D transforms of the base velocity.
|
85
|
+
W_X_B = W_H_B.adjoint()
|
86
|
+
B_X_W = W_H_B.inverse().adjoint()
|
87
|
+
|
88
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
89
|
+
# These transforms define the relative kinematics of the entire model, including
|
90
|
+
# the base transform for both floating-base and fixed-base models.
|
91
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
92
|
+
joint_positions=s, base_transform=W_H_B.as_matrix()
|
93
|
+
)
|
94
|
+
|
95
|
+
# Allocate buffers.
|
96
|
+
v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
97
|
+
c = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
98
|
+
pA = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
99
|
+
MA = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
100
|
+
|
101
|
+
# Allocate the buffer of transforms link -> base.
|
102
|
+
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
103
|
+
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
104
|
+
|
105
|
+
# Initialize base quantities.
|
106
|
+
if model.floating_base():
|
107
|
+
|
108
|
+
# Base velocity v₀ in body-fixed representation.
|
109
|
+
v_0 = B_X_W @ W_v_WB
|
110
|
+
v = v.at[0].set(v_0)
|
111
|
+
|
112
|
+
# Initialize the articulated-body inertia (Mᴬ) of base link.
|
113
|
+
MA_0 = M[0]
|
114
|
+
MA = MA.at[0].set(MA_0)
|
115
|
+
|
116
|
+
# Initialize the articulated-body bias force (pᴬ) of the base link.
|
117
|
+
pA_0 = Cross.vx_star(v[0]) @ MA[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0])
|
118
|
+
pA = pA.at[0].set(pA_0)
|
119
|
+
|
120
|
+
# ======
|
121
|
+
# Pass 1
|
122
|
+
# ======
|
123
|
+
|
124
|
+
Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
|
125
|
+
pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)
|
126
|
+
|
127
|
+
# Propagate kinematics and initialize AB inertia and AB bias forces.
|
128
|
+
def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:
|
129
|
+
|
130
|
+
ii = i - 1
|
131
|
+
v, c, MA, pA, i_X_0 = carry
|
132
|
+
|
133
|
+
# Project the joint velocity into its motion subspace.
|
134
|
+
vJ = S[i] * ṡ[ii]
|
135
|
+
|
136
|
+
# Propagate the link velocity.
|
137
|
+
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
138
|
+
v = v.at[i].set(v_i)
|
139
|
+
|
140
|
+
c_i = Cross.vx(v[i]) @ vJ
|
141
|
+
c = c.at[i].set(c_i)
|
142
|
+
|
143
|
+
# Initialize the articulated-body inertia.
|
144
|
+
MA_i = jnp.array(M[i])
|
145
|
+
MA = MA.at[i].set(MA_i)
|
146
|
+
|
147
|
+
# Compute the link-to-base transform.
|
148
|
+
i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]]
|
149
|
+
i_X_0 = i_X_0.at[i].set(i_Xi_0)
|
150
|
+
|
151
|
+
# Compute link-to-world transform for the 6D force.
|
152
|
+
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
|
153
|
+
|
154
|
+
# Initialize articulated-body bias force.
|
155
|
+
pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i])
|
156
|
+
pA = pA.at[i].set(pA_i)
|
157
|
+
|
158
|
+
return (v, c, MA, pA, i_X_0), None
|
159
|
+
|
160
|
+
(v, c, MA, pA, i_X_0), _ = (
|
161
|
+
jax.lax.scan(
|
162
|
+
f=loop_body_pass1,
|
163
|
+
init=pass_1_carry,
|
164
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
165
|
+
)
|
166
|
+
if model.number_of_links() > 1
|
167
|
+
else [(v, c, MA, pA, i_X_0), None]
|
168
|
+
)
|
169
|
+
|
170
|
+
# ======
|
171
|
+
# Pass 2
|
172
|
+
# ======
|
173
|
+
|
174
|
+
U = jnp.zeros_like(S)
|
175
|
+
d = jnp.zeros(shape=(model.number_of_links(), 1))
|
176
|
+
u = jnp.zeros(shape=(model.number_of_links(), 1))
|
177
|
+
|
178
|
+
Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
|
179
|
+
pass_2_carry: Pass2Carry = (U, d, u, MA, pA)
|
180
|
+
|
181
|
+
def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
|
182
|
+
|
183
|
+
ii = i - 1
|
184
|
+
U, d, u, MA, pA = carry
|
185
|
+
|
186
|
+
U_i = MA[i] @ S[i]
|
187
|
+
U = U.at[i].set(U_i)
|
188
|
+
|
189
|
+
d_i = S[i].T @ U[i]
|
190
|
+
d = d.at[i].set(d_i.squeeze())
|
191
|
+
|
192
|
+
u_i = τ[ii] - S[i].T @ pA[i]
|
193
|
+
u = u.at[i].set(u_i.squeeze())
|
194
|
+
|
195
|
+
# Compute the articulated-body inertia and bias force of this link.
|
196
|
+
Ma = MA[i] - U[i] / d[i] @ U[i].T
|
197
|
+
pa = pA[i] + Ma @ c[i] + U[i] * (u[i] / d[i])
|
198
|
+
|
199
|
+
# Propagate them to the parent, handling the base link.
|
200
|
+
def propagate(
|
201
|
+
MA_pA: tuple[jtp.Matrix, jtp.Matrix]
|
202
|
+
) -> tuple[jtp.Matrix, jtp.Matrix]:
|
203
|
+
|
204
|
+
MA, pA = MA_pA
|
205
|
+
|
206
|
+
MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
|
207
|
+
MA = MA.at[λ[i]].set(MA_λi)
|
208
|
+
|
209
|
+
pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
|
210
|
+
pA = pA.at[λ[i]].set(pA_λi)
|
211
|
+
|
212
|
+
return MA, pA
|
213
|
+
|
214
|
+
MA, pA = jax.lax.cond(
|
215
|
+
pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
|
216
|
+
true_fun=propagate,
|
217
|
+
false_fun=lambda MA_pA: MA_pA,
|
218
|
+
operand=(MA, pA),
|
219
|
+
)
|
220
|
+
|
221
|
+
return (U, d, u, MA, pA), None
|
222
|
+
|
223
|
+
(U, d, u, MA, pA), _ = (
|
224
|
+
jax.lax.scan(
|
225
|
+
f=loop_body_pass2,
|
226
|
+
init=pass_2_carry,
|
227
|
+
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
|
228
|
+
)
|
229
|
+
if model.number_of_links() > 1
|
230
|
+
else [(U, d, u, MA, pA), None]
|
231
|
+
)
|
232
|
+
|
233
|
+
# ======
|
234
|
+
# Pass 3
|
235
|
+
# ======
|
236
|
+
|
237
|
+
if model.floating_base():
|
238
|
+
a0 = jnp.linalg.solve(-MA[0], pA[0])
|
239
|
+
else:
|
240
|
+
a0 = -B_X_W @ W_g
|
241
|
+
|
242
|
+
s̈ = jnp.zeros_like(s)
|
243
|
+
a = jnp.zeros_like(v).at[0].set(a0)
|
244
|
+
|
245
|
+
Pass3Carry = tuple[jtp.Matrix, jtp.Vector]
|
246
|
+
pass_3_carry = (a, s̈)
|
247
|
+
|
248
|
+
def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
|
249
|
+
|
250
|
+
ii = i - 1
|
251
|
+
a, s̈ = carry
|
252
|
+
|
253
|
+
# Propagate the link acceleration.
|
254
|
+
a_i = i_X_λi[i] @ a[λ[i]] + c[i]
|
255
|
+
|
256
|
+
# Compute the joint acceleration.
|
257
|
+
s̈_ii = (u[i] - U[i].T @ a_i) / d[i]
|
258
|
+
s̈ = s̈.at[ii].set(s̈_ii.squeeze())
|
259
|
+
|
260
|
+
# Sum the joint acceleration to the parent link acceleration.
|
261
|
+
a_i = a_i + S[i] * s̈[ii]
|
262
|
+
a = a.at[i].set(a_i)
|
263
|
+
|
264
|
+
return (a, s̈), None
|
265
|
+
|
266
|
+
(a, s̈), _ = (
|
267
|
+
jax.lax.scan(
|
268
|
+
f=loop_body_pass3,
|
269
|
+
init=pass_3_carry,
|
270
|
+
xs=jnp.arange(1, model.number_of_links()),
|
271
|
+
)
|
272
|
+
if model.number_of_links() > 1
|
273
|
+
else [(a, s̈), None]
|
274
|
+
)
|
275
|
+
|
276
|
+
# ==============
|
277
|
+
# Adjust outputs
|
278
|
+
# ==============
|
279
|
+
|
280
|
+
# TODO: remove vstack and shape=(6, 1)?
|
281
|
+
if model.floating_base():
|
282
|
+
# Convert the base acceleration to inertial-fixed representation,
|
283
|
+
# and add gravity.
|
284
|
+
B_a_WB = a[0]
|
285
|
+
W_a_WB = W_X_B @ B_a_WB + W_g
|
286
|
+
else:
|
287
|
+
W_a_WB = jnp.zeros(6)
|
288
|
+
|
289
|
+
return W_a_WB.squeeze(), jnp.atleast_1d(s̈.squeeze())
|