jaxsim 0.2.dev2__py3-none-any.whl → 0.2.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.dev2'
16
- __version_tuple__ = version_tuple = (0, 2, 'dev2')
15
+ __version__ = version = '0.2.dev8'
16
+ __version_tuple__ = version_tuple = (0, 2, 'dev8')
@@ -385,13 +385,13 @@ class Model(Vmappable):
385
385
  def link_names(self) -> tuple[str, ...]:
386
386
  """"""
387
387
 
388
- return tuple(l.name() for l in self.links())
388
+ return tuple(self.physics_model.description.links_dict.keys())
389
389
 
390
390
  @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
391
391
  def joint_names(self) -> tuple[str, ...]:
392
392
  """"""
393
393
 
394
- return tuple(j.name() for j in self.joints())
394
+ return tuple(self.physics_model.description.joints_dict.keys())
395
395
 
396
396
  @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
397
397
  def links(
@@ -135,11 +135,13 @@ def extract_model_data(
135
135
  parent=world_link,
136
136
  child=links_dict[j.child],
137
137
  jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
138
- axis=np.array(j.axis.xyz.xyz)
139
- if j.axis is not None
140
- and j.axis.xyz is not None
141
- and j.axis.xyz.xyz is not None
142
- else None,
138
+ axis=(
139
+ np.array(j.axis.xyz.xyz)
140
+ if j.axis is not None
141
+ and j.axis.xyz is not None
142
+ and j.axis.xyz.xyz is not None
143
+ else None
144
+ ),
143
145
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
144
146
  )
145
147
  for j in sdf_model.joints()
@@ -200,41 +202,55 @@ def extract_model_data(
200
202
  parent=links_dict[j.parent],
201
203
  child=links_dict[j.child],
202
204
  jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
203
- axis=np.array(j.axis.xyz.xyz)
204
- if j.axis is not None
205
- and j.axis.xyz is not None
206
- and j.axis.xyz.xyz is not None
207
- else None,
205
+ axis=(
206
+ np.array(j.axis.xyz.xyz)
207
+ if j.axis is not None
208
+ and j.axis.xyz is not None
209
+ and j.axis.xyz.xyz is not None
210
+ else None
211
+ ),
208
212
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
209
213
  initial_position=0.0,
210
214
  position_limit=(
211
- float(j.axis.limit.lower)
212
- if j.axis is not None and j.axis.limit is not None
213
- else np.finfo(float).min,
214
- float(j.axis.limit.upper)
215
- if j.axis is not None and j.axis.limit is not None
216
- else np.finfo(float).max,
215
+ (
216
+ float(j.axis.limit.lower)
217
+ if j.axis is not None and j.axis.limit is not None
218
+ else np.finfo(float).min
219
+ ),
220
+ (
221
+ float(j.axis.limit.upper)
222
+ if j.axis is not None and j.axis.limit is not None
223
+ else np.finfo(float).max
224
+ ),
225
+ ),
226
+ friction_static=(
227
+ j.axis.dynamics.friction
228
+ if j.axis is not None
229
+ and j.axis.dynamics is not None
230
+ and j.axis.dynamics.friction is not None
231
+ else 0.0
232
+ ),
233
+ friction_viscous=(
234
+ j.axis.dynamics.damping
235
+ if j.axis is not None
236
+ and j.axis.dynamics is not None
237
+ and j.axis.dynamics.damping is not None
238
+ else 0.0
239
+ ),
240
+ position_limit_damper=(
241
+ j.axis.limit.dissipation
242
+ if j.axis is not None
243
+ and j.axis.limit is not None
244
+ and j.axis.limit.dissipation is not None
245
+ else 0.0
246
+ ),
247
+ position_limit_spring=(
248
+ j.axis.limit.stiffness
249
+ if j.axis is not None
250
+ and j.axis.limit is not None
251
+ and j.axis.limit.stiffness is not None
252
+ else 0.0
217
253
  ),
218
- friction_static=j.axis.dynamics.friction
219
- if j.axis is not None
220
- and j.axis.dynamics is not None
221
- and j.axis.dynamics.friction is not None
222
- else 0.0,
223
- friction_viscous=j.axis.dynamics.damping
224
- if j.axis is not None
225
- and j.axis.dynamics is not None
226
- and j.axis.dynamics.damping is not None
227
- else 0.0,
228
- position_limit_damper=j.axis.limit.dissipation
229
- if j.axis is not None
230
- and j.axis.limit is not None
231
- and j.axis.limit.dissipation is not None
232
- else 0.0,
233
- position_limit_spring=j.axis.limit.stiffness
234
- if j.axis is not None
235
- and j.axis.limit is not None
236
- and j.axis.limit.stiffness is not None
237
- else 0.0,
238
254
  )
239
255
  for j in sdf_model.joints()
240
256
  if j.type in {"revolute", "prismatic", "fixed"}
@@ -45,14 +45,14 @@ class PhysicsModel(JaxsimDataclass):
45
45
  )
