jaxsim 0.2.dev101__py3-none-any.whl → 0.2.dev166__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +1 -0
- jaxsim/api/contact.py +194 -0
- jaxsim/api/data.py +951 -0
- jaxsim/api/joint.py +148 -0
- jaxsim/api/link.py +262 -0
- jaxsim/api/model.py +1099 -0
- jaxsim/api/ode.py +280 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +508 -0
- jaxsim/integrators/fixed_step.py +158 -0
- jaxsim/mujoco/__init__.py +1 -1
- jaxsim/mujoco/loaders.py +30 -18
- jaxsim/mujoco/visualizer.py +3 -1
- jaxsim/physics/algos/soft_contacts.py +97 -28
- jaxsim/physics/model/physics_model.py +30 -0
- jaxsim/physics/model/physics_model_state.py +110 -11
- jaxsim/simulation/ode_data.py +43 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/METADATA +2 -1
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/RECORD +23 -13
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/WHEEL +0 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py
ADDED
@@ -0,0 +1,1099 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
import functools
|
5
|
+
import pathlib
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import jax
|
9
|
+
import jax.numpy as jnp
|
10
|
+
import jax_dataclasses
|
11
|
+
import jaxlie
|
12
|
+
import rod
|
13
|
+
from jax_dataclasses import Static
|
14
|
+
|
15
|
+
import jaxsim.api as js
|
16
|
+
import jaxsim.physics.algos.aba
|
17
|
+
import jaxsim.physics.algos.crba
|
18
|
+
import jaxsim.physics.algos.forward_kinematics
|
19
|
+
import jaxsim.physics.algos.rnea
|
20
|
+
import jaxsim.physics.model.physics_model
|
21
|
+
import jaxsim.typing as jtp
|
22
|
+
from jaxsim.high_level.common import VelRepr
|
23
|
+
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
|
24
|
+
from jaxsim.utils import JaxsimDataclass, Mutability
|
25
|
+
|
26
|
+
|
27
|
+
@jax_dataclasses.pytree_dataclass
|
28
|
+
class JaxSimModel(JaxsimDataclass):
|
29
|
+
"""
|
30
|
+
The JaxSim model defining the kinematics and dynamics of a robot.
|
31
|
+
"""
|
32
|
+
|
33
|
+
model_name: Static[str]
|
34
|
+
|
35
|
+
physics_model: jaxsim.physics.model.physics_model.PhysicsModel = dataclasses.field(
|
36
|
+
repr=False
|
37
|
+
)
|
38
|
+
|
39
|
+
terrain: Static[Terrain] = dataclasses.field(default=FlatTerrain(), repr=False)
|
40
|
+
|
41
|
+
built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
|
42
|
+
repr=False, default=None
|
43
|
+
)
|
44
|
+
|
45
|
+
_number_of_links: Static[int] = dataclasses.field(
|
46
|
+
init=False, repr=False, default=None
|
47
|
+
)
|
48
|
+
|
49
|
+
_number_of_joints: Static[int] = dataclasses.field(
|
50
|
+
init=False, repr=False, default=None
|
51
|
+
)
|
52
|
+
|
53
|
+
def __post_init__(self):
|
54
|
+
|
55
|
+
# These attributes are Static so that we can use `jax.vmap` and `jax.lax.scan`
|
56
|
+
# over the all links and joints
|
57
|
+
with self.mutable_context(
|
58
|
+
mutability=Mutability.MUTABLE_NO_VALIDATION,
|
59
|
+
restore_after_exception=False,
|
60
|
+
):
|
61
|
+
self._number_of_links = len(self.physics_model.description.links_dict)
|
62
|
+
self._number_of_joints = len(self.physics_model.description.joints_dict)
|
63
|
+
|
64
|
+
# ========================
|
65
|
+
# Initialization and state
|
66
|
+
# ========================
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def build_from_model_description(
|
70
|
+
model_description: str | pathlib.Path | rod.Model,
|
71
|
+
model_name: str | None = None,
|
72
|
+
gravity: jtp.Array = jaxsim.physics.default_gravity(),
|
73
|
+
is_urdf: bool | None = None,
|
74
|
+
considered_joints: list[str] | None = None,
|
75
|
+
) -> JaxSimModel:
|
76
|
+
"""
|
77
|
+
Build a Model object from a model description.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
model_description:
|
81
|
+
A path to an SDF/URDF file, a string containing
|
82
|
+
its content, or a pre-parsed/pre-built rod model.
|
83
|
+
model_name:
|
84
|
+
The optional name of the model that overrides the one in
|
85
|
+
the description.
|
86
|
+
gravity: The 3D gravity vector.
|
87
|
+
is_urdf:
|
88
|
+
Whether the model description is a URDF or an SDF. This is
|
89
|
+
automatically inferred if the model description is a path to a file.
|
90
|
+
considered_joints:
|
91
|
+
The list of joints to consider. If None, all joints are considered.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
The built Model object.
|
95
|
+
"""
|
96
|
+
|
97
|
+
import jaxsim.parsers.rod
|
98
|
+
|
99
|
+
# Parse the input resource (either a path to file or a string with the URDF/SDF)
|
100
|
+
# and build the -intermediate- model description
|
101
|
+
intermediate_description = jaxsim.parsers.rod.build_model_description(
|
102
|
+
model_description=model_description, is_urdf=is_urdf
|
103
|
+
)
|
104
|
+
|
105
|
+
# Lump links together if not all joints are considered.
|
106
|
+
# Note: this procedure assigns a zero position to all joints not considered.
|
107
|
+
if considered_joints is not None:
|
108
|
+
intermediate_description = intermediate_description.reduce(
|
109
|
+
considered_joints=considered_joints
|
110
|
+
)
|
111
|
+
|
112
|
+
# Create the physics model from the model description
|
113
|
+
physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
|
114
|
+
model_description=intermediate_description, gravity=gravity
|
115
|
+
)
|
116
|
+
|
117
|
+
# Build the model
|
118
|
+
model = JaxSimModel.build(physics_model=physics_model, model_name=model_name)
|
119
|
+
|
120
|
+
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
121
|
+
model.built_from = model_description
|
122
|
+
|
123
|
+
return model
|
124
|
+
|
125
|
+
@staticmethod
|
126
|
+
def build(
|
127
|
+
physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
|
128
|
+
model_name: str | None = None,
|
129
|
+
) -> JaxSimModel:
|
130
|
+
"""
|
131
|
+
Build a Model object from a physics model.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
physics_model: The physics model.
|
135
|
+
model_name:
|
136
|
+
The optional name of the model overriding the physics model name.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
The built Model object.
|
140
|
+
"""
|
141
|
+
|
142
|
+
# Set the model name (if not provided, use the one from the model description)
|
143
|
+
model_name = (
|
144
|
+
model_name if model_name is not None else physics_model.description.name
|
145
|
+
)
|
146
|
+
|
147
|
+
# Build the model
|
148
|
+
model = JaxSimModel(physics_model=physics_model, model_name=model_name) # noqa
|
149
|
+
|
150
|
+
return model
|
151
|
+
|
152
|
+
# ==========
|
153
|
+
# Properties
|
154
|
+
# ==========
|
155
|
+
|
156
|
+
def name(self) -> str:
|
157
|
+
"""
|
158
|
+
Return the name of the model.
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
The name of the model.
|
162
|
+
"""
|
163
|
+
|
164
|
+
return self.model_name
|
165
|
+
|
166
|
+
def number_of_links(self) -> jtp.Int:
|
167
|
+
"""
|
168
|
+
Return the number of links in the model.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
The number of links in the model.
|
172
|
+
|
173
|
+
Note:
|
174
|
+
The base link is included in the count and its index is always 0.
|
175
|
+
"""
|
176
|
+
|
177
|
+
return self._number_of_links
|
178
|
+
|
179
|
+
def number_of_joints(self) -> jtp.Int:
|
180
|
+
"""
|
181
|
+
Return the number of joints in the model.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
The number of joints in the model.
|
185
|
+
"""
|
186
|
+
|
187
|
+
return self._number_of_joints
|
188
|
+
|
189
|
+
# =================
|
190
|
+
# Base link methods
|
191
|
+
# =================
|
192
|
+
|
193
|
+
def floating_base(self) -> bool:
|
194
|
+
"""
|
195
|
+
Return whether the model has a floating base.
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
True if the model is floating-base, False otherwise.
|
199
|
+
"""
|
200
|
+
|
201
|
+
return self.physics_model.is_floating_base
|
202
|
+
|
203
|
+
def base_link(self) -> str:
|
204
|
+
"""
|
205
|
+
Return the name of the base link.
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
The name of the base link.
|
209
|
+
"""
|
210
|
+
|
211
|
+
return self.physics_model.description.root.name
|
212
|
+
|
213
|
+
# =====================
|
214
|
+
# Joint-related methods
|
215
|
+
# =====================
|
216
|
+
|
217
|
+
def dofs(self) -> int:
|
218
|
+
"""
|
219
|
+
Return the number of degrees of freedom of the model.
|
220
|
+
|
221
|
+
Returns:
|
222
|
+
The number of degrees of freedom of the model.
|
223
|
+
|
224
|
+
Note:
|
225
|
+
We do not yet support multi-DoF joints, therefore this is always equal to
|
226
|
+
the number of joints. In the future, this could be different.
|
227
|
+
"""
|
228
|
+
|
229
|
+
return len(self.physics_model.description.joints_dict)
|
230
|
+
|
231
|
+
def joint_names(self) -> tuple[str, ...]:
|
232
|
+
"""
|
233
|
+
Return the names of the joints in the model.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
The names of the joints in the model.
|
237
|
+
"""
|
238
|
+
|
239
|
+
return tuple(self.physics_model.description.joints_dict.keys())
|
240
|
+
|
241
|
+
# ====================
|
242
|
+
# Link-related methods
|
243
|
+
# ====================
|
244
|
+
|
245
|
+
def link_names(self) -> tuple[str, ...]:
|
246
|
+
"""
|
247
|
+
Return the names of the links in the model.
|
248
|
+
|
249
|
+
Returns:
|
250
|
+
The names of the links in the model.
|
251
|
+
"""
|
252
|
+
|
253
|
+
return tuple(self.physics_model.description.links_dict.keys())
|
254
|
+
|
255
|
+
|
256
|
+
# =====================
|
257
|
+
# Model post-processing
|
258
|
+
# =====================
|
259
|
+
|
260
|
+
|
261
|
+
def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimModel:
|
262
|
+
"""
|
263
|
+
Reduce the model by lumping together the links connected by removed joints.
|
264
|
+
|
265
|
+
Args:
|
266
|
+
model: The model to reduce.
|
267
|
+
considered_joints: The sequence of joints to consider.
|
268
|
+
|
269
|
+
Note:
|
270
|
+
If considered_joints contains joints not existing in the model, the method
|
271
|
+
will raise an exception. If considered_joints is empty, the method will
|
272
|
+
return a copy of the input model.
|
273
|
+
"""
|
274
|
+
|
275
|
+
if len(considered_joints) == 0:
|
276
|
+
return model.copy()
|
277
|
+
|
278
|
+
# Reduce the model description.
|
279
|
+
# If considered_joints contains joints not existing in the model, the method
|
280
|
+
# will raise an exception.
|
281
|
+
reduced_intermediate_description = model.physics_model.description.reduce(
|
282
|
+
considered_joints=list(considered_joints)
|
283
|
+
)
|
284
|
+
|
285
|
+
# Create the physics model from the reduced model description
|
286
|
+
physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
|
287
|
+
model_description=reduced_intermediate_description,
|
288
|
+
gravity=model.physics_model.gravity[0:3],
|
289
|
+
)
|
290
|
+
|
291
|
+
# Build the reduced model
|
292
|
+
reduced_model = JaxSimModel.build(
|
293
|
+
physics_model=physics_model, model_name=model.name()
|
294
|
+
)
|
295
|
+
|
296
|
+
return reduced_model
|
297
|
+
|
298
|
+
|
299
|
+
# ===================
|
300
|
+
# Inertial properties
|
301
|
+
# ===================
|
302
|
+
|
303
|
+
|
304
|
+
@jax.jit
|
305
|
+
def total_mass(model: JaxSimModel) -> jtp.Float:
|
306
|
+
"""
|
307
|
+
Compute the total mass of the model.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
model: The model to consider.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
The total mass of the model.
|
314
|
+
"""
|
315
|
+
|
316
|
+
return (
|
317
|
+
jax.vmap(lambda idx: js.link.mass(model=model, link_index=idx))(
|
318
|
+
jnp.arange(model.number_of_links())
|
319
|
+
)
|
320
|
+
.sum()
|
321
|
+
.astype(float)
|
322
|
+
)
|
323
|
+
|
324
|
+
|
325
|
+
# ==============
|
326
|
+
# Center of mass
|
327
|
+
# ==============
|
328
|
+
|
329
|
+
|
330
|
+
@jax.jit
|
331
|
+
def com_position(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
|
332
|
+
"""
|
333
|
+
Compute the position of the center of mass of the model.
|
334
|
+
|
335
|
+
Args:
|
336
|
+
model: The model to consider.
|
337
|
+
data: The data of the considered model.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
The position of the center of mass of the model w.r.t. the world frame.
|
341
|
+
"""
|
342
|
+
|
343
|
+
m = total_mass(model=model)
|
344
|
+
|
345
|
+
W_H_L = forward_kinematics(model=model, data=data)
|
346
|
+
W_H_B = data.base_transform()
|
347
|
+
B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
|
348
|
+
|
349
|
+
def B_p̃_LCoM(i) -> jtp.Vector:
|
350
|
+
m = js.link.mass(model=model, link_index=i)
|
351
|
+
L_p_LCoM = js.link.com_position(
|
352
|
+
model=model, data=data, link_index=i, in_link_frame=True
|
353
|
+
)
|
354
|
+
return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
|
355
|
+
|
356
|
+
com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
|
357
|
+
|
358
|
+
B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
|
359
|
+
B_p̃_CoM = B_p̃_CoM.at[3].set(1)
|
360
|
+
|
361
|
+
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
|
362
|
+
|
363
|
+
|
364
|
+
# ==============================
|
365
|
+
# Rigid Body Dynamics Algorithms
|
366
|
+
# ==============================
|
367
|
+
|
368
|
+
|
369
|
+
@jax.jit
|
370
|
+
def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
|
371
|
+
"""
|
372
|
+
Compute the SE(3) transforms from the world frame to the frames of all links.
|
373
|
+
|
374
|
+
Args:
|
375
|
+
model: The model to consider.
|
376
|
+
data: The data of the considered model.
|
377
|
+
|
378
|
+
Returns:
|
379
|
+
A (nL, 4, 4) array containing the stacked SE(3) transforms of the links.
|
380
|
+
The first axis is the link index.
|
381
|
+
"""
|
382
|
+
|
383
|
+
W_H_LL = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
|
384
|
+
model=model.physics_model,
|
385
|
+
q=data.state.physics_model.joint_positions,
|
386
|
+
xfb=data.state.physics_model.xfb(),
|
387
|
+
)
|
388
|
+
|
389
|
+
return jnp.atleast_3d(W_H_LL).astype(float)
|
390
|
+
|
391
|
+
|
392
|
+
@jax.jit
|
393
|
+
def generalized_free_floating_jacobian(
|
394
|
+
model: JaxSimModel,
|
395
|
+
data: js.data.JaxSimModelData,
|
396
|
+
*,
|
397
|
+
output_vel_repr: VelRepr | None = None,
|
398
|
+
) -> jtp.Matrix:
|
399
|
+
"""
|
400
|
+
Compute the free-floating jacobians of all links.
|
401
|
+
|
402
|
+
Args:
|
403
|
+
model: The model to consider.
|
404
|
+
data: The data of the considered model.
|
405
|
+
output_vel_repr:
|
406
|
+
The output velocity representation of the free-floating jacobians.
|
407
|
+
|
408
|
+
Returns:
|
409
|
+
The (nL, 6, 6+dofs) array containing the stacked free-floating
|
410
|
+
jacobians of the links. The first axis is the link index.
|
411
|
+
"""
|
412
|
+
|
413
|
+
if output_vel_repr is None:
|
414
|
+
output_vel_repr = data.velocity_representation
|
415
|
+
|
416
|
+
# The body frame of the Link.jacobian method is the link frame L.
|
417
|
+
# In this method, we want instead to use the base link B as body frame.
|
418
|
+
# Therefore, we always get the link jacobian having Inertial as output
|
419
|
+
# representation, and then we convert it to the desired output representation.
|
420
|
+
match output_vel_repr:
|
421
|
+
case VelRepr.Inertial:
|
422
|
+
to_output = lambda J: J
|
423
|
+
|
424
|
+
case VelRepr.Body:
|
425
|
+
|
426
|
+
def to_output(W_J_Wi):
|
427
|
+
W_H_B = data.base_transform()
|
428
|
+
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
|
429
|
+
return B_X_W @ W_J_Wi
|
430
|
+
|
431
|
+
case VelRepr.Mixed:
|
432
|
+
|
433
|
+
def to_output(W_J_Wi):
|
434
|
+
W_H_B = data.base_transform()
|
435
|
+
W_H_BW = jnp.array(W_H_B).at[0:3, 0:3].set(jnp.eye(3))
|
436
|
+
BW_X_W = jaxlie.SE3.from_matrix(W_H_BW).inverse().adjoint()
|
437
|
+
return BW_X_W @ W_J_Wi
|
438
|
+
|
439
|
+
case _:
|
440
|
+
raise ValueError(output_vel_repr)
|
441
|
+
|
442
|
+
# Get the link jacobians in Inertial representation and convert them to the
|
443
|
+
# target output representation in which the body frame is the base link B
|
444
|
+
J_free_floating = jax.vmap(
|
445
|
+
lambda i: to_output(js.link.jacobian(model=model, data=data, link_index=i))
|
446
|
+
)(jnp.arange(model.number_of_links()))
|
447
|
+
|
448
|
+
return J_free_floating
|
449
|
+
|
450
|
+
|
451
|
+
@functools.partial(jax.jit, static_argnames=["prefer_aba"])
|
452
|
+
def forward_dynamics(
|
453
|
+
model: JaxSimModel,
|
454
|
+
data: js.data.JaxSimModelData,
|
455
|
+
*,
|
456
|
+
joint_forces: jtp.VectorLike | None = None,
|
457
|
+
external_forces: jtp.MatrixLike | None = None,
|
458
|
+
prefer_aba: float = True,
|
459
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
460
|
+
"""
|
461
|
+
Compute the forward dynamics of the model.
|
462
|
+
|
463
|
+
Args:
|
464
|
+
model: The model to consider.
|
465
|
+
data: The data of the considered model.
|
466
|
+
joint_forces:
|
467
|
+
The joint forces to consider as a vector of shape `(dofs,)`.
|
468
|
+
external_forces:
|
469
|
+
The external forces to consider as a matrix of shape `(nL, 6)`.
|
470
|
+
prefer_aba: Whether to prefer the ABA algorithm over the CRB one.
|
471
|
+
|
472
|
+
Returns:
|
473
|
+
A tuple containing the 6D acceleration in the active representation of the
|
474
|
+
base link and the joint accelerations resulting from the application of the
|
475
|
+
considered joint forces and external forces.
|
476
|
+
"""
|
477
|
+
|
478
|
+
forward_dynamics_fn = forward_dynamics_aba if prefer_aba else forward_dynamics_crb
|
479
|
+
|
480
|
+
return forward_dynamics_fn(
|
481
|
+
model=model,
|
482
|
+
data=data,
|
483
|
+
joint_forces=joint_forces,
|
484
|
+
external_forces=external_forces,
|
485
|
+
)
|
486
|
+
|
487
|
+
|
488
|
+
@jax.jit
|
489
|
+
def forward_dynamics_aba(
|
490
|
+
model: JaxSimModel,
|
491
|
+
data: js.data.JaxSimModelData,
|
492
|
+
*,
|
493
|
+
joint_forces: jtp.VectorLike | None = None,
|
494
|
+
external_forces: jtp.MatrixLike | None = None,
|
495
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
496
|
+
"""
|
497
|
+
Compute the forward dynamics of the model with the ABA algorithm.
|
498
|
+
|
499
|
+
Args:
|
500
|
+
model: The model to consider.
|
501
|
+
data: The data of the considered model.
|
502
|
+
joint_forces:
|
503
|
+
The joint forces to consider as a vector of shape `(dofs,)`.
|
504
|
+
external_forces:
|
505
|
+
The external forces to consider as a matrix of shape `(nL, 6)`.
|
506
|
+
|
507
|
+
Returns:
|
508
|
+
A tuple containing the 6D acceleration in the active representation of the
|
509
|
+
base link and the joint accelerations resulting from the application of the
|
510
|
+
considered joint forces and external forces.
|
511
|
+
"""
|
512
|
+
|
513
|
+
# Build joint torques if not provided
|
514
|
+
τ = (
|
515
|
+
joint_forces
|
516
|
+
if joint_forces is not None
|
517
|
+
else jnp.zeros_like(data.joint_positions())
|
518
|
+
)
|
519
|
+
|
520
|
+
# Build external forces if not provided
|
521
|
+
f_ext = (
|
522
|
+
external_forces
|
523
|
+
if external_forces is not None
|
524
|
+
else jnp.zeros((model.number_of_links(), 6))
|
525
|
+
)
|
526
|
+
|
527
|
+
# Compute ABA
|
528
|
+
W_v̇_WB, s̈ = jaxsim.physics.algos.aba.aba(
|
529
|
+
model=model.physics_model,
|
530
|
+
xfb=data.state.physics_model.xfb(),
|
531
|
+
q=data.state.physics_model.joint_positions,
|
532
|
+
qd=data.state.physics_model.joint_velocities,
|
533
|
+
tau=τ,
|
534
|
+
f_ext=f_ext,
|
535
|
+
)
|
536
|
+
|
537
|
+
def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
|
538
|
+
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
|
539
|
+
|
540
|
+
if data.velocity_representation != VelRepr.Mixed:
|
541
|
+
return C_X_W @ W_vd_WB
|
542
|
+
|
543
|
+
from jaxsim.math.cross import Cross
|
544
|
+
|
545
|
+
W_v_WC = jnp.hstack([W_vl_WC, jnp.zeros(3)])
|
546
|
+
return C_X_W @ (W_vd_WB - Cross.vx(W_v_WC) @ W_v_WB)
|
547
|
+
|
548
|
+
match data.velocity_representation:
|
549
|
+
case VelRepr.Inertial:
|
550
|
+
W_H_C = W_H_W = jnp.eye(4)
|
551
|
+
W_vl_WC = W_vl_WW = jnp.zeros(3)
|
552
|
+
|
553
|
+
case VelRepr.Body:
|
554
|
+
W_H_C = W_H_B = data.base_transform()
|
555
|
+
W_vl_WC = W_vl_WB = data.base_velocity()[0:3]
|
556
|
+
|
557
|
+
case VelRepr.Mixed:
|
558
|
+
W_H_B = data.base_transform()
|
559
|
+
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
560
|
+
W_vl_WC = W_vl_W_BW = data.base_velocity()[0:3]
|
561
|
+
|
562
|
+
case _:
|
563
|
+
raise ValueError(data.velocity_representation)
|
564
|
+
|
565
|
+
# We need to convert the derivative of the base acceleration to the active
|
566
|
+
# representation. In Mixed representation, this conversion is not a plain
|
567
|
+
# transformation with just X, but it also involves a cross product in ℝ⁶.
|
568
|
+
C_v̇_WB = to_active(
|
569
|
+
W_vd_WB=W_v̇_WB.squeeze(),
|
570
|
+
W_H_C=W_H_C,
|
571
|
+
W_v_WB=jnp.hstack(
|
572
|
+
[
|
573
|
+
data.state.physics_model.base_linear_velocity,
|
574
|
+
data.state.physics_model.base_angular_velocity,
|
575
|
+
]
|
576
|
+
),
|
577
|
+
W_vl_WC=W_vl_WC,
|
578
|
+
)
|
579
|
+
|
580
|
+
# Adjust shape
|
581
|
+
s̈ = jnp.atleast_1d(s̈.squeeze())
|
582
|
+
|
583
|
+
return C_v̇_WB, s̈
|
584
|
+
|
585
|
+
|
586
|
+
@jax.jit
|
587
|
+
def forward_dynamics_crb(
|
588
|
+
model: JaxSimModel,
|
589
|
+
data: js.data.JaxSimModelData,
|
590
|
+
*,
|
591
|
+
joint_forces: jtp.MatrixLike | None = None,
|
592
|
+
external_forces: jtp.MatrixLike | None = None,
|
593
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
594
|
+
"""
|
595
|
+
Compute the forward dynamics of the model with the CRB algorithm.
|
596
|
+
|
597
|
+
Args:
|
598
|
+
model: The model to consider.
|
599
|
+
data: The data of the considered model.
|
600
|
+
joint_forces:
|
601
|
+
The joint forces to consider as a vector of shape `(dofs,)`.
|
602
|
+
external_forces:
|
603
|
+
The external forces to consider as a matrix of shape `(nL, 6)`.
|
604
|
+
|
605
|
+
Returns:
|
606
|
+
A tuple containing the 6D acceleration in the active representation of the
|
607
|
+
base link and the joint accelerations resulting from the application of the
|
608
|
+
considered joint forces and external forces.
|
609
|
+
|
610
|
+
Note:
|
611
|
+
Compared to ABA, this method could be significantly slower, especially for
|
612
|
+
models with a large number of degrees of freedom.
|
613
|
+
"""
|
614
|
+
|
615
|
+
# Build joint torques if not provided
|
616
|
+
τ = (
|
617
|
+
joint_forces
|
618
|
+
if joint_forces is not None
|
619
|
+
else jnp.zeros_like(data.joint_positions())
|
620
|
+
)
|
621
|
+
|
622
|
+
# Build external forces if not provided
|
623
|
+
external_forces = (
|
624
|
+
external_forces
|
625
|
+
if external_forces is not None
|
626
|
+
else jnp.zeros(shape=(model.number_of_links(), 6))
|
627
|
+
)
|
628
|
+
|
629
|
+
# Handle models with zero and one DoFs
|
630
|
+
τ = jnp.atleast_1d(τ.squeeze())
|
631
|
+
τ = jnp.vstack(τ) if τ.size > 0 else jnp.empty(shape=(0, 1))
|
632
|
+
|
633
|
+
# Compute terms of the floating-base EoM
|
634
|
+
M = free_floating_mass_matrix(model=model, data=data)
|
635
|
+
h = jnp.vstack(free_floating_bias_forces(model=model, data=data))
|
636
|
+
J = jnp.vstack(generalized_free_floating_jacobian(model=model, data=data))
|
637
|
+
f_ext = jnp.vstack(external_forces.flatten())
|
638
|
+
S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
|
639
|
+
|
640
|
+
# TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i)
|
641
|
+
if model.floating_base():
|
642
|
+
ν̇ = jnp.linalg.inv(M) @ ((S @ τ) - h + J.T @ f_ext)
|
643
|
+
else:
|
644
|
+
v̇_WB = jnp.zeros(6)
|
645
|
+
s̈ = jnp.linalg.inv(M[6:, 6:]) @ ((S @ τ)[6:] - h[6:] + J[:, 6:].T @ f_ext)
|
646
|
+
ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()])
|
647
|
+
|
648
|
+
# Extract the base acceleration in the active representation.
|
649
|
+
# Note that this is an apparent acceleration (relevant in Mixed representation),
|
650
|
+
# therefore it cannot be always expressed in different frames with just a
|
651
|
+
# 6D transformation X.
|
652
|
+
v̇_WB = ν̇[0:6].squeeze().astype(float)
|
653
|
+
|
654
|
+
# Extract the joint accelerations
|
655
|
+
s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float)
|
656
|
+
|
657
|
+
return v̇_WB, s̈
|
658
|
+
|
659
|
+
|
660
|
+
@jax.jit
|
661
|
+
def free_floating_mass_matrix(
|
662
|
+
model: JaxSimModel, data: js.data.JaxSimModelData
|
663
|
+
) -> jtp.Matrix:
|
664
|
+
"""
|
665
|
+
Compute the free-floating mass matrix of the model with the CRBA algorithm.
|
666
|
+
|
667
|
+
Args:
|
668
|
+
model: The model to consider.
|
669
|
+
data: The data of the considered model.
|
670
|
+
|
671
|
+
Returns:
|
672
|
+
The free-floating mass matrix of the model.
|
673
|
+
"""
|
674
|
+
|
675
|
+
M_body = jaxsim.physics.algos.crba.crba(
|
676
|
+
model=model.physics_model,
|
677
|
+
q=data.state.physics_model.joint_positions,
|
678
|
+
)
|
679
|
+
|
680
|
+
match data.velocity_representation:
|
681
|
+
case VelRepr.Body:
|
682
|
+
return M_body
|
683
|
+
|
684
|
+
case VelRepr.Inertial:
|
685
|
+
zero_6n = jnp.zeros(shape=(6, model.dofs()))
|
686
|
+
B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
|
687
|
+
|
688
|
+
invT = jnp.vstack(
|
689
|
+
[
|
690
|
+
jnp.block([B_X_W, zero_6n]),
|
691
|
+
jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
|
692
|
+
]
|
693
|
+
)
|
694
|
+
|
695
|
+
return invT.T @ M_body @ invT
|
696
|
+
|
697
|
+
case VelRepr.Mixed:
|
698
|
+
zero_6n = jnp.zeros(shape=(6, model.dofs()))
|
699
|
+
W_H_BW = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
700
|
+
BW_X_W = jaxlie.SE3.from_matrix(W_H_BW).inverse().adjoint()
|
701
|
+
|
702
|
+
invT = jnp.vstack(
|
703
|
+
[
|
704
|
+
jnp.block([BW_X_W, zero_6n]),
|
705
|
+
jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
|
706
|
+
]
|
707
|
+
)
|
708
|
+
|
709
|
+
return invT.T @ M_body @ invT
|
710
|
+
|
711
|
+
case _:
|
712
|
+
raise ValueError(data.velocity_representation)
|
713
|
+
|
714
|
+
|
715
|
+
@jax.jit
|
716
|
+
def inverse_dynamics(
|
717
|
+
model: JaxSimModel,
|
718
|
+
data: js.data.JaxSimModelData,
|
719
|
+
*,
|
720
|
+
joint_accelerations: jtp.Vector | None = None,
|
721
|
+
base_acceleration: jtp.Vector | None = None,
|
722
|
+
external_forces: jtp.Matrix | None = None,
|
723
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
724
|
+
"""
|
725
|
+
Compute inverse dynamics with the RNEA algorithm.
|
726
|
+
|
727
|
+
Args:
|
728
|
+
model: The model to consider.
|
729
|
+
data: The data of the considered model.
|
730
|
+
joint_accelerations:
|
731
|
+
The joint accelerations to consider as a vector of shape `(dofs,)`.
|
732
|
+
base_acceleration:
|
733
|
+
The base acceleration to consider as a vector of shape `(6,)`.
|
734
|
+
external_forces:
|
735
|
+
The external forces to consider as a matrix of shape `(nL, 6)`.
|
736
|
+
|
737
|
+
Returns:
|
738
|
+
A tuple containing the 6D force in the active representation applied to the
|
739
|
+
base to obtain the considered base acceleration, and the joint forces to apply
|
740
|
+
to obtain the considered joint accelerations.
|
741
|
+
"""
|
742
|
+
|
743
|
+
# Build joint accelerations if not provided
|
744
|
+
joint_accelerations = (
|
745
|
+
joint_accelerations
|
746
|
+
if joint_accelerations is not None
|
747
|
+
else jnp.zeros_like(data.joint_positions())
|
748
|
+
)
|
749
|
+
|
750
|
+
# Build base acceleration if not provided
|
751
|
+
base_acceleration = (
|
752
|
+
base_acceleration if base_acceleration is not None else jnp.zeros(6)
|
753
|
+
)
|
754
|
+
|
755
|
+
external_forces = (
|
756
|
+
external_forces
|
757
|
+
if external_forces is not None
|
758
|
+
else jnp.zeros(shape=(model.number_of_links(), 6))
|
759
|
+
)
|
760
|
+
|
761
|
+
def to_inertial(C_vd_WB, W_H_C, C_v_WB, W_vl_WC):
|
762
|
+
W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
|
763
|
+
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
|
764
|
+
|
765
|
+
if data.velocity_representation != VelRepr.Mixed:
|
766
|
+
return W_X_C @ C_vd_WB
|
767
|
+
else:
|
768
|
+
from jaxsim.math.cross import Cross
|
769
|
+
|
770
|
+
C_v_WC = C_X_W @ jnp.hstack([W_vl_WC, jnp.zeros(3)])
|
771
|
+
return W_X_C @ (C_vd_WB + Cross.vx(C_v_WC) @ C_v_WB)
|
772
|
+
|
773
|
+
match data.velocity_representation:
|
774
|
+
case VelRepr.Inertial:
|
775
|
+
W_H_C = W_H_W = jnp.eye(4)
|
776
|
+
W_vl_WC = W_vl_WW = jnp.zeros(3)
|
777
|
+
|
778
|
+
case VelRepr.Body:
|
779
|
+
W_H_C = W_H_B = data.base_transform()
|
780
|
+
W_vl_WC = W_vl_WB = data.base_velocity()[0:3]
|
781
|
+
|
782
|
+
case VelRepr.Mixed:
|
783
|
+
W_H_B = data.base_transform()
|
784
|
+
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
785
|
+
W_vl_WC = W_vl_W_BW = data.base_velocity()[0:3]
|
786
|
+
|
787
|
+
case _:
|
788
|
+
raise ValueError(data.velocity_representation)
|
789
|
+
|
790
|
+
# We need to convert the derivative of the base acceleration to the Inertial
|
791
|
+
# representation. In Mixed representation, this conversion is not a plain
|
792
|
+
# transformation with just X, but it also involves a cross product in ℝ⁶.
|
793
|
+
W_v̇_WB = to_inertial(
|
794
|
+
C_vd_WB=base_acceleration,
|
795
|
+
W_H_C=W_H_C,
|
796
|
+
C_v_WB=data.base_velocity(),
|
797
|
+
W_vl_WC=W_vl_WC,
|
798
|
+
)
|
799
|
+
|
800
|
+
# Compute RNEA
|
801
|
+
W_f_B, τ = jaxsim.physics.algos.rnea.rnea(
|
802
|
+
model=model.physics_model,
|
803
|
+
xfb=data.state.physics_model.xfb(),
|
804
|
+
q=data.state.physics_model.joint_positions,
|
805
|
+
qd=data.state.physics_model.joint_velocities,
|
806
|
+
qdd=joint_accelerations,
|
807
|
+
a0fb=W_v̇_WB,
|
808
|
+
f_ext=external_forces,
|
809
|
+
)
|
810
|
+
|
811
|
+
# Adjust shape
|
812
|
+
τ = jnp.atleast_1d(τ.squeeze())
|
813
|
+
|
814
|
+
# Express W_f_B in the active representation
|
815
|
+
f_B = js.data.JaxSimModelData.inertial_to_other_representation(
|
816
|
+
array=W_f_B,
|
817
|
+
other_representation=data.velocity_representation,
|
818
|
+
transform=data.base_transform(),
|
819
|
+
is_force=True,
|
820
|
+
).squeeze()
|
821
|
+
|
822
|
+
return f_B.astype(float), τ.astype(float)
|
823
|
+
|
824
|
+
|
825
|
+
@jax.jit
|
826
|
+
def free_floating_gravity_forces(
|
827
|
+
model: JaxSimModel, data: js.data.JaxSimModelData
|
828
|
+
) -> jtp.Vector:
|
829
|
+
"""
|
830
|
+
Compute the free-floating gravity forces :math:`g(\mathbf{q})` of the model.
|
831
|
+
|
832
|
+
Args:
|
833
|
+
model: The model to consider.
|
834
|
+
data: The data of the considered model.
|
835
|
+
|
836
|
+
Returns:
|
837
|
+
The free-floating gravity forces of the model.
|
838
|
+
"""
|
839
|
+
|
840
|
+
# Build a zeroed state
|
841
|
+
data_rnea = js.data.JaxSimModelData.zero(model=model)
|
842
|
+
|
843
|
+
# Set just the generalized position
|
844
|
+
with data_rnea.mutable_context(
|
845
|
+
mutability=Mutability.MUTABLE, restore_after_exception=False
|
846
|
+
):
|
847
|
+
|
848
|
+
data_rnea.state.physics_model.base_position = (
|
849
|
+
data.state.physics_model.base_position
|
850
|
+
)
|
851
|
+
|
852
|
+
data_rnea.state.physics_model.base_quaternion = (
|
853
|
+
data.state.physics_model.base_quaternion
|
854
|
+
)
|
855
|
+
|
856
|
+
data_rnea.state.physics_model.joint_positions = (
|
857
|
+
data.state.physics_model.joint_positions
|
858
|
+
)
|
859
|
+
|
860
|
+
return jnp.hstack(
|
861
|
+
inverse_dynamics(
|
862
|
+
model=model,
|
863
|
+
data=data_rnea,
|
864
|
+
# Set zero inputs:
|
865
|
+
joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
|
866
|
+
base_acceleration=jnp.zeros(6),
|
867
|
+
external_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
|
868
|
+
)
|
869
|
+
).astype(float)
|
870
|
+
|
871
|
+
|
872
|
+
@jax.jit
|
873
|
+
def free_floating_bias_forces(
|
874
|
+
model: JaxSimModel, data: js.data.JaxSimModelData
|
875
|
+
) -> jtp.Vector:
|
876
|
+
"""
|
877
|
+
Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})`
|
878
|
+
of the model.
|
879
|
+
|
880
|
+
Args:
|
881
|
+
model: The model to consider.
|
882
|
+
data: The data of the considered model.
|
883
|
+
|
884
|
+
Returns:
|
885
|
+
The free-floating bias forces of the model.
|
886
|
+
"""
|
887
|
+
|
888
|
+
# Build a zeroed state
|
889
|
+
data_rnea = js.data.JaxSimModelData.zero(model=model)
|
890
|
+
|
891
|
+
# Set the generalized position and generalized velocity
|
892
|
+
with data_rnea.mutable_context(
|
893
|
+
mutability=Mutability.MUTABLE, restore_after_exception=False
|
894
|
+
):
|
895
|
+
|
896
|
+
data_rnea.state.physics_model.base_position = (
|
897
|
+
data.state.physics_model.base_position
|
898
|
+
)
|
899
|
+
|
900
|
+
data_rnea.state.physics_model.base_quaternion = (
|
901
|
+
data.state.physics_model.base_quaternion
|
902
|
+
)
|
903
|
+
|
904
|
+
data_rnea.state.physics_model.joint_positions = (
|
905
|
+
data.state.physics_model.joint_positions
|
906
|
+
)
|
907
|
+
|
908
|
+
data_rnea.state.physics_model.base_linear_velocity = (
|
909
|
+
data.state.physics_model.base_linear_velocity
|
910
|
+
)
|
911
|
+
|
912
|
+
data_rnea.state.physics_model.base_angular_velocity = (
|
913
|
+
data.state.physics_model.base_angular_velocity
|
914
|
+
)
|
915
|
+
|
916
|
+
data_rnea.state.physics_model.joint_velocities = (
|
917
|
+
data.state.physics_model.joint_velocities
|
918
|
+
)
|
919
|
+
|
920
|
+
return jnp.hstack(
|
921
|
+
inverse_dynamics(
|
922
|
+
model=model,
|
923
|
+
data=data_rnea,
|
924
|
+
# Set zero inputs:
|
925
|
+
joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
|
926
|
+
base_acceleration=jnp.zeros(6),
|
927
|
+
external_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
|
928
|
+
)
|
929
|
+
).astype(float)
|
930
|
+
|
931
|
+
|
932
|
+
# ==========================
|
933
|
+
# Other kinematic quantities
|
934
|
+
# ==========================
|
935
|
+
|
936
|
+
|
937
|
+
@jax.jit
|
938
|
+
def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
|
939
|
+
"""
|
940
|
+
Compute the total momentum of the model.
|
941
|
+
|
942
|
+
Args:
|
943
|
+
model: The model to consider.
|
944
|
+
data: The data of the considered model.
|
945
|
+
|
946
|
+
Returns:
|
947
|
+
The total momentum of the model.
|
948
|
+
"""
|
949
|
+
|
950
|
+
# Compute the momentum in body-fixed velocity representation.
|
951
|
+
# Note: the first 6 rows of the mass matrix define the jacobian of the
|
952
|
+
# floating-base momentum.
|
953
|
+
with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
|
954
|
+
B_ν = data.generalized_velocity()
|
955
|
+
M_B = free_floating_mass_matrix(model=model, data=data)
|
956
|
+
|
957
|
+
# Compute the total momentum expressed in the base frame
|
958
|
+
B_h = M_B[0:6, :] @ B_ν
|
959
|
+
|
960
|
+
# Compute the 6D transformation matrix
|
961
|
+
W_H_B = data.base_transform()
|
962
|
+
B_X_W: jtp.Array = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
|
963
|
+
|
964
|
+
# Convert to inertial-fixed representation
|
965
|
+
# (its coordinates transform like 6D forces)
|
966
|
+
W_h = B_X_W.T @ B_h
|
967
|
+
|
968
|
+
# Convert to the active representation of the model
|
969
|
+
return js.data.JaxSimModelData.inertial_to_other_representation(
|
970
|
+
array=W_h,
|
971
|
+
other_representation=data.velocity_representation,
|
972
|
+
transform=W_H_B,
|
973
|
+
is_force=True,
|
974
|
+
).astype(float)
|
975
|
+
|
976
|
+
|
977
|
+
# ======
|
978
|
+
# Energy
|
979
|
+
# ======
|
980
|
+
|
981
|
+
|
982
|
+
@jax.jit
|
983
|
+
def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
|
984
|
+
"""
|
985
|
+
Compute the mechanical energy of the model.
|
986
|
+
|
987
|
+
Args:
|
988
|
+
model: The model to consider.
|
989
|
+
data: The data of the considered model.
|
990
|
+
|
991
|
+
Returns:
|
992
|
+
The mechanical energy of the model.
|
993
|
+
"""
|
994
|
+
|
995
|
+
K = kinetic_energy(model=model, data=data)
|
996
|
+
U = potential_energy(model=model, data=data)
|
997
|
+
|
998
|
+
return (K + U).astype(float)
|
999
|
+
|
1000
|
+
|
1001
|
+
@jax.jit
|
1002
|
+
def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
|
1003
|
+
"""
|
1004
|
+
Compute the kinetic energy of the model.
|
1005
|
+
|
1006
|
+
Args:
|
1007
|
+
model: The model to consider.
|
1008
|
+
data: The data of the considered model.
|
1009
|
+
|
1010
|
+
Returns:
|
1011
|
+
The kinetic energy of the model.
|
1012
|
+
"""
|
1013
|
+
|
1014
|
+
with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
|
1015
|
+
B_ν = data.generalized_velocity()
|
1016
|
+
M_B = free_floating_mass_matrix(model=model, data=data)
|
1017
|
+
|
1018
|
+
K = 0.5 * B_ν.T @ M_B @ B_ν
|
1019
|
+
return K.squeeze().astype(float)
|
1020
|
+
|
1021
|
+
|
1022
|
+
@jax.jit
|
1023
|
+
def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
|
1024
|
+
"""
|
1025
|
+
Compute the potential energy of the model.
|
1026
|
+
|
1027
|
+
Args:
|
1028
|
+
model: The model to consider.
|
1029
|
+
data: The data of the considered model.
|
1030
|
+
|
1031
|
+
Returns:
|
1032
|
+
The potential energy of the model.
|
1033
|
+
"""
|
1034
|
+
|
1035
|
+
m = total_mass(model=model)
|
1036
|
+
gravity = data.gravity.squeeze()
|
1037
|
+
W_p̃_CoM = jnp.hstack([com_position(model=model, data=data), 1])
|
1038
|
+
|
1039
|
+
U = -jnp.hstack([gravity, 0]) @ (m * W_p̃_CoM)
|
1040
|
+
return U.squeeze().astype(float)
|
1041
|
+
|
1042
|
+
|
1043
|
+
# ==========
|
1044
|
+
# Simulation
|
1045
|
+
# ==========
|
1046
|
+
|
1047
|
+
|
1048
|
+
@functools.partial(jax.jit, static_argnames=["integrator"])
|
1049
|
+
def step(
|
1050
|
+
model: JaxSimModel,
|
1051
|
+
data: js.data.JaxSimModelData,
|
1052
|
+
*,
|
1053
|
+
dt: jtp.FloatLike,
|
1054
|
+
integrator: jaxsim.integrators.Integrator,
|
1055
|
+
integrator_state: dict[str, Any] | None = None,
|
1056
|
+
joint_forces: jtp.Vector | None = None,
|
1057
|
+
external_forces: jtp.Vector | None = None,
|
1058
|
+
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
|
1059
|
+
"""
|
1060
|
+
Perform a simulation step.
|
1061
|
+
|
1062
|
+
Args:
|
1063
|
+
model: The model to consider.
|
1064
|
+
data: The data of the considered model.
|
1065
|
+
dt: The time step to consider.
|
1066
|
+
integrator: The integrator to use.
|
1067
|
+
integrator_state: The state of the integrator.
|
1068
|
+
joint_forces: The joint forces to consider.
|
1069
|
+
external_forces: The external forces to consider.
|
1070
|
+
|
1071
|
+
Returns:
|
1072
|
+
A tuple containing the new data of the model
|
1073
|
+
and the new state of the integrator.
|
1074
|
+
"""
|
1075
|
+
|
1076
|
+
integrator_state = integrator_state if integrator_state is not None else dict()
|
1077
|
+
|
1078
|
+
# Extract the initial resources.
|
1079
|
+
t0_ns = data.time_ns
|
1080
|
+
state_x0 = data.state
|
1081
|
+
integrator_state_x0 = integrator_state
|
1082
|
+
|
1083
|
+
# Step the dynamics forward.
|
1084
|
+
state_xf, integrator_state_xf = integrator.step(
|
1085
|
+
x0=state_x0,
|
1086
|
+
t0=jnp.array(t0_ns * 1e9).astype(float),
|
1087
|
+
dt=dt,
|
1088
|
+
params=integrator_state_x0,
|
1089
|
+
**dict(joint_forces=joint_forces, external_forces=external_forces),
|
1090
|
+
)
|
1091
|
+
|
1092
|
+
return (
|
1093
|
+
# Store the new state of the model and the new time.
|
1094
|
+
data.replace(
|
1095
|
+
state=state_xf,
|
1096
|
+
time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
|
1097
|
+
),
|
1098
|
+
integrator_state_xf,
|
1099
|
+
)
|