jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +88 -72
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +5 -0
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev364.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py
CHANGED
@@ -13,15 +13,11 @@ import rod
|
|
13
13
|
from jax_dataclasses import Static
|
14
14
|
|
15
15
|
import jaxsim.api as js
|
16
|
-
import jaxsim.
|
17
|
-
import jaxsim.physics.algos.crba
|
18
|
-
import jaxsim.physics.algos.forward_kinematics
|
19
|
-
import jaxsim.physics.algos.rnea
|
20
|
-
import jaxsim.physics.model.physics_model
|
16
|
+
import jaxsim.parsers.descriptions
|
21
17
|
import jaxsim.typing as jtp
|
22
|
-
from jaxsim.
|
23
|
-
|
24
|
-
from
|
18
|
+
from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability
|
19
|
+
|
20
|
+
from .common import VelRepr
|
25
21
|
|
26
22
|
|
27
23
|
@jax_dataclasses.pytree_dataclass
|
@@ -32,35 +28,22 @@ class JaxSimModel(JaxsimDataclass):
|
|
32
28
|
|
33
29
|
model_name: Static[str]
|
34
30
|
|
35
|
-
|
36
|
-
repr=False
|
31
|
+
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
|
32
|
+
default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
|
37
33
|
)
|
38
34
|
|
39
|
-
terrain: Static[Terrain] = dataclasses.field(default=FlatTerrain(), repr=False)
|
40
|
-
|
41
35
|
built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
|
42
|
-
repr=False,
|
36
|
+
default=None, repr=False, compare=False, hash=False
|
43
37
|
)
|
44
38
|
|
45
|
-
|
46
|
-
|
47
|
-
)
|
39
|
+
description: Static[
|
40
|
+
HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
|
41
|
+
] = dataclasses.field(default=None, repr=False, compare=False, hash=False)
|
48
42
|
|
49
|
-
|
50
|
-
|
43
|
+
kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
|
44
|
+
dataclasses.field(default=None, repr=False, compare=False, hash=False)
|
51
45
|
)
|
52
46
|
|
53
|
-
def __post_init__(self):
|
54
|
-
|
55
|
-
# These attributes are Static so that we can use `jax.vmap` and `jax.lax.scan`
|
56
|
-
# over the all links and joints
|
57
|
-
with self.mutable_context(
|
58
|
-
mutability=Mutability.MUTABLE_NO_VALIDATION,
|
59
|
-
restore_after_exception=False,
|
60
|
-
):
|
61
|
-
self._number_of_links = len(self.physics_model.description.links_dict)
|
62
|
-
self._number_of_joints = len(self.physics_model.description.joints_dict)
|
63
|
-
|
64
47
|
# ========================
|
65
48
|
# Initialization and state
|
66
49
|
# ========================
|
@@ -69,7 +52,6 @@ class JaxSimModel(JaxsimDataclass):
|
|
69
52
|
def build_from_model_description(
|
70
53
|
model_description: str | pathlib.Path | rod.Model,
|
71
54
|
model_name: str | None = None,
|
72
|
-
gravity: jtp.Array = jaxsim.physics.default_gravity(),
|
73
55
|
is_urdf: bool | None = None,
|
74
56
|
considered_joints: list[str] | None = None,
|
75
57
|
) -> JaxSimModel:
|
@@ -83,7 +65,6 @@ class JaxSimModel(JaxsimDataclass):
|
|
83
65
|
model_name:
|
84
66
|
The optional name of the model that overrides the one in
|
85
67
|
the description.
|
86
|
-
gravity: The 3D gravity vector.
|
87
68
|
is_urdf:
|
88
69
|
Whether the model description is a URDF or an SDF. This is
|
89
70
|
automatically inferred if the model description is a path to a file.
|
@@ -109,13 +90,10 @@ class JaxSimModel(JaxsimDataclass):
|
|
109
90
|
considered_joints=considered_joints
|
110
91
|
)
|
111
92
|
|
112
|
-
# Create the physics model from the model description
|
113
|
-
physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
|
114
|
-
model_description=intermediate_description, gravity=gravity
|
115
|
-
)
|
116
|
-
|
117
93
|
# Build the model
|
118
|
-
model = JaxSimModel.build(
|
94
|
+
model = JaxSimModel.build(
|
95
|
+
model_description=intermediate_description, model_name=model_name
|
96
|
+
)
|
119
97
|
|
120
98
|
# Store the origin of the model, in case downstream logic needs it
|
121
99
|
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
@@ -125,14 +103,16 @@ class JaxSimModel(JaxsimDataclass):
|
|
125
103
|
|
126
104
|
@staticmethod
|
127
105
|
def build(
|
128
|
-
|
106
|
+
model_description: jaxsim.parsers.descriptions.ModelDescription,
|
129
107
|
model_name: str | None = None,
|
130
108
|
) -> JaxSimModel:
|
131
109
|
"""
|
132
|
-
Build a Model object from
|
110
|
+
Build a Model object from an intermediate model description.
|
133
111
|
|
134
112
|
Args:
|
135
|
-
|
113
|
+
model_description:
|
114
|
+
The intermediate model description defining the kinematics and dynamics
|
115
|
+
of the model.
|
136
116
|
model_name:
|
137
117
|
The optional name of the model overriding the physics model name.
|
138
118
|
|
@@ -141,12 +121,16 @@ class JaxSimModel(JaxsimDataclass):
|
|
141
121
|
"""
|
142
122
|
|
143
123
|
# Set the model name (if not provided, use the one from the model description)
|
144
|
-
model_name =
|
145
|
-
model_name if model_name is not None else physics_model.description.name
|
146
|
-
)
|
124
|
+
model_name = model_name if model_name is not None else model_description.name
|
147
125
|
|
148
126
|
# Build the model
|
149
|
-
model = JaxSimModel(
|
127
|
+
model = JaxSimModel(
|
128
|
+
model_name=model_name,
|
129
|
+
description=HashlessObject(obj=model_description),
|
130
|
+
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
|
131
|
+
model_description=model_description
|
132
|
+
),
|
133
|
+
)
|
150
134
|
|
151
135
|
return model
|
152
136
|
|
@@ -175,7 +159,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
175
159
|
The base link is included in the count and its index is always 0.
|
176
160
|
"""
|
177
161
|
|
178
|
-
return self.
|
162
|
+
return self.kin_dyn_parameters.number_of_links()
|
179
163
|
|
180
164
|
def number_of_joints(self) -> jtp.Int:
|
181
165
|
"""
|
@@ -185,7 +169,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
185
169
|
The number of joints in the model.
|
186
170
|
"""
|
187
171
|
|
188
|
-
return self.
|
172
|
+
return self.kin_dyn_parameters.number_of_joints()
|
189
173
|
|
190
174
|
# =================
|
191
175
|
# Base link methods
|
@@ -199,7 +183,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
199
183
|
True if the model is floating-base, False otherwise.
|
200
184
|
"""
|
201
185
|
|
202
|
-
return self.
|
186
|
+
return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6)
|
203
187
|
|
204
188
|
def base_link(self) -> str:
|
205
189
|
"""
|
@@ -207,9 +191,12 @@ class JaxSimModel(JaxsimDataclass):
|
|
207
191
|
|
208
192
|
Returns:
|
209
193
|
The name of the base link.
|
194
|
+
|
195
|
+
Note:
|
196
|
+
By default, the base link is the root of the kinematic tree.
|
210
197
|
"""
|
211
198
|
|
212
|
-
return self.
|
199
|
+
return self.link_names()[0]
|
213
200
|
|
214
201
|
# =====================
|
215
202
|
# Joint-related methods
|
@@ -227,7 +214,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
227
214
|
the number of joints. In the future, this could be different.
|
228
215
|
"""
|
229
216
|
|
230
|
-
return
|
217
|
+
return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]))
|
231
218
|
|
232
219
|
def joint_names(self) -> tuple[str, ...]:
|
233
220
|
"""
|
@@ -237,7 +224,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
237
224
|
The names of the joints in the model.
|
238
225
|
"""
|
239
226
|
|
240
|
-
return
|
227
|
+
return self.kin_dyn_parameters.joint_model.joint_names[1:]
|
241
228
|
|
242
229
|
# ====================
|
243
230
|
# Link-related methods
|
@@ -251,7 +238,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
251
238
|
The names of the links in the model.
|
252
239
|
"""
|
253
240
|
|
254
|
-
return
|
241
|
+
return self.kin_dyn_parameters.link_names
|
255
242
|
|
256
243
|
|
257
244
|
# =====================
|
@@ -273,25 +260,17 @@ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimMode
|
|
273
260
|
return a copy of the input model.
|
274
261
|
"""
|
275
262
|
|
276
|
-
if len(considered_joints) == 0:
|
277
|
-
return model.copy()
|
278
|
-
|
279
263
|
# Reduce the model description.
|
280
264
|
# If considered_joints contains joints not existing in the model, the method
|
281
265
|
# will raise an exception.
|
282
|
-
reduced_intermediate_description = model.
|
266
|
+
reduced_intermediate_description = model.description.obj.reduce(
|
283
267
|
considered_joints=list(considered_joints)
|
284
268
|
)
|
285
269
|
|
286
|
-
# Create the physics model from the reduced model description
|
287
|
-
physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
|
288
|
-
model_description=reduced_intermediate_description,
|
289
|
-
gravity=model.physics_model.gravity[0:3],
|
290
|
-
)
|
291
|
-
|
292
270
|
# Build the reduced model
|
293
271
|
reduced_model = JaxSimModel.build(
|
294
|
-
|
272
|
+
model_description=reduced_intermediate_description,
|
273
|
+
model_name=model.name(),
|
295
274
|
)
|
296
275
|
|
297
276
|
# Store the origin of the model, in case downstream logic needs it
|
@@ -327,43 +306,21 @@ def total_mass(model: JaxSimModel) -> jtp.Float:
|
|
327
306
|
)
|
328
307
|
|
329
308
|
|
330
|
-
# ==============
|
331
|
-
# Center of mass
|
332
|
-
# ==============
|
333
|
-
|
334
|
-
|
335
309
|
@jax.jit
|
336
|
-
def
|
310
|
+
def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
|
337
311
|
"""
|
338
|
-
Compute the
|
312
|
+
Compute the spatial 6D inertia matrices of all links of the model.
|
339
313
|
|
340
314
|
Args:
|
341
315
|
model: The model to consider.
|
342
|
-
data: The data of the considered model.
|
343
316
|
|
344
317
|
Returns:
|
345
|
-
|
318
|
+
A 3D array containing the stacked spatial 6D inertia matrices of the links.
|
346
319
|
"""
|
347
320
|
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
W_H_B = data.base_transform()
|
352
|
-
B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
|
353
|
-
|
354
|
-
def B_p̃_LCoM(i) -> jtp.Vector:
|
355
|
-
m = js.link.mass(model=model, link_index=i)
|
356
|
-
L_p_LCoM = js.link.com_position(
|
357
|
-
model=model, data=data, link_index=i, in_link_frame=True
|
358
|
-
)
|
359
|
-
return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
|
360
|
-
|
361
|
-
com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
|
362
|
-
|
363
|
-
B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
|
364
|
-
B_p̃_CoM = B_p̃_CoM.at[3].set(1)
|
365
|
-
|
366
|
-
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
|
321
|
+
return jax.vmap(js.kin_dyn_parameters.LinkParameters.spatial_inertia)(
|
322
|
+
model.kin_dyn_parameters.link_parameters
|
323
|
+
)
|
367
324
|
|
368
325
|
|
369
326
|
# ==============================
|
@@ -385,10 +342,11 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp
|
|
385
342
|
The first axis is the link index.
|
386
343
|
"""
|
387
344
|
|
388
|
-
W_H_LL = jaxsim.
|
389
|
-
model=model
|
390
|
-
|
391
|
-
|
345
|
+
W_H_LL = jaxsim.rbda.forward_kinematics_model(
|
346
|
+
model=model,
|
347
|
+
base_position=data.base_position(),
|
348
|
+
base_quaternion=data.base_orientation(dcm=False),
|
349
|
+
joint_positions=data.joint_positions(model=model),
|
392
350
|
)
|
393
351
|
|
394
352
|
return jnp.atleast_3d(W_H_LL).astype(float)
|
@@ -424,51 +382,64 @@ def generalized_free_floating_jacobian(
|
|
424
382
|
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
425
383
|
)
|
426
384
|
|
427
|
-
#
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
385
|
+
# Compute the doubly-left free-floating full jacobian.
|
386
|
+
B_J_full_WX_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
|
387
|
+
model=model,
|
388
|
+
joint_positions=data.joint_positions(),
|
389
|
+
)
|
390
|
+
|
391
|
+
# Update the input velocity representation such that `J_WL_I @ I_ν`.
|
392
|
+
match data.velocity_representation:
|
432
393
|
case VelRepr.Inertial:
|
433
|
-
|
394
|
+
W_H_B = data.base_transform()
|
395
|
+
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
|
396
|
+
B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
|
397
|
+
B_X_W, jnp.eye(model.dofs())
|
398
|
+
)
|
434
399
|
|
435
400
|
case VelRepr.Body:
|
436
|
-
|
437
|
-
def to_output(W_J_WL: jtp.Matrix) -> jtp.Matrix:
|
438
|
-
W_H_B = data.base_transform()
|
439
|
-
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
|
440
|
-
return B_X_W @ W_J_WL
|
401
|
+
B_J_full_WX_I = B_J_full_WX_B
|
441
402
|
|
442
403
|
case VelRepr.Mixed:
|
404
|
+
W_R_B = data.base_orientation(dcm=True)
|
405
|
+
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
|
406
|
+
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
|
407
|
+
B_J_full_WX_I = B_J_full_WX_BW = (
|
408
|
+
B_J_full_WX_B
|
409
|
+
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
410
|
+
)
|
411
|
+
|
412
|
+
case _:
|
413
|
+
raise ValueError(data.velocity_representation)
|
414
|
+
|
415
|
+
# Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
|
416
|
+
match output_vel_repr:
|
417
|
+
case VelRepr.Inertial:
|
418
|
+
W_H_B = data.base_transform()
|
419
|
+
W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
|
420
|
+
O_J_full_WX_I = W_J_full_WX_I = W_X_B @ B_J_full_WX_I
|
443
421
|
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
422
|
+
case VelRepr.Body:
|
423
|
+
O_J_full_WX_I = B_J_full_WX_I
|
424
|
+
|
425
|
+
case VelRepr.Mixed:
|
426
|
+
W_R_B = data.base_orientation(dcm=True)
|
427
|
+
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
|
428
|
+
BW_X_B = jaxlie.SE3.from_matrix(BW_H_B).adjoint()
|
429
|
+
O_J_full_WX_I = BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I
|
449
430
|
|
450
431
|
case _:
|
451
432
|
raise ValueError(output_vel_repr)
|
452
433
|
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
# This is necessary because for example the body-fixed free-floating jacobian
|
459
|
-
# of a link is L_J_WL, but here being inside model we need B_J_WL.
|
460
|
-
J_free_floating = jax.vmap(
|
461
|
-
lambda i: to_output(
|
462
|
-
W_J_WL=js.link.jacobian(
|
463
|
-
model=model,
|
464
|
-
data=data,
|
465
|
-
link_index=i,
|
466
|
-
output_vel_repr=VelRepr.Inertial,
|
467
|
-
)
|
434
|
+
κ_bool = model.kin_dyn_parameters.support_body_array_bool
|
435
|
+
|
436
|
+
O_J_WL_I = jax.vmap(
|
437
|
+
lambda κ: jnp.where(
|
438
|
+
jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I)
|
468
439
|
)
|
469
|
-
)(
|
440
|
+
)(κ_bool)
|
470
441
|
|
471
|
-
return
|
442
|
+
return O_J_WL_I
|
472
443
|
|
473
444
|
|
474
445
|
@functools.partial(jax.jit, static_argnames=["prefer_aba"])
|
@@ -477,7 +448,7 @@ def forward_dynamics(
|
|
477
448
|
data: js.data.JaxSimModelData,
|
478
449
|
*,
|
479
450
|
joint_forces: jtp.VectorLike | None = None,
|
480
|
-
|
451
|
+
link_forces: jtp.MatrixLike | None = None,
|
481
452
|
prefer_aba: float = True,
|
482
453
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
483
454
|
"""
|
@@ -488,8 +459,8 @@ def forward_dynamics(
|
|
488
459
|
data: The data of the considered model.
|
489
460
|
joint_forces:
|
490
461
|
The joint forces to consider as a vector of shape `(dofs,)`.
|
491
|
-
|
492
|
-
The
|
462
|
+
link_forces:
|
463
|
+
The link 6D forces consider as a matrix of shape `(nL, 6)`.
|
493
464
|
The frame in which they are expressed must be `data.velocity_representation`.
|
494
465
|
prefer_aba: Whether to prefer the ABA algorithm over the CRB one.
|
495
466
|
|
@@ -505,7 +476,7 @@ def forward_dynamics(
|
|
505
476
|
model=model,
|
506
477
|
data=data,
|
507
478
|
joint_forces=joint_forces,
|
508
|
-
|
479
|
+
link_forces=link_forces,
|
509
480
|
)
|
510
481
|
|
511
482
|
|
@@ -515,7 +486,7 @@ def forward_dynamics_aba(
|
|
515
486
|
data: js.data.JaxSimModelData,
|
516
487
|
*,
|
517
488
|
joint_forces: jtp.VectorLike | None = None,
|
518
|
-
|
489
|
+
link_forces: jtp.MatrixLike | None = None,
|
519
490
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
520
491
|
"""
|
521
492
|
Compute the forward dynamics of the model with the ABA algorithm.
|
@@ -525,8 +496,8 @@ def forward_dynamics_aba(
|
|
525
496
|
data: The data of the considered model.
|
526
497
|
joint_forces:
|
527
498
|
The joint forces to consider as a vector of shape `(dofs,)`.
|
528
|
-
|
529
|
-
The
|
499
|
+
link_forces:
|
500
|
+
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
|
530
501
|
The frame in which they are expressed must be `data.velocity_representation`.
|
531
502
|
|
532
503
|
Returns:
|
@@ -535,63 +506,112 @@ def forward_dynamics_aba(
|
|
535
506
|
considered joint forces and external forces.
|
536
507
|
"""
|
537
508
|
|
538
|
-
#
|
509
|
+
# ============
|
510
|
+
# Prepare data
|
511
|
+
# ============
|
512
|
+
|
513
|
+
# Build joint forces, if not provided.
|
539
514
|
τ = (
|
540
|
-
joint_forces
|
515
|
+
jnp.atleast_1d(joint_forces.squeeze())
|
541
516
|
if joint_forces is not None
|
542
517
|
else jnp.zeros_like(data.joint_positions())
|
543
518
|
)
|
544
519
|
|
545
|
-
# Build
|
546
|
-
|
547
|
-
|
548
|
-
if
|
520
|
+
# Build link forces, if not provided.
|
521
|
+
f_L = (
|
522
|
+
jnp.atleast_2d(link_forces.squeeze())
|
523
|
+
if link_forces is not None
|
549
524
|
else jnp.zeros((model.number_of_links(), 6))
|
550
525
|
)
|
551
526
|
|
552
|
-
#
|
553
|
-
|
554
|
-
model=model
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
f_ext=f_ext,
|
527
|
+
# Create a references object that simplifies converting among representations.
|
528
|
+
references = js.references.JaxSimModelReferences.build(
|
529
|
+
model=model,
|
530
|
+
joint_force_references=τ,
|
531
|
+
link_forces=f_L,
|
532
|
+
data=data,
|
533
|
+
velocity_representation=data.velocity_representation,
|
560
534
|
)
|
561
535
|
|
562
|
-
|
563
|
-
|
536
|
+
# Extract the link and joint serializations.
|
537
|
+
link_names = model.link_names()
|
538
|
+
joint_names = model.joint_names()
|
539
|
+
|
540
|
+
# Extract the state in inertial-fixed representation.
|
541
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
542
|
+
W_p_B = data.base_position()
|
543
|
+
W_v_WB = data.base_velocity()
|
544
|
+
W_Q_B = data.base_orientation(dcm=False)
|
545
|
+
s = data.joint_positions(model=model, joint_names=joint_names)
|
546
|
+
ṡ = data.joint_velocities(model=model, joint_names=joint_names)
|
564
547
|
|
565
|
-
|
566
|
-
|
548
|
+
# Extract the inputs in inertial-fixed representation.
|
549
|
+
with references.switch_velocity_representation(VelRepr.Inertial):
|
550
|
+
W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
|
551
|
+
τ = references.joint_force_references(model=model, joint_names=joint_names)
|
552
|
+
|
553
|
+
# ========================
|
554
|
+
# Compute forward dynamics
|
555
|
+
# ========================
|
556
|
+
|
557
|
+
W_v̇_WB, s̈ = jaxsim.rbda.aba(
|
558
|
+
model=model,
|
559
|
+
base_position=W_p_B,
|
560
|
+
base_quaternion=W_Q_B,
|
561
|
+
joint_positions=s,
|
562
|
+
base_linear_velocity=W_v_WB[0:3],
|
563
|
+
base_angular_velocity=W_v_WB[3:6],
|
564
|
+
joint_velocities=ṡ,
|
565
|
+
joint_forces=τ,
|
566
|
+
link_forces=W_f_L,
|
567
|
+
standard_gravity=data.standard_gravity(),
|
568
|
+
)
|
569
|
+
|
570
|
+
# =============
|
571
|
+
# Adjust output
|
572
|
+
# =============
|
567
573
|
|
568
|
-
|
574
|
+
def to_active(
|
575
|
+
W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
|
576
|
+
) -> jtp.Vector:
|
577
|
+
"""
|
578
|
+
Helper to convert the inertial-fixed apparent base acceleration W_v̇_WB to
|
579
|
+
another representation C_v̇_WB expressed in a generic frame C.
|
580
|
+
"""
|
569
581
|
|
570
|
-
|
571
|
-
|
582
|
+
from jaxsim.math import Cross
|
583
|
+
|
584
|
+
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
585
|
+
# In Inertial and Body representations, the cross product is always zero.
|
586
|
+
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
|
587
|
+
return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)
|
572
588
|
|
573
589
|
match data.velocity_representation:
|
574
590
|
case VelRepr.Inertial:
|
591
|
+
# In this case C=W
|
575
592
|
W_H_C = W_H_W = jnp.eye(4)
|
576
|
-
|
593
|
+
W_v_WC = W_v_WW = jnp.zeros(6)
|
577
594
|
|
578
595
|
case VelRepr.Body:
|
596
|
+
# In this case C=B
|
579
597
|
W_H_C = W_H_B = data.base_transform()
|
580
|
-
|
598
|
+
W_v_WC = W_v_WB
|
581
599
|
|
582
600
|
case VelRepr.Mixed:
|
601
|
+
# In this case C=B[W]
|
583
602
|
W_H_B = data.base_transform()
|
584
603
|
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
585
|
-
|
604
|
+
W_ṗ_B = data.base_velocity()[0:3]
|
605
|
+
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
|
586
606
|
|
587
607
|
case _:
|
588
608
|
raise ValueError(data.velocity_representation)
|
589
609
|
|
590
|
-
# We need to convert the derivative of the base
|
610
|
+
# We need to convert the derivative of the base velocity to the active
|
591
611
|
# representation. In Mixed representation, this conversion is not a plain
|
592
612
|
# transformation with just X, but it also involves a cross product in ℝ⁶.
|
593
613
|
C_v̇_WB = to_active(
|
594
|
-
|
614
|
+
W_v̇_WB=W_v̇_WB,
|
595
615
|
W_H_C=W_H_C,
|
596
616
|
W_v_WB=jnp.hstack(
|
597
617
|
[
|
@@ -599,13 +619,16 @@ def forward_dynamics_aba(
|
|
599
619
|
data.state.physics_model.base_angular_velocity,
|
600
620
|
]
|
601
621
|
),
|
602
|
-
|
622
|
+
W_v_WC=W_v_WC,
|
603
623
|
)
|
604
624
|
|
605
|
-
#
|
606
|
-
|
625
|
+
# The ABA algorithm already returns a zero base 6D acceleration for
|
626
|
+
# fixed-based models. However, the to_active function introduces an
|
627
|
+
# additional acceleration component in Mixed representation.
|
628
|
+
# Here below we make sure that the base acceleration is zero.
|
629
|
+
C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6)
|
607
630
|
|
608
|
-
return C_v̇_WB, s
|
631
|
+
return C_v̇_WB.astype(float), s̈.astype(float)
|
609
632
|
|
610
633
|
|
611
634
|
@jax.jit
|
@@ -614,7 +637,7 @@ def forward_dynamics_crb(
|
|
614
637
|
data: js.data.JaxSimModelData,
|
615
638
|
*,
|
616
639
|
joint_forces: jtp.VectorLike | None = None,
|
617
|
-
|
640
|
+
link_forces: jtp.MatrixLike | None = None,
|
618
641
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
619
642
|
"""
|
620
643
|
Compute the forward dynamics of the model with the CRB algorithm.
|
@@ -624,8 +647,8 @@ def forward_dynamics_crb(
|
|
624
647
|
data: The data of the considered model.
|
625
648
|
joint_forces:
|
626
649
|
The joint forces to consider as a vector of shape `(dofs,)`.
|
627
|
-
|
628
|
-
The
|
650
|
+
link_forces:
|
651
|
+
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
|
629
652
|
The frame in which they are expressed must be `data.velocity_representation`.
|
630
653
|
|
631
654
|
Returns:
|
@@ -638,6 +661,10 @@ def forward_dynamics_crb(
|
|
638
661
|
models with a large number of degrees of freedom.
|
639
662
|
"""
|
640
663
|
|
664
|
+
# ============
|
665
|
+
# Prepare data
|
666
|
+
# ============
|
667
|
+
|
641
668
|
# Build joint torques if not provided
|
642
669
|
τ = (
|
643
670
|
jnp.atleast_1d(joint_forces)
|
@@ -647,8 +674,8 @@ def forward_dynamics_crb(
|
|
647
674
|
|
648
675
|
# Build external forces if not provided
|
649
676
|
f = (
|
650
|
-
jnp.atleast_2d(
|
651
|
-
if
|
677
|
+
jnp.atleast_2d(link_forces)
|
678
|
+
if link_forces is not None
|
652
679
|
else jnp.zeros(shape=(model.number_of_links(), 6))
|
653
680
|
)
|
654
681
|
|
@@ -660,6 +687,10 @@ def forward_dynamics_crb(
|
|
660
687
|
|
661
688
|
# TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i)
|
662
689
|
|
690
|
+
# ========================
|
691
|
+
# Compute forward dynamics
|
692
|
+
# ========================
|
693
|
+
|
663
694
|
if model.floating_base():
|
664
695
|
# l: number of links.
|
665
696
|
# g: generalized coordinates, 6 + number of joints.
|
@@ -675,6 +706,10 @@ def forward_dynamics_crb(
|
|
675
706
|
v̇_WB = jnp.zeros(6)
|
676
707
|
ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()])
|
677
708
|
|
709
|
+
# =============
|
710
|
+
# Adjust output
|
711
|
+
# =============
|
712
|
+
|
678
713
|
# Extract the base acceleration in the active representation.
|
679
714
|
# Note that this is an apparent acceleration (relevant in Mixed representation),
|
680
715
|
# therefore it cannot be always expressed in different frames with just a
|
@@ -702,9 +737,9 @@ def free_floating_mass_matrix(
|
|
702
737
|
The free-floating mass matrix of the model.
|
703
738
|
"""
|
704
739
|
|
705
|
-
M_body = jaxsim.
|
706
|
-
model=model
|
707
|
-
|
740
|
+
M_body = jaxsim.rbda.crba(
|
741
|
+
model=model,
|
742
|
+
joint_positions=data.state.physics_model.joint_positions,
|
708
743
|
)
|
709
744
|
|
710
745
|
match data.velocity_representation:
|
@@ -712,29 +747,17 @@ def free_floating_mass_matrix(
|
|
712
747
|
return M_body
|
713
748
|
|
714
749
|
case VelRepr.Inertial:
|
715
|
-
zero_6n = jnp.zeros(shape=(6, model.dofs()))
|
716
|
-
B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
|
717
750
|
|
718
|
-
|
719
|
-
|
720
|
-
jnp.block([B_X_W, zero_6n]),
|
721
|
-
jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
|
722
|
-
]
|
723
|
-
)
|
751
|
+
B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
|
752
|
+
invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
724
753
|
|
725
754
|
return invT.T @ M_body @ invT
|
726
755
|
|
727
756
|
case VelRepr.Mixed:
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
invT = jnp.vstack(
|
733
|
-
[
|
734
|
-
jnp.block([BW_X_W, zero_6n]),
|
735
|
-
jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
|
736
|
-
]
|
737
|
-
)
|
757
|
+
|
758
|
+
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
759
|
+
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
|
760
|
+
invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
738
761
|
|
739
762
|
return invT.T @ M_body @ invT
|
740
763
|
|
@@ -747,9 +770,9 @@ def inverse_dynamics(
|
|
747
770
|
model: JaxSimModel,
|
748
771
|
data: js.data.JaxSimModelData,
|
749
772
|
*,
|
750
|
-
joint_accelerations: jtp.
|
751
|
-
base_acceleration: jtp.
|
752
|
-
|
773
|
+
joint_accelerations: jtp.VectorLike | None = None,
|
774
|
+
base_acceleration: jtp.VectorLike | None = None,
|
775
|
+
link_forces: jtp.MatrixLike | None = None,
|
753
776
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
754
777
|
"""
|
755
778
|
Compute inverse dynamics with the RNEA algorithm.
|
@@ -761,8 +784,8 @@ def inverse_dynamics(
|
|
761
784
|
The joint accelerations to consider as a vector of shape `(dofs,)`.
|
762
785
|
base_acceleration:
|
763
786
|
The base acceleration to consider as a vector of shape `(6,)`.
|
764
|
-
|
765
|
-
The
|
787
|
+
link_forces:
|
788
|
+
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
|
766
789
|
The frame in which they are expressed must be `data.velocity_representation`.
|
767
790
|
|
768
791
|
Returns:
|
@@ -771,49 +794,62 @@ def inverse_dynamics(
|
|
771
794
|
to obtain the considered joint accelerations.
|
772
795
|
"""
|
773
796
|
|
774
|
-
#
|
775
|
-
|
776
|
-
|
797
|
+
# ============
|
798
|
+
# Prepare data
|
799
|
+
# ============
|
800
|
+
|
801
|
+
# Build joint accelerations, if not provided.
|
802
|
+
s̈ = (
|
803
|
+
jnp.atleast_1d(jnp.array(joint_accelerations).squeeze())
|
777
804
|
if joint_accelerations is not None
|
778
805
|
else jnp.zeros_like(data.joint_positions())
|
779
806
|
)
|
780
807
|
|
781
|
-
# Build base acceleration if not provided
|
782
|
-
|
783
|
-
|
808
|
+
# Build base acceleration, if not provided.
|
809
|
+
v̇_WB = (
|
810
|
+
jnp.array(base_acceleration).squeeze()
|
811
|
+
if base_acceleration is not None
|
812
|
+
else jnp.zeros(6)
|
784
813
|
)
|
785
814
|
|
786
|
-
|
787
|
-
|
788
|
-
|
815
|
+
# Build link forces, if not provided.
|
816
|
+
f_L = (
|
817
|
+
jnp.atleast_2d(jnp.array(link_forces).squeeze())
|
818
|
+
if link_forces is not None
|
789
819
|
else jnp.zeros(shape=(model.number_of_links(), 6))
|
790
820
|
)
|
791
821
|
|
792
|
-
def to_inertial(C_v̇_WB, W_H_C, C_v_WB,
|
822
|
+
def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
|
823
|
+
"""
|
824
|
+
Helper to convert the active representation of the base acceleration C_v̇_WB
|
825
|
+
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
|
826
|
+
"""
|
827
|
+
|
828
|
+
from jaxsim.math import Cross
|
829
|
+
|
793
830
|
W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
|
794
831
|
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
|
832
|
+
C_v_WC = C_X_W @ W_v_WC
|
795
833
|
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
from jaxsim.math.cross import Cross
|
800
|
-
|
801
|
-
C_v_WC = C_X_W @ jnp.hstack([W_vl_WC, jnp.zeros(3)])
|
802
|
-
return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)
|
834
|
+
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
835
|
+
# In Inertial and Body representations, the cross product is always zero.
|
836
|
+
return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)
|
803
837
|
|
804
838
|
match data.velocity_representation:
|
805
839
|
case VelRepr.Inertial:
|
806
840
|
W_H_C = W_H_W = jnp.eye(4)
|
807
|
-
|
841
|
+
W_v_WC = W_v_WW = jnp.zeros(6)
|
808
842
|
|
809
843
|
case VelRepr.Body:
|
810
844
|
W_H_C = W_H_B = data.base_transform()
|
811
|
-
|
845
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
846
|
+
W_v_WC = W_v_WB = data.base_velocity()
|
812
847
|
|
813
848
|
case VelRepr.Mixed:
|
814
849
|
W_H_B = data.base_transform()
|
815
850
|
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
816
|
-
|
851
|
+
W_ṗ_B = data.base_velocity()[0:3]
|
852
|
+
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
|
817
853
|
|
818
854
|
case _:
|
819
855
|
raise ValueError(data.velocity_representation)
|
@@ -822,35 +858,60 @@ def inverse_dynamics(
|
|
822
858
|
# representation. In Mixed representation, this conversion is not a plain
|
823
859
|
# transformation with just X, but it also involves a cross product in ℝ⁶.
|
824
860
|
W_v̇_WB = to_inertial(
|
825
|
-
C_v̇_WB=
|
861
|
+
C_v̇_WB=v̇_WB,
|
826
862
|
W_H_C=W_H_C,
|
827
863
|
C_v_WB=data.base_velocity(),
|
828
|
-
|
864
|
+
W_v_WC=W_v_WC,
|
829
865
|
)
|
830
866
|
|
867
|
+
# Create a references object that simplifies converting among representations.
|
831
868
|
references = js.references.JaxSimModelReferences.build(
|
832
869
|
model=model,
|
833
870
|
data=data,
|
834
|
-
link_forces=
|
871
|
+
link_forces=f_L,
|
835
872
|
velocity_representation=data.velocity_representation,
|
836
873
|
)
|
837
874
|
|
838
|
-
#
|
875
|
+
# Extract the link and joint serializations.
|
876
|
+
link_names = model.link_names()
|
877
|
+
joint_names = model.joint_names()
|
878
|
+
|
879
|
+
# Extract the state in inertial-fixed representation.
|
880
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
881
|
+
W_p_B = data.base_position()
|
882
|
+
W_v_WB = data.base_velocity()
|
883
|
+
W_Q_B = data.base_orientation(dcm=False)
|
884
|
+
s = data.joint_positions(model=model, joint_names=joint_names)
|
885
|
+
ṡ = data.joint_velocities(model=model, joint_names=joint_names)
|
886
|
+
|
887
|
+
# Extract the inputs in inertial-fixed representation.
|
839
888
|
with references.switch_velocity_representation(VelRepr.Inertial):
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
qdd=joint_accelerations,
|
846
|
-
a0fb=W_v̇_WB,
|
847
|
-
f_ext=references.link_forces(model=model, data=data),
|
848
|
-
)
|
889
|
+
W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
|
890
|
+
|
891
|
+
# ========================
|
892
|
+
# Compute inverse dynamics
|
893
|
+
# ========================
|
849
894
|
|
850
|
-
|
851
|
-
|
895
|
+
W_f_B, τ = jaxsim.rbda.rnea(
|
896
|
+
model=model,
|
897
|
+
base_position=W_p_B,
|
898
|
+
base_quaternion=W_Q_B,
|
899
|
+
joint_positions=s,
|
900
|
+
base_linear_velocity=W_v_WB[0:3],
|
901
|
+
base_angular_velocity=W_v_WB[3:6],
|
902
|
+
joint_velocities=ṡ,
|
903
|
+
base_linear_acceleration=W_v̇_WB[0:3],
|
904
|
+
base_angular_acceleration=W_v̇_WB[3:6],
|
905
|
+
joint_accelerations=s̈,
|
906
|
+
link_forces=W_f_L,
|
907
|
+
standard_gravity=data.standard_gravity(),
|
908
|
+
)
|
852
909
|
|
853
|
-
#
|
910
|
+
# =============
|
911
|
+
# Adjust output
|
912
|
+
# =============
|
913
|
+
|
914
|
+
# Express W_f_B in the active representation.
|
854
915
|
f_B = js.data.JaxSimModelData.inertial_to_other_representation(
|
855
916
|
array=W_f_B,
|
856
917
|
other_representation=data.velocity_representation,
|
@@ -905,7 +966,7 @@ def free_floating_gravity_forces(
|
|
905
966
|
# Set zero inputs:
|
906
967
|
joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
|
907
968
|
base_acceleration=jnp.zeros(6),
|
908
|
-
|
969
|
+
link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
|
909
970
|
)
|
910
971
|
).astype(float)
|
911
972
|
|
@@ -948,18 +1009,20 @@ def free_floating_bias_forces(
|
|
948
1009
|
data.state.physics_model.joint_positions
|
949
1010
|
)
|
950
1011
|
|
951
|
-
data_rnea.state.physics_model.base_linear_velocity = (
|
952
|
-
data.state.physics_model.base_linear_velocity
|
953
|
-
)
|
954
|
-
|
955
|
-
data_rnea.state.physics_model.base_angular_velocity = (
|
956
|
-
data.state.physics_model.base_angular_velocity
|
957
|
-
)
|
958
|
-
|
959
1012
|
data_rnea.state.physics_model.joint_velocities = (
|
960
1013
|
data.state.physics_model.joint_velocities
|
961
1014
|
)
|
962
1015
|
|
1016
|
+
# Make sure that base velocity is zero for fixed-base model.
|
1017
|
+
if model.floating_base():
|
1018
|
+
data_rnea.state.physics_model.base_linear_velocity = (
|
1019
|
+
data.state.physics_model.base_linear_velocity
|
1020
|
+
)
|
1021
|
+
|
1022
|
+
data_rnea.state.physics_model.base_angular_velocity = (
|
1023
|
+
data.state.physics_model.base_angular_velocity
|
1024
|
+
)
|
1025
|
+
|
963
1026
|
return jnp.hstack(
|
964
1027
|
inverse_dynamics(
|
965
1028
|
model=model,
|
@@ -967,7 +1030,7 @@ def free_floating_bias_forces(
|
|
967
1030
|
# Set zero inputs:
|
968
1031
|
joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
|
969
1032
|
base_acceleration=jnp.zeros(6),
|
970
|
-
|
1033
|
+
link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
|
971
1034
|
)
|
972
1035
|
).astype(float)
|
973
1036
|
|
@@ -977,6 +1040,24 @@ def free_floating_bias_forces(
|
|
977
1040
|
# ==========================
|
978
1041
|
|
979
1042
|
|
1043
|
+
@jax.jit
|
1044
|
+
def locked_spatial_inertia(
|
1045
|
+
model: JaxSimModel, data: js.data.JaxSimModelData
|
1046
|
+
) -> jtp.Matrix:
|
1047
|
+
"""
|
1048
|
+
Compute the locked 6D inertia matrix of the model.
|
1049
|
+
|
1050
|
+
Args:
|
1051
|
+
model: The model to consider.
|
1052
|
+
data: The data of the considered model.
|
1053
|
+
|
1054
|
+
Returns:
|
1055
|
+
The locked 6D inertia matrix of the model.
|
1056
|
+
"""
|
1057
|
+
|
1058
|
+
return total_momentum_jacobian(model=model, data=data)[:, 0:6]
|
1059
|
+
|
1060
|
+
|
980
1061
|
@jax.jit
|
981
1062
|
def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
|
982
1063
|
"""
|
@@ -987,34 +1068,221 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec
|
|
987
1068
|
data: The data of the considered model.
|
988
1069
|
|
989
1070
|
Returns:
|
990
|
-
The total momentum of the model.
|
1071
|
+
The total momentum of the model in the active velocity representation.
|
991
1072
|
"""
|
992
1073
|
|
993
|
-
|
994
|
-
|
995
|
-
# floating-base momentum.
|
996
|
-
with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
|
997
|
-
B_ν = data.generalized_velocity()
|
998
|
-
M_B = free_floating_mass_matrix(model=model, data=data)
|
1074
|
+
ν = data.generalized_velocity()
|
1075
|
+
Jh = total_momentum_jacobian(model=model, data=data)
|
999
1076
|
|
1000
|
-
|
1001
|
-
B_h = M_B[0:6, :] @ B_ν
|
1077
|
+
return Jh @ ν
|
1002
1078
|
|
1003
|
-
# Compute the 6D transformation matrix
|
1004
|
-
W_H_B = data.base_transform()
|
1005
|
-
B_X_W: jtp.Array = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
|
1006
1079
|
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1080
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
1081
|
+
def total_momentum_jacobian(
|
1082
|
+
model: JaxSimModel,
|
1083
|
+
data: js.data.JaxSimModelData,
|
1084
|
+
*,
|
1085
|
+
output_vel_repr: VelRepr | None = None,
|
1086
|
+
) -> jtp.Matrix:
|
1087
|
+
"""
|
1088
|
+
Compute the jacobian of the total momentum.
|
1010
1089
|
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1090
|
+
Args:
|
1091
|
+
model: The model to consider.
|
1092
|
+
data: The data of the considered model.
|
1093
|
+
output_vel_repr: The output velocity representation of the jacobian.
|
1094
|
+
|
1095
|
+
Returns:
|
1096
|
+
The jacobian of the total momentum of the model in the active representation.
|
1097
|
+
"""
|
1098
|
+
|
1099
|
+
output_vel_repr = (
|
1100
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
if output_vel_repr is data.velocity_representation:
|
1104
|
+
return free_floating_mass_matrix(model=model, data=data)[0:6]
|
1105
|
+
|
1106
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
1107
|
+
B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6]
|
1108
|
+
|
1109
|
+
match data.velocity_representation:
|
1110
|
+
case VelRepr.Body:
|
1111
|
+
B_Jh = B_Jh_B
|
1112
|
+
|
1113
|
+
case VelRepr.Inertial:
|
1114
|
+
B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
|
1115
|
+
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
1116
|
+
|
1117
|
+
case VelRepr.Mixed:
|
1118
|
+
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
1119
|
+
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
|
1120
|
+
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
1121
|
+
|
1122
|
+
case _:
|
1123
|
+
raise ValueError(data.velocity_representation)
|
1124
|
+
|
1125
|
+
match output_vel_repr:
|
1126
|
+
case VelRepr.Body:
|
1127
|
+
return B_Jh
|
1128
|
+
|
1129
|
+
case VelRepr.Inertial:
|
1130
|
+
W_H_B = data.base_transform()
|
1131
|
+
B_Xv_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
|
1132
|
+
W_Xf_B = B_Xv_W.T
|
1133
|
+
W_Jh = W_Xf_B @ B_Jh
|
1134
|
+
return W_Jh
|
1135
|
+
|
1136
|
+
case VelRepr.Mixed:
|
1137
|
+
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
1138
|
+
B_Xv_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
|
1139
|
+
BW_Xf_B = B_Xv_BW.T
|
1140
|
+
BW_Jh = BW_Xf_B @ B_Jh
|
1141
|
+
return BW_Jh
|
1142
|
+
|
1143
|
+
case _:
|
1144
|
+
raise ValueError(output_vel_repr)
|
1145
|
+
|
1146
|
+
|
1147
|
+
@jax.jit
|
1148
|
+
def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
|
1149
|
+
"""
|
1150
|
+
Compute the average velocity of the model.
|
1151
|
+
|
1152
|
+
Args:
|
1153
|
+
model: The model to consider.
|
1154
|
+
data: The data of the considered model.
|
1155
|
+
|
1156
|
+
Returns:
|
1157
|
+
The average velocity of the model computed in the base frame and expressed
|
1158
|
+
in the active representation.
|
1159
|
+
"""
|
1160
|
+
|
1161
|
+
ν = data.generalized_velocity()
|
1162
|
+
J = average_velocity_jacobian(model=model, data=data)
|
1163
|
+
|
1164
|
+
return J @ ν
|
1165
|
+
|
1166
|
+
|
1167
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
1168
|
+
def average_velocity_jacobian(
|
1169
|
+
model: JaxSimModel,
|
1170
|
+
data: js.data.JaxSimModelData,
|
1171
|
+
*,
|
1172
|
+
output_vel_repr: VelRepr | None = None,
|
1173
|
+
) -> jtp.Matrix:
|
1174
|
+
"""
|
1175
|
+
Compute the Jacobian of the average velocity of the model.
|
1176
|
+
|
1177
|
+
Args:
|
1178
|
+
model: The model to consider.
|
1179
|
+
data: The data of the considered model.
|
1180
|
+
output_vel_repr: The output velocity representation of the jacobian.
|
1181
|
+
|
1182
|
+
Returns:
|
1183
|
+
The Jacobian of the average centroidal velocity of the model in the desired
|
1184
|
+
representation.
|
1185
|
+
"""
|
1186
|
+
|
1187
|
+
output_vel_repr = (
|
1188
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
1189
|
+
)
|
1190
|
+
|
1191
|
+
# Depending on the velocity representation, the frame G is either G[W] or G[B].
|
1192
|
+
G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data)
|
1193
|
+
|
1194
|
+
match output_vel_repr:
|
1195
|
+
|
1196
|
+
case VelRepr.Inertial:
|
1197
|
+
|
1198
|
+
GW_J = G_J
|
1199
|
+
W_p_CoM = js.com.com_position(model=model, data=data)
|
1200
|
+
|
1201
|
+
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
|
1202
|
+
W_X_GW = jaxlie.SE3.from_matrix(W_H_GW).adjoint()
|
1203
|
+
|
1204
|
+
return W_X_GW @ GW_J
|
1205
|
+
|
1206
|
+
case VelRepr.Body:
|
1207
|
+
|
1208
|
+
GB_J = G_J
|
1209
|
+
W_p_B = data.base_position()
|
1210
|
+
W_p_CoM = js.com.com_position(model=model, data=data)
|
1211
|
+
B_R_W = data.base_orientation(dcm=True).transpose()
|
1212
|
+
|
1213
|
+
B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
|
1214
|
+
B_X_GB = jaxlie.SE3.from_matrix(B_H_GB).adjoint()
|
1215
|
+
|
1216
|
+
return B_X_GB @ GB_J
|
1217
|
+
|
1218
|
+
case VelRepr.Mixed:
|
1219
|
+
|
1220
|
+
GW_J = G_J
|
1221
|
+
W_p_B = data.base_position()
|
1222
|
+
W_p_CoM = js.com.com_position(model=model, data=data)
|
1223
|
+
|
1224
|
+
BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
|
1225
|
+
BW_X_GW = jaxlie.SE3.from_matrix(BW_H_GW).adjoint()
|
1226
|
+
|
1227
|
+
return BW_X_GW @ GW_J
|
1228
|
+
|
1229
|
+
|
1230
|
+
# ========================
|
1231
|
+
# Other dynamic quantities
|
1232
|
+
# ========================
|
1233
|
+
|
1234
|
+
|
1235
|
+
@jax.jit
|
1236
|
+
def link_contact_forces(
|
1237
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
1238
|
+
) -> jtp.Matrix:
|
1239
|
+
"""
|
1240
|
+
Compute the 6D contact forces of all links of the model.
|
1241
|
+
|
1242
|
+
Args:
|
1243
|
+
model: The model to consider.
|
1244
|
+
data: The data of the considered model.
|
1245
|
+
|
1246
|
+
Returns:
|
1247
|
+
A (nL, 6) array containing the stacked 6D contact forces of the links,
|
1248
|
+
expressed in the frame corresponding to the active representation.
|
1249
|
+
"""
|
1250
|
+
|
1251
|
+
# Compute the 6D forces applied to each collidable point expressed in the
|
1252
|
+
# inertial frame.
|
1253
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
1254
|
+
W_f_Ci = js.contact.collidable_point_forces(model=model, data=data)
|
1255
|
+
|
1256
|
+
# Construct the vector defining the parent link index of each collidable point.
|
1257
|
+
# We use this vector to sum the 6D forces of all collidable points rigidly
|
1258
|
+
# attached to the same link.
|
1259
|
+
parent_link_index_of_collidable_points = jnp.array(
|
1260
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
1261
|
+
)
|
1262
|
+
|
1263
|
+
# Sum the forces of all collidable points rigidly attached to a body.
|
1264
|
+
# Since the contact forces W_f_Ci are expressed in the world frame,
|
1265
|
+
# we don't need any coordinate transformation.
|
1266
|
+
W_f_Li = jax.vmap(
|
1267
|
+
lambda nc: (
|
1268
|
+
jnp.vstack(
|
1269
|
+
jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
|
1270
|
+
)
|
1271
|
+
* W_f_Ci
|
1272
|
+
).sum(axis=0)
|
1273
|
+
)(jnp.arange(model.number_of_links()))
|
1274
|
+
|
1275
|
+
# Convert the 6D forces to the active representation.
|
1276
|
+
f_Li = jax.vmap(
|
1277
|
+
lambda W_f_L: data.inertial_to_other_representation(
|
1278
|
+
array=W_f_L,
|
1279
|
+
other_representation=data.velocity_representation,
|
1280
|
+
transform=data.base_transform(),
|
1281
|
+
is_force=True,
|
1282
|
+
)
|
1283
|
+
)(W_f_Li)
|
1284
|
+
|
1285
|
+
return f_Li
|
1018
1286
|
|
1019
1287
|
|
1020
1288
|
# ======
|
@@ -1077,7 +1345,7 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
|
|
1077
1345
|
|
1078
1346
|
m = total_mass(model=model)
|
1079
1347
|
gravity = data.gravity.squeeze()
|
1080
|
-
W_p̃_CoM = jnp.hstack([com_position(model=model, data=data), 1])
|
1348
|
+
W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1])
|
1081
1349
|
|
1082
1350
|
U = -jnp.hstack([gravity, 0]) @ (m * W_p̃_CoM)
|
1083
1351
|
return U.squeeze().astype(float)
|
@@ -1097,7 +1365,8 @@ def step(
|
|
1097
1365
|
integrator: jaxsim.integrators.Integrator,
|
1098
1366
|
integrator_state: dict[str, Any] | None = None,
|
1099
1367
|
joint_forces: jtp.VectorLike | None = None,
|
1100
|
-
|
1368
|
+
link_forces: jtp.MatrixLike | None = None,
|
1369
|
+
**kwargs,
|
1101
1370
|
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
|
1102
1371
|
"""
|
1103
1372
|
Perform a simulation step.
|
@@ -1109,15 +1378,17 @@ def step(
|
|
1109
1378
|
integrator: The integrator to use.
|
1110
1379
|
integrator_state: The state of the integrator.
|
1111
1380
|
joint_forces: The joint forces to consider.
|
1112
|
-
|
1113
|
-
The
|
1381
|
+
link_forces:
|
1382
|
+
The link 6D forces to consider.
|
1114
1383
|
The frame in which they are expressed must be `data.velocity_representation`.
|
1384
|
+
kwargs: Additional kwargs to pass to the integrator.
|
1115
1385
|
|
1116
1386
|
Returns:
|
1117
1387
|
A tuple containing the new data of the model
|
1118
1388
|
and the new state of the integrator.
|
1119
1389
|
"""
|
1120
1390
|
|
1391
|
+
integrator_kwargs = kwargs if kwargs is not None else dict()
|
1121
1392
|
integrator_state = integrator_state if integrator_state is not None else dict()
|
1122
1393
|
|
1123
1394
|
# Extract the initial resources.
|
@@ -1128,10 +1399,12 @@ def step(
|
|
1128
1399
|
# Step the dynamics forward.
|
1129
1400
|
state_xf, integrator_state_xf = integrator.step(
|
1130
1401
|
x0=state_x0,
|
1131
|
-
t0=jnp.array(t0_ns
|
1402
|
+
t0=jnp.array(t0_ns / 1e9).astype(float),
|
1132
1403
|
dt=dt,
|
1133
1404
|
params=integrator_state_x0,
|
1134
|
-
**
|
1405
|
+
**(
|
1406
|
+
dict(joint_forces=joint_forces, link_forces=link_forces) | integrator_kwargs
|
1407
|
+
),
|
1135
1408
|
)
|
1136
1409
|
|
1137
1410
|
return (
|