jaxsim 0.4.3.dev312__py3-none-any.whl → 0.4.3.dev350__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +65 -28
- jaxsim/api/joint.py +8 -9
- jaxsim/api/kin_dyn_parameters.py +9 -4
- jaxsim/api/link.py +3 -4
- jaxsim/api/model.py +21 -22
- jaxsim/api/references.py +1 -1
- jaxsim/integrators/common.py +2 -2
- jaxsim/integrators/variable_step.py +6 -12
- jaxsim/mujoco/loaders.py +9 -138
- jaxsim/mujoco/utils.py +123 -1
- jaxsim/parsers/descriptions/joint.py +1 -26
- jaxsim/parsers/kinematic_graph.py +3 -3
- jaxsim/parsers/rod/parser.py +3 -6
- jaxsim/parsers/rod/utils.py +1 -1
- jaxsim/rbda/collidable_points.py +18 -5
- jaxsim/rbda/contacts/common.py +11 -9
- jaxsim/rbda/contacts/relaxed_rigid.py +14 -5
- jaxsim/rbda/contacts/rigid.py +9 -6
- jaxsim/rbda/contacts/soft.py +17 -4
- jaxsim/rbda/jacobian.py +2 -2
- jaxsim/rbda/utils.py +1 -1
- jaxsim/terrain/terrain.py +9 -1
- jaxsim/utils/tracing.py +3 -9
- jaxsim/utils/wrappers.py +1 -1
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/METADATA +1 -1
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/RECORD +30 -30
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev350'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev350')
|
jaxsim/api/contact.py
CHANGED
@@ -138,7 +138,7 @@ def collidable_point_dynamics(
|
|
138
138
|
**kwargs,
|
139
139
|
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
140
140
|
r"""
|
141
|
-
Compute the 6D force applied to each collidable point.
|
141
|
+
Compute the 6D force applied to each enabled collidable point.
|
142
142
|
|
143
143
|
Args:
|
144
144
|
model: The model to consider.
|
@@ -151,7 +151,7 @@ def collidable_point_dynamics(
|
|
151
151
|
kwargs: Additional keyword arguments to pass to the active contact model.
|
152
152
|
|
153
153
|
Returns:
|
154
|
-
The 6D force applied to each collidable point and additional data based
|
154
|
+
The 6D force applied to each eneabled collidable point and additional data based
|
155
155
|
on the contact model configured:
|
156
156
|
- Soft: the material deformation rate.
|
157
157
|
- Rigid: no additional data.
|
@@ -199,15 +199,19 @@ def collidable_point_dynamics(
|
|
199
199
|
)
|
200
200
|
|
201
201
|
# Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
|
202
|
-
# associated to
|
202
|
+
# associated to the enabled collidable point.
|
203
203
|
# In inertial-fixed representation, the computation of these transforms
|
204
204
|
# is not necessary and the conversion below becomes a no-op.
|
205
|
+
|
206
|
+
# Get the indices of the enabled collidable points.
|
207
|
+
indices_of_enabled_collidable_points = (
|
208
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
209
|
+
)
|
210
|
+
|
205
211
|
W_H_C = (
|
206
212
|
js.contact.transforms(model=model, data=data)
|
207
213
|
if data.velocity_representation is not VelRepr.Inertial
|
208
|
-
else jnp.zeros(
|
209
|
-
shape=(len(model.kin_dyn_parameters.contact_parameters.body), 4, 4)
|
210
|
-
)
|
214
|
+
else jnp.zeros(shape=(len(indices_of_enabled_collidable_points), 4, 4))
|
211
215
|
)
|
212
216
|
|
213
217
|
# Convert the 6D forces to the active representation.
|
@@ -246,6 +250,15 @@ def in_contact(
|
|
246
250
|
if link_names is not None and set(link_names).difference(model.link_names()):
|
247
251
|
raise ValueError("One or more link names are not part of the model")
|
248
252
|
|
253
|
+
# Get the indices of the enabled collidable points.
|
254
|
+
indices_of_enabled_collidable_points = (
|
255
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
256
|
+
)
|
257
|
+
|
258
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
259
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
260
|
+
)[indices_of_enabled_collidable_points]
|
261
|
+
|
249
262
|
W_p_Ci = collidable_point_positions(model=model, data=data)
|
250
263
|
|
251
264
|
terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
|
@@ -262,7 +275,7 @@ def in_contact(
|
|
262
275
|
|
263
276
|
links_in_contact = jax.vmap(
|
264
277
|
lambda link_index: jnp.where(
|
265
|
-
|
278
|
+
parent_link_idx_of_enabled_collidable_points == link_index,
|
266
279
|
below_terrain,
|
267
280
|
jnp.zeros_like(below_terrain, dtype=bool),
|
268
281
|
).any()
|
@@ -426,14 +439,14 @@ def estimate_good_contact_parameters(
|
|
426
439
|
@jax.jit
|
427
440
|
def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
|
428
441
|
r"""
|
429
|
-
Return the pose of the collidable points.
|
442
|
+
Return the pose of the enabled collidable points.
|
430
443
|
|
431
444
|
Args:
|
432
445
|
model: The model to consider.
|
433
446
|
data: The data of the considered model.
|
434
447
|
|
435
448
|
Returns:
|
436
|
-
The stacked SE(3) matrices of all collidable points.
|
449
|
+
The stacked SE(3) matrices of all enabled collidable points.
|
437
450
|
|
438
451
|
Note:
|
439
452
|
Each collidable point is implicitly associated with a frame
|
@@ -442,16 +455,27 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
|
|
442
455
|
rigidly attached to.
|
443
456
|
"""
|
444
457
|
|
458
|
+
# Get the indices of the enabled collidable points.
|
459
|
+
indices_of_enabled_collidable_points = (
|
460
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
461
|
+
)
|
462
|
+
|
463
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
464
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
465
|
+
)[indices_of_enabled_collidable_points]
|
466
|
+
|
445
467
|
# Get the transforms of the parent link of all collidable points.
|
446
468
|
W_H_L = js.model.forward_kinematics(model=model, data=data)[
|
447
|
-
|
469
|
+
parent_link_idx_of_enabled_collidable_points
|
470
|
+
]
|
471
|
+
|
472
|
+
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
|
473
|
+
indices_of_enabled_collidable_points
|
448
474
|
]
|
449
475
|
|
450
476
|
# Build the link-to-point transform from the displacement between the link frame L
|
451
477
|
# and the implicit contact frame C.
|
452
|
-
L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(
|
453
|
-
model.kin_dyn_parameters.contact_parameters.point
|
454
|
-
)
|
478
|
+
L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)
|
455
479
|
|
456
480
|
# Compose the work-to-link and link-to-point transforms.
|
457
481
|
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
|
@@ -465,7 +489,7 @@ def jacobian(
|
|
465
489
|
output_vel_repr: VelRepr | None = None,
|
466
490
|
) -> jtp.Array:
|
467
491
|
r"""
|
468
|
-
Return the free-floating Jacobian of the collidable points.
|
492
|
+
Return the free-floating Jacobian of the enabled collidable points.
|
469
493
|
|
470
494
|
Args:
|
471
495
|
model: The model to consider.
|
@@ -475,7 +499,7 @@ def jacobian(
|
|
475
499
|
|
476
500
|
Returns:
|
477
501
|
The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the
|
478
|
-
collidable points.
|
502
|
+
enabled collidable points.
|
479
503
|
|
480
504
|
Note:
|
481
505
|
Each collidable point is implicitly associated with a frame
|
@@ -488,6 +512,15 @@ def jacobian(
|
|
488
512
|
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
489
513
|
)
|
490
514
|
|
515
|
+
# Get the indices of the enabled collidable points.
|
516
|
+
indices_of_enabled_collidable_points = (
|
517
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
518
|
+
)
|
519
|
+
|
520
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
521
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
522
|
+
)[indices_of_enabled_collidable_points]
|
523
|
+
|
491
524
|
# Compute the Jacobians of all links.
|
492
525
|
W_J_WL = js.model.generalized_free_floating_jacobian(
|
493
526
|
model=model, data=data, output_vel_repr=VelRepr.Inertial
|
@@ -496,9 +529,7 @@ def jacobian(
|
|
496
529
|
# Compute the contact Jacobian.
|
497
530
|
# In inertial-fixed output representation, the Jacobian of the parent link is also
|
498
531
|
# the Jacobian of the frame C implicitly associated with the collidable point.
|
499
|
-
W_J_WC = W_J_WL[
|
500
|
-
jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int)
|
501
|
-
]
|
532
|
+
W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points]
|
502
533
|
|
503
534
|
# Adjust the output representation.
|
504
535
|
match output_vel_repr:
|
@@ -550,7 +581,7 @@ def jacobian_derivative(
|
|
550
581
|
output_vel_repr: VelRepr | None = None,
|
551
582
|
) -> jtp.Matrix:
|
552
583
|
r"""
|
553
|
-
Compute the derivative of the free-floating jacobian of the
|
584
|
+
Compute the derivative of the free-floating jacobian of the enabled collidable points.
|
554
585
|
|
555
586
|
Args:
|
556
587
|
model: The model to consider.
|
@@ -559,7 +590,7 @@ def jacobian_derivative(
|
|
559
590
|
The output velocity representation of the free-floating jacobian derivative.
|
560
591
|
|
561
592
|
Returns:
|
562
|
-
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the
|
593
|
+
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the enabled collidable points.
|
563
594
|
|
564
595
|
Note:
|
565
596
|
The input representation of the free-floating jacobian derivative is the active
|
@@ -570,10 +601,18 @@ def jacobian_derivative(
|
|
570
601
|
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
571
602
|
)
|
572
603
|
|
604
|
+
indices_of_enabled_collidable_points = (
|
605
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
606
|
+
)
|
607
|
+
|
573
608
|
# Get the index of the parent link and the position of the collidable point.
|
574
|
-
|
575
|
-
|
576
|
-
|
609
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
610
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
611
|
+
)[indices_of_enabled_collidable_points]
|
612
|
+
|
613
|
+
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
|
614
|
+
indices_of_enabled_collidable_points
|
615
|
+
]
|
577
616
|
|
578
617
|
# Get the transforms of all the parent links.
|
579
618
|
W_H_Li = js.model.forward_kinematics(model=model, data=data)
|
@@ -646,7 +685,7 @@ def jacobian_derivative(
|
|
646
685
|
output_vel_repr=VelRepr.Inertial,
|
647
686
|
)
|
648
687
|
|
649
|
-
# Get the Jacobian of the collidable points in the mixed representation.
|
688
|
+
# Get the Jacobian of the enabled collidable points in the mixed representation.
|
650
689
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
651
690
|
CW_J_WC_BW = jacobian(
|
652
691
|
model=model,
|
@@ -656,13 +695,11 @@ def jacobian_derivative(
|
|
656
695
|
|
657
696
|
def compute_O_J̇_WC_I(
|
658
697
|
L_p_C: jtp.Vector,
|
659
|
-
|
698
|
+
parent_link_idx: jtp.Int,
|
660
699
|
CW_J_WC_BW: jtp.Matrix,
|
661
700
|
W_H_L: jtp.Matrix,
|
662
701
|
) -> jtp.Matrix:
|
663
702
|
|
664
|
-
parent_link_idx = parent_link_idxs[contact_idx]
|
665
|
-
|
666
703
|
match output_vel_repr:
|
667
704
|
case VelRepr.Inertial:
|
668
705
|
O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841
|
@@ -703,7 +740,7 @@ def jacobian_derivative(
|
|
703
740
|
return O_J̇_WC_I
|
704
741
|
|
705
742
|
O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, 0, None))(
|
706
|
-
L_p_Ci,
|
743
|
+
L_p_Ci, parent_link_idx_of_enabled_collidable_points, CW_J_WC_BW, W_H_Li
|
707
744
|
)
|
708
745
|
|
709
746
|
return O_J̇_WC
|
jaxsim/api/joint.py
CHANGED
@@ -53,9 +53,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
|
|
53
53
|
"""
|
54
54
|
|
55
55
|
exceptions.raise_value_error_if(
|
56
|
-
condition=
|
57
|
-
[joint_index < 0, joint_index >= model.number_of_joints()]
|
58
|
-
).any(),
|
56
|
+
condition=joint_index < 0,
|
59
57
|
msg="Invalid joint index '{idx}'",
|
60
58
|
idx=joint_index,
|
61
59
|
)
|
@@ -123,10 +121,7 @@ def position_limit(
|
|
123
121
|
"""
|
124
122
|
|
125
123
|
if model.number_of_joints() == 0:
|
126
|
-
|
127
|
-
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max
|
128
|
-
|
129
|
-
return jnp.atleast_1d(s_min).astype(float), jnp.atleast_1d(s_max).astype(float)
|
124
|
+
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
|
130
125
|
|
131
126
|
exceptions.raise_value_error_if(
|
132
127
|
condition=jnp.array(
|
@@ -136,8 +131,12 @@ def position_limit(
|
|
136
131
|
idx=joint_index,
|
137
132
|
)
|
138
133
|
|
139
|
-
s_min =
|
140
|
-
|
134
|
+
s_min = jnp.atleast_1d(
|
135
|
+
model.kin_dyn_parameters.joint_parameters.position_limits_min
|
136
|
+
)[joint_index]
|
137
|
+
s_max = jnp.atleast_1d(
|
138
|
+
model.kin_dyn_parameters.joint_parameters.position_limits_max
|
139
|
+
)[joint_index]
|
141
140
|
|
142
141
|
return s_min.astype(float), s_max.astype(float)
|
143
142
|
|
jaxsim/api/kin_dyn_parameters.py
CHANGED
@@ -438,7 +438,9 @@ class KynDynParameters(JaxsimDataclass):
|
|
438
438
|
# Helpers to update parameters
|
439
439
|
# ============================
|
440
440
|
|
441
|
-
def set_link_mass(
|
441
|
+
def set_link_mass(
|
442
|
+
self, link_index: jtp.IntLike, mass: jtp.FloatLike
|
443
|
+
) -> KynDynParameters:
|
442
444
|
"""
|
443
445
|
Set the mass of a link.
|
444
446
|
|
@@ -457,7 +459,7 @@ class KynDynParameters(JaxsimDataclass):
|
|
457
459
|
return self.replace(link_parameters=link_parameters)
|
458
460
|
|
459
461
|
def set_link_inertia(
|
460
|
-
self, link_index:
|
462
|
+
self, link_index: jtp.IntLike, inertia: jtp.MatrixLike
|
461
463
|
) -> KynDynParameters:
|
462
464
|
r"""
|
463
465
|
Set the inertia tensor of a link.
|
@@ -593,10 +595,10 @@ class LinkParameters(JaxsimDataclass):
|
|
593
595
|
"""
|
594
596
|
|
595
597
|
# Extract the link parameters from the 6D spatial inertia.
|
596
|
-
m, L_p_CoM,
|
598
|
+
m, L_p_CoM, I_CoM = Inertia.to_params(M=M)
|
597
599
|
|
598
600
|
# Extract only the necessary elements of the inertia tensor.
|
599
|
-
inertia_elements =
|
601
|
+
inertia_elements = I_CoM[jnp.triu_indices(3)]
|
600
602
|
|
601
603
|
return LinkParameters(
|
602
604
|
index=jnp.array(index).squeeze().astype(int),
|
@@ -743,6 +745,9 @@ class ContactParameters(JaxsimDataclass):
|
|
743
745
|
point:
|
744
746
|
The translations between the link frame and the collidable point, expressed
|
745
747
|
in the coordinates of the parent link frame.
|
748
|
+
enabled:
|
749
|
+
A tuple of booleans representing, for each collidable point, whether it is
|
750
|
+
enabled or not in contact models.
|
746
751
|
|
747
752
|
Note:
|
748
753
|
Contrarily to LinkParameters and JointParameters, this class is not meant
|
jaxsim/api/link.py
CHANGED
@@ -4,6 +4,7 @@ from collections.abc import Sequence
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
6
|
import jax.scipy.linalg
|
7
|
+
import numpy as np
|
7
8
|
|
8
9
|
import jaxsim.api as js
|
9
10
|
import jaxsim.rbda
|
@@ -54,9 +55,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
|
|
54
55
|
"""
|
55
56
|
|
56
57
|
exceptions.raise_value_error_if(
|
57
|
-
condition=
|
58
|
-
[link_index < 0, link_index >= model.number_of_links()]
|
59
|
-
).any(),
|
58
|
+
condition=link_index < 0,
|
60
59
|
msg="Invalid link index '{idx}'",
|
61
60
|
idx=link_index,
|
62
61
|
)
|
@@ -98,7 +97,7 @@ def idxs_to_names(
|
|
98
97
|
The names of the links.
|
99
98
|
"""
|
100
99
|
|
101
|
-
return tuple(
|
100
|
+
return tuple(np.array(model.kin_dyn_parameters.link_names)[list(link_indices)])
|
102
101
|
|
103
102
|
|
104
103
|
# =========
|
jaxsim/api/model.py
CHANGED
@@ -304,7 +304,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
304
304
|
|
305
305
|
return self.model_name
|
306
306
|
|
307
|
-
def number_of_links(self) ->
|
307
|
+
def number_of_links(self) -> int:
|
308
308
|
"""
|
309
309
|
Return the number of links in the model.
|
310
310
|
|
@@ -317,7 +317,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
317
317
|
|
318
318
|
return self.kin_dyn_parameters.number_of_links()
|
319
319
|
|
320
|
-
def number_of_joints(self) ->
|
320
|
+
def number_of_joints(self) -> int:
|
321
321
|
"""
|
322
322
|
Return the number of joints in the model.
|
323
323
|
|
@@ -419,7 +419,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
419
419
|
def reduce(
|
420
420
|
model: JaxSimModel,
|
421
421
|
considered_joints: tuple[str, ...],
|
422
|
-
locked_joint_positions: dict[str, jtp.
|
422
|
+
locked_joint_positions: dict[str, jtp.FloatLike] | None = None,
|
423
423
|
) -> JaxSimModel:
|
424
424
|
"""
|
425
425
|
Reduce the model by lumping together the links connected by removed joints.
|
@@ -1038,12 +1038,7 @@ def forward_dynamics_aba(
|
|
1038
1038
|
C_v̇_WB = to_active(
|
1039
1039
|
W_v̇_WB=W_v̇_WB,
|
1040
1040
|
W_H_C=W_H_C,
|
1041
|
-
W_v_WB=
|
1042
|
-
[
|
1043
|
-
data.state.physics_model.base_linear_velocity,
|
1044
|
-
data.state.physics_model.base_angular_velocity,
|
1045
|
-
]
|
1046
|
-
),
|
1041
|
+
W_v_WB=W_v_WB,
|
1047
1042
|
W_v_WC=W_v_WC,
|
1048
1043
|
)
|
1049
1044
|
|
@@ -2274,20 +2269,23 @@ def step(
|
|
2274
2269
|
# Raise runtime error for not supported case in which Rigid contacts and
|
2275
2270
|
# Baumgarte stabilization are enabled and used with ForwardEuler integrator.
|
2276
2271
|
jaxsim.exceptions.raise_runtime_error_if(
|
2277
|
-
condition=
|
2278
|
-
|
2279
|
-
|
2280
|
-
|
2281
|
-
|
2282
|
-
|
2283
|
-
jnp.array(
|
2284
|
-
[data_tf.contacts_params.K, data_tf.contacts_params.D]
|
2285
|
-
).any(),
|
2286
|
-
),
|
2272
|
+
condition=isinstance(
|
2273
|
+
integrator,
|
2274
|
+
jaxsim.integrators.fixed_step.ForwardEuler
|
2275
|
+
| jaxsim.integrators.fixed_step.ForwardEulerSO3,
|
2276
|
+
)
|
2277
|
+
& ((data_tf.contacts_params.K > 0) | (data_tf.contacts_params.D > 0)),
|
2287
2278
|
msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
|
2288
2279
|
)
|
2289
2280
|
|
2290
|
-
|
2281
|
+
# Extract the indices corresponding to the enabled collidable points.
|
2282
|
+
indices_of_enabled_collidable_points = (
|
2283
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
2284
|
+
)
|
2285
|
+
|
2286
|
+
W_p_C = js.contact.collidable_point_positions(model, data_tf)[
|
2287
|
+
indices_of_enabled_collidable_points
|
2288
|
+
]
|
2291
2289
|
|
2292
2290
|
# Compute the penetration depth of the collidable points.
|
2293
2291
|
δ, *_ = jax.vmap(
|
@@ -2296,8 +2294,9 @@ def step(
|
|
2296
2294
|
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
|
2297
2295
|
|
2298
2296
|
with data_tf.switch_velocity_representation(VelRepr.Mixed):
|
2299
|
-
|
2300
|
-
|
2297
|
+
J_WC = js.contact.jacobian(model, data_tf)[
|
2298
|
+
indices_of_enabled_collidable_points
|
2299
|
+
]
|
2301
2300
|
M = js.model.free_floating_mass_matrix(model, data_tf)
|
2302
2301
|
|
2303
2302
|
# Compute the impact velocity.
|
jaxsim/api/references.py
CHANGED
@@ -503,7 +503,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
503
503
|
]
|
504
504
|
|
505
505
|
exceptions.raise_value_error_if(
|
506
|
-
condition
|
506
|
+
condition=~data.valid(model=model),
|
507
507
|
msg="The provided data is not valid for the model",
|
508
508
|
)
|
509
509
|
W_H_Fi = jax.vmap(
|
jaxsim/integrators/common.py
CHANGED
@@ -319,7 +319,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
319
319
|
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
|
320
320
|
|
321
321
|
# Initialize the carry of the for loop with the stacked kᵢ vectors.
|
322
|
-
carry0 = jax.
|
322
|
+
carry0 = jax.tree.map(
|
323
323
|
lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
|
324
324
|
)
|
325
325
|
|
@@ -507,7 +507,7 @@ class ExplicitRungeKuttaSO3Mixin:
|
|
507
507
|
|
508
508
|
# We assume that the initial quaternion is already unary.
|
509
509
|
exceptions.raise_runtime_error_if(
|
510
|
-
condition
|
510
|
+
condition=~jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0),
|
511
511
|
msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
|
512
512
|
)
|
513
513
|
|
@@ -152,7 +152,7 @@ def compute_pytree_scale(
|
|
152
152
|
"""
|
153
153
|
|
154
154
|
# Consider a zero second pytree, if not given.
|
155
|
-
x2 = jax.tree.map(
|
155
|
+
x2 = jax.tree.map(jnp.zeros_like, x1) if x2 is None else x2
|
156
156
|
|
157
157
|
# Compute the scaling factors of the initial state and its derivative.
|
158
158
|
compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
|
@@ -199,9 +199,7 @@ def local_error_estimation(
|
|
199
199
|
|
200
200
|
# Consider a zero estimated final state, if not given.
|
201
201
|
xf_estimate = (
|
202
|
-
jax.tree.map(
|
203
|
-
if xf_estimate is None
|
204
|
-
else xf_estimate
|
202
|
+
jax.tree.map(jnp.zeros_like, xf) if xf_estimate is None else xf_estimate
|
205
203
|
)
|
206
204
|
|
207
205
|
# Estimate the error.
|
@@ -483,14 +481,10 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
483
481
|
metadata_next,
|
484
482
|
discarded_steps,
|
485
483
|
) = jax.lax.cond(
|
486
|
-
pred=
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
Δt_next < self.dt_min,
|
491
|
-
integrator_init,
|
492
|
-
]
|
493
|
-
).any(),
|
484
|
+
pred=discarded_steps
|
485
|
+
>= self.max_step_rejections | local_error
|
486
|
+
<= 1.0 | Δt_next
|
487
|
+
< self.dt_min | integrator_init,
|
494
488
|
true_fun=accept_step,
|
495
489
|
false_fun=reject_step,
|
496
490
|
)
|
jaxsim/mujoco/loaders.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import dataclasses
|
4
1
|
import pathlib
|
5
2
|
import tempfile
|
6
3
|
import warnings
|
@@ -9,10 +6,14 @@ from typing import Any
|
|
9
6
|
|
10
7
|
import mujoco as mj
|
11
8
|
import numpy as np
|
12
|
-
import numpy.typing as npt
|
13
9
|
import rod.urdf.exporter
|
14
10
|
from lxml import etree as ET
|
15
|
-
|
11
|
+
|
12
|
+
from .utils import MujocoCamera
|
13
|
+
|
14
|
+
MujocoCameraType = (
|
15
|
+
MujocoCamera | Sequence[MujocoCamera] | dict[str, str] | Sequence[dict[str, str]]
|
16
|
+
)
|
16
17
|
|
17
18
|
|
18
19
|
def load_rod_model(
|
@@ -167,12 +168,7 @@ class RodModelToMjcf:
|
|
167
168
|
plane_normal: tuple[float, float, float] = (0, 0, 1),
|
168
169
|
heightmap: bool | None = None,
|
169
170
|
heightmap_samples_xy: tuple[int, int] = (101, 101),
|
170
|
-
cameras: (
|
171
|
-
MujocoCamera
|
172
|
-
| Sequence[MujocoCamera]
|
173
|
-
| dict[str, str]
|
174
|
-
| Sequence[dict[str, str]]
|
175
|
-
) = (),
|
171
|
+
cameras: MujocoCameraType = (),
|
176
172
|
) -> tuple[str, dict[str, Any]]:
|
177
173
|
"""
|
178
174
|
Converts a ROD model to a Mujoco MJCF string.
|
@@ -533,12 +529,7 @@ class UrdfToMjcf:
|
|
533
529
|
model_name: str | None = None,
|
534
530
|
plane_normal: tuple[float, float, float] = (0, 0, 1),
|
535
531
|
heightmap: bool | None = None,
|
536
|
-
cameras: (
|
537
|
-
MujocoCamera
|
538
|
-
| Sequence[MujocoCamera]
|
539
|
-
| dict[str, str]
|
540
|
-
| Sequence[dict[str, str]]
|
541
|
-
) = (),
|
532
|
+
cameras: MujocoCameraType = (),
|
542
533
|
) -> tuple[str, dict[str, Any]]:
|
543
534
|
"""
|
544
535
|
Converts a URDF file to a Mujoco MJCF string.
|
@@ -580,12 +571,7 @@ class SdfToMjcf:
|
|
580
571
|
model_name: str | None = None,
|
581
572
|
plane_normal: tuple[float, float, float] = (0, 0, 1),
|
582
573
|
heightmap: bool | None = None,
|
583
|
-
cameras: (
|
584
|
-
MujocoCamera
|
585
|
-
| Sequence[MujocoCamera]
|
586
|
-
| dict[str, str]
|
587
|
-
| Sequence[dict[str, str]]
|
588
|
-
) = (),
|
574
|
+
cameras: MujocoCameraType = (),
|
589
575
|
) -> tuple[str, dict[str, Any]]:
|
590
576
|
"""
|
591
577
|
Converts a SDF file to a Mujoco MJCF string.
|
@@ -617,118 +603,3 @@ class SdfToMjcf:
|
|
617
603
|
heightmap=heightmap,
|
618
604
|
cameras=cameras,
|
619
605
|
)
|
620
|
-
|
621
|
-
|
622
|
-
@dataclasses.dataclass
|
623
|
-
class MujocoCamera:
|
624
|
-
"""
|
625
|
-
Helper class storing parameters of a Mujoco camera.
|
626
|
-
|
627
|
-
Refer to the official documentation for more details:
|
628
|
-
https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
|
629
|
-
"""
|
630
|
-
|
631
|
-
mode: str = "fixed"
|
632
|
-
|
633
|
-
target: str | None = None
|
634
|
-
fovy: str = "45"
|
635
|
-
pos: str = "0 0 0"
|
636
|
-
|
637
|
-
quat: str | None = None
|
638
|
-
axisangle: str | None = None
|
639
|
-
xyaxes: str | None = None
|
640
|
-
zaxis: str | None = None
|
641
|
-
euler: str | None = None
|
642
|
-
|
643
|
-
name: str | None = None
|
644
|
-
|
645
|
-
@classmethod
|
646
|
-
def build(cls, **kwargs) -> MujocoCamera:
|
647
|
-
|
648
|
-
if not all(isinstance(value, str) for value in kwargs.values()):
|
649
|
-
raise ValueError(f"Values must be strings: {kwargs}")
|
650
|
-
|
651
|
-
return cls(**kwargs)
|
652
|
-
|
653
|
-
@staticmethod
|
654
|
-
def build_from_target_view(
|
655
|
-
camera_name: str,
|
656
|
-
lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
|
657
|
-
distance: float | int | npt.NDArray = 3,
|
658
|
-
azimut: float | int | npt.NDArray = 90,
|
659
|
-
elevation: float | int | npt.NDArray = -45,
|
660
|
-
fovy: float | int | npt.NDArray = 45,
|
661
|
-
degrees: bool = True,
|
662
|
-
**kwargs,
|
663
|
-
) -> MujocoCamera:
|
664
|
-
"""
|
665
|
-
Create a custom camera that looks at a target point.
|
666
|
-
|
667
|
-
Note:
|
668
|
-
The choice of the parameters is easier if we imagine to consider a target
|
669
|
-
frame `T` whose origin is located over the lookat point and having the same
|
670
|
-
orientation of the world frame `W`. We also introduce a camera frame `C`
|
671
|
-
whose origin is located over the lower-left corner of the image, and having
|
672
|
-
the x-axis pointing right and the y-axis pointing up in image coordinates.
|
673
|
-
The camera renders what it sees in the -z direction of frame `C`.
|
674
|
-
|
675
|
-
Args:
|
676
|
-
camera_name: The name of the camera.
|
677
|
-
lookat: The target point to look at (origin of `T`).
|
678
|
-
distance:
|
679
|
-
The distance from the target point (displacement between the origins
|
680
|
-
of `T` and `C`).
|
681
|
-
azimut:
|
682
|
-
The rotation around z of the camera. With an angle of 0, the camera
|
683
|
-
would loot at the target point towards the positive x-axis of `T`.
|
684
|
-
elevation:
|
685
|
-
The rotation around the x-axis of the camera frame `C`. Note that if
|
686
|
-
you want to lift the view angle, the elevation is negative.
|
687
|
-
fovy: The field of view of the camera.
|
688
|
-
degrees: Whether the angles are in degrees or radians.
|
689
|
-
**kwargs: Additional camera parameters.
|
690
|
-
|
691
|
-
Returns:
|
692
|
-
The custom camera.
|
693
|
-
"""
|
694
|
-
|
695
|
-
# Start from a frame whose origin is located over the lookat point.
|
696
|
-
# We initialize a -90 degrees rotation around the z-axis because due to
|
697
|
-
# the default camera coordinate system (x pointing right, y pointing up).
|
698
|
-
W_H_C = np.eye(4)
|
699
|
-
W_H_C[0:3, 3] = np.array(lookat)
|
700
|
-
W_H_C[0:3, 0:3] = Rotation.from_euler(
|
701
|
-
seq="ZX", angles=[-90, 90], degrees=True
|
702
|
-
).as_matrix()
|
703
|
-
|
704
|
-
# Process the azimut.
|
705
|
-
R_az = Rotation.from_euler(seq="Y", angles=azimut, degrees=degrees).as_matrix()
|
706
|
-
W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
|
707
|
-
|
708
|
-
# Process elevation.
|
709
|
-
R_el = Rotation.from_euler(
|
710
|
-
seq="X", angles=elevation, degrees=degrees
|
711
|
-
).as_matrix()
|
712
|
-
W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
|
713
|
-
|
714
|
-
# Process distance.
|
715
|
-
tf_distance = np.eye(4)
|
716
|
-
tf_distance[2, 3] = distance
|
717
|
-
W_H_C = W_H_C @ tf_distance
|
718
|
-
|
719
|
-
# Extract the position and the quaternion.
|
720
|
-
p = W_H_C[0:3, 3]
|
721
|
-
Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
|
722
|
-
|
723
|
-
return MujocoCamera.build(
|
724
|
-
name=camera_name,
|
725
|
-
mode="fixed",
|
726
|
-
fovy=f"{fovy if degrees else np.rad2deg(fovy)}",
|
727
|
-
pos=" ".join(p.astype(str).tolist()),
|
728
|
-
quat=" ".join(Q.astype(str).tolist()),
|
729
|
-
**kwargs,
|
730
|
-
)
|
731
|
-
|
732
|
-
def asdict(self) -> dict[str, str]:
|
733
|
-
|
734
|
-
return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
|