jaxsim 0.2.1.dev123__py3-none-any.whl → 0.2.1.dev155__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +39 -7
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +4 -4
- jaxsim/api/kin_dyn_parameters.py +4 -1
- jaxsim/api/link.py +2 -2
- jaxsim/api/model.py +137 -25
- jaxsim/api/ode.py +15 -2
- jaxsim/math/joint_model.py +2 -2
- jaxsim/mujoco/loaders.py +4 -5
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/visualizer.py +1 -1
- jaxsim/parsers/descriptions/collision.py +58 -23
- jaxsim/parsers/descriptions/joint.py +39 -17
- jaxsim/parsers/descriptions/link.py +9 -6
- jaxsim/parsers/descriptions/model.py +31 -11
- jaxsim/parsers/kinematic_graph.py +36 -11
- jaxsim/parsers/rod/parser.py +1 -1
- jaxsim/parsers/rod/utils.py +5 -7
- jaxsim/rbda/crba.py +2 -2
- jaxsim/rbda/forward_kinematics.py +1 -1
- jaxsim/typing.py +4 -1
- {jaxsim-0.2.1.dev123.dist-info → jaxsim-0.2.1.dev155.dist-info}/METADATA +5 -5
- {jaxsim-0.2.1.dev123.dist-info → jaxsim-0.2.1.dev155.dist-info}/RECORD +26 -26
- {jaxsim-0.2.1.dev123.dist-info → jaxsim-0.2.1.dev155.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.1.dev123.dist-info → jaxsim-0.2.1.dev155.dist-info}/WHEEL +0 -0
- {jaxsim-0.2.1.dev123.dist-info → jaxsim-0.2.1.dev155.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py
CHANGED
@@ -8,8 +8,9 @@ def _jnp_options() -> None:
|
|
8
8
|
|
9
9
|
import jax
|
10
10
|
|
11
|
-
# Enable by default
|
12
|
-
if
|
11
|
+
# Enable by default 64bit precision in JAX.
|
12
|
+
if os.environ.get("JAX_ENABLE_X64", "1") != "0":
|
13
|
+
|
13
14
|
logging.info("Enabling JAX to use 64bit precision")
|
14
15
|
jax.config.update("jax_enable_x64", True)
|
15
16
|
|
@@ -27,6 +28,7 @@ def _np_options() -> None:
|
|
27
28
|
|
28
29
|
|
29
30
|
def _is_editable() -> bool:
|
31
|
+
|
30
32
|
import importlib.util
|
31
33
|
import pathlib
|
32
34
|
import site
|
@@ -45,11 +47,40 @@ def _is_editable() -> bool:
|
|
45
47
|
return jaxsim_package_dir not in site.getsitepackages()
|
46
48
|
|
47
49
|
|
48
|
-
|
49
|
-
|
50
|
-
logging
|
51
|
-
|
52
|
-
|
50
|
+
def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
|
51
|
+
"""
|
52
|
+
Get the default logging level.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
env_var: The environment variable to check.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
The logging level to set.
|
59
|
+
"""
|
60
|
+
|
61
|
+
import os
|
62
|
+
|
63
|
+
# Define the default logging level depending on the installation mode.
|
64
|
+
default_logging_level = (
|
65
|
+
logging.LoggingLevel.DEBUG
|
66
|
+
if _is_editable() # noqa: F821
|
67
|
+
else logging.LoggingLevel.WARNING
|
68
|
+
)
|
69
|
+
|
70
|
+
# Allow to override the default logging level with an environment variable.
|
71
|
+
try:
|
72
|
+
return logging.LoggingLevel[
|
73
|
+
os.environ.get(env_var, default_logging_level.name).upper()
|
74
|
+
]
|
75
|
+
|
76
|
+
except KeyError as exc:
|
77
|
+
msg = f"Invalid logging level defined in {env_var}='{os.environ[env_var]}'"
|
78
|
+
raise RuntimeError(msg) from exc
|
79
|
+
|
80
|
+
|
81
|
+
# Configure the logger with the default logging level.
|
82
|
+
logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL"))
|
83
|
+
|
53
84
|
|
54
85
|
# Configure JAX
|
55
86
|
_jnp_options()
|
@@ -59,6 +90,7 @@ _np_options()
|
|
59
90
|
|
60
91
|
del _jnp_options
|
61
92
|
del _np_options
|
93
|
+
del _get_default_logging_level
|
62
94
|
del _is_editable
|
63
95
|
|
64
96
|
from . import terrain # isort:skip
|
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.2.1.
|
16
|
-
__version_tuple__ = version_tuple = (0, 2, 1, '
|
15
|
+
__version__ = version = '0.2.1.dev155'
|
16
|
+
__version_tuple__ = version_tuple = (0, 2, 1, 'dev155')
|
jaxsim/api/contact.py
CHANGED
@@ -365,20 +365,20 @@ def jacobian(
|
|
365
365
|
|
366
366
|
W_H_C = transforms(model=model, data=data)
|
367
367
|
|
368
|
-
def
|
368
|
+
def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
|
369
369
|
C_X_W = jaxsim.math.Adjoint.from_transform(
|
370
370
|
transform=W_H_C, inverse=True
|
371
371
|
)
|
372
372
|
C_J_WC = C_X_W @ W_J_WC
|
373
373
|
return C_J_WC
|
374
374
|
|
375
|
-
O_J_WC = jax.vmap(
|
375
|
+
O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)
|
376
376
|
|
377
377
|
case VelRepr.Mixed:
|
378
378
|
|
379
379
|
W_H_C = transforms(model=model, data=data)
|
380
380
|
|
381
|
-
def
|
381
|
+
def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
|
382
382
|
|
383
383
|
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
|
384
384
|
|
@@ -389,7 +389,7 @@ def jacobian(
|
|
389
389
|
CW_J_WC = CW_X_W @ W_J_WC
|
390
390
|
return CW_J_WC
|
391
391
|
|
392
|
-
O_J_WC = jax.vmap(
|
392
|
+
O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC)
|
393
393
|
|
394
394
|
case _:
|
395
395
|
raise ValueError(output_vel_repr)
|
jaxsim/api/kin_dyn_parameters.py
CHANGED
@@ -6,6 +6,7 @@ import jax.lax
|
|
6
6
|
import jax.numpy as jnp
|
7
7
|
import jax_dataclasses
|
8
8
|
import jaxlie
|
9
|
+
import numpy as np
|
9
10
|
from jax_dataclasses import Static
|
10
11
|
|
11
12
|
import jaxsim.typing as jtp
|
@@ -220,7 +221,9 @@ class KynDynParameters(JaxsimDataclass):
|
|
220
221
|
(
|
221
222
|
hash(self.number_of_links()),
|
222
223
|
hash(self.number_of_joints()),
|
223
|
-
hash(tuple(
|
224
|
+
hash(tuple(np.atleast_1d(self.parent_array).flatten().tolist())),
|
225
|
+
hash(self._parent_array),
|
226
|
+
hash(self._support_body_array_bool),
|
224
227
|
)
|
225
228
|
)
|
226
229
|
|
jaxsim/api/link.py
CHANGED
@@ -241,8 +241,8 @@ def jacobian(
|
|
241
241
|
)
|
242
242
|
|
243
243
|
# Compute the actual doubly-left free-floating jacobian of the link.
|
244
|
-
κ = model.kin_dyn_parameters.support_body_array_bool[link_index]
|
245
|
-
B_J_WL_B = jnp.hstack([jnp.ones(5), κ]) * B_J_full_WX_B
|
244
|
+
κb = model.kin_dyn_parameters.support_body_array_bool[link_index]
|
245
|
+
B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B
|
246
246
|
|
247
247
|
# Adjust the input representation such that `J_WL_I @ I_ν`.
|
248
248
|
match data.velocity_representation:
|
jaxsim/api/model.py
CHANGED
@@ -16,7 +16,8 @@ from jax_dataclasses import Static
|
|
16
16
|
import jaxsim.api as js
|
17
17
|
import jaxsim.parsers.descriptions
|
18
18
|
import jaxsim.typing as jtp
|
19
|
-
from jaxsim.
|
19
|
+
from jaxsim.math import Cross
|
20
|
+
from jaxsim.utils import JaxsimDataclass, Mutability
|
20
21
|
|
21
22
|
from .common import VelRepr
|
22
23
|
|
@@ -32,6 +33,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
32
33
|
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
|
33
34
|
default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
|
34
35
|
)
|
36
|
+
|
35
37
|
kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
|
36
38
|
dataclasses.field(default=None, repr=False, compare=False, hash=False)
|
37
39
|
)
|
@@ -40,13 +42,9 @@ class JaxSimModel(JaxsimDataclass):
|
|
40
42
|
default=None, repr=False, compare=False, hash=False
|
41
43
|
)
|
42
44
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
@property
|
48
|
-
def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
|
49
|
-
return self._description.get()
|
45
|
+
description: Static[jaxsim.parsers.descriptions.ModelDescription | None] = (
|
46
|
+
dataclasses.field(default=None, repr=False, compare=False, hash=False)
|
47
|
+
)
|
50
48
|
|
51
49
|
def __eq__(self, other: JaxSimModel) -> bool:
|
52
50
|
|
@@ -60,6 +58,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
60
58
|
return hash(
|
61
59
|
(
|
62
60
|
hash(self.model_name),
|
61
|
+
hash(self.description),
|
63
62
|
hash(self.kin_dyn_parameters),
|
64
63
|
)
|
65
64
|
)
|
@@ -156,7 +155,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
156
155
|
# Build the model
|
157
156
|
model = JaxSimModel(
|
158
157
|
model_name=model_name,
|
159
|
-
|
158
|
+
description=model_description,
|
160
159
|
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
|
161
160
|
model_description=model_description
|
162
161
|
),
|
@@ -301,7 +300,7 @@ def reduce(
|
|
301
300
|
locked_joint_positions:
|
302
301
|
A dictionary containing the positions of the joints to be considered
|
303
302
|
in the reduction process. The removed joints in the reduced model
|
304
|
-
will have their position locked to their value
|
303
|
+
will have their position locked to their value of this dictionary.
|
305
304
|
If a joint is not part of the dictionary, its position is set to zero.
|
306
305
|
"""
|
307
306
|
|
@@ -314,10 +313,9 @@ def reduce(
|
|
314
313
|
new_joints = set(model.joint_names()) - set(locked_joint_positions)
|
315
314
|
raise ValueError(f"Passed joints not existing in the model: {new_joints}")
|
316
315
|
|
317
|
-
#
|
318
|
-
|
319
|
-
|
320
|
-
)
|
316
|
+
# Operate on a deep copy of the model description in order to prevent problems
|
317
|
+
# when mutable attributes are updated.
|
318
|
+
intermediate_description = copy.deepcopy(model.description)
|
321
319
|
|
322
320
|
# Update the initial position of the joints.
|
323
321
|
# This is necessary to compute the correct pose of the link pairs connected
|
@@ -685,8 +683,6 @@ def forward_dynamics_aba(
|
|
685
683
|
another representation C_v̇_WB expressed in a generic frame C.
|
686
684
|
"""
|
687
685
|
|
688
|
-
from jaxsim.math import Cross
|
689
|
-
|
690
686
|
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
691
687
|
# In Inertial and Body representations, the cross product is always zero.
|
692
688
|
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
|
@@ -871,6 +867,126 @@ def free_floating_mass_matrix(
|
|
871
867
|
raise ValueError(data.velocity_representation)
|
872
868
|
|
873
869
|
|
870
|
+
@jax.jit
|
871
|
+
def free_floating_coriolis_matrix(
|
872
|
+
model: JaxSimModel, data: js.data.JaxSimModelData
|
873
|
+
) -> jtp.Matrix:
|
874
|
+
"""
|
875
|
+
Compute the free-floating Coriolis matrix of the model.
|
876
|
+
|
877
|
+
Args:
|
878
|
+
model: The model to consider.
|
879
|
+
data: The data of the considered model.
|
880
|
+
|
881
|
+
Returns:
|
882
|
+
The free-floating Coriolis matrix of the model.
|
883
|
+
|
884
|
+
Note:
|
885
|
+
This function, contrarily to other quantities of the equations of motion,
|
886
|
+
does not exploit any iterative algorithm. Therefore, the computation of
|
887
|
+
the Coriolis matrix may be much slower than other quantities.
|
888
|
+
"""
|
889
|
+
|
890
|
+
# We perform all the calculation in body-fixed.
|
891
|
+
# The Coriolis matrix computed in this representation is converted later
|
892
|
+
# to the active representation stored in data.
|
893
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
894
|
+
|
895
|
+
B_ν = data.generalized_velocity()
|
896
|
+
|
897
|
+
# Doubly-left free-floating Jacobian.
|
898
|
+
L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
|
899
|
+
|
900
|
+
# Doubly-left free-floating Jacobian derivative.
|
901
|
+
L_J̇_WL_B = jax.vmap(
|
902
|
+
lambda link_index: js.link.jacobian_derivative(
|
903
|
+
model=model, data=data, link_index=link_index
|
904
|
+
)
|
905
|
+
)(js.link.names_to_idxs(model=model, link_names=model.link_names()))
|
906
|
+
|
907
|
+
L_M_L = link_spatial_inertia_matrices(model=model)
|
908
|
+
|
909
|
+
# Body-fixed link velocities.
|
910
|
+
# Note: we could have called link.velocity() instead of computing it ourselves,
|
911
|
+
# but since we need the link Jacobians later, we can save a double calculation.
|
912
|
+
L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)
|
913
|
+
|
914
|
+
# Compute the contribution of each link to the Coriolis matrix.
|
915
|
+
def compute_link_contribution(M, v, J, J̇) -> jtp.Array:
|
916
|
+
|
917
|
+
return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)
|
918
|
+
|
919
|
+
C_B_links = jax.vmap(compute_link_contribution)(
|
920
|
+
L_M_L,
|
921
|
+
L_v_WL,
|
922
|
+
L_J_WL_B,
|
923
|
+
L_J̇_WL_B,
|
924
|
+
)
|
925
|
+
|
926
|
+
# We need to adjust the Coriolis matrix for fixed-base models.
|
927
|
+
# In this case, the base link does not contribute to the matrix, and we need to zero
|
928
|
+
# the off-diagonal terms mapping joint quantities onto the base configuration.
|
929
|
+
if model.floating_base():
|
930
|
+
C_B = C_B_links.sum(axis=0)
|
931
|
+
else:
|
932
|
+
C_B = C_B_links[1:].sum(axis=0)
|
933
|
+
C_B = C_B.at[0:6, 6:].set(0.0)
|
934
|
+
C_B = C_B.at[6:, 0:6].set(0.0)
|
935
|
+
|
936
|
+
# Adjust the representation of the Coriolis matrix.
|
937
|
+
# Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.
|
938
|
+
match data.velocity_representation:
|
939
|
+
|
940
|
+
case VelRepr.Body:
|
941
|
+
return C_B
|
942
|
+
|
943
|
+
case VelRepr.Inertial:
|
944
|
+
|
945
|
+
n = model.dofs()
|
946
|
+
W_H_B = data.base_transform()
|
947
|
+
B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
|
948
|
+
B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))
|
949
|
+
|
950
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
951
|
+
W_v_WB = data.base_velocity()
|
952
|
+
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
|
953
|
+
|
954
|
+
B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))
|
955
|
+
|
956
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
957
|
+
M = free_floating_mass_matrix(model=model, data=data)
|
958
|
+
|
959
|
+
C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W)
|
960
|
+
|
961
|
+
return C
|
962
|
+
|
963
|
+
case VelRepr.Mixed:
|
964
|
+
|
965
|
+
n = model.dofs()
|
966
|
+
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
967
|
+
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
968
|
+
B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))
|
969
|
+
|
970
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
971
|
+
BW_v_WB = data.base_velocity()
|
972
|
+
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
973
|
+
|
974
|
+
BW_v_BW_B = BW_v_WB - BW_v_W_BW
|
975
|
+
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
|
976
|
+
|
977
|
+
B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n)))
|
978
|
+
|
979
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
980
|
+
M = free_floating_mass_matrix(model=model, data=data)
|
981
|
+
|
982
|
+
C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW)
|
983
|
+
|
984
|
+
return C
|
985
|
+
|
986
|
+
case _:
|
987
|
+
raise ValueError(data.velocity_representation)
|
988
|
+
|
989
|
+
|
874
990
|
@jax.jit
|
875
991
|
def inverse_dynamics(
|
876
992
|
model: JaxSimModel,
|
@@ -931,8 +1047,6 @@ def inverse_dynamics(
|
|
931
1047
|
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
|
932
1048
|
"""
|
933
1049
|
|
934
|
-
from jaxsim.math import Cross
|
935
|
-
|
936
1050
|
W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
|
937
1051
|
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
|
938
1052
|
C_v_WC = C_X_W @ W_v_WC
|
@@ -1364,12 +1478,7 @@ def link_bias_accelerations(
|
|
1364
1478
|
# ================================================
|
1365
1479
|
|
1366
1480
|
# Compute the base transform.
|
1367
|
-
W_H_B =
|
1368
|
-
rotation=jaxlie.SO3.from_quaternion_xyzw(
|
1369
|
-
xyzw=jaxsim.math.Quaternion.to_xyzw(wxyz=data.base_orientation())
|
1370
|
-
),
|
1371
|
-
translation=data.base_position(),
|
1372
|
-
).as_matrix()
|
1481
|
+
W_H_B = data.base_transform()
|
1373
1482
|
|
1374
1483
|
def other_representation_to_inertial(
|
1375
1484
|
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
|
@@ -1410,9 +1519,12 @@ def link_bias_accelerations(
|
|
1410
1519
|
W_H_C = W_H_BW
|
1411
1520
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
1412
1521
|
W_ṗ_B = data.base_velocity()[0:3]
|
1413
|
-
|
1522
|
+
BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
|
1523
|
+
W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
|
1524
|
+
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW
|
1414
1525
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
1415
1526
|
C_v_WB = BW_v_WB = data.base_velocity()
|
1527
|
+
|
1416
1528
|
case _:
|
1417
1529
|
raise ValueError(data.velocity_representation)
|
1418
1530
|
|
jaxsim/api/ode.py
CHANGED
@@ -223,7 +223,9 @@ def system_velocity_dynamics(
|
|
223
223
|
|
224
224
|
@jax.jit
|
225
225
|
def system_position_dynamics(
|
226
|
-
model: js.model.JaxSimModel,
|
226
|
+
model: js.model.JaxSimModel,
|
227
|
+
data: js.data.JaxSimModelData,
|
228
|
+
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
|
227
229
|
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
|
228
230
|
"""
|
229
231
|
Compute the dynamics of the system position.
|
@@ -231,6 +233,8 @@ def system_position_dynamics(
|
|
231
233
|
Args:
|
232
234
|
model: The model to consider.
|
233
235
|
data: The data of the considered model.
|
236
|
+
baumgarte_quaternion_regularization:
|
237
|
+
The Baumgarte regularization coefficient for adjusting the quaternion norm.
|
234
238
|
|
235
239
|
Returns:
|
236
240
|
A tuple containing the derivative of the base position, the derivative of the
|
@@ -250,6 +254,7 @@ def system_position_dynamics(
|
|
250
254
|
quaternion=W_Q_B,
|
251
255
|
omega=W_ω_WB,
|
252
256
|
omega_in_body_fixed=False,
|
257
|
+
K=baumgarte_quaternion_regularization,
|
253
258
|
).squeeze()
|
254
259
|
|
255
260
|
return W_ṗ_B, W_Q̇_B, ṡ
|
@@ -262,6 +267,7 @@ def system_dynamics(
|
|
262
267
|
*,
|
263
268
|
joint_forces: jtp.Vector | None = None,
|
264
269
|
link_forces: jtp.Vector | None = None,
|
270
|
+
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
|
265
271
|
) -> tuple[ODEState, dict[str, Any]]:
|
266
272
|
"""
|
267
273
|
Compute the dynamics of the system.
|
@@ -271,6 +277,9 @@ def system_dynamics(
|
|
271
277
|
data: The data of the considered model.
|
272
278
|
joint_forces: The joint forces to apply.
|
273
279
|
link_forces: The 6D forces to apply to the links.
|
280
|
+
baumgarte_quaternion_regularization:
|
281
|
+
The Baumgarte regularization coefficient used to adjust the norm of the
|
282
|
+
quaternion (only used in integrators not operating on the SO(3) manifold).
|
274
283
|
|
275
284
|
Returns:
|
276
285
|
A tuple with an `ODEState` object storing in each of its attributes the
|
@@ -287,7 +296,11 @@ def system_dynamics(
|
|
287
296
|
)
|
288
297
|
|
289
298
|
# Extract the velocities.
|
290
|
-
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
299
|
+
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
300
|
+
model=model,
|
301
|
+
data=data,
|
302
|
+
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
|
303
|
+
)
|
291
304
|
|
292
305
|
# Create an ODEState object populated with the derivative of each leaf.
|
293
306
|
# Our integrators, operating on generic pytrees, will be able to handle it
|
jaxsim/math/joint_model.py
CHANGED
@@ -34,8 +34,8 @@ class JointModel:
|
|
34
34
|
already in a vectorized form. In other words, it cannot be created using vmap.
|
35
35
|
"""
|
36
36
|
|
37
|
-
λ_H_pre:
|
38
|
-
suc_H_i:
|
37
|
+
λ_H_pre: jtp.Array
|
38
|
+
suc_H_i: jtp.Array
|
39
39
|
|
40
40
|
joint_dofs: Static[tuple[int, ...]]
|
41
41
|
joint_names: Static[tuple[str, ...]]
|
jaxsim/mujoco/loaders.py
CHANGED
@@ -188,10 +188,9 @@ class RodModelToMjcf:
|
|
188
188
|
)
|
189
189
|
|
190
190
|
# If considered joints are passed, make sure that they are all part of the model.
|
191
|
-
if considered_joints -
|
192
|
-
extra_joints = set(considered_joints) -
|
193
|
-
|
194
|
-
)
|
191
|
+
if considered_joints - {j.name for j in rod_model.joints()}:
|
192
|
+
extra_joints = set(considered_joints) - {j.name for j in rod_model.joints()}
|
193
|
+
|
195
194
|
msg = f"Couldn't find the following joints in the model: '{extra_joints}'"
|
196
195
|
raise ValueError(msg)
|
197
196
|
|
@@ -352,7 +351,7 @@ class RodModelToMjcf:
|
|
352
351
|
# Set alpha=0 to the color of all collision elements
|
353
352
|
for geometry_element in mujoco_element.findall(".//geom[@rgba]"):
|
354
353
|
if geometry_element.attrib.get("name") in collision_names:
|
355
|
-
r, g, b,
|
354
|
+
r, g, b, _ = geometry_element.attrib["rgba"].split(" ")
|
356
355
|
geometry_element.set("rgba", f"{r} {g} {b} 0")
|
357
356
|
|
358
357
|
# -----------------------
|
jaxsim/mujoco/model.py
CHANGED
@@ -73,7 +73,7 @@ class MujocoModelHelper:
|
|
73
73
|
new_hfield = generate_hfield(heightmap, (nrow, ncol))
|
74
74
|
model.hfield_data = new_hfield
|
75
75
|
|
76
|
-
return MujocoModelHelper(model=model, data=
|
76
|
+
return MujocoModelHelper(model=model, data=data)
|
77
77
|
|
78
78
|
def time(self) -> float:
|
79
79
|
"""Return the simulation time."""
|
jaxsim/mujoco/visualizer.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import abc
|
2
4
|
import dataclasses
|
3
|
-
from typing import List
|
4
5
|
|
5
6
|
import jax.numpy as jnp
|
6
7
|
import numpy as np
|
7
8
|
import numpy.typing as npt
|
8
9
|
|
10
|
+
import jaxsim.typing as jtp
|
9
11
|
from jaxsim import logging
|
10
12
|
|
11
13
|
from .link import LinkDescription
|
@@ -17,9 +19,9 @@ class CollidablePoint:
|
|
17
19
|
Represents a collidable point associated with a parent link.
|
18
20
|
|
19
21
|
Attributes:
|
20
|
-
parent_link
|
21
|
-
position
|
22
|
-
enabled
|
22
|
+
parent_link: The parent link to which the collidable point is attached.
|
23
|
+
position: The position of the collidable point relative to the parent link.
|
24
|
+
enabled: A flag indicating whether the collidable point is enabled for collision detection.
|
23
25
|
|
24
26
|
"""
|
25
27
|
|
@@ -29,7 +31,7 @@ class CollidablePoint:
|
|
29
31
|
|
30
32
|
def change_link(
|
31
33
|
self, new_link: LinkDescription, new_H_old: npt.NDArray
|
32
|
-
) ->
|
34
|
+
) -> CollidablePoint:
|
33
35
|
"""
|
34
36
|
Move the collidable point to a new parent link.
|
35
37
|
|
@@ -39,8 +41,8 @@ class CollidablePoint:
|
|
39
41
|
|
40
42
|
Returns:
|
41
43
|
CollidablePoint: A new collidable point associated with the new parent link.
|
42
|
-
|
43
44
|
"""
|
45
|
+
|
44
46
|
msg = f"Moving collidable point: {self.parent_link.name} -> {new_link.name}"
|
45
47
|
logging.debug(msg=msg)
|
46
48
|
|
@@ -50,15 +52,24 @@ class CollidablePoint:
|
|
50
52
|
enabled=self.enabled,
|
51
53
|
)
|
52
54
|
|
53
|
-
def
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
55
|
+
def __hash__(self) -> int:
|
56
|
+
|
57
|
+
return hash(
|
58
|
+
(
|
59
|
+
hash(self.parent_link),
|
60
|
+
hash(tuple(self.position.tolist())),
|
61
|
+
hash(self.enabled),
|
62
|
+
)
|
58
63
|
)
|
59
|
-
return retval
|
60
64
|
|
61
|
-
def
|
65
|
+
def __eq__(self, other: CollidablePoint) -> bool:
|
66
|
+
|
67
|
+
if not isinstance(other, CollidablePoint):
|
68
|
+
return False
|
69
|
+
|
70
|
+
return hash(self) == hash(other)
|
71
|
+
|
72
|
+
def __str__(self) -> str:
|
62
73
|
return (
|
63
74
|
f"{self.__class__.__name__}("
|
64
75
|
+ f"parent_link={self.parent_link.name}"
|
@@ -74,11 +85,11 @@ class CollisionShape(abc.ABC):
|
|
74
85
|
Abstract base class for representing collision shapes.
|
75
86
|
|
76
87
|
Attributes:
|
77
|
-
collidable_points
|
88
|
+
collidable_points: A list of collidable points associated with the collision shape.
|
78
89
|
|
79
90
|
"""
|
80
91
|
|
81
|
-
collidable_points:
|
92
|
+
collidable_points: tuple[CollidablePoint]
|
82
93
|
|
83
94
|
def __str__(self):
|
84
95
|
return (
|
@@ -95,14 +106,26 @@ class BoxCollision(CollisionShape):
|
|
95
106
|
Represents a box-shaped collision shape.
|
96
107
|
|
97
108
|
Attributes:
|
98
|
-
center
|
109
|
+
center: The center of the box in the local frame of the collision shape.
|
99
110
|
|
100
111
|
"""
|
101
112
|
|
102
|
-
center:
|
113
|
+
center: jtp.VectorLike
|
103
114
|
|
104
|
-
def
|
105
|
-
return (
|
115
|
+
def __hash__(self) -> int:
|
116
|
+
return hash(
|
117
|
+
(
|
118
|
+
hash(super()),
|
119
|
+
hash(tuple(self.center.tolist())),
|
120
|
+
)
|
121
|
+
)
|
122
|
+
|
123
|
+
def __eq__(self, other: BoxCollision) -> bool:
|
124
|
+
|
125
|
+
if not isinstance(other, BoxCollision):
|
126
|
+
return False
|
127
|
+
|
128
|
+
return hash(self) == hash(other)
|
106
129
|
|
107
130
|
|
108
131
|
@dataclasses.dataclass
|
@@ -111,11 +134,23 @@ class SphereCollision(CollisionShape):
|
|
111
134
|
Represents a spherical collision shape.
|
112
135
|
|
113
136
|
Attributes:
|
114
|
-
center
|
137
|
+
center: The center of the sphere in the local frame of the collision shape.
|
115
138
|
|
116
139
|
"""
|
117
140
|
|
118
|
-
center:
|
141
|
+
center: jtp.VectorLike
|
142
|
+
|
143
|
+
def __hash__(self) -> int:
|
144
|
+
return hash(
|
145
|
+
(
|
146
|
+
hash(super()),
|
147
|
+
hash(tuple(self.center.tolist())),
|
148
|
+
)
|
149
|
+
)
|
150
|
+
|
151
|
+
def __eq__(self, other: BoxCollision) -> bool:
|
152
|
+
|
153
|
+
if not isinstance(other, BoxCollision):
|
154
|
+
return False
|
119
155
|
|
120
|
-
|
121
|
-
return (self.center == other.center).all() and super().__eq__(other)
|
156
|
+
return hash(self) == hash(other)
|
@@ -1,11 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
from typing import ClassVar
|
4
|
+
from typing import ClassVar
|
5
5
|
|
6
6
|
import jax_dataclasses
|
7
7
|
import numpy as np
|
8
|
-
import numpy.typing as npt
|
9
8
|
|
10
9
|
import jaxsim.typing as jtp
|
11
10
|
from jaxsim.utils import JaxsimDataclass, Mutability
|
@@ -15,6 +14,7 @@ from .link import LinkDescription
|
|
15
14
|
|
16
15
|
@dataclasses.dataclass(frozen=True)
|
17
16
|
class JointType:
|
17
|
+
|
18
18
|
Fixed: ClassVar[int] = 0
|
19
19
|
Revolute: ClassVar[int] = 1
|
20
20
|
Prismatic: ClassVar[int] = 2
|
@@ -64,29 +64,31 @@ class JointDescription(JaxsimDataclass):
|
|
64
64
|
"""
|
65
65
|
|
66
66
|
name: jax_dataclasses.Static[str]
|
67
|
-
axis:
|
68
|
-
pose:
|
69
|
-
jtype: jax_dataclasses.Static[
|
67
|
+
axis: jtp.Vector
|
68
|
+
pose: jtp.Matrix
|
69
|
+
jtype: jax_dataclasses.Static[jtp.IntLike]
|
70
70
|
child: LinkDescription = dataclasses.dataclass(repr=False)
|
71
71
|
parent: LinkDescription = dataclasses.dataclass(repr=False)
|
72
72
|
|
73
|
-
index:
|
73
|
+
index: jtp.IntLike | None = None
|
74
|
+
|
75
|
+
friction_static: jtp.FloatLike = 0.0
|
76
|
+
friction_viscous: jtp.FloatLike = 0.0
|
74
77
|
|
75
|
-
|
76
|
-
|
78
|
+
position_limit_damper: jtp.FloatLike = 0.0
|
79
|
+
position_limit_spring: jtp.FloatLike = 0.0
|
77
80
|
|
78
|
-
|
79
|
-
|
81
|
+
position_limit: tuple[jtp.FloatLike, jtp.FloatLike] = (0.0, 0.0)
|
82
|
+
initial_position: jtp.FloatLike | jtp.VectorLike = 0.0
|
80
83
|
|
81
|
-
|
82
|
-
|
84
|
+
motor_inertia: jtp.FloatLike = 0.0
|
85
|
+
motor_viscous_friction: jtp.FloatLike = 0.0
|
86
|
+
motor_gear_ratio: jtp.FloatLike = 1.0
|
83
87
|
|
84
|
-
|
85
|
-
motor_viscous_friction: float = 0.0
|
86
|
-
motor_gear_ratio: float = 1.0
|
88
|
+
def __post_init__(self) -> None:
|
87
89
|
|
88
|
-
def __post_init__(self):
|
89
90
|
if self.axis is not None:
|
91
|
+
|
90
92
|
with self.mutable_context(
|
91
93
|
mutability=Mutability.MUTABLE, restore_after_exception=False
|
92
94
|
):
|
@@ -94,4 +96,24 @@ class JointDescription(JaxsimDataclass):
|
|
94
96
|
self.axis = self.axis / norm_of_axis
|
95
97
|
|
96
98
|
def __hash__(self) -> int:
|
97
|
-
|
99
|
+
|
100
|
+
return hash(
|
101
|
+
(
|
102
|
+
hash(self.name),
|
103
|
+
hash(tuple(self.axis.tolist())),
|
104
|
+
hash(tuple(self.pose.flatten().tolist())),
|
105
|
+
hash(int(self.jtype)),
|
106
|
+
hash(self.child),
|
107
|
+
hash(self.parent),
|
108
|
+
hash(int(self.index)) if self.index is not None else 0,
|
109
|
+
hash(float(self.friction_static)),
|
110
|
+
hash(float(self.friction_viscous)),
|
111
|
+
hash(float(self.position_limit_damper)),
|
112
|
+
hash(float(self.position_limit_spring)),
|
113
|
+
hash((float(el) for el in self.position_limit)),
|
114
|
+
hash(tuple(np.atleast_1d(self.initial_position).tolist())),
|
115
|
+
hash(float(self.motor_inertia)),
|
116
|
+
hash(float(self.motor_viscous_friction)),
|
117
|
+
hash(float(self.motor_gear_ratio)),
|
118
|
+
),
|
119
|
+
)
|
@@ -5,6 +5,7 @@ import dataclasses
|
|
5
5
|
import jax.numpy as jnp
|
6
6
|
import jax_dataclasses
|
7
7
|
import jaxlie
|
8
|
+
import numpy as np
|
8
9
|
from jax_dataclasses import Static
|
9
10
|
|
10
11
|
import jaxsim.typing as jtp
|
@@ -23,7 +24,7 @@ class LinkDescription(JaxsimDataclass):
|
|
23
24
|
index: An optional index for the link (it gets automatically assigned).
|
24
25
|
parent: The parent link of this link.
|
25
26
|
pose: The pose transformation matrix of the link.
|
26
|
-
children:
|
27
|
+
children: The children links.
|
27
28
|
"""
|
28
29
|
|
29
30
|
name: Static[str]
|
@@ -33,7 +34,7 @@ class LinkDescription(JaxsimDataclass):
|
|
33
34
|
parent: LinkDescription = dataclasses.field(default=None, repr=False)
|
34
35
|
pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
|
35
36
|
|
36
|
-
children: Static[
|
37
|
+
children: Static[tuple[LinkDescription]] = dataclasses.field(
|
37
38
|
default_factory=list, repr=False
|
38
39
|
)
|
39
40
|
|
@@ -43,10 +44,12 @@ class LinkDescription(JaxsimDataclass):
|
|
43
44
|
(
|
44
45
|
hash(self.name),
|
45
46
|
hash(float(self.mass)),
|
46
|
-
hash(tuple(self.inertia.flatten().tolist())),
|
47
|
-
hash(int(self.index)),
|
48
|
-
hash(self.
|
49
|
-
hash(tuple(
|
47
|
+
hash(tuple(np.atleast_1d(self.inertia).flatten().tolist())),
|
48
|
+
hash(int(self.index)) if self.index is not None else 0,
|
49
|
+
hash(tuple(np.atleast_1d(self.pose).flatten().tolist())),
|
50
|
+
hash(tuple(self.children)),
|
51
|
+
# Here only using the name to prevent circular recursion:
|
52
|
+
hash(self.parent.name) if self.parent is not None else 0,
|
50
53
|
)
|
51
54
|
)
|
52
55
|
|
@@ -27,7 +27,7 @@ class ModelDescription(KinematicGraph):
|
|
27
27
|
|
28
28
|
fixed_base: bool = True
|
29
29
|
|
30
|
-
collision_shapes:
|
30
|
+
collision_shapes: tuple[CollisionShape, ...] = dataclasses.field(
|
31
31
|
default_factory=list, repr=False, hash=False
|
32
32
|
)
|
33
33
|
|
@@ -37,7 +37,7 @@ class ModelDescription(KinematicGraph):
|
|
37
37
|
links: list[LinkDescription],
|
38
38
|
joints: list[JointDescription],
|
39
39
|
frames: list[LinkDescription] | None = None,
|
40
|
-
collisions:
|
40
|
+
collisions: tuple[CollisionShape, ...] = (),
|
41
41
|
fixed_base: bool = False,
|
42
42
|
base_link_name: str | None = None,
|
43
43
|
considered_joints: Sequence[str] | None = None,
|
@@ -87,7 +87,7 @@ class ModelDescription(KinematicGraph):
|
|
87
87
|
for collision_shape in collisions:
|
88
88
|
|
89
89
|
# Get all the collidable points of the shape
|
90
|
-
coll_points =
|
90
|
+
coll_points = tuple(collision_shape.collidable_points)
|
91
91
|
|
92
92
|
# Assume they have an unique parent link
|
93
93
|
if not len(set({cp.parent_link.name for cp in coll_points})) == 1:
|
@@ -111,7 +111,7 @@ class ModelDescription(KinematicGraph):
|
|
111
111
|
continue
|
112
112
|
|
113
113
|
# Create a new collision shape
|
114
|
-
new_collision_shape = CollisionShape(collidable_points=
|
114
|
+
new_collision_shape = CollisionShape(collidable_points=())
|
115
115
|
final_collisions.append(new_collision_shape)
|
116
116
|
|
117
117
|
# If the frame was found, update the collidable points' pose and add them
|
@@ -133,19 +133,19 @@ class ModelDescription(KinematicGraph):
|
|
133
133
|
),
|
134
134
|
)
|
135
135
|
|
136
|
-
# Store the updated collision
|
137
|
-
new_collision_shape.collidable_points
|
136
|
+
# Store the updated collision.
|
137
|
+
new_collision_shape.collidable_points += (moved_cp,)
|
138
138
|
|
139
139
|
# Build the model
|
140
140
|
model = ModelDescription(
|
141
141
|
name=name,
|
142
142
|
root_pose=kinematic_graph.root_pose,
|
143
143
|
fixed_base=fixed_base,
|
144
|
-
collision_shapes=final_collisions,
|
144
|
+
collision_shapes=tuple(final_collisions),
|
145
145
|
root=kinematic_graph.root,
|
146
146
|
joints=kinematic_graph.joints,
|
147
147
|
frames=kinematic_graph.frames,
|
148
|
-
_joints_removed=kinematic_graph.
|
148
|
+
_joints_removed=kinematic_graph.joints_removed,
|
149
149
|
)
|
150
150
|
|
151
151
|
# Check that the root link of kinematic graph is the desired base link.
|
@@ -174,7 +174,7 @@ class ModelDescription(KinematicGraph):
|
|
174
174
|
links=list(self.links_dict.values()),
|
175
175
|
joints=self.joints,
|
176
176
|
frames=self.frames,
|
177
|
-
collisions=self.collision_shapes,
|
177
|
+
collisions=tuple(self.collision_shapes),
|
178
178
|
fixed_base=self.fixed_base,
|
179
179
|
base_link_name=list(iter(self))[0].name,
|
180
180
|
model_pose=self.root_pose,
|
@@ -182,8 +182,8 @@ class ModelDescription(KinematicGraph):
|
|
182
182
|
)
|
183
183
|
|
184
184
|
# Include the unconnected/removed joints from the original model.
|
185
|
-
for joint in self.
|
186
|
-
reduced_model_description.
|
185
|
+
for joint in self.joints_removed:
|
186
|
+
reduced_model_description.joints_removed.append(joint)
|
187
187
|
|
188
188
|
return reduced_model_description
|
189
189
|
|
@@ -243,3 +243,23 @@ class ModelDescription(KinematicGraph):
|
|
243
243
|
|
244
244
|
# Return enabled collidable points
|
245
245
|
return [cp for cp in all_collidable_points if cp.enabled]
|
246
|
+
|
247
|
+
def __eq__(self, other: ModelDescription) -> bool:
|
248
|
+
|
249
|
+
if not isinstance(other, ModelDescription):
|
250
|
+
return False
|
251
|
+
|
252
|
+
return hash(self) == hash(other)
|
253
|
+
|
254
|
+
def __hash__(self) -> int:
|
255
|
+
|
256
|
+
return hash(
|
257
|
+
(
|
258
|
+
hash(self.name),
|
259
|
+
hash(self.fixed_base),
|
260
|
+
hash(self.root),
|
261
|
+
hash(tuple(self.joints)),
|
262
|
+
hash(tuple(self.frames)),
|
263
|
+
hash(self.root_pose),
|
264
|
+
)
|
265
|
+
)
|
@@ -31,14 +31,21 @@ class RootPose(NamedTuple):
|
|
31
31
|
root_position: npt.NDArray = np.zeros(3)
|
32
32
|
root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])
|
33
33
|
|
34
|
+
def __hash__(self) -> int:
|
35
|
+
|
36
|
+
return hash(
|
37
|
+
(
|
38
|
+
hash(tuple(self.root_position.tolist())),
|
39
|
+
hash(tuple(self.root_quaternion.tolist())),
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
34
43
|
def __eq__(self, other: RootPose) -> bool:
|
35
44
|
|
36
45
|
if not isinstance(other, RootPose):
|
37
46
|
return False
|
38
47
|
|
39
|
-
return
|
40
|
-
self.root_quaternion, other.root_quaternion
|
41
|
-
)
|
48
|
+
return hash(self) == hash(other)
|
42
49
|
|
43
50
|
|
44
51
|
@dataclasses.dataclass(frozen=True)
|
@@ -54,22 +61,24 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
|
|
54
61
|
"""
|
55
62
|
|
56
63
|
root: descriptions.LinkDescription
|
57
|
-
frames: list[descriptions.LinkDescription] = dataclasses.field(
|
64
|
+
frames: list[descriptions.LinkDescription] = dataclasses.field(
|
65
|
+
default_factory=list, hash=False, compare=False
|
66
|
+
)
|
58
67
|
joints: list[descriptions.JointDescription] = dataclasses.field(
|
59
|
-
default_factory=list
|
68
|
+
default_factory=list, hash=False, compare=False
|
60
69
|
)
|
61
70
|
|
62
71
|
root_pose: RootPose = dataclasses.field(default_factory=lambda: RootPose())
|
63
72
|
|
64
73
|
# Private attribute storing optional additional info.
|
65
74
|
_extra_info: dict[str, Any] = dataclasses.field(
|
66
|
-
repr=False,
|
75
|
+
default_factory=dict, repr=False, hash=False, compare=False
|
67
76
|
)
|
68
77
|
|
69
78
|
# Private attribute storing the unconnected joints from the parsed model and
|
70
79
|
# the joints removed after model reduction.
|
71
80
|
_joints_removed: list[descriptions.JointDescription] = dataclasses.field(
|
72
|
-
default_factory=list, repr=False, compare=False
|
81
|
+
default_factory=list, repr=False, hash=False, compare=False
|
73
82
|
)
|
74
83
|
|
75
84
|
@functools.cached_property
|
@@ -98,14 +107,17 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
|
|
98
107
|
for index, link in enumerate(self):
|
99
108
|
link.mutable(validate=False).index = index
|
100
109
|
|
101
|
-
# Get the names of the links and
|
110
|
+
# Get the names of the links, frames, and joints.
|
102
111
|
link_names = [l.name for l in self]
|
103
112
|
frame_names = [f.name for f in self.frames]
|
113
|
+
joint_names = [j.name for j in self.joints]
|
104
114
|
|
105
115
|
# Make sure that they are unique.
|
106
116
|
assert len(link_names) == len(set(link_names))
|
107
117
|
assert len(frame_names) == len(set(frame_names))
|
118
|
+
assert len(joint_names) == len(set(joint_names))
|
108
119
|
assert set(link_names).isdisjoint(set(frame_names))
|
120
|
+
assert set(link_names).isdisjoint(set(joint_names))
|
109
121
|
|
110
122
|
# Order frames with their name.
|
111
123
|
super().__setattr__("frames", sorted(self.frames, key=lambda f: f.name))
|
@@ -251,7 +263,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
|
|
251
263
|
|
252
264
|
# Reset the connections of the root link.
|
253
265
|
for link in links_dict.values():
|
254
|
-
link.children =
|
266
|
+
link.children = tuple()
|
255
267
|
|
256
268
|
# Couple links and joints creating the kinematic graph.
|
257
269
|
for joint in joints:
|
@@ -268,7 +280,8 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
|
|
268
280
|
|
269
281
|
# Assign link's children and make sure they are unique.
|
270
282
|
if child_link.name not in {l.name for l in parent_link.children}:
|
271
|
-
parent_link.
|
283
|
+
with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION):
|
284
|
+
parent_link.children = parent_link.children + (child_link,)
|
272
285
|
|
273
286
|
# Collect all the links of the kinematic graph.
|
274
287
|
all_links_in_graph = list(
|
@@ -315,7 +328,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
|
|
315
328
|
# Update the unconnected links by removing their children. The other properties
|
316
329
|
# are left untouched, it's caller responsibility to post-process them if needed.
|
317
330
|
for link in unconnected_links:
|
318
|
-
link.children =
|
331
|
+
link.children = tuple()
|
319
332
|
msg = "Link '{}' won't be part of the kinematic graph because unconnected"
|
320
333
|
logging.debug(msg=msg.format(link.name))
|
321
334
|
|
@@ -615,6 +628,17 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
|
|
615
628
|
horizontal=True,
|
616
629
|
)
|
617
630
|
|
631
|
+
@property
|
632
|
+
def joints_removed(self) -> list[descriptions.JointDescription]:
|
633
|
+
"""
|
634
|
+
Get the list of joints removed during the graph reduction.
|
635
|
+
|
636
|
+
Returns:
|
637
|
+
The list of removed joints.
|
638
|
+
"""
|
639
|
+
|
640
|
+
return self._joints_removed
|
641
|
+
|
618
642
|
@staticmethod
|
619
643
|
def breadth_first_search(
|
620
644
|
root: descriptions.LinkDescription,
|
@@ -785,6 +809,7 @@ class KinematicGraphTransforms:
|
|
785
809
|
|
786
810
|
# Get the joint.
|
787
811
|
joint = self.graph.joints_dict[name]
|
812
|
+
assert joint.name == name
|
788
813
|
|
789
814
|
# Get the transform of the parent link.
|
790
815
|
M_H_L = self.transform(name=joint.parent.name)
|
jaxsim/parsers/rod/parser.py
CHANGED
jaxsim/parsers/rod/utils.py
CHANGED
@@ -6,7 +6,7 @@ import numpy.typing as npt
|
|
6
6
|
import rod
|
7
7
|
|
8
8
|
import jaxsim.typing as jtp
|
9
|
-
from jaxsim.math
|
9
|
+
from jaxsim.math import Inertia
|
10
10
|
from jaxsim.parsers import descriptions
|
11
11
|
|
12
12
|
|
@@ -59,9 +59,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
|
|
59
59
|
return M_L.astype(dtype=float)
|
60
60
|
|
61
61
|
|
62
|
-
def joint_to_joint_type(
|
63
|
-
joint: rod.Joint,
|
64
|
-
) -> descriptions.JointType:
|
62
|
+
def joint_to_joint_type(joint: rod.Joint) -> int:
|
65
63
|
"""
|
66
64
|
Extract the joint type from an SDF joint.
|
67
65
|
|
@@ -69,7 +67,7 @@ def joint_to_joint_type(
|
|
69
67
|
joint: The parsed SDF joint.
|
70
68
|
|
71
69
|
Returns:
|
72
|
-
The corresponding joint type
|
70
|
+
The integer corresponding to the joint type.
|
73
71
|
"""
|
74
72
|
|
75
73
|
axis = joint.axis
|
@@ -138,7 +136,7 @@ def create_box_collision(
|
|
138
136
|
collidable_points = [
|
139
137
|
descriptions.CollidablePoint(
|
140
138
|
parent_link=link_description,
|
141
|
-
position=corner,
|
139
|
+
position=np.array(corner),
|
142
140
|
enabled=True,
|
143
141
|
)
|
144
142
|
for corner in box_corners_wrt_link.T
|
@@ -197,7 +195,7 @@ def create_sphere_collision(
|
|
197
195
|
collidable_points = [
|
198
196
|
descriptions.CollidablePoint(
|
199
197
|
parent_link=link_description,
|
200
|
-
position=point,
|
198
|
+
position=np.array(point),
|
201
199
|
enabled=True,
|
202
200
|
)
|
203
201
|
for point in sphere_points_wrt_link.T
|
jaxsim/rbda/crba.py
CHANGED
@@ -111,7 +111,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
|
|
111
111
|
# a while loop using a for loop with fixed number of iterations.
|
112
112
|
def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> tuple[CarryInnerFn, None]:
|
113
113
|
def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]:
|
114
|
-
j,
|
114
|
+
j, _, _ = carry
|
115
115
|
out = jax.lax.cond(
|
116
116
|
pred=(λ[j] > 0),
|
117
117
|
true_fun=while_loop_body,
|
@@ -120,7 +120,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
|
|
120
120
|
)
|
121
121
|
return out, None
|
122
122
|
|
123
|
-
j,
|
123
|
+
j, _, _ = carry
|
124
124
|
return jax.lax.cond(
|
125
125
|
pred=(k == j),
|
126
126
|
true_fun=compute_inner,
|
@@ -49,7 +49,7 @@ def forward_kinematics_model(
|
|
49
49
|
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
50
50
|
# These transforms define the relative kinematics of the entire model, including
|
51
51
|
# the base transform for both floating-base and fixed-base models.
|
52
|
-
i_X_λi,
|
52
|
+
i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
53
53
|
joint_positions=s, base_transform=W_H_B.as_matrix()
|
54
54
|
)
|
55
55
|
|
jaxsim/typing.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
from
|
1
|
+
from collections.abc import Hashable
|
2
|
+
from typing import Any
|
2
3
|
|
3
4
|
import jax
|
4
5
|
|
@@ -24,6 +25,7 @@ PyTree = (
|
|
24
25
|
# =======================
|
25
26
|
|
26
27
|
Array = jax.typing.ArrayLike
|
28
|
+
Scalar = Array
|
27
29
|
Vector = Array
|
28
30
|
Matrix = Array
|
29
31
|
|
@@ -31,6 +33,7 @@ Int = int | IntJax
|
|
31
33
|
Bool = bool | ArrayJax
|
32
34
|
Float = float | FloatJax
|
33
35
|
|
36
|
+
ScalarLike = Scalar | int | float
|
34
37
|
ArrayLike = Array
|
35
38
|
VectorLike = Vector
|
36
39
|
MatrixLike = Matrix
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.2.1.
|
3
|
+
Version: 0.2.1.dev155
|
4
4
|
Home-page: https://github.com/ami-iit/jaxsim
|
5
5
|
Author: Diego Ferigo
|
6
6
|
Author-email: diego.ferigo@iit.it
|
@@ -68,7 +68,7 @@ Requires-Dist: mujoco >=3.0.0 ; extra == 'viz'
|
|
68
68
|
|
69
69
|
JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX.
|
70
70
|
|
71
|
-
Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
|
71
|
+
Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
|
72
72
|
|
73
73
|
## Features
|
74
74
|
|
@@ -91,7 +91,7 @@ Its design facilitates research and accelerates prototyping in the intersection
|
|
91
91
|
|
92
92
|
### JaxSim as a multibody dynamics library
|
93
93
|
|
94
|
-
- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
|
94
|
+
- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
|
95
95
|
- Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion.
|
96
96
|
- Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation].
|
97
97
|
- Exposes all the necessary quantities to develop controllers in centroidal coordinates.
|
@@ -198,10 +198,10 @@ The main differences between MJX/Brax and JaxSim are as follows:
|
|
198
198
|
|
199
199
|
- JaxSim supports out-of-the-box all SDF models with [Pose Frame Semantics][PFS].
|
200
200
|
- JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface.
|
201
|
-
Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous
|
201
|
+
Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous
|
202
202
|
function of the state $(\mathbf{q}, \boldsymbol{\nu})$, it doesn't require running any optimization algorithm
|
203
203
|
when stepping the simulation forward.
|
204
|
-
- JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators.
|
204
|
+
- JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators.
|
205
205
|
|
206
206
|
[brax]: https://github.com/google/brax
|
207
207
|
[mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html
|
@@ -1,18 +1,18 @@
|
|
1
|
-
jaxsim/__init__.py,sha256=
|
2
|
-
jaxsim/_version.py,sha256=
|
1
|
+
jaxsim/__init__.py,sha256=xzuTuZrgKdWLqqDzbvqzm2cJrEtAbepOeUqDu7ByVek,2621
|
2
|
+
jaxsim/_version.py,sha256=VTwovSkbhBoA3d5Ku6UfV56QTou0p3bVvIXRd63p0yQ,428
|
3
3
|
jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
|
4
|
-
jaxsim/typing.py,sha256=
|
4
|
+
jaxsim/typing.py,sha256=cl7HHQCeP3mHmtF6EuQZcCjGvDmc_AryMWntP_lRBGg,722
|
5
5
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
6
6
|
jaxsim/api/com.py,sha256=Yof6otFi-mLWAs1rqjmeNJTOWIH9gn7BdU5EIjiL6Ts,13481
|
7
7
|
jaxsim/api/common.py,sha256=bqQ__pIQZbh-j8rkoHUkYHAgGiJnDzjHG-q4Ny0OOYQ,6646
|
8
|
-
jaxsim/api/contact.py,sha256=
|
8
|
+
jaxsim/api/contact.py,sha256=79kcdq7C1_kWgxd1QWBabBhIPkwWEVLk-Fiz9kh-4so,12800
|
9
9
|
jaxsim/api/data.py,sha256=xfKJz6Rw0YTk-EHCGiT8BFQrs_ggOz01lRi1Qh1mb28,27256
|
10
10
|
jaxsim/api/frame.py,sha256=0YXOrGmx3cSQqa4_Ky-n6zyup3I3xvXNEgub-Bc5xUw,6222
|
11
11
|
jaxsim/api/joint.py,sha256=-5DogPg4g4mmLckyVIVNjwv-Rxz0IWS7_md9nDlhPWA,4581
|
12
|
-
jaxsim/api/kin_dyn_parameters.py,sha256=
|
13
|
-
jaxsim/api/link.py,sha256=
|
14
|
-
jaxsim/api/model.py,sha256=
|
15
|
-
jaxsim/api/ode.py,sha256=
|
12
|
+
jaxsim/api/kin_dyn_parameters.py,sha256=b1e96I8hKU5fh4StLObdVcDpr_6ZglrgD3SRyrqTu18,26203
|
13
|
+
jaxsim/api/link.py,sha256=MdMWaMpM5Dj5JHK8uwHZ4zR4Fjq3R4asi2sGTxk1OAs,16647
|
14
|
+
jaxsim/api/model.py,sha256=sCx9CcP23A1I_ae4UqTq4Fpq5u0aDki72CqgnR1H50w,59465
|
15
|
+
jaxsim/api/ode.py,sha256=luTQJsIXUtCp_81dR42X7WrMvwrXtYbyJiqss29v7zA,10786
|
16
16
|
jaxsim/api/ode_data.py,sha256=D6FzMkvY_qNuoFEImyp7sxAk-0pJOd3oZeSr9bBTcLk,23089
|
17
17
|
jaxsim/api/references.py,sha256=UA6kSQVBoq-bXSo99EOELf-_MD5MTy2zS0GtG3wQ410,16618
|
18
18
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
@@ -23,31 +23,31 @@ jaxsim/math/__init__.py,sha256=inJ9nRFkqstuGa8OyFkfWVudo5U9Ug4WgDBuKva8AIA,337
|
|
23
23
|
jaxsim/math/adjoint.py,sha256=DT21izjVW497GRrgNfx8tv0ZeWW5QncWMGMhI0acUNw,4425
|
24
24
|
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
25
25
|
jaxsim/math/inertia.py,sha256=UAB7ym4gXFanejcs_ovZMpteHCc6poWYmt-mLmd5hhk,1640
|
26
|
-
jaxsim/math/joint_model.py,sha256=
|
26
|
+
jaxsim/math/joint_model.py,sha256=xJSocGOyLzLJIQo4j5rBfMCPD4ltUQ2jCZfN747i2Ck,9989
|
27
27
|
jaxsim/math/quaternion.py,sha256=X9b8jHf0QemKUjIZSnXRJc3DdMr42CBhBy_mi9_X_AM,5068
|
28
28
|
jaxsim/math/rotation.py,sha256=Z90daUjGpuNEVLfWB3SVtM9EtwAIaneVj9A9UpWXqhA,2182
|
29
29
|
jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
|
30
30
|
jaxsim/math/transform.py,sha256=ZKzoXwKRowE17p3P0EgKrq8mocdMTGUKNyaKt6-bAX0,2941
|
31
31
|
jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
|
32
32
|
jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
|
33
|
-
jaxsim/mujoco/loaders.py,sha256=
|
34
|
-
jaxsim/mujoco/model.py,sha256=
|
35
|
-
jaxsim/mujoco/visualizer.py,sha256=
|
33
|
+
jaxsim/mujoco/loaders.py,sha256=7rjpeJ6_GuitlCty-ZkLhTILQ0GmsFzDMgve-7Gkkh4,20984
|
34
|
+
jaxsim/mujoco/model.py,sha256=1KVRjSLOTCuHt53apBPQTnFYJRknlVoKLQaxWsNK8qc,13494
|
35
|
+
jaxsim/mujoco/visualizer.py,sha256=PXgQzwetS9mRJYHBknDMLsQ9152FdrSvZuT9xE_dfIQ,5069
|
36
36
|
jaxsim/parsers/__init__.py,sha256=sonYi-bBWAoB04kp1mxT4uIORxjb7SdZ0ukGPmVx98Y,44
|
37
|
-
jaxsim/parsers/kinematic_graph.py,sha256=
|
37
|
+
jaxsim/parsers/kinematic_graph.py,sha256=WdIxntWfxXf67x90oM5KHHXFrSITMwVahqWgcOjYFzc,34730
|
38
38
|
jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
|
39
|
-
jaxsim/parsers/descriptions/collision.py,sha256=
|
40
|
-
jaxsim/parsers/descriptions/joint.py,sha256=
|
41
|
-
jaxsim/parsers/descriptions/link.py,sha256=
|
42
|
-
jaxsim/parsers/descriptions/model.py,sha256=
|
39
|
+
jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
|
40
|
+
jaxsim/parsers/descriptions/joint.py,sha256=z_nYSS0fdkcaerjUlPX0U1Vn1ArBT0u_XdKjqxG3HcY,3959
|
41
|
+
jaxsim/parsers/descriptions/link.py,sha256=QvEE7J6iMQLibpLqlcBV428UA7NMpFFXJwe35GYnjAY,3124
|
42
|
+
jaxsim/parsers/descriptions/model.py,sha256=V9nSyCK3mo7680WYMDEx1MTfdDTJzbCGPqAp3qA2XRE,9511
|
43
43
|
jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
|
44
|
-
jaxsim/parsers/rod/parser.py,sha256=
|
45
|
-
jaxsim/parsers/rod/utils.py,sha256=
|
44
|
+
jaxsim/parsers/rod/parser.py,sha256=4COuhkAYv4-GIpCqvkXEJWpDEQczEkBM3KwpqX48Rek,13514
|
45
|
+
jaxsim/parsers/rod/utils.py,sha256=KSjgy6WsmTrD5HZEA2x8hOBSRU4bUGOOHzxKkeFO5r8,5721
|
46
46
|
jaxsim/rbda/__init__.py,sha256=MqEZwzu8SHPAlIFHmSXmCjehuOJGRX58OrBVAbBVMwg,374
|
47
47
|
jaxsim/rbda/aba.py,sha256=0OoCzHhf1v-qqr1y5PIrD7_mPwAlid0fjXxUrIa5E_s,9118
|
48
48
|
jaxsim/rbda/collidable_points.py,sha256=4ZNJbEj2nEi15jBLR-GNbdaqKgkN58FBgqd_TXupEgg,4948
|
49
|
-
jaxsim/rbda/crba.py,sha256=
|
50
|
-
jaxsim/rbda/forward_kinematics.py,sha256=
|
49
|
+
jaxsim/rbda/crba.py,sha256=awsWEQXLE0UPEXIcZCVsAqBEPjyahMNzY9ux6nE1l-s,4739
|
50
|
+
jaxsim/rbda/forward_kinematics.py,sha256=94W7TUXvZjMb-99CyYR8pObuxIYYX9B_dtRZqsNcThs,3418
|
51
51
|
jaxsim/rbda/jacobian.py,sha256=M79bGir-2w_iJ2GurYhOGgMfJnp7ZMOCW6AeeWKK8iM,10745
|
52
52
|
jaxsim/rbda/rnea.py,sha256=DjwkvXQVUSUclM3Uy3UPZ2tao91R5dGd4o7TsS2qObI,7650
|
53
53
|
jaxsim/rbda/soft_contacts.py,sha256=52zJOF31hFpqoaOednTvi8j_UxhRcdGNjzOPb2v2MPc,11257
|
@@ -58,8 +58,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
|
58
58
|
jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
|
59
59
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
60
60
|
jaxsim/utils/wrappers.py,sha256=EJMcblYKUjxw9HJShVf81Ig3pHUJno6Dx6h-RnY--wM,2040
|
61
|
-
jaxsim-0.2.1.
|
62
|
-
jaxsim-0.2.1.
|
63
|
-
jaxsim-0.2.1.
|
64
|
-
jaxsim-0.2.1.
|
65
|
-
jaxsim-0.2.1.
|
61
|
+
jaxsim-0.2.1.dev155.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
62
|
+
jaxsim-0.2.1.dev155.dist-info/METADATA,sha256=Jhyuk3qGnK7WzHYbuA4b5mCXofQw7WZAj6zqh01jKkU,9740
|
63
|
+
jaxsim-0.2.1.dev155.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
64
|
+
jaxsim-0.2.1.dev155.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
65
|
+
jaxsim-0.2.1.dev155.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|