jaxsim 0.2.dev101__py3-none-any.whl → 0.2.dev166__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/api/model.py ADDED
@@ -0,0 +1,1099 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import functools
5
+ import pathlib
6
+ from typing import Any
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import jax_dataclasses
11
+ import jaxlie
12
+ import rod
13
+ from jax_dataclasses import Static
14
+
15
+ import jaxsim.api as js
16
+ import jaxsim.physics.algos.aba
17
+ import jaxsim.physics.algos.crba
18
+ import jaxsim.physics.algos.forward_kinematics
19
+ import jaxsim.physics.algos.rnea
20
+ import jaxsim.physics.model.physics_model
21
+ import jaxsim.typing as jtp
22
+ from jaxsim.high_level.common import VelRepr
23
+ from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
24
+ from jaxsim.utils import JaxsimDataclass, Mutability
25
+
26
+
27
+ @jax_dataclasses.pytree_dataclass
28
+ class JaxSimModel(JaxsimDataclass):
29
+ """
30
+ The JaxSim model defining the kinematics and dynamics of a robot.
31
+ """
32
+
33
+ model_name: Static[str]
34
+
35
+ physics_model: jaxsim.physics.model.physics_model.PhysicsModel = dataclasses.field(
36
+ repr=False
37
+ )
38
+
39
+ terrain: Static[Terrain] = dataclasses.field(default=FlatTerrain(), repr=False)
40
+
41
+ built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
42
+ repr=False, default=None
43
+ )
44
+
45
+ _number_of_links: Static[int] = dataclasses.field(
46
+ init=False, repr=False, default=None
47
+ )
48
+
49
+ _number_of_joints: Static[int] = dataclasses.field(
50
+ init=False, repr=False, default=None
51
+ )
52
+
53
+ def __post_init__(self):
54
+
55
+ # These attributes are Static so that we can use `jax.vmap` and `jax.lax.scan`
56
+ # over the all links and joints
57
+ with self.mutable_context(
58
+ mutability=Mutability.MUTABLE_NO_VALIDATION,
59
+ restore_after_exception=False,
60
+ ):
61
+ self._number_of_links = len(self.physics_model.description.links_dict)
62
+ self._number_of_joints = len(self.physics_model.description.joints_dict)
63
+
64
+ # ========================
65
+ # Initialization and state
66
+ # ========================
67
+
68
+ @staticmethod
69
+ def build_from_model_description(
70
+ model_description: str | pathlib.Path | rod.Model,
71
+ model_name: str | None = None,
72
+ gravity: jtp.Array = jaxsim.physics.default_gravity(),
73
+ is_urdf: bool | None = None,
74
+ considered_joints: list[str] | None = None,
75
+ ) -> JaxSimModel:
76
+ """
77
+ Build a Model object from a model description.
78
+
79
+ Args:
80
+ model_description:
81
+ A path to an SDF/URDF file, a string containing
82
+ its content, or a pre-parsed/pre-built rod model.
83
+ model_name:
84
+ The optional name of the model that overrides the one in
85
+ the description.
86
+ gravity: The 3D gravity vector.
87
+ is_urdf:
88
+ Whether the model description is a URDF or an SDF. This is
89
+ automatically inferred if the model description is a path to a file.
90
+ considered_joints:
91
+ The list of joints to consider. If None, all joints are considered.
92
+
93
+ Returns:
94
+ The built Model object.
95
+ """
96
+
97
+ import jaxsim.parsers.rod
98
+
99
+ # Parse the input resource (either a path to file or a string with the URDF/SDF)
100
+ # and build the -intermediate- model description
101
+ intermediate_description = jaxsim.parsers.rod.build_model_description(
102
+ model_description=model_description, is_urdf=is_urdf
103
+ )
104
+
105
+ # Lump links together if not all joints are considered.
106
+ # Note: this procedure assigns a zero position to all joints not considered.
107
+ if considered_joints is not None:
108
+ intermediate_description = intermediate_description.reduce(
109
+ considered_joints=considered_joints
110
+ )
111
+
112
+ # Create the physics model from the model description
113
+ physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
114
+ model_description=intermediate_description, gravity=gravity
115
+ )
116
+
117
+ # Build the model
118
+ model = JaxSimModel.build(physics_model=physics_model, model_name=model_name)
119
+
120
+ with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
121
+ model.built_from = model_description
122
+
123
+ return model
124
+
125
+ @staticmethod
126
+ def build(
127
+ physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
128
+ model_name: str | None = None,
129
+ ) -> JaxSimModel:
130
+ """
131
+ Build a Model object from a physics model.
132
+
133
+ Args:
134
+ physics_model: The physics model.
135
+ model_name:
136
+ The optional name of the model overriding the physics model name.
137
+
138
+ Returns:
139
+ The built Model object.
140
+ """
141
+
142
+ # Set the model name (if not provided, use the one from the model description)
143
+ model_name = (
144
+ model_name if model_name is not None else physics_model.description.name
145
+ )
146
+
147
+ # Build the model
148
+ model = JaxSimModel(physics_model=physics_model, model_name=model_name) # noqa
149
+
150
+ return model
151
+
152
+ # ==========
153
+ # Properties
154
+ # ==========
155
+
156
+ def name(self) -> str:
157
+ """
158
+ Return the name of the model.
159
+
160
+ Returns:
161
+ The name of the model.
162
+ """
163
+
164
+ return self.model_name
165
+
166
+ def number_of_links(self) -> jtp.Int:
167
+ """
168
+ Return the number of links in the model.
169
+
170
+ Returns:
171
+ The number of links in the model.
172
+
173
+ Note:
174
+ The base link is included in the count and its index is always 0.
175
+ """
176
+
177
+ return self._number_of_links
178
+
179
+ def number_of_joints(self) -> jtp.Int:
180
+ """
181
+ Return the number of joints in the model.
182
+
183
+ Returns:
184
+ The number of joints in the model.
185
+ """
186
+
187
+ return self._number_of_joints
188
+
189
+ # =================
190
+ # Base link methods
191
+ # =================
192
+
193
+ def floating_base(self) -> bool:
194
+ """
195
+ Return whether the model has a floating base.
196
+
197
+ Returns:
198
+ True if the model is floating-base, False otherwise.
199
+ """
200
+
201
+ return self.physics_model.is_floating_base
202
+
203
+ def base_link(self) -> str:
204
+ """
205
+ Return the name of the base link.
206
+
207
+ Returns:
208
+ The name of the base link.
209
+ """
210
+
211
+ return self.physics_model.description.root.name
212
+
213
+ # =====================
214
+ # Joint-related methods
215
+ # =====================
216
+
217
+ def dofs(self) -> int:
218
+ """
219
+ Return the number of degrees of freedom of the model.
220
+
221
+ Returns:
222
+ The number of degrees of freedom of the model.
223
+
224
+ Note:
225
+ We do not yet support multi-DoF joints, therefore this is always equal to
226
+ the number of joints. In the future, this could be different.
227
+ """
228
+
229
+ return len(self.physics_model.description.joints_dict)
230
+
231
+ def joint_names(self) -> tuple[str, ...]:
232
+ """
233
+ Return the names of the joints in the model.
234
+
235
+ Returns:
236
+ The names of the joints in the model.
237
+ """
238
+
239
+ return tuple(self.physics_model.description.joints_dict.keys())
240
+
241
+ # ====================
242
+ # Link-related methods
243
+ # ====================
244
+
245
+ def link_names(self) -> tuple[str, ...]:
246
+ """
247
+ Return the names of the links in the model.
248
+
249
+ Returns:
250
+ The names of the links in the model.
251
+ """
252
+
253
+ return tuple(self.physics_model.description.links_dict.keys())
254
+
255
+
256
+ # =====================
257
+ # Model post-processing
258
+ # =====================
259
+
260
+
261
+ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimModel:
262
+ """
263
+ Reduce the model by lumping together the links connected by removed joints.
264
+
265
+ Args:
266
+ model: The model to reduce.
267
+ considered_joints: The sequence of joints to consider.
268
+
269
+ Note:
270
+ If considered_joints contains joints not existing in the model, the method
271
+ will raise an exception. If considered_joints is empty, the method will
272
+ return a copy of the input model.
273
+ """
274
+
275
+ if len(considered_joints) == 0:
276
+ return model.copy()
277
+
278
+ # Reduce the model description.
279
+ # If considered_joints contains joints not existing in the model, the method
280
+ # will raise an exception.
281
+ reduced_intermediate_description = model.physics_model.description.reduce(
282
+ considered_joints=list(considered_joints)
283
+ )
284
+
285
+ # Create the physics model from the reduced model description
286
+ physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
287
+ model_description=reduced_intermediate_description,
288
+ gravity=model.physics_model.gravity[0:3],
289
+ )
290
+
291
+ # Build the reduced model
292
+ reduced_model = JaxSimModel.build(
293
+ physics_model=physics_model, model_name=model.name()
294
+ )
295
+
296
+ return reduced_model
297
+
298
+
299
+ # ===================
300
+ # Inertial properties
301
+ # ===================
302
+
303
+
304
+ @jax.jit
305
+ def total_mass(model: JaxSimModel) -> jtp.Float:
306
+ """
307
+ Compute the total mass of the model.
308
+
309
+ Args:
310
+ model: The model to consider.
311
+
312
+ Returns:
313
+ The total mass of the model.
314
+ """
315
+
316
+ return (
317
+ jax.vmap(lambda idx: js.link.mass(model=model, link_index=idx))(
318
+ jnp.arange(model.number_of_links())
319
+ )
320
+ .sum()
321
+ .astype(float)
322
+ )
323
+
324
+
325
+ # ==============
326
+ # Center of mass
327
+ # ==============
328
+
329
+
330
+ @jax.jit
331
+ def com_position(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
332
+ """
333
+ Compute the position of the center of mass of the model.
334
+
335
+ Args:
336
+ model: The model to consider.
337
+ data: The data of the considered model.
338
+
339
+ Returns:
340
+ The position of the center of mass of the model w.r.t. the world frame.
341
+ """
342
+
343
+ m = total_mass(model=model)
344
+
345
+ W_H_L = forward_kinematics(model=model, data=data)
346
+ W_H_B = data.base_transform()
347
+ B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
348
+
349
+ def B_p̃_LCoM(i) -> jtp.Vector:
350
+ m = js.link.mass(model=model, link_index=i)
351
+ L_p_LCoM = js.link.com_position(
352
+ model=model, data=data, link_index=i, in_link_frame=True
353
+ )
354
+ return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
355
+
356
+ com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
357
+
358
+ B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
359
+ B_p̃_CoM = B_p̃_CoM.at[3].set(1)
360
+
361
+ return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
362
+
363
+
364
+ # ==============================
365
+ # Rigid Body Dynamics Algorithms
366
+ # ==============================
367
+
368
+
369
+ @jax.jit
370
+ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
371
+ """
372
+ Compute the SE(3) transforms from the world frame to the frames of all links.
373
+
374
+ Args:
375
+ model: The model to consider.
376
+ data: The data of the considered model.
377
+
378
+ Returns:
379
+ A (nL, 4, 4) array containing the stacked SE(3) transforms of the links.
380
+ The first axis is the link index.
381
+ """
382
+
383
+ W_H_LL = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
384
+ model=model.physics_model,
385
+ q=data.state.physics_model.joint_positions,
386
+ xfb=data.state.physics_model.xfb(),
387
+ )
388
+
389
+ return jnp.atleast_3d(W_H_LL).astype(float)
390
+
391
+
392
+ @jax.jit
393
+ def generalized_free_floating_jacobian(
394
+ model: JaxSimModel,
395
+ data: js.data.JaxSimModelData,
396
+ *,
397
+ output_vel_repr: VelRepr | None = None,
398
+ ) -> jtp.Matrix:
399
+ """
400
+ Compute the free-floating jacobians of all links.
401
+
402
+ Args:
403
+ model: The model to consider.
404
+ data: The data of the considered model.
405
+ output_vel_repr:
406
+ The output velocity representation of the free-floating jacobians.
407
+
408
+ Returns:
409
+ The (nL, 6, 6+dofs) array containing the stacked free-floating
410
+ jacobians of the links. The first axis is the link index.
411
+ """
412
+
413
+ if output_vel_repr is None:
414
+ output_vel_repr = data.velocity_representation
415
+
416
+ # The body frame of the Link.jacobian method is the link frame L.
417
+ # In this method, we want instead to use the base link B as body frame.
418
+ # Therefore, we always get the link jacobian having Inertial as output
419
+ # representation, and then we convert it to the desired output representation.
420
+ match output_vel_repr:
421
+ case VelRepr.Inertial:
422
+ to_output = lambda J: J
423
+
424
+ case VelRepr.Body:
425
+
426
+ def to_output(W_J_Wi):
427
+ W_H_B = data.base_transform()
428
+ B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
429
+ return B_X_W @ W_J_Wi
430
+
431
+ case VelRepr.Mixed:
432
+
433
+ def to_output(W_J_Wi):
434
+ W_H_B = data.base_transform()
435
+ W_H_BW = jnp.array(W_H_B).at[0:3, 0:3].set(jnp.eye(3))
436
+ BW_X_W = jaxlie.SE3.from_matrix(W_H_BW).inverse().adjoint()
437
+ return BW_X_W @ W_J_Wi
438
+
439
+ case _:
440
+ raise ValueError(output_vel_repr)
441
+
442
+ # Get the link jacobians in Inertial representation and convert them to the
443
+ # target output representation in which the body frame is the base link B
444
+ J_free_floating = jax.vmap(
445
+ lambda i: to_output(js.link.jacobian(model=model, data=data, link_index=i))
446
+ )(jnp.arange(model.number_of_links()))
447
+
448
+ return J_free_floating
449
+
450
+
451
+ @functools.partial(jax.jit, static_argnames=["prefer_aba"])
452
+ def forward_dynamics(
453
+ model: JaxSimModel,
454
+ data: js.data.JaxSimModelData,
455
+ *,
456
+ joint_forces: jtp.VectorLike | None = None,
457
+ external_forces: jtp.MatrixLike | None = None,
458
+ prefer_aba: float = True,
459
+ ) -> tuple[jtp.Vector, jtp.Vector]:
460
+ """
461
+ Compute the forward dynamics of the model.
462
+
463
+ Args:
464
+ model: The model to consider.
465
+ data: The data of the considered model.
466
+ joint_forces:
467
+ The joint forces to consider as a vector of shape `(dofs,)`.
468
+ external_forces:
469
+ The external forces to consider as a matrix of shape `(nL, 6)`.
470
+ prefer_aba: Whether to prefer the ABA algorithm over the CRB one.
471
+
472
+ Returns:
473
+ A tuple containing the 6D acceleration in the active representation of the
474
+ base link and the joint accelerations resulting from the application of the
475
+ considered joint forces and external forces.
476
+ """
477
+
478
+ forward_dynamics_fn = forward_dynamics_aba if prefer_aba else forward_dynamics_crb
479
+
480
+ return forward_dynamics_fn(
481
+ model=model,
482
+ data=data,
483
+ joint_forces=joint_forces,
484
+ external_forces=external_forces,
485
+ )
486
+
487
+
488
+ @jax.jit
489
+ def forward_dynamics_aba(
490
+ model: JaxSimModel,
491
+ data: js.data.JaxSimModelData,
492
+ *,
493
+ joint_forces: jtp.VectorLike | None = None,
494
+ external_forces: jtp.MatrixLike | None = None,
495
+ ) -> tuple[jtp.Vector, jtp.Vector]:
496
+ """
497
+ Compute the forward dynamics of the model with the ABA algorithm.
498
+
499
+ Args:
500
+ model: The model to consider.
501
+ data: The data of the considered model.
502
+ joint_forces:
503
+ The joint forces to consider as a vector of shape `(dofs,)`.
504
+ external_forces:
505
+ The external forces to consider as a matrix of shape `(nL, 6)`.
506
+
507
+ Returns:
508
+ A tuple containing the 6D acceleration in the active representation of the
509
+ base link and the joint accelerations resulting from the application of the
510
+ considered joint forces and external forces.
511
+ """
512
+
513
+ # Build joint torques if not provided
514
+ τ = (
515
+ joint_forces
516
+ if joint_forces is not None
517
+ else jnp.zeros_like(data.joint_positions())
518
+ )
519
+
520
+ # Build external forces if not provided
521
+ f_ext = (
522
+ external_forces
523
+ if external_forces is not None
524
+ else jnp.zeros((model.number_of_links(), 6))
525
+ )
526
+
527
+ # Compute ABA
528
+ W_v̇_WB, s̈ = jaxsim.physics.algos.aba.aba(
529
+ model=model.physics_model,
530
+ xfb=data.state.physics_model.xfb(),
531
+ q=data.state.physics_model.joint_positions,
532
+ qd=data.state.physics_model.joint_velocities,
533
+ tau=τ,
534
+ f_ext=f_ext,
535
+ )
536
+
537
+ def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
538
+ C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
539
+
540
+ if data.velocity_representation != VelRepr.Mixed:
541
+ return C_X_W @ W_vd_WB
542
+
543
+ from jaxsim.math.cross import Cross
544
+
545
+ W_v_WC = jnp.hstack([W_vl_WC, jnp.zeros(3)])
546
+ return C_X_W @ (W_vd_WB - Cross.vx(W_v_WC) @ W_v_WB)
547
+
548
+ match data.velocity_representation:
549
+ case VelRepr.Inertial:
550
+ W_H_C = W_H_W = jnp.eye(4)
551
+ W_vl_WC = W_vl_WW = jnp.zeros(3)
552
+
553
+ case VelRepr.Body:
554
+ W_H_C = W_H_B = data.base_transform()
555
+ W_vl_WC = W_vl_WB = data.base_velocity()[0:3]
556
+
557
+ case VelRepr.Mixed:
558
+ W_H_B = data.base_transform()
559
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
560
+ W_vl_WC = W_vl_W_BW = data.base_velocity()[0:3]
561
+
562
+ case _:
563
+ raise ValueError(data.velocity_representation)
564
+
565
+ # We need to convert the derivative of the base acceleration to the active
566
+ # representation. In Mixed representation, this conversion is not a plain
567
+ # transformation with just X, but it also involves a cross product in ℝ⁶.
568
+ C_v̇_WB = to_active(
569
+ W_vd_WB=W_v̇_WB.squeeze(),
570
+ W_H_C=W_H_C,
571
+ W_v_WB=jnp.hstack(
572
+ [
573
+ data.state.physics_model.base_linear_velocity,
574
+ data.state.physics_model.base_angular_velocity,
575
+ ]
576
+ ),
577
+ W_vl_WC=W_vl_WC,
578
+ )
579
+
580
+ # Adjust shape
581
+ s̈ = jnp.atleast_1d(s̈.squeeze())
582
+
583
+ return C_v̇_WB, s̈
584
+
585
+
586
+ @jax.jit
587
+ def forward_dynamics_crb(
588
+ model: JaxSimModel,
589
+ data: js.data.JaxSimModelData,
590
+ *,
591
+ joint_forces: jtp.MatrixLike | None = None,
592
+ external_forces: jtp.MatrixLike | None = None,
593
+ ) -> tuple[jtp.Vector, jtp.Vector]:
594
+ """
595
+ Compute the forward dynamics of the model with the CRB algorithm.
596
+
597
+ Args:
598
+ model: The model to consider.
599
+ data: The data of the considered model.
600
+ joint_forces:
601
+ The joint forces to consider as a vector of shape `(dofs,)`.
602
+ external_forces:
603
+ The external forces to consider as a matrix of shape `(nL, 6)`.
604
+
605
+ Returns:
606
+ A tuple containing the 6D acceleration in the active representation of the
607
+ base link and the joint accelerations resulting from the application of the
608
+ considered joint forces and external forces.
609
+
610
+ Note:
611
+ Compared to ABA, this method could be significantly slower, especially for
612
+ models with a large number of degrees of freedom.
613
+ """
614
+
615
+ # Build joint torques if not provided
616
+ τ = (
617
+ joint_forces
618
+ if joint_forces is not None
619
+ else jnp.zeros_like(data.joint_positions())
620
+ )
621
+
622
+ # Build external forces if not provided
623
+ external_forces = (
624
+ external_forces
625
+ if external_forces is not None
626
+ else jnp.zeros(shape=(model.number_of_links(), 6))
627
+ )
628
+
629
+ # Handle models with zero and one DoFs
630
+ τ = jnp.atleast_1d(τ.squeeze())
631
+ τ = jnp.vstack(τ) if τ.size > 0 else jnp.empty(shape=(0, 1))
632
+
633
+ # Compute terms of the floating-base EoM
634
+ M = free_floating_mass_matrix(model=model, data=data)
635
+ h = jnp.vstack(free_floating_bias_forces(model=model, data=data))
636
+ J = jnp.vstack(generalized_free_floating_jacobian(model=model, data=data))
637
+ f_ext = jnp.vstack(external_forces.flatten())
638
+ S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
639
+
640
+ # TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i)
641
+ if model.floating_base():
642
+ ν̇ = jnp.linalg.inv(M) @ ((S @ τ) - h + J.T @ f_ext)
643
+ else:
644
+ v̇_WB = jnp.zeros(6)
645
+ s̈ = jnp.linalg.inv(M[6:, 6:]) @ ((S @ τ)[6:] - h[6:] + J[:, 6:].T @ f_ext)
646
+ ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()])
647
+
648
+ # Extract the base acceleration in the active representation.
649
+ # Note that this is an apparent acceleration (relevant in Mixed representation),
650
+ # therefore it cannot be always expressed in different frames with just a
651
+ # 6D transformation X.
652
+ v̇_WB = ν̇[0:6].squeeze().astype(float)
653
+
654
+ # Extract the joint accelerations
655
+ s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float)
656
+
657
+ return v̇_WB, s̈
658
+
659
+
660
+ @jax.jit
661
+ def free_floating_mass_matrix(
662
+ model: JaxSimModel, data: js.data.JaxSimModelData
663
+ ) -> jtp.Matrix:
664
+ """
665
+ Compute the free-floating mass matrix of the model with the CRBA algorithm.
666
+
667
+ Args:
668
+ model: The model to consider.
669
+ data: The data of the considered model.
670
+
671
+ Returns:
672
+ The free-floating mass matrix of the model.
673
+ """
674
+
675
+ M_body = jaxsim.physics.algos.crba.crba(
676
+ model=model.physics_model,
677
+ q=data.state.physics_model.joint_positions,
678
+ )
679
+
680
+ match data.velocity_representation:
681
+ case VelRepr.Body:
682
+ return M_body
683
+
684
+ case VelRepr.Inertial:
685
+ zero_6n = jnp.zeros(shape=(6, model.dofs()))
686
+ B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
687
+
688
+ invT = jnp.vstack(
689
+ [
690
+ jnp.block([B_X_W, zero_6n]),
691
+ jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
692
+ ]
693
+ )
694
+
695
+ return invT.T @ M_body @ invT
696
+
697
+ case VelRepr.Mixed:
698
+ zero_6n = jnp.zeros(shape=(6, model.dofs()))
699
+ W_H_BW = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
700
+ BW_X_W = jaxlie.SE3.from_matrix(W_H_BW).inverse().adjoint()
701
+
702
+ invT = jnp.vstack(
703
+ [
704
+ jnp.block([BW_X_W, zero_6n]),
705
+ jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
706
+ ]
707
+ )
708
+
709
+ return invT.T @ M_body @ invT
710
+
711
+ case _:
712
+ raise ValueError(data.velocity_representation)
713
+
714
+
715
+ @jax.jit
716
+ def inverse_dynamics(
717
+ model: JaxSimModel,
718
+ data: js.data.JaxSimModelData,
719
+ *,
720
+ joint_accelerations: jtp.Vector | None = None,
721
+ base_acceleration: jtp.Vector | None = None,
722
+ external_forces: jtp.Matrix | None = None,
723
+ ) -> tuple[jtp.Vector, jtp.Vector]:
724
+ """
725
+ Compute inverse dynamics with the RNEA algorithm.
726
+
727
+ Args:
728
+ model: The model to consider.
729
+ data: The data of the considered model.
730
+ joint_accelerations:
731
+ The joint accelerations to consider as a vector of shape `(dofs,)`.
732
+ base_acceleration:
733
+ The base acceleration to consider as a vector of shape `(6,)`.
734
+ external_forces:
735
+ The external forces to consider as a matrix of shape `(nL, 6)`.
736
+
737
+ Returns:
738
+ A tuple containing the 6D force in the active representation applied to the
739
+ base to obtain the considered base acceleration, and the joint forces to apply
740
+ to obtain the considered joint accelerations.
741
+ """
742
+
743
+ # Build joint accelerations if not provided
744
+ joint_accelerations = (
745
+ joint_accelerations
746
+ if joint_accelerations is not None
747
+ else jnp.zeros_like(data.joint_positions())
748
+ )
749
+
750
+ # Build base acceleration if not provided
751
+ base_acceleration = (
752
+ base_acceleration if base_acceleration is not None else jnp.zeros(6)
753
+ )
754
+
755
+ external_forces = (
756
+ external_forces
757
+ if external_forces is not None
758
+ else jnp.zeros(shape=(model.number_of_links(), 6))
759
+ )
760
+
761
+ def to_inertial(C_vd_WB, W_H_C, C_v_WB, W_vl_WC):
762
+ W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
763
+ C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
764
+
765
+ if data.velocity_representation != VelRepr.Mixed:
766
+ return W_X_C @ C_vd_WB
767
+ else:
768
+ from jaxsim.math.cross import Cross
769
+
770
+ C_v_WC = C_X_W @ jnp.hstack([W_vl_WC, jnp.zeros(3)])
771
+ return W_X_C @ (C_vd_WB + Cross.vx(C_v_WC) @ C_v_WB)
772
+
773
+ match data.velocity_representation:
774
+ case VelRepr.Inertial:
775
+ W_H_C = W_H_W = jnp.eye(4)
776
+ W_vl_WC = W_vl_WW = jnp.zeros(3)
777
+
778
+ case VelRepr.Body:
779
+ W_H_C = W_H_B = data.base_transform()
780
+ W_vl_WC = W_vl_WB = data.base_velocity()[0:3]
781
+
782
+ case VelRepr.Mixed:
783
+ W_H_B = data.base_transform()
784
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
785
+ W_vl_WC = W_vl_W_BW = data.base_velocity()[0:3]
786
+
787
+ case _:
788
+ raise ValueError(data.velocity_representation)
789
+
790
+ # We need to convert the derivative of the base acceleration to the Inertial
791
+ # representation. In Mixed representation, this conversion is not a plain
792
+ # transformation with just X, but it also involves a cross product in ℝ⁶.
793
+ W_v̇_WB = to_inertial(
794
+ C_vd_WB=base_acceleration,
795
+ W_H_C=W_H_C,
796
+ C_v_WB=data.base_velocity(),
797
+ W_vl_WC=W_vl_WC,
798
+ )
799
+
800
+ # Compute RNEA
801
+ W_f_B, τ = jaxsim.physics.algos.rnea.rnea(
802
+ model=model.physics_model,
803
+ xfb=data.state.physics_model.xfb(),
804
+ q=data.state.physics_model.joint_positions,
805
+ qd=data.state.physics_model.joint_velocities,
806
+ qdd=joint_accelerations,
807
+ a0fb=W_v̇_WB,
808
+ f_ext=external_forces,
809
+ )
810
+
811
+ # Adjust shape
812
+ τ = jnp.atleast_1d(τ.squeeze())
813
+
814
+ # Express W_f_B in the active representation
815
+ f_B = js.data.JaxSimModelData.inertial_to_other_representation(
816
+ array=W_f_B,
817
+ other_representation=data.velocity_representation,
818
+ transform=data.base_transform(),
819
+ is_force=True,
820
+ ).squeeze()
821
+
822
+ return f_B.astype(float), τ.astype(float)
823
+
824
+
825
+ @jax.jit
826
+ def free_floating_gravity_forces(
827
+ model: JaxSimModel, data: js.data.JaxSimModelData
828
+ ) -> jtp.Vector:
829
+ """
830
+ Compute the free-floating gravity forces :math:`g(\mathbf{q})` of the model.
831
+
832
+ Args:
833
+ model: The model to consider.
834
+ data: The data of the considered model.
835
+
836
+ Returns:
837
+ The free-floating gravity forces of the model.
838
+ """
839
+
840
+ # Build a zeroed state
841
+ data_rnea = js.data.JaxSimModelData.zero(model=model)
842
+
843
+ # Set just the generalized position
844
+ with data_rnea.mutable_context(
845
+ mutability=Mutability.MUTABLE, restore_after_exception=False
846
+ ):
847
+
848
+ data_rnea.state.physics_model.base_position = (
849
+ data.state.physics_model.base_position
850
+ )
851
+
852
+ data_rnea.state.physics_model.base_quaternion = (
853
+ data.state.physics_model.base_quaternion
854
+ )
855
+
856
+ data_rnea.state.physics_model.joint_positions = (
857
+ data.state.physics_model.joint_positions
858
+ )
859
+
860
+ return jnp.hstack(
861
+ inverse_dynamics(
862
+ model=model,
863
+ data=data_rnea,
864
+ # Set zero inputs:
865
+ joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
866
+ base_acceleration=jnp.zeros(6),
867
+ external_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
868
+ )
869
+ ).astype(float)
870
+
871
+
872
+ @jax.jit
873
+ def free_floating_bias_forces(
874
+ model: JaxSimModel, data: js.data.JaxSimModelData
875
+ ) -> jtp.Vector:
876
+ """
877
+ Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})`
878
+ of the model.
879
+
880
+ Args:
881
+ model: The model to consider.
882
+ data: The data of the considered model.
883
+
884
+ Returns:
885
+ The free-floating bias forces of the model.
886
+ """
887
+
888
+ # Build a zeroed state
889
+ data_rnea = js.data.JaxSimModelData.zero(model=model)
890
+
891
+ # Set the generalized position and generalized velocity
892
+ with data_rnea.mutable_context(
893
+ mutability=Mutability.MUTABLE, restore_after_exception=False
894
+ ):
895
+
896
+ data_rnea.state.physics_model.base_position = (
897
+ data.state.physics_model.base_position
898
+ )
899
+
900
+ data_rnea.state.physics_model.base_quaternion = (
901
+ data.state.physics_model.base_quaternion
902
+ )
903
+
904
+ data_rnea.state.physics_model.joint_positions = (
905
+ data.state.physics_model.joint_positions
906
+ )
907
+
908
+ data_rnea.state.physics_model.base_linear_velocity = (
909
+ data.state.physics_model.base_linear_velocity
910
+ )
911
+
912
+ data_rnea.state.physics_model.base_angular_velocity = (
913
+ data.state.physics_model.base_angular_velocity
914
+ )
915
+
916
+ data_rnea.state.physics_model.joint_velocities = (
917
+ data.state.physics_model.joint_velocities
918
+ )
919
+
920
+ return jnp.hstack(
921
+ inverse_dynamics(
922
+ model=model,
923
+ data=data_rnea,
924
+ # Set zero inputs:
925
+ joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
926
+ base_acceleration=jnp.zeros(6),
927
+ external_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
928
+ )
929
+ ).astype(float)
930
+
931
+
932
+ # ==========================
933
+ # Other kinematic quantities
934
+ # ==========================
935
+
936
+
937
+ @jax.jit
938
+ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
939
+ """
940
+ Compute the total momentum of the model.
941
+
942
+ Args:
943
+ model: The model to consider.
944
+ data: The data of the considered model.
945
+
946
+ Returns:
947
+ The total momentum of the model.
948
+ """
949
+
950
+ # Compute the momentum in body-fixed velocity representation.
951
+ # Note: the first 6 rows of the mass matrix define the jacobian of the
952
+ # floating-base momentum.
953
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
954
+ B_ν = data.generalized_velocity()
955
+ M_B = free_floating_mass_matrix(model=model, data=data)
956
+
957
+ # Compute the total momentum expressed in the base frame
958
+ B_h = M_B[0:6, :] @ B_ν
959
+
960
+ # Compute the 6D transformation matrix
961
+ W_H_B = data.base_transform()
962
+ B_X_W: jtp.Array = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
963
+
964
+ # Convert to inertial-fixed representation
965
+ # (its coordinates transform like 6D forces)
966
+ W_h = B_X_W.T @ B_h
967
+
968
+ # Convert to the active representation of the model
969
+ return js.data.JaxSimModelData.inertial_to_other_representation(
970
+ array=W_h,
971
+ other_representation=data.velocity_representation,
972
+ transform=W_H_B,
973
+ is_force=True,
974
+ ).astype(float)
975
+
976
+
977
+ # ======
978
+ # Energy
979
+ # ======
980
+
981
+
982
+ @jax.jit
983
+ def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
984
+ """
985
+ Compute the mechanical energy of the model.
986
+
987
+ Args:
988
+ model: The model to consider.
989
+ data: The data of the considered model.
990
+
991
+ Returns:
992
+ The mechanical energy of the model.
993
+ """
994
+
995
+ K = kinetic_energy(model=model, data=data)
996
+ U = potential_energy(model=model, data=data)
997
+
998
+ return (K + U).astype(float)
999
+
1000
+
1001
+ @jax.jit
1002
+ def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
1003
+ """
1004
+ Compute the kinetic energy of the model.
1005
+
1006
+ Args:
1007
+ model: The model to consider.
1008
+ data: The data of the considered model.
1009
+
1010
+ Returns:
1011
+ The kinetic energy of the model.
1012
+ """
1013
+
1014
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
1015
+ B_ν = data.generalized_velocity()
1016
+ M_B = free_floating_mass_matrix(model=model, data=data)
1017
+
1018
+ K = 0.5 * B_ν.T @ M_B @ B_ν
1019
+ return K.squeeze().astype(float)
1020
+
1021
+
1022
+ @jax.jit
1023
+ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
1024
+ """
1025
+ Compute the potential energy of the model.
1026
+
1027
+ Args:
1028
+ model: The model to consider.
1029
+ data: The data of the considered model.
1030
+
1031
+ Returns:
1032
+ The potential energy of the model.
1033
+ """
1034
+
1035
+ m = total_mass(model=model)
1036
+ gravity = data.gravity.squeeze()
1037
+ W_p̃_CoM = jnp.hstack([com_position(model=model, data=data), 1])
1038
+
1039
+ U = -jnp.hstack([gravity, 0]) @ (m * W_p̃_CoM)
1040
+ return U.squeeze().astype(float)
1041
+
1042
+
1043
+ # ==========
1044
+ # Simulation
1045
+ # ==========
1046
+
1047
+
1048
+ @functools.partial(jax.jit, static_argnames=["integrator"])
1049
+ def step(
1050
+ model: JaxSimModel,
1051
+ data: js.data.JaxSimModelData,
1052
+ *,
1053
+ dt: jtp.FloatLike,
1054
+ integrator: jaxsim.integrators.Integrator,
1055
+ integrator_state: dict[str, Any] | None = None,
1056
+ joint_forces: jtp.Vector | None = None,
1057
+ external_forces: jtp.Vector | None = None,
1058
+ ) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
1059
+ """
1060
+ Perform a simulation step.
1061
+
1062
+ Args:
1063
+ model: The model to consider.
1064
+ data: The data of the considered model.
1065
+ dt: The time step to consider.
1066
+ integrator: The integrator to use.
1067
+ integrator_state: The state of the integrator.
1068
+ joint_forces: The joint forces to consider.
1069
+ external_forces: The external forces to consider.
1070
+
1071
+ Returns:
1072
+ A tuple containing the new data of the model
1073
+ and the new state of the integrator.
1074
+ """
1075
+
1076
+ integrator_state = integrator_state if integrator_state is not None else dict()
1077
+
1078
+ # Extract the initial resources.
1079
+ t0_ns = data.time_ns
1080
+ state_x0 = data.state
1081
+ integrator_state_x0 = integrator_state
1082
+
1083
+ # Step the dynamics forward.
1084
+ state_xf, integrator_state_xf = integrator.step(
1085
+ x0=state_x0,
1086
+ t0=jnp.array(t0_ns * 1e9).astype(float),
1087
+ dt=dt,
1088
+ params=integrator_state_x0,
1089
+ **dict(joint_forces=joint_forces, external_forces=external_forces),
1090
+ )
1091
+
1092
+ return (
1093
+ # Store the new state of the model and the new time.
1094
+ data.replace(
1095
+ state=state_xf,
1096
+ time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
1097
+ ),
1098
+ integrator_state_xf,
1099
+ )