jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,284 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.math.adjoint import Adjoint
|
9
|
-
from jaxsim.math.cross import Cross
|
10
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
11
|
-
|
12
|
-
from . import utils
|
13
|
-
|
14
|
-
|
15
|
-
def aba(
|
16
|
-
model: PhysicsModel,
|
17
|
-
xfb: jtp.Vector,
|
18
|
-
q: jtp.Vector,
|
19
|
-
qd: jtp.Vector,
|
20
|
-
tau: jtp.Vector,
|
21
|
-
f_ext: jtp.Matrix | None = None,
|
22
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
23
|
-
"""
|
24
|
-
Articulated Body Algorithm (ABA) algorithm with motor dynamics for forward dynamics.
|
25
|
-
"""
|
26
|
-
|
27
|
-
x_fb, q, qd, _, tau, f_ext = utils.process_inputs(
|
28
|
-
physics_model=model, xfb=xfb, q=q, qd=qd, tau=tau, f_ext=f_ext
|
29
|
-
)
|
30
|
-
|
31
|
-
# Extract data from the physics model
|
32
|
-
pre_X_λi = model.tree_transforms
|
33
|
-
M = model.spatial_inertias
|
34
|
-
i_X_pre = model.joint_transforms(q=q)
|
35
|
-
S = model.motion_subspaces(q=q)
|
36
|
-
λ = model.parent_array()
|
37
|
-
|
38
|
-
# Extract motor parameters from the physics model
|
39
|
-
Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
|
40
|
-
IM = jnp.array(
|
41
|
-
[jnp.eye(6) * m for m in [*model._joint_motor_inertia.values()]] * model.NB
|
42
|
-
)
|
43
|
-
K̅ᵥ = Γ.T * jnp.array([*model._joint_motor_viscous_friction.values()]) * Γ
|
44
|
-
m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)
|
45
|
-
|
46
|
-
# Initialize buffers
|
47
|
-
v = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
48
|
-
MA = jnp.array([jnp.zeros([6, 6])] * model.NB)
|
49
|
-
pA = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
50
|
-
c = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
51
|
-
i_X_λi = jnp.zeros_like(i_X_pre)
|
52
|
-
|
53
|
-
m_v = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
54
|
-
m_c = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
55
|
-
pR = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
56
|
-
|
57
|
-
# Base pose B_X_W and velocity
|
58
|
-
base_quat = jnp.vstack(x_fb[0:4])
|
59
|
-
base_pos = jnp.vstack(x_fb[4:7])
|
60
|
-
base_vel = jnp.vstack(jnp.hstack([x_fb[10:13], x_fb[7:10]]))
|
61
|
-
|
62
|
-
# 6D transform of base velocity
|
63
|
-
B_X_W = Adjoint.from_quaternion_and_translation(
|
64
|
-
quaternion=base_quat,
|
65
|
-
translation=base_pos,
|
66
|
-
inverse=True,
|
67
|
-
normalize_quaternion=True,
|
68
|
-
)
|
69
|
-
i_X_λi = i_X_λi.at[0].set(B_X_W)
|
70
|
-
|
71
|
-
# Transforms link -> base
|
72
|
-
i_X_0 = jnp.zeros_like(pre_X_λi)
|
73
|
-
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
74
|
-
|
75
|
-
# Initialize base quantities
|
76
|
-
if model.is_floating_base:
|
77
|
-
# Base velocity v₀
|
78
|
-
v_0 = B_X_W @ base_vel
|
79
|
-
v = v.at[0].set(v_0)
|
80
|
-
|
81
|
-
# AB inertia (Mᴬ) and AB bias forces (pᴬ)
|
82
|
-
MA_0 = M[0]
|
83
|
-
MA = MA.at[0].set(MA_0)
|
84
|
-
pA_0 = Cross.vx_star(v[0]) @ MA_0 @ v[0] - Adjoint.inverse(
|
85
|
-
B_X_W
|
86
|
-
).T @ jnp.vstack(f_ext[0])
|
87
|
-
pA = pA.at[0].set(pA_0)
|
88
|
-
|
89
|
-
Pass1Carry = Tuple[
|
90
|
-
jtp.MatrixJax,
|
91
|
-
jtp.MatrixJax,
|
92
|
-
jtp.MatrixJax,
|
93
|
-
jtp.MatrixJax,
|
94
|
-
jtp.MatrixJax,
|
95
|
-
jtp.MatrixJax,
|
96
|
-
jtp.MatrixJax,
|
97
|
-
jtp.MatrixJax,
|
98
|
-
jtp.MatrixJax,
|
99
|
-
]
|
100
|
-
|
101
|
-
pass_1_carry = (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0)
|
102
|
-
|
103
|
-
# Pass 1
|
104
|
-
def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
|
105
|
-
ii = i - 1
|
106
|
-
i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0 = carry
|
107
|
-
|
108
|
-
# Compute parent-to-child transform
|
109
|
-
i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
|
110
|
-
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
|
111
|
-
|
112
|
-
# Propagate link velocity
|
113
|
-
vJ = S[i] * qd[ii] * (qd.size != 0)
|
114
|
-
m_vJ = m_S[i] * qd[ii] * (qd.size != 0)
|
115
|
-
|
116
|
-
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
117
|
-
v = v.at[i].set(v_i)
|
118
|
-
|
119
|
-
m_v_i = i_X_λi[i] @ v[λ[i]] + m_vJ
|
120
|
-
m_v = m_v.at[i].set(m_v_i)
|
121
|
-
|
122
|
-
c_i = Cross.vx(v[i]) @ vJ
|
123
|
-
c = c.at[i].set(c_i)
|
124
|
-
m_c_i = Cross.vx(m_v[i]) @ m_vJ
|
125
|
-
m_c = m_c.at[i].set(m_c_i)
|
126
|
-
|
127
|
-
# Initialize articulated-body inertia
|
128
|
-
MA_i = jnp.array(M[i])
|
129
|
-
MA = MA.at[i].set(MA_i)
|
130
|
-
|
131
|
-
# Initialize articulated-body bias forces
|
132
|
-
i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]]
|
133
|
-
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
134
|
-
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
|
135
|
-
|
136
|
-
pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i])
|
137
|
-
pA = pA.at[i].set(pA_i)
|
138
|
-
|
139
|
-
pR_i = Cross.vx_star(m_v[i]) @ IM[i] @ m_v[i] - K̅ᵥ[i] * m_v[i]
|
140
|
-
pR = pR.at[i].set(pR_i)
|
141
|
-
|
142
|
-
return (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), None
|
143
|
-
|
144
|
-
(i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), _ = jax.lax.scan(
|
145
|
-
f=loop_body_pass1,
|
146
|
-
init=pass_1_carry,
|
147
|
-
xs=np.arange(start=1, stop=model.NB),
|
148
|
-
)
|
149
|
-
|
150
|
-
U = jnp.zeros_like(S)
|
151
|
-
m_U = jnp.zeros_like(S)
|
152
|
-
d = jnp.zeros(shape=(model.NB, 1))
|
153
|
-
u = jnp.zeros(shape=(model.NB, 1))
|
154
|
-
m_u = jnp.zeros(shape=(model.NB, 1))
|
155
|
-
|
156
|
-
Pass2Carry = Tuple[
|
157
|
-
jtp.MatrixJax,
|
158
|
-
jtp.MatrixJax,
|
159
|
-
jtp.MatrixJax,
|
160
|
-
jtp.MatrixJax,
|
161
|
-
jtp.MatrixJax,
|
162
|
-
jtp.MatrixJax,
|
163
|
-
jtp.MatrixJax,
|
164
|
-
]
|
165
|
-
|
166
|
-
pass_2_carry = (U, m_U, d, u, m_u, MA, pA)
|
167
|
-
|
168
|
-
# Pass 2
|
169
|
-
def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
|
170
|
-
ii = i - 1
|
171
|
-
U, m_U, d, u, m_u, MA, pA = carry
|
172
|
-
|
173
|
-
# Compute intermediate results
|
174
|
-
u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
|
175
|
-
u = u.at[i].set(u_i.squeeze())
|
176
|
-
|
177
|
-
has_motors = jnp.allclose(Γ[i], 1.0)
|
178
|
-
|
179
|
-
m_u_i = (
|
180
|
-
tau[ii] / Γ[i] * has_motors - m_S[i].T @ pR[i]
|
181
|
-
if tau.size != 0
|
182
|
-
else -m_S[i].T @ pR[i]
|
183
|
-
)
|
184
|
-
m_u = m_u.at[i].set(m_u_i.squeeze())
|
185
|
-
|
186
|
-
U_i = MA[i] @ S[i]
|
187
|
-
U = U.at[i].set(U_i)
|
188
|
-
|
189
|
-
m_U_i = IM[i] @ m_S[i]
|
190
|
-
m_U = m_U.at[i].set(m_U_i)
|
191
|
-
|
192
|
-
d_i = S[i].T @ MA[i] @ S[i] + m_S[i].T @ IM[i] @ m_S[i]
|
193
|
-
d = d.at[i].set(d_i.squeeze())
|
194
|
-
|
195
|
-
# Compute the articulated-body inertia and bias forces of this link
|
196
|
-
Ma = MA[i] + IM[i] - U[i] / d[i] @ U[i].T - m_U[i] / d[i] @ m_U[i].T
|
197
|
-
pa = (
|
198
|
-
pA[i]
|
199
|
-
+ pR[i]
|
200
|
-
+ Ma[i] @ c[i]
|
201
|
-
+ IM[i] @ m_c[i]
|
202
|
-
+ U[i] / d[i] * u[i]
|
203
|
-
+ m_U[i] / d[i] * m_u[i]
|
204
|
-
)
|
205
|
-
|
206
|
-
# Propagate them to the parent, handling the base link
|
207
|
-
def propagate(
|
208
|
-
MA_pA: Tuple[jtp.MatrixJax, jtp.MatrixJax]
|
209
|
-
) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]:
|
210
|
-
MA, pA = MA_pA
|
211
|
-
|
212
|
-
MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
|
213
|
-
MA = MA.at[λ[i]].set(MA_λi)
|
214
|
-
|
215
|
-
pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
|
216
|
-
pA = pA.at[λ[i]].set(pA_λi)
|
217
|
-
|
218
|
-
return MA, pA
|
219
|
-
|
220
|
-
MA, pA = jax.lax.cond(
|
221
|
-
pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
|
222
|
-
true_fun=propagate,
|
223
|
-
false_fun=lambda MA_pA: MA_pA,
|
224
|
-
operand=(MA, pA),
|
225
|
-
)
|
226
|
-
|
227
|
-
return (U, m_U, d, u, m_u, MA, pA), None
|
228
|
-
|
229
|
-
(U, m_U, d, u, m_u, MA, pA), _ = jax.lax.scan(
|
230
|
-
f=loop_body_pass2,
|
231
|
-
init=pass_2_carry,
|
232
|
-
xs=np.flip(np.arange(start=1, stop=model.NB)),
|
233
|
-
)
|
234
|
-
|
235
|
-
if model.is_floating_base:
|
236
|
-
a0 = jnp.linalg.solve(-MA[0], pA[0])
|
237
|
-
else:
|
238
|
-
a0 = -B_X_W @ jnp.vstack(model.gravity)
|
239
|
-
|
240
|
-
a = jnp.zeros_like(S)
|
241
|
-
a = a.at[0].set(a0)
|
242
|
-
qdd = jnp.zeros_like(q)
|
243
|
-
|
244
|
-
Pass3Carry = Tuple[jtp.MatrixJax, jtp.VectorJax]
|
245
|
-
pass_3_carry = (a, qdd)
|
246
|
-
|
247
|
-
# Pass 3
|
248
|
-
def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
|
249
|
-
ii = i - 1
|
250
|
-
a, qdd = carry
|
251
|
-
|
252
|
-
# Propagate link accelerations
|
253
|
-
a_i = i_X_λi[i] @ a[λ[i]] + c[i]
|
254
|
-
|
255
|
-
# Compute joint accelerations
|
256
|
-
qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / d[i]
|
257
|
-
qdd = qdd.at[ii].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd
|
258
|
-
|
259
|
-
a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
|
260
|
-
a = a.at[i].set(a_i)
|
261
|
-
|
262
|
-
return (a, qdd), None
|
263
|
-
|
264
|
-
(a, qdd), _ = jax.lax.scan(
|
265
|
-
f=loop_body_pass3,
|
266
|
-
init=pass_3_carry,
|
267
|
-
xs=np.arange(1, model.NB),
|
268
|
-
)
|
269
|
-
|
270
|
-
# Handle 1 DoF models
|
271
|
-
qdd = jnp.atleast_1d(qdd.squeeze())
|
272
|
-
qdd = jnp.vstack(qdd) if qdd.size > 0 else jnp.empty(shape=(0, 1))
|
273
|
-
|
274
|
-
# Get the resulting base acceleration (w/o gravity) in body-fixed representation
|
275
|
-
B_a_WB = a[0]
|
276
|
-
|
277
|
-
# Convert the base acceleration to inertial-fixed representation, and add gravity
|
278
|
-
W_a_WB = jnp.vstack(
|
279
|
-
jnp.linalg.solve(B_X_W, B_a_WB) + jnp.vstack(model.gravity)
|
280
|
-
if model.is_floating_base
|
281
|
-
else jnp.zeros(6)
|
282
|
-
)
|
283
|
-
|
284
|
-
return W_a_WB, qdd
|
jaxsim/physics/algos/crba.py
DELETED
@@ -1,154 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
9
|
-
|
10
|
-
from . import utils
|
11
|
-
|
12
|
-
|
13
|
-
def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
|
14
|
-
"""
|
15
|
-
Compute the Composite Rigid-Body Inertia Matrix (CRBA) for an articulated body or robot given joint positions.
|
16
|
-
|
17
|
-
Args:
|
18
|
-
model (PhysicsModel): The physics model of the articulated body or robot.
|
19
|
-
q (jtp.Vector): Joint positions (Generalized coordinates).
|
20
|
-
|
21
|
-
Returns:
|
22
|
-
jtp.Matrix: The Composite Rigid-Body Inertia Matrix (CRBA) of the articulated body or robot.
|
23
|
-
"""
|
24
|
-
|
25
|
-
_, q, _, _, _, _ = utils.process_inputs(
|
26
|
-
physics_model=model, xfb=None, q=q, qd=None, tau=None, f_ext=None
|
27
|
-
)
|
28
|
-
|
29
|
-
Xtree = model.tree_transforms
|
30
|
-
Mc = model.spatial_inertias
|
31
|
-
S = model.motion_subspaces(q=q)
|
32
|
-
Xj = model.joint_transforms(q=q)
|
33
|
-
|
34
|
-
Xup = jnp.zeros_like(Xtree)
|
35
|
-
i_X_0 = jnp.zeros_like(Xtree)
|
36
|
-
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
37
|
-
|
38
|
-
# Parent array mapping: i -> λ(i).
|
39
|
-
# Exception: λ(0) must not be used, it's initialized to -1.
|
40
|
-
λ = model.parent
|
41
|
-
|
42
|
-
# ====================
|
43
|
-
# Propagate kinematics
|
44
|
-
# ====================
|
45
|
-
|
46
|
-
ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
|
47
|
-
forward_pass_carry = (Xup, i_X_0)
|
48
|
-
|
49
|
-
def propagate_kinematics(
|
50
|
-
carry: ForwardPassCarry, i: jtp.Int
|
51
|
-
) -> Tuple[ForwardPassCarry, None]:
|
52
|
-
Xup, i_X_0 = carry
|
53
|
-
|
54
|
-
Xup_i = Xj[i] @ Xtree[i]
|
55
|
-
Xup = Xup.at[i].set(Xup_i)
|
56
|
-
|
57
|
-
i_X_0_i = Xup[i] @ i_X_0[λ[i]]
|
58
|
-
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
59
|
-
|
60
|
-
return (Xup, i_X_0), None
|
61
|
-
|
62
|
-
(Xup, i_X_0), _ = jax.lax.scan(
|
63
|
-
f=propagate_kinematics,
|
64
|
-
init=forward_pass_carry,
|
65
|
-
xs=np.arange(start=1, stop=model.NB),
|
66
|
-
)
|
67
|
-
|
68
|
-
# ===================
|
69
|
-
# Compute mass matrix
|
70
|
-
# ===================
|
71
|
-
|
72
|
-
M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))
|
73
|
-
|
74
|
-
BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
|
75
|
-
backward_pass_carry = (Mc, M)
|
76
|
-
|
77
|
-
def backward_pass(
|
78
|
-
carry: BackwardPassCarry, i: jtp.Int
|
79
|
-
) -> Tuple[BackwardPassCarry, None]:
|
80
|
-
ii = i - 1
|
81
|
-
Mc, M = carry
|
82
|
-
|
83
|
-
Mc_λi = Mc[λ[i]] + Xup[i].T @ Mc[i] @ Xup[i]
|
84
|
-
Mc = Mc.at[λ[i]].set(Mc_λi)
|
85
|
-
|
86
|
-
Fi = Mc[i] @ S[i]
|
87
|
-
M_ii = S[i].T @ Fi
|
88
|
-
M = M.at[ii + 6, ii + 6].set(M_ii.squeeze())
|
89
|
-
|
90
|
-
j = i
|
91
|
-
|
92
|
-
CarryInnerFn = Tuple[jtp.Int, jtp.MatrixJax, jtp.MatrixJax]
|
93
|
-
carry_inner_fn = (j, Fi, M)
|
94
|
-
|
95
|
-
def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn:
|
96
|
-
j, Fi, M = carry
|
97
|
-
|
98
|
-
Fi = Xup[j].T @ Fi
|
99
|
-
j = λ[j]
|
100
|
-
jj = j - 1
|
101
|
-
|
102
|
-
M_ij = Fi.T @ S[j]
|
103
|
-
|
104
|
-
M = M.at[ii + 6, jj + 6].set(M_ij.squeeze())
|
105
|
-
M = M.at[jj + 6, ii + 6].set(M_ij.squeeze())
|
106
|
-
|
107
|
-
return j, Fi, M
|
108
|
-
|
109
|
-
# The following functions are part of a (rather messy) workaround for computing
|
110
|
-
# a while loop using a for loop with fixed number of iterations.
|
111
|
-
def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> Tuple[CarryInnerFn, None]:
|
112
|
-
def compute_inner(carry: CarryInnerFn) -> Tuple[CarryInnerFn, None]:
|
113
|
-
j, Fi, M = carry
|
114
|
-
out = jax.lax.cond(
|
115
|
-
pred=(λ[j] > 0),
|
116
|
-
true_fun=while_loop_body,
|
117
|
-
false_fun=lambda carry: carry,
|
118
|
-
operand=carry,
|
119
|
-
)
|
120
|
-
return out, None
|
121
|
-
|
122
|
-
j, Fi, M = carry
|
123
|
-
return jax.lax.cond(
|
124
|
-
pred=(k == j),
|
125
|
-
true_fun=compute_inner,
|
126
|
-
false_fun=lambda carry: (carry, None),
|
127
|
-
operand=carry,
|
128
|
-
)
|
129
|
-
|
130
|
-
(j, Fi, M), _ = jax.lax.scan(
|
131
|
-
f=inner_fn,
|
132
|
-
init=carry_inner_fn,
|
133
|
-
xs=np.flip(np.arange(start=1, stop=model.NB)),
|
134
|
-
)
|
135
|
-
|
136
|
-
Fi = i_X_0[j].T @ Fi
|
137
|
-
|
138
|
-
M = M.at[0:6, ii + 6].set(Fi.squeeze())
|
139
|
-
M = M.at[ii + 6, 0:6].set(Fi.squeeze())
|
140
|
-
|
141
|
-
return (Mc, M), None
|
142
|
-
|
143
|
-
# This scan performs the backward pass to compute Mbj, Mjb and Mjj, that
|
144
|
-
# also includes a fake while loop implemented with a scan and two cond.
|
145
|
-
(Mc, M), _ = jax.lax.scan(
|
146
|
-
f=backward_pass,
|
147
|
-
init=backward_pass_carry,
|
148
|
-
xs=np.flip(np.arange(start=1, stop=model.NB)),
|
149
|
-
)
|
150
|
-
|
151
|
-
# Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶
|
152
|
-
M = M.at[0:6, 0:6].set(Mc[0])
|
153
|
-
|
154
|
-
return M
|
@@ -1,79 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.math.adjoint import Adjoint
|
9
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
10
|
-
|
11
|
-
from . import utils
|
12
|
-
|
13
|
-
|
14
|
-
def forward_kinematics_model(
|
15
|
-
model: PhysicsModel, q: jtp.Vector, xfb: jtp.Vector
|
16
|
-
) -> jtp.Array:
|
17
|
-
"""
|
18
|
-
Compute the forward kinematics transformations for all links in an articulated body or robot.
|
19
|
-
|
20
|
-
Args:
|
21
|
-
model (PhysicsModel): The physics model of the articulated body or robot.
|
22
|
-
q (jtp.Vector): Joint positions (Generalized coordinates).
|
23
|
-
xfb (jtp.Vector): The base pose vector, including the quaternion (first 4 elements) and translation (last 3 elements).
|
24
|
-
|
25
|
-
Returns:
|
26
|
-
jtp.Array: A 3D array containing the forward kinematics transformations for all links.
|
27
|
-
"""
|
28
|
-
|
29
|
-
x_fb, q, _, _, _, _ = utils.process_inputs(
|
30
|
-
physics_model=model, xfb=xfb, q=q, qd=None, tau=None, f_ext=None
|
31
|
-
)
|
32
|
-
|
33
|
-
W_X_0 = Adjoint.from_quaternion_and_translation(
|
34
|
-
quaternion=x_fb[0:4], translation=x_fb[4:7]
|
35
|
-
)
|
36
|
-
|
37
|
-
# This is the 6D velocity transform from i-th link frame to the world frame
|
38
|
-
W_X_i = jnp.zeros(shape=[model.NB, 6, 6])
|
39
|
-
W_X_i = W_X_i.at[0].set(W_X_0)
|
40
|
-
|
41
|
-
i_X_pre = model.joint_transforms(q=q)
|
42
|
-
pre_X_λi = model.tree_transforms
|
43
|
-
|
44
|
-
# This is the parent-to-child 6D velocity transforms of all links
|
45
|
-
i_X_λi = jnp.zeros_like(i_X_pre)
|
46
|
-
|
47
|
-
# Parent array mapping: i -> λ(i).
|
48
|
-
# Exception: λ(0) must not be used, it's initialized to -1.
|
49
|
-
λ = model.parent
|
50
|
-
|
51
|
-
PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
|
52
|
-
propagate_kinematics_carry = (i_X_λi, W_X_i)
|
53
|
-
|
54
|
-
def propagate_kinematics(
|
55
|
-
carry: PropagateKinematicsCarry, i: jtp.Int
|
56
|
-
) -> Tuple[PropagateKinematicsCarry, None]:
|
57
|
-
i_X_λi, W_X_i = carry
|
58
|
-
|
59
|
-
i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
|
60
|
-
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
|
61
|
-
|
62
|
-
W_X_i_i = W_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
|
63
|
-
W_X_i = W_X_i.at[i].set(W_X_i_i)
|
64
|
-
|
65
|
-
return (i_X_λi, W_X_i), None
|
66
|
-
|
67
|
-
(_, W_X_i), _ = jax.lax.scan(
|
68
|
-
f=propagate_kinematics,
|
69
|
-
init=propagate_kinematics_carry,
|
70
|
-
xs=np.arange(start=1, stop=model.NB),
|
71
|
-
)
|
72
|
-
|
73
|
-
return jnp.stack([Adjoint.to_transform(adjoint=X) for X in list(W_X_i)])
|
74
|
-
|
75
|
-
|
76
|
-
def forward_kinematics(
|
77
|
-
model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector, xfb: jtp.Vector
|
78
|
-
) -> jtp.Matrix:
|
79
|
-
return forward_kinematics_model(model=model, q=q, xfb=xfb)[body_index]
|
jaxsim/physics/algos/jacobian.py
DELETED
@@ -1,98 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.math.adjoint import Adjoint
|
9
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
10
|
-
|
11
|
-
from . import utils
|
12
|
-
|
13
|
-
|
14
|
-
def jacobian(model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector) -> jtp.Matrix:
|
15
|
-
"""
|
16
|
-
Compute the Jacobian matrix for a specific link in an articulated body or robot.
|
17
|
-
|
18
|
-
Args:
|
19
|
-
model (PhysicsModel): The physics model of the articulated body or robot.
|
20
|
-
body_index (jtp.Int): The index of the link for which to compute the Jacobian matrix.
|
21
|
-
q (jtp.Vector): Joint positions (Generalized coordinates).
|
22
|
-
|
23
|
-
Returns:
|
24
|
-
jtp.Matrix: The Jacobian matrix for the specified link.
|
25
|
-
"""
|
26
|
-
_, q, _, _, _, _ = utils.process_inputs(physics_model=model, q=q)
|
27
|
-
|
28
|
-
S = model.motion_subspaces(q=q)
|
29
|
-
i_X_pre = model.joint_transforms(q=q)
|
30
|
-
pre_X_λi = model.tree_transforms
|
31
|
-
i_X_λi = jnp.zeros_like(i_X_pre)
|
32
|
-
|
33
|
-
i_X_0 = jnp.zeros_like(i_X_pre)
|
34
|
-
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
35
|
-
|
36
|
-
# Parent array mapping: i -> λ(i).
|
37
|
-
# Exception: λ(0) must not be used, it's initialized to -1.
|
38
|
-
λ = model.parent
|
39
|
-
|
40
|
-
# ====================
|
41
|
-
# Propagate kinematics
|
42
|
-
# ====================
|
43
|
-
|
44
|
-
PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
|
45
|
-
propagate_kinematics_carry = (i_X_λi, i_X_0)
|
46
|
-
|
47
|
-
def propagate_kinematics(
|
48
|
-
carry: PropagateKinematicsCarry, i: jtp.Int
|
49
|
-
) -> Tuple[PropagateKinematicsCarry, None]:
|
50
|
-
i_X_λi, i_X_0 = carry
|
51
|
-
|
52
|
-
# For each body (i), compute the parent (λi) to body (i) adjoint matrix
|
53
|
-
i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
|
54
|
-
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
|
55
|
-
|
56
|
-
# Compute the base (0) to body (i) adjoint matrix.
|
57
|
-
# This works fine since we traverse the kinematic tree following the link
|
58
|
-
# indices assigned with BFS.
|
59
|
-
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
|
60
|
-
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
61
|
-
|
62
|
-
return (i_X_λi, i_X_0), None
|
63
|
-
|
64
|
-
(i_X_λi, i_X_0), _ = jax.lax.scan(
|
65
|
-
f=propagate_kinematics,
|
66
|
-
init=propagate_kinematics_carry,
|
67
|
-
xs=np.arange(start=1, stop=model.NB),
|
68
|
-
)
|
69
|
-
|
70
|
-
# ============================
|
71
|
-
# Compute doubly-left Jacobian
|
72
|
-
# ============================
|
73
|
-
|
74
|
-
J = jnp.zeros(shape=(6, 6 + model.dofs()))
|
75
|
-
|
76
|
-
Jb = i_X_0[body_index]
|
77
|
-
J = J.at[0:6, 0:6].set(Jb)
|
78
|
-
|
79
|
-
# To make JIT happy, we operate on a boolean version of κ(i).
|
80
|
-
# Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
|
81
|
-
κ_bool = model.support_body_array_bool(body_index=body_index)
|
82
|
-
|
83
|
-
def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> Tuple[jtp.MatrixJax, None]:
|
84
|
-
def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:
|
85
|
-
ii = i - 1
|
86
|
-
Js_i = i_X_0[body_index] @ Adjoint.inverse(i_X_0[i]) @ S[i]
|
87
|
-
J = J.at[0:6, 6 + ii].set(Js_i.squeeze())
|
88
|
-
|
89
|
-
return J
|
90
|
-
|
91
|
-
J = jax.lax.select(pred=κ_bool[i], on_true=update_jacobian(J, i), on_false=J)
|
92
|
-
return J, None
|
93
|
-
|
94
|
-
J, _ = jax.lax.scan(
|
95
|
-
f=compute_jacobian, init=J, xs=np.arange(start=1, stop=model.NB)
|
96
|
-
)
|
97
|
-
|
98
|
-
return J
|