jaxsim 0.5.1.dev126__py3-none-any.whl → 0.5.1.dev139__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +0 -7
- jaxsim/_version.py +2 -2
- jaxsim/api/com.py +1 -1
- jaxsim/api/common.py +1 -1
- jaxsim/api/contact.py +3 -0
- jaxsim/api/data.py +2 -1
- jaxsim/api/kin_dyn_parameters.py +18 -1
- jaxsim/api/model.py +7 -4
- jaxsim/api/ode.py +21 -1
- jaxsim/exceptions.py +8 -0
- jaxsim/integrators/common.py +72 -11
- jaxsim/integrators/fixed_step.py +91 -40
- jaxsim/integrators/variable_step.py +117 -46
- jaxsim/math/adjoint.py +19 -10
- jaxsim/math/cross.py +6 -2
- jaxsim/math/inertia.py +8 -4
- jaxsim/math/quaternion.py +10 -6
- jaxsim/math/rotation.py +6 -3
- jaxsim/math/skew.py +2 -2
- jaxsim/math/transform.py +12 -4
- jaxsim/math/utils.py +2 -2
- jaxsim/mujoco/loaders.py +17 -7
- jaxsim/mujoco/model.py +15 -15
- jaxsim/mujoco/utils.py +6 -1
- jaxsim/mujoco/visualizer.py +11 -7
- jaxsim/parsers/descriptions/collision.py +7 -4
- jaxsim/parsers/descriptions/joint.py +16 -14
- jaxsim/parsers/descriptions/model.py +1 -1
- jaxsim/parsers/kinematic_graph.py +38 -0
- jaxsim/parsers/rod/meshes.py +5 -5
- jaxsim/parsers/rod/parser.py +1 -1
- jaxsim/parsers/rod/utils.py +11 -0
- jaxsim/rbda/contacts/common.py +2 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +7 -4
- jaxsim/rbda/contacts/rigid.py +8 -4
- jaxsim/rbda/contacts/soft.py +37 -0
- jaxsim/rbda/contacts/visco_elastic.py +1 -0
- jaxsim/terrain/terrain.py +52 -0
- jaxsim/utils/jaxsim_dataclass.py +3 -3
- jaxsim/utils/tracing.py +2 -2
- jaxsim/utils/wrappers.py +9 -0
- {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/METADATA +1 -1
- jaxsim-0.5.1.dev139.dist-info/RECORD +74 -0
- jaxsim-0.5.1.dev126.dist-info/RECORD +0 -74
- {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/LICENSE +0 -0
- {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/WHEEL +0 -0
- {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import functools
|
2
3
|
from typing import Any, ClassVar, Generic
|
3
4
|
|
@@ -216,6 +217,17 @@ def local_error_estimation(
|
|
216
217
|
|
217
218
|
@jax_dataclasses.pytree_dataclass
|
218
219
|
class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
220
|
+
"""
|
221
|
+
An Embedded Runge-Kutta integrator.
|
222
|
+
|
223
|
+
This class implements a general-purpose Embedded Runge-Kutta integrator
|
224
|
+
that can be used to solve ordinary differential equations with adaptive
|
225
|
+
step sizes.
|
226
|
+
|
227
|
+
The integrator is based on an Explicit Runge-Kutta method, and it uses
|
228
|
+
two different solutions to estimate the local integration error. The
|
229
|
+
error is then used to adapt the step size to reach a desired accuracy.
|
230
|
+
"""
|
219
231
|
|
220
232
|
AfterInitKey: ClassVar[str] = "after_init"
|
221
233
|
InitializingKey: ClassVar[str] = "initializing"
|
@@ -243,6 +255,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
243
255
|
# Maximum number of rejected steps when the Δt needs to be reduced.
|
244
256
|
max_step_rejections: Static[jtp.IntLike] = MAX_STEP_REJECTIONS_DEFAULT
|
245
257
|
|
258
|
+
index_of_fsal: jtp.IntLike | None = None
|
259
|
+
fsal_enabled_if_supported: bool = False
|
260
|
+
|
246
261
|
def init(
|
247
262
|
self,
|
248
263
|
x0: State,
|
@@ -257,6 +272,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
257
272
|
x0: The initial state of the system.
|
258
273
|
t0: The initial time of the system.
|
259
274
|
dt: The time step of the integration.
|
275
|
+
**kwargs: Additional parameters.
|
260
276
|
|
261
277
|
Returns:
|
262
278
|
The metadata of the integrator to be passed to the first step.
|
@@ -296,6 +312,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
296
312
|
def __call__(
|
297
313
|
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
298
314
|
) -> tuple[NextState, dict[str, Any]]:
|
315
|
+
"""
|
316
|
+
Integrate the system for a single step.
|
317
|
+
"""
|
299
318
|
|
300
319
|
# This method is called differently in three stages:
|
301
320
|
#
|
@@ -512,10 +531,16 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
512
531
|
|
513
532
|
@property
|
514
533
|
def order_of_solution(self) -> int:
|
534
|
+
"""
|
535
|
+
The order of the solution.
|
536
|
+
"""
|
515
537
|
return self.order_of_bT_rows[self.row_index_of_solution]
|
516
538
|
|
517
539
|
@property
|
518
540
|
def order_of_solution_estimate(self) -> int:
|
541
|
+
"""
|
542
|
+
The order of the solution estimate.
|
543
|
+
"""
|
519
544
|
return self.order_of_bT_rows[self.row_index_of_solution_estimate]
|
520
545
|
|
521
546
|
@classmethod
|
@@ -534,17 +559,36 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
534
559
|
max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT,
|
535
560
|
**kwargs,
|
536
561
|
) -> Self:
|
562
|
+
"""
|
563
|
+
Build an Embedded Runge-Kutta integrator.
|
564
|
+
|
565
|
+
Args:
|
566
|
+
dynamics: The system dynamics function.
|
567
|
+
fsal_enabled_if_supported:
|
568
|
+
Whether to enable the FSAL property if supported by the integrator.
|
569
|
+
dt_max: The maximum step size.
|
570
|
+
dt_min: The minimum step size.
|
571
|
+
rtol: The relative tolerance.
|
572
|
+
atol: The absolute tolerance.
|
573
|
+
safety: The safety factor to shrink the step size.
|
574
|
+
beta_max: The maximum factor to increase the step size.
|
575
|
+
beta_min: The minimum factor to increase the step size.
|
576
|
+
max_step_rejections: The maximum number of step rejections.
|
577
|
+
**kwargs: Additional parameters.
|
578
|
+
"""
|
579
|
+
|
580
|
+
b = cls.__dataclass_fields__["b"].default_factory()
|
537
581
|
|
538
582
|
# Check that b.T has enough rows based on the configured index of the
|
539
583
|
# solution estimate. This is necessary for embedded methods.
|
540
584
|
if (
|
541
585
|
cls.row_index_of_solution_estimate is not None
|
542
|
-
and cls.row_index_of_solution_estimate >=
|
586
|
+
and cls.row_index_of_solution_estimate >= b.T.shape[0]
|
543
587
|
):
|
544
588
|
msg = "The index of the solution estimate ({}-th row of `b.T`) "
|
545
589
|
msg += "is out of range ({})."
|
546
590
|
raise ValueError(
|
547
|
-
msg.format(cls.row_index_of_solution_estimate,
|
591
|
+
msg.format(cls.row_index_of_solution_estimate, b.T.shape[0])
|
548
592
|
)
|
549
593
|
|
550
594
|
integrator = super().build(
|
@@ -569,65 +613,92 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
569
613
|
|
570
614
|
@jax_dataclasses.pytree_dataclass
|
571
615
|
class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
|
616
|
+
"""
|
617
|
+
The Heun-Euler integrator for SO(3) dynamics.
|
618
|
+
"""
|
572
619
|
|
573
|
-
A:
|
574
|
-
|
575
|
-
[
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
620
|
+
A: jtp.Matrix = dataclasses.field(
|
621
|
+
default_factory=lambda: jnp.array(
|
622
|
+
[
|
623
|
+
[0, 0],
|
624
|
+
[1, 0],
|
625
|
+
]
|
626
|
+
).astype(float),
|
627
|
+
compare=False,
|
628
|
+
)
|
629
|
+
|
630
|
+
b: jtp.Matrix = dataclasses.field(
|
631
|
+
default_factory=lambda: (
|
632
|
+
jnp.atleast_2d(
|
633
|
+
jnp.array(
|
634
|
+
[
|
635
|
+
[1 / 2, 1 / 2],
|
636
|
+
[1, 0],
|
637
|
+
]
|
638
|
+
),
|
639
|
+
)
|
640
|
+
.astype(float)
|
641
|
+
.transpose()
|
642
|
+
),
|
643
|
+
compare=False,
|
591
644
|
)
|
592
645
|
|
593
|
-
c:
|
594
|
-
|
595
|
-
|
646
|
+
c: jtp.Vector = dataclasses.field(
|
647
|
+
default_factory=lambda: jnp.array(
|
648
|
+
[0, 1],
|
649
|
+
).astype(float),
|
650
|
+
compare=False,
|
651
|
+
)
|
596
652
|
|
597
653
|
row_index_of_solution: ClassVar[int] = 0
|
598
654
|
row_index_of_solution_estimate: ClassVar[int | None] = 1
|
599
655
|
|
600
656
|
order_of_bT_rows: ClassVar[tuple[int, ...]] = (2, 1)
|
601
657
|
|
658
|
+
index_of_fsal: jtp.IntLike | None = None
|
659
|
+
fsal_enabled_if_supported: bool = False
|
660
|
+
|
602
661
|
|
603
662
|
@jax_dataclasses.pytree_dataclass
|
604
663
|
class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
|
664
|
+
"""
|
665
|
+
The Bogacki-Shampine integrator for SO(3) dynamics.
|
666
|
+
"""
|
605
667
|
|
606
|
-
A:
|
607
|
-
|
608
|
-
[
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
jnp.atleast_2d(
|
617
|
-
jnp.array(
|
618
|
-
[
|
619
|
-
[2 / 9, 1 / 3, 4 / 9, 0],
|
620
|
-
[7 / 24, 1 / 4, 1 / 3, 1 / 8],
|
621
|
-
]
|
622
|
-
),
|
623
|
-
)
|
624
|
-
.astype(float)
|
625
|
-
.transpose()
|
668
|
+
A: jtp.Matrix = dataclasses.field(
|
669
|
+
default_factory=lambda: jnp.array(
|
670
|
+
[
|
671
|
+
[0, 0, 0, 0],
|
672
|
+
[1 / 2, 0, 0, 0],
|
673
|
+
[0, 3 / 4, 0, 0],
|
674
|
+
[2 / 9, 1 / 3, 4 / 9, 0],
|
675
|
+
]
|
676
|
+
).astype(float),
|
677
|
+
compare=False,
|
626
678
|
)
|
627
679
|
|
628
|
-
|
629
|
-
|
630
|
-
|
680
|
+
b: jtp.Matrix = dataclasses.field(
|
681
|
+
default_factory=lambda: (
|
682
|
+
jnp.atleast_2d(
|
683
|
+
jnp.array(
|
684
|
+
[
|
685
|
+
[2 / 9, 1 / 3, 4 / 9, 0],
|
686
|
+
[7 / 24, 1 / 4, 1 / 3, 1 / 8],
|
687
|
+
]
|
688
|
+
),
|
689
|
+
)
|
690
|
+
.astype(float)
|
691
|
+
.transpose()
|
692
|
+
),
|
693
|
+
compare=False,
|
694
|
+
)
|
695
|
+
|
696
|
+
c: jtp.Vector = dataclasses.field(
|
697
|
+
default_factory=lambda: jnp.array(
|
698
|
+
[0, 1 / 2, 3 / 4, 1],
|
699
|
+
).astype(float),
|
700
|
+
compare=False,
|
701
|
+
)
|
631
702
|
|
632
703
|
row_index_of_solution: ClassVar[int] = 0
|
633
704
|
row_index_of_solution_estimate: ClassVar[int | None] = 1
|
jaxsim/math/adjoint.py
CHANGED
@@ -7,10 +7,14 @@ from .skew import Skew
|
|
7
7
|
|
8
8
|
|
9
9
|
class Adjoint:
|
10
|
+
"""
|
11
|
+
A utility class for adjoint matrix operations.
|
12
|
+
"""
|
13
|
+
|
10
14
|
@staticmethod
|
11
15
|
def from_quaternion_and_translation(
|
12
|
-
quaternion: jtp.Vector
|
13
|
-
translation: jtp.Vector =
|
16
|
+
quaternion: jtp.Vector | None = None,
|
17
|
+
translation: jtp.Vector | None = None,
|
14
18
|
inverse: bool = False,
|
15
19
|
normalize_quaternion: bool = False,
|
16
20
|
) -> jtp.Matrix:
|
@@ -18,8 +22,8 @@ class Adjoint:
|
|
18
22
|
Create an adjoint matrix from a quaternion and a translation.
|
19
23
|
|
20
24
|
Args:
|
21
|
-
quaternion (jtp.Vector): A quaternion vector (4D) representing orientation.
|
22
|
-
translation (jtp.Vector): A translation vector (3D).
|
25
|
+
quaternion (jtp.Vector): A quaternion vector (4D) representing orientation. Default is [1, 0, 0, 0].
|
26
|
+
translation (jtp.Vector): A translation vector (3D). Default is [0, 0, 0].
|
23
27
|
inverse (bool): Whether to compute the inverse adjoint. Default is False.
|
24
28
|
normalize_quaternion (bool): Whether to normalize the quaternion before creating the adjoint.
|
25
29
|
Default is False.
|
@@ -27,6 +31,8 @@ class Adjoint:
|
|
27
31
|
Returns:
|
28
32
|
jtp.Matrix: The adjoint matrix.
|
29
33
|
"""
|
34
|
+
quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
|
35
|
+
translation = translation if translation is not None else jnp.zeros(3)
|
30
36
|
assert quaternion.size == 4
|
31
37
|
assert translation.size == 3
|
32
38
|
|
@@ -61,21 +67,24 @@ class Adjoint:
|
|
61
67
|
|
62
68
|
@staticmethod
|
63
69
|
def from_rotation_and_translation(
|
64
|
-
rotation: jtp.Matrix =
|
65
|
-
translation: jtp.Vector =
|
70
|
+
rotation: jtp.Matrix | None = None,
|
71
|
+
translation: jtp.Vector | None = None,
|
66
72
|
inverse: bool = False,
|
67
73
|
) -> jtp.Matrix:
|
68
74
|
"""
|
69
75
|
Create an adjoint matrix from a rotation matrix and a translation vector.
|
70
76
|
|
71
77
|
Args:
|
72
|
-
rotation (jtp.Matrix): A 3x3 rotation matrix.
|
73
|
-
translation (jtp.Vector): A translation vector (3D).
|
78
|
+
rotation (jtp.Matrix): A 3x3 rotation matrix. Default is identity.
|
79
|
+
translation (jtp.Vector): A translation vector (3D). Default is [0, 0, 0].
|
74
80
|
inverse (bool): Whether to compute the inverse adjoint. Default is False.
|
75
81
|
|
76
82
|
Returns:
|
77
83
|
jtp.Matrix: The adjoint matrix.
|
78
84
|
"""
|
85
|
+
rotation = rotation if rotation is not None else jnp.eye(3)
|
86
|
+
translation = translation if translation is not None else jnp.zeros(3)
|
87
|
+
|
79
88
|
assert rotation.shape == (3, 3)
|
80
89
|
assert translation.size == 3
|
81
90
|
|
@@ -105,7 +114,7 @@ class Adjoint:
|
|
105
114
|
Convert an adjoint matrix to a transformation matrix.
|
106
115
|
|
107
116
|
Args:
|
108
|
-
adjoint
|
117
|
+
adjoint: The adjoint matrix (6x6).
|
109
118
|
|
110
119
|
Returns:
|
111
120
|
jtp.Matrix: The transformation matrix (4x4).
|
@@ -131,7 +140,7 @@ class Adjoint:
|
|
131
140
|
Compute the inverse of an adjoint matrix.
|
132
141
|
|
133
142
|
Args:
|
134
|
-
adjoint
|
143
|
+
adjoint: The adjoint matrix.
|
135
144
|
|
136
145
|
Returns:
|
137
146
|
jtp.Matrix: The inverse adjoint matrix.
|
jaxsim/math/cross.py
CHANGED
@@ -6,13 +6,17 @@ from .skew import Skew
|
|
6
6
|
|
7
7
|
|
8
8
|
class Cross:
|
9
|
+
"""
|
10
|
+
A utility class for cross product matrix operations.
|
11
|
+
"""
|
12
|
+
|
9
13
|
@staticmethod
|
10
14
|
def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
|
11
15
|
"""
|
12
16
|
Compute the cross product matrix for 6D velocities.
|
13
17
|
|
14
18
|
Args:
|
15
|
-
velocity_sixd
|
19
|
+
velocity_sixd: A 6D velocity vector [v, ω].
|
16
20
|
|
17
21
|
Returns:
|
18
22
|
jtp.Matrix: The cross product matrix (6x6).
|
@@ -37,7 +41,7 @@ class Cross:
|
|
37
41
|
Compute the negative transpose of the cross product matrix for 6D velocities.
|
38
42
|
|
39
43
|
Args:
|
40
|
-
velocity_sixd
|
44
|
+
velocity_sixd: A 6D velocity vector [v, ω].
|
41
45
|
|
42
46
|
Returns:
|
43
47
|
jtp.Matrix: The negative transpose of the cross product matrix (6x6).
|
jaxsim/math/inertia.py
CHANGED
@@ -6,15 +6,19 @@ from .skew import Skew
|
|
6
6
|
|
7
7
|
|
8
8
|
class Inertia:
|
9
|
+
"""
|
10
|
+
A utility class for inertia matrix operations.
|
11
|
+
"""
|
12
|
+
|
9
13
|
@staticmethod
|
10
14
|
def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:
|
11
15
|
"""
|
12
16
|
Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix.
|
13
17
|
|
14
18
|
Args:
|
15
|
-
mass
|
16
|
-
com
|
17
|
-
I
|
19
|
+
mass: The mass of the body.
|
20
|
+
com: The center of mass position (3D).
|
21
|
+
I: The 3x3 inertia matrix.
|
18
22
|
|
19
23
|
Returns:
|
20
24
|
jtp.Matrix: The 6x6 inertia matrix.
|
@@ -42,7 +46,7 @@ class Inertia:
|
|
42
46
|
Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.
|
43
47
|
|
44
48
|
Args:
|
45
|
-
M
|
49
|
+
M: The 6x6 inertia matrix.
|
46
50
|
|
47
51
|
Returns:
|
48
52
|
tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
|
jaxsim/math/quaternion.py
CHANGED
@@ -8,13 +8,17 @@ from .utils import safe_norm
|
|
8
8
|
|
9
9
|
|
10
10
|
class Quaternion:
|
11
|
+
"""
|
12
|
+
A utility class for quaternion operations.
|
13
|
+
"""
|
14
|
+
|
11
15
|
@staticmethod
|
12
16
|
def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:
|
13
17
|
"""
|
14
18
|
Convert a quaternion from WXYZ to XYZW representation.
|
15
19
|
|
16
20
|
Args:
|
17
|
-
wxyz
|
21
|
+
wxyz: Quaternion in WXYZ representation.
|
18
22
|
|
19
23
|
Returns:
|
20
24
|
jtp.Vector: Quaternion in XYZW representation.
|
@@ -27,7 +31,7 @@ class Quaternion:
|
|
27
31
|
Convert a quaternion from XYZW to WXYZ representation.
|
28
32
|
|
29
33
|
Args:
|
30
|
-
xyzw
|
34
|
+
xyzw: Quaternion in XYZW representation.
|
31
35
|
|
32
36
|
Returns:
|
33
37
|
jtp.Vector: Quaternion in WXYZ representation.
|
@@ -40,7 +44,7 @@ class Quaternion:
|
|
40
44
|
Convert a quaternion to a direction cosine matrix (DCM).
|
41
45
|
|
42
46
|
Args:
|
43
|
-
quaternion
|
47
|
+
quaternion: Quaternion in XYZW representation.
|
44
48
|
|
45
49
|
Returns:
|
46
50
|
jtp.Matrix: Direction cosine matrix (DCM).
|
@@ -53,7 +57,7 @@ class Quaternion:
|
|
53
57
|
Convert a direction cosine matrix (DCM) to a quaternion.
|
54
58
|
|
55
59
|
Args:
|
56
|
-
dcm
|
60
|
+
dcm: Direction cosine matrix (DCM).
|
57
61
|
|
58
62
|
Returns:
|
59
63
|
jtp.Vector: Quaternion in XYZW representation.
|
@@ -71,8 +75,8 @@ class Quaternion:
|
|
71
75
|
Compute the derivative of a quaternion given angular velocity.
|
72
76
|
|
73
77
|
Args:
|
74
|
-
quaternion
|
75
|
-
omega
|
78
|
+
quaternion: Quaternion in XYZW representation.
|
79
|
+
omega: Angular velocity vector.
|
76
80
|
omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame.
|
77
81
|
K (float): A scaling factor.
|
78
82
|
|
jaxsim/math/rotation.py
CHANGED
@@ -8,6 +8,9 @@ from .utils import safe_norm
|
|
8
8
|
|
9
9
|
|
10
10
|
class Rotation:
|
11
|
+
"""
|
12
|
+
A utility class for rotation matrix operations.
|
13
|
+
"""
|
11
14
|
|
12
15
|
@staticmethod
|
13
16
|
def x(theta: jtp.Float) -> jtp.Matrix:
|
@@ -15,7 +18,7 @@ class Rotation:
|
|
15
18
|
Generate a 3D rotation matrix around the X-axis.
|
16
19
|
|
17
20
|
Args:
|
18
|
-
theta
|
21
|
+
theta: Rotation angle in radians.
|
19
22
|
|
20
23
|
Returns:
|
21
24
|
jtp.Matrix: 3D rotation matrix.
|
@@ -29,7 +32,7 @@ class Rotation:
|
|
29
32
|
Generate a 3D rotation matrix around the Y-axis.
|
30
33
|
|
31
34
|
Args:
|
32
|
-
theta
|
35
|
+
theta: Rotation angle in radians.
|
33
36
|
|
34
37
|
Returns:
|
35
38
|
jtp.Matrix: 3D rotation matrix.
|
@@ -43,7 +46,7 @@ class Rotation:
|
|
43
46
|
Generate a 3D rotation matrix around the Z-axis.
|
44
47
|
|
45
48
|
Args:
|
46
|
-
theta
|
49
|
+
theta: Rotation angle in radians.
|
47
50
|
|
48
51
|
Returns:
|
49
52
|
jtp.Matrix: 3D rotation matrix.
|
jaxsim/math/skew.py
CHANGED
@@ -14,7 +14,7 @@ class Skew:
|
|
14
14
|
Compute the skew-symmetric matrix (wedge operator) of a 3D vector.
|
15
15
|
|
16
16
|
Args:
|
17
|
-
vector
|
17
|
+
vector: A 3D vector.
|
18
18
|
|
19
19
|
Returns:
|
20
20
|
jtp.Matrix: The skew-symmetric matrix corresponding to the input vector.
|
@@ -31,7 +31,7 @@ class Skew:
|
|
31
31
|
Extract the 3D vector from a skew-symmetric matrix (vee operator).
|
32
32
|
|
33
33
|
Args:
|
34
|
-
matrix
|
34
|
+
matrix: A 3x3 skew-symmetric matrix.
|
35
35
|
|
36
36
|
Returns:
|
37
37
|
jtp.Vector: The 3D vector extracted from the input matrix.
|
jaxsim/math/transform.py
CHANGED
@@ -5,11 +5,14 @@ import jaxsim.typing as jtp
|
|
5
5
|
|
6
6
|
|
7
7
|
class Transform:
|
8
|
+
"""
|
9
|
+
A utility class for transformation matrix operations.
|
10
|
+
"""
|
8
11
|
|
9
12
|
@staticmethod
|
10
13
|
def from_quaternion_and_translation(
|
11
|
-
quaternion: jtp.VectorLike
|
12
|
-
translation: jtp.VectorLike =
|
14
|
+
quaternion: jtp.VectorLike | None = None,
|
15
|
+
translation: jtp.VectorLike | None = None,
|
13
16
|
inverse: jtp.BoolLike = False,
|
14
17
|
normalize_quaternion: jtp.BoolLike = False,
|
15
18
|
) -> jtp.Matrix:
|
@@ -27,6 +30,9 @@ class Transform:
|
|
27
30
|
The 4x4 transformation matrix representing the SE(3) transformation.
|
28
31
|
"""
|
29
32
|
|
33
|
+
quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
|
34
|
+
translation = translation if translation is not None else jnp.zeros(3)
|
35
|
+
|
30
36
|
W_Q_B = jnp.array(quaternion).astype(float)
|
31
37
|
W_p_B = jnp.array(translation).astype(float)
|
32
38
|
|
@@ -44,8 +50,8 @@ class Transform:
|
|
44
50
|
|
45
51
|
@staticmethod
|
46
52
|
def from_rotation_and_translation(
|
47
|
-
rotation: jtp.MatrixLike =
|
48
|
-
translation: jtp.VectorLike =
|
53
|
+
rotation: jtp.MatrixLike | None = None,
|
54
|
+
translation: jtp.VectorLike | None = None,
|
49
55
|
inverse: jtp.BoolLike = False,
|
50
56
|
) -> jtp.Matrix:
|
51
57
|
"""
|
@@ -59,6 +65,8 @@ class Transform:
|
|
59
65
|
Returns:
|
60
66
|
The 4x4 transformation matrix representing the SE(3) transformation.
|
61
67
|
"""
|
68
|
+
rotation = rotation if rotation is not None else jnp.eye(3)
|
69
|
+
translation = translation if translation is not None else jnp.zeros(3)
|
62
70
|
|
63
71
|
A_R_B = jnp.array(rotation).astype(float)
|
64
72
|
W_p_B = jnp.array(translation).astype(float)
|
jaxsim/math/utils.py
CHANGED
@@ -5,8 +5,8 @@ import jaxsim.typing as jtp
|
|
5
5
|
|
6
6
|
def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
|
7
7
|
"""
|
8
|
-
|
9
|
-
to
|
8
|
+
Compute an array norm handling NaNs and making sure that
|
9
|
+
it is safe to get the gradient.
|
10
10
|
|
11
11
|
Args:
|
12
12
|
array: The array for which to compute the norm.
|
jaxsim/mujoco/loaders.py
CHANGED
@@ -22,7 +22,7 @@ def load_rod_model(
|
|
22
22
|
model_name: str | None = None,
|
23
23
|
) -> rod.Model:
|
24
24
|
"""
|
25
|
-
|
25
|
+
Load a ROD model from a URDF/SDF file or a ROD model.
|
26
26
|
|
27
27
|
Args:
|
28
28
|
model_description: The URDF/SDF file or ROD model to load.
|
@@ -62,14 +62,16 @@ def load_rod_model(
|
|
62
62
|
|
63
63
|
|
64
64
|
class RodModelToMjcf:
|
65
|
-
"""
|
65
|
+
"""
|
66
|
+
Class to convert a ROD model to a Mujoco MJCF string.
|
67
|
+
"""
|
66
68
|
|
67
69
|
@staticmethod
|
68
70
|
def assets_from_rod_model(
|
69
71
|
rod_model: rod.Model,
|
70
72
|
) -> dict[str, bytes]:
|
71
73
|
"""
|
72
|
-
|
74
|
+
Generate a dictionary of assets from a ROD model.
|
73
75
|
|
74
76
|
Args:
|
75
77
|
rod_model: The ROD model to extract the assets from.
|
@@ -112,7 +114,7 @@ class RodModelToMjcf:
|
|
112
114
|
floating_joint_name: str = "world_to_base",
|
113
115
|
) -> str:
|
114
116
|
"""
|
115
|
-
|
117
|
+
Add a floating joint to a URDF string.
|
116
118
|
|
117
119
|
Args:
|
118
120
|
urdf_string: The URDF string to modify.
|
@@ -171,7 +173,7 @@ class RodModelToMjcf:
|
|
171
173
|
cameras: MujocoCameraType = (),
|
172
174
|
) -> tuple[str, dict[str, Any]]:
|
173
175
|
"""
|
174
|
-
|
176
|
+
Convert a ROD model to a Mujoco MJCF string.
|
175
177
|
|
176
178
|
Args:
|
177
179
|
rod_model: The ROD model to convert.
|
@@ -522,6 +524,10 @@ class RodModelToMjcf:
|
|
522
524
|
|
523
525
|
|
524
526
|
class UrdfToMjcf:
|
527
|
+
"""
|
528
|
+
Class to convert a URDF file to a Mujoco MJCF string.
|
529
|
+
"""
|
530
|
+
|
525
531
|
@staticmethod
|
526
532
|
def convert(
|
527
533
|
urdf: str | pathlib.Path,
|
@@ -532,7 +538,7 @@ class UrdfToMjcf:
|
|
532
538
|
cameras: MujocoCameraType = (),
|
533
539
|
) -> tuple[str, dict[str, Any]]:
|
534
540
|
"""
|
535
|
-
|
541
|
+
Convert a URDF file to a Mujoco MJCF string.
|
536
542
|
|
537
543
|
Args:
|
538
544
|
urdf: The URDF file to convert.
|
@@ -564,6 +570,10 @@ class UrdfToMjcf:
|
|
564
570
|
|
565
571
|
|
566
572
|
class SdfToMjcf:
|
573
|
+
"""
|
574
|
+
Class to convert a SDF file to a Mujoco MJCF string.
|
575
|
+
"""
|
576
|
+
|
567
577
|
@staticmethod
|
568
578
|
def convert(
|
569
579
|
sdf: str | pathlib.Path,
|
@@ -574,7 +584,7 @@ class SdfToMjcf:
|
|
574
584
|
cameras: MujocoCameraType = (),
|
575
585
|
) -> tuple[str, dict[str, Any]]:
|
576
586
|
"""
|
577
|
-
|
587
|
+
Convert a SDF file to a Mujoco MJCF string.
|
578
588
|
|
579
589
|
Args:
|
580
590
|
sdf: The SDF file to convert.
|