jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/physics/algos/rnea.py
DELETED
@@ -1,180 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.math.adjoint import Adjoint
|
9
|
-
from jaxsim.math.cross import Cross
|
10
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
11
|
-
|
12
|
-
from . import utils
|
13
|
-
|
14
|
-
|
15
|
-
def rnea(
|
16
|
-
model: PhysicsModel,
|
17
|
-
xfb: jtp.Vector,
|
18
|
-
q: jtp.Vector,
|
19
|
-
qd: jtp.Vector,
|
20
|
-
qdd: jtp.Vector,
|
21
|
-
a0fb: jtp.Vector = jnp.zeros(6),
|
22
|
-
f_ext: jtp.Matrix | None = None,
|
23
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
24
|
-
"""
|
25
|
-
Perform Inverse Dynamics Calculation using the Recursive Newton-Euler Algorithm (RNEA).
|
26
|
-
|
27
|
-
This function calculates the joint torques (forces) required to achieve a desired motion
|
28
|
-
given the robot's configuration, velocities, accelerations, and external forces.
|
29
|
-
|
30
|
-
Args:
|
31
|
-
model (PhysicsModel): The robot's physics model containing dynamic parameters.
|
32
|
-
xfb (jtp.Vector): The floating base state, including orientation and position.
|
33
|
-
q (jtp.Vector): Joint positions (angles).
|
34
|
-
qd (jtp.Vector): Joint velocities.
|
35
|
-
qdd (jtp.Vector): Joint accelerations.
|
36
|
-
a0fb (jtp.Vector, optional): Base acceleration. Defaults to zeros.
|
37
|
-
f_ext (jtp.Matrix, optional): External forces acting on the robot. Defaults to None.
|
38
|
-
|
39
|
-
Returns:
|
40
|
-
W_f0 (jtp.Vector): The base 6D force expressed in the world frame.
|
41
|
-
tau (jtp.Vector): Joint torques (forces) required for the desired motion.
|
42
|
-
"""
|
43
|
-
|
44
|
-
xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
|
45
|
-
physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext
|
46
|
-
)
|
47
|
-
|
48
|
-
a0fb = a0fb.squeeze()
|
49
|
-
gravity = model.gravity.squeeze()
|
50
|
-
|
51
|
-
if a0fb.shape[0] != 6:
|
52
|
-
raise ValueError(a0fb.shape)
|
53
|
-
|
54
|
-
M = model.spatial_inertias
|
55
|
-
pre_X_λi = model.tree_transforms
|
56
|
-
i_X_pre = model.joint_transforms(q=q)
|
57
|
-
S = model.motion_subspaces(q=q)
|
58
|
-
i_X_λi = jnp.zeros_like(pre_X_λi)
|
59
|
-
|
60
|
-
i_X_0 = jnp.zeros_like(pre_X_λi)
|
61
|
-
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
62
|
-
|
63
|
-
# Parent array mapping: i -> λ(i).
|
64
|
-
# Exception: λ(0) must not be used, it's initialized to -1.
|
65
|
-
λ = model.parent_array()
|
66
|
-
|
67
|
-
v = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
68
|
-
a = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
69
|
-
f = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
70
|
-
|
71
|
-
# 6D transform of base velocity
|
72
|
-
B_X_W = Adjoint.from_quaternion_and_translation(
|
73
|
-
quaternion=xfb[0:4],
|
74
|
-
translation=xfb[4:7],
|
75
|
-
inverse=True,
|
76
|
-
normalize_quaternion=True,
|
77
|
-
)
|
78
|
-
i_X_λi = i_X_λi.at[0].set(B_X_W)
|
79
|
-
|
80
|
-
a_0 = -B_X_W @ jnp.vstack(gravity)
|
81
|
-
a = a.at[0].set(a_0)
|
82
|
-
|
83
|
-
if model.is_floating_base:
|
84
|
-
W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
|
85
|
-
|
86
|
-
v_0 = B_X_W @ W_v_WB
|
87
|
-
v = v.at[0].set(v_0)
|
88
|
-
|
89
|
-
a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity))
|
90
|
-
a = a.at[0].set(a_0)
|
91
|
-
|
92
|
-
f_0 = (
|
93
|
-
M[0] @ a[0]
|
94
|
-
+ Cross.vx_star(v[0]) @ M[0] @ v[0]
|
95
|
-
- Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0])
|
96
|
-
)
|
97
|
-
f = f.at[0].set(f_0)
|
98
|
-
|
99
|
-
ForwardPassCarry = Tuple[
|
100
|
-
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
|
101
|
-
]
|
102
|
-
forward_pass_carry = (i_X_λi, v, a, i_X_0, f)
|
103
|
-
|
104
|
-
def forward_pass(
|
105
|
-
carry: ForwardPassCarry, i: jtp.Int
|
106
|
-
) -> Tuple[ForwardPassCarry, None]:
|
107
|
-
ii = i - 1
|
108
|
-
i_X_λi, v, a, i_X_0, f = carry
|
109
|
-
|
110
|
-
vJ = S[i] * qd[ii]
|
111
|
-
i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
|
112
|
-
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
|
113
|
-
|
114
|
-
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
115
|
-
v = v.at[i].set(v_i)
|
116
|
-
|
117
|
-
a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
|
118
|
-
a = a.at[i].set(a_i)
|
119
|
-
|
120
|
-
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
|
121
|
-
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
122
|
-
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
|
123
|
-
|
124
|
-
f_i = (
|
125
|
-
M[i] @ a[i]
|
126
|
-
+ Cross.vx_star(v[i]) @ M[i] @ v[i]
|
127
|
-
- i_Xf_W @ jnp.vstack(f_ext[i])
|
128
|
-
)
|
129
|
-
f = f.at[i].set(f_i)
|
130
|
-
|
131
|
-
return (i_X_λi, v, a, i_X_0, f), None
|
132
|
-
|
133
|
-
(i_X_λi, v, a, i_X_0, f), _ = jax.lax.scan(
|
134
|
-
f=forward_pass,
|
135
|
-
init=forward_pass_carry,
|
136
|
-
xs=np.arange(start=1, stop=model.NB),
|
137
|
-
)
|
138
|
-
|
139
|
-
tau = jnp.zeros_like(q)
|
140
|
-
|
141
|
-
BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
|
142
|
-
backward_pass_carry = (tau, f)
|
143
|
-
|
144
|
-
def backward_pass(
|
145
|
-
carry: BackwardPassCarry, i: jtp.Int
|
146
|
-
) -> Tuple[BackwardPassCarry, None]:
|
147
|
-
ii = i - 1
|
148
|
-
tau, f = carry
|
149
|
-
|
150
|
-
value = S[i].T @ f[i]
|
151
|
-
tau = tau.at[ii].set(value.squeeze())
|
152
|
-
|
153
|
-
def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax:
|
154
|
-
f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
|
155
|
-
f = f.at[λ[i]].set(f_λi)
|
156
|
-
return f
|
157
|
-
|
158
|
-
f = jax.lax.cond(
|
159
|
-
pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
|
160
|
-
true_fun=update_f,
|
161
|
-
false_fun=lambda f: f,
|
162
|
-
operand=f,
|
163
|
-
)
|
164
|
-
|
165
|
-
return (tau, f), None
|
166
|
-
|
167
|
-
(tau, f), _ = jax.lax.scan(
|
168
|
-
f=backward_pass,
|
169
|
-
init=backward_pass_carry,
|
170
|
-
xs=np.flip(np.arange(start=1, stop=model.NB)),
|
171
|
-
)
|
172
|
-
|
173
|
-
# Handle 1 DoF models
|
174
|
-
tau = jnp.atleast_1d(tau.squeeze())
|
175
|
-
tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1))
|
176
|
-
|
177
|
-
# Express the base 6D force in the world frame
|
178
|
-
W_f0 = B_X_W.T @ jnp.vstack(f[0])
|
179
|
-
|
180
|
-
return W_f0, tau
|
@@ -1,196 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.math.adjoint import Adjoint
|
9
|
-
from jaxsim.math.cross import Cross
|
10
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
11
|
-
|
12
|
-
from . import utils
|
13
|
-
|
14
|
-
|
15
|
-
def rnea(
|
16
|
-
model: PhysicsModel,
|
17
|
-
xfb: jtp.Vector,
|
18
|
-
q: jtp.Vector,
|
19
|
-
qd: jtp.Vector,
|
20
|
-
qdd: jtp.Vector,
|
21
|
-
a0fb: jtp.Vector = jnp.zeros(6),
|
22
|
-
f_ext: jtp.Matrix | None = None,
|
23
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
24
|
-
"""
|
25
|
-
Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics.
|
26
|
-
"""
|
27
|
-
|
28
|
-
xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
|
29
|
-
physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext
|
30
|
-
)
|
31
|
-
|
32
|
-
a0fb = a0fb.squeeze()
|
33
|
-
gravity = model.gravity.squeeze()
|
34
|
-
|
35
|
-
if a0fb.shape[0] != 6:
|
36
|
-
raise ValueError(a0fb.shape)
|
37
|
-
|
38
|
-
M = model.spatial_inertias
|
39
|
-
pre_X_λi = model.tree_transforms
|
40
|
-
i_X_pre = model.joint_transforms(q=q)
|
41
|
-
S = model.motion_subspaces(q=q)
|
42
|
-
i_X_λi = jnp.zeros_like(pre_X_λi)
|
43
|
-
|
44
|
-
Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
|
45
|
-
IM = jnp.array([*model._joint_motor_inertia.values()])
|
46
|
-
K_v = jnp.array([*model._joint_motor_viscous_friction.values()])
|
47
|
-
K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ)
|
48
|
-
m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)
|
49
|
-
|
50
|
-
i_X_0 = jnp.zeros_like(pre_X_λi)
|
51
|
-
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
52
|
-
|
53
|
-
# Parent array mapping: i -> λ(i).
|
54
|
-
# Exception: λ(0) must not be used, it's initialized to -1.
|
55
|
-
λ = model.parent_array()
|
56
|
-
|
57
|
-
v = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
58
|
-
a = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
59
|
-
f = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
60
|
-
|
61
|
-
v_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
62
|
-
a_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
63
|
-
f_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
64
|
-
|
65
|
-
# 6D transform of base velocity
|
66
|
-
B_X_W = Adjoint.from_quaternion_and_translation(
|
67
|
-
quaternion=xfb[0:4],
|
68
|
-
translation=xfb[4:7],
|
69
|
-
inverse=True,
|
70
|
-
normalize_quaternion=True,
|
71
|
-
)
|
72
|
-
i_X_λi = i_X_λi.at[0].set(B_X_W)
|
73
|
-
|
74
|
-
a_0 = -B_X_W @ jnp.vstack(gravity)
|
75
|
-
a = a.at[0].set(a_0)
|
76
|
-
|
77
|
-
if model.is_floating_base:
|
78
|
-
W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
|
79
|
-
|
80
|
-
v_0 = B_X_W @ W_v_WB
|
81
|
-
v = v.at[0].set(v_0)
|
82
|
-
|
83
|
-
a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity))
|
84
|
-
a = a.at[0].set(a_0)
|
85
|
-
|
86
|
-
f_0 = (
|
87
|
-
M[0] @ a[0]
|
88
|
-
+ Cross.vx_star(v[0]) @ M[0] @ v[0]
|
89
|
-
- Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0])
|
90
|
-
)
|
91
|
-
f = f.at[0].set(f_0)
|
92
|
-
|
93
|
-
ForwardPassCarry = Tuple[
|
94
|
-
jtp.MatrixJax,
|
95
|
-
jtp.MatrixJax,
|
96
|
-
jtp.MatrixJax,
|
97
|
-
jtp.MatrixJax,
|
98
|
-
jtp.MatrixJax,
|
99
|
-
jtp.MatrixJax,
|
100
|
-
jtp.MatrixJax,
|
101
|
-
jtp.MatrixJax,
|
102
|
-
]
|
103
|
-
forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m)
|
104
|
-
|
105
|
-
def forward_pass(
|
106
|
-
carry: ForwardPassCarry, i: jtp.Int
|
107
|
-
) -> Tuple[ForwardPassCarry, None]:
|
108
|
-
ii = i - 1
|
109
|
-
i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry
|
110
|
-
|
111
|
-
vJ = S[i] * qd[ii]
|
112
|
-
vJ_m = m_S[i] * qd[ii]
|
113
|
-
|
114
|
-
i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
|
115
|
-
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
|
116
|
-
|
117
|
-
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
118
|
-
v = v.at[i].set(v_i)
|
119
|
-
|
120
|
-
v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m
|
121
|
-
v_m = v_m.at[i].set(v_i_m)
|
122
|
-
|
123
|
-
a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
|
124
|
-
a = a.at[i].set(a_i)
|
125
|
-
|
126
|
-
a_i_m = i_X_λi[i] @ a_m[λ[i]] + m_S[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m
|
127
|
-
a_m = a_m.at[i].set(a_i_m)
|
128
|
-
|
129
|
-
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
|
130
|
-
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
131
|
-
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
|
132
|
-
|
133
|
-
f_i = (
|
134
|
-
M[i] @ a[i]
|
135
|
-
+ Cross.vx_star(v[i]) @ M[i] @ v[i]
|
136
|
-
- i_Xf_W @ jnp.vstack(f_ext[i])
|
137
|
-
)
|
138
|
-
f = f.at[i].set(f_i)
|
139
|
-
|
140
|
-
f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i]
|
141
|
-
f_m = f_m.at[i].set(f_i_m)
|
142
|
-
|
143
|
-
return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None
|
144
|
-
|
145
|
-
(i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan(
|
146
|
-
f=forward_pass,
|
147
|
-
init=forward_pass_carry,
|
148
|
-
xs=np.arange(start=1, stop=model.NB),
|
149
|
-
)
|
150
|
-
|
151
|
-
tau = jnp.zeros_like(q)
|
152
|
-
|
153
|
-
BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
|
154
|
-
backward_pass_carry = (tau, f, f_m)
|
155
|
-
|
156
|
-
def backward_pass(
|
157
|
-
carry: BackwardPassCarry, i: jtp.Int
|
158
|
-
) -> Tuple[BackwardPassCarry, None]:
|
159
|
-
ii = i - 1
|
160
|
-
tau, f, f_m = carry
|
161
|
-
|
162
|
-
value = S[i].T @ f[i] + m_S[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii]
|
163
|
-
tau = tau.at[ii].set(value.squeeze())
|
164
|
-
|
165
|
-
def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax:
|
166
|
-
f, f_m = ffm
|
167
|
-
f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
|
168
|
-
f = f.at[λ[i]].set(f_λi)
|
169
|
-
|
170
|
-
f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i]
|
171
|
-
f_m = f_m.at[λ[i]].set(f_m_λi)
|
172
|
-
return f, f_m
|
173
|
-
|
174
|
-
f, f_m = jax.lax.cond(
|
175
|
-
pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
|
176
|
-
true_fun=update_f,
|
177
|
-
false_fun=lambda f: f,
|
178
|
-
operand=(f, f_m),
|
179
|
-
)
|
180
|
-
|
181
|
-
return (tau, f, f_m), None
|
182
|
-
|
183
|
-
(tau, f, f_m), _ = jax.lax.scan(
|
184
|
-
f=backward_pass,
|
185
|
-
init=backward_pass_carry,
|
186
|
-
xs=np.flip(np.arange(start=1, stop=model.NB)),
|
187
|
-
)
|
188
|
-
|
189
|
-
# Handle 1 DoF models
|
190
|
-
tau = jnp.atleast_1d(tau.squeeze())
|
191
|
-
tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1))
|
192
|
-
|
193
|
-
# Express the base 6D force in the world frame
|
194
|
-
W_f0 = B_X_W.T @ jnp.vstack(f[0])
|
195
|
-
|
196
|
-
return W_f0, tau
|