jaxsim 0.1rc0__py3-none-any.whl → 0.2.dev8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/high_level/model.py +2 -2
- jaxsim/parsers/rod/parser.py +52 -36
- jaxsim/physics/model/physics_model.py +6 -6
- jaxsim/simulation/simulator.py +5 -6
- jaxsim/utils/oop.py +14 -10
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.dev8.dist-info}/METADATA +3 -3
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.dev8.dist-info}/RECORD +11 -11
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.dev8.dist-info}/LICENSE +0 -0
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.dev8.dist-info}/WHEEL +0 -0
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.dev8.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
jaxsim/high_level/model.py
CHANGED
@@ -385,13 +385,13 @@ class Model(Vmappable):
|
|
385
385
|
def link_names(self) -> tuple[str, ...]:
|
386
386
|
""""""
|
387
387
|
|
388
|
-
return tuple(
|
388
|
+
return tuple(self.physics_model.description.links_dict.keys())
|
389
389
|
|
390
390
|
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
391
391
|
def joint_names(self) -> tuple[str, ...]:
|
392
392
|
""""""
|
393
393
|
|
394
|
-
return tuple(
|
394
|
+
return tuple(self.physics_model.description.joints_dict.keys())
|
395
395
|
|
396
396
|
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
397
397
|
def links(
|
jaxsim/parsers/rod/parser.py
CHANGED
@@ -135,11 +135,13 @@ def extract_model_data(
|
|
135
135
|
parent=world_link,
|
136
136
|
child=links_dict[j.child],
|
137
137
|
jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
|
138
|
-
axis=
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
138
|
+
axis=(
|
139
|
+
np.array(j.axis.xyz.xyz)
|
140
|
+
if j.axis is not None
|
141
|
+
and j.axis.xyz is not None
|
142
|
+
and j.axis.xyz.xyz is not None
|
143
|
+
else None
|
144
|
+
),
|
143
145
|
pose=j.pose.transform() if j.pose is not None else np.eye(4),
|
144
146
|
)
|
145
147
|
for j in sdf_model.joints()
|
@@ -200,41 +202,55 @@ def extract_model_data(
|
|
200
202
|
parent=links_dict[j.parent],
|
201
203
|
child=links_dict[j.child],
|
202
204
|
jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
|
203
|
-
axis=
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
205
|
+
axis=(
|
206
|
+
np.array(j.axis.xyz.xyz)
|
207
|
+
if j.axis is not None
|
208
|
+
and j.axis.xyz is not None
|
209
|
+
and j.axis.xyz.xyz is not None
|
210
|
+
else None
|
211
|
+
),
|
208
212
|
pose=j.pose.transform() if j.pose is not None else np.eye(4),
|
209
213
|
initial_position=0.0,
|
210
214
|
position_limit=(
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
215
|
+
(
|
216
|
+
float(j.axis.limit.lower)
|
217
|
+
if j.axis is not None and j.axis.limit is not None
|
218
|
+
else np.finfo(float).min
|
219
|
+
),
|
220
|
+
(
|
221
|
+
float(j.axis.limit.upper)
|
222
|
+
if j.axis is not None and j.axis.limit is not None
|
223
|
+
else np.finfo(float).max
|
224
|
+
),
|
225
|
+
),
|
226
|
+
friction_static=(
|
227
|
+
j.axis.dynamics.friction
|
228
|
+
if j.axis is not None
|
229
|
+
and j.axis.dynamics is not None
|
230
|
+
and j.axis.dynamics.friction is not None
|
231
|
+
else 0.0
|
232
|
+
),
|
233
|
+
friction_viscous=(
|
234
|
+
j.axis.dynamics.damping
|
235
|
+
if j.axis is not None
|
236
|
+
and j.axis.dynamics is not None
|
237
|
+
and j.axis.dynamics.damping is not None
|
238
|
+
else 0.0
|
239
|
+
),
|
240
|
+
position_limit_damper=(
|
241
|
+
j.axis.limit.dissipation
|
242
|
+
if j.axis is not None
|
243
|
+
and j.axis.limit is not None
|
244
|
+
and j.axis.limit.dissipation is not None
|
245
|
+
else 0.0
|
246
|
+
),
|
247
|
+
position_limit_spring=(
|
248
|
+
j.axis.limit.stiffness
|
249
|
+
if j.axis is not None
|
250
|
+
and j.axis.limit is not None
|
251
|
+
and j.axis.limit.stiffness is not None
|
252
|
+
else 0.0
|
217
253
|
),
|
218
|
-
friction_static=j.axis.dynamics.friction
|
219
|
-
if j.axis is not None
|
220
|
-
and j.axis.dynamics is not None
|
221
|
-
and j.axis.dynamics.friction is not None
|
222
|
-
else 0.0,
|
223
|
-
friction_viscous=j.axis.dynamics.damping
|
224
|
-
if j.axis is not None
|
225
|
-
and j.axis.dynamics is not None
|
226
|
-
and j.axis.dynamics.damping is not None
|
227
|
-
else 0.0,
|
228
|
-
position_limit_damper=j.axis.limit.dissipation
|
229
|
-
if j.axis is not None
|
230
|
-
and j.axis.limit is not None
|
231
|
-
and j.axis.limit.dissipation is not None
|
232
|
-
else 0.0,
|
233
|
-
position_limit_spring=j.axis.limit.stiffness
|
234
|
-
if j.axis is not None
|
235
|
-
and j.axis.limit is not None
|
236
|
-
and j.axis.limit.stiffness is not None
|
237
|
-
else 0.0,
|
238
254
|
)
|
239
255
|
for j in sdf_model.joints()
|
240
256
|
if j.type in {"revolute", "prismatic", "fixed"}
|
@@ -45,14 +45,14 @@ class PhysicsModel(JaxsimDataclass):
|
|
45
45
|
)
|
46
46
|
is_floating_base: Static[bool] = dataclasses.field(default=False)
|
47
47
|
gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact())
|
48
|
-
description: Static[
|
49
|
-
|
50
|
-
|
48
|
+
description: Static[jaxsim.parsers.descriptions.model.ModelDescription] = (
|
49
|
+
dataclasses.field(default=None)
|
50
|
+
)
|
51
51
|
|
52
52
|
_parent_array_dict: Static[Dict[int, int]] = dataclasses.field(default_factory=dict)
|
53
|
-
_jtype_dict: Static[
|
54
|
-
|
55
|
-
|
53
|
+
_jtype_dict: Static[Dict[int, Union[JointType, JointDescriptor]]] = (
|
54
|
+
dataclasses.field(default_factory=dict)
|
55
|
+
)
|
56
56
|
_tree_transforms_dict: Dict[int, jtp.Matrix] = dataclasses.field(
|
57
57
|
default_factory=dict
|
58
58
|
)
|
jaxsim/simulation/simulator.py
CHANGED
@@ -432,8 +432,9 @@ class JaxSim(Vmappable):
|
|
432
432
|
def step_over_horizon(
|
433
433
|
self,
|
434
434
|
horizon_steps: jtp.Int,
|
435
|
-
callback_handler:
|
436
|
-
|
435
|
+
callback_handler: (
|
436
|
+
Union["scb.SimulatorCallback", "scb.CallbackHandler"] | None
|
437
|
+
) = None,
|
437
438
|
clear_inputs: jtp.Bool = False,
|
438
439
|
) -> Union[
|
439
440
|
"JaxSim",
|
@@ -459,10 +460,8 @@ class JaxSim(Vmappable):
|
|
459
460
|
sim = self.copy().mutable(validate=True)
|
460
461
|
|
461
462
|
# Helper to get callbacks from the handler
|
462
|
-
get_cb = (
|
463
|
-
|
464
|
-
if h is not None and hasattr(h, cb_name)
|
465
|
-
else None
|
463
|
+
get_cb = lambda h, cb_name: (
|
464
|
+
getattr(h, cb_name) if h is not None and hasattr(h, cb_name) else None
|
466
465
|
)
|
467
466
|
|
468
467
|
# Get the callbacks
|
jaxsim/utils/oop.py
CHANGED
@@ -3,16 +3,20 @@ import dataclasses
|
|
3
3
|
import functools
|
4
4
|
import inspect
|
5
5
|
import os
|
6
|
-
from typing import Any, Callable, Generator
|
6
|
+
from typing import Any, Callable, Generator, TypeVar
|
7
7
|
|
8
8
|
import jax
|
9
9
|
import jax.flatten_util
|
10
|
+
from typing_extensions import ParamSpec
|
10
11
|
|
11
12
|
from jaxsim import logging
|
12
13
|
from jaxsim.utils import tracing
|
13
14
|
|
14
15
|
from . import Mutability, Vmappable
|
15
16
|
|
17
|
+
_P = ParamSpec("_P")
|
18
|
+
_R = TypeVar("_R")
|
19
|
+
|
16
20
|
|
17
21
|
class jax_tf:
|
18
22
|
"""
|
@@ -27,13 +31,13 @@ class jax_tf:
|
|
27
31
|
|
28
32
|
@staticmethod
|
29
33
|
def method_ro(
|
30
|
-
fn: Callable,
|
34
|
+
fn: Callable[_P, _R],
|
31
35
|
jit: bool = True,
|
32
36
|
static_argnames: tuple[str, ...] | list[str] = (),
|
33
37
|
vmap: bool | None = None,
|
34
38
|
vmap_in_axes: tuple[int, ...] | int | None = None,
|
35
39
|
vmap_out_axes: tuple[int, ...] | int | None = None,
|
36
|
-
):
|
40
|
+
) -> Callable[_P, _R]:
|
37
41
|
"""
|
38
42
|
Decorator for r/o methods of classes inheriting from Vmappable.
|
39
43
|
"""
|
@@ -51,14 +55,14 @@ class jax_tf:
|
|
51
55
|
|
52
56
|
@staticmethod
|
53
57
|
def method_rw(
|
54
|
-
fn: Callable,
|
58
|
+
fn: Callable[_P, _R],
|
55
59
|
validate: bool = True,
|
56
60
|
jit: bool = True,
|
57
61
|
static_argnames: tuple[str, ...] | list[str] = (),
|
58
62
|
vmap: bool | None = None,
|
59
63
|
vmap_in_axes: tuple[int, ...] | int | None = None,
|
60
64
|
vmap_out_axes: tuple[int, ...] | int | None = None,
|
61
|
-
):
|
65
|
+
) -> Callable[_P, _R]:
|
62
66
|
"""
|
63
67
|
Decorator for r/w methods of classes inheriting from Vmappable.
|
64
68
|
"""
|
@@ -76,7 +80,7 @@ class jax_tf:
|
|
76
80
|
|
77
81
|
@staticmethod
|
78
82
|
def method(
|
79
|
-
fn: Callable,
|
83
|
+
fn: Callable[_P, _R],
|
80
84
|
read_only: bool = True,
|
81
85
|
validate: bool = True,
|
82
86
|
jit_enabled: bool = True,
|
@@ -109,7 +113,7 @@ class jax_tf:
|
|
109
113
|
"""
|
110
114
|
|
111
115
|
@functools.wraps(fn)
|
112
|
-
def wrapper(*args, **kwargs):
|
116
|
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs):
|
113
117
|
"""The wrapper function that is returned by this decorator."""
|
114
118
|
|
115
119
|
# Methods of classes inheriting from Vmappable decorated by this wrapper
|
@@ -202,9 +206,9 @@ class jax_tf:
|
|
202
206
|
mutability_dict = {
|
203
207
|
Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION,
|
204
208
|
Mutability.MUTABLE: Mutability.MUTABLE,
|
205
|
-
Mutability.FROZEN:
|
206
|
-
|
207
|
-
|
209
|
+
Mutability.FROZEN: (
|
210
|
+
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
211
|
+
),
|
208
212
|
}
|
209
213
|
|
210
214
|
# We need to replace all the dynamic leafs of the original instance with those
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.dev8
|
4
4
|
Summary: A physics engine in reduced coordinates implemented with JAX.
|
5
5
|
Home-page: https://github.com/ami-iit/jaxsim
|
6
6
|
Author: Diego Ferigo
|
@@ -38,7 +38,7 @@ Requires-Dist: jax-dataclasses >=1.4.0
|
|
38
38
|
Requires-Dist: pptree
|
39
39
|
Requires-Dist: rod
|
40
40
|
Provides-Extra: all
|
41
|
-
Requires-Dist: black ; extra == 'all'
|
41
|
+
Requires-Dist: black[jupyter] ; extra == 'all'
|
42
42
|
Requires-Dist: isort ; extra == 'all'
|
43
43
|
Requires-Dist: idyntree ; extra == 'all'
|
44
44
|
Requires-Dist: pytest >=6.0 ; extra == 'all'
|
@@ -46,7 +46,7 @@ Requires-Dist: pytest-forked ; extra == 'all'
|
|
46
46
|
Requires-Dist: pytest-icdiff ; extra == 'all'
|
47
47
|
Requires-Dist: robot-descriptions ; extra == 'all'
|
48
48
|
Provides-Extra: style
|
49
|
-
Requires-Dist: black ; extra == 'style'
|
49
|
+
Requires-Dist: black[jupyter] ; extra == 'style'
|
50
50
|
Requires-Dist: isort ; extra == 'style'
|
51
51
|
Provides-Extra: testing
|
52
52
|
Requires-Dist: idyntree ; extra == 'testing'
|
@@ -1,12 +1,12 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=LJhCG4rsmCrTKTocwRIvllPQeYTxDn-VFn6NjPngn4s,1877
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=PIvMaj0I1p6HezVGBy1PEqXZf5zmL-w3B5rbYTFjdUY,419
|
3
3
|
jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
|
4
4
|
jaxsim/typing.py,sha256=Skdm3OrTCT0MvaL57kD10-6LAfceYIUTzO8zuP5b0RA,777
|
5
5
|
jaxsim/high_level/__init__.py,sha256=aWYBCsYmEO76Qt4GEi91Hye_ifGFLvc_bpy9OQplz2o,69
|
6
6
|
jaxsim/high_level/common.py,sha256=6nyRlFsNOLEy5JvLH70VPWeHGSL_ZKNxX3Q62ccqSuY,196
|
7
7
|
jaxsim/high_level/joint.py,sha256=0WF0QWkZzP0SXw0QYpn3PAwdZq0_uXFr2_f1OATiOBA,4089
|
8
8
|
jaxsim/high_level/link.py,sha256=4kcBMh-3w9c-fkTYm3_sXfdwd3NwUm7jKf5BjwEge94,8010
|
9
|
-
jaxsim/high_level/model.py,sha256=
|
9
|
+
jaxsim/high_level/model.py,sha256=oyLMHkhIOqXL_cnj98wG_npHCsBhiLd5uF3AFcrTZqw,57383
|
10
10
|
jaxsim/math/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
jaxsim/math/adjoint.py,sha256=ImkOkWHQKMukBprLTsOPpSuqb1NNPA3_t447zRVo79s,3779
|
12
12
|
jaxsim/math/conv.py,sha256=jbr9MU_vGtTpLTQpbqwAT4huF51skCfORFUBmKSdlaI,3138
|
@@ -25,7 +25,7 @@ jaxsim/parsers/descriptions/joint.py,sha256=hpH0ANvIhbEQk-NGRmWIvPv3lXW385TBIMWN
|
|
25
25
|
jaxsim/parsers/descriptions/link.py,sha256=SONPaSwtNhIX93RIVng8Fb_Y7I5h3sk-5rsqd5U7Fmw,2493
|
26
26
|
jaxsim/parsers/descriptions/model.py,sha256=wenuDrjoBf6prkzm9WyYT0nFWc0l6WBpKNjLoRUDPxo,8937
|
27
27
|
jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
|
28
|
-
jaxsim/parsers/rod/parser.py,sha256=
|
28
|
+
jaxsim/parsers/rod/parser.py,sha256=_lbU3Oaj7-2GDw-KxCCjWB4ZpAj0I5oEDxquN5Ff0p4,12565
|
29
29
|
jaxsim/parsers/rod/utils.py,sha256=C3RfLSnHHR7rgZhnM15QjM_tpJFHKf1Jf2mOsdep3LM,6492
|
30
30
|
jaxsim/physics/__init__.py,sha256=TKH7IqQi39eobWT03b820ky6tWVWFFWcO8YPayNpnZc,216
|
31
31
|
jaxsim/physics/algos/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -41,24 +41,24 @@ jaxsim/physics/algos/terrain.py,sha256=Gw9-1AjU4c4Yd2xzo0i-fgWwYlroj03TjScJsz_2m
|
|
41
41
|
jaxsim/physics/algos/utils.py,sha256=0OiELbXtv5Jink3l-vMK_OGHgGkZ_wTAAclcd7vDKoc,2230
|
42
42
|
jaxsim/physics/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
43
43
|
jaxsim/physics/model/ground_contact.py,sha256=mva-yDzYHREmgUu8jGJmIAsf66_SF6ZISmN-XQQ9Ktw,1924
|
44
|
-
jaxsim/physics/model/physics_model.py,sha256=
|
44
|
+
jaxsim/physics/model/physics_model.py,sha256=kVTIaJQrxALzyWjWrDLnwDOcxmzaPGSpUOS8BCq-g6M,13249
|
45
45
|
jaxsim/physics/model/physics_model_state.py,sha256=nAXAwXnL7kez8DlxI_AxOY-p9ZJjVFAgZ-P0JZ5-crA,5384
|
46
46
|
jaxsim/simulation/__init__.py,sha256=WOWkzq7rMGa4xWvjNqTYtD0Nl4yLQtULGW1xU7hD9m0,182
|
47
47
|
jaxsim/simulation/integrators.py,sha256=E6kz4irJYzPi4hUnzyxOgGZJiOb1ATT9QQ-UtyLOjMo,13779
|
48
48
|
jaxsim/simulation/ode.py,sha256=cawB4qsuSoGifF_wkwGwOFU8NB0R89isa7sK-IZipv8,10376
|
49
49
|
jaxsim/simulation/ode_data.py,sha256=spzHU5LnOL6mJPuuhho-J61koT-bcTRonqMMkiPo3M4,1750
|
50
50
|
jaxsim/simulation/ode_integration.py,sha256=56PqhI_MEeiBQ3MNpE2ytSQBntBPPlUM-oe50DOpbHw,4042
|
51
|
-
jaxsim/simulation/simulator.py,sha256=
|
51
|
+
jaxsim/simulation/simulator.py,sha256=2wy6cVmwgMk6AVZ8GJ8jdLIfCRMzLe64wwqUCMtL9Xs,18160
|
52
52
|
jaxsim/simulation/simulator_callbacks.py,sha256=FEKyQyhVmDRckIOizvNDhSj1Xh6gYLO7yXDrK09_VQY,1476
|
53
53
|
jaxsim/simulation/utils.py,sha256=YdNA1mYGBAE7xVA-Dw7_OoBEuh0J8RS2X0RPQZf4c5E,329
|
54
54
|
jaxsim/sixd/__init__.py,sha256=3tbynXQjvJ6X1IRcDH5eQBgBL0ilTSerDkS8SEF7a8A,62
|
55
55
|
jaxsim/utils/__init__.py,sha256=UQsjrWMrhSQUfUXIIQhzd0kEioLHR3U0OhB-sIQqOd4,291
|
56
56
|
jaxsim/utils/jaxsim_dataclass.py,sha256=FbjfEoCoYC_F-M3wUggXiEhQ7MMS-V_ciYQca-uSiMQ,3272
|
57
|
-
jaxsim/utils/oop.py,sha256=
|
57
|
+
jaxsim/utils/oop.py,sha256=LQhBXkSOD0zgYNJLO7Bl0FPRg-LvtvPzxyQa1WFP0rM,22616
|
58
58
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
59
59
|
jaxsim/utils/vmappable.py,sha256=NqGL9nGFRI5OorCfnjXsjR_yXigzDxL0lW1YhQ_nMTY,3655
|
60
|
-
jaxsim-0.
|
61
|
-
jaxsim-0.
|
62
|
-
jaxsim-0.
|
63
|
-
jaxsim-0.
|
64
|
-
jaxsim-0.
|
60
|
+
jaxsim-0.2.dev8.dist-info/LICENSE,sha256=EsU2z6_sWW4Zduzq3goVWjZoCZVKQsM4H_y0o7oRA7Q,1547
|
61
|
+
jaxsim-0.2.dev8.dist-info/METADATA,sha256=8OFNTQyRWt1A5WWP4-K6qtGGFTU-xNvnC34W-0emyys,7163
|
62
|
+
jaxsim-0.2.dev8.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
63
|
+
jaxsim-0.2.dev8.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
64
|
+
jaxsim-0.2.dev8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|