jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +64 -30
- jaxsim/math/cross.py +18 -9
- jaxsim/math/inertia.py +11 -9
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +59 -25
- jaxsim/math/rotation.py +30 -24
- jaxsim/math/skew.py +18 -7
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/integrators/fixed_step.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1
|
+
import dataclasses
|
1
2
|
from typing import ClassVar, Generic
|
2
3
|
|
3
|
-
import jax
|
4
4
|
import jax.numpy as jnp
|
5
5
|
import jax_dataclasses
|
6
|
-
import jaxlie
|
7
6
|
|
8
|
-
|
7
|
+
import jaxsim.api as js
|
8
|
+
import jaxsim.typing as jtp
|
9
9
|
|
10
|
-
from .common import ExplicitRungeKutta,
|
11
|
-
|
12
|
-
ODEStateDerivative = ODEState
|
10
|
+
from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
|
13
11
|
|
12
|
+
ODEStateDerivative = js.ode_data.ODEState
|
14
13
|
|
15
14
|
# =====================================================
|
16
15
|
# Explicit Runge-Kutta integrators operating on PyTrees
|
@@ -19,83 +18,107 @@ ODEStateDerivative = ODEState
|
|
19
18
|
|
20
19
|
@jax_dataclasses.pytree_dataclass
|
21
20
|
class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
21
|
+
"""
|
22
|
+
Forward Euler integrator.
|
23
|
+
"""
|
22
24
|
|
23
|
-
A:
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
b: ClassVar[jax.typing.ArrayLike] = (
|
30
|
-
jnp.array(
|
31
|
-
[
|
32
|
-
[1],
|
33
|
-
]
|
34
|
-
)
|
35
|
-
.astype(float)
|
36
|
-
.transpose()
|
25
|
+
A: jtp.Matrix = dataclasses.field(
|
26
|
+
default_factory=lambda: jnp.atleast_2d(0).astype(float), compare=False
|
27
|
+
)
|
28
|
+
b: jtp.Matrix = dataclasses.field(
|
29
|
+
default_factory=lambda: jnp.atleast_2d(1).astype(float), compare=False
|
37
30
|
)
|
38
31
|
|
39
|
-
c:
|
40
|
-
|
41
|
-
)
|
32
|
+
c: jtp.Vector = dataclasses.field(
|
33
|
+
default_factory=lambda: jnp.atleast_1d(0).astype(float), compare=False
|
34
|
+
)
|
42
35
|
|
43
|
-
row_index_of_solution:
|
44
|
-
order_of_bT_rows:
|
36
|
+
row_index_of_solution: int = 0
|
37
|
+
order_of_bT_rows: tuple[int, ...] = (1,)
|
38
|
+
index_of_fsal: jtp.IntLike | None = None
|
39
|
+
fsal_enabled_if_supported: bool = False
|
45
40
|
|
46
41
|
|
47
42
|
@jax_dataclasses.pytree_dataclass
|
48
|
-
class
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
43
|
+
class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
44
|
+
"""
|
45
|
+
Heun's second-order integrator.
|
46
|
+
"""
|
47
|
+
|
48
|
+
A: jtp.Matrix = dataclasses.field(
|
49
|
+
default_factory=lambda: jnp.array(
|
50
|
+
[
|
51
|
+
[0, 0],
|
52
|
+
[1, 0],
|
53
|
+
]
|
54
|
+
).astype(float),
|
55
|
+
compare=False,
|
56
|
+
)
|
57
|
+
|
58
|
+
b: jtp.Matrix = dataclasses.field(
|
59
|
+
default_factory=lambda: (
|
60
|
+
jnp.atleast_2d(
|
61
|
+
jnp.array([1 / 2, 1 / 2]),
|
62
|
+
)
|
63
|
+
.astype(float)
|
64
|
+
.transpose()
|
65
|
+
),
|
66
|
+
compare=False,
|
63
67
|
)
|
64
68
|
|
65
|
-
c:
|
66
|
-
|
67
|
-
|
69
|
+
c: jtp.Vector = dataclasses.field(
|
70
|
+
default_factory=lambda: jnp.array(
|
71
|
+
[0, 1],
|
72
|
+
).astype(float),
|
73
|
+
compare=False,
|
74
|
+
)
|
68
75
|
|
69
76
|
row_index_of_solution: ClassVar[int] = 0
|
70
77
|
order_of_bT_rows: ClassVar[tuple[int, ...]] = (2,)
|
78
|
+
index_of_fsal: jtp.IntLike | None = None
|
79
|
+
fsal_enabled_if_supported: bool = False
|
71
80
|
|
72
81
|
|
73
82
|
@jax_dataclasses.pytree_dataclass
|
74
83
|
class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
84
|
+
"""
|
85
|
+
Fourth-order Runge-Kutta integrator.
|
86
|
+
"""
|
75
87
|
|
76
|
-
A:
|
77
|
-
|
78
|
-
[
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
jnp.atleast_2d(
|
87
|
-
jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
|
88
|
-
)
|
89
|
-
.astype(float)
|
90
|
-
.transpose()
|
88
|
+
A: jtp.Matrix = dataclasses.field(
|
89
|
+
default_factory=lambda: jnp.array(
|
90
|
+
[
|
91
|
+
[0, 0, 0, 0],
|
92
|
+
[1 / 2, 0, 0, 0],
|
93
|
+
[0, 1 / 2, 0, 0],
|
94
|
+
[0, 0, 1, 0],
|
95
|
+
]
|
96
|
+
).astype(float),
|
97
|
+
compare=False,
|
91
98
|
)
|
92
99
|
|
93
|
-
|
94
|
-
|
95
|
-
|
100
|
+
b: jtp.Matrix = dataclasses.field(
|
101
|
+
default_factory=lambda: (
|
102
|
+
jnp.atleast_2d(
|
103
|
+
jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
|
104
|
+
)
|
105
|
+
.astype(float)
|
106
|
+
.transpose()
|
107
|
+
),
|
108
|
+
compare=False,
|
109
|
+
)
|
110
|
+
|
111
|
+
c: jtp.Vector = dataclasses.field(
|
112
|
+
default_factory=lambda: jnp.array(
|
113
|
+
[0, 1 / 2, 1 / 2, 1],
|
114
|
+
).astype(float),
|
115
|
+
compare=False,
|
116
|
+
)
|
96
117
|
|
97
118
|
row_index_of_solution: ClassVar[int] = 0
|
98
119
|
order_of_bT_rows: ClassVar[tuple[int, ...]] = (4,)
|
120
|
+
index_of_fsal: jtp.IntLike | None = None
|
121
|
+
fsal_enabled_if_supported: bool = False
|
99
122
|
|
100
123
|
|
101
124
|
# ===============================================================================
|
@@ -103,56 +126,28 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
103
126
|
# ===============================================================================
|
104
127
|
|
105
128
|
|
106
|
-
|
129
|
+
@jax_dataclasses.pytree_dataclass
|
130
|
+
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
|
107
131
|
"""
|
108
|
-
|
109
|
-
`PyTreeType = ODEState` to integrate the quaternion on SO(3).
|
132
|
+
Forward Euler integrator for SO(3) states.
|
110
133
|
"""
|
111
134
|
|
112
|
-
@classmethod
|
113
|
-
def post_process_state(
|
114
|
-
cls, x0: ODEState, t0: Time, xf: ODEState, dt: TimeStep
|
115
|
-
) -> ODEState:
|
116
|
-
|
117
|
-
# Indices to convert quaternions between serializations.
|
118
|
-
to_xyzw = jnp.array([1, 2, 3, 0])
|
119
|
-
to_wxyz = jnp.array([3, 0, 1, 2])
|
120
|
-
|
121
|
-
# Get the initial quaternion.
|
122
|
-
W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
|
123
|
-
xyzw=x0.physics_model.base_quaternion[to_xyzw]
|
124
|
-
)
|
125
|
-
|
126
|
-
# Get the final angular velocity.
|
127
|
-
# This is already computed by averaging the kᵢ in RK-based schemes.
|
128
|
-
# Therefore, by using the ω at tf, we obtain a RK scheme operating
|
129
|
-
# on the SO(3) manifold.
|
130
|
-
W_ω_WB_tf = xf.physics_model.base_angular_velocity
|
131
|
-
|
132
|
-
# Integrate the quaternion on SO(3).
|
133
|
-
# Note that we left-multiply with the exponential map since the angular
|
134
|
-
# velocity is expressed in the inertial frame.
|
135
|
-
W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
|
136
|
-
|
137
|
-
# Replace the quaternion in the final state.
|
138
|
-
return xf.replace(
|
139
|
-
physics_model=xf.physics_model.replace(
|
140
|
-
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
|
141
|
-
),
|
142
|
-
validate=True,
|
143
|
-
)
|
144
|
-
|
145
|
-
|
146
|
-
@jax_dataclasses.pytree_dataclass
|
147
|
-
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
|
148
135
|
pass
|
149
136
|
|
150
137
|
|
151
138
|
@jax_dataclasses.pytree_dataclass
|
152
|
-
class
|
139
|
+
class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
|
140
|
+
"""
|
141
|
+
Heun's second-order integrator for SO(3) states.
|
142
|
+
"""
|
143
|
+
|
153
144
|
pass
|
154
145
|
|
155
146
|
|
156
147
|
@jax_dataclasses.pytree_dataclass
|
157
|
-
class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[ODEState]):
|
148
|
+
class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
|
149
|
+
"""
|
150
|
+
Fourth-order Runge-Kutta integrator for SO(3) states.
|
151
|
+
"""
|
152
|
+
|
158
153
|
pass
|