46
46
  is_floating_base: Static[bool] = dataclasses.field(default=False)
47
47
  gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact())
48
- description: Static[
49
- jaxsim.parsers.descriptions.model.ModelDescription
50
- ] = dataclasses.field(default=None)
48
+ description: Static[jaxsim.parsers.descriptions.model.ModelDescription] = (
49
+ dataclasses.field(default=None)
50
+ )
51
51
 
52
52
  _parent_array_dict: Static[Dict[int, int]] = dataclasses.field(default_factory=dict)
53
- _jtype_dict: Static[
54
- Dict[int, Union[JointType, JointDescriptor]]
55
- ] = dataclasses.field(default_factory=dict)
53
+ _jtype_dict: Static[Dict[int, Union[JointType, JointDescriptor]]] = (
54
+ dataclasses.field(default_factory=dict)
55
+ )
56
56
  _tree_transforms_dict: Dict[int, jtp.Matrix] = dataclasses.field(
57
57
  default_factory=dict
58
58
  )
@@ -432,8 +432,9 @@ class JaxSim(Vmappable):
432
432
  def step_over_horizon(
433
433
  self,
434
434
  horizon_steps: jtp.Int,
435
- callback_handler: Union["scb.SimulatorCallback", "scb.CallbackHandler"]
436
- | None = None,
435
+ callback_handler: (
436
+ Union["scb.SimulatorCallback", "scb.CallbackHandler"] | None
437
+ ) = None,
437
438
  clear_inputs: jtp.Bool = False,
438
439
  ) -> Union[
439
440
  "JaxSim",
@@ -459,10 +460,8 @@ class JaxSim(Vmappable):
459
460
  sim = self.copy().mutable(validate=True)
460
461
 
461
462
  # Helper to get callbacks from the handler
462
- get_cb = (
463
- lambda h, cb_name: getattr(h, cb_name)
464
- if h is not None and hasattr(h, cb_name)
465
- else None
463
+ get_cb = lambda h, cb_name: (
464
+ getattr(h, cb_name) if h is not None and hasattr(h, cb_name) else None
466
465
  )
467
466
 
468
467
  # Get the callbacks
jaxsim/utils/oop.py CHANGED
@@ -3,16 +3,20 @@ import dataclasses
3
3
  import functools
4
4
  import inspect
5
5
  import os
6
- from typing import Any, Callable, Generator
6
+ from typing import Any, Callable, Generator, TypeVar
7
7
 
8
8
  import jax
9
9
  import jax.flatten_util
10
+ from typing_extensions import ParamSpec
10
11
 
11
12
  from jaxsim import logging
12
13
  from jaxsim.utils import tracing
13
14
 
14
15
  from . import Mutability, Vmappable
15
16
 
17
+ _P = ParamSpec("_P")
18
+ _R = TypeVar("_R")
19
+
16
20
 
17
21
  class jax_tf:
18
22
  """
@@ -27,13 +31,13 @@ class jax_tf:
27
31
 
28
32
  @staticmethod
29
33
  def method_ro(
30
- fn: Callable,
34
+ fn: Callable[_P, _R],
31
35
  jit: bool = True,
32
36
  static_argnames: tuple[str, ...] | list[str] = (),
33
37
  vmap: bool | None = None,
34
38
  vmap_in_axes: tuple[int, ...] | int | None = None,
35
39
  vmap_out_axes: tuple[int, ...] | int | None = None,
36
- ):
40
+ ) -> Callable[_P, _R]:
37
41
  """
38
42
  Decorator for r/o methods of classes inheriting from Vmappable.
39
43
  """
@@ -51,14 +55,14 @@ class jax_tf:
51
55
 
52
56
  @staticmethod
53
57
  def method_rw(
54
- fn: Callable,
58
+ fn: Callable[_P, _R],
55
59
  validate: bool = True,
56
60
  jit: bool = True,
57
61
  static_argnames: tuple[str, ...] | list[str] = (),
58
62
  vmap: bool | None = None,
59
63
  vmap_in_axes: tuple[int, ...] | int | None = None,
60
64
  vmap_out_axes: tuple[int, ...] | int | None = None,
61
- ):
65
+ ) -> Callable[_P, _R]:
62
66
  """
63
67
  Decorator for r/w methods of classes inheriting from Vmappable.
64
68
  """
@@ -76,7 +80,7 @@ class jax_tf:
76
80
 
77
81
  @staticmethod
78
82
  def method(
79
- fn: Callable,
83
+ fn: Callable[_P, _R],
80
84
  read_only: bool = True,
81
85
  validate: bool = True,
82
86
  jit_enabled: bool = True,
@@ -109,7 +113,7 @@ class jax_tf:
109
113
  """
110
114
 
111
115
  @functools.wraps(fn)
112
- def wrapper(*args, **kwargs):
116
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs):
113
117
  """The wrapper function that is returned by this decorator."""
114
118
 
115
119
  # Methods of classes inheriting from Vmappable decorated by this wrapper
@@ -202,9 +206,9 @@ class jax_tf:
202
206
  mutability_dict = {
203
207
  Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION,
204
208
  Mutability.MUTABLE: Mutability.MUTABLE,
205
- Mutability.FROZEN: Mutability.MUTABLE
206
- if validate
207
- else Mutability.MUTABLE_NO_VALIDATION,
209
+ Mutability.FROZEN: (
210
+ Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
211
+ ),
208
212
  }
209
213
 
210
214
  # We need to replace all the dynamic leafs of the original instance with those
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.dev2
3
+ Version: 0.2.dev8
4
4
  Summary: A physics engine in reduced coordinates implemented with JAX.
5
5
  Home-page: https://github.com/ami-iit/jaxsim
6
6
  Author: Diego Ferigo
@@ -38,7 +38,7 @@ Requires-Dist: jax-dataclasses >=1.4.0
38
38
  Requires-Dist: pptree
39
39
  Requires-Dist: rod
40
40
  Provides-Extra: all
41
- Requires-Dist: black ; extra == 'all'
41
+ Requires-Dist: black[jupyter] ; extra == 'all'
42
42
  Requires-Dist: isort ; extra == 'all'
43
43
  Requires-Dist: idyntree ; extra == 'all'
44
44
  Requires-Dist: pytest >=6.0 ; extra == 'all'
@@ -46,7 +46,7 @@ Requires-Dist: pytest-forked ; extra == 'all'
46
46
  Requires-Dist: pytest-icdiff ; extra == 'all'
47
47
  Requires-Dist: robot-descriptions ; extra == 'all'
48
48
  Provides-Extra: style
49
- Requires-Dist: black ; extra == 'style'
49
+ Requires-Dist: black[jupyter] ; extra == 'style'
50
50
  Requires-Dist: isort ; extra == 'style'
51
51
  Provides-Extra: testing
52
52
  Requires-Dist: idyntree ; extra == 'testing'
@@ -1,12 +1,12 @@
1
1
  jaxsim/__init__.py,sha256=LJhCG4rsmCrTKTocwRIvllPQeYTxDn-VFn6NjPngn4s,1877
2
- jaxsim/_version.py,sha256=g1WA1Um8OhHzOMYGMH0Q7Zv83YuHk_sKrO6hYFFg2IQ,419
2
+ jaxsim/_version.py,sha256=PIvMaj0I1p6HezVGBy1PEqXZf5zmL-w3B5rbYTFjdUY,419
3
3
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
4
4
  jaxsim/typing.py,sha256=Skdm3OrTCT0MvaL57kD10-6LAfceYIUTzO8zuP5b0RA,777
5
5
  jaxsim/high_level/__init__.py,sha256=aWYBCsYmEO76Qt4GEi91Hye_ifGFLvc_bpy9OQplz2o,69
6
6
  jaxsim/high_level/common.py,sha256=6nyRlFsNOLEy5JvLH70VPWeHGSL_ZKNxX3Q62ccqSuY,196
7
7
  jaxsim/high_level/joint.py,sha256=0WF0QWkZzP0SXw0QYpn3PAwdZq0_uXFr2_f1OATiOBA,4089
8
8
  jaxsim/high_level/link.py,sha256=4kcBMh-3w9c-fkTYm3_sXfdwd3NwUm7jKf5BjwEge94,8010
9
- jaxsim/high_level/model.py,sha256=Krv1ZPWWz7bL_XdeyupHXoGLDI0fZy7aQcJPKpJJS2w,57347
9
+ jaxsim/high_level/model.py,sha256=oyLMHkhIOqXL_cnj98wG_npHCsBhiLd5uF3AFcrTZqw,57383
10
10
  jaxsim/math/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  jaxsim/math/adjoint.py,sha256=ImkOkWHQKMukBprLTsOPpSuqb1NNPA3_t447zRVo79s,3779
12
12
  jaxsim/math/conv.py,sha256=jbr9MU_vGtTpLTQpbqwAT4huF51skCfORFUBmKSdlaI,3138
@@ -25,7 +25,7 @@ jaxsim/parsers/descriptions/joint.py,sha256=hpH0ANvIhbEQk-NGRmWIvPv3lXW385TBIMWN
25
25
  jaxsim/parsers/descriptions/link.py,sha256=SONPaSwtNhIX93RIVng8Fb_Y7I5h3sk-5rsqd5U7Fmw,2493
26
26
  jaxsim/parsers/descriptions/model.py,sha256=wenuDrjoBf6prkzm9WyYT0nFWc0l6WBpKNjLoRUDPxo,8937
27
27
  jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
28
- jaxsim/parsers/rod/parser.py,sha256=JcHkCoIhsecA1MVfn8Q0Ubkhyzb5FDpaiZzFh-rbQ6Q,12173
28
+ jaxsim/parsers/rod/parser.py,sha256=_lbU3Oaj7-2GDw-KxCCjWB4ZpAj0I5oEDxquN5Ff0p4,12565
29
29
  jaxsim/parsers/rod/utils.py,sha256=C3RfLSnHHR7rgZhnM15QjM_tpJFHKf1Jf2mOsdep3LM,6492
30
30
  jaxsim/physics/__init__.py,sha256=TKH7IqQi39eobWT03b820ky6tWVWFFWcO8YPayNpnZc,216
31
31
  jaxsim/physics/algos/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -41,24 +41,24 @@ jaxsim/physics/algos/terrain.py,sha256=Gw9-1AjU4c4Yd2xzo0i-fgWwYlroj03TjScJsz_2m
41
41
  jaxsim/physics/algos/utils.py,sha256=0OiELbXtv5Jink3l-vMK_OGHgGkZ_wTAAclcd7vDKoc,2230
42
42
  jaxsim/physics/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
43
  jaxsim/physics/model/ground_contact.py,sha256=mva-yDzYHREmgUu8jGJmIAsf66_SF6ZISmN-XQQ9Ktw,1924
44
- jaxsim/physics/model/physics_model.py,sha256=UtGBTxXR0cLEc9EHzDo6Q__RRNgX4VCS88wQ7SsJuyU,13245
44
+ jaxsim/physics/model/physics_model.py,sha256=kVTIaJQrxALzyWjWrDLnwDOcxmzaPGSpUOS8BCq-g6M,13249
45
45
  jaxsim/physics/model/physics_model_state.py,sha256=nAXAwXnL7kez8DlxI_AxOY-p9ZJjVFAgZ-P0JZ5-crA,5384
46
46
  jaxsim/simulation/__init__.py,sha256=WOWkzq7rMGa4xWvjNqTYtD0Nl4yLQtULGW1xU7hD9m0,182
47
47
  jaxsim/simulation/integrators.py,sha256=E6kz4irJYzPi4hUnzyxOgGZJiOb1ATT9QQ-UtyLOjMo,13779
48
48
  jaxsim/simulation/ode.py,sha256=cawB4qsuSoGifF_wkwGwOFU8NB0R89isa7sK-IZipv8,10376
49
49
  jaxsim/simulation/ode_data.py,sha256=spzHU5LnOL6mJPuuhho-J61koT-bcTRonqMMkiPo3M4,1750
50
50
  jaxsim/simulation/ode_integration.py,sha256=56PqhI_MEeiBQ3MNpE2ytSQBntBPPlUM-oe50DOpbHw,4042
51
- jaxsim/simulation/simulator.py,sha256=Et58GPQkXIWL0NfM8DczbjJCixz2KI6Imz8w9KFWzN4,18168
51
+ jaxsim/simulation/simulator.py,sha256=2wy6cVmwgMk6AVZ8GJ8jdLIfCRMzLe64wwqUCMtL9Xs,18160
52
52
  jaxsim/simulation/simulator_callbacks.py,sha256=FEKyQyhVmDRckIOizvNDhSj1Xh6gYLO7yXDrK09_VQY,1476
53
53
  jaxsim/simulation/utils.py,sha256=YdNA1mYGBAE7xVA-Dw7_OoBEuh0J8RS2X0RPQZf4c5E,329
54
54
  jaxsim/sixd/__init__.py,sha256=3tbynXQjvJ6X1IRcDH5eQBgBL0ilTSerDkS8SEF7a8A,62
55
55
  jaxsim/utils/__init__.py,sha256=UQsjrWMrhSQUfUXIIQhzd0kEioLHR3U0OhB-sIQqOd4,291
56
56
  jaxsim/utils/jaxsim_dataclass.py,sha256=FbjfEoCoYC_F-M3wUggXiEhQ7MMS-V_ciYQca-uSiMQ,3272
57
- jaxsim/utils/oop.py,sha256=PPQvD2_pXx81UD4dbu8E8RiYLegJNEsD23IT-YkEuQs,22434
57
+ jaxsim/utils/oop.py,sha256=LQhBXkSOD0zgYNJLO7Bl0FPRg-LvtvPzxyQa1WFP0rM,22616
58
58
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
59
59
  jaxsim/utils/vmappable.py,sha256=NqGL9nGFRI5OorCfnjXsjR_yXigzDxL0lW1YhQ_nMTY,3655
60
- jaxsim-0.2.dev2.dist-info/LICENSE,sha256=EsU2z6_sWW4Zduzq3goVWjZoCZVKQsM4H_y0o7oRA7Q,1547
61
- jaxsim-0.2.dev2.dist-info/METADATA,sha256=WQOGcK4TiSg2sMJRsPxS2HTHQhwhT90JHJZmYQMCXe0,7145
62
- jaxsim-0.2.dev2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
63
- jaxsim-0.2.dev2.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
64
- jaxsim-0.2.dev2.dist-info/RECORD,,
60
+ jaxsim-0.2.dev8.dist-info/LICENSE,sha256=EsU2z6_sWW4Zduzq3goVWjZoCZVKQsM4H_y0o7oRA7Q,1547
61
+ jaxsim-0.2.dev8.dist-info/METADATA,sha256=8OFNTQyRWt1A5WWP4-K6qtGGFTU-xNvnC34W-0emyys,7163
62
+ jaxsim-0.2.dev8.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
63
+ jaxsim-0.2.dev8.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
64
+ jaxsim-0.2.dev8.dist-info/RECORD,,