jaxsim 0.6.1.dev13__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
- jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.1.dev13.dist-info/RECORD +0 -74
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py
CHANGED
@@ -5,7 +5,6 @@ import dataclasses
|
|
5
5
|
import functools
|
6
6
|
import pathlib
|
7
7
|
from collections.abc import Sequence
|
8
|
-
from typing import Any
|
9
8
|
|
10
9
|
import jax
|
11
10
|
import jax.numpy as jnp
|
@@ -30,21 +29,28 @@ class JaxSimModel(JaxsimDataclass):
|
|
30
29
|
The JaxSim model defining the kinematics and dynamics of a robot.
|
31
30
|
"""
|
32
31
|
|
32
|
+
# link_spatial_inertial_matrices, motion_subspaces
|
33
|
+
|
33
34
|
model_name: Static[str]
|
34
35
|
|
35
|
-
time_step:
|
36
|
-
|
36
|
+
time_step: float = dataclasses.field(
|
37
|
+
default=0.001,
|
37
38
|
)
|
38
39
|
|
39
40
|
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
|
40
41
|
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
|
41
42
|
)
|
42
43
|
|
43
|
-
|
44
|
+
gravity: Static[float] = jaxsim.math.STANDARD_GRAVITY
|
45
|
+
|
44
46
|
contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(
|
45
47
|
default=None, repr=False
|
46
48
|
)
|
47
49
|
|
50
|
+
contacts_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field(
|
51
|
+
default=None, repr=False
|
52
|
+
)
|
53
|
+
|
48
54
|
kin_dyn_parameters: js.kin_dyn_parameters.KinDynParameters | None = (
|
49
55
|
dataclasses.field(default=None, repr=False)
|
50
56
|
)
|
@@ -53,10 +59,6 @@ class JaxSimModel(JaxsimDataclass):
|
|
53
59
|
default=None, repr=False
|
54
60
|
)
|
55
61
|
|
56
|
-
integrator: Static[jaxsim.integrators.Integrator | None] = dataclasses.field(
|
57
|
-
default=None, repr=False
|
58
|
-
)
|
59
|
-
|
60
62
|
_description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
|
61
63
|
dataclasses.field(default=None, repr=False)
|
62
64
|
)
|
@@ -89,7 +91,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
89
91
|
return hash(
|
90
92
|
(
|
91
93
|
hash(self.model_name),
|
92
|
-
hash(
|
94
|
+
hash(self.time_step),
|
93
95
|
hash(self.kin_dyn_parameters),
|
94
96
|
hash(self.contact_model),
|
95
97
|
)
|
@@ -106,11 +108,9 @@ class JaxSimModel(JaxsimDataclass):
|
|
106
108
|
*,
|
107
109
|
model_name: str | None = None,
|
108
110
|
time_step: jtp.FloatLike | None = None,
|
109
|
-
integrator: (
|
110
|
-
jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
|
111
|
-
) = None,
|
112
111
|
terrain: jaxsim.terrain.Terrain | None = None,
|
113
112
|
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
|
113
|
+
contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
|
114
114
|
is_urdf: bool | None = None,
|
115
115
|
considered_joints: Sequence[str] | None = None,
|
116
116
|
) -> JaxSimModel:
|
@@ -130,10 +130,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
130
130
|
contact_model:
|
131
131
|
The contact model to consider.
|
132
132
|
If not specified, a soft contacts model is used.
|
133
|
-
|
134
|
-
The integrator to use. If not specified, a default one is used.
|
135
|
-
This argument can either be a pre-built integrator instance or one
|
136
|
-
of the integrator classes defined in JaxSim.
|
133
|
+
contact_params: The parameters of the contact model.
|
137
134
|
is_urdf:
|
138
135
|
The optional flag to force the model description to be parsed as a URDF.
|
139
136
|
This is usually automatically inferred.
|
@@ -164,9 +161,9 @@ class JaxSimModel(JaxsimDataclass):
|
|
164
161
|
model_description=intermediate_description,
|
165
162
|
model_name=model_name,
|
166
163
|
time_step=time_step,
|
167
|
-
integrator=integrator,
|
168
164
|
terrain=terrain,
|
169
165
|
contact_model=contact_model,
|
166
|
+
contacts_params=contact_params,
|
170
167
|
)
|
171
168
|
|
172
169
|
# Store the origin of the model, in case downstream logic needs it.
|
@@ -182,11 +179,10 @@ class JaxSimModel(JaxsimDataclass):
|
|
182
179
|
*,
|
183
180
|
model_name: str | None = None,
|
184
181
|
time_step: jtp.FloatLike | None = None,
|
185
|
-
integrator: (
|
186
|
-
jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
|
187
|
-
) = None,
|
188
182
|
terrain: jaxsim.terrain.Terrain | None = None,
|
189
183
|
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
|
184
|
+
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
|
185
|
+
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
|
190
186
|
) -> JaxSimModel:
|
191
187
|
"""
|
192
188
|
Build a Model object from an intermediate model description.
|
@@ -202,13 +198,11 @@ class JaxSimModel(JaxsimDataclass):
|
|
202
198
|
manually overridden in the function that steps the simulation.
|
203
199
|
terrain: The terrain to consider (the default is a flat infinite plane).
|
204
200
|
The optional name of the model overriding the physics model name.
|
205
|
-
integrator:
|
206
|
-
The integrator to use. If not specified, a default one is used.
|
207
|
-
This argument can either be a pre-built integrator instance or one
|
208
|
-
of the integrator classes defined in JaxSim.
|
209
201
|
contact_model:
|
210
202
|
The contact model to consider.
|
211
203
|
If not specified, a soft contacts model is used.
|
204
|
+
contacts_params: The parameters of the soft contacts.
|
205
|
+
gravity: The gravity constant.
|
212
206
|
|
213
207
|
Returns:
|
214
208
|
The built Model object.
|
@@ -228,7 +222,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
228
222
|
time_step = (
|
229
223
|
time_step
|
230
224
|
if time_step is not None
|
231
|
-
else JaxSimModel.__dataclass_fields__["time_step"].
|
225
|
+
else JaxSimModel.__dataclass_fields__["time_step"].default
|
232
226
|
)
|
233
227
|
|
234
228
|
# Create the default contact model.
|
@@ -237,39 +231,11 @@ class JaxSimModel(JaxsimDataclass):
|
|
237
231
|
contact_model = (
|
238
232
|
contact_model
|
239
233
|
if contact_model is not None
|
240
|
-
else jaxsim.rbda.contacts.
|
234
|
+
else jaxsim.rbda.contacts.RelaxedRigidContacts.build()
|
241
235
|
)
|
242
236
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
# If None, build a default integrator.
|
247
|
-
case None:
|
248
|
-
|
249
|
-
integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
|
250
|
-
dynamics=js.ode.wrap_system_dynamics_for_integration(
|
251
|
-
system_dynamics=js.ode.system_dynamics
|
252
|
-
)
|
253
|
-
)
|
254
|
-
|
255
|
-
# If it's a pre-built integrator (also a custom one from the user)
|
256
|
-
# just use it as is.
|
257
|
-
case _ if isinstance(integrator, jaxsim.integrators.Integrator):
|
258
|
-
pass
|
259
|
-
|
260
|
-
# If an integrator class is passed, assume that it is a JaxSim integrator
|
261
|
-
# and build it with the default system dynamics.
|
262
|
-
case _ if issubclass(integrator, jaxsim.integrators.Integrator):
|
263
|
-
|
264
|
-
integrator_cls = integrator
|
265
|
-
integrator = integrator_cls.build(
|
266
|
-
dynamics=js.ode.wrap_system_dynamics_for_integration(
|
267
|
-
system_dynamics=js.ode.system_dynamics
|
268
|
-
)
|
269
|
-
)
|
270
|
-
|
271
|
-
case _:
|
272
|
-
raise ValueError(f"Invalid integrator: {integrator}")
|
237
|
+
if contacts_params is None:
|
238
|
+
contacts_params = contact_model._parameters_class()
|
273
239
|
|
274
240
|
# Build the model.
|
275
241
|
model = cls(
|
@@ -280,7 +246,8 @@ class JaxSimModel(JaxsimDataclass):
|
|
280
246
|
time_step=time_step,
|
281
247
|
terrain=terrain,
|
282
248
|
contact_model=contact_model,
|
283
|
-
|
249
|
+
contacts_params=contacts_params,
|
250
|
+
gravity=gravity,
|
284
251
|
# The following is wrapped as hashless since it's a static argument, and we
|
285
252
|
# don't want to trigger recompilation if it changes. All relevant parameters
|
286
253
|
# needed to compute kinematics and dynamics quantities are stored in the
|
@@ -350,7 +317,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
350
317
|
True if the model is floating-base, False otherwise.
|
351
318
|
"""
|
352
319
|
|
353
|
-
return
|
320
|
+
return self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6
|
354
321
|
|
355
322
|
def base_link(self) -> str:
|
356
323
|
"""
|
@@ -381,7 +348,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
381
348
|
the number of joints. In the future, this could be different.
|
382
349
|
"""
|
383
350
|
|
384
|
-
return
|
351
|
+
return sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:])
|
385
352
|
|
386
353
|
def joint_names(self) -> tuple[str, ...]:
|
387
354
|
"""
|
@@ -464,7 +431,7 @@ def reduce(
|
|
464
431
|
for joint_name in set(model.joint_names()) - set(considered_joints):
|
465
432
|
j = intermediate_description.joints_dict[joint_name]
|
466
433
|
with j.mutable_context():
|
467
|
-
j.initial_position =
|
434
|
+
j.initial_position = locked_joint_positions.get(joint_name, 0.0)
|
468
435
|
|
469
436
|
# Reduce the model description.
|
470
437
|
# If `considered_joints` contains joints not existing in the model,
|
@@ -480,7 +447,6 @@ def reduce(
|
|
480
447
|
time_step=model.time_step,
|
481
448
|
terrain=model.terrain,
|
482
449
|
contact_model=model.contact_model,
|
483
|
-
integrator=model.integrator,
|
484
450
|
)
|
485
451
|
|
486
452
|
# Store the origin of the model, in case downstream logic needs it.
|
@@ -534,31 +500,6 @@ def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
|
|
534
500
|
# ==============================
|
535
501
|
|
536
502
|
|
537
|
-
@jax.jit
|
538
|
-
@js.common.named_scope
|
539
|
-
def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
|
540
|
-
"""
|
541
|
-
Compute the SE(3) transforms from the world frame to the frames of all links.
|
542
|
-
|
543
|
-
Args:
|
544
|
-
model: The model to consider.
|
545
|
-
data: The data of the considered model.
|
546
|
-
|
547
|
-
Returns:
|
548
|
-
A (nL, 4, 4) array containing the stacked SE(3) transforms of the links.
|
549
|
-
The first axis is the link index.
|
550
|
-
"""
|
551
|
-
|
552
|
-
W_H_LL = jaxsim.rbda.forward_kinematics_model(
|
553
|
-
model=model,
|
554
|
-
base_position=data.base_position(),
|
555
|
-
base_quaternion=data.base_orientation(dcm=False),
|
556
|
-
joint_positions=data.joint_positions(model=model),
|
557
|
-
)
|
558
|
-
|
559
|
-
return jnp.atleast_3d(W_H_LL).astype(float)
|
560
|
-
|
561
|
-
|
562
503
|
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
563
504
|
def generalized_free_floating_jacobian(
|
564
505
|
model: JaxSimModel,
|
@@ -592,7 +533,7 @@ def generalized_free_floating_jacobian(
|
|
592
533
|
# Compute the doubly-left free-floating full jacobian.
|
593
534
|
B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
|
594
535
|
model=model,
|
595
|
-
joint_positions=data.joint_positions
|
536
|
+
joint_positions=data.joint_positions,
|
596
537
|
)
|
597
538
|
|
598
539
|
# ======================================================================
|
@@ -603,7 +544,7 @@ def generalized_free_floating_jacobian(
|
|
603
544
|
|
604
545
|
case VelRepr.Inertial:
|
605
546
|
|
606
|
-
W_H_B = data.
|
547
|
+
W_H_B = data._base_transform
|
607
548
|
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
|
608
549
|
|
609
550
|
B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841
|
@@ -617,7 +558,7 @@ def generalized_free_floating_jacobian(
|
|
617
558
|
|
618
559
|
case VelRepr.Mixed:
|
619
560
|
|
620
|
-
W_R_B = data.base_orientation
|
561
|
+
W_R_B = jaxsim.math.Quaternion.to_dcm(data.base_orientation)
|
621
562
|
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
|
622
563
|
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
623
564
|
|
@@ -651,7 +592,7 @@ def generalized_free_floating_jacobian(
|
|
651
592
|
|
652
593
|
case VelRepr.Inertial:
|
653
594
|
|
654
|
-
W_H_B = data.
|
595
|
+
W_H_B = data._base_transform
|
655
596
|
W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)
|
656
597
|
|
657
598
|
O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841
|
@@ -669,7 +610,7 @@ def generalized_free_floating_jacobian(
|
|
669
610
|
|
670
611
|
case VelRepr.Mixed:
|
671
612
|
|
672
|
-
W_H_B = data.
|
613
|
+
W_H_B = data._base_transform
|
673
614
|
|
674
615
|
LW_H_L = jax.vmap(
|
675
616
|
lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3))
|
@@ -718,15 +659,15 @@ def generalized_free_floating_jacobian_derivative(
|
|
718
659
|
# Compute the derivative of the doubly-left free-floating full jacobian.
|
719
660
|
B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left(
|
720
661
|
model=model,
|
721
|
-
joint_positions=data.joint_positions
|
722
|
-
joint_velocities=data.joint_velocities
|
662
|
+
joint_positions=data.joint_positions,
|
663
|
+
joint_velocities=data.joint_velocities,
|
723
664
|
)
|
724
665
|
|
725
666
|
# The derivative of the equation to change the input and output representations
|
726
667
|
# of the Jacobian derivative needs the computation of the plain link Jacobian.
|
727
668
|
B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
|
728
669
|
model=model,
|
729
|
-
joint_positions=data.joint_positions
|
670
|
+
joint_positions=data.joint_positions,
|
730
671
|
)
|
731
672
|
|
732
673
|
# Compute the actual doubly-left free-floating jacobian derivative of the link
|
@@ -734,7 +675,7 @@ def generalized_free_floating_jacobian_derivative(
|
|
734
675
|
κb = model.kin_dyn_parameters.support_body_array_bool
|
735
676
|
|
736
677
|
# Compute the base transform.
|
737
|
-
W_H_B = data.
|
678
|
+
W_H_B = data._base_transform
|
738
679
|
|
739
680
|
# We add the 5 columns of ones to the Jacobian derivative to account for the
|
740
681
|
# base velocity and acceleration (5 + number of links = 6 + number of joints).
|
@@ -758,7 +699,7 @@ def generalized_free_floating_jacobian_derivative(
|
|
758
699
|
|
759
700
|
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
|
760
701
|
|
761
|
-
W_v_WB = data.base_velocity
|
702
|
+
W_v_WB = data.base_velocity
|
762
703
|
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
|
763
704
|
|
764
705
|
# Compute the operator to change the representation of ν, and its
|
@@ -784,7 +725,7 @@ def generalized_free_floating_jacobian_derivative(
|
|
784
725
|
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
|
785
726
|
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
786
727
|
|
787
|
-
BW_v_WB = data.base_velocity
|
728
|
+
BW_v_WB = data.base_velocity
|
788
729
|
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
789
730
|
|
790
731
|
BW_v_BW_B = BW_v_WB - BW_v_W_BW
|
@@ -809,7 +750,7 @@ def generalized_free_floating_jacobian_derivative(
|
|
809
750
|
O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)
|
810
751
|
|
811
752
|
with data.switch_velocity_representation(VelRepr.Body):
|
812
|
-
B_v_WB = data.base_velocity
|
753
|
+
B_v_WB = data.base_velocity
|
813
754
|
|
814
755
|
O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
|
815
756
|
|
@@ -822,9 +763,9 @@ def generalized_free_floating_jacobian_derivative(
|
|
822
763
|
B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)
|
823
764
|
|
824
765
|
with data.switch_velocity_representation(VelRepr.Body):
|
825
|
-
B_v_WB = data.base_velocity
|
766
|
+
B_v_WB = data.base_velocity
|
826
767
|
L_v_WL = jnp.einsum(
|
827
|
-
"b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity
|
768
|
+
"b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity
|
828
769
|
)
|
829
770
|
|
830
771
|
O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
|
@@ -842,7 +783,7 @@ def generalized_free_floating_jacobian_derivative(
|
|
842
783
|
B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)
|
843
784
|
|
844
785
|
with data.switch_velocity_representation(VelRepr.Body):
|
845
|
-
B_v_WB = data.base_velocity
|
786
|
+
B_v_WB = data.base_velocity
|
846
787
|
|
847
788
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
848
789
|
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
|
@@ -852,7 +793,7 @@ def generalized_free_floating_jacobian_derivative(
|
|
852
793
|
LW_X_B,
|
853
794
|
B_J_WL_B
|
854
795
|
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
855
|
-
@ data.generalized_velocity
|
796
|
+
@ data.generalized_velocity,
|
856
797
|
)
|
857
798
|
|
858
799
|
LW_v_W_LW = LW_v_WL.at[:, 3:6].set(jnp.zeros_like(LW_v_WL[:, 3:6]))
|
@@ -954,7 +895,7 @@ def forward_dynamics_aba(
|
|
954
895
|
τ = (
|
955
896
|
jnp.atleast_1d(joint_forces.squeeze())
|
956
897
|
if joint_forces is not None
|
957
|
-
else jnp.zeros_like(data.joint_positions
|
898
|
+
else jnp.zeros_like(data.joint_positions)
|
958
899
|
)
|
959
900
|
|
960
901
|
# Build link forces, if not provided.
|
@@ -973,22 +914,17 @@ def forward_dynamics_aba(
|
|
973
914
|
velocity_representation=data.velocity_representation,
|
974
915
|
)
|
975
916
|
|
976
|
-
# Extract the link and joint serializations.
|
977
|
-
link_names = model.link_names()
|
978
|
-
joint_names = model.joint_names()
|
979
|
-
|
980
917
|
# Extract the state in inertial-fixed representation.
|
981
918
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
982
|
-
W_p_B = data.base_position
|
983
|
-
W_v_WB = data.base_velocity
|
984
|
-
W_Q_B = data.base_orientation
|
985
|
-
s = data.joint_positions
|
986
|
-
ṡ = data.joint_velocities
|
919
|
+
W_p_B = data.base_position
|
920
|
+
W_v_WB = data.base_velocity
|
921
|
+
W_Q_B = data.base_orientation
|
922
|
+
s = data.joint_positions
|
923
|
+
ṡ = data.joint_velocities
|
987
924
|
|
988
925
|
# Extract the inputs in inertial-fixed representation.
|
989
|
-
|
990
|
-
|
991
|
-
τ = references.joint_force_references(model=model, joint_names=joint_names)
|
926
|
+
W_f_L = references._link_forces
|
927
|
+
τ = references._joint_force_references
|
992
928
|
|
993
929
|
# ========================
|
994
930
|
# Compute forward dynamics
|
@@ -1004,7 +940,7 @@ def forward_dynamics_aba(
|
|
1004
940
|
joint_velocities=ṡ,
|
1005
941
|
joint_forces=τ,
|
1006
942
|
link_forces=W_f_L,
|
1007
|
-
standard_gravity=
|
943
|
+
standard_gravity=model.gravity,
|
1008
944
|
)
|
1009
945
|
|
1010
946
|
# =============
|
@@ -1032,14 +968,14 @@ def forward_dynamics_aba(
|
|
1032
968
|
|
1033
969
|
case VelRepr.Body:
|
1034
970
|
# In this case C=B
|
1035
|
-
W_H_C = W_H_B = data.
|
971
|
+
W_H_C = W_H_B = data._base_transform
|
1036
972
|
W_v_WC = W_v_WB
|
1037
973
|
|
1038
974
|
case VelRepr.Mixed:
|
1039
975
|
# In this case C=B[W]
|
1040
|
-
W_H_B = data.
|
976
|
+
W_H_B = data._base_transform
|
1041
977
|
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
|
1042
|
-
W_ṗ_B = data.base_velocity
|
978
|
+
W_ṗ_B = data.base_velocity[0:3]
|
1043
979
|
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
|
1044
980
|
|
1045
981
|
case _:
|
@@ -1103,7 +1039,7 @@ def forward_dynamics_crb(
|
|
1103
1039
|
τ = (
|
1104
1040
|
jnp.atleast_1d(joint_forces)
|
1105
1041
|
if joint_forces is not None
|
1106
|
-
else jnp.zeros_like(data.joint_positions
|
1042
|
+
else jnp.zeros_like(data.joint_positions)
|
1107
1043
|
)
|
1108
1044
|
|
1109
1045
|
# Build external forces if not provided.
|
@@ -1174,7 +1110,7 @@ def free_floating_mass_matrix(
|
|
1174
1110
|
|
1175
1111
|
M_body = jaxsim.rbda.crba(
|
1176
1112
|
model=model,
|
1177
|
-
joint_positions=data.
|
1113
|
+
joint_positions=data.joint_positions,
|
1178
1114
|
)
|
1179
1115
|
|
1180
1116
|
match data.velocity_representation:
|
@@ -1183,16 +1119,14 @@ def free_floating_mass_matrix(
|
|
1183
1119
|
|
1184
1120
|
case VelRepr.Inertial:
|
1185
1121
|
|
1186
|
-
B_X_W = Adjoint.from_transform(
|
1187
|
-
transform=data.base_transform(), inverse=True
|
1188
|
-
)
|
1122
|
+
B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True)
|
1189
1123
|
invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
1190
1124
|
|
1191
1125
|
return invT.T @ M_body @ invT
|
1192
1126
|
|
1193
1127
|
case VelRepr.Mixed:
|
1194
1128
|
|
1195
|
-
BW_H_B = data.
|
1129
|
+
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
|
1196
1130
|
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1197
1131
|
invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
1198
1132
|
|
@@ -1228,7 +1162,7 @@ def free_floating_coriolis_matrix(
|
|
1228
1162
|
# to the active representation stored in data.
|
1229
1163
|
with data.switch_velocity_representation(VelRepr.Body):
|
1230
1164
|
|
1231
|
-
B_ν = data.generalized_velocity
|
1165
|
+
B_ν = data.generalized_velocity
|
1232
1166
|
|
1233
1167
|
# Doubly-left free-floating Jacobian.
|
1234
1168
|
L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
|
@@ -1275,12 +1209,12 @@ def free_floating_coriolis_matrix(
|
|
1275
1209
|
case VelRepr.Inertial:
|
1276
1210
|
|
1277
1211
|
n = model.dofs()
|
1278
|
-
W_H_B = data.
|
1212
|
+
W_H_B = data._base_transform
|
1279
1213
|
B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
|
1280
1214
|
B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))
|
1281
1215
|
|
1282
1216
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
1283
|
-
W_v_WB = data.base_velocity
|
1217
|
+
W_v_WB = data.base_velocity
|
1284
1218
|
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
|
1285
1219
|
|
1286
1220
|
B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))
|
@@ -1295,12 +1229,12 @@ def free_floating_coriolis_matrix(
|
|
1295
1229
|
case VelRepr.Mixed:
|
1296
1230
|
|
1297
1231
|
n = model.dofs()
|
1298
|
-
BW_H_B = data.
|
1232
|
+
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
|
1299
1233
|
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1300
1234
|
B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))
|
1301
1235
|
|
1302
1236
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
1303
|
-
BW_v_WB = data.base_velocity
|
1237
|
+
BW_v_WB = data.base_velocity
|
1304
1238
|
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
1305
1239
|
|
1306
1240
|
BW_v_BW_B = BW_v_WB - BW_v_W_BW
|
@@ -1357,7 +1291,7 @@ def inverse_dynamics(
|
|
1357
1291
|
s̈ = (
|
1358
1292
|
jnp.atleast_1d(jnp.array(joint_accelerations).squeeze())
|
1359
1293
|
if joint_accelerations is not None
|
1360
|
-
else jnp.zeros_like(data.joint_positions
|
1294
|
+
else jnp.zeros_like(data.joint_positions)
|
1361
1295
|
)
|
1362
1296
|
|
1363
1297
|
# Build base acceleration, if not provided.
|
@@ -1394,14 +1328,14 @@ def inverse_dynamics(
|
|
1394
1328
|
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
1395
1329
|
|
1396
1330
|
case VelRepr.Body:
|
1397
|
-
W_H_C = W_H_B = data.
|
1331
|
+
W_H_C = W_H_B = data._base_transform
|
1398
1332
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
1399
|
-
W_v_WC = W_v_WB = data.base_velocity
|
1333
|
+
W_v_WC = W_v_WB = data.base_velocity
|
1400
1334
|
|
1401
1335
|
case VelRepr.Mixed:
|
1402
|
-
W_H_B = data.
|
1336
|
+
W_H_B = data._base_transform
|
1403
1337
|
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
|
1404
|
-
W_ṗ_B = data.base_velocity
|
1338
|
+
W_ṗ_B = data.base_velocity[0:3]
|
1405
1339
|
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
|
1406
1340
|
|
1407
1341
|
case _:
|
@@ -1413,7 +1347,7 @@ def inverse_dynamics(
|
|
1413
1347
|
W_v̇_WB = to_inertial(
|
1414
1348
|
C_v̇_WB=v̇_WB,
|
1415
1349
|
W_H_C=W_H_C,
|
1416
|
-
C_v_WB=data.base_velocity
|
1350
|
+
C_v_WB=data.base_velocity,
|
1417
1351
|
W_v_WC=W_v_WC,
|
1418
1352
|
)
|
1419
1353
|
|
@@ -1425,21 +1359,16 @@ def inverse_dynamics(
|
|
1425
1359
|
velocity_representation=data.velocity_representation,
|
1426
1360
|
)
|
1427
1361
|
|
1428
|
-
# Extract the link and joint serializations.
|
1429
|
-
link_names = model.link_names()
|
1430
|
-
joint_names = model.joint_names()
|
1431
|
-
|
1432
1362
|
# Extract the state in inertial-fixed representation.
|
1433
1363
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
1434
|
-
W_p_B = data.base_position
|
1435
|
-
W_v_WB = data.base_velocity
|
1436
|
-
W_Q_B = data.
|
1437
|
-
s = data.joint_positions
|
1438
|
-
ṡ = data.joint_velocities
|
1364
|
+
W_p_B = data.base_position
|
1365
|
+
W_v_WB = data.base_velocity
|
1366
|
+
W_Q_B = data.base_quaternion
|
1367
|
+
s = data.joint_positions
|
1368
|
+
ṡ = data.joint_velocities
|
1439
1369
|
|
1440
1370
|
# Extract the inputs in inertial-fixed representation.
|
1441
|
-
|
1442
|
-
W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
|
1371
|
+
W_f_L = references._link_forces
|
1443
1372
|
|
1444
1373
|
# ========================
|
1445
1374
|
# Compute inverse dynamics
|
@@ -1457,7 +1386,7 @@ def inverse_dynamics(
|
|
1457
1386
|
base_angular_acceleration=W_v̇_WB[3:6],
|
1458
1387
|
joint_accelerations=s̈,
|
1459
1388
|
link_forces=W_f_L,
|
1460
|
-
standard_gravity=
|
1389
|
+
standard_gravity=model.gravity,
|
1461
1390
|
)
|
1462
1391
|
|
1463
1392
|
# =============
|
@@ -1468,7 +1397,7 @@ def inverse_dynamics(
|
|
1468
1397
|
f_B = js.data.JaxSimModelData.inertial_to_other_representation(
|
1469
1398
|
array=W_f_B,
|
1470
1399
|
other_representation=data.velocity_representation,
|
1471
|
-
transform=data.
|
1400
|
+
transform=data._base_transform,
|
1472
1401
|
is_force=True,
|
1473
1402
|
).squeeze()
|
1474
1403
|
|
@@ -1497,21 +1426,12 @@ def free_floating_gravity_forces(
|
|
1497
1426
|
)
|
1498
1427
|
|
1499
1428
|
# Set just the generalized position.
|
1500
|
-
|
1501
|
-
|
1502
|
-
|
1503
|
-
|
1504
|
-
|
1505
|
-
|
1506
|
-
)
|
1507
|
-
|
1508
|
-
data_rnea.state.physics_model.base_quaternion = (
|
1509
|
-
data.state.physics_model.base_quaternion
|
1510
|
-
)
|
1511
|
-
|
1512
|
-
data_rnea.state.physics_model.joint_positions = (
|
1513
|
-
data.state.physics_model.joint_positions
|
1514
|
-
)
|
1429
|
+
data_rnea = data_rnea.replace(
|
1430
|
+
model=model,
|
1431
|
+
base_position=data.base_position,
|
1432
|
+
base_quaternion=data.base_quaternion,
|
1433
|
+
joint_positions=data.joint_positions,
|
1434
|
+
)
|
1515
1435
|
|
1516
1436
|
return jnp.hstack(
|
1517
1437
|
inverse_dynamics(
|
@@ -1548,35 +1468,20 @@ def free_floating_bias_forces(
|
|
1548
1468
|
)
|
1549
1469
|
|
1550
1470
|
# Set the generalized position and generalized velocity.
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
)
|
1566
|
-
|
1567
|
-
data_rnea.state.physics_model.joint_velocities = (
|
1568
|
-
data.state.physics_model.joint_velocities
|
1569
|
-
)
|
1570
|
-
|
1571
|
-
# Make sure that base velocity is zero for fixed-base model.
|
1572
|
-
if model.floating_base():
|
1573
|
-
data_rnea.state.physics_model.base_linear_velocity = (
|
1574
|
-
data.state.physics_model.base_linear_velocity
|
1575
|
-
)
|
1576
|
-
|
1577
|
-
data_rnea.state.physics_model.base_angular_velocity = (
|
1578
|
-
data.state.physics_model.base_angular_velocity
|
1579
|
-
)
|
1471
|
+
base_linear_velocity, base_angular_velocity = None, None
|
1472
|
+
if model.floating_base():
|
1473
|
+
base_velocity = data.base_velocity
|
1474
|
+
base_linear_velocity = base_velocity[:3]
|
1475
|
+
base_angular_velocity = base_velocity[3:]
|
1476
|
+
data_rnea = data_rnea.replace(
|
1477
|
+
model=model,
|
1478
|
+
base_position=data.base_position,
|
1479
|
+
base_quaternion=data.base_quaternion,
|
1480
|
+
joint_positions=data.joint_positions,
|
1481
|
+
joint_velocities=data.joint_velocities,
|
1482
|
+
base_linear_velocity=base_linear_velocity,
|
1483
|
+
base_angular_velocity=base_angular_velocity,
|
1484
|
+
)
|
1580
1485
|
|
1581
1486
|
return jnp.hstack(
|
1582
1487
|
inverse_dynamics(
|
@@ -1628,7 +1533,7 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec
|
|
1628
1533
|
The total momentum of the model in the active velocity representation.
|
1629
1534
|
"""
|
1630
1535
|
|
1631
|
-
ν = data.generalized_velocity
|
1536
|
+
ν = data.generalized_velocity
|
1632
1537
|
Jh = total_momentum_jacobian(model=model, data=data)
|
1633
1538
|
|
1634
1539
|
return Jh @ ν
|
@@ -1668,13 +1573,11 @@ def total_momentum_jacobian(
|
|
1668
1573
|
B_Jh = B_Jh_B
|
1669
1574
|
|
1670
1575
|
case VelRepr.Inertial:
|
1671
|
-
B_X_W = Adjoint.from_transform(
|
1672
|
-
transform=data.base_transform(), inverse=True
|
1673
|
-
)
|
1576
|
+
B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True)
|
1674
1577
|
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
1675
1578
|
|
1676
1579
|
case VelRepr.Mixed:
|
1677
|
-
BW_H_B = data.
|
1580
|
+
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
|
1678
1581
|
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1679
1582
|
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
1680
1583
|
|
@@ -1686,14 +1589,14 @@ def total_momentum_jacobian(
|
|
1686
1589
|
return B_Jh
|
1687
1590
|
|
1688
1591
|
case VelRepr.Inertial:
|
1689
|
-
W_H_B = data.
|
1592
|
+
W_H_B = data._base_transform
|
1690
1593
|
B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
|
1691
1594
|
W_Xf_B = B_Xv_W.T
|
1692
1595
|
W_Jh = W_Xf_B @ B_Jh
|
1693
1596
|
return W_Jh
|
1694
1597
|
|
1695
1598
|
case VelRepr.Mixed:
|
1696
|
-
BW_H_B = data.
|
1599
|
+
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
|
1697
1600
|
B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1698
1601
|
BW_Xf_B = B_Xv_BW.T
|
1699
1602
|
BW_Jh = BW_Xf_B @ B_Jh
|
@@ -1718,7 +1621,7 @@ def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.V
|
|
1718
1621
|
in the active representation.
|
1719
1622
|
"""
|
1720
1623
|
|
1721
|
-
ν = data.generalized_velocity
|
1624
|
+
ν = data.generalized_velocity
|
1722
1625
|
J = average_velocity_jacobian(model=model, data=data)
|
1723
1626
|
|
1724
1627
|
return J @ ν
|
@@ -1766,9 +1669,9 @@ def average_velocity_jacobian(
|
|
1766
1669
|
case VelRepr.Body:
|
1767
1670
|
|
1768
1671
|
GB_J = G_J
|
1769
|
-
W_p_B = data.base_position
|
1672
|
+
W_p_B = data.base_position
|
1770
1673
|
W_p_CoM = js.com.com_position(model=model, data=data)
|
1771
|
-
B_R_W = data.base_orientation
|
1674
|
+
B_R_W = jaxsim.math.Quaternion.to_dcm(data.base_orientation).transpose()
|
1772
1675
|
|
1773
1676
|
B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
|
1774
1677
|
B_X_GB = Adjoint.from_transform(transform=B_H_GB)
|
@@ -1778,7 +1681,7 @@ def average_velocity_jacobian(
|
|
1778
1681
|
case VelRepr.Mixed:
|
1779
1682
|
|
1780
1683
|
GW_J = G_J
|
1781
|
-
W_p_B = data.base_position
|
1684
|
+
W_p_B = data.base_position
|
1782
1685
|
W_p_CoM = js.com.com_position(model=model, data=data)
|
1783
1686
|
|
1784
1687
|
BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
|
@@ -1819,7 +1722,7 @@ def link_bias_accelerations(
|
|
1819
1722
|
# ================================================
|
1820
1723
|
|
1821
1724
|
# Compute the base transform.
|
1822
|
-
W_H_B = data.
|
1725
|
+
W_H_B = data._base_transform
|
1823
1726
|
|
1824
1727
|
def other_representation_to_inertial(
|
1825
1728
|
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
|
@@ -1846,25 +1749,25 @@ def link_bias_accelerations(
|
|
1846
1749
|
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
|
1847
1750
|
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
1848
1751
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
1849
|
-
C_v_WB = W_v_WB = data.base_velocity
|
1752
|
+
C_v_WB = W_v_WB = data.base_velocity
|
1850
1753
|
|
1851
1754
|
case VelRepr.Body:
|
1852
1755
|
W_H_C = W_H_B
|
1853
1756
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
1854
|
-
W_v_WC = W_v_WB = data.base_velocity
|
1757
|
+
W_v_WC = W_v_WB = data.base_velocity # noqa: F841
|
1855
1758
|
with data.switch_velocity_representation(VelRepr.Body):
|
1856
|
-
C_v_WB = B_v_WB = data.base_velocity
|
1759
|
+
C_v_WB = B_v_WB = data.base_velocity
|
1857
1760
|
|
1858
1761
|
case VelRepr.Mixed:
|
1859
1762
|
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
1860
1763
|
W_H_C = W_H_BW
|
1861
1764
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
1862
|
-
W_ṗ_B = data.base_velocity
|
1765
|
+
W_ṗ_B = data.base_velocity[0:3]
|
1863
1766
|
BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
|
1864
1767
|
W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
|
1865
1768
|
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841
|
1866
1769
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
1867
|
-
C_v_WB = BW_v_WB = data.base_velocity
|
1770
|
+
C_v_WB = BW_v_WB = data.base_velocity # noqa: F841
|
1868
1771
|
|
1869
1772
|
case _:
|
1870
1773
|
raise ValueError(data.velocity_representation)
|
@@ -1888,20 +1791,23 @@ def link_bias_accelerations(
|
|
1888
1791
|
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
1889
1792
|
# These transforms define the relative kinematics of the entire model, including
|
1890
1793
|
# the base transform for both floating-base and fixed-base models.
|
1891
|
-
i_X_λi
|
1892
|
-
joint_positions=data.joint_positions
|
1794
|
+
i_X_λi = model.kin_dyn_parameters.joint_transforms(
|
1795
|
+
joint_positions=data.joint_positions, base_transform=W_H_B
|
1893
1796
|
)
|
1894
1797
|
|
1798
|
+
# Extract the joint motion subspaces.
|
1799
|
+
S = model.kin_dyn_parameters.motion_subspaces
|
1800
|
+
|
1895
1801
|
# Allocate the buffer to store the body-fixed link velocities.
|
1896
1802
|
L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6))
|
1897
1803
|
|
1898
1804
|
# Store the base velocity.
|
1899
1805
|
with data.switch_velocity_representation(VelRepr.Body):
|
1900
|
-
B_v_WB = data.base_velocity
|
1806
|
+
B_v_WB = data.base_velocity
|
1901
1807
|
L_v_WL = L_v_WL.at[0].set(B_v_WB)
|
1902
1808
|
|
1903
1809
|
# Get the joint velocities.
|
1904
|
-
ṡ = data.joint_velocities
|
1810
|
+
ṡ = data.joint_velocities
|
1905
1811
|
|
1906
1812
|
# Allocate the buffer to store the body-fixed link accelerations,
|
1907
1813
|
# and initialize the base acceleration.
|
@@ -1980,11 +1886,11 @@ def link_bias_accelerations(
|
|
1980
1886
|
)
|
1981
1887
|
|
1982
1888
|
case VelRepr.Inertial:
|
1983
|
-
C_H_L = W_H_L =
|
1889
|
+
C_H_L = W_H_L = data._link_transforms
|
1984
1890
|
L_v_CL = L_v_WL
|
1985
1891
|
|
1986
1892
|
case VelRepr.Mixed:
|
1987
|
-
W_H_L =
|
1893
|
+
W_H_L = data._link_transforms
|
1988
1894
|
LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
|
1989
1895
|
C_H_L = LW_H_L
|
1990
1896
|
L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
|
@@ -2002,77 +1908,6 @@ def link_bias_accelerations(
|
|
2002
1908
|
return O_v̇_WL
|
2003
1909
|
|
2004
1910
|
|
2005
|
-
@jax.jit
|
2006
|
-
@js.common.named_scope
|
2007
|
-
def link_contact_forces(
|
2008
|
-
model: js.model.JaxSimModel,
|
2009
|
-
data: js.data.JaxSimModelData,
|
2010
|
-
*,
|
2011
|
-
link_forces: jtp.MatrixLike | None = None,
|
2012
|
-
joint_force_references: jtp.VectorLike | None = None,
|
2013
|
-
**kwargs,
|
2014
|
-
) -> jtp.Matrix:
|
2015
|
-
"""
|
2016
|
-
Compute the 6D contact forces of all links of the model.
|
2017
|
-
|
2018
|
-
Args:
|
2019
|
-
model: The model to consider.
|
2020
|
-
data: The data of the considered model.
|
2021
|
-
link_forces:
|
2022
|
-
The 6D external forces to apply to the links expressed in the same
|
2023
|
-
representation of data.
|
2024
|
-
joint_force_references:
|
2025
|
-
The joint force references to apply to the joints.
|
2026
|
-
kwargs: Additional keyword arguments to pass to the active contact model..
|
2027
|
-
|
2028
|
-
Returns:
|
2029
|
-
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
|
2030
|
-
expressed in the frame corresponding to the active representation.
|
2031
|
-
"""
|
2032
|
-
|
2033
|
-
# Note: the following code should be kept in sync with the function
|
2034
|
-
# `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since
|
2035
|
-
# there we need to get also aux_data.
|
2036
|
-
|
2037
|
-
# Build link forces if not provided.
|
2038
|
-
# These forces are expressed in the frame corresponding to the velocity
|
2039
|
-
# representation of data.
|
2040
|
-
O_f_L = (
|
2041
|
-
jnp.atleast_2d(link_forces.squeeze())
|
2042
|
-
if link_forces is not None
|
2043
|
-
else jnp.zeros((model.number_of_links(), 6))
|
2044
|
-
).astype(float)
|
2045
|
-
|
2046
|
-
# Build joint force references if not provided.
|
2047
|
-
joint_force_references = (
|
2048
|
-
jnp.atleast_1d(joint_force_references)
|
2049
|
-
if joint_force_references is not None
|
2050
|
-
else jnp.zeros(model.dofs())
|
2051
|
-
)
|
2052
|
-
|
2053
|
-
# We expect that the 6D forces included in the `link_forces` argument are expressed
|
2054
|
-
# in the frame corresponding to the velocity representation of `data`.
|
2055
|
-
input_references = js.references.JaxSimModelReferences.build(
|
2056
|
-
model=model,
|
2057
|
-
data=data,
|
2058
|
-
velocity_representation=data.velocity_representation,
|
2059
|
-
link_forces=O_f_L,
|
2060
|
-
joint_force_references=joint_force_references,
|
2061
|
-
)
|
2062
|
-
|
2063
|
-
# Compute the 6D forces applied to the links equivalent to the forces applied
|
2064
|
-
# to the frames associated to the collidable points.
|
2065
|
-
f_L, _ = model.contact_model.compute_link_contact_forces(
|
2066
|
-
model=model,
|
2067
|
-
data=data,
|
2068
|
-
link_forces=input_references.link_forces(model=model, data=data),
|
2069
|
-
joint_force_references=input_references.joint_force_references(),
|
2070
|
-
**kwargs,
|
2071
|
-
)
|
2072
|
-
|
2073
|
-
return f_L
|
2074
|
-
|
2075
|
-
|
2076
1911
|
# ======
|
2077
1912
|
# Energy
|
2078
1913
|
# ======
|
@@ -2113,7 +1948,7 @@ def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Flo
|
|
2113
1948
|
"""
|
2114
1949
|
|
2115
1950
|
with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
|
2116
|
-
B_ν = data.generalized_velocity
|
1951
|
+
B_ν = data.generalized_velocity
|
2117
1952
|
M_B = free_floating_mass_matrix(model=model, data=data)
|
2118
1953
|
|
2119
1954
|
K = 0.5 * B_ν.T @ M_B @ B_ν
|
@@ -2135,11 +1970,8 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
|
|
2135
1970
|
"""
|
2136
1971
|
|
2137
1972
|
m = total_mass(model=model)
|
2138
|
-
gravity = data.gravity.squeeze()
|
2139
1973
|
W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1])
|
2140
|
-
|
2141
|
-
U = -jnp.hstack([gravity, 0]) @ (m * W_p̃_CoM)
|
2142
|
-
return U.squeeze().astype(float)
|
1974
|
+
return jnp.sum((m * W_p̃_CoM)[2] * model.gravity)
|
2143
1975
|
|
2144
1976
|
|
2145
1977
|
# ==========
|
@@ -2153,34 +1985,22 @@ def step(
|
|
2153
1985
|
model: JaxSimModel,
|
2154
1986
|
data: js.data.JaxSimModelData,
|
2155
1987
|
*,
|
2156
|
-
t0: jtp.FloatLike = 0.0,
|
2157
|
-
dt: jtp.FloatLike | None = None,
|
2158
|
-
integrator: jaxsim.integrators.Integrator | None = None,
|
2159
|
-
integrator_metadata: dict[str, Any] | None = None,
|
2160
1988
|
link_forces: jtp.MatrixLike | None = None,
|
2161
1989
|
joint_force_references: jtp.VectorLike | None = None,
|
2162
|
-
|
2163
|
-
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
|
1990
|
+
) -> js.data.JaxSimModelData:
|
2164
1991
|
"""
|
2165
1992
|
Perform a simulation step.
|
2166
1993
|
|
2167
1994
|
Args:
|
2168
1995
|
model: The model to consider.
|
2169
1996
|
data: The data of the considered model.
|
2170
|
-
integrator: The integrator to use.
|
2171
|
-
integrator_metadata: The metadata of the integrator, if needed.
|
2172
|
-
t0: The initial time to consider. Only relevant for time-dependent dynamics.
|
2173
1997
|
dt: The time step to consider. If not specified, it is read from the model.
|
2174
1998
|
link_forces:
|
2175
|
-
The 6D forces to apply to the links expressed in
|
2176
|
-
the velocity representation of `data`.
|
1999
|
+
The 6D forces to apply to the links expressed in same representation of data.
|
2177
2000
|
joint_force_references: The joint force references to consider.
|
2178
|
-
kwargs: Additional kwargs to pass to the integrator.
|
2179
2001
|
|
2180
2002
|
Returns:
|
2181
|
-
|
2182
|
-
data computed during the step. If the integrator has metadata, the dictionary
|
2183
|
-
will contain the new metadata stored in the `integrator_metadata` key.
|
2003
|
+
The new data of the model after the simulation step.
|
2184
2004
|
|
2185
2005
|
Note:
|
2186
2006
|
In order to reduce the occurrences of frame conversions performed internally,
|
@@ -2188,167 +2008,84 @@ def step(
|
|
2188
2008
|
particularly useful for automatically differentiated logic.
|
2189
2009
|
"""
|
2190
2010
|
|
2191
|
-
#
|
2192
|
-
#
|
2193
|
-
# kwargs of this step function.
|
2194
|
-
kwargs = kwargs if kwargs is not None else {}
|
2195
|
-
integrator_kwargs = kwargs.pop("integrator_kwargs", {})
|
2196
|
-
integrator_kwargs = kwargs | integrator_kwargs
|
2197
|
-
|
2198
|
-
# Extract the integrator and the optional metadata.
|
2199
|
-
integrator_metadata_t0 = integrator_metadata
|
2200
|
-
integrator = integrator if integrator is not None else model.integrator
|
2201
|
-
|
2202
|
-
# Initialize the time-related variables.
|
2203
|
-
state_t0 = data.state
|
2204
|
-
t0 = jnp.array(t0, dtype=float)
|
2205
|
-
dt = jnp.array(dt if dt is not None else model.time_step).astype(float)
|
2206
|
-
|
2207
|
-
# The visco-elastic contacts operate at best with their own integrator.
|
2208
|
-
# They can be used with Euler-like integrators, paying the price of ignoring
|
2209
|
-
# some of the benefits of continuous-time integration on the system position.
|
2210
|
-
# Furthermore, the requirement to know the Δt used by the integrator is not
|
2211
|
-
# compatible with high-order integrators, that use advanced RK stages to evaluate
|
2212
|
-
# the dynamics at intermediate times.
|
2213
|
-
module = jaxsim.rbda.contacts.visco_elastic.step.__module__
|
2214
|
-
name = jaxsim.rbda.contacts.visco_elastic.step.__name__
|
2215
|
-
msg = "You need to use the custom '{}.{}' function with this contact model."
|
2216
|
-
jaxsim.exceptions.raise_runtime_error_if(
|
2217
|
-
condition=(
|
2218
|
-
isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts)
|
2219
|
-
& (
|
2220
|
-
~jnp.allclose(dt, model.time_step)
|
2221
|
-
| ~int(
|
2222
|
-
isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler)
|
2223
|
-
)
|
2224
|
-
)
|
2225
|
-
),
|
2226
|
-
msg=msg.format(module, name),
|
2227
|
-
)
|
2011
|
+
# TODO: some contact models here may want to perform a dynamic filtering of
|
2012
|
+
# the enabled collidable points
|
2228
2013
|
|
2229
|
-
#
|
2230
|
-
|
2231
|
-
|
2014
|
+
# Extract the inputs
|
2015
|
+
O_f_L_external = jnp.atleast_2d(
|
2016
|
+
jnp.array(link_forces, dtype=float).squeeze()
|
2017
|
+
if link_forces is not None
|
2018
|
+
else jnp.zeros((model.number_of_links(), 6))
|
2019
|
+
)
|
2232
2020
|
|
2233
|
-
#
|
2234
|
-
|
2021
|
+
# Get the external forces in inertial-fixed representation.
|
2022
|
+
W_f_L_external = jax.vmap(
|
2023
|
+
lambda f_L, W_H_L: js.data.JaxSimModelData.other_representation_to_inertial(
|
2024
|
+
f_L,
|
2025
|
+
other_representation=data.velocity_representation,
|
2026
|
+
transform=W_H_L,
|
2027
|
+
is_force=True,
|
2028
|
+
)
|
2029
|
+
)(O_f_L_external, data._link_transforms)
|
2235
2030
|
|
2236
|
-
|
2237
|
-
|
2238
|
-
|
2239
|
-
|
2240
|
-
model=model,
|
2241
|
-
data=data,
|
2242
|
-
velocity_representation=data.velocity_representation,
|
2243
|
-
link_forces=link_forces,
|
2244
|
-
joint_force_references=joint_force_references,
|
2031
|
+
τ_references = jnp.atleast_1d(
|
2032
|
+
jnp.array(joint_force_references, dtype=float).squeeze()
|
2033
|
+
if joint_force_references is not None
|
2034
|
+
else jnp.zeros(model.dofs())
|
2245
2035
|
)
|
2246
2036
|
|
2247
|
-
#
|
2248
|
-
#
|
2249
|
-
#
|
2037
|
+
# ================================
|
2038
|
+
# Compute the total joint torques
|
2039
|
+
# ================================
|
2250
2040
|
|
2251
|
-
|
2252
|
-
|
2253
|
-
|
2254
|
-
f_L = references.link_forces(model=model, data=data)
|
2255
|
-
τ_references = references.joint_force_references(model=model)
|
2256
|
-
|
2257
|
-
# Step the dynamics forward.
|
2258
|
-
state_tf, integrator_metadata_tf = integrator.step(
|
2259
|
-
x0=state_t0,
|
2260
|
-
t0=t0,
|
2261
|
-
dt=dt,
|
2262
|
-
metadata=integrator_metadata_t0,
|
2263
|
-
# Always inject the current (model, data) pair into the system dynamics
|
2264
|
-
# considered by the integrator, and include the input variables represented
|
2265
|
-
# by the pair (f_L, τ_references).
|
2266
|
-
# Note that the wrapper of the system dynamics will override (state_x0, t0)
|
2267
|
-
# inside the passed data even if it is not strictly needed. This logic is
|
2268
|
-
# necessary to reuse the jit-compiled step function of compatible pytrees
|
2269
|
-
# of model and data produced e.g. by parameterized applications.
|
2270
|
-
**(
|
2271
|
-
dict(
|
2272
|
-
model=model,
|
2273
|
-
data=data,
|
2274
|
-
link_forces=f_L,
|
2275
|
-
joint_force_references=τ_references,
|
2276
|
-
)
|
2277
|
-
| integrator_kwargs
|
2278
|
-
),
|
2041
|
+
τ_total = js.actuation_model.compute_resultant_torques(
|
2042
|
+
model, data, joint_force_references=τ_references
|
2279
2043
|
)
|
2280
2044
|
|
2281
|
-
#
|
2282
|
-
|
2045
|
+
# ======================
|
2046
|
+
# Compute contact forces
|
2047
|
+
# ======================
|
2283
2048
|
|
2284
|
-
|
2285
|
-
# Phase 3: post-step
|
2286
|
-
# ==================
|
2049
|
+
W_f_L_terrain = jnp.zeros_like(W_f_L_external)
|
2287
2050
|
|
2288
|
-
|
2289
|
-
match model.contact_model:
|
2051
|
+
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
2290
2052
|
|
2291
|
-
#
|
2292
|
-
#
|
2293
|
-
|
2294
|
-
|
2053
|
+
# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
|
2054
|
+
# with the terrain.
|
2055
|
+
W_f_L_terrain = js.contact_model.link_contact_forces(
|
2056
|
+
model=model,
|
2057
|
+
data=data,
|
2058
|
+
link_forces=W_f_L_external,
|
2059
|
+
joint_torques=τ_total,
|
2060
|
+
)
|
2295
2061
|
|
2296
|
-
|
2297
|
-
|
2298
|
-
|
2299
|
-
condition=isinstance(
|
2300
|
-
integrator,
|
2301
|
-
jaxsim.integrators.fixed_step.ForwardEuler
|
2302
|
-
| jaxsim.integrators.fixed_step.ForwardEulerSO3,
|
2303
|
-
)
|
2304
|
-
& ((data_tf.contacts_params.K > 0) | (data_tf.contacts_params.D > 0)),
|
2305
|
-
msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
|
2306
|
-
)
|
2062
|
+
# ==============================
|
2063
|
+
# Compute the total link forces
|
2064
|
+
# ==============================
|
2307
2065
|
|
2308
|
-
|
2309
|
-
indices_of_enabled_collidable_points = (
|
2310
|
-
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
2311
|
-
)
|
2066
|
+
W_f_L_total = W_f_L_external + W_f_L_terrain
|
2312
2067
|
|
2313
|
-
|
2314
|
-
|
2315
|
-
|
2316
|
-
|
2317
|
-
# Compute the penetration depth of the collidable points.
|
2318
|
-
δ, *_ = jax.vmap(
|
2319
|
-
jaxsim.rbda.contacts.common.compute_penetration_data,
|
2320
|
-
in_axes=(0, 0, None),
|
2321
|
-
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
|
2322
|
-
|
2323
|
-
with data_tf.switch_velocity_representation(VelRepr.Mixed):
|
2324
|
-
J_WC = js.contact.jacobian(model, data_tf)[
|
2325
|
-
indices_of_enabled_collidable_points
|
2326
|
-
]
|
2327
|
-
M = js.model.free_floating_mass_matrix(model, data_tf)
|
2328
|
-
BW_ν_pre_impact = data_tf.generalized_velocity()
|
2329
|
-
|
2330
|
-
# Compute the impact velocity.
|
2331
|
-
# It may be discontinuous in case new contacts are made.
|
2332
|
-
BW_ν_post_impact = (
|
2333
|
-
jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
|
2334
|
-
generalized_velocity=BW_ν_pre_impact,
|
2335
|
-
inactive_collidable_points=(δ <= 0),
|
2336
|
-
M=M,
|
2337
|
-
J_WC=J_WC,
|
2338
|
-
)
|
2339
|
-
)
|
2068
|
+
# ===============================
|
2069
|
+
# Compute the system acceleration
|
2070
|
+
# ===============================
|
2340
2071
|
|
2341
|
-
|
2342
|
-
|
2343
|
-
|
2072
|
+
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
|
2073
|
+
W_v̇_WB, s̈ = js.ode.system_acceleration(
|
2074
|
+
model=model,
|
2075
|
+
data=data,
|
2076
|
+
link_forces=W_f_L_total,
|
2077
|
+
joint_torques=τ_total,
|
2078
|
+
)
|
2344
2079
|
|
2345
|
-
#
|
2346
|
-
|
2347
|
-
|
2348
|
-
)
|
2080
|
+
# =============================
|
2081
|
+
# Advance the simulation state
|
2082
|
+
# =============================
|
2349
2083
|
|
2350
|
-
|
2351
|
-
|
2352
|
-
|
2353
|
-
|
2084
|
+
data_tf = js.integrators.semi_implicit_euler_integration(
|
2085
|
+
model=model,
|
2086
|
+
data=data,
|
2087
|
+
base_acceleration_inertial=W_v̇_WB,
|
2088
|
+
joint_accelerations=s̈,
|
2354
2089
|
)
|
2090
|
+
|
2091
|
+
return data_tf
|