jaxsim 0.6.2.dev281__py3-none-any.whl → 0.6.2.dev296__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/kin_dyn_parameters.py +286 -6
- jaxsim/api/model.py +384 -29
- jaxsim/math/joint_model.py +2 -1
- jaxsim/mujoco/utils.py +1 -5
- {jaxsim-0.6.2.dev281.dist-info → jaxsim-0.6.2.dev296.dist-info}/METADATA +1 -1
- {jaxsim-0.6.2.dev281.dist-info → jaxsim-0.6.2.dev296.dist-info}/RECORD +10 -10
- {jaxsim-0.6.2.dev281.dist-info → jaxsim-0.6.2.dev296.dist-info}/WHEEL +1 -1
- {jaxsim-0.6.2.dev281.dist-info → jaxsim-0.6.2.dev296.dist-info}/licenses/LICENSE +0 -0
- {jaxsim-0.6.2.dev281.dist-info → jaxsim-0.6.2.dev296.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.6.2.
|
21
|
-
__version_tuple__ = version_tuple = (0, 6, 2, '
|
20
|
+
__version__ = version = '0.6.2.dev296'
|
21
|
+
__version_tuple__ = version_tuple = (0, 6, 2, 'dev296')
|
jaxsim/api/kin_dyn_parameters.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
+
from typing import ClassVar
|
4
5
|
|
5
6
|
import jax.lax
|
6
7
|
import jax.numpy as jnp
|
@@ -9,8 +10,10 @@ import numpy as np
|
|
9
10
|
import numpy.typing as npt
|
10
11
|
from jax_dataclasses import Static
|
11
12
|
|
13
|
+
import jaxsim
|
12
14
|
import jaxsim.typing as jtp
|
13
|
-
from jaxsim.math import
|
15
|
+
from jaxsim.math import Inertia, JointModel, supported_joint_motion
|
16
|
+
from jaxsim.math.adjoint import Adjoint
|
14
17
|
from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription
|
15
18
|
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
|
16
19
|
|
@@ -30,6 +33,7 @@ class KinDynParameters(JaxsimDataclass):
|
|
30
33
|
contact_parameters: The parameters of the collidable points.
|
31
34
|
joint_model: The joint model of the model.
|
32
35
|
joint_parameters: The parameters of the joints.
|
36
|
+
hw_link_metadata: The hardware parameters of the model links.
|
33
37
|
"""
|
34
38
|
|
35
39
|
# Static
|
@@ -51,6 +55,9 @@ class KinDynParameters(JaxsimDataclass):
|
|
51
55
|
joint_model: JointModel
|
52
56
|
joint_parameters: JointParameters | None
|
53
57
|
|
58
|
+
# Model hardware parameters
|
59
|
+
hw_link_metadata: HwLinkMetadata | None = dataclasses.field(default=None)
|
60
|
+
|
54
61
|
@property
|
55
62
|
def motion_subspaces(self) -> jtp.Matrix:
|
56
63
|
r"""
|
@@ -197,7 +204,6 @@ class KinDynParameters(JaxsimDataclass):
|
|
197
204
|
carry0 = κb, link_index
|
198
205
|
|
199
206
|
def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
|
200
|
-
|
201
207
|
κb, active_link_index = carry
|
202
208
|
|
203
209
|
κb, active_link_index = jax.lax.cond(
|
@@ -224,7 +230,6 @@ class KinDynParameters(JaxsimDataclass):
|
|
224
230
|
)
|
225
231
|
|
226
232
|
def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:
|
227
|
-
|
228
233
|
S = {
|
229
234
|
JointType.Fixed: np.zeros(shape=(6, 1)),
|
230
235
|
JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])),
|
@@ -265,14 +270,12 @@ class KinDynParameters(JaxsimDataclass):
|
|
265
270
|
)
|
266
271
|
|
267
272
|
def __eq__(self, other: KinDynParameters) -> bool:
|
268
|
-
|
269
273
|
if not isinstance(other, KinDynParameters):
|
270
274
|
return False
|
271
275
|
|
272
276
|
return hash(self) == hash(other)
|
273
277
|
|
274
278
|
def __hash__(self) -> int:
|
275
|
-
|
276
279
|
return hash(
|
277
280
|
(
|
278
281
|
hash(self.number_of_links()),
|
@@ -671,7 +674,11 @@ class LinkParameters(JaxsimDataclass):
|
|
671
674
|
|
672
675
|
return (
|
673
676
|
jnp.hstack(
|
674
|
-
[
|
677
|
+
[
|
678
|
+
params.mass,
|
679
|
+
params.center_of_mass.squeeze(),
|
680
|
+
params.inertia_elements,
|
681
|
+
]
|
675
682
|
)
|
676
683
|
.squeeze()
|
677
684
|
.astype(float)
|
@@ -882,3 +889,276 @@ class FrameParameters(JaxsimDataclass):
|
|
882
889
|
assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]
|
883
890
|
|
884
891
|
return fp
|
892
|
+
|
893
|
+
|
894
|
+
@dataclasses.dataclass(frozen=True)
|
895
|
+
class LinkParametrizableShape:
|
896
|
+
"""
|
897
|
+
Enum-like class listing the supported shapes for HW parametrization.
|
898
|
+
"""
|
899
|
+
|
900
|
+
Unsupported: ClassVar[int] = -1
|
901
|
+
Box: ClassVar[int] = 0
|
902
|
+
Cylinder: ClassVar[int] = 1
|
903
|
+
Sphere: ClassVar[int] = 2
|
904
|
+
|
905
|
+
|
906
|
+
@jax_dataclasses.pytree_dataclass
|
907
|
+
class HwLinkMetadata(JaxsimDataclass):
|
908
|
+
"""
|
909
|
+
Class storing the hardware parameters of a link.
|
910
|
+
|
911
|
+
Attributes:
|
912
|
+
shape: The shape of the link.
|
913
|
+
0 = box, 1 = sphere, 2 = cylinder, -1 = unsupported.
|
914
|
+
dims: The dimensions of the link.
|
915
|
+
box: [lx,ly,lz], sphere: [r,0,0], cylinder: [r,l,0]
|
916
|
+
density: The density of the link.
|
917
|
+
L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G.
|
918
|
+
L_H_vis: The homogeneous transformation matrix from the link frame to the visual frame.
|
919
|
+
L_H_pre_mask: The mask indicating the link's child joint indices.
|
920
|
+
L_H_pre: The homogeneous transforms for child joints.
|
921
|
+
"""
|
922
|
+
|
923
|
+
shape: jtp.Vector
|
924
|
+
dims: jtp.Vector
|
925
|
+
density: jtp.Float
|
926
|
+
L_H_G: jtp.Matrix
|
927
|
+
L_H_vis: jtp.Matrix
|
928
|
+
L_H_pre_mask: jtp.Vector
|
929
|
+
L_H_pre: jtp.Matrix
|
930
|
+
|
931
|
+
@staticmethod
|
932
|
+
def compute_mass_and_inertia(
|
933
|
+
hw_link_metadata: HwLinkMetadata,
|
934
|
+
) -> tuple[jtp.Float, jtp.Matrix]:
|
935
|
+
"""
|
936
|
+
Compute the mass and inertia of a hardware link based on its metadata.
|
937
|
+
|
938
|
+
This function calculates the mass and inertia tensor of a hardware link
|
939
|
+
using its shape, dimensions, and density. The computation is performed
|
940
|
+
by using shape-specific methods.
|
941
|
+
|
942
|
+
Args:
|
943
|
+
hw_link_metadata: Metadata describing the hardware link,
|
944
|
+
including its shape, dimensions, and density.
|
945
|
+
|
946
|
+
Returns:
|
947
|
+
tuple: A tuple containing:
|
948
|
+
- mass: The computed mass of the hardware link.
|
949
|
+
- inertia: The computed inertia tensor of the hardware link.
|
950
|
+
"""
|
951
|
+
|
952
|
+
mass, inertia = jax.lax.switch(
|
953
|
+
hw_link_metadata.shape,
|
954
|
+
[
|
955
|
+
HwLinkMetadata._box,
|
956
|
+
HwLinkMetadata._cylinder,
|
957
|
+
HwLinkMetadata._sphere,
|
958
|
+
],
|
959
|
+
hw_link_metadata.dims,
|
960
|
+
hw_link_metadata.density,
|
961
|
+
)
|
962
|
+
return mass, inertia
|
963
|
+
|
964
|
+
@staticmethod
|
965
|
+
def _box(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
|
966
|
+
lx, ly, lz = dims
|
967
|
+
|
968
|
+
mass = density * lx * ly * lz
|
969
|
+
|
970
|
+
inertia = jnp.array(
|
971
|
+
[
|
972
|
+
[mass * (ly**2 + lz**2) / 12, 0, 0],
|
973
|
+
[0, mass * (lx**2 + lz**2) / 12, 0],
|
974
|
+
[0, 0, mass * (lx**2 + ly**2) / 12],
|
975
|
+
]
|
976
|
+
)
|
977
|
+
return mass, inertia
|
978
|
+
|
979
|
+
@staticmethod
|
980
|
+
def _cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
|
981
|
+
r, l, _ = dims
|
982
|
+
|
983
|
+
mass = density * (jnp.pi * r**2 * l)
|
984
|
+
|
985
|
+
inertia = jnp.array(
|
986
|
+
[
|
987
|
+
[mass * (3 * r**2 + l**2) / 12, 0, 0],
|
988
|
+
[0, mass * (3 * r**2 + l**2) / 12, 0],
|
989
|
+
[0, 0, mass * (r**2) / 2],
|
990
|
+
]
|
991
|
+
)
|
992
|
+
|
993
|
+
return mass, inertia
|
994
|
+
|
995
|
+
@staticmethod
|
996
|
+
def _sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
|
997
|
+
r = dims[0]
|
998
|
+
|
999
|
+
mass = density * (4 / 3 * jnp.pi * r**3)
|
1000
|
+
|
1001
|
+
inertia = jnp.eye(3) * (2 / 5 * mass * r**2)
|
1002
|
+
|
1003
|
+
return mass, inertia
|
1004
|
+
|
1005
|
+
@staticmethod
|
1006
|
+
def _convert_scaling_to_3d_vector(
|
1007
|
+
shape: jtp.Int, scaling_factors: jtp.Vector
|
1008
|
+
) -> jtp.Vector:
|
1009
|
+
"""
|
1010
|
+
Convert scaling factors for specific shape dimensions into a 3D scaling vector.
|
1011
|
+
|
1012
|
+
Args:
|
1013
|
+
shape: The shape of the link (e.g., box, sphere, cylinder).
|
1014
|
+
scaling_factors: The scaling factors for the shape dimensions.
|
1015
|
+
|
1016
|
+
Returns:
|
1017
|
+
A 3D scaling vector to apply to position vectors.
|
1018
|
+
|
1019
|
+
Note:
|
1020
|
+
The scaling factors are applied as follows to generate the 3D scale vector:
|
1021
|
+
- Box: [lx, ly, lz]
|
1022
|
+
- Cylinder: [r, r, l]
|
1023
|
+
- Sphere: [r, r, r]
|
1024
|
+
"""
|
1025
|
+
return jax.lax.switch(
|
1026
|
+
shape,
|
1027
|
+
branches=[
|
1028
|
+
# Box
|
1029
|
+
lambda: scaling_factors,
|
1030
|
+
# Cylinder
|
1031
|
+
lambda: jnp.array(
|
1032
|
+
[
|
1033
|
+
scaling_factors[0],
|
1034
|
+
scaling_factors[0],
|
1035
|
+
scaling_factors[1],
|
1036
|
+
]
|
1037
|
+
),
|
1038
|
+
# Sphere
|
1039
|
+
lambda: jnp.array(
|
1040
|
+
[
|
1041
|
+
scaling_factors[0],
|
1042
|
+
scaling_factors[0],
|
1043
|
+
scaling_factors[0],
|
1044
|
+
]
|
1045
|
+
),
|
1046
|
+
],
|
1047
|
+
)
|
1048
|
+
|
1049
|
+
@staticmethod
|
1050
|
+
def compute_inertia_link(I_com, mass, L_H_G) -> jtp.Matrix:
|
1051
|
+
"""
|
1052
|
+
Compute the inertia tensor of the link based on its shape and mass.
|
1053
|
+
"""
|
1054
|
+
|
1055
|
+
L_R_G = L_H_G[:3, :3]
|
1056
|
+
return L_R_G @ I_com @ L_R_G.T
|
1057
|
+
|
1058
|
+
@staticmethod
|
1059
|
+
def apply_scaling(
|
1060
|
+
hw_metadata: HwLinkMetadata, scaling_factors: ScalingFactors
|
1061
|
+
) -> HwLinkMetadata:
|
1062
|
+
"""
|
1063
|
+
Apply scaling to the hardware parameters and return a new HwLinkMetadata object.
|
1064
|
+
|
1065
|
+
Args:
|
1066
|
+
hw_metadata: the original HwLinkMetadata object.
|
1067
|
+
scaling_factors: the scaling factors to apply.
|
1068
|
+
|
1069
|
+
Returns:
|
1070
|
+
A new HwLinkMetadata object with updated parameters.
|
1071
|
+
"""
|
1072
|
+
|
1073
|
+
# ==================================
|
1074
|
+
# Handle unsupported links
|
1075
|
+
# ==================================
|
1076
|
+
def unsupported_case(hw_metadata, scaling_factors):
|
1077
|
+
# Return the metadata unchanged for unsupported links
|
1078
|
+
return hw_metadata
|
1079
|
+
|
1080
|
+
def supported_case(hw_metadata, scaling_factors):
|
1081
|
+
# ==================================
|
1082
|
+
# Update the kinematics of the link
|
1083
|
+
# ==================================
|
1084
|
+
|
1085
|
+
# Get the nominal transforms
|
1086
|
+
L_H_G = hw_metadata.L_H_G
|
1087
|
+
L_H_vis = hw_metadata.L_H_vis
|
1088
|
+
L_H_pre_array = hw_metadata.L_H_pre
|
1089
|
+
L_H_pre_mask = hw_metadata.L_H_pre_mask
|
1090
|
+
|
1091
|
+
# Compute the 3D scaling vector
|
1092
|
+
scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector(
|
1093
|
+
hw_metadata.shape, scaling_factors.dims
|
1094
|
+
)
|
1095
|
+
|
1096
|
+
# Express the transforms in the G frame
|
1097
|
+
G_H_L = jaxsim.math.Transform.inverse(L_H_G)
|
1098
|
+
G_H_vis = G_H_L @ L_H_vis
|
1099
|
+
G_H_pre_array = jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array)
|
1100
|
+
|
1101
|
+
# Apply the scaling to the position vectors
|
1102
|
+
G_H̅_L = G_H_L.at[:3, 3].set(scale_vector * G_H_L[:3, 3])
|
1103
|
+
G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3])
|
1104
|
+
# Apply scaling to the position vectors in G_H_pre_array based on the mask
|
1105
|
+
G_H̅_pre_array = jax.vmap(
|
1106
|
+
lambda G_H_pre, mask: jnp.where(
|
1107
|
+
# Expand mask for broadcasting
|
1108
|
+
mask[..., None, None],
|
1109
|
+
# Apply scaling
|
1110
|
+
G_H_pre.at[:3, 3].set(scale_vector * G_H_pre[:3, 3]),
|
1111
|
+
# Keep unchanged if mask is False
|
1112
|
+
G_H_pre,
|
1113
|
+
)
|
1114
|
+
)(G_H_pre_array, L_H_pre_mask)
|
1115
|
+
|
1116
|
+
# Get back to the link frame
|
1117
|
+
L_H̅_G = jaxsim.math.Transform.inverse(G_H̅_L)
|
1118
|
+
L_H̅_vis = L_H̅_G @ G_H̅_vis
|
1119
|
+
L_H̅_pre_array = jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array)
|
1120
|
+
|
1121
|
+
# ============================
|
1122
|
+
# Update the shape parameters
|
1123
|
+
# ============================
|
1124
|
+
|
1125
|
+
updated_dims = hw_metadata.dims * scaling_factors.dims
|
1126
|
+
|
1127
|
+
# ==============================
|
1128
|
+
# Scale the density of the link
|
1129
|
+
# ==============================
|
1130
|
+
|
1131
|
+
updated_density = hw_metadata.density * scaling_factors.density
|
1132
|
+
|
1133
|
+
# ============================
|
1134
|
+
# Return updated HwLinkMetadata
|
1135
|
+
# ============================
|
1136
|
+
|
1137
|
+
return hw_metadata.replace(
|
1138
|
+
dims=updated_dims,
|
1139
|
+
density=updated_density,
|
1140
|
+
L_H_G=L_H̅_G,
|
1141
|
+
L_H_vis=L_H̅_vis,
|
1142
|
+
L_H_pre=L_H̅_pre_array,
|
1143
|
+
)
|
1144
|
+
|
1145
|
+
# Use jax.lax.cond to handle unsupported links
|
1146
|
+
return jax.lax.cond(
|
1147
|
+
hw_metadata.shape == LinkParametrizableShape.Unsupported,
|
1148
|
+
lambda: unsupported_case(hw_metadata, scaling_factors),
|
1149
|
+
lambda: supported_case(hw_metadata, scaling_factors),
|
1150
|
+
)
|
1151
|
+
|
1152
|
+
|
1153
|
+
@jax_dataclasses.pytree_dataclass
|
1154
|
+
class ScalingFactors(JaxsimDataclass):
|
1155
|
+
"""
|
1156
|
+
Class storing scaling factors for hardware parameters.
|
1157
|
+
|
1158
|
+
Attributes:
|
1159
|
+
dims: Scaling factors for shape dimensions.
|
1160
|
+
density: Scaling factor for density.
|
1161
|
+
"""
|
1162
|
+
|
1163
|
+
dims: jtp.Vector
|
1164
|
+
density: jtp.Float
|
jaxsim/api/model.py
CHANGED
@@ -11,14 +11,26 @@ import jax
|
|
11
11
|
import jax.numpy as jnp
|
12
12
|
import jax_dataclasses
|
13
13
|
import rod
|
14
|
+
import rod.urdf
|
14
15
|
from jax_dataclasses import Static
|
16
|
+
from rod.urdf.exporter import UrdfExporter
|
15
17
|
|
16
18
|
import jaxsim.api as js
|
17
19
|
import jaxsim.exceptions
|
18
20
|
import jaxsim.terrain
|
19
21
|
import jaxsim.typing as jtp
|
22
|
+
from jaxsim import logging
|
23
|
+
from jaxsim.api.kin_dyn_parameters import (
|
24
|
+
HwLinkMetadata,
|
25
|
+
KinDynParameters,
|
26
|
+
LinkParameters,
|
27
|
+
LinkParametrizableShape,
|
28
|
+
ScalingFactors,
|
29
|
+
)
|
20
30
|
from jaxsim.math import Adjoint, Cross
|
21
31
|
from jaxsim.parsers.descriptions import ModelDescription
|
32
|
+
from jaxsim.parsers.descriptions.joint import JointDescription
|
33
|
+
from jaxsim.parsers.descriptions.link import LinkDescription
|
22
34
|
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
|
23
35
|
|
24
36
|
from .common import VelRepr
|
@@ -86,7 +98,6 @@ class JaxSimModel(JaxsimDataclass):
|
|
86
98
|
return self._description.get()
|
87
99
|
|
88
100
|
def __eq__(self, other: JaxSimModel) -> bool:
|
89
|
-
|
90
101
|
if not isinstance(other, JaxSimModel):
|
91
102
|
return False
|
92
103
|
|
@@ -102,7 +113,6 @@ class JaxSimModel(JaxsimDataclass):
|
|
102
113
|
return True
|
103
114
|
|
104
115
|
def __hash__(self) -> int:
|
105
|
-
|
106
116
|
return hash(
|
107
117
|
(
|
108
118
|
hash(self.model_name),
|
@@ -194,6 +204,12 @@ class JaxSimModel(JaxsimDataclass):
|
|
194
204
|
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
195
205
|
model.built_from = model_description
|
196
206
|
|
207
|
+
# Compute the hw parametrization metadata of the model
|
208
|
+
# TODO: move the building of the metadata to KinDynParameters.build()
|
209
|
+
# and use the model_description instead of model.built_from.
|
210
|
+
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
211
|
+
model.kin_dyn_parameters.hw_link_metadata = model.compute_hw_link_metadata()
|
212
|
+
|
197
213
|
return model
|
198
214
|
|
199
215
|
@classmethod
|
@@ -297,6 +313,274 @@ class JaxSimModel(JaxsimDataclass):
|
|
297
313
|
|
298
314
|
return model
|
299
315
|
|
316
|
+
def compute_hw_link_metadata(self) -> HwLinkMetadata:
|
317
|
+
"""
|
318
|
+
Compute the parametric metadata of the links in the model.
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
An instance of HwLinkMetadata containing the metadata of all links.
|
322
|
+
"""
|
323
|
+
model_description = self.description
|
324
|
+
|
325
|
+
# Get ordered links and joints from the model description
|
326
|
+
ordered_links: list[LinkDescription] = sorted(
|
327
|
+
list(model_description.links_dict.values()),
|
328
|
+
key=lambda l: l.index,
|
329
|
+
)
|
330
|
+
ordered_joints: list[JointDescription] = sorted(
|
331
|
+
list(model_description.joints_dict.values()),
|
332
|
+
key=lambda j: j.index,
|
333
|
+
)
|
334
|
+
|
335
|
+
# Ensure the model was built from a valid source
|
336
|
+
rod_model = None
|
337
|
+
match self.built_from:
|
338
|
+
case str() | pathlib.Path():
|
339
|
+
rod_model = rod.Sdf.load(sdf=self.built_from).models()[0]
|
340
|
+
assert rod_model.name == self.name()
|
341
|
+
case rod.Model():
|
342
|
+
rod_model = self.built_from
|
343
|
+
case _:
|
344
|
+
logging.debug(
|
345
|
+
f"Invalid type for model.built_from ({type(self.built_from)})."
|
346
|
+
"Skipping for hardware parametrization."
|
347
|
+
)
|
348
|
+
return HwLinkMetadata(
|
349
|
+
shape=jnp.array([]),
|
350
|
+
dims=jnp.array([]),
|
351
|
+
density=jnp.array([]),
|
352
|
+
L_H_G=jnp.array([]),
|
353
|
+
L_H_vis=jnp.array([]),
|
354
|
+
L_H_pre_mask=jnp.array([]),
|
355
|
+
L_H_pre=jnp.array([]),
|
356
|
+
)
|
357
|
+
|
358
|
+
# Use URDF frame convention for consistent pose representation
|
359
|
+
rod_model.switch_frame_convention(
|
360
|
+
frame_convention=rod.FrameConvention.Urdf, explicit_frames=True
|
361
|
+
)
|
362
|
+
|
363
|
+
rod_links_dict = {}
|
364
|
+
|
365
|
+
# Filter links that support parameterization
|
366
|
+
for rod_link in rod_model.links():
|
367
|
+
if len(rod_link.visuals()) != 1:
|
368
|
+
logging.debug(
|
369
|
+
f"Skipping link '{rod_link.name}' for hardware parametrization due to multiple visuals."
|
370
|
+
)
|
371
|
+
continue
|
372
|
+
|
373
|
+
if not isinstance(
|
374
|
+
rod_link.visual.geometry.geometry(), (rod.Box, rod.Sphere, rod.Cylinder)
|
375
|
+
):
|
376
|
+
logging.debug(
|
377
|
+
f"Skipping link '{rod_link.name}' for hardware parametrization due to unsupported geometry."
|
378
|
+
)
|
379
|
+
continue
|
380
|
+
|
381
|
+
rod_links_dict[rod_link.name] = rod_link
|
382
|
+
|
383
|
+
# Initialize lists to collect metadata for all links
|
384
|
+
shapes = []
|
385
|
+
dims = []
|
386
|
+
densities = []
|
387
|
+
L_H_Gs = []
|
388
|
+
L_H_vises = []
|
389
|
+
L_H_pre_masks = []
|
390
|
+
L_H_pres = []
|
391
|
+
|
392
|
+
# Process each link
|
393
|
+
for link_description in ordered_links:
|
394
|
+
link_name = link_description.name
|
395
|
+
|
396
|
+
if link_name not in self.link_names():
|
397
|
+
logging.debug(
|
398
|
+
f"Skipping link '{link_name}' for hardware parametrization as it is not part of the JaxSim model."
|
399
|
+
)
|
400
|
+
continue
|
401
|
+
|
402
|
+
if link_name not in rod_links_dict:
|
403
|
+
logging.debug(
|
404
|
+
f"Skipping link '{link_name}' for hardware parametrization as it is not part of the ROD model."
|
405
|
+
)
|
406
|
+
continue
|
407
|
+
|
408
|
+
rod_link = rod_links_dict[link_name]
|
409
|
+
link_index = int(js.link.name_to_idx(model=self, link_name=link_name))
|
410
|
+
|
411
|
+
# Get child joints for the link
|
412
|
+
child_joints_indices = [
|
413
|
+
js.joint.name_to_idx(model=self, joint_name=j.name)
|
414
|
+
for j in ordered_joints
|
415
|
+
if j.parent.name == link_name
|
416
|
+
]
|
417
|
+
|
418
|
+
# Skip unsupported links
|
419
|
+
if not jnp.allclose(
|
420
|
+
self.kin_dyn_parameters.joint_model.suc_H_i[link_index], jnp.eye(4)
|
421
|
+
):
|
422
|
+
logging.debug(
|
423
|
+
f"Skipping link '{link_name}' for hardware parametrization due to unsupported suc_H_link."
|
424
|
+
)
|
425
|
+
continue
|
426
|
+
|
427
|
+
# Compute density and dimensions
|
428
|
+
mass = float(self.kin_dyn_parameters.link_parameters.mass[link_index])
|
429
|
+
geometry = rod_link.visual.geometry.geometry()
|
430
|
+
if isinstance(geometry, rod.Box):
|
431
|
+
lx, ly, lz = geometry.size
|
432
|
+
density = mass / (lx * ly * lz)
|
433
|
+
dims.append([lx, ly, lz])
|
434
|
+
shapes.append(LinkParametrizableShape.Box)
|
435
|
+
elif isinstance(geometry, rod.Sphere):
|
436
|
+
r = geometry.radius
|
437
|
+
density = mass / (4 / 3 * jnp.pi * r**3)
|
438
|
+
dims.append([r, 0, 0])
|
439
|
+
shapes.append(LinkParametrizableShape.Sphere)
|
440
|
+
elif isinstance(geometry, rod.Cylinder):
|
441
|
+
r, l = geometry.radius, geometry.length
|
442
|
+
density = mass / (jnp.pi * r**2 * l)
|
443
|
+
dims.append([r, l, 0])
|
444
|
+
shapes.append(LinkParametrizableShape.Cylinder)
|
445
|
+
else:
|
446
|
+
logging.debug(
|
447
|
+
f"Skipping link '{link_name}' for hardware parametrization due to unsupported geometry."
|
448
|
+
)
|
449
|
+
continue
|
450
|
+
|
451
|
+
densities.append(density)
|
452
|
+
L_H_Gs.append(rod_link.inertial.pose.transform())
|
453
|
+
L_H_vises.append(rod_link.visual.pose.transform())
|
454
|
+
L_H_pre_masks.append(
|
455
|
+
[
|
456
|
+
int(joint_index in child_joints_indices)
|
457
|
+
for joint_index in range(self.number_of_joints())
|
458
|
+
]
|
459
|
+
)
|
460
|
+
L_H_pres.append(
|
461
|
+
[
|
462
|
+
(
|
463
|
+
self.kin_dyn_parameters.joint_model.λ_H_pre[joint_index + 1]
|
464
|
+
if joint_index in child_joints_indices
|
465
|
+
else jnp.eye(4)
|
466
|
+
)
|
467
|
+
for joint_index in range(self.number_of_joints())
|
468
|
+
]
|
469
|
+
)
|
470
|
+
|
471
|
+
# Stack collected data into JAX arrays
|
472
|
+
return HwLinkMetadata(
|
473
|
+
shape=jnp.array(shapes, dtype=int),
|
474
|
+
dims=jnp.array(dims, dtype=float),
|
475
|
+
density=jnp.array(densities, dtype=float),
|
476
|
+
L_H_G=jnp.array(L_H_Gs, dtype=float),
|
477
|
+
L_H_vis=jnp.array(L_H_vises, dtype=float),
|
478
|
+
L_H_pre_mask=jnp.array(L_H_pre_masks, dtype=bool),
|
479
|
+
L_H_pre=jnp.array(L_H_pres, dtype=float),
|
480
|
+
)
|
481
|
+
|
482
|
+
def export_updated_model(self) -> str:
|
483
|
+
"""
|
484
|
+
Export the JaxSim model to URDF with the current hardware parameters.
|
485
|
+
|
486
|
+
Returns:
|
487
|
+
The URDF string of the updated model.
|
488
|
+
|
489
|
+
Note:
|
490
|
+
This method is not meant to be used in JIT-compiled functions.
|
491
|
+
"""
|
492
|
+
|
493
|
+
import numpy as np
|
494
|
+
|
495
|
+
if isinstance(jnp.zeros(0), jax.core.Tracer):
|
496
|
+
raise RuntimeError("This method cannot be used in JIT-compiled functions")
|
497
|
+
|
498
|
+
# Ensure `built_from` is a ROD model and create `rod_model_output`
|
499
|
+
if isinstance(self.built_from, rod.Model):
|
500
|
+
rod_model_output = copy.deepcopy(self.built_from)
|
501
|
+
elif isinstance(self.built_from, (str, pathlib.Path)):
|
502
|
+
rod_model_output = rod.Sdf.load(sdf=self.built_from).models()[0]
|
503
|
+
else:
|
504
|
+
raise ValueError(
|
505
|
+
"The JaxSim model must be built from a valid ROD model source"
|
506
|
+
)
|
507
|
+
|
508
|
+
# Switch to URDF frame convention for easier mapping
|
509
|
+
rod_model_output.switch_frame_convention(
|
510
|
+
frame_convention=rod.FrameConvention.Urdf,
|
511
|
+
explicit_frames=True,
|
512
|
+
attach_frames_to_links=True,
|
513
|
+
)
|
514
|
+
|
515
|
+
# Get links and joints from the ROD model
|
516
|
+
links_dict = {link.name: link for link in rod_model_output.links()}
|
517
|
+
joints_dict = {joint.name: joint for joint in rod_model_output.joints()}
|
518
|
+
|
519
|
+
# Iterate over the hardware metadata to update the ROD model
|
520
|
+
hw_metadata = self.kin_dyn_parameters.hw_link_metadata
|
521
|
+
for link_index, link_name in enumerate(self.link_names()):
|
522
|
+
if link_name not in links_dict:
|
523
|
+
continue
|
524
|
+
|
525
|
+
# Update mass and inertia
|
526
|
+
mass = float(self.kin_dyn_parameters.link_parameters.mass[link_index])
|
527
|
+
center_of_mass = np.array(
|
528
|
+
self.kin_dyn_parameters.link_parameters.center_of_mass[link_index]
|
529
|
+
)
|
530
|
+
inertia_tensor = LinkParameters.unflatten_inertia_tensor(
|
531
|
+
self.kin_dyn_parameters.link_parameters.inertia_elements[link_index]
|
532
|
+
)
|
533
|
+
|
534
|
+
links_dict[link_name].inertial.mass = mass
|
535
|
+
L_H_COM = np.eye(4)
|
536
|
+
L_H_COM[0:3, 3] = center_of_mass
|
537
|
+
links_dict[link_name].inertial.pose = rod.Pose.from_transform(
|
538
|
+
transform=L_H_COM,
|
539
|
+
relative_to=links_dict[link_name].inertial.pose.relative_to,
|
540
|
+
)
|
541
|
+
links_dict[link_name].inertial.inertia = rod.Inertia.from_inertia_tensor(
|
542
|
+
inertia_tensor=inertia_tensor, validate=True
|
543
|
+
)
|
544
|
+
|
545
|
+
# Update visual shape
|
546
|
+
shape = hw_metadata.shape[link_index]
|
547
|
+
dims = hw_metadata.dims[link_index]
|
548
|
+
if shape == LinkParametrizableShape.Box:
|
549
|
+
links_dict[link_name].visual.geometry.box.size = dims.tolist()
|
550
|
+
elif shape == LinkParametrizableShape.Sphere:
|
551
|
+
links_dict[link_name].visual.geometry.sphere.radius = float(dims[0])
|
552
|
+
elif shape == LinkParametrizableShape.Cylinder:
|
553
|
+
links_dict[link_name].visual.geometry.cylinder.radius = float(dims[0])
|
554
|
+
links_dict[link_name].visual.geometry.cylinder.length = float(dims[1])
|
555
|
+
else:
|
556
|
+
logging.debug(f"Skipping unsupported shape for link '{link_name}'")
|
557
|
+
continue
|
558
|
+
|
559
|
+
# Update visual pose
|
560
|
+
links_dict[link_name].visual.pose = rod.Pose.from_transform(
|
561
|
+
transform=np.array(hw_metadata.L_H_vis[link_index]),
|
562
|
+
relative_to=links_dict[link_name].visual.pose.relative_to,
|
563
|
+
)
|
564
|
+
|
565
|
+
# Update joint poses
|
566
|
+
for joint_index in range(self.number_of_joints()):
|
567
|
+
if hw_metadata.L_H_pre_mask[link_index, joint_index]:
|
568
|
+
joint_name = js.joint.idx_to_name(
|
569
|
+
model=self, joint_index=joint_index
|
570
|
+
)
|
571
|
+
if joint_name in joints_dict:
|
572
|
+
joints_dict[joint_name].pose = rod.Pose.from_transform(
|
573
|
+
transform=np.array(
|
574
|
+
hw_metadata.L_H_pre[link_index, joint_index]
|
575
|
+
),
|
576
|
+
relative_to=joints_dict[joint_name].pose.relative_to,
|
577
|
+
)
|
578
|
+
|
579
|
+
# Export the URDF string.
|
580
|
+
urdf_string = UrdfExporter(pretty=True).to_urdf_string(sdf=rod_model_output)
|
581
|
+
|
582
|
+
return urdf_string
|
583
|
+
|
300
584
|
# ==========
|
301
585
|
# Properties
|
302
586
|
# ==========
|
@@ -585,9 +869,7 @@ def generalized_free_floating_jacobian(
|
|
585
869
|
# ======================================================================
|
586
870
|
|
587
871
|
match data.velocity_representation:
|
588
|
-
|
589
872
|
case VelRepr.Inertial:
|
590
|
-
|
591
873
|
W_H_B = data._base_transform
|
592
874
|
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
|
593
875
|
|
@@ -597,11 +879,9 @@ def generalized_free_floating_jacobian(
|
|
597
879
|
)
|
598
880
|
|
599
881
|
case VelRepr.Body:
|
600
|
-
|
601
882
|
B_J_full_WX_I = B_J_full_WX_B
|
602
883
|
|
603
884
|
case VelRepr.Mixed:
|
604
|
-
|
605
885
|
W_R_B = jaxsim.math.Quaternion.to_dcm(data.base_orientation)
|
606
886
|
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
|
607
887
|
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
@@ -633,9 +913,7 @@ def generalized_free_floating_jacobian(
|
|
633
913
|
# =======================================================================
|
634
914
|
|
635
915
|
match output_vel_repr:
|
636
|
-
|
637
916
|
case VelRepr.Inertial:
|
638
|
-
|
639
917
|
W_H_B = data._base_transform
|
640
918
|
W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)
|
641
919
|
|
@@ -644,7 +922,6 @@ def generalized_free_floating_jacobian(
|
|
644
922
|
)(B_J_WL_I)
|
645
923
|
|
646
924
|
case VelRepr.Body:
|
647
|
-
|
648
925
|
O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841
|
649
926
|
lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
|
650
927
|
B_H_L, inverse=True
|
@@ -653,7 +930,6 @@ def generalized_free_floating_jacobian(
|
|
653
930
|
)(B_H_L, B_J_WL_I)
|
654
931
|
|
655
932
|
case VelRepr.Mixed:
|
656
|
-
|
657
933
|
W_H_B = data._base_transform
|
658
934
|
|
659
935
|
LW_H_L = jax.vmap(
|
@@ -738,9 +1014,7 @@ def generalized_free_floating_jacobian_derivative(
|
|
738
1014
|
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
|
739
1015
|
|
740
1016
|
match data.velocity_representation:
|
741
|
-
|
742
1017
|
case VelRepr.Inertial:
|
743
|
-
|
744
1018
|
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
|
745
1019
|
|
746
1020
|
W_v_WB = data.base_velocity
|
@@ -752,7 +1026,6 @@ def generalized_free_floating_jacobian_derivative(
|
|
752
1026
|
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On)
|
753
1027
|
|
754
1028
|
case VelRepr.Body:
|
755
|
-
|
756
1029
|
B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
|
757
1030
|
translation=jnp.zeros(3), rotation=jnp.eye(3)
|
758
1031
|
)
|
@@ -765,7 +1038,6 @@ def generalized_free_floating_jacobian_derivative(
|
|
765
1038
|
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On)
|
766
1039
|
|
767
1040
|
case VelRepr.Mixed:
|
768
|
-
|
769
1041
|
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
|
770
1042
|
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
771
1043
|
|
@@ -788,9 +1060,7 @@ def generalized_free_floating_jacobian_derivative(
|
|
788
1060
|
# ======================================================
|
789
1061
|
|
790
1062
|
match output_vel_repr:
|
791
|
-
|
792
1063
|
case VelRepr.Inertial:
|
793
|
-
|
794
1064
|
O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)
|
795
1065
|
|
796
1066
|
with data.switch_velocity_representation(VelRepr.Body):
|
@@ -799,7 +1069,6 @@ def generalized_free_floating_jacobian_derivative(
|
|
799
1069
|
O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
|
800
1070
|
|
801
1071
|
case VelRepr.Body:
|
802
|
-
|
803
1072
|
O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
|
804
1073
|
transform=B_H_L, inverse=True
|
805
1074
|
)
|
@@ -817,7 +1086,6 @@ def generalized_free_floating_jacobian_derivative(
|
|
817
1086
|
)
|
818
1087
|
|
819
1088
|
case VelRepr.Mixed:
|
820
|
-
|
821
1089
|
W_H_L = W_H_B @ B_H_L
|
822
1090
|
LW_H_L = W_H_L.at[:, 0:3, 3].set(jnp.zeros_like(W_H_L[:, 0:3, 3]))
|
823
1091
|
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
|
@@ -1190,14 +1458,12 @@ def free_floating_mass_matrix(
|
|
1190
1458
|
return M_body
|
1191
1459
|
|
1192
1460
|
case VelRepr.Inertial:
|
1193
|
-
|
1194
1461
|
B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True)
|
1195
1462
|
invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
1196
1463
|
|
1197
1464
|
return invT.T @ M_body @ invT
|
1198
1465
|
|
1199
1466
|
case VelRepr.Mixed:
|
1200
|
-
|
1201
1467
|
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
|
1202
1468
|
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1203
1469
|
invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
@@ -1233,7 +1499,6 @@ def free_floating_coriolis_matrix(
|
|
1233
1499
|
# The Coriolis matrix computed in this representation is converted later
|
1234
1500
|
# to the active representation stored in data.
|
1235
1501
|
with data.switch_velocity_representation(VelRepr.Body):
|
1236
|
-
|
1237
1502
|
B_ν = data.generalized_velocity
|
1238
1503
|
|
1239
1504
|
# Doubly-left free-floating Jacobian.
|
@@ -1251,7 +1516,6 @@ def free_floating_coriolis_matrix(
|
|
1251
1516
|
|
1252
1517
|
# Compute the contribution of each link to the Coriolis matrix.
|
1253
1518
|
def compute_link_contribution(M, v, J, J̇) -> jtp.Array:
|
1254
|
-
|
1255
1519
|
return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)
|
1256
1520
|
|
1257
1521
|
C_B_links = jax.vmap(compute_link_contribution)(
|
@@ -1274,12 +1538,10 @@ def free_floating_coriolis_matrix(
|
|
1274
1538
|
# Adjust the representation of the Coriolis matrix.
|
1275
1539
|
# Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.
|
1276
1540
|
match data.velocity_representation:
|
1277
|
-
|
1278
1541
|
case VelRepr.Body:
|
1279
1542
|
return C_B
|
1280
1543
|
|
1281
1544
|
case VelRepr.Inertial:
|
1282
|
-
|
1283
1545
|
n = model.dofs()
|
1284
1546
|
W_H_B = data._base_transform
|
1285
1547
|
B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
|
@@ -1299,7 +1561,6 @@ def free_floating_coriolis_matrix(
|
|
1299
1561
|
return C
|
1300
1562
|
|
1301
1563
|
case VelRepr.Mixed:
|
1302
|
-
|
1303
1564
|
n = model.dofs()
|
1304
1565
|
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
|
1305
1566
|
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
@@ -1720,9 +1981,7 @@ def average_velocity_jacobian(
|
|
1720
1981
|
G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data)
|
1721
1982
|
|
1722
1983
|
match output_vel_repr:
|
1723
|
-
|
1724
1984
|
case VelRepr.Inertial:
|
1725
|
-
|
1726
1985
|
GW_J = G_J
|
1727
1986
|
W_p_CoM = js.com.com_position(model=model, data=data)
|
1728
1987
|
|
@@ -1732,7 +1991,6 @@ def average_velocity_jacobian(
|
|
1732
1991
|
return W_X_GW @ GW_J
|
1733
1992
|
|
1734
1993
|
case VelRepr.Body:
|
1735
|
-
|
1736
1994
|
GB_J = G_J
|
1737
1995
|
W_p_B = data.base_position
|
1738
1996
|
W_p_CoM = js.com.com_position(model=model, data=data)
|
@@ -1744,7 +2002,6 @@ def average_velocity_jacobian(
|
|
1744
2002
|
return B_X_GB @ GB_J
|
1745
2003
|
|
1746
2004
|
case VelRepr.Mixed:
|
1747
|
-
|
1748
2005
|
GW_J = G_J
|
1749
2006
|
W_p_B = data.base_position
|
1750
2007
|
W_p_CoM = js.com.com_position(model=model, data=data)
|
@@ -2039,6 +2296,104 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
|
|
2039
2296
|
return jnp.sum((m * W_p̃_CoM)[2] * model.gravity)
|
2040
2297
|
|
2041
2298
|
|
2299
|
+
# ===================
|
2300
|
+
# Hw parametrization
|
2301
|
+
# ===================
|
2302
|
+
|
2303
|
+
|
2304
|
+
@jax.jit
|
2305
|
+
@js.common.named_scope
|
2306
|
+
def update_hw_parameters(
|
2307
|
+
model: JaxSimModel, scaling_factors: ScalingFactors
|
2308
|
+
) -> JaxSimModel:
|
2309
|
+
"""
|
2310
|
+
Update the hardware parameters of the model by scaling the parameters of the links.
|
2311
|
+
|
2312
|
+
This function applies scaling factors to the hardware metadata of the links,
|
2313
|
+
updating their shape, dimensions, density, and other related parameters. It
|
2314
|
+
recalculates the mass and inertia tensors of the links based on the updated
|
2315
|
+
metadata and adjusts the joint model transforms accordingly.
|
2316
|
+
|
2317
|
+
Args:
|
2318
|
+
model: The JaxSimModel object to update.
|
2319
|
+
scaling_factors: A ScalingFactors object containing scaling factors for
|
2320
|
+
dimensions and density of the links.
|
2321
|
+
|
2322
|
+
Returns:
|
2323
|
+
The updated JaxSimModel object with modified hardware parameters.
|
2324
|
+
"""
|
2325
|
+
kin_dyn_params: KinDynParameters = model.kin_dyn_parameters
|
2326
|
+
link_parameters: LinkParameters = kin_dyn_params.link_parameters
|
2327
|
+
hw_link_metadata: HwLinkMetadata = kin_dyn_params.hw_link_metadata
|
2328
|
+
|
2329
|
+
# Apply scaling to hw_link_metadata using vmap
|
2330
|
+
updated_hw_link_metadata = jax.vmap(HwLinkMetadata.apply_scaling)(
|
2331
|
+
hw_link_metadata, scaling_factors
|
2332
|
+
)
|
2333
|
+
|
2334
|
+
# Compute mass and inertia once and unpack the results
|
2335
|
+
m_updated, I_com_updated = jax.vmap(HwLinkMetadata.compute_mass_and_inertia)(
|
2336
|
+
updated_hw_link_metadata
|
2337
|
+
)
|
2338
|
+
|
2339
|
+
# Rotate the inertia tensor at CoM with the link orientation, and store
|
2340
|
+
# it in KynDynParameters.
|
2341
|
+
I_L_updated = jax.vmap(
|
2342
|
+
lambda metadata, I_com: metadata.L_H_G[:3, :3]
|
2343
|
+
@ I_com
|
2344
|
+
@ metadata.L_H_G[:3, :3].T
|
2345
|
+
)(updated_hw_link_metadata, I_com_updated)
|
2346
|
+
|
2347
|
+
# Update link parameters
|
2348
|
+
updated_link_parameters = link_parameters.replace(
|
2349
|
+
mass=m_updated,
|
2350
|
+
inertia_elements=jax.vmap(LinkParameters.flatten_inertia_tensor)(I_L_updated),
|
2351
|
+
center_of_mass=jax.vmap(lambda metadata: metadata.L_H_G[:3, 3])(
|
2352
|
+
updated_hw_link_metadata
|
2353
|
+
),
|
2354
|
+
)
|
2355
|
+
|
2356
|
+
# Update joint model transforms (λ_H_pre)
|
2357
|
+
def update_λ_H_pre(joint_index):
|
2358
|
+
# Extract the transforms and masks for the current joint index across all links
|
2359
|
+
L_H_pre_for_joint = updated_hw_link_metadata.L_H_pre[:, joint_index]
|
2360
|
+
L_H_pre_mask_for_joint = updated_hw_link_metadata.L_H_pre_mask[:, joint_index]
|
2361
|
+
|
2362
|
+
# Use the mask to select the first valid transform or fall back to the original
|
2363
|
+
valid_transforms = jnp.where(
|
2364
|
+
L_H_pre_mask_for_joint[:, None, None], # Expand mask for broadcasting
|
2365
|
+
L_H_pre_for_joint, # Use the transform if the mask is True
|
2366
|
+
jnp.zeros_like(L_H_pre_for_joint), # Otherwise, use a zero matrix
|
2367
|
+
)
|
2368
|
+
|
2369
|
+
# Sum the valid transforms (only one will be non-zero due to the mask)
|
2370
|
+
selected_transform = jnp.sum(valid_transforms, axis=0)
|
2371
|
+
|
2372
|
+
# If no valid transform exists, fall back to the original λ_H_pre
|
2373
|
+
return jax.lax.cond(
|
2374
|
+
jnp.any(L_H_pre_mask_for_joint),
|
2375
|
+
lambda: selected_transform,
|
2376
|
+
lambda: kin_dyn_params.joint_model.λ_H_pre[joint_index],
|
2377
|
+
)
|
2378
|
+
|
2379
|
+
# Apply the update function to all joint indices
|
2380
|
+
updated_λ_H_pre = jax.vmap(update_λ_H_pre)(
|
2381
|
+
jnp.arange(kin_dyn_params.number_of_joints() + 1)
|
2382
|
+
)
|
2383
|
+
# Replace the joint model with the updated transforms
|
2384
|
+
updated_joint_model = kin_dyn_params.joint_model.replace(λ_H_pre=updated_λ_H_pre)
|
2385
|
+
|
2386
|
+
# Replace the kin_dyn_parameters with updated values
|
2387
|
+
updated_kin_dyn_params = kin_dyn_params.replace(
|
2388
|
+
link_parameters=updated_link_parameters,
|
2389
|
+
hw_link_metadata=updated_hw_link_metadata,
|
2390
|
+
joint_model=updated_joint_model,
|
2391
|
+
)
|
2392
|
+
|
2393
|
+
# Return the updated model
|
2394
|
+
return model.replace(kin_dyn_parameters=updated_kin_dyn_params)
|
2395
|
+
|
2396
|
+
|
2042
2397
|
# ==========
|
2043
2398
|
# Simulation
|
2044
2399
|
# ==========
|
jaxsim/math/joint_model.py
CHANGED
@@ -10,10 +10,11 @@ import jaxsim.typing as jtp
|
|
10
10
|
from jaxsim.math import Rotation
|
11
11
|
from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
|
12
12
|
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
|
13
|
+
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass
|
13
14
|
|
14
15
|
|
15
16
|
@jax_dataclasses.pytree_dataclass
|
16
|
-
class JointModel:
|
17
|
+
class JointModel(JaxsimDataclass):
|
17
18
|
"""
|
18
19
|
Class describing the joint kinematics of a robot model.
|
19
20
|
|
jaxsim/mujoco/utils.py
CHANGED
@@ -71,11 +71,7 @@ def mujoco_data_from_jaxsim(
|
|
71
71
|
|
72
72
|
model_helper.set_joint_positions(
|
73
73
|
joint_names=list(jaxsim_model.joint_names()),
|
74
|
-
positions=np.array(
|
75
|
-
jaxsim_data.joint_positions(
|
76
|
-
model=jaxsim_model, joint_names=jaxsim_model.joint_names()
|
77
|
-
)
|
78
|
-
),
|
74
|
+
positions=np.array(jaxsim_data.joint_positions),
|
79
75
|
)
|
80
76
|
|
81
77
|
# Updating these joints is not necessary after the first time.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.6.2.
|
3
|
+
Version: 0.6.2.dev296
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
|
6
6
|
Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
|
@@ -1,5 +1,5 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=EKeysKN-7UswwJLCl7n6qIBKQIVUtYsCMYu_tCoFn7g,3628
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=Bzgp_Dut_GzvkRH6pABuIcblUVxBbl7RquWxbbiRawk,528
|
3
3
|
jaxsim/exceptions.py,sha256=MQ3LRMfVMX2-g3qYj7mUVNV9OLlIA48TANJegbcQyXI,2641
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=7msl8t5Jt09RNYfKdPJtpjLfWurldcycDappb045Eso,761
|
@@ -12,16 +12,16 @@ jaxsim/api/data.py,sha256=9pxug2gFIZPwqi9kNYXhEziA5IQBB9KNNwIfPfc_kAU,23249
|
|
12
12
|
jaxsim/api/frame.py,sha256=4wg6GsyBQgYhSvc-ry_31JsaL66sZt3TtgwjB7NrHmk,14583
|
13
13
|
jaxsim/api/integrators.py,sha256=sHdTWw2Z-Q7jggA8zRkA1KYYd4HNIozXPwNvFwt0YBs,9011
|
14
14
|
jaxsim/api/joint.py,sha256=AnqlNWmBOay-gsoo0y4AbfFQ2OCJm-8T1E0IMhZeLoY,7457
|
15
|
-
jaxsim/api/kin_dyn_parameters.py,sha256=
|
15
|
+
jaxsim/api/kin_dyn_parameters.py,sha256=r1rOjKfe7Qg7xBWyCNtlAi_Wg1WTzToQrUdbZ7PO10Q,39073
|
16
16
|
jaxsim/api/link.py,sha256=bSZOYJDY9HJMgY25VzevTTsdFZTJc6yRRpslc5FhGHE,12782
|
17
|
-
jaxsim/api/model.py,sha256=
|
17
|
+
jaxsim/api/model.py,sha256=ppdriJBHWJ-qXey9Vjqnd7IjqdEE0R6W9ZG9Y7KuS2s,85460
|
18
18
|
jaxsim/api/ode.py,sha256=fp20_LK9lXw2DfNkQgrfJmtd_iBXDNzZkOn0u5Pm8Qw,6190
|
19
19
|
jaxsim/api/references.py,sha256=-vd50y3v-jkXAsILS432etIKV6e2EUE2oOeLHuUrfuQ,20399
|
20
20
|
jaxsim/math/__init__.py,sha256=dNozvtm8WsB7nxw4uK29yQQKPcDUEczr2zcHoZfMItc,383
|
21
21
|
jaxsim/math/adjoint.py,sha256=Pb0WAiAoN1ge8j_dPcdK307jmC5LzD1-DeUj9Z_NXkI,4667
|
22
22
|
jaxsim/math/cross.py,sha256=AM4HauuuT09q2TN42qvdXhJ9LvtCh0e7ZyLjP-7sANs,1498
|
23
23
|
jaxsim/math/inertia.py,sha256=T-iAjPYSD_72R0ZG8GDJhe5i3Jc3ojhlbBRSscTdCKg,1577
|
24
|
-
jaxsim/math/joint_model.py,sha256=
|
24
|
+
jaxsim/math/joint_model.py,sha256=vBnwXSsw2LCb2Tr5wl2iCo0KvLqcibBbeKcsoH5r9tk,6990
|
25
25
|
jaxsim/math/quaternion.py,sha256=MSaZywzJDxs2te1ZELeIcupKSFIA9q_pdXy7fDAEqM4,4539
|
26
26
|
jaxsim/math/rotation.py,sha256=TEUtT3X2tFieNxdlccup1pfaTgCTtfX-hTNotd8-nNk,1892
|
27
27
|
jaxsim/math/skew.py,sha256=z_9YN-NDHL3n4KXWNbzTSMkFDZ0SDpz4RUcwwYFOaao,1402
|
@@ -31,7 +31,7 @@ jaxsim/mujoco/__init__.py,sha256=1kAWzYOS7nP29S5FGyWPqiAnPf4yPSoaPW-WBGBjVV0,214
|
|
31
31
|
jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
|
32
32
|
jaxsim/mujoco/loaders.py,sha256=OCk1T11iIm3qZUibNpo_bxxLgaGSkCpLt7ae_ND0ExA,23272
|
33
33
|
jaxsim/mujoco/model.py,sha256=bRXo1uhWDN37sP9qdejr_2vq_PXHQ7p6eyBlFff_JcE,16492
|
34
|
-
jaxsim/mujoco/utils.py,sha256=
|
34
|
+
jaxsim/mujoco/utils.py,sha256=q75OSjxLU2BromVUejt0DVnSbrV5D177YW6LkOdu78g,8823
|
35
35
|
jaxsim/mujoco/visualizer.py,sha256=cmI6DhFb1XC7oEtg_wl-s-U56dWHA-F7GlBD6EDYTyA,7744
|
36
36
|
jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
37
37
|
jaxsim/parsers/kinematic_graph.py,sha256=ARq11Pv6yMDZeRRDlqrWzfVfUS7qSDwR57aWA4k54as,35758
|
@@ -65,8 +65,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
|
65
65
|
jaxsim/utils/jaxsim_dataclass.py,sha256=XzmZeIibcaOzaxpprsGSxH3UrM66PAO456rFV91sNXg,11453
|
66
66
|
jaxsim/utils/tracing.py,sha256=Btwxdfhb7fJLk3r5PlQkGYj60Y2KbFT1gANGIA697FU,530
|
67
67
|
jaxsim/utils/wrappers.py,sha256=3IMwydqFgmSPqeuUQ3PRmdhDc1IoT6XC23jPC_LjWXs,4175
|
68
|
-
jaxsim-0.6.2.
|
69
|
-
jaxsim-0.6.2.
|
70
|
-
jaxsim-0.6.2.
|
71
|
-
jaxsim-0.6.2.
|
72
|
-
jaxsim-0.6.2.
|
68
|
+
jaxsim-0.6.2.dev296.dist-info/licenses/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
69
|
+
jaxsim-0.6.2.dev296.dist-info/METADATA,sha256=4jgGP0Aq1uGxo0AybZcSHs9U8Xqxq-noHcs6ctLrFWM,19658
|
70
|
+
jaxsim-0.6.2.dev296.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
71
|
+
jaxsim-0.6.2.dev296.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
72
|
+
jaxsim-0.6.2.dev296.dist-info/RECORD,,
|
File without changes
|
File without changes
|