jaxsim 0.4.3.dev12__py3-none-any.whl → 0.4.3.dev18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/common.py +2 -2
- jaxsim/api/contact.py +37 -9
- jaxsim/api/data.py +1 -1
- jaxsim/api/frame.py +1 -1
- jaxsim/api/joint.py +1 -1
- jaxsim/api/link.py +1 -1
- jaxsim/api/model.py +62 -8
- jaxsim/api/ode.py +114 -36
- jaxsim/api/ode_data.py +11 -7
- jaxsim/integrators/common.py +30 -21
- jaxsim/integrators/variable_step.py +2 -2
- jaxsim/logging.py +1 -2
- jaxsim/math/inertia.py +1 -3
- jaxsim/math/joint_model.py +1 -1
- jaxsim/math/rotation.py +1 -3
- jaxsim/mujoco/loaders.py +2 -1
- jaxsim/mujoco/model.py +2 -1
- jaxsim/mujoco/visualizer.py +2 -2
- jaxsim/parsers/descriptions/model.py +1 -1
- jaxsim/parsers/kinematic_graph.py +4 -3
- jaxsim/parsers/rod/parser.py +10 -10
- jaxsim/rbda/contacts/common.py +3 -2
- jaxsim/rbda/contacts/rigid.py +478 -0
- jaxsim/rbda/rnea.py +5 -7
- jaxsim/utils/jaxsim_dataclass.py +3 -3
- jaxsim/utils/wrappers.py +2 -1
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/METADATA +2 -1
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/RECORD +32 -31
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/top_level.txt +0 -0
jaxsim/integrators/common.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import abc
|
2
2
|
import dataclasses
|
3
|
-
from typing import Any, ClassVar, Generic, Protocol,
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
@@ -64,7 +64,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
64
64
|
|
65
65
|
@classmethod
|
66
66
|
def build(
|
67
|
-
cls:
|
67
|
+
cls: type[Self],
|
68
68
|
*,
|
69
69
|
dynamics: SystemDynamics[State, StateDerivative],
|
70
70
|
**kwargs,
|
@@ -109,11 +109,14 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
109
109
|
integrator.params = params
|
110
110
|
|
111
111
|
with integrator.mutable_context(mutability=Mutability.MUTABLE):
|
112
|
-
xf = integrator(x0, t0, dt, **kwargs)
|
112
|
+
xf, aux_dict = integrator(x0, t0, dt, **kwargs)
|
113
113
|
|
114
|
-
return
|
115
|
-
|
116
|
-
|
114
|
+
return (
|
115
|
+
xf,
|
116
|
+
integrator.params
|
117
|
+
| {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
|
118
|
+
| aux_dict,
|
119
|
+
)
|
117
120
|
|
118
121
|
@abc.abstractmethod
|
119
122
|
def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
|
@@ -224,7 +227,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
224
227
|
@override
|
225
228
|
@classmethod
|
226
229
|
def build(
|
227
|
-
cls:
|
230
|
+
cls: type[Self],
|
228
231
|
*,
|
229
232
|
dynamics: SystemDynamics[State, StateDerivative],
|
230
233
|
fsal_enabled_if_supported: jtp.BoolLike = True,
|
@@ -277,15 +280,19 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
277
280
|
|
278
281
|
return integrator
|
279
282
|
|
280
|
-
def __call__(
|
283
|
+
def __call__(
|
284
|
+
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
285
|
+
) -> tuple[NextState, dict[str, Any]]:
|
281
286
|
|
282
287
|
# Here z is a batched state with as many batch elements as b.T rows.
|
283
288
|
# Note that z has multiple batches only if b.T has more than one row,
|
284
289
|
# e.g. in Butcher tableau of embedded schemes.
|
285
|
-
z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
|
290
|
+
z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
|
286
291
|
|
287
292
|
# The next state is the batch element located at the configured index of solution.
|
288
|
-
|
293
|
+
next_state = jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
|
294
|
+
|
295
|
+
return next_state, aux_dict
|
289
296
|
|
290
297
|
@classmethod
|
291
298
|
def integrate_rk_stage(
|
@@ -343,7 +350,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
343
350
|
|
344
351
|
def _compute_next_state(
|
345
352
|
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
346
|
-
) -> NextState:
|
353
|
+
) -> tuple[NextState, dict[str, Any]]:
|
347
354
|
"""
|
348
355
|
Compute the next state of the system, returning all the output states.
|
349
356
|
|
@@ -373,19 +380,21 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
373
380
|
)
|
374
381
|
|
375
382
|
# Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
|
376
|
-
get_ẋ
|
383
|
+
get_ẋ0_and_aux_dict = lambda: self.params.get("dxdt0", f(x0, t0))
|
377
384
|
|
378
385
|
# We use a `jax.lax.scan` to compile the `f` function only once.
|
379
386
|
# Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
|
380
387
|
# would include 4 repetitions of the `f` logic, making everything extremely slow.
|
381
|
-
def scan_body(
|
388
|
+
def scan_body(
|
389
|
+
carry: jax.Array, i: int | jax.Array
|
390
|
+
) -> tuple[jax.Array, dict[str, Any]]:
|
382
391
|
""""""
|
383
392
|
|
384
393
|
# Unpack the carry, i.e. the stacked kᵢ vectors.
|
385
394
|
K = carry
|
386
395
|
|
387
396
|
# Define the computation of the Runge-Kutta stage.
|
388
|
-
def compute_ki() -> jax.Array:
|
397
|
+
def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
|
389
398
|
|
390
399
|
# Compute ∑ⱼ aᵢⱼ kⱼ.
|
391
400
|
op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
|
@@ -398,13 +407,13 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
398
407
|
# Compute the next time for the kᵢ evaluation.
|
399
408
|
ti = t0 + c[i] * Δt
|
400
409
|
|
401
|
-
# This is k
|
402
|
-
return f(xi, ti)
|
410
|
+
# This is kᵢ, aux_dict = f(xᵢ, tᵢ).
|
411
|
+
return f(xi, ti)
|
403
412
|
|
404
413
|
# This selector enables FSAL property in the first iteration (i=0).
|
405
|
-
ki = jax.lax.cond(
|
414
|
+
ki, aux_dict = jax.lax.cond(
|
406
415
|
pred=jnp.logical_and(i == 0, self.has_fsal),
|
407
|
-
true_fun=get_ẋ
|
416
|
+
true_fun=get_ẋ0_and_aux_dict,
|
408
417
|
false_fun=compute_ki,
|
409
418
|
)
|
410
419
|
|
@@ -413,10 +422,10 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
413
422
|
K = jax.tree_util.tree_map(op, K, ki)
|
414
423
|
|
415
424
|
carry = K
|
416
|
-
return carry,
|
425
|
+
return carry, aux_dict
|
417
426
|
|
418
427
|
# Compute the state derivatives kᵢ.
|
419
|
-
K,
|
428
|
+
K, aux_dict = jax.lax.scan(
|
420
429
|
f=scan_body,
|
421
430
|
init=carry0,
|
422
431
|
xs=jnp.arange(c.size),
|
@@ -439,7 +448,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
439
448
|
lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
|
440
449
|
)(z)
|
441
450
|
|
442
|
-
return z_transformed
|
451
|
+
return z_transformed, aux_dict
|
443
452
|
|
444
453
|
@staticmethod
|
445
454
|
def butcher_tableau_is_valid(
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import functools
|
2
|
-
from typing import Any, ClassVar, Generic
|
2
|
+
from typing import Any, ClassVar, Generic
|
3
3
|
|
4
4
|
try:
|
5
5
|
from typing import Self
|
@@ -495,7 +495,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
495
495
|
|
496
496
|
@classmethod
|
497
497
|
def build(
|
498
|
-
cls:
|
498
|
+
cls: type[Self],
|
499
499
|
*,
|
500
500
|
dynamics: SystemDynamics[State, StateDerivative],
|
501
501
|
fsal_enabled_if_supported: jtp.BoolLike = True,
|
jaxsim/logging.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
import enum
|
2
2
|
import logging
|
3
|
-
from typing import Union
|
4
3
|
|
5
4
|
import coloredlogs
|
6
5
|
|
@@ -20,7 +19,7 @@ def _logger() -> logging.Logger:
|
|
20
19
|
return logging.getLogger(name=LOGGER_NAME)
|
21
20
|
|
22
21
|
|
23
|
-
def set_logging_level(level:
|
22
|
+
def set_logging_level(level: int | LoggingLevel = LoggingLevel.WARNING):
|
24
23
|
if isinstance(level, int):
|
25
24
|
level = LoggingLevel(level)
|
26
25
|
|
jaxsim/math/inertia.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
1
|
import jax.numpy as jnp
|
4
2
|
|
5
3
|
import jaxsim.typing as jtp
|
@@ -39,7 +37,7 @@ class Inertia:
|
|
39
37
|
return M
|
40
38
|
|
41
39
|
@staticmethod
|
42
|
-
def to_params(M: jtp.Matrix) ->
|
40
|
+
def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
|
43
41
|
"""
|
44
42
|
Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.
|
45
43
|
|
jaxsim/math/joint_model.py
CHANGED
@@ -107,7 +107,7 @@ class JointModel:
|
|
107
107
|
λ_H_pre=λ_H_pre,
|
108
108
|
suc_H_i=suc_H_i,
|
109
109
|
# Static attributes
|
110
|
-
joint_dofs=tuple([base_dofs] + [
|
110
|
+
joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]),
|
111
111
|
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
|
112
112
|
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
|
113
113
|
joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
|
jaxsim/math/rotation.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
1
|
import jax
|
4
2
|
import jax.numpy as jnp
|
5
3
|
import jaxlie
|
@@ -64,7 +62,7 @@ class Rotation:
|
|
64
62
|
vector = vector.squeeze()
|
65
63
|
theta = jnp.linalg.norm(vector)
|
66
64
|
|
67
|
-
def theta_is_not_zero(theta_and_v:
|
65
|
+
def theta_is_not_zero(theta_and_v: tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix:
|
68
66
|
theta, v = theta_and_v
|
69
67
|
|
70
68
|
s = jnp.sin(theta)
|
jaxsim/mujoco/loaders.py
CHANGED
jaxsim/mujoco/model.py
CHANGED
jaxsim/mujoco/visualizer.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import contextlib
|
2
2
|
import pathlib
|
3
|
-
from
|
3
|
+
from collections.abc import Sequence
|
4
4
|
|
5
5
|
import mediapy as media
|
6
6
|
import mujoco as mj
|
@@ -172,7 +172,7 @@ class MujocoVisualizer:
|
|
172
172
|
distance: float | int | npt.NDArray | None = None,
|
173
173
|
azimut: float | int | npt.NDArray | None = None,
|
174
174
|
elevation: float | int | npt.NDArray | None = None,
|
175
|
-
) ->
|
175
|
+
) -> contextlib.AbstractContextManager[mujoco.viewer.Handle]:
|
176
176
|
"""
|
177
177
|
Context manager to open the Mujoco passive viewer.
|
178
178
|
|
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
import copy
|
4
4
|
import dataclasses
|
5
5
|
import functools
|
6
|
-
from
|
6
|
+
from collections.abc import Callable, Iterable, Sequence
|
7
|
+
from typing import Any
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
import numpy.typing as npt
|
@@ -444,7 +445,7 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
444
445
|
msg.format(
|
445
446
|
link_to_remove.name,
|
446
447
|
self.joints_connection_dict[
|
447
|
-
|
448
|
+
parent_of_link_to_remove.name, link_to_remove.name
|
448
449
|
].name,
|
449
450
|
parent_of_link_to_remove.name,
|
450
451
|
)
|
@@ -852,7 +853,7 @@ class KinematicGraphTransforms:
|
|
852
853
|
|
853
854
|
# Get the joint between the link and its parent.
|
854
855
|
parent_joint = self.graph.joints_connection_dict[
|
855
|
-
|
856
|
+
link.parent.name, link.name
|
856
857
|
]
|
857
858
|
|
858
859
|
# Get the transform of the parent joint.
|
jaxsim/parsers/rod/parser.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import pathlib
|
3
|
-
from typing import
|
3
|
+
from typing import NamedTuple
|
4
4
|
|
5
5
|
import jax.numpy as jnp
|
6
6
|
import numpy as np
|
@@ -23,17 +23,17 @@ class SDFData(NamedTuple):
|
|
23
23
|
fixed_base: bool
|
24
24
|
base_link_name: str
|
25
25
|
|
26
|
-
link_descriptions:
|
27
|
-
joint_descriptions:
|
28
|
-
frame_descriptions:
|
29
|
-
collision_shapes:
|
26
|
+
link_descriptions: list[descriptions.LinkDescription]
|
27
|
+
joint_descriptions: list[descriptions.JointDescription]
|
28
|
+
frame_descriptions: list[descriptions.LinkDescription]
|
29
|
+
collision_shapes: list[descriptions.CollisionShape]
|
30
30
|
|
31
31
|
sdf_model: rod.Model | None = None
|
32
32
|
model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose()
|
33
33
|
|
34
34
|
|
35
35
|
def extract_model_data(
|
36
|
-
model_description:
|
36
|
+
model_description: pathlib.Path | str | rod.Model,
|
37
37
|
model_name: str | None = None,
|
38
38
|
is_urdf: bool | None = None,
|
39
39
|
) -> SDFData:
|
@@ -114,7 +114,7 @@ def extract_model_data(
|
|
114
114
|
]
|
115
115
|
|
116
116
|
# Create a dictionary to find easily links.
|
117
|
-
links_dict:
|
117
|
+
links_dict: dict[str, descriptions.LinkDescription] = {l.name: l for l in links}
|
118
118
|
|
119
119
|
# ============
|
120
120
|
# Parse frames
|
@@ -304,7 +304,7 @@ def extract_model_data(
|
|
304
304
|
# ================
|
305
305
|
|
306
306
|
# Initialize the collision shapes
|
307
|
-
collisions:
|
307
|
+
collisions: list[descriptions.CollisionShape] = []
|
308
308
|
|
309
309
|
# Parse the collisions
|
310
310
|
for link in sdf_model.links():
|
@@ -339,8 +339,8 @@ def extract_model_data(
|
|
339
339
|
|
340
340
|
|
341
341
|
def build_model_description(
|
342
|
-
model_description:
|
343
|
-
is_urdf:
|
342
|
+
model_description: pathlib.Path | str | rod.Model,
|
343
|
+
is_urdf: bool | None = False,
|
344
344
|
) -> descriptions.ModelDescription:
|
345
345
|
"""
|
346
346
|
Builds a model description from an SDF/URDF resource.
|
jaxsim/rbda/contacts/common.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Any
|
|
5
5
|
|
6
6
|
import jaxsim.terrain
|
7
7
|
import jaxsim.typing as jtp
|
8
|
+
from jaxsim.utils import JaxsimDataclass
|
8
9
|
|
9
10
|
|
10
11
|
class ContactsState(abc.ABC):
|
@@ -42,7 +43,7 @@ class ContactsState(abc.ABC):
|
|
42
43
|
pass
|
43
44
|
|
44
45
|
|
45
|
-
class ContactsParams(
|
46
|
+
class ContactsParams(JaxsimDataclass):
|
46
47
|
"""
|
47
48
|
Abstract class representing the parameters of a contact model.
|
48
49
|
"""
|
@@ -67,7 +68,7 @@ class ContactsParams(abc.ABC):
|
|
67
68
|
pass
|
68
69
|
|
69
70
|
|
70
|
-
class ContactModel(
|
71
|
+
class ContactModel(JaxsimDataclass):
|
71
72
|
"""
|
72
73
|
Abstract class representing a contact model.
|
73
74
|
|