jaxsim 0.4.1.dev26__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/com.py +12 -12
- jaxsim/api/contact.py +168 -0
- jaxsim/api/data.py +2 -11
- jaxsim/api/frame.py +156 -1
- jaxsim/api/link.py +10 -7
- jaxsim/api/model.py +68 -24
- jaxsim/api/references.py +107 -1
- jaxsim/integrators/common.py +4 -2
- jaxsim/math/adjoint.py +3 -4
- jaxsim/math/joint_model.py +2 -2
- jaxsim/math/quaternion.py +2 -4
- jaxsim/math/transform.py +1 -3
- jaxsim/mujoco/loaders.py +12 -2
- jaxsim/mujoco/model.py +5 -5
- jaxsim/rbda/aba.py +2 -2
- jaxsim/rbda/collidable_points.py +2 -2
- jaxsim/rbda/contacts/soft.py +1 -1
- jaxsim/rbda/crba.py +24 -12
- jaxsim/rbda/forward_kinematics.py +10 -6
- jaxsim/rbda/jacobian.py +24 -12
- jaxsim/rbda/rnea.py +2 -2
- {jaxsim-0.4.1.dev26.dist-info → jaxsim-0.4.2.dist-info}/METADATA +31 -19
- {jaxsim-0.4.1.dev26.dist-info → jaxsim-0.4.2.dist-info}/RECORD +27 -27
- {jaxsim-0.4.1.dev26.dist-info → jaxsim-0.4.2.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.1.dev26.dist-info → jaxsim-0.4.2.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.1.dev26.dist-info → jaxsim-0.4.2.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py
CHANGED
@@ -301,10 +301,10 @@ class JaxSimModel(JaxsimDataclass):
|
|
301
301
|
|
302
302
|
def frame_names(self) -> tuple[str, ...]:
|
303
303
|
"""
|
304
|
-
Return the names of the
|
304
|
+
Return the names of the frames in the model.
|
305
305
|
|
306
306
|
Returns:
|
307
|
-
The names of the
|
307
|
+
The names of the frames in the model.
|
308
308
|
"""
|
309
309
|
|
310
310
|
return self.kin_dyn_parameters.frame_parameters.name
|
@@ -495,8 +495,9 @@ def generalized_free_floating_jacobian(
|
|
495
495
|
W_H_B = data.base_transform()
|
496
496
|
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
|
497
497
|
|
498
|
-
B_J_full_WX_I = B_J_full_WX_W =
|
499
|
-
|
498
|
+
B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841
|
499
|
+
B_J_full_WX_B
|
500
|
+
@ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
500
501
|
)
|
501
502
|
|
502
503
|
case VelRepr.Body:
|
@@ -509,7 +510,7 @@ def generalized_free_floating_jacobian(
|
|
509
510
|
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
|
510
511
|
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
511
512
|
|
512
|
-
B_J_full_WX_I = B_J_full_WX_BW = (
|
513
|
+
B_J_full_WX_I = B_J_full_WX_BW = ( # noqa: F841
|
513
514
|
B_J_full_WX_B
|
514
515
|
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
515
516
|
)
|
@@ -542,11 +543,13 @@ def generalized_free_floating_jacobian(
|
|
542
543
|
W_H_B = data.base_transform()
|
543
544
|
W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)
|
544
545
|
|
545
|
-
O_J_WL_I = W_J_WL_I = jax.vmap(
|
546
|
+
O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841
|
547
|
+
lambda B_J_WL_I: W_X_B @ B_J_WL_I
|
548
|
+
)(B_J_WL_I)
|
546
549
|
|
547
550
|
case VelRepr.Body:
|
548
551
|
|
549
|
-
O_J_WL_I = L_J_WL_I = jax.vmap(
|
552
|
+
O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841
|
550
553
|
lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
|
551
554
|
B_H_L, inverse=True
|
552
555
|
)
|
@@ -565,7 +568,7 @@ def generalized_free_floating_jacobian(
|
|
565
568
|
lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
|
566
569
|
)(LW_H_L, B_H_L)
|
567
570
|
|
568
|
-
O_J_WL_I = LW_J_WL_I = jax.vmap(
|
571
|
+
O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841
|
569
572
|
lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B)
|
570
573
|
@ B_J_WL_I
|
571
574
|
)(LW_H_B, B_J_WL_I)
|
@@ -576,6 +579,41 @@ def generalized_free_floating_jacobian(
|
|
576
579
|
return O_J_WL_I
|
577
580
|
|
578
581
|
|
582
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
583
|
+
def generalized_free_floating_jacobian_derivative(
|
584
|
+
model: JaxSimModel,
|
585
|
+
data: js.data.JaxSimModelData,
|
586
|
+
*,
|
587
|
+
output_vel_repr: VelRepr | None = None,
|
588
|
+
) -> jtp.Matrix:
|
589
|
+
"""
|
590
|
+
Compute the free-floating jacobian derivatives of all links.
|
591
|
+
|
592
|
+
Args:
|
593
|
+
model: The model to consider.
|
594
|
+
data: The data of the considered model.
|
595
|
+
output_vel_repr:
|
596
|
+
The output velocity representation of the free-floating jacobian derivatives.
|
597
|
+
|
598
|
+
Returns:
|
599
|
+
The `(nL, 6, 6+dofs)` array containing the stacked free-floating
|
600
|
+
jacobian derivatives of the links. The first axis is the link index.
|
601
|
+
"""
|
602
|
+
|
603
|
+
output_vel_repr = (
|
604
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
605
|
+
)
|
606
|
+
|
607
|
+
O_J̇_WL_I = jax.vmap(
|
608
|
+
lambda model, data, link_idxs, output_vel_repr: js.link.jacobian_derivative(
|
609
|
+
model, data, link_index=link_idxs, output_vel_repr=output_vel_repr
|
610
|
+
),
|
611
|
+
in_axes=(None, None, 0, None),
|
612
|
+
)(model, data, jnp.arange(model.number_of_links()), output_vel_repr)
|
613
|
+
|
614
|
+
return O_J̇_WL_I
|
615
|
+
|
616
|
+
|
579
617
|
@functools.partial(jax.jit, static_argnames=["prefer_aba"])
|
580
618
|
def forward_dynamics(
|
581
619
|
model: JaxSimModel,
|
@@ -721,8 +759,8 @@ def forward_dynamics_aba(
|
|
721
759
|
match data.velocity_representation:
|
722
760
|
case VelRepr.Inertial:
|
723
761
|
# In this case C=W
|
724
|
-
W_H_C = W_H_W = jnp.eye(4)
|
725
|
-
W_v_WC = W_v_WW = jnp.zeros(6)
|
762
|
+
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
|
763
|
+
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
726
764
|
|
727
765
|
case VelRepr.Body:
|
728
766
|
# In this case C=B
|
@@ -732,9 +770,9 @@ def forward_dynamics_aba(
|
|
732
770
|
case VelRepr.Mixed:
|
733
771
|
# In this case C=B[W]
|
734
772
|
W_H_B = data.base_transform()
|
735
|
-
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
773
|
+
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
|
736
774
|
W_ṗ_B = data.base_velocity()[0:3]
|
737
|
-
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
|
775
|
+
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
|
738
776
|
|
739
777
|
case _:
|
740
778
|
raise ValueError(data.velocity_representation)
|
@@ -1089,8 +1127,8 @@ def inverse_dynamics(
|
|
1089
1127
|
|
1090
1128
|
match data.velocity_representation:
|
1091
1129
|
case VelRepr.Inertial:
|
1092
|
-
W_H_C = W_H_W = jnp.eye(4)
|
1093
|
-
W_v_WC = W_v_WW = jnp.zeros(6)
|
1130
|
+
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
|
1131
|
+
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
1094
1132
|
|
1095
1133
|
case VelRepr.Body:
|
1096
1134
|
W_H_C = W_H_B = data.base_transform()
|
@@ -1099,9 +1137,9 @@ def inverse_dynamics(
|
|
1099
1137
|
|
1100
1138
|
case VelRepr.Mixed:
|
1101
1139
|
W_H_B = data.base_transform()
|
1102
|
-
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
1140
|
+
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
|
1103
1141
|
W_ṗ_B = data.base_velocity()[0:3]
|
1104
|
-
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
|
1142
|
+
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
|
1105
1143
|
|
1106
1144
|
case _:
|
1107
1145
|
raise ValueError(data.velocity_representation)
|
@@ -1536,15 +1574,15 @@ def link_bias_accelerations(
|
|
1536
1574
|
# a simple C_X_W 6D transform.
|
1537
1575
|
match data.velocity_representation:
|
1538
1576
|
case VelRepr.Inertial:
|
1539
|
-
W_H_C = W_H_W = jnp.eye(4)
|
1540
|
-
W_v_WC = W_v_WW = jnp.zeros(6)
|
1577
|
+
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
|
1578
|
+
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
1541
1579
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
1542
1580
|
C_v_WB = W_v_WB = data.base_velocity()
|
1543
1581
|
|
1544
1582
|
case VelRepr.Body:
|
1545
1583
|
W_H_C = W_H_B
|
1546
1584
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
1547
|
-
W_v_WC = W_v_WB = data.base_velocity()
|
1585
|
+
W_v_WC = W_v_WB = data.base_velocity() # noqa: F841
|
1548
1586
|
with data.switch_velocity_representation(VelRepr.Body):
|
1549
1587
|
C_v_WB = B_v_WB = data.base_velocity()
|
1550
1588
|
|
@@ -1555,9 +1593,9 @@ def link_bias_accelerations(
|
|
1555
1593
|
W_ṗ_B = data.base_velocity()[0:3]
|
1556
1594
|
BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
|
1557
1595
|
W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
|
1558
|
-
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW
|
1596
|
+
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841
|
1559
1597
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
1560
|
-
C_v_WB = BW_v_WB = data.base_velocity()
|
1598
|
+
C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841
|
1561
1599
|
|
1562
1600
|
case _:
|
1563
1601
|
raise ValueError(data.velocity_representation)
|
@@ -1665,8 +1703,12 @@ def link_bias_accelerations(
|
|
1665
1703
|
|
1666
1704
|
match data.velocity_representation:
|
1667
1705
|
case VelRepr.Body:
|
1668
|
-
C_H_L = L_H_L = jnp.stack(
|
1669
|
-
|
1706
|
+
C_H_L = L_H_L = jnp.stack( # noqa: F841
|
1707
|
+
[jnp.eye(4)] * model.number_of_links()
|
1708
|
+
)
|
1709
|
+
L_v_CL = L_v_LL = jnp.zeros( # noqa: F841
|
1710
|
+
shape=(model.number_of_links(), 6)
|
1711
|
+
)
|
1670
1712
|
|
1671
1713
|
case VelRepr.Inertial:
|
1672
1714
|
C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
|
@@ -1676,7 +1718,9 @@ def link_bias_accelerations(
|
|
1676
1718
|
W_H_L = js.model.forward_kinematics(model=model, data=data)
|
1677
1719
|
LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
|
1678
1720
|
C_H_L = LW_H_L
|
1679
|
-
L_v_CL = L_v_LW_L = jax.vmap(
|
1721
|
+
L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
|
1722
|
+
lambda v: v.at[0:3].set(jnp.zeros(3))
|
1723
|
+
)(L_v_WL)
|
1680
1724
|
|
1681
1725
|
case _:
|
1682
1726
|
raise ValueError(data.velocity_representation)
|
jaxsim/api/references.py
CHANGED
@@ -8,6 +8,7 @@ import jax_dataclasses
|
|
8
8
|
|
9
9
|
import jaxsim.api as js
|
10
10
|
import jaxsim.typing as jtp
|
11
|
+
from jaxsim import exceptions
|
11
12
|
from jaxsim.utils.tracing import not_tracing
|
12
13
|
|
13
14
|
from .common import VelRepr
|
@@ -30,6 +31,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
30
31
|
@staticmethod
|
31
32
|
def zero(
|
32
33
|
model: js.model.JaxSimModel,
|
34
|
+
data: js.data.JaxSimModelData | None = None,
|
33
35
|
velocity_representation: VelRepr = VelRepr.Inertial,
|
34
36
|
) -> JaxSimModelReferences:
|
35
37
|
"""
|
@@ -37,6 +39,9 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
37
39
|
|
38
40
|
Args:
|
39
41
|
model: The model for which to create the zero references.
|
42
|
+
data:
|
43
|
+
The data of the model, only needed if the velocity representation is
|
44
|
+
not inertial-fixed.
|
40
45
|
velocity_representation: The velocity representation to use.
|
41
46
|
|
42
47
|
Returns:
|
@@ -44,7 +49,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
44
49
|
"""
|
45
50
|
|
46
51
|
return JaxSimModelReferences.build(
|
47
|
-
model=model, velocity_representation=velocity_representation
|
52
|
+
model=model, data=data, velocity_representation=velocity_representation
|
48
53
|
)
|
49
54
|
|
50
55
|
@staticmethod
|
@@ -441,3 +446,104 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
441
446
|
return replace(
|
442
447
|
forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
|
443
448
|
)
|
449
|
+
|
450
|
+
def apply_frame_forces(
|
451
|
+
self,
|
452
|
+
forces: jtp.MatrixLike,
|
453
|
+
model: js.model.JaxSimModel,
|
454
|
+
data: js.data.JaxSimModelData,
|
455
|
+
frame_names: tuple[str, ...] | str | None = None,
|
456
|
+
additive: bool = False,
|
457
|
+
) -> Self:
|
458
|
+
"""
|
459
|
+
Apply the frame forces.
|
460
|
+
|
461
|
+
Args:
|
462
|
+
forces: The frame 6D forces in the active representation.
|
463
|
+
model:
|
464
|
+
The model to consider, only needed if a frame serialization different
|
465
|
+
from the implicit one is used.
|
466
|
+
data:
|
467
|
+
The data of the considered model, only needed if the velocity
|
468
|
+
representation is not inertial-fixed.
|
469
|
+
frame_names: The names of the frames corresponding to the forces.
|
470
|
+
additive:
|
471
|
+
Whether to add the forces to the existing ones instead of replacing them.
|
472
|
+
|
473
|
+
Returns:
|
474
|
+
A new `JaxSimModelReferences` object with the given frame forces.
|
475
|
+
|
476
|
+
Note:
|
477
|
+
The frame forces must be expressed in the active representation.
|
478
|
+
Then, we always convert and store forces in inertial-fixed representation.
|
479
|
+
"""
|
480
|
+
|
481
|
+
f_F = jnp.atleast_2d(forces).astype(float)
|
482
|
+
|
483
|
+
# If we have the model, we can extract the frame names if not provided.
|
484
|
+
frame_names = frame_names if frame_names is not None else model.frame_names()
|
485
|
+
|
486
|
+
# Make sure that the frame names are a tuple if they are provided by the user.
|
487
|
+
frame_names = (frame_names,) if isinstance(frame_names, str) else frame_names
|
488
|
+
|
489
|
+
if len(frame_names) != f_F.shape[0]:
|
490
|
+
msg = "The number of frame names ({}) must match the number of forces ({})"
|
491
|
+
raise ValueError(msg.format(len(frame_names), f_F.shape[0]))
|
492
|
+
|
493
|
+
# Extract the frame indices.
|
494
|
+
frame_idxs = js.frame.names_to_idxs(frame_names=frame_names, model=model)
|
495
|
+
parent_link_idxs = jax.vmap(js.frame.idx_of_parent_link, in_axes=(None,))(
|
496
|
+
model, frame_index=frame_idxs
|
497
|
+
)
|
498
|
+
|
499
|
+
exceptions.raise_value_error_if(
|
500
|
+
condition=jnp.logical_not(data.valid(model=model)),
|
501
|
+
msg="The provided data is not valid for the model",
|
502
|
+
)
|
503
|
+
W_H_Fi = jax.vmap(
|
504
|
+
lambda frame_idx: js.frame.transform(
|
505
|
+
model=model, data=data, frame_index=frame_idx
|
506
|
+
)
|
507
|
+
)(frame_idxs)
|
508
|
+
|
509
|
+
# Helper function to convert a single 6D force to the inertial representation
|
510
|
+
# considering as body the frame (i.e. L_f_F and LW_f_F).
|
511
|
+
def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix:
|
512
|
+
return JaxSimModelReferences.other_representation_to_inertial(
|
513
|
+
array=f_F,
|
514
|
+
other_representation=self.velocity_representation,
|
515
|
+
transform=W_H_F,
|
516
|
+
is_force=True,
|
517
|
+
)
|
518
|
+
|
519
|
+
match self.velocity_representation:
|
520
|
+
case VelRepr.Inertial:
|
521
|
+
W_f_F = f_F
|
522
|
+
|
523
|
+
case VelRepr.Body | VelRepr.Mixed:
|
524
|
+
W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi)
|
525
|
+
|
526
|
+
case _:
|
527
|
+
raise ValueError("Invalid velocity representation.")
|
528
|
+
|
529
|
+
# Sum the forces on the parent links.
|
530
|
+
mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links())
|
531
|
+
W_f_L = mask.T @ W_f_F
|
532
|
+
|
533
|
+
with self.switch_velocity_representation(
|
534
|
+
velocity_representation=VelRepr.Inertial
|
535
|
+
):
|
536
|
+
references = self.apply_link_forces(
|
537
|
+
model=model,
|
538
|
+
data=data,
|
539
|
+
link_names=js.link.idxs_to_names(
|
540
|
+
model=model, link_indices=parent_link_idxs
|
541
|
+
),
|
542
|
+
forces=W_f_L,
|
543
|
+
additive=additive,
|
544
|
+
)
|
545
|
+
|
546
|
+
with references.switch_velocity_representation(
|
547
|
+
velocity_representation=self.velocity_representation
|
548
|
+
):
|
549
|
+
return references
|
jaxsim/integrators/common.py
CHANGED
@@ -261,8 +261,10 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
261
261
|
# Check if the Butcher tableau supports FSAL (first-same-as-last).
|
262
262
|
# If it does, store the index of the intermediate derivative to be used as the
|
263
263
|
# first derivative of the next iteration.
|
264
|
-
has_fsal, index_of_fsal =
|
265
|
-
|
264
|
+
has_fsal, index_of_fsal = ( # noqa: F841
|
265
|
+
ExplicitRungeKutta.butcher_tableau_supports_fsal(
|
266
|
+
A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
|
267
|
+
)
|
266
268
|
)
|
267
269
|
|
268
270
|
# Build the integrator object.
|
jaxsim/math/adjoint.py
CHANGED
@@ -3,7 +3,6 @@ import jaxlie
|
|
3
3
|
|
4
4
|
import jaxsim.typing as jtp
|
5
5
|
|
6
|
-
from .quaternion import Quaternion
|
7
6
|
from .skew import Skew
|
8
7
|
|
9
8
|
|
@@ -31,7 +30,7 @@ class Adjoint:
|
|
31
30
|
assert quaternion.size == 4
|
32
31
|
assert translation.size == 3
|
33
32
|
|
34
|
-
Q_sixd = jaxlie.SO3
|
33
|
+
Q_sixd = jaxlie.SO3(wxyz=quaternion)
|
35
34
|
Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()
|
36
35
|
|
37
36
|
return Adjoint.from_rotation_and_translation(
|
@@ -84,14 +83,14 @@ class Adjoint:
|
|
84
83
|
A_o_B = translation.squeeze()
|
85
84
|
|
86
85
|
if not inverse:
|
87
|
-
X = A_X_B = jnp.vstack(
|
86
|
+
X = A_X_B = jnp.vstack( # noqa: F841
|
88
87
|
[
|
89
88
|
jnp.block([A_R_B, Skew.wedge(A_o_B) @ A_R_B]),
|
90
89
|
jnp.block([jnp.zeros(shape=(3, 3)), A_R_B]),
|
91
90
|
]
|
92
91
|
)
|
93
92
|
else:
|
94
|
-
X = B_X_A = jnp.vstack(
|
93
|
+
X = B_X_A = jnp.vstack( # noqa: F841
|
95
94
|
[
|
96
95
|
jnp.block([A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)]),
|
97
96
|
jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]),
|
jaxsim/math/joint_model.py
CHANGED
@@ -254,8 +254,8 @@ def supported_joint_motion(
|
|
254
254
|
# This is a metadata required by only some joint types.
|
255
255
|
axis = jnp.array(joint_axis).astype(float).squeeze()
|
256
256
|
|
257
|
-
pre_H_suc = jaxlie.SE3.
|
258
|
-
|
257
|
+
pre_H_suc = jaxlie.SE3.from_matrix(
|
258
|
+
matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis))
|
259
259
|
)
|
260
260
|
|
261
261
|
S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))
|
jaxsim/math/quaternion.py
CHANGED
@@ -43,9 +43,7 @@ class Quaternion:
|
|
43
43
|
Returns:
|
44
44
|
jtp.Matrix: Direction cosine matrix (DCM).
|
45
45
|
"""
|
46
|
-
return jaxlie.SO3.
|
47
|
-
xyzw=Quaternion.to_xyzw(quaternion)
|
48
|
-
).as_matrix()
|
46
|
+
return jaxlie.SO3(wxyz=quaternion).as_matrix()
|
49
47
|
|
50
48
|
@staticmethod
|
51
49
|
def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
|
@@ -158,7 +156,7 @@ class Quaternion:
|
|
158
156
|
A_Q_B = jnp.array(quaternion).squeeze().astype(float)
|
159
157
|
|
160
158
|
# Build the initial SO(3) quaternion.
|
161
|
-
W_Q_B_t0 = jaxlie.SO3
|
159
|
+
W_Q_B_t0 = jaxlie.SO3(wxyz=A_Q_B)
|
162
160
|
|
163
161
|
# Integrate the quaternion on the manifold.
|
164
162
|
W_Q_B_tf = jax.lax.select(
|
jaxsim/math/transform.py
CHANGED
@@ -3,8 +3,6 @@ import jaxlie
|
|
3
3
|
|
4
4
|
import jaxsim.typing as jtp
|
5
5
|
|
6
|
-
from .quaternion import Quaternion
|
7
|
-
|
8
6
|
|
9
7
|
class Transform:
|
10
8
|
|
@@ -35,7 +33,7 @@ class Transform:
|
|
35
33
|
assert W_p_B.size == 3
|
36
34
|
assert W_Q_B.size == 4
|
37
35
|
|
38
|
-
A_R_B = jaxlie.SO3
|
36
|
+
A_R_B = jaxlie.SO3(wxyz=W_Q_B)
|
39
37
|
A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
|
40
38
|
|
41
39
|
A_H_B = jaxlie.SE3.from_rotation_and_translation(
|
jaxsim/mujoco/loaders.py
CHANGED
@@ -532,7 +532,12 @@ class UrdfToMjcf:
|
|
532
532
|
model_name: str | None = None,
|
533
533
|
plane_normal: tuple[float, float, float] = (0, 0, 1),
|
534
534
|
heightmap: bool | None = None,
|
535
|
-
cameras:
|
535
|
+
cameras: (
|
536
|
+
MujocoCamera
|
537
|
+
| Sequence[MujocoCamera]
|
538
|
+
| dict[str, str]
|
539
|
+
| Sequence[dict[str, str]]
|
540
|
+
) = (),
|
536
541
|
) -> tuple[str, dict[str, Any]]:
|
537
542
|
"""
|
538
543
|
Converts a URDF file to a Mujoco MJCF string.
|
@@ -574,7 +579,12 @@ class SdfToMjcf:
|
|
574
579
|
model_name: str | None = None,
|
575
580
|
plane_normal: tuple[float, float, float] = (0, 0, 1),
|
576
581
|
heightmap: bool | None = None,
|
577
|
-
cameras:
|
582
|
+
cameras: (
|
583
|
+
MujocoCamera
|
584
|
+
| Sequence[MujocoCamera]
|
585
|
+
| dict[str, str]
|
586
|
+
| Sequence[dict[str, str]]
|
587
|
+
) = (),
|
578
588
|
) -> tuple[str, dict[str, Any]]:
|
579
589
|
"""
|
580
590
|
Converts a SDF file to a Mujoco MJCF string.
|
jaxsim/mujoco/model.py
CHANGED
@@ -238,9 +238,9 @@ class MujocoModelHelper:
|
|
238
238
|
raise ValueError("The orientation is not a valid element of SO(3)")
|
239
239
|
|
240
240
|
W_Q_B = (
|
241
|
-
Rotation.from_matrix(orientation).as_quat(
|
242
|
-
|
243
|
-
|
241
|
+
Rotation.from_matrix(orientation).as_quat(
|
242
|
+
canonical=True, scalar_first=False
|
243
|
+
)
|
244
244
|
if dcm
|
245
245
|
else orientation
|
246
246
|
)
|
@@ -394,8 +394,8 @@ class MujocoModelHelper:
|
|
394
394
|
if dcm:
|
395
395
|
return R
|
396
396
|
|
397
|
-
q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True)
|
398
|
-
return q_xyzw
|
397
|
+
q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True, scalar_first=False)
|
398
|
+
return q_xyzw
|
399
399
|
|
400
400
|
# ===============
|
401
401
|
# Private methods
|
jaxsim/rbda/aba.py
CHANGED
@@ -4,7 +4,7 @@ import jaxlie
|
|
4
4
|
|
5
5
|
import jaxsim.api as js
|
6
6
|
import jaxsim.typing as jtp
|
7
|
-
from jaxsim.math import Adjoint, Cross,
|
7
|
+
from jaxsim.math import Adjoint, Cross, StandardGravity
|
8
8
|
|
9
9
|
from . import utils
|
10
10
|
|
@@ -77,7 +77,7 @@ def aba(
|
|
77
77
|
|
78
78
|
# Compute the base transform.
|
79
79
|
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
80
|
-
rotation=jaxlie.SO3
|
80
|
+
rotation=jaxlie.SO3(wxyz=W_Q_B),
|
81
81
|
translation=W_p_B,
|
82
82
|
)
|
83
83
|
|
jaxsim/rbda/collidable_points.py
CHANGED
@@ -4,7 +4,7 @@ import jaxlie
|
|
4
4
|
|
5
5
|
import jaxsim.api as js
|
6
6
|
import jaxsim.typing as jtp
|
7
|
-
from jaxsim.math import Adjoint,
|
7
|
+
from jaxsim.math import Adjoint, Skew
|
8
8
|
|
9
9
|
from . import utils
|
10
10
|
|
@@ -57,7 +57,7 @@ def collidable_points_pos_vel(
|
|
57
57
|
|
58
58
|
# Compute the base transform.
|
59
59
|
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
60
|
-
rotation=jaxlie.SO3
|
60
|
+
rotation=jaxlie.SO3(wxyz=W_Q_B),
|
61
61
|
translation=W_p_B,
|
62
62
|
)
|
63
63
|
|
jaxsim/rbda/contacts/soft.py
CHANGED
@@ -192,7 +192,7 @@ class SoftContacts(ContactModel):
|
|
192
192
|
|
193
193
|
# Unpack the position of the collidable point.
|
194
194
|
px, py, pz = W_p_C = position.squeeze()
|
195
|
-
|
195
|
+
W_ṗ_C = velocity.squeeze()
|
196
196
|
|
197
197
|
# Compute the terrain normal and the contact depth.
|
198
198
|
n̂ = self.terrain.normal(x=px, y=py).squeeze()
|
jaxsim/rbda/crba.py
CHANGED
@@ -59,10 +59,14 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
|
|
59
59
|
|
60
60
|
return (i_X_0,), None
|
61
61
|
|
62
|
-
(i_X_0,), _ =
|
63
|
-
|
64
|
-
|
65
|
-
|
62
|
+
(i_X_0,), _ = (
|
63
|
+
jax.lax.scan(
|
64
|
+
f=propagate_kinematics,
|
65
|
+
init=forward_pass_carry,
|
66
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
67
|
+
)
|
68
|
+
if model.number_of_links() > 1
|
69
|
+
else [(i_X_0,), None]
|
66
70
|
)
|
67
71
|
|
68
72
|
# ===================
|
@@ -128,10 +132,14 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
|
|
128
132
|
operand=carry,
|
129
133
|
)
|
130
134
|
|
131
|
-
(j, Fi, M), _ =
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
+
(j, Fi, M), _ = (
|
136
|
+
jax.lax.scan(
|
137
|
+
f=inner_fn,
|
138
|
+
init=carry_inner_fn,
|
139
|
+
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
|
140
|
+
)
|
141
|
+
if model.number_of_links() > 1
|
142
|
+
else [(j, Fi, M), None]
|
135
143
|
)
|
136
144
|
|
137
145
|
Fi = i_X_0[j].T @ Fi
|
@@ -143,10 +151,14 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
|
|
143
151
|
|
144
152
|
# This scan performs the backward pass to compute Mbj, Mjb and Mjj, that
|
145
153
|
# also includes a fake while loop implemented with a scan and two cond.
|
146
|
-
(Mc, M), _ =
|
147
|
-
|
148
|
-
|
149
|
-
|
154
|
+
(Mc, M), _ = (
|
155
|
+
jax.lax.scan(
|
156
|
+
f=backward_pass,
|
157
|
+
init=backward_pass_carry,
|
158
|
+
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
|
159
|
+
)
|
160
|
+
if model.number_of_links() > 1
|
161
|
+
else [(Mc, M), None]
|
150
162
|
)
|
151
163
|
|
152
164
|
# Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.
|
@@ -4,7 +4,7 @@ import jaxlie
|
|
4
4
|
|
5
5
|
import jaxsim.api as js
|
6
6
|
import jaxsim.typing as jtp
|
7
|
-
from jaxsim.math import Adjoint
|
7
|
+
from jaxsim.math import Adjoint
|
8
8
|
|
9
9
|
from . import utils
|
10
10
|
|
@@ -42,7 +42,7 @@ def forward_kinematics_model(
|
|
42
42
|
|
43
43
|
# Compute the base transform.
|
44
44
|
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
45
|
-
rotation=jaxlie.SO3
|
45
|
+
rotation=jaxlie.SO3(wxyz=W_Q_B),
|
46
46
|
translation=W_p_B,
|
47
47
|
)
|
48
48
|
|
@@ -75,10 +75,14 @@ def forward_kinematics_model(
|
|
75
75
|
|
76
76
|
return (W_X_i,), None
|
77
77
|
|
78
|
-
(W_X_i,), _ =
|
79
|
-
|
80
|
-
|
81
|
-
|
78
|
+
(W_X_i,), _ = (
|
79
|
+
jax.lax.scan(
|
80
|
+
f=propagate_kinematics,
|
81
|
+
init=propagate_kinematics_carry,
|
82
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
83
|
+
)
|
84
|
+
if model.number_of_links() > 1
|
85
|
+
else [(W_X_i,), None]
|
82
86
|
)
|
83
87
|
|
84
88
|
return jax.vmap(Adjoint.to_transform)(W_X_i)
|