jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1.dev401.dist-info/METADATA +0 -167
- jaxsim-0.1.dev401.dist-info/RECORD +0 -64
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/parsers/rod/parser.py
CHANGED
@@ -10,7 +10,7 @@ from jaxsim import logging
|
|
10
10
|
from jaxsim.math.quaternion import Quaternion
|
11
11
|
from jaxsim.parsers import descriptions, kinematic_graph
|
12
12
|
|
13
|
-
from . import utils
|
13
|
+
from . import utils
|
14
14
|
|
15
15
|
|
16
16
|
class SDFData(NamedTuple):
|
@@ -135,11 +135,13 @@ def extract_model_data(
|
|
135
135
|
parent=world_link,
|
136
136
|
child=links_dict[j.child],
|
137
137
|
jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
|
138
|
-
axis=
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
138
|
+
axis=(
|
139
|
+
np.array(j.axis.xyz.xyz)
|
140
|
+
if j.axis is not None
|
141
|
+
and j.axis.xyz is not None
|
142
|
+
and j.axis.xyz.xyz is not None
|
143
|
+
else None
|
144
|
+
),
|
143
145
|
pose=j.pose.transform() if j.pose is not None else np.eye(4),
|
144
146
|
)
|
145
147
|
for j in sdf_model.joints()
|
@@ -200,41 +202,55 @@ def extract_model_data(
|
|
200
202
|
parent=links_dict[j.parent],
|
201
203
|
child=links_dict[j.child],
|
202
204
|
jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
|
203
|
-
axis=
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
205
|
+
axis=(
|
206
|
+
np.array(j.axis.xyz.xyz)
|
207
|
+
if j.axis is not None
|
208
|
+
and j.axis.xyz is not None
|
209
|
+
and j.axis.xyz.xyz is not None
|
210
|
+
else None
|
211
|
+
),
|
208
212
|
pose=j.pose.transform() if j.pose is not None else np.eye(4),
|
209
213
|
initial_position=0.0,
|
210
214
|
position_limit=(
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
215
|
+
(
|
216
|
+
float(j.axis.limit.lower)
|
217
|
+
if j.axis is not None and j.axis.limit is not None
|
218
|
+
else np.finfo(float).min
|
219
|
+
),
|
220
|
+
(
|
221
|
+
float(j.axis.limit.upper)
|
222
|
+
if j.axis is not None and j.axis.limit is not None
|
223
|
+
else np.finfo(float).max
|
224
|
+
),
|
225
|
+
),
|
226
|
+
friction_static=(
|
227
|
+
j.axis.dynamics.friction
|
228
|
+
if j.axis is not None
|
229
|
+
and j.axis.dynamics is not None
|
230
|
+
and j.axis.dynamics.friction is not None
|
231
|
+
else 0.0
|
232
|
+
),
|
233
|
+
friction_viscous=(
|
234
|
+
j.axis.dynamics.damping
|
235
|
+
if j.axis is not None
|
236
|
+
and j.axis.dynamics is not None
|
237
|
+
and j.axis.dynamics.damping is not None
|
238
|
+
else 0.0
|
239
|
+
),
|
240
|
+
position_limit_damper=(
|
241
|
+
j.axis.limit.dissipation
|
242
|
+
if j.axis is not None
|
243
|
+
and j.axis.limit is not None
|
244
|
+
and j.axis.limit.dissipation is not None
|
245
|
+
else 0.0
|
246
|
+
),
|
247
|
+
position_limit_spring=(
|
248
|
+
j.axis.limit.stiffness
|
249
|
+
if j.axis is not None
|
250
|
+
and j.axis.limit is not None
|
251
|
+
and j.axis.limit.stiffness is not None
|
252
|
+
else 0.0
|
217
253
|
),
|
218
|
-
friction_static=j.axis.dynamics.friction
|
219
|
-
if j.axis is not None
|
220
|
-
and j.axis.dynamics is not None
|
221
|
-
and j.axis.dynamics.friction is not None
|
222
|
-
else 0.0,
|
223
|
-
friction_viscous=j.axis.dynamics.damping
|
224
|
-
if j.axis is not None
|
225
|
-
and j.axis.dynamics is not None
|
226
|
-
and j.axis.dynamics.damping is not None
|
227
|
-
else 0.0,
|
228
|
-
position_limit_damper=j.axis.limit.dissipation
|
229
|
-
if j.axis is not None
|
230
|
-
and j.axis.limit is not None
|
231
|
-
and j.axis.limit.dissipation is not None
|
232
|
-
else 0.0,
|
233
|
-
position_limit_spring=j.axis.limit.stiffness
|
234
|
-
if j.axis is not None
|
235
|
-
and j.axis.limit is not None
|
236
|
-
and j.axis.limit.stiffness is not None
|
237
|
-
else 0.0,
|
238
254
|
)
|
239
255
|
for j in sdf_model.joints()
|
240
256
|
if j.type in {"revolute", "prismatic", "fixed"}
|
@@ -341,6 +357,6 @@ def build_model_description(
|
|
341
357
|
)
|
342
358
|
|
343
359
|
# Store the parsed SDF tree as extra info
|
344
|
-
model = dataclasses.replace(model, extra_info=
|
360
|
+
model = dataclasses.replace(model, extra_info={"sdf_model": sdf_data.sdf_model})
|
345
361
|
|
346
362
|
return model
|
jaxsim/parsers/rod/utils.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1
1
|
import os
|
2
2
|
from typing import Union
|
3
3
|
|
4
|
-
import
|
4
|
+
import jaxlie
|
5
5
|
import numpy as np
|
6
6
|
import numpy.typing as npt
|
7
7
|
import rod
|
8
8
|
|
9
|
+
import jaxsim.typing as jtp
|
10
|
+
from jaxsim.math.inertia import Inertia
|
9
11
|
from jaxsim.parsers import descriptions
|
10
12
|
|
11
13
|
|
12
|
-
def from_sdf_inertial(inertial: rod.Inertial) ->
|
14
|
+
def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
|
13
15
|
"""
|
14
16
|
Extract the 6D inertia matrix from an SDF inertial element.
|
15
17
|
|
@@ -20,9 +22,6 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
|
|
20
22
|
The 6D inertia matrix of the link expressed in the link frame.
|
21
23
|
"""
|
22
24
|
|
23
|
-
from jaxsim.math.inertia import Inertia
|
24
|
-
from jaxsim.sixd import se3
|
25
|
-
|
26
25
|
# Extract the "mass" element
|
27
26
|
m = inertial.mass
|
28
27
|
|
@@ -52,13 +51,13 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
|
|
52
51
|
L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)
|
53
52
|
|
54
53
|
# We need its inverse
|
55
|
-
CoM_H_L =
|
56
|
-
CoM_X_L
|
54
|
+
CoM_H_L = jaxlie.SE3.from_matrix(matrix=L_H_CoM).inverse()
|
55
|
+
CoM_X_L = CoM_H_L.adjoint()
|
57
56
|
|
58
57
|
# Express the CoM inertia matrix in the link frame L
|
59
58
|
M_L = CoM_X_L.T @ M_CoM @ CoM_X_L
|
60
59
|
|
61
|
-
return
|
60
|
+
return M_L.astype(dtype=float)
|
62
61
|
|
63
62
|
|
64
63
|
def axis_to_jtype(
|
jaxsim/rbda/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
from .aba import aba
|
2
|
+
from .collidable_points import collidable_points_pos_vel
|
3
|
+
from .crba import crba
|
4
|
+
from .forward_kinematics import forward_kinematics, forward_kinematics_model
|
5
|
+
from .jacobian import jacobian, jacobian_full_doubly_left
|
6
|
+
from .rnea import rnea
|
7
|
+
from .soft_contacts import SoftContacts, SoftContactsParams
|
jaxsim/rbda/aba.py
ADDED
@@ -0,0 +1,295 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
4
|
+
|
5
|
+
import jaxsim.api as js
|
6
|
+
import jaxsim.typing as jtp
|
7
|
+
from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
|
8
|
+
|
9
|
+
from . import utils
|
10
|
+
|
11
|
+
|
12
|
+
def aba(
|
13
|
+
model: js.model.JaxSimModel,
|
14
|
+
*,
|
15
|
+
base_position: jtp.VectorLike,
|
16
|
+
base_quaternion: jtp.VectorLike,
|
17
|
+
joint_positions: jtp.VectorLike,
|
18
|
+
base_linear_velocity: jtp.VectorLike,
|
19
|
+
base_angular_velocity: jtp.VectorLike,
|
20
|
+
joint_velocities: jtp.VectorLike,
|
21
|
+
joint_forces: jtp.VectorLike | None = None,
|
22
|
+
link_forces: jtp.MatrixLike | None = None,
|
23
|
+
standard_gravity: jtp.FloatLike = StandardGravity,
|
24
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
25
|
+
"""
|
26
|
+
Compute forward dynamics using the Articulated Body Algorithm (ABA).
|
27
|
+
|
28
|
+
Args:
|
29
|
+
model: The model to consider.
|
30
|
+
base_position: The position of the base link.
|
31
|
+
base_quaternion: The quaternion of the base link.
|
32
|
+
joint_positions: The positions of the joints.
|
33
|
+
base_linear_velocity:
|
34
|
+
The linear velocity of the base link in inertial-fixed representation.
|
35
|
+
base_angular_velocity:
|
36
|
+
The angular velocity of the base link in inertial-fixed representation.
|
37
|
+
joint_velocities: The velocities of the joints.
|
38
|
+
joint_forces: The forces applied to the joints.
|
39
|
+
link_forces:
|
40
|
+
The forces applied to the links expressed in the world frame.
|
41
|
+
standard_gravity: The standard gravity constant.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
A tuple containing the base acceleration in inertial-fixed representation
|
45
|
+
and the joint accelerations that result from the applications of the given
|
46
|
+
joint and link forces.
|
47
|
+
|
48
|
+
Note:
|
49
|
+
The algorithm expects a quaternion with unit norm.
|
50
|
+
"""
|
51
|
+
|
52
|
+
W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs(
|
53
|
+
model=model,
|
54
|
+
base_position=base_position,
|
55
|
+
base_quaternion=base_quaternion,
|
56
|
+
joint_positions=joint_positions,
|
57
|
+
base_linear_velocity=base_linear_velocity,
|
58
|
+
base_angular_velocity=base_angular_velocity,
|
59
|
+
joint_velocities=joint_velocities,
|
60
|
+
base_linear_acceleration=None,
|
61
|
+
base_angular_acceleration=None,
|
62
|
+
joint_accelerations=None,
|
63
|
+
joint_forces=joint_forces,
|
64
|
+
link_forces=link_forces,
|
65
|
+
standard_gravity=standard_gravity,
|
66
|
+
)
|
67
|
+
|
68
|
+
W_g = jnp.atleast_2d(W_g).T
|
69
|
+
W_v_WB = jnp.atleast_2d(W_v_WB).T
|
70
|
+
|
71
|
+
# Get the 6D spatial inertia matrices of all links.
|
72
|
+
M = js.model.link_spatial_inertia_matrices(model=model)
|
73
|
+
|
74
|
+
# Get the parent array λ(i).
|
75
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
76
|
+
λ = model.kin_dyn_parameters.parent_array
|
77
|
+
|
78
|
+
# Compute the base transform.
|
79
|
+
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
80
|
+
rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
|
81
|
+
translation=W_p_B,
|
82
|
+
)
|
83
|
+
|
84
|
+
# Compute 6D transforms of the base velocity.
|
85
|
+
W_X_B = W_H_B.adjoint()
|
86
|
+
B_X_W = W_H_B.inverse().adjoint()
|
87
|
+
|
88
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
89
|
+
# These transforms define the relative kinematics of the entire model, including
|
90
|
+
# the base transform for both floating-base and fixed-base models.
|
91
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
92
|
+
joint_positions=s, base_transform=W_H_B.as_matrix()
|
93
|
+
)
|
94
|
+
|
95
|
+
# Allocate buffers.
|
96
|
+
v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
97
|
+
c = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
98
|
+
pA = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
99
|
+
MA = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
100
|
+
|
101
|
+
# Allocate the buffer of transforms link -> base.
|
102
|
+
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
103
|
+
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
104
|
+
|
105
|
+
# Initialize base quantities
|
106
|
+
if model.floating_base():
|
107
|
+
|
108
|
+
# Base velocity v₀ in body-fixed representation.
|
109
|
+
v_0 = B_X_W @ W_v_WB
|
110
|
+
v = v.at[0].set(v_0)
|
111
|
+
|
112
|
+
# Initialize the articulated-body inertia (Mᴬ) of base link.
|
113
|
+
MA_0 = M[0]
|
114
|
+
MA = MA.at[0].set(MA_0)
|
115
|
+
|
116
|
+
# Initialize the articulated-body bias force (pᴬ) of the base link.
|
117
|
+
pA_0 = Cross.vx_star(v[0]) @ MA[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0])
|
118
|
+
pA = pA.at[0].set(pA_0)
|
119
|
+
|
120
|
+
# ======
|
121
|
+
# Pass 1
|
122
|
+
# ======
|
123
|
+
|
124
|
+
Pass1Carry = tuple[
|
125
|
+
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
|
126
|
+
]
|
127
|
+
|
128
|
+
pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)
|
129
|
+
|
130
|
+
# Propagate kinematics and initialize AB inertia and AB bias forces.
|
131
|
+
def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:
|
132
|
+
|
133
|
+
ii = i - 1
|
134
|
+
v, c, MA, pA, i_X_0 = carry
|
135
|
+
|
136
|
+
# Project the joint velocity into its motion subspace.
|
137
|
+
vJ = S[i] * ṡ[ii]
|
138
|
+
|
139
|
+
# Propagate the link velocity.
|
140
|
+
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
141
|
+
v = v.at[i].set(v_i)
|
142
|
+
|
143
|
+
c_i = Cross.vx(v[i]) @ vJ
|
144
|
+
c = c.at[i].set(c_i)
|
145
|
+
|
146
|
+
# Initialize the articulated-body inertia.
|
147
|
+
MA_i = jnp.array(M[i])
|
148
|
+
MA = MA.at[i].set(MA_i)
|
149
|
+
|
150
|
+
# Compute the link-to-base transform.
|
151
|
+
i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]]
|
152
|
+
i_X_0 = i_X_0.at[i].set(i_Xi_0)
|
153
|
+
|
154
|
+
# Compute link-to-world transform for the 6D force.
|
155
|
+
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
|
156
|
+
|
157
|
+
# Initialize articulated-body bias force.
|
158
|
+
pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i])
|
159
|
+
pA = pA.at[i].set(pA_i)
|
160
|
+
|
161
|
+
return (v, c, MA, pA, i_X_0), None
|
162
|
+
|
163
|
+
(v, c, MA, pA, i_X_0), _ = (
|
164
|
+
jax.lax.scan(
|
165
|
+
f=loop_body_pass1,
|
166
|
+
init=pass_1_carry,
|
167
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
168
|
+
)
|
169
|
+
if model.number_of_links() > 1
|
170
|
+
else [(v, c, MA, pA, i_X_0), None]
|
171
|
+
)
|
172
|
+
|
173
|
+
# ======
|
174
|
+
# Pass 2
|
175
|
+
# ======
|
176
|
+
|
177
|
+
U = jnp.zeros_like(S)
|
178
|
+
d = jnp.zeros(shape=(model.number_of_links(), 1))
|
179
|
+
u = jnp.zeros(shape=(model.number_of_links(), 1))
|
180
|
+
|
181
|
+
Pass2Carry = tuple[
|
182
|
+
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
|
183
|
+
]
|
184
|
+
|
185
|
+
pass_2_carry: Pass2Carry = (U, d, u, MA, pA)
|
186
|
+
|
187
|
+
def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
|
188
|
+
|
189
|
+
ii = i - 1
|
190
|
+
U, d, u, MA, pA = carry
|
191
|
+
|
192
|
+
U_i = MA[i] @ S[i]
|
193
|
+
U = U.at[i].set(U_i)
|
194
|
+
|
195
|
+
d_i = S[i].T @ U[i]
|
196
|
+
d = d.at[i].set(d_i.squeeze())
|
197
|
+
|
198
|
+
u_i = τ[ii] - S[i].T @ pA[i]
|
199
|
+
u = u.at[i].set(u_i.squeeze())
|
200
|
+
|
201
|
+
# Compute the articulated-body inertia and bias force of this link.
|
202
|
+
Ma = MA[i] - U[i] / d[i] @ U[i].T
|
203
|
+
pa = pA[i] + Ma @ c[i] + U[i] * (u[i] / d[i])
|
204
|
+
|
205
|
+
# Propagate them to the parent, handling the base link.
|
206
|
+
def propagate(
|
207
|
+
MA_pA: tuple[jtp.MatrixJax, jtp.MatrixJax]
|
208
|
+
) -> tuple[jtp.MatrixJax, jtp.MatrixJax]:
|
209
|
+
|
210
|
+
MA, pA = MA_pA
|
211
|
+
|
212
|
+
MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
|
213
|
+
MA = MA.at[λ[i]].set(MA_λi)
|
214
|
+
|
215
|
+
pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
|
216
|
+
pA = pA.at[λ[i]].set(pA_λi)
|
217
|
+
|
218
|
+
return MA, pA
|
219
|
+
|
220
|
+
MA, pA = jax.lax.cond(
|
221
|
+
pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
|
222
|
+
true_fun=propagate,
|
223
|
+
false_fun=lambda MA_pA: MA_pA,
|
224
|
+
operand=(MA, pA),
|
225
|
+
)
|
226
|
+
|
227
|
+
return (U, d, u, MA, pA), None
|
228
|
+
|
229
|
+
(U, d, u, MA, pA), _ = (
|
230
|
+
jax.lax.scan(
|
231
|
+
f=loop_body_pass2,
|
232
|
+
init=pass_2_carry,
|
233
|
+
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
|
234
|
+
)
|
235
|
+
if model.number_of_links() > 1
|
236
|
+
else [(U, d, u, MA, pA), None]
|
237
|
+
)
|
238
|
+
|
239
|
+
# ======
|
240
|
+
# Pass 3
|
241
|
+
# ======
|
242
|
+
|
243
|
+
if model.floating_base():
|
244
|
+
a0 = jnp.linalg.solve(-MA[0], pA[0])
|
245
|
+
else:
|
246
|
+
a0 = -B_X_W @ W_g
|
247
|
+
|
248
|
+
s̈ = jnp.zeros_like(s)
|
249
|
+
a = jnp.zeros_like(v).at[0].set(a0)
|
250
|
+
|
251
|
+
Pass3Carry = tuple[jtp.MatrixJax, jtp.VectorJax]
|
252
|
+
pass_3_carry = (a, s̈)
|
253
|
+
|
254
|
+
def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
|
255
|
+
|
256
|
+
ii = i - 1
|
257
|
+
a, s̈ = carry
|
258
|
+
|
259
|
+
# Propagate the link acceleration.
|
260
|
+
a_i = i_X_λi[i] @ a[λ[i]] + c[i]
|
261
|
+
|
262
|
+
# Compute the joint acceleration.
|
263
|
+
s̈_ii = (u[i] - U[i].T @ a_i) / d[i]
|
264
|
+
s̈ = s̈.at[ii].set(s̈_ii.squeeze())
|
265
|
+
|
266
|
+
# Sum the joint acceleration to the parent link acceleration.
|
267
|
+
a_i = a_i + S[i] * s̈[ii]
|
268
|
+
a = a.at[i].set(a_i)
|
269
|
+
|
270
|
+
return (a, s̈), None
|
271
|
+
|
272
|
+
(a, s̈), _ = (
|
273
|
+
jax.lax.scan(
|
274
|
+
f=loop_body_pass3,
|
275
|
+
init=pass_3_carry,
|
276
|
+
xs=jnp.arange(1, model.number_of_links()),
|
277
|
+
)
|
278
|
+
if model.number_of_links() > 1
|
279
|
+
else [(a, s̈), None]
|
280
|
+
)
|
281
|
+
|
282
|
+
# ==============
|
283
|
+
# Adjust outputs
|
284
|
+
# ==============
|
285
|
+
|
286
|
+
# TODO: remove vstack and shape=(6, 1)?
|
287
|
+
if model.floating_base():
|
288
|
+
# Convert the base acceleration to inertial-fixed representation,
|
289
|
+
# and add gravity.
|
290
|
+
B_a_WB = a[0]
|
291
|
+
W_a_WB = W_X_B @ B_a_WB + W_g
|
292
|
+
else:
|
293
|
+
W_a_WB = jnp.zeros(6)
|
294
|
+
|
295
|
+
return W_a_WB.squeeze(), jnp.atleast_1d(s̈.squeeze())
|
@@ -0,0 +1,142 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
4
|
+
|
5
|
+
import jaxsim.api as js
|
6
|
+
import jaxsim.typing as jtp
|
7
|
+
from jaxsim.math import Adjoint, Quaternion, Skew
|
8
|
+
|
9
|
+
from . import utils
|
10
|
+
|
11
|
+
|
12
|
+
def collidable_points_pos_vel(
|
13
|
+
model: js.model.JaxSimModel,
|
14
|
+
*,
|
15
|
+
base_position: jtp.Vector,
|
16
|
+
base_quaternion: jtp.Vector,
|
17
|
+
joint_positions: jtp.Vector,
|
18
|
+
base_linear_velocity: jtp.Vector,
|
19
|
+
base_angular_velocity: jtp.Vector,
|
20
|
+
joint_velocities: jtp.Vector,
|
21
|
+
) -> tuple[jtp.Matrix, jtp.Matrix]:
|
22
|
+
"""
|
23
|
+
|
24
|
+
Compute the position and linear velocity of collidable points in the world frame.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
model: The model to consider.
|
28
|
+
base_position: The position of the base link.
|
29
|
+
base_quaternion: The quaternion of the base link.
|
30
|
+
joint_positions: The positions of the joints.
|
31
|
+
base_linear_velocity:
|
32
|
+
The linear velocity of the base link in inertial-fixed representation.
|
33
|
+
base_angular_velocity:
|
34
|
+
The angular velocity of the base link in inertial-fixed representation.
|
35
|
+
joint_velocities: The velocities of the joints.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
A tuple containing the position and linear velocity of collidable points.
|
39
|
+
"""
|
40
|
+
|
41
|
+
if len(model.kin_dyn_parameters.contact_parameters.body) == 0:
|
42
|
+
return jnp.array(0).astype(float), jnp.empty(0).astype(float)
|
43
|
+
|
44
|
+
W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
|
45
|
+
model=model,
|
46
|
+
base_position=base_position,
|
47
|
+
base_quaternion=base_quaternion,
|
48
|
+
joint_positions=joint_positions,
|
49
|
+
base_linear_velocity=base_linear_velocity,
|
50
|
+
base_angular_velocity=base_angular_velocity,
|
51
|
+
joint_velocities=joint_velocities,
|
52
|
+
)
|
53
|
+
|
54
|
+
# Get the parent array λ(i).
|
55
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
56
|
+
λ = model.kin_dyn_parameters.parent_array
|
57
|
+
|
58
|
+
# Compute the base transform.
|
59
|
+
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
60
|
+
rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
|
61
|
+
translation=W_p_B,
|
62
|
+
)
|
63
|
+
|
64
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
65
|
+
# These transforms define the relative kinematics of the entire model, including
|
66
|
+
# the base transform for both floating-base and fixed-base models.
|
67
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
68
|
+
joint_positions=s, base_transform=W_H_B.as_matrix()
|
69
|
+
)
|
70
|
+
|
71
|
+
# Allocate buffer of transforms world -> link and initialize the base pose.
|
72
|
+
W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
73
|
+
W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
|
74
|
+
|
75
|
+
# Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity.
|
76
|
+
W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6))
|
77
|
+
W_v_Wi = W_v_Wi.at[0].set(W_v_WB)
|
78
|
+
|
79
|
+
# ====================
|
80
|
+
# Propagate kinematics
|
81
|
+
# ====================
|
82
|
+
|
83
|
+
PropagateTransformsCarry = tuple[jtp.MatrixJax, jtp.Matrix]
|
84
|
+
propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)
|
85
|
+
|
86
|
+
def propagate_kinematics(
|
87
|
+
carry: PropagateTransformsCarry, i: jtp.Int
|
88
|
+
) -> tuple[PropagateTransformsCarry, None]:
|
89
|
+
|
90
|
+
ii = i - 1
|
91
|
+
W_X_i, W_v_Wi = carry
|
92
|
+
|
93
|
+
# Compute the parent to child 6D transform.
|
94
|
+
λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i])
|
95
|
+
|
96
|
+
# Compute the world to child 6D transform.
|
97
|
+
W_Xi_i = W_X_i[λ[i]] @ λi_X_i
|
98
|
+
W_X_i = W_X_i.at[i].set(W_Xi_i)
|
99
|
+
|
100
|
+
# Propagate the 6D velocity
|
101
|
+
W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
|
102
|
+
W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
|
103
|
+
|
104
|
+
return (W_X_i, W_v_Wi), None
|
105
|
+
|
106
|
+
(W_X_i, W_v_Wi), _ = (
|
107
|
+
jax.lax.scan(
|
108
|
+
f=propagate_kinematics,
|
109
|
+
init=propagate_transforms_carry,
|
110
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
111
|
+
)
|
112
|
+
if model.number_of_links() > 1
|
113
|
+
else [(W_X_i, W_v_Wi), None]
|
114
|
+
)
|
115
|
+
|
116
|
+
# ==================================================
|
117
|
+
# Compute position and velocity of collidable points
|
118
|
+
# ==================================================
|
119
|
+
|
120
|
+
def process_point_kinematics(
|
121
|
+
Li_p_C: jtp.VectorJax, parent_body: jtp.Int
|
122
|
+
) -> tuple[jtp.VectorJax, jtp.VectorJax]:
|
123
|
+
# Compute the position of the collidable point
|
124
|
+
W_p_Ci = (
|
125
|
+
Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
|
126
|
+
)[0:3]
|
127
|
+
|
128
|
+
# Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}
|
129
|
+
CW_vl_WCi = (
|
130
|
+
jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
|
131
|
+
@ W_v_Wi[parent_body].squeeze()
|
132
|
+
)
|
133
|
+
|
134
|
+
return W_p_Ci, CW_vl_WCi
|
135
|
+
|
136
|
+
# Process all the collidable points in parallel
|
137
|
+
W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
|
138
|
+
model.kin_dyn_parameters.contact_parameters.point,
|
139
|
+
jnp.array(model.kin_dyn_parameters.contact_parameters.body),
|
140
|
+
)
|
141
|
+
|
142
|
+
return W_p_Ci, CW_vl_WC
|