jaxsim 0.5.1.dev126__py3-none-any.whl → 0.5.1.dev139__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. jaxsim/__init__.py +0 -7
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/com.py +1 -1
  4. jaxsim/api/common.py +1 -1
  5. jaxsim/api/contact.py +3 -0
  6. jaxsim/api/data.py +2 -1
  7. jaxsim/api/kin_dyn_parameters.py +18 -1
  8. jaxsim/api/model.py +7 -4
  9. jaxsim/api/ode.py +21 -1
  10. jaxsim/exceptions.py +8 -0
  11. jaxsim/integrators/common.py +72 -11
  12. jaxsim/integrators/fixed_step.py +91 -40
  13. jaxsim/integrators/variable_step.py +117 -46
  14. jaxsim/math/adjoint.py +19 -10
  15. jaxsim/math/cross.py +6 -2
  16. jaxsim/math/inertia.py +8 -4
  17. jaxsim/math/quaternion.py +10 -6
  18. jaxsim/math/rotation.py +6 -3
  19. jaxsim/math/skew.py +2 -2
  20. jaxsim/math/transform.py +12 -4
  21. jaxsim/math/utils.py +2 -2
  22. jaxsim/mujoco/loaders.py +17 -7
  23. jaxsim/mujoco/model.py +15 -15
  24. jaxsim/mujoco/utils.py +6 -1
  25. jaxsim/mujoco/visualizer.py +11 -7
  26. jaxsim/parsers/descriptions/collision.py +7 -4
  27. jaxsim/parsers/descriptions/joint.py +16 -14
  28. jaxsim/parsers/descriptions/model.py +1 -1
  29. jaxsim/parsers/kinematic_graph.py +38 -0
  30. jaxsim/parsers/rod/meshes.py +5 -5
  31. jaxsim/parsers/rod/parser.py +1 -1
  32. jaxsim/parsers/rod/utils.py +11 -0
  33. jaxsim/rbda/contacts/common.py +2 -0
  34. jaxsim/rbda/contacts/relaxed_rigid.py +7 -4
  35. jaxsim/rbda/contacts/rigid.py +8 -4
  36. jaxsim/rbda/contacts/soft.py +37 -0
  37. jaxsim/rbda/contacts/visco_elastic.py +1 -0
  38. jaxsim/terrain/terrain.py +52 -0
  39. jaxsim/utils/jaxsim_dataclass.py +3 -3
  40. jaxsim/utils/tracing.py +2 -2
  41. jaxsim/utils/wrappers.py +9 -0
  42. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/METADATA +1 -1
  43. jaxsim-0.5.1.dev139.dist-info/RECORD +74 -0
  44. jaxsim-0.5.1.dev126.dist-info/RECORD +0 -74
  45. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/LICENSE +0 -0
  46. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/WHEEL +0 -0
  47. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import dataclasses
1
2
  import functools
2
3
  from typing import Any, ClassVar, Generic
3
4
 
