jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/common.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1
1
|
import abc
|
2
2
|
import contextlib
|
3
3
|
import dataclasses
|
4
|
+
import enum
|
4
5
|
import functools
|
5
|
-
from
|
6
|
+
from collections.abc import Callable, Iterator
|
7
|
+
from typing import ParamSpec, TypeVar
|
6
8
|
|
7
9
|
import jax
|
8
10
|
import jax.numpy as jnp
|
9
11
|
import jax_dataclasses
|
10
|
-
import jaxlie
|
11
12
|
from jax_dataclasses import Static
|
12
13
|
|
13
14
|
import jaxsim.typing as jtp
|
14
|
-
from jaxsim.
|
15
|
+
from jaxsim.math import Adjoint
|
15
16
|
from jaxsim.utils import JaxsimDataclass, Mutability
|
16
17
|
|
17
18
|
try:
|
@@ -20,6 +21,32 @@ except ImportError:
|
|
20
21
|
from typing_extensions import Self
|
21
22
|
|
22
23
|
|
24
|
+
_P = ParamSpec("_P")
|
25
|
+
_R = TypeVar("_R")
|
26
|
+
|
27
|
+
|
28
|
+
def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
|
29
|
+
"""Apply a JAX named scope to a function for improved profiling and clarity."""
|
30
|
+
|
31
|
+
@functools.wraps(fn)
|
32
|
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
33
|
+
with jax.named_scope(name or fn.__name__):
|
34
|
+
return fn(*args, **kwargs)
|
35
|
+
|
36
|
+
return wrapper
|
37
|
+
|
38
|
+
|
39
|
+
@enum.unique
|
40
|
+
class VelRepr(enum.IntEnum):
|
41
|
+
"""
|
42
|
+
Enumeration of all supported 6D velocity representations.
|
43
|
+
"""
|
44
|
+
|
45
|
+
Body = enum.auto()
|
46
|
+
Mixed = enum.auto()
|
47
|
+
Inertial = enum.auto()
|
48
|
+
|
49
|
+
|
23
50
|
@jax_dataclasses.pytree_dataclass
|
24
51
|
class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
25
52
|
"""
|
@@ -33,7 +60,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
33
60
|
@contextlib.contextmanager
|
34
61
|
def switch_velocity_representation(
|
35
62
|
self, velocity_representation: VelRepr
|
36
|
-
) ->
|
63
|
+
) -> Iterator[Self]:
|
37
64
|
"""
|
38
65
|
Context manager to temporarily switch the velocity representation.
|
39
66
|
|
@@ -48,7 +75,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
48
75
|
|
49
76
|
try:
|
50
77
|
|
51
|
-
# First, we replace the velocity representation
|
78
|
+
# First, we replace the velocity representation.
|
52
79
|
with self.mutable_context(
|
53
80
|
mutability=Mutability.MUTABLE_NO_VALIDATION,
|
54
81
|
restore_after_exception=True,
|
@@ -59,7 +86,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
59
86
|
# We run this in a mutable context with restoration so that any exception
|
60
87
|
# occurring, we restore the original object in case it was modified.
|
61
88
|
with self.mutable_context(
|
62
|
-
mutability=self.
|
89
|
+
mutability=self.mutability(), restore_after_exception=True
|
63
90
|
):
|
64
91
|
yield self
|
65
92
|
|
@@ -76,16 +103,17 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
76
103
|
array: jtp.Array,
|
77
104
|
other_representation: VelRepr,
|
78
105
|
transform: jtp.Matrix,
|
79
|
-
|
106
|
+
*,
|
107
|
+
is_force: bool,
|
80
108
|
) -> jtp.Array:
|
81
|
-
"""
|
109
|
+
r"""
|
82
110
|
Convert a 6D quantity from inertial-fixed to another representation.
|
83
111
|
|
84
112
|
Args:
|
85
113
|
array: The 6D quantity to convert.
|
86
114
|
other_representation: The representation to convert to.
|
87
115
|
transform:
|
88
|
-
The
|
116
|
+
The :math:`W \mathbf{H}_O` transform, where :math:`O` is the
|
89
117
|
reference frame of the other representation.
|
90
118
|
is_force: Whether the quantity is a 6D force or a 6D velocity.
|
91
119
|
|
@@ -110,11 +138,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
110
138
|
case VelRepr.Body:
|
111
139
|
|
112
140
|
if not is_force:
|
113
|
-
O_Xv_W =
|
141
|
+
O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
|
114
142
|
O_array = O_Xv_W @ W_array
|
115
143
|
|
116
144
|
else:
|
117
|
-
O_Xf_W =
|
145
|
+
O_Xf_W = Adjoint.from_transform(transform=W_H_O).T
|
118
146
|
O_array = O_Xf_W @ W_array
|
119
147
|
|
120
148
|
return O_array
|
@@ -124,11 +152,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
124
152
|
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
|
125
153
|
|
126
154
|
if not is_force:
|
127
|
-
OW_Xv_W =
|
155
|
+
OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
|
128
156
|
OW_array = OW_Xv_W @ W_array
|
129
157
|
|
130
158
|
else:
|
131
|
-
OW_Xf_W =
|
159
|
+
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T
|
132
160
|
OW_array = OW_Xf_W @ W_array
|
133
161
|
|
134
162
|
return OW_array
|
@@ -142,9 +170,10 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
142
170
|
array: jtp.Array,
|
143
171
|
other_representation: VelRepr,
|
144
172
|
transform: jtp.Matrix,
|
145
|
-
|
173
|
+
*,
|
174
|
+
is_force: bool,
|
146
175
|
) -> jtp.Array:
|
147
|
-
"""
|
176
|
+
r"""
|
148
177
|
Convert a 6D quantity from another representation to inertial-fixed.
|
149
178
|
|
150
179
|
Args:
|
@@ -177,11 +206,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
177
206
|
O_array = array
|
178
207
|
|
179
208
|
if not is_force:
|
180
|
-
W_Xv_O: jtp.Array =
|
209
|
+
W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O)
|
181
210
|
W_array = W_Xv_O @ O_array
|
182
211
|
|
183
212
|
else:
|
184
|
-
W_Xf_O =
|
213
|
+
W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T
|
185
214
|
W_array = W_Xf_O @ O_array
|
186
215
|
|
187
216
|
return W_array
|
@@ -192,11 +221,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
192
221
|
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
|
193
222
|
|
194
223
|
if not is_force:
|
195
|
-
W_Xv_BW: jtp.Array =
|
224
|
+
W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW)
|
196
225
|
W_array = W_Xv_BW @ BW_array
|
197
226
|
|
198
227
|
else:
|
199
|
-
W_Xf_BW =
|
228
|
+
W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T
|
200
229
|
W_array = W_Xf_BW @ BW_array
|
201
230
|
|
202
231
|
return W_array
|