jaxsim 0.5.1.dev123__py3-none-any.whl → 0.5.1.dev133__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/_version.py +2 -2
  2. jaxsim/api/com.py +1 -1
  3. jaxsim/api/common.py +1 -1
  4. jaxsim/api/contact.py +3 -0
  5. jaxsim/api/data.py +2 -1
  6. jaxsim/api/kin_dyn_parameters.py +26 -9
  7. jaxsim/api/model.py +9 -6
  8. jaxsim/api/ode.py +21 -1
  9. jaxsim/exceptions.py +8 -0
  10. jaxsim/integrators/common.py +60 -2
  11. jaxsim/integrators/fixed_step.py +21 -0
  12. jaxsim/integrators/variable_step.py +44 -0
  13. jaxsim/math/adjoint.py +13 -10
  14. jaxsim/math/cross.py +6 -2
  15. jaxsim/math/inertia.py +8 -4
  16. jaxsim/math/quaternion.py +10 -6
  17. jaxsim/math/rotation.py +6 -3
  18. jaxsim/math/skew.py +2 -2
  19. jaxsim/math/transform.py +3 -0
  20. jaxsim/math/utils.py +2 -2
  21. jaxsim/mujoco/loaders.py +17 -7
  22. jaxsim/mujoco/model.py +15 -15
  23. jaxsim/mujoco/utils.py +6 -1
  24. jaxsim/mujoco/visualizer.py +11 -7
  25. jaxsim/parsers/descriptions/collision.py +7 -4
  26. jaxsim/parsers/descriptions/joint.py +16 -14
  27. jaxsim/parsers/descriptions/model.py +1 -1
  28. jaxsim/parsers/kinematic_graph.py +38 -0
  29. jaxsim/parsers/rod/meshes.py +5 -5
  30. jaxsim/parsers/rod/parser.py +1 -1
  31. jaxsim/parsers/rod/utils.py +11 -0
  32. jaxsim/rbda/contacts/common.py +2 -0
  33. jaxsim/rbda/contacts/relaxed_rigid.py +7 -4
  34. jaxsim/rbda/contacts/rigid.py +8 -4
  35. jaxsim/rbda/contacts/soft.py +37 -0
  36. jaxsim/rbda/contacts/visco_elastic.py +1 -0
  37. jaxsim/terrain/terrain.py +52 -0
  38. jaxsim/utils/jaxsim_dataclass.py +3 -3
  39. jaxsim/utils/tracing.py +2 -2
  40. jaxsim/utils/wrappers.py +9 -0
  41. {jaxsim-0.5.1.dev123.dist-info → jaxsim-0.5.1.dev133.dist-info}/METADATA +1 -1
  42. jaxsim-0.5.1.dev133.dist-info/RECORD +74 -0
  43. jaxsim-0.5.1.dev123.dist-info/RECORD +0 -74
  44. {jaxsim-0.5.1.dev123.dist-info → jaxsim-0.5.1.dev133.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.5.1.dev123.dist-info → jaxsim-0.5.1.dev133.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.5.1.dev123.dist-info → jaxsim-0.5.1.dev133.dist-info}/top_level.txt +0 -0
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.5.1.dev123'
16
- __version_tuple__ = version_tuple = (0, 5, 1, 'dev123')
15
+ __version__ = version = '0.5.1.dev133'
16
+ __version_tuple__ = version_tuple = (0, 5, 1, 'dev133')
jaxsim/api/com.py CHANGED
@@ -279,7 +279,7 @@ def bias_acceleration(
279
279
  C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
280
280
  ) -> jtp.Vector:
281
281
  """
282
- Helper to convert the body-fixed representation of the link bias acceleration
282
+ Convert the body-fixed representation of the link bias acceleration
283
283
  C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
284
284
  """
285
285
 
jaxsim/api/common.py CHANGED
@@ -26,7 +26,7 @@ _R = TypeVar("_R")
26
26
 
27
27
 
28
28
  def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
29
- """Applies a JAX named scope to a function for improved profiling and clarity."""
29
+ """Apply a JAX named scope to a function for improved profiling and clarity."""
30
30
 
31
31
  @functools.wraps(fn)
32
32
  def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