@@ -216,6 +217,17 @@ def local_error_estimation(
216
217
 
217
218
  @jax_dataclasses.pytree_dataclass
218
219
  class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
220
+ """
221
+ An Embedded Runge-Kutta integrator.
222
+
223
+ This class implements a general-purpose Embedded Runge-Kutta integrator
224
+ that can be used to solve ordinary differential equations with adaptive
225
+ step sizes.
226
+
227
+ The integrator is based on an Explicit Runge-Kutta method, and it uses
228
+ two different solutions to estimate the local integration error. The
229
+ error is then used to adapt the step size to reach a desired accuracy.
230
+ """
219
231
 
220
232
  AfterInitKey: ClassVar[str] = "after_init"
221
233
  InitializingKey: ClassVar[str] = "initializing"
@@ -243,6 +255,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
243
255
  # Maximum number of rejected steps when the Δt needs to be reduced.
244
256
  max_step_rejections: Static[jtp.IntLike] = MAX_STEP_REJECTIONS_DEFAULT
245
257
 
258
+ index_of_fsal: jtp.IntLike | None = None
259
+ fsal_enabled_if_supported: bool = False
260
+
246
261
  def init(
247
262
  self,
248
263
  x0: State,
@@ -257,6 +272,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
257
272
  x0: The initial state of the system.
258
273
  t0: The initial time of the system.
259
274
  dt: The time step of the integration.
275
+ **kwargs: Additional parameters.
260
276
 
261
277
  Returns:
262
278
  The metadata of the integrator to be passed to the first step.
@@ -296,6 +312,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
296
312
  def __call__(
297
313
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
298
314
  ) -> tuple[NextState, dict[str, Any]]:
315
+ """
316
+ Integrate the system for a single step.
317
+ """
299
318
 
300
319
  # This method is called differently in three stages:
301
320
  #
@@ -512,10 +531,16 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
512
531
 
513
532
  @property
514
533
  def order_of_solution(self) -> int:
534
+ """
535
+ The order of the solution.
536
+ """
515
537
  return self.order_of_bT_rows[self.row_index_of_solution]
516
538
 
517
539
  @property
518
540
  def order_of_solution_estimate(self) -> int:
541
+ """
542
+ The order of the solution estimate.
543
+ """
519
544
  return self.order_of_bT_rows[self.row_index_of_solution_estimate]
520
545
 
521
546
  @classmethod
@@ -534,17 +559,36 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
534
559
  max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT,
535
560
  **kwargs,
536
561
  ) -> Self:
562
+ """
563
+ Build an Embedded Runge-Kutta integrator.
564
+
565
+ Args:
566
+ dynamics: The system dynamics function.
567
+ fsal_enabled_if_supported:
568
+ Whether to enable the FSAL property if supported by the integrator.
569
+ dt_max: The maximum step size.
570
+ dt_min: The minimum step size.
571
+ rtol: The relative tolerance.
572
+ atol: The absolute tolerance.
573
+ safety: The safety factor to shrink the step size.
574
+ beta_max: The maximum factor to increase the step size.
575
+ beta_min: The minimum factor to increase the step size.
576
+ max_step_rejections: The maximum number of step rejections.
577
+ **kwargs: Additional parameters.
578
+ """
579
+
580
+ b = cls.__dataclass_fields__["b"].default_factory()
537
581
 
538
582
  # Check that b.T has enough rows based on the configured index of the
539
583
  # solution estimate. This is necessary for embedded methods.
540
584
  if (
541
585
  cls.row_index_of_solution_estimate is not None
542
- and cls.row_index_of_solution_estimate >= cls.b.T.shape[0]
586
+ and cls.row_index_of_solution_estimate >= b.T.shape[0]
543
587
  ):
544
588
  msg = "The index of the solution estimate ({}-th row of `b.T`) "
545
589
  msg += "is out of range ({})."
546
590
  raise ValueError(
547
- msg.format(cls.row_index_of_solution_estimate, cls.b.T.shape[0])
591
+ msg.format(cls.row_index_of_solution_estimate, b.T.shape[0])
548
592
  )
549
593
 
550
594
  integrator = super().build(
@@ -569,65 +613,92 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
569
613
 
570
614
  @jax_dataclasses.pytree_dataclass
571
615
  class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
616
+ """
617
+ The Heun-Euler integrator for SO(3) dynamics.
618
+ """
572
619
 
573
- A: ClassVar[jtp.Matrix] = jnp.array(
574
- [
575
- [0, 0],
576
- [1, 0],
577
- ]
578
- ).astype(float)
579
-
580
- b: ClassVar[jtp.Matrix] = (
581
- jnp.atleast_2d(
582
- jnp.array(
583
- [
584
- [1 / 2, 1 / 2],
585
- [1, 0],
586
- ]
587
- ),
588
- )
589
- .astype(float)
590
- .transpose()
620
+ A: jtp.Matrix = dataclasses.field(
621
+ default_factory=lambda: jnp.array(
622
+ [
623
+ [0, 0],
624
+ [1, 0],
625
+ ]
626
+ ).astype(float),
627
+ compare=False,
628
+ )
629
+
630
+ b: jtp.Matrix = dataclasses.field(
631
+ default_factory=lambda: (
632
+ jnp.atleast_2d(
633
+ jnp.array(
634
+ [
635
+ [1 / 2, 1 / 2],
636
+ [1, 0],
637
+ ]
638
+ ),
639
+ )
640
+ .astype(float)
641
+ .transpose()
642
+ ),
643
+ compare=False,
591
644
  )
592
645
 
593
- c: ClassVar[jtp.Vector] = jnp.array(
594
- [0, 1],
595
- ).astype(float)
646
+ c: jtp.Vector = dataclasses.field(
647
+ default_factory=lambda: jnp.array(
648
+ [0, 1],
649
+ ).astype(float),
650
+ compare=False,
651
+ )
596
652
 
597
653
  row_index_of_solution: ClassVar[int] = 0
598
654
  row_index_of_solution_estimate: ClassVar[int | None] = 1
599
655
 
600
656
  order_of_bT_rows: ClassVar[tuple[int, ...]] = (2, 1)
601
657
 
658
+ index_of_fsal: jtp.IntLike | None = None
659
+ fsal_enabled_if_supported: bool = False
660
+
602
661
 
603
662
  @jax_dataclasses.pytree_dataclass
604
663
  class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
664
+ """
665
+ The Bogacki-Shampine integrator for SO(3) dynamics.
666
+ """
605
667
 
606
- A: ClassVar[jtp.Matrix] = jnp.array(
607
- [
608
- [0, 0, 0, 0],
609
- [1 / 2, 0, 0, 0],
610
- [0, 3 / 4, 0, 0],
611
- [2 / 9, 1 / 3, 4 / 9, 0],
612
- ]
613
- ).astype(float)
614
-
615
- b: ClassVar[jtp.Matrix] = (
616
- jnp.atleast_2d(
617
- jnp.array(
618
- [
619
- [2 / 9, 1 / 3, 4 / 9, 0],
620
- [7 / 24, 1 / 4, 1 / 3, 1 / 8],
621
- ]
622
- ),
623
- )
624
- .astype(float)
625
- .transpose()
668
+ A: jtp.Matrix = dataclasses.field(
669
+ default_factory=lambda: jnp.array(
670
+ [
671
+ [0, 0, 0, 0],
672
+ [1 / 2, 0, 0, 0],
673
+ [0, 3 / 4, 0, 0],
674
+ [2 / 9, 1 / 3, 4 / 9, 0],
675
+ ]
676
+ ).astype(float),
677
+ compare=False,
626
678
  )
627
679
 
628
- c: ClassVar[jtp.Vector] = jnp.array(
629
- [0, 1 / 2, 3 / 4, 1],
630
- ).astype(float)
680
+ b: jtp.Matrix = dataclasses.field(
681
+ default_factory=lambda: (
682
+ jnp.atleast_2d(
683
+ jnp.array(
684
+ [
685
+ [2 / 9, 1 / 3, 4 / 9, 0],
686
+ [7 / 24, 1 / 4, 1 / 3, 1 / 8],
687
+ ]
688
+ ),
689
+ )
690
+ .astype(float)
691
+ .transpose()
692
+ ),
693
+ compare=False,
694
+ )
695
+
696
+ c: jtp.Vector = dataclasses.field(
697
+ default_factory=lambda: jnp.array(
698
+ [0, 1 / 2, 3 / 4, 1],
699
+ ).astype(float),
700
+ compare=False,
701
+ )
631
702
 
632
703
  row_index_of_solution: ClassVar[int] = 0
633
704
  row_index_of_solution_estimate: ClassVar[int | None] = 1
jaxsim/math/adjoint.py CHANGED
@@ -7,10 +7,14 @@ from .skew import Skew
7
7
 
8
8
 
9
9
  class Adjoint:
10
+ """
11
+ A utility class for adjoint matrix operations.
12
+ """
13
+
10
14
  @staticmethod
11
15
  def from_quaternion_and_translation(
12
- quaternion: jtp.Vector = jnp.array([1.0, 0, 0, 0]),
13
- translation: jtp.Vector = jnp.zeros(3),
16
+ quaternion: jtp.Vector | None = None,
17
+ translation: jtp.Vector | None = None,
14
18
  inverse: bool = False,
15
19
  normalize_quaternion: bool = False,
16
20
  ) -> jtp.Matrix:
@@ -18,8 +22,8 @@ class Adjoint:
18
22
  Create an adjoint matrix from a quaternion and a translation.
19
23
 
20
24
  Args:
21
- quaternion (jtp.Vector): A quaternion vector (4D) representing orientation.
22
- translation (jtp.Vector): A translation vector (3D).
25
+ quaternion (jtp.Vector): A quaternion vector (4D) representing orientation. Default is [1, 0, 0, 0].
26
+ translation (jtp.Vector): A translation vector (3D). Default is [0, 0, 0].
23
27
  inverse (bool): Whether to compute the inverse adjoint. Default is False.
24
28
  normalize_quaternion (bool): Whether to normalize the quaternion before creating the adjoint.
25
29
  Default is False.
@@ -27,6 +31,8 @@ class Adjoint:
27
31
  Returns:
28
32
  jtp.Matrix: The adjoint matrix.
29
33
  """
34
+ quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
35
+ translation = translation if translation is not None else jnp.zeros(3)
30
36
  assert quaternion.size == 4
31
37
  assert translation.size == 3
32
38
 
@@ -61,21 +67,24 @@ class Adjoint:
61
67
 
62
68
  @staticmethod
63
69
  def from_rotation_and_translation(
64
- rotation: jtp.Matrix = jnp.eye(3),
65
- translation: jtp.Vector = jnp.zeros(3),
70
+ rotation: jtp.Matrix | None = None,
71
+ translation: jtp.Vector | None = None,
66
72
  inverse: bool = False,
67
73
  ) -> jtp.Matrix:
68
74
  """
69
75
  Create an adjoint matrix from a rotation matrix and a translation vector.
70
76
 
71
77
  Args:
72
- rotation (jtp.Matrix): A 3x3 rotation matrix.
73
- translation (jtp.Vector): A translation vector (3D).
78
+ rotation (jtp.Matrix): A 3x3 rotation matrix. Default is identity.
79
+ translation (jtp.Vector): A translation vector (3D). Default is [0, 0, 0].
74
80
  inverse (bool): Whether to compute the inverse adjoint. Default is False.
75
81
 
76
82
  Returns:
77
83
  jtp.Matrix: The adjoint matrix.
78
84
  """
85
+ rotation = rotation if rotation is not None else jnp.eye(3)
86
+ translation = translation if translation is not None else jnp.zeros(3)
87
+
79
88
  assert rotation.shape == (3, 3)
80
89
  assert translation.size == 3
81
90
 
@@ -105,7 +114,7 @@ class Adjoint:
105
114
  Convert an adjoint matrix to a transformation matrix.
106
115
 
107
116
  Args:
108
- adjoint (jtp.Matrix): The adjoint matrix (6x6).
117
+ adjoint: The adjoint matrix (6x6).
109
118
 
110
119
  Returns:
111
120
  jtp.Matrix: The transformation matrix (4x4).
@@ -131,7 +140,7 @@ class Adjoint:
131
140
  Compute the inverse of an adjoint matrix.
132
141
 
133
142
  Args:
134
- adjoint (jtp.Matrix): The adjoint matrix.
143
+ adjoint: The adjoint matrix.
135
144
 
136
145
  Returns:
137
146
  jtp.Matrix: The inverse adjoint matrix.
jaxsim/math/cross.py CHANGED
@@ -6,13 +6,17 @@ from .skew import Skew
6
6
 
7
7
 
8
8
  class Cross:
9
+ """
10
+ A utility class for cross product matrix operations.
11
+ """
12
+
9
13
  @staticmethod
10
14
  def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
11
15
  """
12
16
  Compute the cross product matrix for 6D velocities.
13
17
 
14
18
  Args:
15
- velocity_sixd (jtp.Vector): A 6D velocity vector [v, ω].
19
+ velocity_sixd: A 6D velocity vector [v, ω].
16
20
 
17
21
  Returns:
18
22
  jtp.Matrix: The cross product matrix (6x6).
@@ -37,7 +41,7 @@ class Cross:
37
41
  Compute the negative transpose of the cross product matrix for 6D velocities.
38
42
 
39
43
  Args:
40
- velocity_sixd (jtp.Vector): A 6D velocity vector [v, ω].
44
+ velocity_sixd: A 6D velocity vector [v, ω].
41
45
 
42
46
  Returns:
43
47
  jtp.Matrix: The negative transpose of the cross product matrix (6x6).
jaxsim/math/inertia.py CHANGED
@@ -6,15 +6,19 @@ from .skew import Skew
6
6
 
7
7
 
8
8
  class Inertia:
9
+ """
10
+ A utility class for inertia matrix operations.
11
+ """
12
+
9
13
  @staticmethod
10
14
  def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:
11
15
  """
12
16
  Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix.
13
17
 
14
18
  Args:
15
- mass (jtp.Float): The mass of the body.
16
- com (jtp.Vector): The center of mass position (3D).
17
- I (jtp.Matrix): The 3x3 inertia matrix.
19
+ mass: The mass of the body.
20
+ com: The center of mass position (3D).
21
+ I: The 3x3 inertia matrix.
18
22
 
19
23
  Returns:
20
24
  jtp.Matrix: The 6x6 inertia matrix.
@@ -42,7 +46,7 @@ class Inertia:
42
46
  Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.
43
47
 
44
48
  Args:
45
- M (jtp.Matrix): The 6x6 inertia matrix.
49
+ M: The 6x6 inertia matrix.
46
50
 
47
51
  Returns:
48
52
  tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
jaxsim/math/quaternion.py CHANGED
@@ -8,13 +8,17 @@ from .utils import safe_norm
8
8
 
9
9
 
10
10
  class Quaternion:
11
+ """
12
+ A utility class for quaternion operations.
13
+ """
14
+
11
15
  @staticmethod
12
16
  def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:
13
17
  """
14
18
  Convert a quaternion from WXYZ to XYZW representation.
15
19
 
16
20
  Args:
17
- wxyz (jtp.Vector): Quaternion in WXYZ representation.
21
+ wxyz: Quaternion in WXYZ representation.
18
22
 
19
23
  Returns:
20
24
  jtp.Vector: Quaternion in XYZW representation.
@@ -27,7 +31,7 @@ class Quaternion:
27
31
  Convert a quaternion from XYZW to WXYZ representation.
28
32
 
29
33
  Args:
30
- xyzw (jtp.Vector): Quaternion in XYZW representation.
34
+ xyzw: Quaternion in XYZW representation.
31
35
 
32
36
  Returns:
33
37
  jtp.Vector: Quaternion in WXYZ representation.
@@ -40,7 +44,7 @@ class Quaternion:
40
44
  Convert a quaternion to a direction cosine matrix (DCM).
41
45
 
42
46
  Args:
43
- quaternion (jtp.Vector): Quaternion in XYZW representation.
47
+ quaternion: Quaternion in XYZW representation.
44
48
 
45
49
  Returns:
46
50
  jtp.Matrix: Direction cosine matrix (DCM).
@@ -53,7 +57,7 @@ class Quaternion:
53
57
  Convert a direction cosine matrix (DCM) to a quaternion.
54
58
 
55
59
  Args:
56
- dcm (jtp.Matrix): Direction cosine matrix (DCM).
60
+ dcm: Direction cosine matrix (DCM).
57
61
 
58
62
  Returns:
59
63
  jtp.Vector: Quaternion in XYZW representation.
@@ -71,8 +75,8 @@ class Quaternion:
71
75
  Compute the derivative of a quaternion given angular velocity.
72
76
 
73
77
  Args:
74
- quaternion (jtp.Vector): Quaternion in XYZW representation.
75
- omega (jtp.Vector): Angular velocity vector.
78
+ quaternion: Quaternion in XYZW representation.
79
+ omega: Angular velocity vector.
76
80
  omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame.
77
81
  K (float): A scaling factor.
78
82
 
jaxsim/math/rotation.py CHANGED
@@ -8,6 +8,9 @@ from .utils import safe_norm
8
8
 
9
9
 
10
10
  class Rotation:
11
+ """
12
+ A utility class for rotation matrix operations.
13
+ """
11
14
 
12
15
  @staticmethod
13
16
  def x(theta: jtp.Float) -> jtp.Matrix:
@@ -15,7 +18,7 @@ class Rotation:
15
18
  Generate a 3D rotation matrix around the X-axis.
16
19
 
17
20
  Args:
18
- theta (jtp.Float): Rotation angle in radians.
21
+ theta: Rotation angle in radians.
19
22
 
20
23
  Returns:
21
24
  jtp.Matrix: 3D rotation matrix.
@@ -29,7 +32,7 @@ class Rotation:
29
32
  Generate a 3D rotation matrix around the Y-axis.
30
33
 
31
34
  Args:
32
- theta (jtp.Float): Rotation angle in radians.
35
+ theta: Rotation angle in radians.
33
36
 
34
37
  Returns:
35
38
  jtp.Matrix: 3D rotation matrix.
@@ -43,7 +46,7 @@ class Rotation:
43
46
  Generate a 3D rotation matrix around the Z-axis.
44
47
 
45
48
  Args:
46
- theta (jtp.Float): Rotation angle in radians.
49
+ theta: Rotation angle in radians.
47
50
 
48
51
  Returns:
49
52
  jtp.Matrix: 3D rotation matrix.
jaxsim/math/skew.py CHANGED
@@ -14,7 +14,7 @@ class Skew:
14
14
  Compute the skew-symmetric matrix (wedge operator) of a 3D vector.
15
15
 
16
16
  Args:
17
- vector (jtp.Vector): A 3D vector.
17
+ vector: A 3D vector.
18
18
 
19
19
  Returns:
20
20
  jtp.Matrix: The skew-symmetric matrix corresponding to the input vector.
@@ -31,7 +31,7 @@ class Skew:
31
31
  Extract the 3D vector from a skew-symmetric matrix (vee operator).
32
32
 
33
33
  Args:
34
- matrix (jtp.Matrix): A 3x3 skew-symmetric matrix.
34
+ matrix: A 3x3 skew-symmetric matrix.
35
35
 
36
36
  Returns:
37
37
  jtp.Vector: The 3D vector extracted from the input matrix.
jaxsim/math/transform.py CHANGED
@@ -5,11 +5,14 @@ import jaxsim.typing as jtp
5
5
 
6
6
 
7
7
  class Transform:
8
+ """
9
+ A utility class for transformation matrix operations.
10
+ """
8
11
 
9
12
  @staticmethod
10
13
  def from_quaternion_and_translation(
11
- quaternion: jtp.VectorLike = jnp.array([1.0, 0, 0, 0]),
12
- translation: jtp.VectorLike = jnp.zeros(3),
14
+ quaternion: jtp.VectorLike | None = None,
15
+ translation: jtp.VectorLike | None = None,
13
16
  inverse: jtp.BoolLike = False,
14
17
  normalize_quaternion: jtp.BoolLike = False,
15
18
  ) -> jtp.Matrix:
@@ -27,6 +30,9 @@ class Transform:
27
30
  The 4x4 transformation matrix representing the SE(3) transformation.
28
31
  """
29
32
 
33
+ quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
34
+ translation = translation if translation is not None else jnp.zeros(3)
35
+
30
36
  W_Q_B = jnp.array(quaternion).astype(float)
31
37
  W_p_B = jnp.array(translation).astype(float)
32
38
 
@@ -44,8 +50,8 @@ class Transform:
44
50
 
45
51
  @staticmethod
46
52
  def from_rotation_and_translation(
47
- rotation: jtp.MatrixLike = jnp.eye(3),
48
- translation: jtp.VectorLike = jnp.zeros(3),
53
+ rotation: jtp.MatrixLike | None = None,
54
+ translation: jtp.VectorLike | None = None,
49
55
  inverse: jtp.BoolLike = False,
50
56
  ) -> jtp.Matrix:
51
57
  """
@@ -59,6 +65,8 @@ class Transform:
59
65
  Returns:
60
66
  The 4x4 transformation matrix representing the SE(3) transformation.
61
67
  """
68
+ rotation = rotation if rotation is not None else jnp.eye(3)
69
+ translation = translation if translation is not None else jnp.zeros(3)
62
70
 
63
71
  A_R_B = jnp.array(rotation).astype(float)
64
72
  W_p_B = jnp.array(translation).astype(float)
jaxsim/math/utils.py CHANGED
@@ -5,8 +5,8 @@ import jaxsim.typing as jtp
5
5
 
6
6
  def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
7
7
  """
8
- Provides a calculation for an array norm so that it is safe
9
- to compute the gradient and handle NaNs.
8
+ Compute an array norm handling NaNs and making sure that
9
+ it is safe to get the gradient.
10
10
 
11
11
  Args:
12
12
  array: The array for which to compute the norm.
jaxsim/mujoco/loaders.py CHANGED
@@ -22,7 +22,7 @@ def load_rod_model(
22
22
  model_name: str | None = None,
23
23
  ) -> rod.Model:
24
24
  """
25
- Loads a ROD model from a URDF/SDF file or a ROD model.
25
+ Load a ROD model from a URDF/SDF file or a ROD model.
26
26
 
27
27
  Args:
28
28
  model_description: The URDF/SDF file or ROD model to load.
@@ -62,14 +62,16 @@ def load_rod_model(
62
62
 
63
63
 
64
64
  class RodModelToMjcf:
65
- """"""
65
+ """
66
+ Class to convert a ROD model to a Mujoco MJCF string.
67
+ """
66
68
 
67
69
  @staticmethod
68
70
  def assets_from_rod_model(
69
71
  rod_model: rod.Model,
70
72
  ) -> dict[str, bytes]:
71
73
  """
72
- Generates a dictionary of assets from a ROD model.
74
+ Generate a dictionary of assets from a ROD model.
73
75
 
74
76
  Args:
75
77
  rod_model: The ROD model to extract the assets from.
@@ -112,7 +114,7 @@ class RodModelToMjcf:
112
114
  floating_joint_name: str = "world_to_base",
113
115
  ) -> str:
114
116
  """
115
- Adds a floating joint to a URDF string.
117
+ Add a floating joint to a URDF string.
116
118
 
117
119
  Args:
118
120
  urdf_string: The URDF string to modify.
@@ -171,7 +173,7 @@ class RodModelToMjcf:
171
173
  cameras: MujocoCameraType = (),
172
174
  ) -> tuple[str, dict[str, Any]]:
173
175
  """
174
- Converts a ROD model to a Mujoco MJCF string.
176
+ Convert a ROD model to a Mujoco MJCF string.
175
177
 
176
178
  Args:
177
179
  rod_model: The ROD model to convert.
@@ -522,6 +524,10 @@ class RodModelToMjcf:
522
524
 
523
525
 
524
526
  class UrdfToMjcf:
527
+ """
528
+ Class to convert a URDF file to a Mujoco MJCF string.
529
+ """
530
+
525
531
  @staticmethod
526
532
  def convert(
527
533
  urdf: str | pathlib.Path,
@@ -532,7 +538,7 @@ class UrdfToMjcf:
532
538
  cameras: MujocoCameraType = (),
533
539
  ) -> tuple[str, dict[str, Any]]:
534
540
  """
535
- Converts a URDF file to a Mujoco MJCF string.
541
+ Convert a URDF file to a Mujoco MJCF string.
536
542
 
537
543
  Args:
538
544
  urdf: The URDF file to convert.
@@ -564,6 +570,10 @@ class UrdfToMjcf:
564
570
 
565
571
 
566
572
  class SdfToMjcf:
573
+ """
574
+ Class to convert a SDF file to a Mujoco MJCF string.
575
+ """
576
+
567
577
  @staticmethod
568
578
  def convert(
569
579
  sdf: str | pathlib.Path,
@@ -574,7 +584,7 @@ class SdfToMjcf:
574
584
  cameras: MujocoCameraType = (),
575
585
  ) -> tuple[str, dict[str, Any]]:
576
586
  """
577
- Converts a SDF file to a Mujoco MJCF string.
587
+ Convert a SDF file to a Mujoco MJCF string.
578
588
 
579
589
  Args:
580
590
  sdf: The SDF file to convert.