jaxsim 0.5.1.dev126__py3-none-any.whl → 0.5.1.dev133__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/com.py +1 -1
- jaxsim/api/common.py +1 -1
- jaxsim/api/contact.py +3 -0
- jaxsim/api/data.py +2 -1
- jaxsim/api/kin_dyn_parameters.py +18 -1
- jaxsim/api/model.py +7 -4
- jaxsim/api/ode.py +21 -1
- jaxsim/exceptions.py +8 -0
- jaxsim/integrators/common.py +60 -2
- jaxsim/integrators/fixed_step.py +21 -0
- jaxsim/integrators/variable_step.py +44 -0
- jaxsim/math/adjoint.py +13 -10
- jaxsim/math/cross.py +6 -2
- jaxsim/math/inertia.py +8 -4
- jaxsim/math/quaternion.py +10 -6
- jaxsim/math/rotation.py +6 -3
- jaxsim/math/skew.py +2 -2
- jaxsim/math/transform.py +3 -0
- jaxsim/math/utils.py +2 -2
- jaxsim/mujoco/loaders.py +17 -7
- jaxsim/mujoco/model.py +15 -15
- jaxsim/mujoco/utils.py +6 -1
- jaxsim/mujoco/visualizer.py +11 -7
- jaxsim/parsers/descriptions/collision.py +7 -4
- jaxsim/parsers/descriptions/joint.py +16 -14
- jaxsim/parsers/descriptions/model.py +1 -1
- jaxsim/parsers/kinematic_graph.py +38 -0
- jaxsim/parsers/rod/meshes.py +5 -5
- jaxsim/parsers/rod/parser.py +1 -1
- jaxsim/parsers/rod/utils.py +11 -0
- jaxsim/rbda/contacts/common.py +2 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +7 -4
- jaxsim/rbda/contacts/rigid.py +8 -4
- jaxsim/rbda/contacts/soft.py +37 -0
- jaxsim/rbda/contacts/visco_elastic.py +1 -0
- jaxsim/terrain/terrain.py +52 -0
- jaxsim/utils/jaxsim_dataclass.py +3 -3
- jaxsim/utils/tracing.py +2 -2
- jaxsim/utils/wrappers.py +9 -0
- {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev133.dist-info}/METADATA +1 -1
- jaxsim-0.5.1.dev133.dist-info/RECORD +74 -0
- jaxsim-0.5.1.dev126.dist-info/RECORD +0 -74
- {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev133.dist-info}/LICENSE +0 -0
- {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev133.dist-info}/WHEEL +0 -0
- {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev133.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.5.1.
|
16
|
-
__version_tuple__ = version_tuple = (0, 5, 1, '
|
15
|
+
__version__ = version = '0.5.1.dev133'
|
16
|
+
__version_tuple__ = version_tuple = (0, 5, 1, 'dev133')
|
jaxsim/api/com.py
CHANGED
@@ -279,7 +279,7 @@ def bias_acceleration(
|
|
279
279
|
C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
|
280
280
|
) -> jtp.Vector:
|
281
281
|
"""
|
282
|
-
|
282
|
+
Convert the body-fixed representation of the link bias acceleration
|
283
283
|
C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
|
284
284
|
"""
|
285
285
|
|
jaxsim/api/common.py
CHANGED
@@ -26,7 +26,7 @@ _R = TypeVar("_R")
|
|
26
26
|
|
27
27
|
|
28
28
|
def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
|
29
|
-
"""
|
29
|
+
"""Apply a JAX named scope to a function for improved profiling and clarity."""
|
30
30
|
|
31
31
|
@functools.wraps(fn)
|
32
32
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
jaxsim/api/contact.py
CHANGED
@@ -293,6 +293,9 @@ def in_contact(
|
|
293
293
|
def estimate_good_soft_contacts_parameters(
|
294
294
|
*args, **kwargs
|
295
295
|
) -> jaxsim.rbda.contacts.ContactParamsTypes:
|
296
|
+
"""
|
297
|
+
Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.
|
298
|
+
"""
|
296
299
|
|
297
300
|
msg = "This method is deprecated, please use `{}`."
|
298
301
|
logging.warning(msg.format(estimate_good_contact_parameters.__name__))
|
jaxsim/api/data.py
CHANGED
@@ -456,7 +456,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
456
456
|
@jax.jit
|
457
457
|
def generalized_velocity(self) -> jtp.Vector:
|
458
458
|
r"""
|
459
|
-
Get the generalized velocity
|
459
|
+
Get the generalized velocity.
|
460
|
+
|
460
461
|
:math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`
|
461
462
|
|
462
463
|
Returns:
|
jaxsim/api/kin_dyn_parameters.py
CHANGED
@@ -52,10 +52,16 @@ class KinDynParameters(JaxsimDataclass):
|
|
52
52
|
|
53
53
|
@property
|
54
54
|
def parent_array(self) -> jtp.Vector:
|
55
|
+
r"""
|
56
|
+
Return the parent array :math:`\lambda(i)` of the model.
|
57
|
+
"""
|
55
58
|
return self._parent_array.get()
|
56
59
|
|
57
60
|
@property
|
58
61
|
def support_body_array_bool(self) -> jtp.Matrix:
|
62
|
+
r"""
|
63
|
+
Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
|
64
|
+
"""
|
59
65
|
return self._support_body_array_bool.get()
|
60
66
|
|
61
67
|
@staticmethod
|
@@ -648,7 +654,16 @@ class LinkParameters(JaxsimDataclass):
|
|
648
654
|
def build_from_flat_parameters(
|
649
655
|
index: jtp.IntLike, parameters: jtp.VectorLike
|
650
656
|
) -> LinkParameters:
|
657
|
+
"""
|
658
|
+
Build a LinkParameters object from a flat vector of parameters.
|
659
|
+
|
660
|
+
Args:
|
661
|
+
index: The index of the link.
|
662
|
+
parameters: The flat vector of parameters.
|
651
663
|
|
664
|
+
Returns:
|
665
|
+
The LinkParameters object.
|
666
|
+
"""
|
652
667
|
index = jnp.array(index).squeeze().astype(int)
|
653
668
|
|
654
669
|
m = jnp.array(parameters[0]).squeeze().astype(float)
|
@@ -772,7 +787,9 @@ class ContactParameters(JaxsimDataclass):
|
|
772
787
|
|
773
788
|
@property
|
774
789
|
def indices_of_enabled_collidable_points(self) -> npt.NDArray:
|
775
|
-
|
790
|
+
"""
|
791
|
+
Return the indices of the enabled collidable points.
|
792
|
+
"""
|
776
793
|
return np.where(np.array(self.enabled))[0]
|
777
794
|
|
778
795
|
@staticmethod
|
jaxsim/api/model.py
CHANGED
@@ -63,6 +63,9 @@ class JaxSimModel(JaxsimDataclass):
|
|
63
63
|
|
64
64
|
@property
|
65
65
|
def description(self) -> ModelDescription:
|
66
|
+
"""
|
67
|
+
Return the model description.
|
68
|
+
"""
|
66
69
|
return self._description.get()
|
67
70
|
|
68
71
|
def __eq__(self, other: JaxSimModel) -> bool:
|
@@ -1015,7 +1018,7 @@ def forward_dynamics_aba(
|
|
1015
1018
|
W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
|
1016
1019
|
) -> jtp.Vector:
|
1017
1020
|
"""
|
1018
|
-
|
1021
|
+
Convert the inertial-fixed apparent base acceleration W_v̇_WB to
|
1019
1022
|
another representation C_v̇_WB expressed in a generic frame C.
|
1020
1023
|
"""
|
1021
1024
|
|
@@ -1376,7 +1379,7 @@ def inverse_dynamics(
|
|
1376
1379
|
|
1377
1380
|
def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
|
1378
1381
|
"""
|
1379
|
-
|
1382
|
+
Convert the active representation of the base acceleration C_v̇_WB
|
1380
1383
|
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
|
1381
1384
|
"""
|
1382
1385
|
|
@@ -1825,7 +1828,7 @@ def link_bias_accelerations(
|
|
1825
1828
|
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
|
1826
1829
|
) -> jtp.Vector:
|
1827
1830
|
"""
|
1828
|
-
|
1831
|
+
Convert the active representation of the base acceleration C_v̇_WB
|
1829
1832
|
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
|
1830
1833
|
"""
|
1831
1834
|
|
@@ -1961,7 +1964,7 @@ def link_bias_accelerations(
|
|
1961
1964
|
L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
|
1962
1965
|
) -> jtp.Vector:
|
1963
1966
|
"""
|
1964
|
-
|
1967
|
+
Convert the body-fixed apparent acceleration L_v̇_WL to
|
1965
1968
|
another representation C_v̇_WL expressed in a generic frame C.
|
1966
1969
|
"""
|
1967
1970
|
|
jaxsim/api/ode.py
CHANGED
@@ -15,12 +15,32 @@ from .ode_data import ODEState
|
|
15
15
|
|
16
16
|
|
17
17
|
class SystemDynamicsFromModelAndData(Protocol):
|
18
|
+
"""
|
19
|
+
Protocol defining the signature of a function computing the system dynamics
|
20
|
+
given a model and data object.
|
21
|
+
"""
|
22
|
+
|
18
23
|
def __call__(
|
19
24
|
self,
|
20
25
|
model: js.model.JaxSimModel,
|
21
26
|
data: js.data.JaxSimModelData,
|
22
27
|
**kwargs: dict[str, Any],
|
23
|
-
) -> tuple[ODEState, dict[str, Any]]:
|
28
|
+
) -> tuple[ODEState, dict[str, Any]]:
|
29
|
+
"""
|
30
|
+
Compute the system dynamics given a model and data object.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
model: The model to consider.
|
34
|
+
data: The data of the considered model.
|
35
|
+
**kwargs: Additional keyword arguments.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
A tuple with an `ODEState` object storing in each of its attributes the
|
39
|
+
corresponding derivative, and the dictionary of auxiliary data returned
|
40
|
+
by the system dynamics evaluation.
|
41
|
+
"""
|
42
|
+
|
43
|
+
pass
|
24
44
|
|
25
45
|
|
26
46
|
def wrap_system_dynamics_for_integration(
|
jaxsim/exceptions.py
CHANGED
@@ -17,6 +17,8 @@ def raise_if(
|
|
17
17
|
msg:
|
18
18
|
The message to display when the exception is raised. The message can be a
|
19
19
|
format string (fmt), whose fields are filled with the args and kwargs.
|
20
|
+
*args: The arguments to fill the format string.
|
21
|
+
**kwargs: The keyword arguments to fill the format string
|
20
22
|
"""
|
21
23
|
|
22
24
|
# Disable host callback if running on unsupported hardware or if the user
|
@@ -61,6 +63,9 @@ def raise_if(
|
|
61
63
|
def raise_runtime_error_if(
|
62
64
|
condition: bool | jax.Array, msg: str, *args, **kwargs
|
63
65
|
) -> None:
|
66
|
+
"""
|
67
|
+
Raise a RuntimeError if a condition is met. Useful in jit-compiled functions.
|
68
|
+
"""
|
64
69
|
|
65
70
|
return raise_if(condition, RuntimeError, msg, *args, **kwargs)
|
66
71
|
|
@@ -68,5 +73,8 @@ def raise_runtime_error_if(
|
|
68
73
|
def raise_value_error_if(
|
69
74
|
condition: bool | jax.Array, msg: str, *args, **kwargs
|
70
75
|
) -> None:
|
76
|
+
"""
|
77
|
+
Raise a ValueError if a condition is met. Useful in jit-compiled functions.
|
78
|
+
"""
|
71
79
|
|
72
80
|
return raise_if(condition, ValueError, msg, *args, **kwargs)
|
jaxsim/integrators/common.py
CHANGED
@@ -36,9 +36,25 @@ PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
|
|
36
36
|
|
37
37
|
|
38
38
|
class SystemDynamics(Protocol[State, StateDerivative]):
|
39
|
+
"""
|
40
|
+
Protocol defining the system dynamics.
|
41
|
+
"""
|
42
|
+
|
39
43
|
def __call__(
|
40
44
|
self, x: State, t: Time, **kwargs
|
41
|
-
) -> tuple[StateDerivative, dict[str, Any]]:
|
45
|
+
) -> tuple[StateDerivative, dict[str, Any]]:
|
46
|
+
"""
|
47
|
+
Compute the state derivative of the system.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
x: The state of the system.
|
51
|
+
t: The time of the system.
|
52
|
+
**kwargs: Additional keyword arguments.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The state derivative of the system and the auxiliary dictionary.
|
56
|
+
"""
|
57
|
+
pass
|
42
58
|
|
43
59
|
|
44
60
|
# =======================
|
@@ -48,6 +64,9 @@ class SystemDynamics(Protocol[State, StateDerivative]):
|
|
48
64
|
|
49
65
|
@jax_dataclasses.pytree_dataclass
|
50
66
|
class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
67
|
+
"""
|
68
|
+
Factory class for integrators.
|
69
|
+
"""
|
51
70
|
|
52
71
|
dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
|
53
72
|
repr=False, hash=False, compare=False, kw_only=True
|
@@ -110,6 +129,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
110
129
|
def __call__(
|
111
130
|
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
112
131
|
) -> tuple[NextState, dict[str, Any]]:
|
132
|
+
"""
|
133
|
+
Perform a single integration step.
|
134
|
+
"""
|
113
135
|
pass
|
114
136
|
|
115
137
|
def init(
|
@@ -121,6 +143,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
121
143
|
include_dynamics_aux_dict: bool = False,
|
122
144
|
**kwargs,
|
123
145
|
) -> dict[str, Any]:
|
146
|
+
"""
|
147
|
+
Initialize the integrator. This method is deprecated.
|
148
|
+
"""
|
124
149
|
|
125
150
|
logging.warning(
|
126
151
|
"The 'init' method has been deprecated. There is no need to call it."
|
@@ -131,6 +156,18 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
131
156
|
|
132
157
|
@jax_dataclasses.pytree_dataclass
|
133
158
|
class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
|
159
|
+
"""
|
160
|
+
Base class for explicit Runge-Kutta integrators.
|
161
|
+
|
162
|
+
Attributes:
|
163
|
+
A: The Runge-Kutta matrix.
|
164
|
+
b: The weights coefficients.
|
165
|
+
c: The nodes coefficients.
|
166
|
+
order_of_bT_rows: The order of the solution.
|
167
|
+
row_index_of_solution: The row of the integration output corresponding to the final solution.
|
168
|
+
fsal_enabled_if_supported: Whether to enable the FSAL property, if supported.
|
169
|
+
index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration.
|
170
|
+
"""
|
134
171
|
|
135
172
|
# The Runge-Kutta matrix.
|
136
173
|
A: ClassVar[jtp.Matrix]
|
@@ -156,10 +193,16 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
156
193
|
|
157
194
|
@property
|
158
195
|
def has_fsal(self) -> bool:
|
196
|
+
"""
|
197
|
+
Check if the integrator supports the FSAL property.
|
198
|
+
"""
|
159
199
|
return self.fsal_enabled_if_supported and self.index_of_fsal is not None
|
160
200
|
|
161
201
|
@property
|
162
202
|
def order(self) -> int:
|
203
|
+
"""
|
204
|
+
Return the order of the integrator.
|
205
|
+
"""
|
163
206
|
return self.order_of_bT_rows[self.row_index_of_solution]
|
164
207
|
|
165
208
|
@override
|
@@ -221,6 +264,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
221
264
|
def __call__(
|
222
265
|
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
223
266
|
) -> tuple[NextState, dict[str, Any]]:
|
267
|
+
"""
|
268
|
+
Perform a single integration step.
|
269
|
+
"""
|
224
270
|
|
225
271
|
# Here z is a batched state with as many batch elements as b.T rows.
|
226
272
|
# Note that z has multiple batches only if b.T has more than one row,
|
@@ -331,7 +377,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
331
377
|
def scan_body(
|
332
378
|
carry: jax.Array, i: int | jax.Array
|
333
379
|
) -> tuple[jax.Array, dict[str, Any]]:
|
334
|
-
"""
|
380
|
+
"""
|
381
|
+
Compute the kᵢ derivative of the Runge-Kutta stage.
|
382
|
+
"""
|
335
383
|
|
336
384
|
# Unpack the carry, i.e. the stacked kᵢ vectors.
|
337
385
|
K = carry
|
@@ -498,6 +546,16 @@ class ExplicitRungeKuttaSO3Mixin:
|
|
498
546
|
def post_process_state(
|
499
547
|
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
|
500
548
|
) -> js.ode_data.ODEState:
|
549
|
+
r"""
|
550
|
+
Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the
|
551
|
+
quaternion is normalized.
|
552
|
+
|
553
|
+
Args:
|
554
|
+
x0: The initial state of the system.
|
555
|
+
t0: The initial time of the system.
|
556
|
+
xf: The final state of the system obtain through the integration.
|
557
|
+
dt: The time step used for the integration.
|
558
|
+
"""
|
501
559
|
|
502
560
|
# Extract the initial base quaternion.
|
503
561
|
W_Q_B_t0 = x0.physics_model.base_quaternion
|
jaxsim/integrators/fixed_step.py
CHANGED
@@ -17,6 +17,9 @@ ODEStateDerivative = js.ode_data.ODEState
|
|
17
17
|
|
18
18
|
@jax_dataclasses.pytree_dataclass
|
19
19
|
class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
20
|
+
"""
|
21
|
+
Forward Euler integrator.
|
22
|
+
"""
|
20
23
|
|
21
24
|
A: ClassVar[jtp.Matrix] = jnp.atleast_2d(0).astype(float)
|
22
25
|
|
@@ -30,6 +33,9 @@ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
30
33
|
|
31
34
|
@jax_dataclasses.pytree_dataclass
|
32
35
|
class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
36
|
+
"""
|
37
|
+
Heun's second-order integrator.
|
38
|
+
"""
|
33
39
|
|
34
40
|
A: ClassVar[jtp.Matrix] = jnp.array(
|
35
41
|
[
|
@@ -56,6 +62,9 @@ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
56
62
|
|
57
63
|
@jax_dataclasses.pytree_dataclass
|
58
64
|
class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
65
|
+
"""
|
66
|
+
Fourth-order Runge-Kutta integrator.
|
67
|
+
"""
|
59
68
|
|
60
69
|
A: ClassVar[jtp.Matrix] = jnp.array(
|
61
70
|
[
|
@@ -89,14 +98,26 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
89
98
|
|
90
99
|
@jax_dataclasses.pytree_dataclass
|
91
100
|
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
|
101
|
+
"""
|
102
|
+
Forward Euler integrator for SO(3) states.
|
103
|
+
"""
|
104
|
+
|
92
105
|
pass
|
93
106
|
|
94
107
|
|
95
108
|
@jax_dataclasses.pytree_dataclass
|
96
109
|
class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
|
110
|
+
"""
|
111
|
+
Heun's second-order integrator for SO(3) states.
|
112
|
+
"""
|
113
|
+
|
97
114
|
pass
|
98
115
|
|
99
116
|
|
100
117
|
@jax_dataclasses.pytree_dataclass
|
101
118
|
class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
|
119
|
+
"""
|
120
|
+
Fourth-order Runge-Kutta integrator for SO(3) states.
|
121
|
+
"""
|
122
|
+
|
102
123
|
pass
|
@@ -216,6 +216,17 @@ def local_error_estimation(
|
|
216
216
|
|
217
217
|
@jax_dataclasses.pytree_dataclass
|
218
218
|
class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
219
|
+
"""
|
220
|
+
An Embedded Runge-Kutta integrator.
|
221
|
+
|
222
|
+
This class implements a general-purpose Embedded Runge-Kutta integrator
|
223
|
+
that can be used to solve ordinary differential equations with adaptive
|
224
|
+
step sizes.
|
225
|
+
|
226
|
+
The integrator is based on an Explicit Runge-Kutta method, and it uses
|
227
|
+
two different solutions to estimate the local integration error. The
|
228
|
+
error is then used to adapt the step size to reach a desired accuracy.
|
229
|
+
"""
|
219
230
|
|
220
231
|
AfterInitKey: ClassVar[str] = "after_init"
|
221
232
|
InitializingKey: ClassVar[str] = "initializing"
|
@@ -257,6 +268,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
257
268
|
x0: The initial state of the system.
|
258
269
|
t0: The initial time of the system.
|
259
270
|
dt: The time step of the integration.
|
271
|
+
**kwargs: Additional parameters.
|
260
272
|
|
261
273
|
Returns:
|
262
274
|
The metadata of the integrator to be passed to the first step.
|
@@ -296,6 +308,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
296
308
|
def __call__(
|
297
309
|
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
298
310
|
) -> tuple[NextState, dict[str, Any]]:
|
311
|
+
"""
|
312
|
+
Integrate the system for a single step.
|
313
|
+
"""
|
299
314
|
|
300
315
|
# This method is called differently in three stages:
|
301
316
|
#
|
@@ -512,10 +527,16 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
512
527
|
|
513
528
|
@property
|
514
529
|
def order_of_solution(self) -> int:
|
530
|
+
"""
|
531
|
+
The order of the solution.
|
532
|
+
"""
|
515
533
|
return self.order_of_bT_rows[self.row_index_of_solution]
|
516
534
|
|
517
535
|
@property
|
518
536
|
def order_of_solution_estimate(self) -> int:
|
537
|
+
"""
|
538
|
+
The order of the solution estimate.
|
539
|
+
"""
|
519
540
|
return self.order_of_bT_rows[self.row_index_of_solution_estimate]
|
520
541
|
|
521
542
|
@classmethod
|
@@ -534,6 +555,23 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
534
555
|
max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT,
|
535
556
|
**kwargs,
|
536
557
|
) -> Self:
|
558
|
+
"""
|
559
|
+
Build an Embedded Runge-Kutta integrator.
|
560
|
+
|
561
|
+
Args:
|
562
|
+
dynamics: The system dynamics function.
|
563
|
+
fsal_enabled_if_supported:
|
564
|
+
Whether to enable the FSAL property if supported by the integrator.
|
565
|
+
dt_max: The maximum step size.
|
566
|
+
dt_min: The minimum step size.
|
567
|
+
rtol: The relative tolerance.
|
568
|
+
atol: The absolute tolerance.
|
569
|
+
safety: The safety factor to shrink the step size.
|
570
|
+
beta_max: The maximum factor to increase the step size.
|
571
|
+
beta_min: The minimum factor to increase the step size.
|
572
|
+
max_step_rejections: The maximum number of step rejections.
|
573
|
+
**kwargs: Additional parameters.
|
574
|
+
"""
|
537
575
|
|
538
576
|
# Check that b.T has enough rows based on the configured index of the
|
539
577
|
# solution estimate. This is necessary for embedded methods.
|
@@ -569,6 +607,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
569
607
|
|
570
608
|
@jax_dataclasses.pytree_dataclass
|
571
609
|
class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
|
610
|
+
"""
|
611
|
+
The Heun-Euler integrator for SO(3) dynamics.
|
612
|
+
"""
|
572
613
|
|
573
614
|
A: ClassVar[jtp.Matrix] = jnp.array(
|
574
615
|
[
|
@@ -602,6 +643,9 @@ class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
|
|
602
643
|
|
603
644
|
@jax_dataclasses.pytree_dataclass
|
604
645
|
class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
|
646
|
+
"""
|
647
|
+
The Bogacki-Shampine integrator for SO(3) dynamics.
|
648
|
+
"""
|
605
649
|
|
606
650
|
A: ClassVar[jtp.Matrix] = jnp.array(
|
607
651
|
[
|
jaxsim/math/adjoint.py
CHANGED
@@ -7,6 +7,10 @@ from .skew import Skew
|
|
7
7
|
|
8
8
|
|
9
9
|
class Adjoint:
|
10
|
+
"""
|
11
|
+
A utility class for adjoint matrix operations.
|
12
|
+
"""
|
13
|
+
|
10
14
|
@staticmethod
|
11
15
|
def from_quaternion_and_translation(
|
12
16
|
quaternion: jtp.Vector = jnp.array([1.0, 0, 0, 0]),
|
@@ -18,11 +22,10 @@ class Adjoint:
|
|
18
22
|
Create an adjoint matrix from a quaternion and a translation.
|
19
23
|
|
20
24
|
Args:
|
21
|
-
quaternion
|
22
|
-
translation
|
23
|
-
inverse
|
24
|
-
normalize_quaternion
|
25
|
-
Default is False.
|
25
|
+
quaternion: A quaternion vector (4D) representing orientation.
|
26
|
+
translation: A translation vector (3D).
|
27
|
+
inverse: Whether to compute the inverse adjoint.
|
28
|
+
normalize_quaternion: Whether to normalize the quaternion before creating the adjoint.
|
26
29
|
|
27
30
|
Returns:
|
28
31
|
jtp.Matrix: The adjoint matrix.
|
@@ -69,9 +72,9 @@ class Adjoint:
|
|
69
72
|
Create an adjoint matrix from a rotation matrix and a translation vector.
|
70
73
|
|
71
74
|
Args:
|
72
|
-
rotation
|
73
|
-
translation
|
74
|
-
inverse
|
75
|
+
rotation: A 3x3 rotation matrix.
|
76
|
+
translation: A translation vector (3D).
|
77
|
+
inverse: Whether to compute the inverse adjoint. Default is False.
|
75
78
|
|
76
79
|
Returns:
|
77
80
|
jtp.Matrix: The adjoint matrix.
|
@@ -105,7 +108,7 @@ class Adjoint:
|
|
105
108
|
Convert an adjoint matrix to a transformation matrix.
|
106
109
|
|
107
110
|
Args:
|
108
|
-
adjoint
|
111
|
+
adjoint: The adjoint matrix (6x6).
|
109
112
|
|
110
113
|
Returns:
|
111
114
|
jtp.Matrix: The transformation matrix (4x4).
|
@@ -131,7 +134,7 @@ class Adjoint:
|
|
131
134
|
Compute the inverse of an adjoint matrix.
|
132
135
|
|
133
136
|
Args:
|
134
|
-
adjoint
|
137
|
+
adjoint: The adjoint matrix.
|
135
138
|
|
136
139
|
Returns:
|
137
140
|
jtp.Matrix: The inverse adjoint matrix.
|
jaxsim/math/cross.py
CHANGED
@@ -6,13 +6,17 @@ from .skew import Skew
|
|
6
6
|
|
7
7
|
|
8
8
|
class Cross:
|
9
|
+
"""
|
10
|
+
A utility class for cross product matrix operations.
|
11
|
+
"""
|
12
|
+
|
9
13
|
@staticmethod
|
10
14
|
def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
|
11
15
|
"""
|
12
16
|
Compute the cross product matrix for 6D velocities.
|
13
17
|
|
14
18
|
Args:
|
15
|
-
velocity_sixd
|
19
|
+
velocity_sixd: A 6D velocity vector [v, ω].
|
16
20
|
|
17
21
|
Returns:
|
18
22
|
jtp.Matrix: The cross product matrix (6x6).
|
@@ -37,7 +41,7 @@ class Cross:
|
|
37
41
|
Compute the negative transpose of the cross product matrix for 6D velocities.
|
38
42
|
|
39
43
|
Args:
|
40
|
-
velocity_sixd
|
44
|
+
velocity_sixd: A 6D velocity vector [v, ω].
|
41
45
|
|
42
46
|
Returns:
|
43
47
|
jtp.Matrix: The negative transpose of the cross product matrix (6x6).
|
jaxsim/math/inertia.py
CHANGED
@@ -6,15 +6,19 @@ from .skew import Skew
|
|
6
6
|
|
7
7
|
|
8
8
|
class Inertia:
|
9
|
+
"""
|
10
|
+
A utility class for inertia matrix operations.
|
11
|
+
"""
|
12
|
+
|
9
13
|
@staticmethod
|
10
14
|
def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:
|
11
15
|
"""
|
12
16
|
Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix.
|
13
17
|
|
14
18
|
Args:
|
15
|
-
mass
|
16
|
-
com
|
17
|
-
I
|
19
|
+
mass: The mass of the body.
|
20
|
+
com: The center of mass position (3D).
|
21
|
+
I: The 3x3 inertia matrix.
|
18
22
|
|
19
23
|
Returns:
|
20
24
|
jtp.Matrix: The 6x6 inertia matrix.
|
@@ -42,7 +46,7 @@ class Inertia:
|
|
42
46
|
Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.
|
43
47
|
|
44
48
|
Args:
|
45
|
-
M
|
49
|
+
M: The 6x6 inertia matrix.
|
46
50
|
|
47
51
|
Returns:
|
48
52
|
tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
|
jaxsim/math/quaternion.py
CHANGED
@@ -8,13 +8,17 @@ from .utils import safe_norm
|
|
8
8
|
|
9
9
|
|
10
10
|
class Quaternion:
|
11
|
+
"""
|
12
|
+
A utility class for quaternion operations.
|
13
|
+
"""
|
14
|
+
|
11
15
|
@staticmethod
|
12
16
|
def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:
|
13
17
|
"""
|
14
18
|
Convert a quaternion from WXYZ to XYZW representation.
|
15
19
|
|
16
20
|
Args:
|
17
|
-
wxyz
|
21
|
+
wxyz: Quaternion in WXYZ representation.
|
18
22
|
|
19
23
|
Returns:
|
20
24
|
jtp.Vector: Quaternion in XYZW representation.
|
@@ -27,7 +31,7 @@ class Quaternion:
|
|
27
31
|
Convert a quaternion from XYZW to WXYZ representation.
|
28
32
|
|
29
33
|
Args:
|
30
|
-
xyzw
|
34
|
+
xyzw: Quaternion in XYZW representation.
|
31
35
|
|
32
36
|
Returns:
|
33
37
|
jtp.Vector: Quaternion in WXYZ representation.
|
@@ -40,7 +44,7 @@ class Quaternion:
|
|
40
44
|
Convert a quaternion to a direction cosine matrix (DCM).
|
41
45
|
|
42
46
|
Args:
|
43
|
-
quaternion
|
47
|
+
quaternion: Quaternion in XYZW representation.
|
44
48
|
|
45
49
|
Returns:
|
46
50
|
jtp.Matrix: Direction cosine matrix (DCM).
|
@@ -53,7 +57,7 @@ class Quaternion:
|
|
53
57
|
Convert a direction cosine matrix (DCM) to a quaternion.
|
54
58
|
|
55
59
|
Args:
|
56
|
-
dcm
|
60
|
+
dcm: Direction cosine matrix (DCM).
|
57
61
|
|
58
62
|
Returns:
|
59
63
|
jtp.Vector: Quaternion in XYZW representation.
|
@@ -71,8 +75,8 @@ class Quaternion:
|
|
71
75
|
Compute the derivative of a quaternion given angular velocity.
|
72
76
|
|
73
77
|
Args:
|
74
|
-
quaternion
|
75
|
-
omega
|
78
|
+
quaternion: Quaternion in XYZW representation.
|
79
|
+
omega: Angular velocity vector.
|
76
80
|
omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame.
|
77
81
|
K (float): A scaling factor.
|
78
82
|
|