jaxsim/api/contact.py CHANGED
@@ -293,6 +293,9 @@ def in_contact(
293
293
  def estimate_good_soft_contacts_parameters(
294
294
  *args, **kwargs
295
295
  ) -> jaxsim.rbda.contacts.ContactParamsTypes:
296
+ """
297
+ Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.
298
+ """
296
299
 
297
300
  msg = "This method is deprecated, please use `{}`."
298
301
  logging.warning(msg.format(estimate_good_contact_parameters.__name__))
jaxsim/api/data.py CHANGED
@@ -456,7 +456,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
456
456
  @jax.jit
457
457
  def generalized_velocity(self) -> jtp.Vector:
458
458
  r"""
459
- Get the generalized velocity
459
+ Get the generalized velocity.
460
+
460
461
  :math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`
461
462
 
462
463
  Returns:
@@ -16,7 +16,7 @@ from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
16
16
 
17
17
 
18
18
  @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
19
- class KynDynParameters(JaxsimDataclass):
19
+ class KinDynParameters(JaxsimDataclass):
20
20
  r"""
21
21
  Class storing the kinematic and dynamic parameters of a model.
22
22
 
@@ -52,14 +52,20 @@ class KynDynParameters(JaxsimDataclass):
52
52
 
53
53
  @property
54
54
  def parent_array(self) -> jtp.Vector:
55
+ r"""
56
+ Return the parent array :math:`\lambda(i)` of the model.
57
+ """
55
58
  return self._parent_array.get()
56
59
 
57
60
  @property
58
61
  def support_body_array_bool(self) -> jtp.Matrix:
62
+ r"""
63
+ Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
64
+ """
59
65
  return self._support_body_array_bool.get()
60
66
 
61
67
  @staticmethod
62
- def build(model_description: ModelDescription) -> KynDynParameters:
68
+ def build(model_description: ModelDescription) -> KinDynParameters:
63
69
  """
64
70
  Construct the kinematic and dynamic parameters of the model.
65
71
 
@@ -210,10 +216,10 @@ class KynDynParameters(JaxsimDataclass):
210
216
  )
211
217
 
212
218
  # =================================
213
- # Build and return KynDynParameters
219
+ # Build and return KinDynParameters
214
220
  # =================================
215
221
 
216
- return KynDynParameters(
222
+ return KinDynParameters(
217
223
  link_names=tuple(l.name for l in ordered_links),
218
224
  _parent_array=HashedNumpyArray(array=parent_array),
219
225
  _support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
@@ -224,9 +230,9 @@ class KynDynParameters(JaxsimDataclass):
224
230
  frame_parameters=frame_parameters,
225
231
  )
226
232
 
227
- def __eq__(self, other: KynDynParameters) -> bool:
233
+ def __eq__(self, other: KinDynParameters) -> bool:
228
234
 
229
- if not isinstance(other, KynDynParameters):
235
+ if not isinstance(other, KinDynParameters):
230
236
  return False
231
237
 
232
238
  return hash(self) == hash(other)
@@ -450,7 +456,7 @@ class KynDynParameters(JaxsimDataclass):
450
456
 
451
457
  def set_link_mass(
452
458
  self, link_index: jtp.IntLike, mass: jtp.FloatLike
453
- ) -> KynDynParameters:
459
+ ) -> KinDynParameters:
454
460
  """
455
461
  Set the mass of a link.
456
462
 
@@ -470,7 +476,7 @@ class KynDynParameters(JaxsimDataclass):
470
476
 
471
477
  def set_link_inertia(
472
478
  self, link_index: jtp.IntLike, inertia: jtp.MatrixLike
473
- ) -> KynDynParameters:
479
+ ) -> KinDynParameters:
474
480
  r"""
475
481
  Set the inertia tensor of a link.
476
482
 
@@ -648,7 +654,16 @@ class LinkParameters(JaxsimDataclass):
648
654
  def build_from_flat_parameters(
649
655
  index: jtp.IntLike, parameters: jtp.VectorLike
650
656
  ) -> LinkParameters:
657
+ """
658
+ Build a LinkParameters object from a flat vector of parameters.
659
+
660
+ Args:
661
+ index: The index of the link.
662
+ parameters: The flat vector of parameters.
651
663
 
