jaxsim 0.3.1.dev62__py3-none-any.whl → 0.3.1.dev94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -5
- jaxsim/_version.py +2 -2
- jaxsim/api/com.py +3 -4
- jaxsim/api/common.py +11 -11
- jaxsim/api/contact.py +11 -3
- jaxsim/api/data.py +3 -6
- jaxsim/api/frame.py +9 -10
- jaxsim/api/kin_dyn_parameters.py +25 -28
- jaxsim/api/link.py +12 -12
- jaxsim/api/model.py +47 -43
- jaxsim/api/ode.py +19 -12
- jaxsim/api/ode_data.py +11 -11
- jaxsim/integrators/common.py +19 -29
- jaxsim/integrators/fixed_step.py +10 -10
- jaxsim/integrators/variable_step.py +13 -13
- jaxsim/math/__init__.py +2 -1
- jaxsim/math/joint_model.py +2 -1
- jaxsim/math/quaternion.py +3 -9
- jaxsim/math/transform.py +2 -2
- jaxsim/mujoco/loaders.py +5 -5
- jaxsim/mujoco/model.py +6 -6
- jaxsim/mujoco/visualizer.py +3 -0
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/joint.py +1 -1
- jaxsim/parsers/descriptions/link.py +3 -4
- jaxsim/parsers/descriptions/model.py +1 -1
- jaxsim/parsers/kinematic_graph.py +38 -39
- jaxsim/parsers/rod/parser.py +14 -14
- jaxsim/parsers/rod/utils.py +9 -11
- jaxsim/rbda/aba.py +6 -12
- jaxsim/rbda/collidable_points.py +8 -7
- jaxsim/rbda/contacts/soft.py +29 -27
- jaxsim/rbda/crba.py +3 -3
- jaxsim/rbda/forward_kinematics.py +1 -1
- jaxsim/rbda/jacobian.py +8 -8
- jaxsim/rbda/rnea.py +3 -3
- jaxsim/rbda/utils.py +1 -1
- jaxsim/terrain/terrain.py +100 -22
- jaxsim/typing.py +14 -22
- jaxsim/utils/jaxsim_dataclass.py +4 -4
- jaxsim/utils/wrappers.py +5 -1
- {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/METADATA +1 -1
- jaxsim-0.3.1.dev94.dist-info/RECORD +68 -0
- {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/WHEEL +1 -1
- jaxsim-0.3.1.dev62.dist-info/RECORD +0 -68
- {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/LICENSE +0 -0
- {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/top_level.txt +0 -0
jaxsim/parsers/rod/utils.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
import os
|
2
2
|
|
3
|
-
import jaxlie
|
4
3
|
import numpy as np
|
5
4
|
import numpy.typing as npt
|
6
5
|
import rod
|
7
6
|
|
8
7
|
import jaxsim.typing as jtp
|
9
|
-
from jaxsim.math import Inertia
|
8
|
+
from jaxsim.math import Adjoint, Inertia
|
10
9
|
from jaxsim.parsers import descriptions
|
11
10
|
|
12
11
|
|
@@ -21,10 +20,10 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
|
|
21
20
|
The 6D inertia matrix of the link expressed in the link frame.
|
22
21
|
"""
|
23
22
|
|
24
|
-
# Extract the "mass" element
|
23
|
+
# Extract the "mass" element.
|
25
24
|
m = inertial.mass
|
26
25
|
|
27
|
-
# Extract the "inertia" element
|
26
|
+
# Extract the "inertia" element.
|
28
27
|
inertia_element = inertial.inertia
|
29
28
|
|
30
29
|
ixx = inertia_element.ixx
|
@@ -34,7 +33,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
|
|
34
33
|
ixz = inertia_element.ixz if inertia_element.ixz is not None else 0.0
|
35
34
|
iyz = inertia_element.iyz if inertia_element.iyz is not None else 0.0
|
36
35
|
|
37
|
-
# Build the 3x3 inertia matrix expressed in the CoM
|
36
|
+
# Build the 3x3 inertia matrix expressed in the CoM.
|
38
37
|
I_CoM = np.array(
|
39
38
|
[
|
40
39
|
[ixx, ixy, ixz],
|
@@ -43,17 +42,16 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
|
|
43
42
|
]
|
44
43
|
)
|
45
44
|
|
46
|
-
# Build the 6x6 generalized inertia at the CoM
|
45
|
+
# Build the 6x6 generalized inertia at the CoM.
|
47
46
|
M_CoM = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_CoM)
|
48
47
|
|
49
|
-
# Compute the transform from the inertial frame (CoM) to the link frame
|
48
|
+
# Compute the transform from the inertial frame (CoM) to the link frame.
|
50
49
|
L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)
|
51
50
|
|
52
|
-
# We need its inverse
|
53
|
-
|
54
|
-
CoM_X_L = CoM_H_L.adjoint()
|
51
|
+
# We need its inverse.
|
52
|
+
CoM_X_L = Adjoint.from_transform(transform=L_H_CoM, inverse=True)
|
55
53
|
|
56
|
-
# Express the CoM inertia matrix in the link frame L
|
54
|
+
# Express the CoM inertia matrix in the link frame L.
|
57
55
|
M_L = CoM_X_L.T @ M_CoM @ CoM_X_L
|
58
56
|
|
59
57
|
return M_L.astype(dtype=float)
|
jaxsim/rbda/aba.py
CHANGED
@@ -102,7 +102,7 @@ def aba(
|
|
102
102
|
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
103
103
|
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
104
104
|
|
105
|
-
# Initialize base quantities
|
105
|
+
# Initialize base quantities.
|
106
106
|
if model.floating_base():
|
107
107
|
|
108
108
|
# Base velocity v₀ in body-fixed representation.
|
@@ -121,10 +121,7 @@ def aba(
|
|
121
121
|
# Pass 1
|
122
122
|
# ======
|
123
123
|
|
124
|
-
Pass1Carry = tuple[
|
125
|
-
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
|
126
|
-
]
|
127
|
-
|
124
|
+
Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
|
128
125
|
pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)
|
129
126
|
|
130
127
|
# Propagate kinematics and initialize AB inertia and AB bias forces.
|
@@ -178,10 +175,7 @@ def aba(
|
|
178
175
|
d = jnp.zeros(shape=(model.number_of_links(), 1))
|
179
176
|
u = jnp.zeros(shape=(model.number_of_links(), 1))
|
180
177
|
|
181
|
-
Pass2Carry = tuple[
|
182
|
-
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
|
183
|
-
]
|
184
|
-
|
178
|
+
Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
|
185
179
|
pass_2_carry: Pass2Carry = (U, d, u, MA, pA)
|
186
180
|
|
187
181
|
def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
|
@@ -204,8 +198,8 @@ def aba(
|
|
204
198
|
|
205
199
|
# Propagate them to the parent, handling the base link.
|
206
200
|
def propagate(
|
207
|
-
MA_pA: tuple[jtp.
|
208
|
-
) -> tuple[jtp.
|
201
|
+
MA_pA: tuple[jtp.Matrix, jtp.Matrix]
|
202
|
+
) -> tuple[jtp.Matrix, jtp.Matrix]:
|
209
203
|
|
210
204
|
MA, pA = MA_pA
|
211
205
|
|
@@ -248,7 +242,7 @@ def aba(
|
|
248
242
|
s̈ = jnp.zeros_like(s)
|
249
243
|
a = jnp.zeros_like(v).at[0].set(a0)
|
250
244
|
|
251
|
-
Pass3Carry = tuple[jtp.
|
245
|
+
Pass3Carry = tuple[jtp.Matrix, jtp.Vector]
|
252
246
|
pass_3_carry = (a, s̈)
|
253
247
|
|
254
248
|
def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
|
jaxsim/rbda/collidable_points.py
CHANGED
@@ -80,7 +80,7 @@ def collidable_points_pos_vel(
|
|
80
80
|
# Propagate kinematics
|
81
81
|
# ====================
|
82
82
|
|
83
|
-
PropagateTransformsCarry = tuple[jtp.
|
83
|
+
PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix]
|
84
84
|
propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)
|
85
85
|
|
86
86
|
def propagate_kinematics(
|
@@ -97,7 +97,7 @@ def collidable_points_pos_vel(
|
|
97
97
|
W_Xi_i = W_X_i[λ[i]] @ λi_X_i
|
98
98
|
W_X_i = W_X_i.at[i].set(W_Xi_i)
|
99
99
|
|
100
|
-
# Propagate the 6D velocity
|
100
|
+
# Propagate the 6D velocity.
|
101
101
|
W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
|
102
102
|
W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
|
103
103
|
|
@@ -118,14 +118,15 @@ def collidable_points_pos_vel(
|
|
118
118
|
# ==================================================
|
119
119
|
|
120
120
|
def process_point_kinematics(
|
121
|
-
Li_p_C: jtp.
|
122
|
-
) -> tuple[jtp.
|
123
|
-
|
121
|
+
Li_p_C: jtp.Vector, parent_body: jtp.Int
|
122
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
123
|
+
|
124
|
+
# Compute the position of the collidable point.
|
124
125
|
W_p_Ci = (
|
125
126
|
Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
|
126
127
|
)[0:3]
|
127
128
|
|
128
|
-
# Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}
|
129
|
+
# Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}.
|
129
130
|
CW_vl_WCi = (
|
130
131
|
jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
|
131
132
|
@ W_v_Wi[parent_body].squeeze()
|
@@ -133,7 +134,7 @@ def collidable_points_pos_vel(
|
|
133
134
|
|
134
135
|
return W_p_Ci, CW_vl_WCi
|
135
136
|
|
136
|
-
# Process all the collidable points in parallel
|
137
|
+
# Process all the collidable points in parallel.
|
137
138
|
W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
|
138
139
|
model.kin_dyn_parameters.contact_parameters.point,
|
139
140
|
jnp.array(model.kin_dyn_parameters.contact_parameters.body),
|
jaxsim/rbda/contacts/soft.py
CHANGED
@@ -105,24 +105,24 @@ class SoftContactsParams(ContactsParams):
|
|
105
105
|
- ξ < 1.0: under-damped
|
106
106
|
"""
|
107
107
|
|
108
|
-
# Use symbols for input parameters
|
108
|
+
# Use symbols for input parameters.
|
109
109
|
ξ = damping_ratio
|
110
110
|
δ_max = max_penetration
|
111
111
|
μc = static_friction_coefficient
|
112
112
|
|
113
|
-
# Compute the total mass of the model
|
113
|
+
# Compute the total mass of the model.
|
114
114
|
m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
|
115
115
|
|
116
|
-
# Rename the standard gravity
|
116
|
+
# Rename the standard gravity.
|
117
117
|
g = standard_gravity
|
118
118
|
|
119
|
-
# Compute the average support force on each collidable point
|
119
|
+
# Compute the average support force on each collidable point.
|
120
120
|
f_average = m * g / number_of_active_collidable_points_steady_state
|
121
121
|
|
122
|
-
# Compute the stiffness to get the desired steady-state penetration
|
122
|
+
# Compute the stiffness to get the desired steady-state penetration.
|
123
123
|
K = f_average / jnp.power(δ_max, 3 / 2)
|
124
124
|
|
125
|
-
# Compute the damping using the damping ratio
|
125
|
+
# Compute the damping using the damping ratio.
|
126
126
|
critical_damping = 2 * jnp.sqrt(K * m)
|
127
127
|
D = ξ * critical_damping
|
128
128
|
|
@@ -151,14 +151,16 @@ class SoftContacts(ContactModel):
|
|
151
151
|
default_factory=SoftContactsParams
|
152
152
|
)
|
153
153
|
|
154
|
-
terrain: Terrain = dataclasses.field(
|
154
|
+
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
155
|
+
default_factory=FlatTerrain
|
156
|
+
)
|
155
157
|
|
156
158
|
def compute_contact_forces(
|
157
159
|
self,
|
158
160
|
position: jtp.Vector,
|
159
161
|
velocity: jtp.Vector,
|
160
162
|
tangential_deformation: jtp.Vector,
|
161
|
-
) -> tuple[jtp.Vector, tuple[jtp.Vector
|
163
|
+
) -> tuple[jtp.Vector, tuple[jtp.Vector]]:
|
162
164
|
"""
|
163
165
|
Compute the contact forces and material deformation rate.
|
164
166
|
|
@@ -188,18 +190,18 @@ class SoftContacts(ContactModel):
|
|
188
190
|
# Normal force computation
|
189
191
|
# ========================
|
190
192
|
|
191
|
-
# Unpack the position of the collidable point
|
193
|
+
# Unpack the position of the collidable point.
|
192
194
|
px, py, pz = W_p_C = position.squeeze()
|
193
195
|
vx, vy, vz = W_ṗ_C = velocity.squeeze()
|
194
196
|
|
195
|
-
# Compute the terrain normal and the contact depth
|
197
|
+
# Compute the terrain normal and the contact depth.
|
196
198
|
n̂ = self.terrain.normal(x=px, y=py).squeeze()
|
197
199
|
h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
|
198
200
|
|
199
|
-
# Compute the penetration depth normal to the terrain
|
201
|
+
# Compute the penetration depth normal to the terrain.
|
200
202
|
δ = jnp.maximum(0.0, jnp.dot(h, n̂))
|
201
203
|
|
202
|
-
# Compute the penetration normal velocity
|
204
|
+
# Compute the penetration normal velocity.
|
203
205
|
δ̇ = -jnp.dot(W_ṗ_C, n̂)
|
204
206
|
|
205
207
|
# Non-linear spring-damper model.
|
@@ -210,10 +212,10 @@ class SoftContacts(ContactModel):
|
|
210
212
|
on_false=jnp.array(0.0),
|
211
213
|
)
|
212
214
|
|
213
|
-
# Prevent negative normal forces that might occur when δ̇ is largely negative
|
215
|
+
# Prevent negative normal forces that might occur when δ̇ is largely negative.
|
214
216
|
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
|
215
217
|
|
216
|
-
# Compute the 3D linear force in C[W] frame
|
218
|
+
# Compute the 3D linear force in C[W] frame.
|
217
219
|
force_normal = force_normal_mag * n̂
|
218
220
|
|
219
221
|
# ====================================
|
@@ -230,11 +232,11 @@ class SoftContacts(ContactModel):
|
|
230
232
|
)
|
231
233
|
|
232
234
|
def with_no_friction():
|
233
|
-
# Compute 6D mixed force in C[W]
|
235
|
+
# Compute 6D mixed force in C[W].
|
234
236
|
CW_f_lin = force_normal
|
235
237
|
CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
|
236
238
|
|
237
|
-
# Compute lin-ang 6D forces (inertial representation)
|
239
|
+
# Compute lin-ang 6D forces (inertial representation).
|
238
240
|
W_f = W_Xf_CW @ CW_f
|
239
241
|
|
240
242
|
return W_f, (ṁ,)
|
@@ -258,32 +260,32 @@ class SoftContacts(ContactModel):
|
|
258
260
|
return jnp.zeros(6), (ṁ,)
|
259
261
|
|
260
262
|
def below_terrain():
|
261
|
-
# Decompose the velocity in normal and tangential components
|
263
|
+
# Decompose the velocity in normal and tangential components.
|
262
264
|
v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
|
263
265
|
v_tangential = W_ṗ_C - v_normal
|
264
266
|
|
265
|
-
# Compute the tangential force. If inside the friction cone, the contact
|
267
|
+
# Compute the tangential force. If inside the friction cone, the contact.
|
266
268
|
f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
|
267
269
|
|
268
270
|
def sticking_contact():
|
269
|
-
# Sum the normal and tangential forces, and create the 6D force
|
271
|
+
# Sum the normal and tangential forces, and create the 6D force.
|
270
272
|
CW_f_stick = force_normal + f_tangential
|
271
273
|
CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
|
272
274
|
|
273
|
-
# In this case the 3D material deformation is the tangential velocity
|
275
|
+
# In this case the 3D material deformation is the tangential velocity.
|
274
276
|
ṁ = v_tangential
|
275
277
|
|
276
278
|
# Return the 6D force in the contact frame and
|
277
|
-
# the deformation derivative
|
279
|
+
# the deformation derivative.
|
278
280
|
return CW_f, ṁ
|
279
281
|
|
280
282
|
def slipping_contact():
|
281
|
-
# Project the force to the friction cone boundary
|
283
|
+
# Project the force to the friction cone boundary.
|
282
284
|
f_tangential_projected = (μ * force_normal_mag) * (
|
283
285
|
f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9)
|
284
286
|
)
|
285
287
|
|
286
|
-
# Sum the normal and tangential forces, and create the 6D force
|
288
|
+
# Sum the normal and tangential forces, and create the 6D force.
|
287
289
|
CW_f_slip = force_normal + f_tangential_projected
|
288
290
|
CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
|
289
291
|
|
@@ -297,7 +299,7 @@ class SoftContacts(ContactModel):
|
|
297
299
|
ṁ = (f_tangential_projected - α * m) / β
|
298
300
|
|
299
301
|
# Return the 6D force in the contact frame and
|
300
|
-
# the deformation derivative
|
302
|
+
# the deformation derivative.
|
301
303
|
return CW_f, ṁ
|
302
304
|
|
303
305
|
CW_f, ṁ = jax.lax.cond(
|
@@ -307,10 +309,10 @@ class SoftContacts(ContactModel):
|
|
307
309
|
operand=None,
|
308
310
|
)
|
309
311
|
|
310
|
-
# Express the 6D force in the world frame
|
312
|
+
# Express the 6D force in the world frame.
|
311
313
|
W_f = W_Xf_CW @ CW_f
|
312
314
|
|
313
|
-
# Return the 6D force in the world frame and the deformation derivative
|
315
|
+
# Return the 6D force in the world frame and the deformation derivative.
|
314
316
|
return W_f, (ṁ,)
|
315
317
|
|
316
318
|
# (W_f, (ṁ,))
|
@@ -321,7 +323,7 @@ class SoftContacts(ContactModel):
|
|
321
323
|
operand=None,
|
322
324
|
)
|
323
325
|
|
324
|
-
# (W_f, m
|
326
|
+
# (W_f, (ṁ,))
|
325
327
|
return jax.lax.cond(
|
326
328
|
pred=(μ == 0.0),
|
327
329
|
true_fun=lambda _: with_no_friction(),
|
jaxsim/rbda/crba.py
CHANGED
@@ -45,7 +45,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
|
|
45
45
|
# Propagate kinematics
|
46
46
|
# ====================
|
47
47
|
|
48
|
-
ForwardPassCarry = tuple[jtp.
|
48
|
+
ForwardPassCarry = tuple[jtp.Matrix]
|
49
49
|
forward_pass_carry: ForwardPassCarry = (i_X_0,)
|
50
50
|
|
51
51
|
def propagate_kinematics(
|
@@ -71,7 +71,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
|
|
71
71
|
|
72
72
|
M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))
|
73
73
|
|
74
|
-
BackwardPassCarry = tuple[jtp.
|
74
|
+
BackwardPassCarry = tuple[jtp.Matrix, jtp.Matrix]
|
75
75
|
backward_pass_carry: BackwardPassCarry = (Mc, M)
|
76
76
|
|
77
77
|
def backward_pass(
|
@@ -90,7 +90,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
|
|
90
90
|
|
91
91
|
j = i
|
92
92
|
|
93
|
-
CarryInnerFn = tuple[jtp.Int, jtp.
|
93
|
+
CarryInnerFn = tuple[jtp.Int, jtp.Matrix, jtp.Matrix]
|
94
94
|
carry_inner_fn = (j, Fi, M)
|
95
95
|
|
96
96
|
def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn:
|
@@ -61,7 +61,7 @@ def forward_kinematics_model(
|
|
61
61
|
# Propagate the kinematics
|
62
62
|
# ========================
|
63
63
|
|
64
|
-
PropagateKinematicsCarry = tuple[jtp.
|
64
|
+
PropagateKinematicsCarry = tuple[jtp.Matrix]
|
65
65
|
propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i,)
|
66
66
|
|
67
67
|
def propagate_kinematics(
|
jaxsim/rbda/jacobian.py
CHANGED
@@ -50,7 +50,7 @@ def jacobian(
|
|
50
50
|
# Propagate kinematics
|
51
51
|
# ====================
|
52
52
|
|
53
|
-
PropagateKinematicsCarry = tuple[jtp.
|
53
|
+
PropagateKinematicsCarry = tuple[jtp.Matrix]
|
54
54
|
propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,)
|
55
55
|
|
56
56
|
def propagate_kinematics(
|
@@ -86,9 +86,9 @@ def jacobian(
|
|
86
86
|
# Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
|
87
87
|
κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index]
|
88
88
|
|
89
|
-
def compute_jacobian(J: jtp.
|
89
|
+
def compute_jacobian(J: jtp.Matrix, i: jtp.Int) -> tuple[jtp.Matrix, None]:
|
90
90
|
|
91
|
-
def update_jacobian(J: jtp.
|
91
|
+
def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix:
|
92
92
|
|
93
93
|
ii = i - 1
|
94
94
|
|
@@ -155,16 +155,16 @@ def jacobian_full_doubly_left(
|
|
155
155
|
B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
156
156
|
B_X_i = B_X_i.at[0].set(jnp.eye(6))
|
157
157
|
|
158
|
-
#
|
159
|
-
# Compute doubly-left Jacobian
|
160
|
-
#
|
158
|
+
# =================================
|
159
|
+
# Compute doubly-left full Jacobian
|
160
|
+
# =================================
|
161
161
|
|
162
162
|
# Allocate the Jacobian matrix.
|
163
163
|
# The Jbb section of the doubly-left Jacobian is an identity matrix.
|
164
164
|
J = jnp.zeros(shape=(6, 6 + model.dofs()))
|
165
165
|
J = J.at[0:6, 0:6].set(jnp.eye(6))
|
166
166
|
|
167
|
-
ComputeFullJacobianCarry = tuple[jtp.
|
167
|
+
ComputeFullJacobianCarry = tuple[jtp.Matrix, jtp.Matrix]
|
168
168
|
compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J)
|
169
169
|
|
170
170
|
def compute_full_jacobian(
|
@@ -261,7 +261,7 @@ def jacobian_derivative_full_doubly_left(
|
|
261
261
|
J̇ = jnp.zeros(shape=(6, 6 + model.dofs()))
|
262
262
|
|
263
263
|
ComputeFullJacobianDerivativeCarry = tuple[
|
264
|
-
jtp.
|
264
|
+
jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix
|
265
265
|
]
|
266
266
|
|
267
267
|
compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = (
|
jaxsim/rbda/rnea.py
CHANGED
@@ -132,7 +132,7 @@ def rnea(
|
|
132
132
|
# Pass 1
|
133
133
|
# ======
|
134
134
|
|
135
|
-
ForwardPassCarry = Tuple[jtp.
|
135
|
+
ForwardPassCarry = Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
|
136
136
|
forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)
|
137
137
|
|
138
138
|
def forward_pass(
|
@@ -186,7 +186,7 @@ def rnea(
|
|
186
186
|
|
187
187
|
τ = jnp.zeros_like(s)
|
188
188
|
|
189
|
-
BackwardPassCarry = Tuple[jtp.
|
189
|
+
BackwardPassCarry = Tuple[jtp.Vector, jtp.Matrix]
|
190
190
|
backward_pass_carry: BackwardPassCarry = (τ, f)
|
191
191
|
|
192
192
|
def backward_pass(
|
@@ -201,7 +201,7 @@ def rnea(
|
|
201
201
|
τ = τ.at[ii].set(τ_i.squeeze())
|
202
202
|
|
203
203
|
# Propagate the force to the parent link.
|
204
|
-
def update_f(f: jtp.
|
204
|
+
def update_f(f: jtp.Matrix) -> jtp.Matrix:
|
205
205
|
|
206
206
|
f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
|
207
207
|
f = f.at[λ[i]].set(f_λi)
|
jaxsim/rbda/utils.py
CHANGED
@@ -19,7 +19,7 @@ def process_inputs(
|
|
19
19
|
joint_accelerations: jtp.VectorLike | None = None,
|
20
20
|
joint_forces: jtp.VectorLike | None = None,
|
21
21
|
link_forces: jtp.MatrixLike | None = None,
|
22
|
-
standard_gravity: jtp.
|
22
|
+
standard_gravity: jtp.ScalarLike | None = None,
|
23
23
|
) -> tuple[
|
24
24
|
jtp.Vector,
|
25
25
|
jtp.Vector,
|
jaxsim/terrain/terrain.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import abc
|
4
|
+
import dataclasses
|
2
5
|
|
3
6
|
import jax.numpy as jnp
|
4
7
|
import jax_dataclasses
|
@@ -7,22 +10,23 @@ import jaxsim.typing as jtp
|
|
7
10
|
|
8
11
|
|
9
12
|
class Terrain(abc.ABC):
|
13
|
+
|
10
14
|
delta = 0.010
|
11
15
|
|
12
16
|
@abc.abstractmethod
|
13
|
-
def height(self, x:
|
17
|
+
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
14
18
|
pass
|
15
19
|
|
16
|
-
def normal(self, x:
|
20
|
+
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
|
17
21
|
"""
|
18
22
|
Compute the normal vector of the terrain at a specific (x, y) location.
|
19
23
|
|
20
24
|
Args:
|
21
|
-
x
|
22
|
-
y
|
25
|
+
x: The x-coordinate of the location.
|
26
|
+
y: The y-coordinate of the location.
|
23
27
|
|
24
28
|
Returns:
|
25
|
-
|
29
|
+
The normal vector of the terrain surface at the specified location.
|
26
30
|
"""
|
27
31
|
|
28
32
|
# https://stackoverflow.com/a/5282364
|
@@ -40,43 +44,117 @@ class Terrain(abc.ABC):
|
|
40
44
|
|
41
45
|
@jax_dataclasses.pytree_dataclass
|
42
46
|
class FlatTerrain(Terrain):
|
43
|
-
|
44
|
-
|
47
|
+
|
48
|
+
z: float = dataclasses.field(default=0.0, kw_only=True)
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def build(height: jtp.FloatLike) -> FlatTerrain:
|
52
|
+
|
53
|
+
return FlatTerrain(z=float(height))
|
54
|
+
|
55
|
+
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
56
|
+
|
57
|
+
return jnp.array(self.z, dtype=float)
|
58
|
+
|
59
|
+
def __hash__(self) -> int:
|
60
|
+
|
61
|
+
return hash(self.z)
|
62
|
+
|
63
|
+
def __eq__(self, other: FlatTerrain) -> bool:
|
64
|
+
|
65
|
+
if not isinstance(other, FlatTerrain):
|
66
|
+
return False
|
67
|
+
|
68
|
+
return self.z == other.z
|
45
69
|
|
46
70
|
|
47
71
|
@jax_dataclasses.pytree_dataclass
|
48
|
-
class PlaneTerrain(
|
49
|
-
|
72
|
+
class PlaneTerrain(FlatTerrain):
|
73
|
+
|
74
|
+
plane_normal: tuple[float, float, float] = jax_dataclasses.field(
|
75
|
+
default=(0.0, 0.0, 0.0), kw_only=True
|
76
|
+
)
|
50
77
|
|
51
78
|
@staticmethod
|
52
|
-
def build(
|
79
|
+
def build(
|
80
|
+
plane_normal: jtp.VectorLike, plane_height_over_origin: jtp.FloatLike = 0.0
|
81
|
+
) -> PlaneTerrain:
|
53
82
|
"""
|
54
83
|
Create a PlaneTerrain instance with a specified plane normal vector.
|
55
84
|
|
56
85
|
Args:
|
57
|
-
plane_normal
|
86
|
+
plane_normal: The normal vector of the terrain plane.
|
87
|
+
plane_height_over_origin: The height of the plane over the origin.
|
58
88
|
|
59
89
|
Returns:
|
60
90
|
PlaneTerrain: A PlaneTerrain instance.
|
61
91
|
"""
|
62
|
-
if not isinstance(plane_normal, list):
|
63
|
-
raise TypeError(
|
64
|
-
f"Expected a list for the plane normal vector, got: {type(plane_normal)}."
|
65
|
-
)
|
66
92
|
|
67
|
-
|
93
|
+
plane_normal = jnp.array(plane_normal, dtype=float)
|
94
|
+
plane_height_over_origin = jnp.array(plane_height_over_origin, dtype=float)
|
95
|
+
|
96
|
+
if plane_normal.shape != (3,):
|
97
|
+
msg = "Expected a 3D vector for the plane normal, got '{}'."
|
98
|
+
raise ValueError(msg.format(plane_normal.shape))
|
68
99
|
|
69
|
-
|
100
|
+
# Make sure that the plane normal is a unit vector.
|
101
|
+
plane_normal = plane_normal / jnp.linalg.norm(plane_normal)
|
102
|
+
|
103
|
+
return PlaneTerrain(
|
104
|
+
z=float(plane_height_over_origin),
|
105
|
+
plane_normal=tuple(plane_normal.tolist()),
|
106
|
+
)
|
107
|
+
|
108
|
+
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
70
109
|
"""
|
71
110
|
Compute the height of the terrain at a specific (x, y) location on a plane.
|
72
111
|
|
73
112
|
Args:
|
74
|
-
x
|
75
|
-
y
|
113
|
+
x: The x-coordinate of the location.
|
114
|
+
y: The y-coordinate of the location.
|
76
115
|
|
77
116
|
Returns:
|
78
|
-
|
117
|
+
The height of the terrain at the specified location on the plane.
|
79
118
|
"""
|
80
119
|
|
81
|
-
|
82
|
-
|
120
|
+
# Equation of the plane: A x + B y + C z + D = 0
|
121
|
+
# Normal vector coordinates: (A, B, C)
|
122
|
+
# The height over the origin: -D/C
|
123
|
+
|
124
|
+
# Get the plane equation coefficients from the terrain normal.
|
125
|
+
A, B, C = self.plane_normal
|
126
|
+
|
127
|
+
# Compute the final coefficient D considering the terrain height.
|
128
|
+
D = -C * self.z
|
129
|
+
|
130
|
+
# Invert the plane equation to get the height at the given (x, y) coordinates.
|
131
|
+
return jnp.array(-(A * x + B * y + D) / C).astype(float)
|
132
|
+
|
133
|
+
def __hash__(self) -> int:
|
134
|
+
|
135
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
136
|
+
|
137
|
+
return hash(
|
138
|
+
(
|
139
|
+
hash(self.z),
|
140
|
+
HashedNumpyArray.hash_of_array(
|
141
|
+
array=jnp.array(self.plane_normal, dtype=float)
|
142
|
+
),
|
143
|
+
)
|
144
|
+
)
|
145
|
+
|
146
|
+
def __eq__(self, other: PlaneTerrain) -> bool:
|
147
|
+
|
148
|
+
if not isinstance(other, PlaneTerrain):
|
149
|
+
return False
|
150
|
+
|
151
|
+
if not (
|
152
|
+
jnp.allclose(self.z, other.z)
|
153
|
+
and jnp.allclose(
|
154
|
+
jnp.array(self.plane_normal, dtype=float),
|
155
|
+
jnp.array(other.plane_normal, dtype=float),
|
156
|
+
)
|
157
|
+
):
|
158
|
+
return False
|
159
|
+
|
160
|
+
return True
|
jaxsim/typing.py
CHANGED
@@ -7,14 +7,14 @@ import jax
|
|
7
7
|
# JAX types
|
8
8
|
# =========
|
9
9
|
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
10
|
+
Array = jax.Array
|
11
|
+
Scalar = Array
|
12
|
+
Vector = Array
|
13
|
+
Matrix = Array
|
14
14
|
|
15
|
-
|
16
|
-
|
17
|
-
|
15
|
+
Int = Scalar
|
16
|
+
Bool = Scalar
|
17
|
+
Float = Scalar
|
18
18
|
|
19
19
|
PyTree = (
|
20
20
|
dict[Hashable, "PyTree"] | list["PyTree"] | tuple["PyTree"] | None | jax.Array | Any
|
@@ -24,19 +24,11 @@ PyTree = (
|
|
24
24
|
# Mixed JAX / NumPy types
|
25
25
|
# =======================
|
26
26
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
27
|
+
ArrayLike = jax.typing.ArrayLike | tuple
|
28
|
+
ScalarLike = int | float | Scalar | ArrayLike
|
29
|
+
VectorLike = Vector | ArrayLike | tuple
|
30
|
+
MatrixLike = Matrix | ArrayLike
|
31
31
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
ScalarLike = Scalar | int | float
|
37
|
-
ArrayLike = Array
|
38
|
-
VectorLike = Vector
|
39
|
-
MatrixLike = Matrix
|
40
|
-
IntLike = Int
|
41
|
-
BoolLike = Bool
|
42
|
-
FloatLike = Float
|
32
|
+
IntLike = int | Int | jax.typing.ArrayLike
|
33
|
+
BoolLike = bool | Bool | jax.typing.ArrayLike
|
34
|
+
FloatLike = float | Float | jax.typing.ArrayLike
|