jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev366__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +86 -74
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/link.py +2 -2
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev366.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/top_level.txt +0 -0
@@ -1,196 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.math.adjoint import Adjoint
|
9
|
-
from jaxsim.math.cross import Cross
|
10
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
11
|
-
|
12
|
-
from . import utils
|
13
|
-
|
14
|
-
|
15
|
-
def rnea(
|
16
|
-
model: PhysicsModel,
|
17
|
-
xfb: jtp.Vector,
|
18
|
-
q: jtp.Vector,
|
19
|
-
qd: jtp.Vector,
|
20
|
-
qdd: jtp.Vector,
|
21
|
-
a0fb: jtp.Vector = jnp.zeros(6),
|
22
|
-
f_ext: jtp.Matrix | None = None,
|
23
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
24
|
-
"""
|
25
|
-
Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics.
|
26
|
-
"""
|
27
|
-
|
28
|
-
xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
|
29
|
-
physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext
|
30
|
-
)
|
31
|
-
|
32
|
-
a0fb = a0fb.squeeze()
|
33
|
-
gravity = model.gravity.squeeze()
|
34
|
-
|
35
|
-
if a0fb.shape[0] != 6:
|
36
|
-
raise ValueError(a0fb.shape)
|
37
|
-
|
38
|
-
M = model.spatial_inertias
|
39
|
-
pre_X_λi = model.tree_transforms
|
40
|
-
i_X_pre = model.joint_transforms(q=q)
|
41
|
-
S = model.motion_subspaces(q=q)
|
42
|
-
i_X_λi = jnp.zeros_like(pre_X_λi)
|
43
|
-
|
44
|
-
Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
|
45
|
-
IM = jnp.array([*model._joint_motor_inertia.values()])
|
46
|
-
K_v = jnp.array([*model._joint_motor_viscous_friction.values()])
|
47
|
-
K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ)
|
48
|
-
m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)
|
49
|
-
|
50
|
-
i_X_0 = jnp.zeros_like(pre_X_λi)
|
51
|
-
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
52
|
-
|
53
|
-
# Parent array mapping: i -> λ(i).
|
54
|
-
# Exception: λ(0) must not be used, it's initialized to -1.
|
55
|
-
λ = model.parent_array()
|
56
|
-
|
57
|
-
v = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
58
|
-
a = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
59
|
-
f = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
60
|
-
|
61
|
-
v_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
62
|
-
a_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
63
|
-
f_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
|
64
|
-
|
65
|
-
# 6D transform of base velocity
|
66
|
-
B_X_W = Adjoint.from_quaternion_and_translation(
|
67
|
-
quaternion=xfb[0:4],
|
68
|
-
translation=xfb[4:7],
|
69
|
-
inverse=True,
|
70
|
-
normalize_quaternion=True,
|
71
|
-
)
|
72
|
-
i_X_λi = i_X_λi.at[0].set(B_X_W)
|
73
|
-
|
74
|
-
a_0 = -B_X_W @ jnp.vstack(gravity)
|
75
|
-
a = a.at[0].set(a_0)
|
76
|
-
|
77
|
-
if model.is_floating_base:
|
78
|
-
W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
|
79
|
-
|
80
|
-
v_0 = B_X_W @ W_v_WB
|
81
|
-
v = v.at[0].set(v_0)
|
82
|
-
|
83
|
-
a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity))
|
84
|
-
a = a.at[0].set(a_0)
|
85
|
-
|
86
|
-
f_0 = (
|
87
|
-
M[0] @ a[0]
|
88
|
-
+ Cross.vx_star(v[0]) @ M[0] @ v[0]
|
89
|
-
- Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0])
|
90
|
-
)
|
91
|
-
f = f.at[0].set(f_0)
|
92
|
-
|
93
|
-
ForwardPassCarry = Tuple[
|
94
|
-
jtp.MatrixJax,
|
95
|
-
jtp.MatrixJax,
|
96
|
-
jtp.MatrixJax,
|
97
|
-
jtp.MatrixJax,
|
98
|
-
jtp.MatrixJax,
|
99
|
-
jtp.MatrixJax,
|
100
|
-
jtp.MatrixJax,
|
101
|
-
jtp.MatrixJax,
|
102
|
-
]
|
103
|
-
forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m)
|
104
|
-
|
105
|
-
def forward_pass(
|
106
|
-
carry: ForwardPassCarry, i: jtp.Int
|
107
|
-
) -> Tuple[ForwardPassCarry, None]:
|
108
|
-
ii = i - 1
|
109
|
-
i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry
|
110
|
-
|
111
|
-
vJ = S[i] * qd[ii]
|
112
|
-
vJ_m = m_S[i] * qd[ii]
|
113
|
-
|
114
|
-
i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
|
115
|
-
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
|
116
|
-
|
117
|
-
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
118
|
-
v = v.at[i].set(v_i)
|
119
|
-
|
120
|
-
v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m
|
121
|
-
v_m = v_m.at[i].set(v_i_m)
|
122
|
-
|
123
|
-
a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
|
124
|
-
a = a.at[i].set(a_i)
|
125
|
-
|
126
|
-
a_i_m = i_X_λi[i] @ a_m[λ[i]] + m_S[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m
|
127
|
-
a_m = a_m.at[i].set(a_i_m)
|
128
|
-
|
129
|
-
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
|
130
|
-
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
131
|
-
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
|
132
|
-
|
133
|
-
f_i = (
|
134
|
-
M[i] @ a[i]
|
135
|
-
+ Cross.vx_star(v[i]) @ M[i] @ v[i]
|
136
|
-
- i_Xf_W @ jnp.vstack(f_ext[i])
|
137
|
-
)
|
138
|
-
f = f.at[i].set(f_i)
|
139
|
-
|
140
|
-
f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i]
|
141
|
-
f_m = f_m.at[i].set(f_i_m)
|
142
|
-
|
143
|
-
return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None
|
144
|
-
|
145
|
-
(i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan(
|
146
|
-
f=forward_pass,
|
147
|
-
init=forward_pass_carry,
|
148
|
-
xs=np.arange(start=1, stop=model.NB),
|
149
|
-
)
|
150
|
-
|
151
|
-
tau = jnp.zeros_like(q)
|
152
|
-
|
153
|
-
BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
|
154
|
-
backward_pass_carry = (tau, f, f_m)
|
155
|
-
|
156
|
-
def backward_pass(
|
157
|
-
carry: BackwardPassCarry, i: jtp.Int
|
158
|
-
) -> Tuple[BackwardPassCarry, None]:
|
159
|
-
ii = i - 1
|
160
|
-
tau, f, f_m = carry
|
161
|
-
|
162
|
-
value = S[i].T @ f[i] + m_S[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii]
|
163
|
-
tau = tau.at[ii].set(value.squeeze())
|
164
|
-
|
165
|
-
def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax:
|
166
|
-
f, f_m = ffm
|
167
|
-
f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
|
168
|
-
f = f.at[λ[i]].set(f_λi)
|
169
|
-
|
170
|
-
f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i]
|
171
|
-
f_m = f_m.at[λ[i]].set(f_m_λi)
|
172
|
-
return f, f_m
|
173
|
-
|
174
|
-
f, f_m = jax.lax.cond(
|
175
|
-
pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
|
176
|
-
true_fun=update_f,
|
177
|
-
false_fun=lambda f: f,
|
178
|
-
operand=(f, f_m),
|
179
|
-
)
|
180
|
-
|
181
|
-
return (tau, f, f_m), None
|
182
|
-
|
183
|
-
(tau, f, f_m), _ = jax.lax.scan(
|
184
|
-
f=backward_pass,
|
185
|
-
init=backward_pass_carry,
|
186
|
-
xs=np.flip(np.arange(start=1, stop=model.NB)),
|
187
|
-
)
|
188
|
-
|
189
|
-
# Handle 1 DoF models
|
190
|
-
tau = jnp.atleast_1d(tau.squeeze())
|
191
|
-
tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1))
|
192
|
-
|
193
|
-
# Express the base 6D force in the world frame
|
194
|
-
W_f0 = B_X_W.T @ jnp.vstack(f[0])
|
195
|
-
|
196
|
-
return W_f0, tau
|