664
+ Returns:
665
+ The LinkParameters object.
666
+ """
652
667
  index = jnp.array(index).squeeze().astype(int)
653
668
 
654
669
  m = jnp.array(parameters[0]).squeeze().astype(float)
@@ -772,7 +787,9 @@ class ContactParameters(JaxsimDataclass):
772
787
 
773
788
  @property
774
789
  def indices_of_enabled_collidable_points(self) -> npt.NDArray:
775
-
790
+ """
791
+ Return the indices of the enabled collidable points.
792
+ """
776
793
  return np.where(np.array(self.enabled))[0]
777
794
 
778
795
  @staticmethod
jaxsim/api/model.py CHANGED
@@ -45,7 +45,7 @@ class JaxSimModel(JaxsimDataclass):
45
45
  default=None, repr=False
46
46
  )
47
47
 
48
- kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
48
+ kin_dyn_parameters: js.kin_dyn_parameters.KinDynParameters | None = (
49
49
  dataclasses.field(default=None, repr=False)
50
50
  )
51
51
 
@@ -63,6 +63,9 @@ class JaxSimModel(JaxsimDataclass):
63
63
 
64
64
  @property
65
65
  def description(self) -> ModelDescription:
66
+ """
67
+ Return the model description.
68
+ """
66
69
  return self._description.get()
67
70
 
68
71
  def __eq__(self, other: JaxSimModel) -> bool:
@@ -271,7 +274,7 @@ class JaxSimModel(JaxsimDataclass):
271
274
  # Build the model.
272
275
  model = cls(
273
276
  model_name=model_name,
274
- kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
277
+ kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build(
275
278
  model_description=model_description
276
279
  ),
277
280
  time_step=time_step,
@@ -1015,7 +1018,7 @@ def forward_dynamics_aba(
1015
1018
  W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
1016
1019
  ) -> jtp.Vector:
1017
1020
  """
1018
- Helper to convert the inertial-fixed apparent base acceleration W_v̇_WB to
1021
+ Convert the inertial-fixed apparent base acceleration W_v̇_WB to
1019
1022
  another representation C_v̇_WB expressed in a generic frame C.
1020
1023
  """
1021
1024
 
@@ -1376,7 +1379,7 @@ def inverse_dynamics(
1376
1379
 
1377
1380
  def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
1378
1381
  """
1379
- Helper to convert the active representation of the base acceleration C_v̇_WB
1382
+ Convert the active representation of the base acceleration C_v̇_WB
1380
1383
  expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
1381
1384
  """
1382
1385
 
@@ -1825,7 +1828,7 @@ def link_bias_accelerations(
1825
1828
  C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
1826
1829
  ) -> jtp.Vector:
1827
1830
  """
1828
- Helper to convert the active representation of the base acceleration C_v̇_WB
1831
+ Convert the active representation of the base acceleration C_v̇_WB
1829
1832
  expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
1830
1833
  """
1831
1834
 
@@ -1961,7 +1964,7 @@ def link_bias_accelerations(
1961
1964
  L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
1962
1965
  ) -> jtp.Vector:
1963
1966
  """
1964
- Helper to convert the body-fixed apparent acceleration L_v̇_WL to
1967
+ Convert the body-fixed apparent acceleration L_v̇_WL to
1965
1968
  another representation C_v̇_WL expressed in a generic frame C.
1966
1969
  """
1967
1970
 
jaxsim/api/ode.py CHANGED
@@ -15,12 +15,32 @@ from .ode_data import ODEState
15
15
 
16
16
 
17
17
  class SystemDynamicsFromModelAndData(Protocol):
18
+ """
19
+ Protocol defining the signature of a function computing the system dynamics
20
+ given a model and data object.
21
+ """
22
+
18
23
  def __call__(
19
24
  self,
20
25
  model: js.model.JaxSimModel,
21
26
  data: js.data.JaxSimModelData,
22
27
  **kwargs: dict[str, Any],
23
- ) -> tuple[ODEState, dict[str, Any]]: ...
28
+ ) -> tuple[ODEState, dict[str, Any]]:
29
+ """
30
+ Compute the system dynamics given a model and data object.
31
+
32
+ Args:
33
+ model: The model to consider.
34
+ data: The data of the considered model.
35
+ **kwargs: Additional keyword arguments.
36
+
37
+ Returns:
38
+ A tuple with an `ODEState` object storing in each of its attributes the
39
+ corresponding derivative, and the dictionary of auxiliary data returned
40
+ by the system dynamics evaluation.
41
+ """
42
+
43
+ pass
24
44
 
25
45
 
26
46
  def wrap_system_dynamics_for_integration(
jaxsim/exceptions.py CHANGED
@@ -17,6 +17,8 @@ def raise_if(
17
17
  msg:
18
18
  The message to display when the exception is raised. The message can be a
19
19
  format string (fmt), whose fields are filled with the args and kwargs.
20
+ *args: The arguments to fill the format string.
21
+ **kwargs: The keyword arguments to fill the format string
20
22
  """
