jaxsim 0.3.1.dev62__py3-none-any.whl → 0.3.1.dev94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -5
- jaxsim/_version.py +2 -2
- jaxsim/api/com.py +3 -4
- jaxsim/api/common.py +11 -11
- jaxsim/api/contact.py +11 -3
- jaxsim/api/data.py +3 -6
- jaxsim/api/frame.py +9 -10
- jaxsim/api/kin_dyn_parameters.py +25 -28
- jaxsim/api/link.py +12 -12
- jaxsim/api/model.py +47 -43
- jaxsim/api/ode.py +19 -12
- jaxsim/api/ode_data.py +11 -11
- jaxsim/integrators/common.py +19 -29
- jaxsim/integrators/fixed_step.py +10 -10
- jaxsim/integrators/variable_step.py +13 -13
- jaxsim/math/__init__.py +2 -1
- jaxsim/math/joint_model.py +2 -1
- jaxsim/math/quaternion.py +3 -9
- jaxsim/math/transform.py +2 -2
- jaxsim/mujoco/loaders.py +5 -5
- jaxsim/mujoco/model.py +6 -6
- jaxsim/mujoco/visualizer.py +3 -0
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/joint.py +1 -1
- jaxsim/parsers/descriptions/link.py +3 -4
- jaxsim/parsers/descriptions/model.py +1 -1
- jaxsim/parsers/kinematic_graph.py +38 -39
- jaxsim/parsers/rod/parser.py +14 -14
- jaxsim/parsers/rod/utils.py +9 -11
- jaxsim/rbda/aba.py +6 -12
- jaxsim/rbda/collidable_points.py +8 -7
- jaxsim/rbda/contacts/soft.py +29 -27
- jaxsim/rbda/crba.py +3 -3
- jaxsim/rbda/forward_kinematics.py +1 -1
- jaxsim/rbda/jacobian.py +8 -8
- jaxsim/rbda/rnea.py +3 -3
- jaxsim/rbda/utils.py +1 -1
- jaxsim/terrain/terrain.py +100 -22
- jaxsim/typing.py +14 -22
- jaxsim/utils/jaxsim_dataclass.py +4 -4
- jaxsim/utils/wrappers.py +5 -1
- {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/METADATA +1 -1
- jaxsim-0.3.1.dev94.dist-info/RECORD +68 -0
- {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/WHEEL +1 -1
- jaxsim-0.3.1.dev62.dist-info/RECORD +0 -68
- {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/LICENSE +0 -0
- {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py
CHANGED
@@ -9,14 +9,14 @@ from typing import Any, Sequence
|
|
9
9
|
import jax
|
10
10
|
import jax.numpy as jnp
|
11
11
|
import jax_dataclasses
|
12
|
-
import jaxlie
|
13
12
|
import rod
|
14
13
|
from jax_dataclasses import Static
|
15
14
|
|
16
15
|
import jaxsim.api as js
|
17
|
-
import jaxsim.
|
16
|
+
import jaxsim.terrain
|
18
17
|
import jaxsim.typing as jtp
|
19
|
-
from jaxsim.math import Cross
|
18
|
+
from jaxsim.math import Adjoint, Cross
|
19
|
+
from jaxsim.parsers.descriptions import ModelDescription
|
20
20
|
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
|
21
21
|
|
22
22
|
from .common import VelRepr
|
@@ -46,12 +46,12 @@ class JaxSimModel(JaxsimDataclass):
|
|
46
46
|
default=None, repr=False
|
47
47
|
)
|
48
48
|
|
49
|
-
_description: Static[
|
50
|
-
|
51
|
-
|
49
|
+
_description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
|
50
|
+
dataclasses.field(default=None, repr=False)
|
51
|
+
)
|
52
52
|
|
53
53
|
@property
|
54
|
-
def description(self) ->
|
54
|
+
def description(self) -> ModelDescription:
|
55
55
|
return self._description.get()
|
56
56
|
|
57
57
|
def __eq__(self, other: JaxSimModel) -> bool:
|
@@ -116,7 +116,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
116
116
|
import jaxsim.parsers.rod
|
117
117
|
|
118
118
|
# Parse the input resource (either a path to file or a string with the URDF/SDF)
|
119
|
-
# and build the -intermediate- model description
|
119
|
+
# and build the -intermediate- model description.
|
120
120
|
intermediate_description = jaxsim.parsers.rod.build_model_description(
|
121
121
|
model_description=model_description, is_urdf=is_urdf
|
122
122
|
)
|
@@ -128,7 +128,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
128
128
|
considered_joints=considered_joints
|
129
129
|
)
|
130
130
|
|
131
|
-
# Build the model
|
131
|
+
# Build the model.
|
132
132
|
model = JaxSimModel.build(
|
133
133
|
model_description=intermediate_description,
|
134
134
|
model_name=model_name,
|
@@ -136,7 +136,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
136
136
|
contact_model=contact_model,
|
137
137
|
)
|
138
138
|
|
139
|
-
# Store the origin of the model, in case downstream logic needs it
|
139
|
+
# Store the origin of the model, in case downstream logic needs it.
|
140
140
|
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
141
141
|
model.built_from = model_description
|
142
142
|
|
@@ -144,7 +144,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
144
144
|
|
145
145
|
@staticmethod
|
146
146
|
def build(
|
147
|
-
model_description:
|
147
|
+
model_description: ModelDescription,
|
148
148
|
model_name: str | None = None,
|
149
149
|
*,
|
150
150
|
terrain: jaxsim.terrain.Terrain | None = None,
|
@@ -169,14 +169,14 @@ class JaxSimModel(JaxsimDataclass):
|
|
169
169
|
"""
|
170
170
|
from jaxsim.rbda.contacts.soft import SoftContacts
|
171
171
|
|
172
|
-
# Set the model name (if not provided, use the one from the model description)
|
172
|
+
# Set the model name (if not provided, use the one from the model description).
|
173
173
|
model_name = model_name if model_name is not None else model_description.name
|
174
174
|
|
175
|
-
# Set the terrain (if not provided, use the default flat terrain)
|
175
|
+
# Set the terrain (if not provided, use the default flat terrain).
|
176
176
|
terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
|
177
177
|
contact_model = contact_model or SoftContacts(terrain=terrain)
|
178
178
|
|
179
|
-
# Build the model
|
179
|
+
# Build the model.
|
180
180
|
model = JaxSimModel(
|
181
181
|
model_name=model_name,
|
182
182
|
_description=wrappers.HashlessObject(obj=model_description),
|
@@ -361,7 +361,7 @@ def reduce(
|
|
361
361
|
considered_joints=list(considered_joints)
|
362
362
|
)
|
363
363
|
|
364
|
-
# Build the reduced model
|
364
|
+
# Build the reduced model.
|
365
365
|
reduced_model = JaxSimModel.build(
|
366
366
|
model_description=reduced_intermediate_description,
|
367
367
|
model_name=model.name(),
|
@@ -369,7 +369,7 @@ def reduce(
|
|
369
369
|
contact_model=model.contact_model,
|
370
370
|
)
|
371
371
|
|
372
|
-
# Store the origin of the model, in case downstream logic needs it
|
372
|
+
# Store the origin of the model, in case downstream logic needs it.
|
373
373
|
with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
374
374
|
reduced_model.built_from = model.built_from
|
375
375
|
|
@@ -493,7 +493,7 @@ def generalized_free_floating_jacobian(
|
|
493
493
|
case VelRepr.Inertial:
|
494
494
|
|
495
495
|
W_H_B = data.base_transform()
|
496
|
-
B_X_W =
|
496
|
+
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
|
497
497
|
|
498
498
|
B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
|
499
499
|
B_X_W, jnp.eye(model.dofs())
|
@@ -507,7 +507,7 @@ def generalized_free_floating_jacobian(
|
|
507
507
|
|
508
508
|
W_R_B = data.base_orientation(dcm=True)
|
509
509
|
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
|
510
|
-
B_X_BW =
|
510
|
+
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
511
511
|
|
512
512
|
B_J_full_WX_I = B_J_full_WX_BW = (
|
513
513
|
B_J_full_WX_B
|
@@ -715,7 +715,7 @@ def forward_dynamics_aba(
|
|
715
715
|
|
716
716
|
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
717
717
|
# In Inertial and Body representations, the cross product is always zero.
|
718
|
-
C_X_W =
|
718
|
+
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
|
719
719
|
return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)
|
720
720
|
|
721
721
|
match data.velocity_representation:
|
@@ -797,21 +797,21 @@ def forward_dynamics_crb(
|
|
797
797
|
# Prepare data
|
798
798
|
# ============
|
799
799
|
|
800
|
-
# Build joint torques if not provided
|
800
|
+
# Build joint torques if not provided.
|
801
801
|
τ = (
|
802
802
|
jnp.atleast_1d(joint_forces)
|
803
803
|
if joint_forces is not None
|
804
804
|
else jnp.zeros_like(data.joint_positions())
|
805
805
|
)
|
806
806
|
|
807
|
-
# Build external forces if not provided
|
807
|
+
# Build external forces if not provided.
|
808
808
|
f = (
|
809
809
|
jnp.atleast_2d(link_forces)
|
810
810
|
if link_forces is not None
|
811
811
|
else jnp.zeros(shape=(model.number_of_links(), 6))
|
812
812
|
)
|
813
813
|
|
814
|
-
# Compute terms of the floating-base EoM
|
814
|
+
# Compute terms of the floating-base EoM.
|
815
815
|
M = free_floating_mass_matrix(model=model, data=data)
|
816
816
|
h = free_floating_bias_forces(model=model, data=data)
|
817
817
|
S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
|
@@ -848,7 +848,7 @@ def forward_dynamics_crb(
|
|
848
848
|
# 6D transformation X.
|
849
849
|
v̇_WB = ν̇[0:6].squeeze().astype(float)
|
850
850
|
|
851
|
-
# Extract the joint accelerations
|
851
|
+
# Extract the joint accelerations.
|
852
852
|
s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float)
|
853
853
|
|
854
854
|
return v̇_WB, s̈
|
@@ -880,7 +880,9 @@ def free_floating_mass_matrix(
|
|
880
880
|
|
881
881
|
case VelRepr.Inertial:
|
882
882
|
|
883
|
-
B_X_W =
|
883
|
+
B_X_W = Adjoint.from_transform(
|
884
|
+
transform=data.base_transform(), inverse=True
|
885
|
+
)
|
884
886
|
invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
885
887
|
|
886
888
|
return invT.T @ M_body @ invT
|
@@ -888,7 +890,7 @@ def free_floating_mass_matrix(
|
|
888
890
|
case VelRepr.Mixed:
|
889
891
|
|
890
892
|
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
891
|
-
B_X_BW =
|
893
|
+
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
892
894
|
invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
893
895
|
|
894
896
|
return invT.T @ M_body @ invT
|
@@ -1077,8 +1079,8 @@ def inverse_dynamics(
|
|
1077
1079
|
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
|
1078
1080
|
"""
|
1079
1081
|
|
1080
|
-
W_X_C =
|
1081
|
-
C_X_W =
|
1082
|
+
W_X_C = Adjoint.from_transform(transform=W_H_C)
|
1083
|
+
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
|
1082
1084
|
C_v_WC = C_X_W @ W_v_WC
|
1083
1085
|
|
1084
1086
|
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
@@ -1187,12 +1189,12 @@ def free_floating_gravity_forces(
|
|
1187
1189
|
The free-floating gravity forces of the model.
|
1188
1190
|
"""
|
1189
1191
|
|
1190
|
-
# Build a zeroed state
|
1192
|
+
# Build a zeroed state.
|
1191
1193
|
data_rnea = js.data.JaxSimModelData.zero(
|
1192
1194
|
model=model, velocity_representation=data.velocity_representation
|
1193
1195
|
)
|
1194
1196
|
|
1195
|
-
# Set just the generalized position
|
1197
|
+
# Set just the generalized position.
|
1196
1198
|
with data_rnea.mutable_context(
|
1197
1199
|
mutability=Mutability.MUTABLE, restore_after_exception=False
|
1198
1200
|
):
|
@@ -1237,12 +1239,12 @@ def free_floating_bias_forces(
|
|
1237
1239
|
The free-floating bias forces of the model.
|
1238
1240
|
"""
|
1239
1241
|
|
1240
|
-
# Build a zeroed state
|
1242
|
+
# Build a zeroed state.
|
1241
1243
|
data_rnea = js.data.JaxSimModelData.zero(
|
1242
1244
|
model=model, velocity_representation=data.velocity_representation
|
1243
1245
|
)
|
1244
1246
|
|
1245
|
-
# Set the generalized position and generalized velocity
|
1247
|
+
# Set the generalized position and generalized velocity.
|
1246
1248
|
with data_rnea.mutable_context(
|
1247
1249
|
mutability=Mutability.MUTABLE, restore_after_exception=False
|
1248
1250
|
):
|
@@ -1361,12 +1363,14 @@ def total_momentum_jacobian(
|
|
1361
1363
|
B_Jh = B_Jh_B
|
1362
1364
|
|
1363
1365
|
case VelRepr.Inertial:
|
1364
|
-
B_X_W =
|
1366
|
+
B_X_W = Adjoint.from_transform(
|
1367
|
+
transform=data.base_transform(), inverse=True
|
1368
|
+
)
|
1365
1369
|
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
1366
1370
|
|
1367
1371
|
case VelRepr.Mixed:
|
1368
1372
|
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
1369
|
-
B_X_BW =
|
1373
|
+
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1370
1374
|
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
1371
1375
|
|
1372
1376
|
case _:
|
@@ -1378,14 +1382,14 @@ def total_momentum_jacobian(
|
|
1378
1382
|
|
1379
1383
|
case VelRepr.Inertial:
|
1380
1384
|
W_H_B = data.base_transform()
|
1381
|
-
B_Xv_W =
|
1385
|
+
B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
|
1382
1386
|
W_Xf_B = B_Xv_W.T
|
1383
1387
|
W_Jh = W_Xf_B @ B_Jh
|
1384
1388
|
return W_Jh
|
1385
1389
|
|
1386
1390
|
case VelRepr.Mixed:
|
1387
1391
|
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
1388
|
-
B_Xv_BW =
|
1392
|
+
B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1389
1393
|
BW_Xf_B = B_Xv_BW.T
|
1390
1394
|
BW_Jh = BW_Xf_B @ B_Jh
|
1391
1395
|
return BW_Jh
|
@@ -1449,7 +1453,7 @@ def average_velocity_jacobian(
|
|
1449
1453
|
W_p_CoM = js.com.com_position(model=model, data=data)
|
1450
1454
|
|
1451
1455
|
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
|
1452
|
-
W_X_GW =
|
1456
|
+
W_X_GW = Adjoint.from_transform(transform=W_H_GW)
|
1453
1457
|
|
1454
1458
|
return W_X_GW @ GW_J
|
1455
1459
|
|
@@ -1461,7 +1465,7 @@ def average_velocity_jacobian(
|
|
1461
1465
|
B_R_W = data.base_orientation(dcm=True).transpose()
|
1462
1466
|
|
1463
1467
|
B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
|
1464
|
-
B_X_GB =
|
1468
|
+
B_X_GB = Adjoint.from_transform(transform=B_H_GB)
|
1465
1469
|
|
1466
1470
|
return B_X_GB @ GB_J
|
1467
1471
|
|
@@ -1472,7 +1476,7 @@ def average_velocity_jacobian(
|
|
1472
1476
|
W_p_CoM = js.com.com_position(model=model, data=data)
|
1473
1477
|
|
1474
1478
|
BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
|
1475
|
-
BW_X_GW =
|
1479
|
+
BW_X_GW = Adjoint.from_transform(transform=BW_H_GW)
|
1476
1480
|
|
1477
1481
|
return BW_X_GW @ GW_J
|
1478
1482
|
|
@@ -1518,8 +1522,8 @@ def link_bias_accelerations(
|
|
1518
1522
|
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
|
1519
1523
|
"""
|
1520
1524
|
|
1521
|
-
W_X_C =
|
1522
|
-
C_X_W =
|
1525
|
+
W_X_C = Adjoint.from_transform(transform=W_H_C)
|
1526
|
+
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
|
1523
1527
|
|
1524
1528
|
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
1525
1529
|
# In Inertial and Body representations, the cross product is always zero.
|
@@ -1606,7 +1610,7 @@ def link_bias_accelerations(
|
|
1606
1610
|
# not remove gravity during the propagation.
|
1607
1611
|
|
1608
1612
|
# Initialize the loop.
|
1609
|
-
Carry = tuple[jtp.
|
1613
|
+
Carry = tuple[jtp.Matrix, jtp.Matrix]
|
1610
1614
|
carry0: Carry = (L_v_WL, L_v̇_WL)
|
1611
1615
|
|
1612
1616
|
def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]:
|
@@ -1832,8 +1836,8 @@ def step(
|
|
1832
1836
|
integrator_state: The state of the integrator.
|
1833
1837
|
joint_forces: The joint forces to consider.
|
1834
1838
|
link_forces:
|
1835
|
-
The
|
1836
|
-
|
1839
|
+
The 6D forces to apply to the links expressed in the frame corresponding to
|
1840
|
+
the velocity representation of `data`.
|
1837
1841
|
kwargs: Additional kwargs to pass to the integrator.
|
1838
1842
|
|
1839
1843
|
Returns:
|
jaxsim/api/ode.py
CHANGED
@@ -96,7 +96,9 @@ def system_velocity_dynamics(
|
|
96
96
|
model: The model to consider.
|
97
97
|
data: The data of the considered model.
|
98
98
|
joint_forces: The joint forces to apply.
|
99
|
-
link_forces:
|
99
|
+
link_forces:
|
100
|
+
The 6D forces to apply to the links expressed in the frame corresponding to
|
101
|
+
the velocity representation of `data`.
|
100
102
|
|
101
103
|
Returns:
|
102
104
|
A tuple containing the derivative of the base 6D velocity in inertial-fixed
|
@@ -105,14 +107,16 @@ def system_velocity_dynamics(
|
|
105
107
|
the system dynamics evaluation.
|
106
108
|
"""
|
107
109
|
|
108
|
-
# Build joint torques if not provided
|
110
|
+
# Build joint torques if not provided.
|
109
111
|
τ = (
|
110
112
|
jnp.atleast_1d(joint_forces.squeeze())
|
111
113
|
if joint_forces is not None
|
112
114
|
else jnp.zeros_like(data.joint_positions())
|
113
115
|
).astype(float)
|
114
116
|
|
115
|
-
# Build link forces if not provided
|
117
|
+
# Build link forces if not provided.
|
118
|
+
# These forces are expressed in the frame corresponding to the velocity
|
119
|
+
# representation of data.
|
116
120
|
O_f_L = (
|
117
121
|
jnp.atleast_2d(link_forces.squeeze())
|
118
122
|
if link_forces is not None
|
@@ -127,16 +131,17 @@ def system_velocity_dynamics(
|
|
127
131
|
# with the terrain.
|
128
132
|
W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
|
129
133
|
|
130
|
-
#
|
131
|
-
|
132
|
-
W_f_Ci = None
|
134
|
+
# Import privately the soft contacts classes.
|
135
|
+
from jaxsim.rbda.contacts.soft import SoftContactsState
|
133
136
|
|
134
137
|
# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
|
138
|
+
assert isinstance(data.state.contact, SoftContactsState)
|
135
139
|
ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)
|
136
140
|
|
137
141
|
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
138
|
-
|
139
|
-
#
|
142
|
+
|
143
|
+
# Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
|
144
|
+
# and the corresponding material deformation rates.
|
140
145
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
141
146
|
W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data)
|
142
147
|
|
@@ -178,7 +183,7 @@ def system_velocity_dynamics(
|
|
178
183
|
model.kin_dyn_parameters.joint_parameters.friction_viscous
|
179
184
|
).astype(float)
|
180
185
|
|
181
|
-
# Compute the joint friction torque
|
186
|
+
# Compute the joint friction torque.
|
182
187
|
τ_friction = -(
|
183
188
|
jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
|
184
189
|
+ jnp.diag(kv) @ data.state.physics_model.joint_velocities
|
@@ -188,7 +193,7 @@ def system_velocity_dynamics(
|
|
188
193
|
# Compute forward dynamics
|
189
194
|
# ========================
|
190
195
|
|
191
|
-
# Compute the total joint forces
|
196
|
+
# Compute the total joint forces.
|
192
197
|
τ_total = τ + τ_friction + τ_position_limit
|
193
198
|
|
194
199
|
references = js.references.JaxSimModelReferences.build(
|
@@ -202,7 +207,7 @@ def system_velocity_dynamics(
|
|
202
207
|
with references.switch_velocity_representation(VelRepr.Inertial):
|
203
208
|
W_f_L = references.link_forces(model=model, data=data)
|
204
209
|
|
205
|
-
# Compute the total external 6D forces applied to the links
|
210
|
+
# Compute the total external 6D forces applied to the links.
|
206
211
|
W_f_L_total = W_f_L + W_f_Li_terrain
|
207
212
|
|
208
213
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
@@ -273,7 +278,9 @@ def system_dynamics(
|
|
273
278
|
model: The model to consider.
|
274
279
|
data: The data of the considered model.
|
275
280
|
joint_forces: The joint forces to apply.
|
276
|
-
link_forces:
|
281
|
+
link_forces:
|
282
|
+
The 6D forces to apply to the links expressed in the frame corresponding to
|
283
|
+
the velocity representation of `data`.
|
277
284
|
baumgarte_quaternion_regularization:
|
278
285
|
The Baumgarte regularization coefficient used to adjust the norm of the
|
279
286
|
quaternion (only used in integrators not operating on the SO(3) manifold).
|
jaxsim/api/ode_data.py
CHANGED
@@ -31,8 +31,8 @@ class ODEInput(JaxsimDataclass):
|
|
31
31
|
@staticmethod
|
32
32
|
def build_from_jaxsim_model(
|
33
33
|
model: js.model.JaxSimModel | None = None,
|
34
|
-
joint_forces: jtp.
|
35
|
-
link_forces: jtp.
|
34
|
+
joint_forces: jtp.VectorLike | None = None,
|
35
|
+
link_forces: jtp.MatrixLike | None = None,
|
36
36
|
) -> ODEInput:
|
37
37
|
"""
|
38
38
|
Build an `ODEInput` from a `JaxSimModel`.
|
@@ -160,7 +160,7 @@ class ODEState(JaxsimDataclass):
|
|
160
160
|
`JaxSimModel` and initialized to zero.
|
161
161
|
"""
|
162
162
|
|
163
|
-
# Get the contact model from the `JaxSimModel
|
163
|
+
# Get the contact model from the `JaxSimModel`.
|
164
164
|
match model.contact_model:
|
165
165
|
case SoftContacts():
|
166
166
|
contact = SoftContactsState.build_from_jaxsim_model(
|
@@ -212,7 +212,7 @@ class ODEState(JaxsimDataclass):
|
|
212
212
|
else PhysicsModelState.zero(model=model)
|
213
213
|
)
|
214
214
|
|
215
|
-
# Get the contact model from the `JaxSimModel
|
215
|
+
# Get the contact model from the `JaxSimModel`.
|
216
216
|
match contact:
|
217
217
|
case SoftContactsState():
|
218
218
|
pass
|
@@ -423,7 +423,7 @@ class PhysicsModelState(JaxsimDataclass):
|
|
423
423
|
base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
|
424
424
|
)
|
425
425
|
|
426
|
-
# assert state.valid(physics_model)
|
426
|
+
# TODO (diegoferigo): assert state.valid(physics_model)
|
427
427
|
return physics_model_state
|
428
428
|
|
429
429
|
@staticmethod
|
@@ -501,14 +501,14 @@ class PhysicsModelInput(JaxsimDataclass):
|
|
501
501
|
f_ext: The matrix of external forces applied to the links.
|
502
502
|
"""
|
503
503
|
|
504
|
-
tau: jtp.
|
505
|
-
f_ext: jtp.
|
504
|
+
tau: jtp.Vector
|
505
|
+
f_ext: jtp.Matrix
|
506
506
|
|
507
507
|
@staticmethod
|
508
508
|
def build_from_jaxsim_model(
|
509
509
|
model: js.model.JaxSimModel | None = None,
|
510
|
-
joint_forces: jtp.
|
511
|
-
link_forces: jtp.
|
510
|
+
joint_forces: jtp.VectorLike | None = None,
|
511
|
+
link_forces: jtp.MatrixLike | None = None,
|
512
512
|
) -> PhysicsModelInput:
|
513
513
|
"""
|
514
514
|
Build a `PhysicsModelInput` from a `JaxSimModel`.
|
@@ -535,8 +535,8 @@ class PhysicsModelInput(JaxsimDataclass):
|
|
535
535
|
|
536
536
|
@staticmethod
|
537
537
|
def build(
|
538
|
-
joint_forces: jtp.
|
539
|
-
link_forces: jtp.
|
538
|
+
joint_forces: jtp.VectorLike | None = None,
|
539
|
+
link_forces: jtp.MatrixLike | None = None,
|
540
540
|
number_of_dofs: jtp.Int | None = None,
|
541
541
|
number_of_links: jtp.Int | None = None,
|
542
542
|
) -> PhysicsModelInput:
|
jaxsim/integrators/common.py
CHANGED
@@ -10,7 +10,6 @@ from jax_dataclasses import Static
|
|
10
10
|
|
11
11
|
import jaxsim.api as js
|
12
12
|
import jaxsim.typing as jtp
|
13
|
-
from jaxsim.math import Quaternion
|
14
13
|
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
|
15
14
|
|
16
15
|
try:
|
@@ -28,8 +27,8 @@ except ImportError:
|
|
28
27
|
# Generic types
|
29
28
|
# =============
|
30
29
|
|
31
|
-
Time =
|
32
|
-
TimeStep =
|
30
|
+
Time = jtp.FloatLike
|
31
|
+
TimeStep = jtp.FloatLike
|
33
32
|
State = NextState = TypeVar("State")
|
34
33
|
StateDerivative = TypeVar("StateDerivative")
|
35
34
|
PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
|
@@ -80,7 +79,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
80
79
|
The integrator object.
|
81
80
|
"""
|
82
81
|
|
83
|
-
return cls(dynamics=dynamics, **kwargs)
|
82
|
+
return cls(dynamics=dynamics, **kwargs)
|
84
83
|
|
85
84
|
def step(
|
86
85
|
self,
|
@@ -192,14 +191,14 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
192
191
|
class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
|
193
192
|
|
194
193
|
# The Runge-Kutta matrix.
|
195
|
-
A: ClassVar[
|
194
|
+
A: ClassVar[jtp.Matrix]
|
196
195
|
|
197
196
|
# The weights coefficients.
|
198
197
|
# Note that in practice we typically use its transpose `b.transpose()`.
|
199
|
-
b: ClassVar[
|
198
|
+
b: ClassVar[jtp.Matrix]
|
200
199
|
|
201
200
|
# The nodes coefficients.
|
202
|
-
c: ClassVar[
|
201
|
+
c: ClassVar[jtp.Vector]
|
203
202
|
|
204
203
|
# Define the order of the solution.
|
205
204
|
# It should have as many elements as the number of rows of `b.transpose()`.
|
@@ -385,7 +384,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
385
384
|
# Define the computation of the Runge-Kutta stage.
|
386
385
|
def compute_ki() -> jax.Array:
|
387
386
|
|
388
|
-
# Compute ∑ⱼ aᵢⱼ k
|
387
|
+
# Compute ∑ⱼ aᵢⱼ kⱼ.
|
389
388
|
op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
|
390
389
|
sum_ak = jax.tree_util.tree_map(op_sum_ak, K)
|
391
390
|
|
@@ -441,7 +440,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
441
440
|
|
442
441
|
@staticmethod
|
443
442
|
def butcher_tableau_is_valid(
|
444
|
-
A:
|
443
|
+
A: jtp.Matrix, b: jtp.Matrix, c: jtp.Vector
|
445
444
|
) -> jtp.Bool:
|
446
445
|
"""
|
447
446
|
Check if the Butcher tableau is valid.
|
@@ -467,7 +466,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
467
466
|
return valid
|
468
467
|
|
469
468
|
@staticmethod
|
470
|
-
def butcher_tableau_is_explicit(A:
|
469
|
+
def butcher_tableau_is_explicit(A: jtp.Matrix) -> jtp.Bool:
|
471
470
|
"""
|
472
471
|
Check if the Butcher tableau corresponds to an explicit integration scheme.
|
473
472
|
|
@@ -482,9 +481,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
482
481
|
|
483
482
|
@staticmethod
|
484
483
|
def butcher_tableau_supports_fsal(
|
485
|
-
A:
|
486
|
-
b:
|
487
|
-
c:
|
484
|
+
A: jtp.Matrix,
|
485
|
+
b: jtp.Matrix,
|
486
|
+
c: jtp.Vector,
|
488
487
|
index_of_solution: jtp.IntLike = 0,
|
489
488
|
) -> [bool, int | None]:
|
490
489
|
"""
|
@@ -548,17 +547,11 @@ class ExplicitRungeKuttaSO3Mixin:
|
|
548
547
|
op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
|
549
548
|
xf: js.ode_data.ODEState = jax.tree_util.tree_map(op, x0, k)
|
550
549
|
|
551
|
-
|
552
|
-
W_ω_WB_t0 = x0.physics_model.base_angular_velocity
|
550
|
+
W_Q_B_tf = xf.physics_model.base_quaternion
|
553
551
|
|
554
552
|
return xf.replace(
|
555
553
|
physics_model=xf.physics_model.replace(
|
556
|
-
base_quaternion=
|
557
|
-
quaternion=W_Q_B_t0,
|
558
|
-
dt=dt,
|
559
|
-
omega=W_ω_WB_t0,
|
560
|
-
omega_in_body_fixed=False,
|
561
|
-
),
|
554
|
+
base_quaternion=W_Q_B_tf / jnp.linalg.norm(W_Q_B_tf)
|
562
555
|
)
|
563
556
|
)
|
564
557
|
|
@@ -569,10 +562,9 @@ class ExplicitRungeKuttaSO3Mixin:
|
|
569
562
|
|
570
563
|
# Indices to convert quaternions between serializations.
|
571
564
|
to_xyzw = jnp.array([1, 2, 3, 0])
|
572
|
-
to_wxyz = jnp.array([3, 0, 1, 2])
|
573
565
|
|
574
|
-
# Get the initial
|
575
|
-
|
566
|
+
# Get the initial rotation.
|
567
|
+
W_R_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
|
576
568
|
xyzw=x0.physics_model.base_quaternion[to_xyzw]
|
577
569
|
)
|
578
570
|
|
@@ -582,15 +574,13 @@ class ExplicitRungeKuttaSO3Mixin:
|
|
582
574
|
# on the SO(3) manifold.
|
583
575
|
W_ω_WB_tf = xf.physics_model.base_angular_velocity
|
584
576
|
|
585
|
-
# Integrate the
|
577
|
+
# Integrate the orientation on SO(3).
|
586
578
|
# Note that we left-multiply with the exponential map since the angular
|
587
579
|
# velocity is expressed in the inertial frame.
|
588
|
-
|
580
|
+
W_R_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_R_B_t0
|
589
581
|
|
590
582
|
# Replace the quaternion in the final state.
|
591
583
|
return xf.replace(
|
592
|
-
physics_model=xf.physics_model.replace(
|
593
|
-
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
|
594
|
-
),
|
584
|
+
physics_model=xf.physics_model.replace(base_quaternion=W_R_B_tf.wxyz),
|
595
585
|
validate=True,
|
596
586
|
)
|
jaxsim/integrators/fixed_step.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
from typing import ClassVar, Generic
|
2
2
|
|
3
|
-
import jax
|
4
3
|
import jax.numpy as jnp
|
5
4
|
import jax_dataclasses
|
6
5
|
|
7
6
|
import jaxsim.api as js
|
7
|
+
import jaxsim.typing as jtp
|
8
8
|
|
9
9
|
from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
|
10
10
|
|
@@ -18,11 +18,11 @@ ODEStateDerivative = js.ode_data.ODEState
|
|
18
18
|
@jax_dataclasses.pytree_dataclass
|
19
19
|
class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
20
20
|
|
21
|
-
A: ClassVar[
|
21
|
+
A: ClassVar[jtp.Matrix] = jnp.atleast_2d(0).astype(float)
|
22
22
|
|
23
|
-
b: ClassVar[
|
23
|
+
b: ClassVar[jtp.Matrix] = jnp.atleast_2d(1).astype(float).transpose()
|
24
24
|
|
25
|
-
c: ClassVar[
|
25
|
+
c: ClassVar[jtp.Vector] = jnp.atleast_1d(0).astype(float)
|
26
26
|
|
27
27
|
row_index_of_solution: ClassVar[int] = 0
|
28
28
|
order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
|
@@ -31,14 +31,14 @@ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
31
31
|
@jax_dataclasses.pytree_dataclass
|
32
32
|
class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
33
33
|
|
34
|
-
A: ClassVar[
|
34
|
+
A: ClassVar[jtp.Matrix] = jnp.array(
|
35
35
|
[
|
36
36
|
[0, 0],
|
37
37
|
[1, 0],
|
38
38
|
]
|
39
39
|
).astype(float)
|
40
40
|
|
41
|
-
b: ClassVar[
|
41
|
+
b: ClassVar[jtp.Matrix] = (
|
42
42
|
jnp.atleast_2d(
|
43
43
|
jnp.array([1 / 2, 1 / 2]),
|
44
44
|
)
|
@@ -46,7 +46,7 @@ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
46
46
|
.transpose()
|
47
47
|
)
|
48
48
|
|
49
|
-
c: ClassVar[
|
49
|
+
c: ClassVar[jtp.Vector] = jnp.array(
|
50
50
|
[0, 1],
|
51
51
|
).astype(float)
|
52
52
|
|
@@ -57,7 +57,7 @@ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
57
57
|
@jax_dataclasses.pytree_dataclass
|
58
58
|
class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
59
59
|
|
60
|
-
A: ClassVar[
|
60
|
+
A: ClassVar[jtp.Matrix] = jnp.array(
|
61
61
|
[
|
62
62
|
[0, 0, 0, 0],
|
63
63
|
[1 / 2, 0, 0, 0],
|
@@ -66,7 +66,7 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
66
66
|
]
|
67
67
|
).astype(float)
|
68
68
|
|
69
|
-
b: ClassVar[
|
69
|
+
b: ClassVar[jtp.Matrix] = (
|
70
70
|
jnp.atleast_2d(
|
71
71
|
jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
|
72
72
|
)
|
@@ -74,7 +74,7 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
74
74
|
.transpose()
|
75
75
|
)
|
76
76
|
|
77
|
-
c: ClassVar[
|
77
|
+
c: ClassVar[jtp.Vector] = jnp.array(
|
78
78
|
[0, 1 / 2, 1 / 2, 1],
|
79
79
|
).astype(float)
|
80
80
|
|