jaxsim 0.4.3.dev12__py3-none-any.whl → 0.4.3.dev18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  import abc
2
2
  import dataclasses
3
- from typing import Any, ClassVar, Generic, Protocol, Type, TypeVar
3
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
@@ -64,7 +64,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
64
64
 
65
65
  @classmethod
66
66
  def build(
67
- cls: Type[Self],
67
+ cls: type[Self],
68
68
  *,
69
69
  dynamics: SystemDynamics[State, StateDerivative],
70
70
  **kwargs,
@@ -109,11 +109,14 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
109
109
  integrator.params = params
110
110
 
111
111
  with integrator.mutable_context(mutability=Mutability.MUTABLE):
112
- xf = integrator(x0, t0, dt, **kwargs)
112
+ xf, aux_dict = integrator(x0, t0, dt, **kwargs)
113
113
 
114
- return xf, integrator.params | {
115
- Integrator.AfterInitKey: jnp.array(False).astype(bool)
116
- }
114
+ return (
115
+ xf,
116
+ integrator.params
117
+ | {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
118
+ | aux_dict,
119
+ )
117
120
 
118
121
  @abc.abstractmethod
119
122
  def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
@@ -224,7 +227,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
224
227
  @override
225
228
  @classmethod
226
229
  def build(
227
- cls: Type[Self],
230
+ cls: type[Self],
228
231
  *,
229
232
  dynamics: SystemDynamics[State, StateDerivative],
230
233
  fsal_enabled_if_supported: jtp.BoolLike = True,
@@ -277,15 +280,19 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
277
280
 
278
281
  return integrator
279
282
 
280
- def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
283
+ def __call__(
284
+ self, x0: State, t0: Time, dt: TimeStep, **kwargs
285
+ ) -> tuple[NextState, dict[str, Any]]:
281
286
 
282
287
  # Here z is a batched state with as many batch elements as b.T rows.
283
288
  # Note that z has multiple batches only if b.T has more than one row,
284
289
  # e.g. in Butcher tableau of embedded schemes.
285
- z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
290
+ z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
286
291
 
287
292
  # The next state is the batch element located at the configured index of solution.
288
- return jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
293
+ next_state = jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
294
+
295
+ return next_state, aux_dict
289
296
 
290
297
  @classmethod
291
298
  def integrate_rk_stage(
@@ -343,7 +350,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
343
350
 
344
351
  def _compute_next_state(
345
352
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
346
- ) -> NextState:
353
+ ) -> tuple[NextState, dict[str, Any]]:
347
354
  """
348
355
  Compute the next state of the system, returning all the output states.
349
356
 
@@ -373,19 +380,21 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
373
380
  )
374
381
 
375
382
  # Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
376
- get_ẋ0 = lambda: self.params.get("dxdt0", f(x0, t0)[0])
383
+ get_ẋ0_and_aux_dict = lambda: self.params.get("dxdt0", f(x0, t0))
377
384
 
378
385
  # We use a `jax.lax.scan` to compile the `f` function only once.
379
386
  # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
380
387
  # would include 4 repetitions of the `f` logic, making everything extremely slow.
381
- def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[jax.Array, None]:
388
+ def scan_body(
389
+ carry: jax.Array, i: int | jax.Array
390
+ ) -> tuple[jax.Array, dict[str, Any]]:
382
391
  """"""
383
392
 
384
393
  # Unpack the carry, i.e. the stacked kᵢ vectors.
385
394
  K = carry
386
395
 
387
396
  # Define the computation of the Runge-Kutta stage.
388
- def compute_ki() -> jax.Array:
397
+ def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
389
398
 
390
399
  # Compute ∑ⱼ aᵢⱼ kⱼ.
391
400
  op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
@@ -398,13 +407,13 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
398
407
  # Compute the next time for the kᵢ evaluation.
399
408
  ti = t0 + c[i] * Δt
400
409
 
401
- # This is k = f(xᵢ, tᵢ).
402
- return f(xi, ti)[0]
410
+ # This is kᵢ, aux_dict = f(xᵢ, tᵢ).
411
+ return f(xi, ti)
403
412
 
404
413
  # This selector enables FSAL property in the first iteration (i=0).
405
- ki = jax.lax.cond(
414
+ ki, aux_dict = jax.lax.cond(
406
415
  pred=jnp.logical_and(i == 0, self.has_fsal),
407
- true_fun=get_ẋ0,
416
+ true_fun=get_ẋ0_and_aux_dict,
408
417
  false_fun=compute_ki,
409
418
  )
410
419
 
@@ -413,10 +422,10 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
413
422
  K = jax.tree_util.tree_map(op, K, ki)
414
423
 
415
424
  carry = K
416
- return carry, None
425
+ return carry, aux_dict
417
426
 
418
427
  # Compute the state derivatives kᵢ.
419
- K, _ = jax.lax.scan(
428
+ K, aux_dict = jax.lax.scan(
420
429
  f=scan_body,
421
430
  init=carry0,
422
431
  xs=jnp.arange(c.size),
@@ -439,7 +448,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
439
448
  lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
440
449
  )(z)
441
450
 
442
- return z_transformed
451
+ return z_transformed, aux_dict
443
452
 
444
453
  @staticmethod
445
454
  def butcher_tableau_is_valid(
@@ -1,5 +1,5 @@
1
1
  import functools
2
- from typing import Any, ClassVar, Generic, Type
2
+ from typing import Any, ClassVar, Generic
3
3
 
4
4
  try:
5
5
  from typing import Self
@@ -495,7 +495,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
495
495
 
496
496
  @classmethod
497
497
  def build(
498
- cls: Type[Self],
498
+ cls: type[Self],
499
499
  *,
500
500
  dynamics: SystemDynamics[State, StateDerivative],
501
501
  fsal_enabled_if_supported: jtp.BoolLike = True,
jaxsim/logging.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import enum
2
2
  import logging
3
- from typing import Union
4
3
 
5
4
  import coloredlogs
6
5
 
@@ -20,7 +19,7 @@ def _logger() -> logging.Logger:
20
19
  return logging.getLogger(name=LOGGER_NAME)
21
20
 
22
21
 
23
- def set_logging_level(level: Union[int, LoggingLevel] = LoggingLevel.WARNING):
22
+ def set_logging_level(level: int | LoggingLevel = LoggingLevel.WARNING):
24
23
  if isinstance(level, int):
25
24
  level = LoggingLevel(level)
26
25
 
jaxsim/math/inertia.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import Tuple
2
-
3
1
  import jax.numpy as jnp
4
2
 
5
3
  import jaxsim.typing as jtp
@@ -39,7 +37,7 @@ class Inertia:
39
37
  return M
40
38
 
41
39
  @staticmethod
42
- def to_params(M: jtp.Matrix) -> Tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
40
+ def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
43
41
  """
44
42
  Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.
45
43
 
@@ -107,7 +107,7 @@ class JointModel:
107
107
  λ_H_pre=λ_H_pre,
108
108
  suc_H_i=suc_H_i,
109
109
  # Static attributes
110
- joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
110
+ joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]),
111
111
  joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
112
112
  joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
113
113
  joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
jaxsim/math/rotation.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import Tuple
2
-
3
1
  import jax
4
2
  import jax.numpy as jnp
5
3
  import jaxlie
@@ -64,7 +62,7 @@ class Rotation:
64
62
  vector = vector.squeeze()
65
63
  theta = jnp.linalg.norm(vector)
66
64
 
67
- def theta_is_not_zero(theta_and_v: Tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix:
65
+ def theta_is_not_zero(theta_and_v: tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix:
68
66
  theta, v = theta_and_v
69
67
 
70
68
  s = jnp.sin(theta)
jaxsim/mujoco/loaders.py CHANGED
@@ -4,7 +4,8 @@ import dataclasses
4
4
  import pathlib
5
5
  import tempfile
6
6
  import warnings
7
- from typing import Any, Sequence
7
+ from collections.abc import Sequence
8
+ from typing import Any
8
9
 
9
10
  import mujoco as mj
10
11
  import numpy as np
jaxsim/mujoco/model.py CHANGED
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  import functools
4
4
  import pathlib
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import mujoco as mj
8
9
  import numpy as np
@@ -1,6 +1,6 @@
1
1
  import contextlib
2
2
  import pathlib
3
- from typing import ContextManager, Sequence
3
+ from collections.abc import Sequence
4
4
 
5
5
  import mediapy as media
6
6
  import mujoco as mj
@@ -172,7 +172,7 @@ class MujocoVisualizer:
172
172
  distance: float | int | npt.NDArray | None = None,
173
173
  azimut: float | int | npt.NDArray | None = None,
174
174
  elevation: float | int | npt.NDArray | None = None,
175
- ) -> ContextManager[mujoco.viewer.Handle]:
175
+ ) -> contextlib.AbstractContextManager[mujoco.viewer.Handle]:
176
176
  """
177
177
  Context manager to open the Mujoco passive viewer.
178
178
 
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
4
  import itertools
5
- from typing import Sequence
5
+ from collections.abc import Sequence
6
6
 
7
7
  from jaxsim import logging
8
8
 
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  import copy
4
4
  import dataclasses
5
5
  import functools
6
- from typing import Any, Callable, Iterable, Sequence
6
+ from collections.abc import Callable, Iterable, Sequence
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import numpy.typing as npt
@@ -444,7 +445,7 @@ class KinematicGraph(Sequence[LinkDescription]):
444
445
  msg.format(
445
446
  link_to_remove.name,
446
447
  self.joints_connection_dict[
447
- (parent_of_link_to_remove.name, link_to_remove.name)
448
+ parent_of_link_to_remove.name, link_to_remove.name
448
449
  ].name,
449
450
  parent_of_link_to_remove.name,
450
451
  )
@@ -852,7 +853,7 @@ class KinematicGraphTransforms:
852
853
 
853
854
  # Get the joint between the link and its parent.
854
855
  parent_joint = self.graph.joints_connection_dict[
855
- (link.parent.name, link.name)
856
+ link.parent.name, link.name
856
857
  ]
857
858
 
858
859
  # Get the transform of the parent joint.
@@ -1,6 +1,6 @@
1
1
  import dataclasses
2
2
  import pathlib
3
- from typing import Dict, List, NamedTuple, Optional, Union
3
+ from typing import NamedTuple
4
4
 
5
5
  import jax.numpy as jnp
6
6
  import numpy as np
@@ -23,17 +23,17 @@ class SDFData(NamedTuple):
23
23
  fixed_base: bool
24
24
  base_link_name: str
25
25
 
26
- link_descriptions: List[descriptions.LinkDescription]
27
- joint_descriptions: List[descriptions.JointDescription]
28
- frame_descriptions: List[descriptions.LinkDescription]
29
- collision_shapes: List[descriptions.CollisionShape]
26
+ link_descriptions: list[descriptions.LinkDescription]
27
+ joint_descriptions: list[descriptions.JointDescription]
28
+ frame_descriptions: list[descriptions.LinkDescription]
29
+ collision_shapes: list[descriptions.CollisionShape]
30
30
 
31
31
  sdf_model: rod.Model | None = None
32
32
  model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose()
33
33
 
34
34
 
35
35
  def extract_model_data(
36
- model_description: Union[pathlib.Path, str, rod.Model],
36
+ model_description: pathlib.Path | str | rod.Model,
37
37
  model_name: str | None = None,
38
38
  is_urdf: bool | None = None,
39
39
  ) -> SDFData:
@@ -114,7 +114,7 @@ def extract_model_data(
114
114
  ]
115
115
 
116
116
  # Create a dictionary to find easily links.
117
- links_dict: Dict[str, descriptions.LinkDescription] = {l.name: l for l in links}
117
+ links_dict: dict[str, descriptions.LinkDescription] = {l.name: l for l in links}
118
118
 
119
119
  # ============
120
120
  # Parse frames
@@ -304,7 +304,7 @@ def extract_model_data(
304
304
  # ================
305
305
 
306
306
  # Initialize the collision shapes
307
- collisions: List[descriptions.CollisionShape] = []
307
+ collisions: list[descriptions.CollisionShape] = []
308
308
 
309
309
  # Parse the collisions
310
310
  for link in sdf_model.links():
@@ -339,8 +339,8 @@ def extract_model_data(
339
339
 
340
340
 
341
341
  def build_model_description(
342
- model_description: Union[pathlib.Path, str, rod.Model],
343
- is_urdf: Optional[bool] = False,
342
+ model_description: pathlib.Path | str | rod.Model,
343
+ is_urdf: bool | None = False,
344
344
  ) -> descriptions.ModelDescription:
345
345
  """
346
346
  Builds a model description from an SDF/URDF resource.
@@ -5,6 +5,7 @@ from typing import Any
5
5
 
6
6
  import jaxsim.terrain
7
7
  import jaxsim.typing as jtp
8
+ from jaxsim.utils import JaxsimDataclass
8
9
 
9
10
 
10
11
  class ContactsState(abc.ABC):
@@ -42,7 +43,7 @@ class ContactsState(abc.ABC):
42
43
  pass
43
44
 
44
45
 
45
- class ContactsParams(abc.ABC):
46
+ class ContactsParams(JaxsimDataclass):
46
47
  """
47
48
  Abstract class representing the parameters of a contact model.
48
49
  """
@@ -67,7 +68,7 @@ class ContactsParams(abc.ABC):
67
68
  pass
68
69
 
69
70
 
70
- class ContactModel(abc.ABC):
71
+ class ContactModel(JaxsimDataclass):
71
72
  """
72
73
  Abstract class representing a contact model.
73
74