21
23
 
22
24
  # Disable host callback if running on unsupported hardware or if the user
@@ -61,6 +63,9 @@ def raise_if(
61
63
  def raise_runtime_error_if(
62
64
  condition: bool | jax.Array, msg: str, *args, **kwargs
63
65
  ) -> None:
66
+ """
67
+ Raise a RuntimeError if a condition is met. Useful in jit-compiled functions.
68
+ """
64
69
 
65
70
  return raise_if(condition, RuntimeError, msg, *args, **kwargs)
66
71
 
@@ -68,5 +73,8 @@ def raise_runtime_error_if(
68
73
  def raise_value_error_if(
69
74
  condition: bool | jax.Array, msg: str, *args, **kwargs
70
75
  ) -> None:
76
+ """
77
+ Raise a ValueError if a condition is met. Useful in jit-compiled functions.
78
+ """
71
79
 
72
80
  return raise_if(condition, ValueError, msg, *args, **kwargs)
@@ -36,9 +36,25 @@ PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
36
36
 
37
37
 
38
38
  class SystemDynamics(Protocol[State, StateDerivative]):
39
+ """
40
+ Protocol defining the system dynamics.
41
+ """
42
+
39
43
  def __call__(
40
44
  self, x: State, t: Time, **kwargs
41
- ) -> tuple[StateDerivative, dict[str, Any]]: ...
45
+ ) -> tuple[StateDerivative, dict[str, Any]]:
46
+ """
47
+ Compute the state derivative of the system.
48
+
49
+ Args:
50
+ x: The state of the system.
51
+ t: The time of the system.
52
+ **kwargs: Additional keyword arguments.
53
+
54
+ Returns:
55
+ The state derivative of the system and the auxiliary dictionary.
56
+ """
57
+ pass
42
58
 
43
59
 
44
60
  # =======================
@@ -48,6 +64,9 @@ class SystemDynamics(Protocol[State, StateDerivative]):
48
64
 
49
65
  @jax_dataclasses.pytree_dataclass
50
66
  class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
67
+ """
68
+ Factory class for integrators.
69
+ """
51
70
 
52
71
  dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
53
72
  repr=False, hash=False, compare=False, kw_only=True
@@ -110,6 +129,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
110
129
  def __call__(
111
130
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
112
131
  ) -> tuple[NextState, dict[str, Any]]:
132
+ """
133
+ Perform a single integration step.
134
+ """
113
135
  pass
114
136
 
115
137
  def init(
@@ -121,6 +143,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
121
143
  include_dynamics_aux_dict: bool = False,
122
144
  **kwargs,
123
145
  ) -> dict[str, Any]:
146
+ """
147
+ Initialize the integrator. This method is deprecated.
148
+ """
124
149
 
125
150
  logging.warning(
126
151
  "The 'init' method has been deprecated. There is no need to call it."
@@ -131,6 +156,18 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
131
156
 
132
157
  @jax_dataclasses.pytree_dataclass
133
158
  class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
159
+ """
160
+ Base class for explicit Runge-Kutta integrators.
161
+
162
+ Attributes:
163
+ A: The Runge-Kutta matrix.
164
+ b: The weights coefficients.
165
+ c: The nodes coefficients.
166
+ order_of_bT_rows: The order of the solution.
167
+ row_index_of_solution: The row of the integration output corresponding to the final solution.
168
+ fsal_enabled_if_supported: Whether to enable the FSAL property, if supported.
169
+ index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration.
170
+ """
134
171
 
135
172
  # The Runge-Kutta matrix.
136
173
  A: ClassVar[jtp.Matrix]
@@ -156,10 +193,16 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
156
193
 
157
194
  @property
158
195
  def has_fsal(self) -> bool:
196
+ """
197
+ Check if the integrator supports the FSAL property.
198
+ """
159
199
  return self.fsal_enabled_if_supported and self.index_of_fsal is not None
160
200
 
161
201
  @property
162
202
  def order(self) -> int:
203
+ """
204
+ Return the order of the integrator.
205
+ """
163
206
  return self.order_of_bT_rows[self.row_index_of_solution]
164
207
 
165
208
  @override
@@ -221,6 +264,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
221
264
  def __call__(
222
265
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
223
266
  ) -> tuple[NextState, dict[str, Any]]:
267
+ """
268
+ Perform a single integration step.
269
+ """
224
270
 
225
271
  # Here z is a batched state with as many batch elements as b.T rows.
226
272
  # Note that z has multiple batches only if b.T has more than one row,
@@ -331,7 +377,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
331
377
  def scan_body(
332
378
  carry: jax.Array, i: int | jax.Array
333
379
  ) -> tuple[jax.Array, dict[str, Any]]:
334
- """"""
380
+ """
381
+ Compute the kᵢ derivative of the Runge-Kutta stage.
382
+ """
335
383
 
336
384
  # Unpack the carry, i.e. the stacked kᵢ vectors.
337
385
  K = carry
@@ -498,6 +546,16 @@ class ExplicitRungeKuttaSO3Mixin:
498
546
  def post_process_state(
499
547
  cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
500
548
  ) -> js.ode_data.ODEState:
549
+ r"""
550
+ Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the
551
+ quaternion is normalized.
552
+
553
+ Args:
554
+ x0: The initial state of the system.
555
+ t0: The initial time of the system.
556
+ xf: The final state of the system obtain through the integration.
557
+ dt: The time step used for the integration.
558
+ """
501
559
 
502
560
  # Extract the initial base quaternion.
503
561
  W_Q_B_t0 = x0.physics_model.base_quaternion
@@ -17,6 +17,9 @@ ODEStateDerivative = js.ode_data.ODEState
17
17
 
18
18
  @jax_dataclasses.pytree_dataclass
19
19
  class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
20
+ """
21
+ Forward Euler integrator.
22
+ """
20
23
 
21
24
  A: ClassVar[jtp.Matrix] = jnp.atleast_2d(0).astype(float)
22
25
 
@@ -30,6 +33,9 @@ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
30
33
 
31
34
  @jax_dataclasses.pytree_dataclass
32
35
  class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
36
+ """
37
+ Heun's second-order integrator.
38
+ """
33
39
 
34
40
  A: ClassVar[jtp.Matrix] = jnp.array(
35
41
  [
@@ -56,6 +62,9 @@ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
56
62
 
57
63
  @jax_dataclasses.pytree_dataclass
58
64
  class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
65
+ """
66
+ Fourth-order Runge-Kutta integrator.
67
+ """
59
68
 
60
69
  A: ClassVar[jtp.Matrix] = jnp.array(
61
70
  [
@@ -89,14 +98,26 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
89
98
 
90
99
  @jax_dataclasses.pytree_dataclass
91
100
  class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
101
+ """
102
+ Forward Euler integrator for SO(3) states.
103
+ """
104
+
92
105
  pass
93
106
 
94
107
 
95
108
  @jax_dataclasses.pytree_dataclass
96
109
  class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
110
+ """
111
+ Heun's second-order integrator for SO(3) states.
112
+ """
113
+
97
114
  pass
98
115
 
99
116
 
100
117
  @jax_dataclasses.pytree_dataclass
101
118
  class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
119
+ """
120
+ Fourth-order Runge-Kutta integrator for SO(3) states.
121
+ """
122
+
102
123
  pass
@@ -216,6 +216,17 @@ def local_error_estimation(
216
216
 
217
217
  @jax_dataclasses.pytree_dataclass
218
218
  class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
219
+ """
220
+ An Embedded Runge-Kutta integrator.
221
+
222
+ This class implements a general-purpose Embedded Runge-Kutta integrator
223
+ that can be used to solve ordinary differential equations with adaptive
224
+ step sizes.
225
+
226
+ The integrator is based on an Explicit Runge-Kutta method, and it uses
227
+ two different solutions to estimate the local integration error. The
228
+ error is then used to adapt the step size to reach a desired accuracy.
229
+ """
219
230
 
220
231
  AfterInitKey: ClassVar[str] = "after_init"
221
232
  InitializingKey: ClassVar[str] = "initializing"
@@ -257,6 +268,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
257
268
  x0: The initial state of the system.
258
269
  t0: The initial time of the system.
259
270
  dt: The time step of the integration.
271
+ **kwargs: Additional parameters.
260
272
 
261
273
  Returns:
262
274
  The metadata of the integrator to be passed to the first step.
@@ -296,6 +308,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
296
308
  def __call__(
297
309
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
298
310
  ) -> tuple[NextState, dict[str, Any]]:
311
+ """
312
+ Integrate the system for a single step.
313
+ """
299
314
 
300
315
  # This method is called differently in three stages:
301
316
  #
@@ -512,10 +527,16 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
512
527
 
513
528
  @property
514
529
  def order_of_solution(self) -> int:
530
+ """
531
+ The order of the solution.
532
+ """
515
533
  return self.order_of_bT_rows[self.row_index_of_solution]
516
534
 
517
535
  @property
518
536
  def order_of_solution_estimate(self) -> int:
537
+ """
538
+ The order of the solution estimate.
539
+ """
519
540
  return self.order_of_bT_rows[self.row_index_of_solution_estimate]
520
541
 
521
542
  @classmethod
@@ -534,6 +555,23 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
534
555
  max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT,
535
556
  **kwargs,
536
557
  ) -> Self:
558
+ """
559
+ Build an Embedded Runge-Kutta integrator.
560
+
561
+ Args:
562
+ dynamics: The system dynamics function.
563
+ fsal_enabled_if_supported:
564
+ Whether to enable the FSAL property if supported by the integrator.
565
+ dt_max: The maximum step size.
566
+ dt_min: The minimum step size.
567
+ rtol: The relative tolerance.
568
+ atol: The absolute tolerance.
569
+ safety: The safety factor to shrink the step size.
570
+ beta_max: The maximum factor to increase the step size.
571
+ beta_min: The minimum factor to increase the step size.
572
+ max_step_rejections: The maximum number of step rejections.
573
+ **kwargs: Additional parameters.
574
+ """
537
575
 
538
576
  # Check that b.T has enough rows based on the configured index of the
539
577
  # solution estimate. This is necessary for embedded methods.
@@ -569,6 +607,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
569
607
 
570
608
  @jax_dataclasses.pytree_dataclass
571
609
  class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
610
+ """
611
+ The Heun-Euler integrator for SO(3) dynamics.
612
+ """
572
613
 
573
614
  A: ClassVar[jtp.Matrix] = jnp.array(
574
615
  [
@@ -602,6 +643,9 @@ class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
602
643
 
603
644
  @jax_dataclasses.pytree_dataclass
604
645
  class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
646
+ """
647
+ The Bogacki-Shampine integrator for SO(3) dynamics.
648
+ """
605
649
 
606
650
  A: ClassVar[jtp.Matrix] = jnp.array(
607
651
  [
jaxsim/math/adjoint.py CHANGED
@@ -7,6 +7,10 @@ from .skew import Skew
7
7
 
8
8
 
9
9
  class Adjoint:
10
+ """
11
+ A utility class for adjoint matrix operations.
12
+ """
13
+
10
14
  @staticmethod
11
15
  def from_quaternion_and_translation(
12
16
  quaternion: jtp.Vector = jnp.array([1.0, 0, 0, 0]),
@@ -18,11 +22,10 @@ class Adjoint:
18
22
  Create an adjoint matrix from a quaternion and a translation.
19
23
 
20
24
  Args:
21
- quaternion (jtp.Vector): A quaternion vector (4D) representing orientation.
22
- translation (jtp.Vector): A translation vector (3D).
23
- inverse (bool): Whether to compute the inverse adjoint. Default is False.
24
- normalize_quaternion (bool): Whether to normalize the quaternion before creating the adjoint.
25
- Default is False.
25
+ quaternion: A quaternion vector (4D) representing orientation.
26
+ translation: A translation vector (3D).
27
+ inverse: Whether to compute the inverse adjoint.
28
+ normalize_quaternion: Whether to normalize the quaternion before creating the adjoint.
26
29
 
27
30
  Returns:
28
31
  jtp.Matrix: The adjoint matrix.
@@ -69,9 +72,9 @@ class Adjoint:
69
72
  Create an adjoint matrix from a rotation matrix and a translation vector.
70
73
 
71
74
  Args:
72
- rotation (jtp.Matrix): A 3x3 rotation matrix.
73
- translation (jtp.Vector): A translation vector (3D).
74
- inverse (bool): Whether to compute the inverse adjoint. Default is False.
75
+ rotation: A 3x3 rotation matrix.
76
+ translation: A translation vector (3D).
77
+ inverse: Whether to compute the inverse adjoint. Default is False.
75
78
 
76
79
  Returns:
77
80
  jtp.Matrix: The adjoint matrix.
@@ -105,7 +108,7 @@ class Adjoint:
105
108
  Convert an adjoint matrix to a transformation matrix.
106
109
 
107
110
  Args:
108
- adjoint (jtp.Matrix): The adjoint matrix (6x6).
111
+ adjoint: The adjoint matrix (6x6).
109
112
 
110
113
  Returns:
111
114
  jtp.Matrix: The transformation matrix (4x4).
@@ -131,7 +134,7 @@ class Adjoint:
131
134
  Compute the inverse of an adjoint matrix.
132
135
 
133
136
  Args:
134
- adjoint (jtp.Matrix): The adjoint matrix.
137
+ adjoint: The adjoint matrix.
135
138
 
136
139
  Returns:
137
140
  jtp.Matrix: The inverse adjoint matrix.
jaxsim/math/cross.py CHANGED
@@ -6,13 +6,17 @@ from .skew import Skew
6
6
 
7
7
 
8
8
  class Cross:
9
+ """
10
+ A utility class for cross product matrix operations.
11
+ """
12
+
9
13
  @staticmethod
10
14
  def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
11
15
  """
12
16
  Compute the cross product matrix for 6D velocities.
13
17
 
14
18
  Args:
15
- velocity_sixd (jtp.Vector): A 6D velocity vector [v, ω].
19
+ velocity_sixd: A 6D velocity vector [v, ω].
16
20
 
17
21
  Returns:
18
22
  jtp.Matrix: The cross product matrix (6x6).
@@ -37,7 +41,7 @@ class Cross:
37
41
  Compute the negative transpose of the cross product matrix for 6D velocities.
38
42
 
39
43
  Args:
40
- velocity_sixd (jtp.Vector): A 6D velocity vector [v, ω].
44
+ velocity_sixd: A 6D velocity vector [v, ω].
41
45
 
42
46
  Returns:
43
47
  jtp.Matrix: The negative transpose of the cross product matrix (6x6).
jaxsim/math/inertia.py CHANGED
@@ -6,15 +6,19 @@ from .skew import Skew
6
6
 
7
7
 
8
8
  class Inertia:
9
+ """
10
+ A utility class for inertia matrix operations.
11
+ """
12
+
9
13
  @staticmethod
10
14
  def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:
11
15
  """
12
16
  Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix.
13
17
 
14
18
  Args:
15
- mass (jtp.Float): The mass of the body.
16
- com (jtp.Vector): The center of mass position (3D).
17
- I (jtp.Matrix): The 3x3 inertia matrix.
19
+ mass: The mass of the body.
20
+ com: The center of mass position (3D).
21
+ I: The 3x3 inertia matrix.
18
22
 
19
23
  Returns:
20
24
  jtp.Matrix: The 6x6 inertia matrix.
@@ -42,7 +46,7 @@ class Inertia:
42
46
  Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.
43
47
 
44
48
  Args:
45
- M (jtp.Matrix): The 6x6 inertia matrix.
49
+ M: The 6x6 inertia matrix.
46
50
 
47
51
  Returns:
48
52
  tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).