jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev364__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +86 -74
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/link.py +2 -2
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev364.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/utils/jaxsim_dataclass.py
CHANGED
@@ -2,7 +2,9 @@ import abc
|
|
2
2
|
import contextlib
|
3
3
|
import copy
|
4
4
|
import dataclasses
|
5
|
-
|
5
|
+
import functools
|
6
|
+
from collections.abc import Iterator
|
7
|
+
from typing import Any, Callable, ClassVar, Sequence, Type
|
6
8
|
|
7
9
|
import jax.flatten_util
|
8
10
|
import jax_dataclasses
|
@@ -19,51 +21,219 @@ except ImportError:
|
|
19
21
|
|
20
22
|
@jax_dataclasses.pytree_dataclass
|
21
23
|
class JaxsimDataclass(abc.ABC):
|
22
|
-
""""""
|
24
|
+
"""Class extending `jax_dataclasses.pytree_dataclass` instances with utilities."""
|
23
25
|
|
24
26
|
# This attribute is set by jax_dataclasses
|
25
27
|
__mutability__: ClassVar[Mutability] = Mutability.FROZEN
|
26
28
|
|
27
29
|
@contextlib.contextmanager
|
28
|
-
def editable(self: Self, validate: bool = True) ->
|
29
|
-
"""
|
30
|
+
def editable(self: Self, validate: bool = True) -> Iterator[Self]:
|
31
|
+
"""
|
32
|
+
Context manager to operate on a mutable copy of the object.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
validate: Whether to validate the output PyTree upon exiting the context.
|
36
|
+
|
37
|
+
Yields:
|
38
|
+
A mutable copy of the object.
|
39
|
+
|
40
|
+
Note:
|
41
|
+
This context manager is useful to operate on an r/w copy of a PyTree making
|
42
|
+
sure that the output object does not trigger JIT recompilations.
|
43
|
+
"""
|
30
44
|
|
31
45
|
mutability = (
|
32
46
|
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
33
47
|
)
|
34
48
|
|
35
|
-
with
|
49
|
+
with self.copy().mutable_context(mutability=mutability) as obj:
|
36
50
|
yield obj
|
37
51
|
|
38
52
|
@contextlib.contextmanager
|
39
53
|
def mutable_context(
|
40
54
|
self: Self, mutability: Mutability, restore_after_exception: bool = True
|
41
|
-
) ->
|
42
|
-
"""
|
55
|
+
) -> Iterator[Self]:
|
56
|
+
"""
|
57
|
+
Context manager to temporarily change the mutability of the object.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
mutability: The mutability to set.
|
61
|
+
restore_after_exception:
|
62
|
+
Whether to restore the original object in case of an exception
|
63
|
+
occurring within the context.
|
64
|
+
|
65
|
+
Yields:
|
66
|
+
The object with the new mutability.
|
67
|
+
|
68
|
+
Note:
|
69
|
+
This context manager is useful to operate in place on a PyTree without
|
70
|
+
the need to make a copy while optionally keeping active the checks on
|
71
|
+
the PyTree structure, shapes, and dtypes.
|
72
|
+
"""
|
43
73
|
|
44
74
|
if restore_after_exception:
|
45
75
|
self_copy = self.copy()
|
46
76
|
|
47
|
-
original_mutability = self.
|
77
|
+
original_mutability = self.mutability()
|
78
|
+
|
79
|
+
original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
|
80
|
+
original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
|
81
|
+
original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
|
82
|
+
original_structure = jax.tree_util.tree_structure(tree=self)
|
48
83
|
|
49
|
-
def restore_self():
|
50
|
-
self.
|
84
|
+
def restore_self() -> None:
|
85
|
+
self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION)
|
51
86
|
for f in dataclasses.fields(self_copy):
|
52
87
|
setattr(self, f.name, getattr(self_copy, f.name))
|
53
88
|
|
54
89
|
try:
|
55
|
-
self.
|
90
|
+
self.set_mutability(mutability)
|
56
91
|
yield self
|
92
|
+
|
93
|
+
if mutability is not Mutability.MUTABLE_NO_VALIDATION:
|
94
|
+
new_structure = jax.tree_util.tree_structure(tree=self)
|
95
|
+
if original_structure != new_structure:
|
96
|
+
msg = "Pytree structure has changed from {} to {}"
|
97
|
+
raise ValueError(msg.format(original_structure, new_structure))
|
98
|
+
|
99
|
+
new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
|
100
|
+
if original_shapes != new_shapes:
|
101
|
+
msg = "Leaves shapes have changed from {} to {}"
|
102
|
+
raise ValueError(msg.format(original_shapes, new_shapes))
|
103
|
+
|
104
|
+
new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
|
105
|
+
if original_dtypes != new_dtypes:
|
106
|
+
msg = "Leaves dtypes have changed from {} to {}"
|
107
|
+
raise ValueError(msg.format(original_dtypes, new_dtypes))
|
108
|
+
|
109
|
+
new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
|
110
|
+
if original_weak_types != new_weak_types:
|
111
|
+
msg = "Leaves weak types have changed from {} to {}"
|
112
|
+
raise ValueError(msg.format(original_weak_types, new_weak_types))
|
113
|
+
|
57
114
|
except Exception as e:
|
58
115
|
if restore_after_exception:
|
59
116
|
restore_self()
|
60
|
-
self.
|
117
|
+
self.set_mutability(original_mutability)
|
61
118
|
raise e
|
119
|
+
|
62
120
|
finally:
|
63
|
-
self.
|
121
|
+
self.set_mutability(original_mutability)
|
122
|
+
|
123
|
+
@staticmethod
|
124
|
+
def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]:
|
125
|
+
"""
|
126
|
+
Helper method to get the leaf shapes of a PyTree.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
tree: The PyTree to consider.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
A tuple containing the leaf shapes of the PyTree or `None` is the leaf is
|
133
|
+
not a numpy-like array.
|
134
|
+
"""
|
135
|
+
|
136
|
+
return tuple( # noqa
|
137
|
+
leaf.shape if hasattr(leaf, "shape") else None
|
138
|
+
for leaf in jax.tree_util.tree_leaves(tree)
|
139
|
+
if hasattr(leaf, "shape")
|
140
|
+
)
|
141
|
+
|
142
|
+
@staticmethod
|
143
|
+
def get_leaf_dtypes(tree: jtp.PyTree) -> tuple:
|
144
|
+
"""
|
145
|
+
Helper method to get the leaf dtypes of a PyTree.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
tree: The PyTree to consider.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is
|
152
|
+
not a numpy-like array.
|
153
|
+
"""
|
154
|
+
|
155
|
+
return tuple(
|
156
|
+
leaf.dtype if hasattr(leaf, "dtype") else None
|
157
|
+
for leaf in jax.tree_util.tree_leaves(tree)
|
158
|
+
if hasattr(leaf, "dtype")
|
159
|
+
)
|
160
|
+
|
161
|
+
@staticmethod
|
162
|
+
def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]:
|
163
|
+
"""
|
164
|
+
Helper method to get the leaf weak types of a PyTree.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
tree: The PyTree to consider.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
A tuple marking whether the leaf contains a JAX array with weak type.
|
171
|
+
"""
|
172
|
+
|
173
|
+
return tuple(
|
174
|
+
leaf.weak_type if hasattr(leaf, "weak_type") else False
|
175
|
+
for leaf in jax.tree_util.tree_leaves(tree)
|
176
|
+
if hasattr(leaf, "weak_type")
|
177
|
+
)
|
178
|
+
|
179
|
+
@staticmethod
|
180
|
+
def check_compatibility(*trees: Sequence[Any]) -> None:
|
181
|
+
"""
|
182
|
+
Check whether the PyTrees are compatible in structure, shape, and dtype.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
*trees: The PyTrees to compare.
|
186
|
+
|
187
|
+
Raises:
|
188
|
+
ValueError: If the PyTrees have incompatible structures, shapes, or dtypes.
|
189
|
+
"""
|
190
|
+
|
191
|
+
target_structure = jax.tree_util.tree_structure(trees[0])
|
192
|
+
|
193
|
+
compatible_structure = functools.reduce(
|
194
|
+
lambda compatible, tree: compatible
|
195
|
+
and jax.tree_util.tree_structure(tree) == target_structure,
|
196
|
+
trees[1:],
|
197
|
+
True,
|
198
|
+
)
|
199
|
+
|
200
|
+
if not compatible_structure:
|
201
|
+
raise ValueError("Pytrees have incompatible structures.")
|
202
|
+
|
203
|
+
target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0])
|
204
|
+
|
205
|
+
compatible_shapes = functools.reduce(
|
206
|
+
lambda compatible, tree: compatible
|
207
|
+
and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes,
|
208
|
+
trees[1:],
|
209
|
+
True,
|
210
|
+
)
|
211
|
+
|
212
|
+
if not compatible_shapes:
|
213
|
+
raise ValueError("Pytrees have incompatible shapes.")
|
214
|
+
|
215
|
+
target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0])
|
216
|
+
|
217
|
+
compatible_dtypes = functools.reduce(
|
218
|
+
lambda compatible, tree: compatible
|
219
|
+
and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes,
|
220
|
+
trees[1:],
|
221
|
+
True,
|
222
|
+
)
|
223
|
+
|
224
|
+
if not compatible_dtypes:
|
225
|
+
raise ValueError("Pytrees have incompatible dtypes.")
|
64
226
|
|
65
227
|
def is_mutable(self, validate: bool = False) -> bool:
|
66
|
-
"""
|
228
|
+
"""
|
229
|
+
Check whether the object is mutable.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
validate: Additionally checks if the object also has validation enabled.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
True if the object is mutable, False otherwise.
|
236
|
+
"""
|
67
237
|
|
68
238
|
return (
|
69
239
|
self.__mutability__ is Mutability.MUTABLE
|
@@ -71,39 +241,120 @@ class JaxsimDataclass(abc.ABC):
|
|
71
241
|
else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION
|
72
242
|
)
|
73
243
|
|
74
|
-
def
|
75
|
-
|
76
|
-
|
77
|
-
else:
|
78
|
-
mutability = (
|
79
|
-
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
80
|
-
)
|
244
|
+
def mutability(self) -> Mutability:
|
245
|
+
"""
|
246
|
+
Get the mutability type of the object.
|
81
247
|
|
82
|
-
|
248
|
+
Returns:
|
249
|
+
The mutability type of the object.
|
250
|
+
"""
|
83
251
|
|
84
|
-
def _mutability(self) -> Mutability:
|
85
252
|
return self.__mutability__
|
86
253
|
|
87
|
-
def
|
254
|
+
def set_mutability(self, mutability: Mutability) -> None:
|
255
|
+
"""
|
256
|
+
Set the mutability of the object in-place.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
mutability: The desired mutability type.
|
260
|
+
"""
|
261
|
+
|
88
262
|
jax_dataclasses._copy_and_mutate._mark_mutable(
|
89
263
|
self, mutable=mutability, visited=set()
|
90
264
|
)
|
91
265
|
|
92
266
|
def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self:
|
93
|
-
|
267
|
+
"""
|
268
|
+
Return a mutable reference of the object.
|
269
|
+
|
270
|
+
Args:
|
271
|
+
mutable: Whether to make the object mutable.
|
272
|
+
validate: Whether to enable validation on the object.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
A mutable reference of the object.
|
276
|
+
"""
|
277
|
+
|
278
|
+
if mutable:
|
279
|
+
mutability = (
|
280
|
+
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
mutability = Mutability.FROZEN
|
284
|
+
|
285
|
+
self.set_mutability(mutability=mutability)
|
94
286
|
return self
|
95
287
|
|
96
288
|
def copy(self: Self) -> Self:
|
289
|
+
"""
|
290
|
+
Return a copy of the object.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
A copy of the object.
|
294
|
+
"""
|
295
|
+
|
296
|
+
# Make a copy calling tree_map.
|
97
297
|
obj = jax.tree_util.tree_map(lambda leaf: leaf, self)
|
98
|
-
|
298
|
+
|
299
|
+
# Make sure that the copied object and all the copied leaves have the same
|
300
|
+
# mutability of the original object.
|
301
|
+
obj.set_mutability(mutability=self.mutability())
|
302
|
+
|
99
303
|
return obj
|
100
304
|
|
101
305
|
def replace(self: Self, validate: bool = True, **kwargs) -> Self:
|
102
|
-
|
103
|
-
|
306
|
+
"""
|
307
|
+
Return a new object replacing in-place the specified fields with new values.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
validate: Whether to validate that the new fields do not alter the PyTree.
|
311
|
+
**kwargs: The fields to replace.
|
312
|
+
|
313
|
+
Returns:
|
314
|
+
A reference of the object with the specified fields replaced.
|
315
|
+
"""
|
316
|
+
|
317
|
+
# Use the dataclasses replace method.
|
318
|
+
obj = dataclasses.replace(self, **kwargs)
|
319
|
+
|
320
|
+
if validate:
|
321
|
+
JaxsimDataclass.check_compatibility(self, obj)
|
322
|
+
|
323
|
+
# Make sure that all the new leaves have the same mutability of the object.
|
324
|
+
obj.set_mutability(mutability=self.mutability())
|
104
325
|
|
105
|
-
obj._set_mutability(mutability=self._mutability())
|
106
326
|
return obj
|
107
327
|
|
108
328
|
def flatten(self) -> jtp.VectorJax:
|
109
|
-
|
329
|
+
"""
|
330
|
+
Flatten the object into a 1D vector.
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
A 1D vector containing the flattened object.
|
334
|
+
"""
|
335
|
+
|
336
|
+
return self.flatten_fn()(self)
|
337
|
+
|
338
|
+
@classmethod
|
339
|
+
def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]:
|
340
|
+
"""
|
341
|
+
Return a function to flatten the object into a 1D vector.
|
342
|
+
|
343
|
+
Returns:
|
344
|
+
A function to flatten the object into a 1D vector.
|
345
|
+
"""
|
346
|
+
|
347
|
+
return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0]
|
348
|
+
|
349
|
+
def unflatten_fn(self: Self) -> Callable[[jtp.VectorJax], Self]:
|
350
|
+
"""
|
351
|
+
Return a function to unflatten a 1D vector into the object.
|
352
|
+
|
353
|
+
Returns:
|
354
|
+
A function to unflatten a 1D vector into the object.
|
355
|
+
|
356
|
+
Notes:
|
357
|
+
Due to JAX internals, the function to unflatten a PyTree needs to be
|
358
|
+
created from an existing instance of the PyTree.
|
359
|
+
"""
|
360
|
+
return jax.flatten_util.ravel_pytree(self)[1]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.dev364
|
4
4
|
Summary: A physics engine in reduced coordinates implemented with JAX.
|
5
5
|
Home-page: https://github.com/ami-iit/jaxsim
|
6
6
|
Author: Diego Ferigo
|
@@ -31,12 +31,12 @@ Requires-Python: >=3.11
|
|
31
31
|
Description-Content-Type: text/markdown
|
32
32
|
License-File: LICENSE
|
33
33
|
Requires-Dist: coloredlogs
|
34
|
-
Requires-Dist: jax
|
35
|
-
Requires-Dist: jaxlib
|
34
|
+
Requires-Dist: jax >=0.4.13
|
35
|
+
Requires-Dist: jaxlib >=0.4.13
|
36
36
|
Requires-Dist: jaxlie >=1.3.0
|
37
37
|
Requires-Dist: jax-dataclasses >=1.4.0
|
38
38
|
Requires-Dist: pptree
|
39
|
-
Requires-Dist: rod
|
39
|
+
Requires-Dist: rod >=0.2.0
|
40
40
|
Requires-Dist: typing-extensions ; python_version < "3.12"
|
41
41
|
Provides-Extra: all
|
42
42
|
Requires-Dist: black[jupyter] ; extra == 'all'
|
@@ -44,7 +44,6 @@ Requires-Dist: isort ; extra == 'all'
|
|
44
44
|
Requires-Dist: pre-commit ; extra == 'all'
|
45
45
|
Requires-Dist: idyntree ; extra == 'all'
|
46
46
|
Requires-Dist: pytest >=6.0 ; extra == 'all'
|
47
|
-
Requires-Dist: pytest-forked ; extra == 'all'
|
48
47
|
Requires-Dist: pytest-icdiff ; extra == 'all'
|
49
48
|
Requires-Dist: robot-descriptions ; extra == 'all'
|
50
49
|
Requires-Dist: lxml ; extra == 'all'
|
@@ -57,7 +56,6 @@ Requires-Dist: pre-commit ; extra == 'style'
|
|
57
56
|
Provides-Extra: testing
|
58
57
|
Requires-Dist: idyntree ; extra == 'testing'
|
59
58
|
Requires-Dist: pytest >=6.0 ; extra == 'testing'
|
60
|
-
Requires-Dist: pytest-forked ; extra == 'testing'
|
61
59
|
Requires-Dist: pytest-icdiff ; extra == 'testing'
|
62
60
|
Requires-Dist: robot-descriptions ; extra == 'testing'
|
63
61
|
Provides-Extra: viz
|
@@ -0,0 +1,64 @@
|
|
1
|
+
jaxsim/__init__.py,sha256=OcrfoYS1DGcmAGqu2AqlCTiUVxcpi-IsVwcr_16x74Q,1789
|
2
|
+
jaxsim/_version.py,sha256=yZse2Rzb5HFwkUDJRRD3kYlkmiLmp2arTSwcTm3cl1E,423
|
3
|
+
jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
|
4
|
+
jaxsim/typing.py,sha256=MeuOCQtLAr-sPkvB_sU8FtwGNRirz1auCwIgRC-QZl8,646
|
5
|
+
jaxsim/api/__init__.py,sha256=fNTCPUeDfOAizRd4RsW3Epv0sLTu0KJGoFRSEsi75VM,162
|
6
|
+
jaxsim/api/com.py,sha256=Qtm_6qpiK4WtDVn6JMUHa8DySgBl9CjgKCybJqZ58Lc,7379
|
7
|
+
jaxsim/api/common.py,sha256=6oqZO-QTYr2mpMx5qRrAnCIjQpjIJVe7MlavNFIKbNA,6638
|
8
|
+
jaxsim/api/contact.py,sha256=Ve4ZOWkLEBRgK3KhtICxKY7YzsxYvc3lO-pPRBjqSnY,8659
|
9
|
+
jaxsim/api/data.py,sha256=1AJyKjmXdsWEf_CkvOXRmvRsDZamG706mAAg1Gttll0,26773
|
10
|
+
jaxsim/api/joint.py,sha256=q31Kp3Cqv-yTcxijjzbj_QADFnGQyjb2al9fYZtzedo,4763
|
11
|
+
jaxsim/api/kin_dyn_parameters.py,sha256=G4mtSi8fElYe0yttLgsxSOPf7vcK-yqTu06Aa5SSrYg,26012
|
12
|
+
jaxsim/api/link.py,sha256=LZVcQhQsTKsfR13KewFtEMYu4siVJl7mqoDwYsoFFes,9240
|
13
|
+
jaxsim/api/model.py,sha256=mFdEwVuzIR7Lvj4oiIsA1n1oxpRZXWyB0IDzIxcG33Q,43689
|
14
|
+
jaxsim/api/ode.py,sha256=rbSruK0Dkp09oBgHbB_-NrZS4o2tY9geK0yLLJfXzpM,9821
|
15
|
+
jaxsim/api/ode_data.py,sha256=dwRFVPJ30XMmdUbPXEu7YxsQ97jZP4L4fd5ZzhrO5Ys,22184
|
16
|
+
jaxsim/api/references.py,sha256=Lvskf17r619KKxwCJP7hAAty2kaXgDXJX1uKqoDIDgo,15483
|
17
|
+
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
18
|
+
jaxsim/integrators/common.py,sha256=APmQVXKdN9JMIq7wrVKq2HlI1UKpqvLqaTqAq0VfJJ0,20700
|
19
|
+
jaxsim/integrators/fixed_step.py,sha256=JXaEyEzfSiYea0GnPA7l27J3X0YPB0e25D4qfrxAvzQ,2766
|
20
|
+
jaxsim/integrators/variable_step.py,sha256=jq3PStzFiMciU7lux6CTj4B3gVOfSpYgK2oz2yzIbdo,21380
|
21
|
+
jaxsim/math/__init__.py,sha256=inJ9nRFkqstuGa8OyFkfWVudo5U9Ug4WgDBuKva8AIA,337
|
22
|
+
jaxsim/math/adjoint.py,sha256=DT21izjVW497GRrgNfx8tv0ZeWW5QncWMGMhI0acUNw,4425
|
23
|
+
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
24
|
+
jaxsim/math/inertia.py,sha256=UAB7ym4gXFanejcs_ovZMpteHCc6poWYmt-mLmd5hhk,1640
|
25
|
+
jaxsim/math/joint_model.py,sha256=LKLB26VMz6vx9JLdFUWhGyrElYFEQV-bJiQO5kaZUGY,10896
|
26
|
+
jaxsim/math/quaternion.py,sha256=X9b8jHf0QemKUjIZSnXRJc3DdMr42CBhBy_mi9_X_AM,5068
|
27
|
+
jaxsim/math/rotation.py,sha256=Z90daUjGpuNEVLfWB3SVtM9EtwAIaneVj9A9UpWXqhA,2182
|
28
|
+
jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
|
29
|
+
jaxsim/math/transform.py,sha256=nqH6ofde6VWjfHihmYdXvDxKFChpOPH6AsoqkUI1Og0,2928
|
30
|
+
jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
|
31
|
+
jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
|
32
|
+
jaxsim/mujoco/loaders.py,sha256=8sXc_tsDFWBYl8nesgFarYd3hA-PESLMrXsnR3Siz1Y,16400
|
33
|
+
jaxsim/mujoco/model.py,sha256=0kG2GERxjVFqWZ1K3352rgUNfchB4kRtIrsvv4pS4oc,10766
|
34
|
+
jaxsim/mujoco/visualizer.py,sha256=-qg26t5tleTva6zzQmc5SdnlC8XZ1ZAwZ_lDjdwHJ0A,4400
|
35
|
+
jaxsim/parsers/__init__.py,sha256=sonYi-bBWAoB04kp1mxT4uIORxjb7SdZ0ukGPmVx98Y,44
|
36
|
+
jaxsim/parsers/kinematic_graph.py,sha256=2B5gtUboiSVJIm6PegbwHb5g_iXltG0R9_7h8RJ-92M,23785
|
37
|
+
jaxsim/parsers/descriptions/__init__.py,sha256=EbTfnrK3oCxA3pNv--YUwllJ6uICENvFgAdRbYtS9ts,238
|
38
|
+
jaxsim/parsers/descriptions/collision.py,sha256=HUWwuRgI9KznY29FFw1_zU3bGigDEezrcPOJSxSJGNU,3382
|
39
|
+
jaxsim/parsers/descriptions/joint.py,sha256=hpH0ANvIhbEQk-NGRmWIvPv3lXW385TBIMWNgz5rzM4,4106
|
40
|
+
jaxsim/parsers/descriptions/link.py,sha256=hqLLitrAXnouia6ULk1BPOIEfRxrXwHmoPsi306IZW8,2859
|
41
|
+
jaxsim/parsers/descriptions/model.py,sha256=wenuDrjoBf6prkzm9WyYT0nFWc0l6WBpKNjLoRUDPxo,8937
|
42
|
+
jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
|
43
|
+
jaxsim/parsers/rod/parser.py,sha256=mFi1baSJte6EMmWLpVjVuCpicfAAF48aFUzoKzYzPpo,12555
|
44
|
+
jaxsim/parsers/rod/utils.py,sha256=xoxNdb4IgOePFFnblNo6UrgYCRm53mB3No3yXM6qlpw,6471
|
45
|
+
jaxsim/rbda/__init__.py,sha256=HLwxeU-IxaRpFGUCSQv-LDv20JHTt3Xj7ELiRbRieS8,319
|
46
|
+
jaxsim/rbda/aba.py,sha256=0OoCzHhf1v-qqr1y5PIrD7_mPwAlid0fjXxUrIa5E_s,9118
|
47
|
+
jaxsim/rbda/collidable_points.py,sha256=4ZNJbEj2nEi15jBLR-GNbdaqKgkN58FBgqd_TXupEgg,4948
|
48
|
+
jaxsim/rbda/crba.py,sha256=GodskOZjtrSlbQAqxRv1un_706O7BaJK-U2qa18vJk8,4741
|
49
|
+
jaxsim/rbda/forward_kinematics.py,sha256=OHugNU7C0UxYAW0o1rqH1ZgniSwurz6L1T1MJxfxq08,3418
|
50
|
+
jaxsim/rbda/jacobian.py,sha256=9LGGy9ya5m5U0mBmV1NFH5XYZpEMYbx74qnYBvZs7Ok,6360
|
51
|
+
jaxsim/rbda/rnea.py,sha256=DjwkvXQVUSUclM3Uy3UPZ2tao91R5dGd4o7TsS2qObI,7650
|
52
|
+
jaxsim/rbda/soft_contacts.py,sha256=2EZ9Lw4nFWqXTMEeYsirl17H61s82SmTZllKVsP1Yek,10759
|
53
|
+
jaxsim/rbda/utils.py,sha256=zpbFM2Iq8cntku0BFVu9nfEqZhInCWi9D2INT6MFEI8,5003
|
54
|
+
jaxsim/terrain/__init__.py,sha256=dzekq9yyj3DKTsCARteqc81lAw3OSnl6EhXn8_Q_ozI,64
|
55
|
+
jaxsim/terrain/terrain.py,sha256=q0xkWqEShVq-p1j2abTLZq8sEhjyJwquxQKm80PaHhM,2161
|
56
|
+
jaxsim/utils/__init__.py,sha256=tnQq1_CavdfeKaLYt3pmO7Jk4MU2RwwQU_qICkjyoTY,197
|
57
|
+
jaxsim/utils/hashless.py,sha256=bFIwKeo9KiWwsY8QM55duEGGQOyyJ4jQyPcuqTLEp5k,297
|
58
|
+
jaxsim/utils/jaxsim_dataclass.py,sha256=9FtzgXQPSa6AN26FleDgXGsJMse4dtVACDESrn0g3bw,11359
|
59
|
+
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
60
|
+
jaxsim-0.2.dev364.dist-info/LICENSE,sha256=EsU2z6_sWW4Zduzq3goVWjZoCZVKQsM4H_y0o7oRA7Q,1547
|
61
|
+
jaxsim-0.2.dev364.dist-info/METADATA,sha256=8p533sQ7hCyjA28HWGZ3NvohEIc56BSi06HlEVs6tJA,7630
|
62
|
+
jaxsim-0.2.dev364.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
63
|
+
jaxsim-0.2.dev364.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
64
|
+
jaxsim-0.2.dev364.dist-info/RECORD,,
|
jaxsim/high_level/__init__.py
DELETED
jaxsim/high_level/common.py
DELETED
jaxsim/high_level/joint.py
DELETED
@@ -1,148 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
import functools
|
3
|
-
from typing import Any
|
4
|
-
|
5
|
-
import jax.numpy as jnp
|
6
|
-
import jax_dataclasses
|
7
|
-
from jax_dataclasses import Static
|
8
|
-
|
9
|
-
import jaxsim.parsers
|
10
|
-
import jaxsim.typing as jtp
|
11
|
-
from jaxsim.utils import Vmappable, not_tracing, oop
|
12
|
-
|
13
|
-
|
14
|
-
@jax_dataclasses.pytree_dataclass
|
15
|
-
class Joint(Vmappable):
|
16
|
-
"""
|
17
|
-
High-level class to operate in r/o on a single joint of a simulated model.
|
18
|
-
"""
|
19
|
-
|
20
|
-
joint_description: Static[jaxsim.parsers.descriptions.JointDescription]
|
21
|
-
|
22
|
-
_parent_model: Any = dataclasses.field(
|
23
|
-
default=None, repr=False, compare=False, hash=False
|
24
|
-
)
|
25
|
-
|
26
|
-
@property
|
27
|
-
def parent_model(self) -> "jaxsim.high_level.model.Model":
|
28
|
-
""""""
|
29
|
-
|
30
|
-
return self._parent_model
|
31
|
-
|
32
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
33
|
-
def valid(self) -> jtp.Bool:
|
34
|
-
""""""
|
35
|
-
|
36
|
-
return jnp.array(self.parent_model is not None, dtype=bool)
|
37
|
-
|
38
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
39
|
-
def index(self) -> jtp.Int:
|
40
|
-
""""""
|
41
|
-
|
42
|
-
return jnp.array(self.joint_description.index, dtype=int)
|
43
|
-
|
44
|
-
@functools.partial(oop.jax_tf.method_ro)
|
45
|
-
def dofs(self) -> jtp.Int:
|
46
|
-
""""""
|
47
|
-
|
48
|
-
return jnp.array(1, dtype=int)
|
49
|
-
|
50
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
51
|
-
def name(self) -> str:
|
52
|
-
""""""
|
53
|
-
|
54
|
-
return self.joint_description.name
|
55
|
-
|
56
|
-
@functools.partial(oop.jax_tf.method_ro)
|
57
|
-
def position(self, dof: int | None = None) -> jtp.Float:
|
58
|
-
""""""
|
59
|
-
|
60
|
-
dof = dof if dof is not None else 0
|
61
|
-
|
62
|
-
return jnp.array(
|
63
|
-
self.parent_model.joint_positions(joint_names=(self.name(),))[dof],
|
64
|
-
dtype=float,
|
65
|
-
)
|
66
|
-
|
67
|
-
@functools.partial(oop.jax_tf.method_ro)
|
68
|
-
def velocity(self, dof: int | None = None) -> jtp.Float:
|
69
|
-
""""""
|
70
|
-
|
71
|
-
dof = dof if dof is not None else 0
|
72
|
-
|
73
|
-
return jnp.array(
|
74
|
-
self.parent_model.joint_velocities(joint_names=(self.name(),))[dof],
|
75
|
-
dtype=float,
|
76
|
-
)
|
77
|
-
|
78
|
-
@functools.partial(oop.jax_tf.method_ro)
|
79
|
-
def force_target(self, dof: int | None = None) -> jtp.Float:
|
80
|
-
""""""
|
81
|
-
|
82
|
-
dof = dof if dof is not None else 0
|
83
|
-
|
84
|
-
return jnp.array(
|
85
|
-
self.parent_model.joint_generalized_forces_targets(
|
86
|
-
joint_names=(self.name(),)
|
87
|
-
)[dof],
|
88
|
-
dtype=float,
|
89
|
-
)
|
90
|
-
|
91
|
-
@functools.partial(oop.jax_tf.method_ro)
|
92
|
-
def position_limit(self, dof: int | None = None) -> tuple[jtp.Float, jtp.Float]:
|
93
|
-
""""""
|
94
|
-
|
95
|
-
dof = dof if dof is not None else 0
|
96
|
-
|
97
|
-
if not_tracing(dof) and dof != 0:
|
98
|
-
msg = "Only joints with 1 DoF are currently supported"
|
99
|
-
raise ValueError(msg)
|
100
|
-
|
101
|
-
low, high = self.joint_description.position_limit
|
102
|
-
|
103
|
-
return jnp.array(low, dtype=float), jnp.array(high, dtype=float)
|
104
|
-
|
105
|
-
# =============
|
106
|
-
# Motor methods
|
107
|
-
# =============
|
108
|
-
@functools.partial(oop.jax_tf.method_ro)
|
109
|
-
def motor_inertia(self) -> jtp.Vector:
|
110
|
-
""""""
|
111
|
-
|
112
|
-
return jnp.array(self.joint_description.motor_inertia, dtype=float)
|
113
|
-
|
114
|
-
@functools.partial(oop.jax_tf.method_ro)
|
115
|
-
def motor_gear_ratio(self) -> jtp.Vector:
|
116
|
-
""""""
|
117
|
-
|
118
|
-
return jnp.array(self.joint_description.motor_gear_ratio, dtype=float)
|
119
|
-
|
120
|
-
@functools.partial(oop.jax_tf.method_ro)
|
121
|
-
def motor_viscous_friction(self) -> jtp.Vector:
|
122
|
-
""""""
|
123
|
-
|
124
|
-
return jnp.array(self.joint_description.motor_viscous_friction, dtype=float)
|
125
|
-
|
126
|
-
# =================
|
127
|
-
# Multi-DoF methods
|
128
|
-
# =================
|
129
|
-
|
130
|
-
@functools.partial(oop.jax_tf.method_ro)
|
131
|
-
def joint_position(self) -> jtp.Vector:
|
132
|
-
""""""
|
133
|
-
|
134
|
-
return self.parent_model.joint_positions(joint_names=(self.name(),))
|
135
|
-
|
136
|
-
@functools.partial(oop.jax_tf.method_ro)
|
137
|
-
def joint_velocity(self) -> jtp.Vector:
|
138
|
-
""""""
|
139
|
-
|
140
|
-
return self.parent_model.joint_velocities(joint_names=(self.name(),))
|
141
|
-
|
142
|
-
@functools.partial(oop.jax_tf.method_ro)
|
143
|
-
def joint_force_target(self) -> jtp.Vector:
|
144
|
-
""""""
|
145
|
-
|
146
|
-
return self.parent_model.joint_generalized_forces_targets(
|
147
|
-
joint_names=(self.name(),)
|
148
|
-
)
|