jaxsim 0.4.3.dev177__py3-none-any.whl → 0.4.3.dev186__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/__init__.py CHANGED
@@ -8,21 +8,35 @@ def _jnp_options() -> None:
8
8
 
9
9
  import jax
10
10
 
11
- # Enable by default 64bit precision in JAX.
12
- if os.environ.get("JAX_ENABLE_X64", "1") != "0":
13
-
14
- logging.info("Enabling JAX to use 64bit precision")
11
+ # Check if running on TPU
12
+ is_tpu = jax.devices()[0].platform == "tpu"
13
+
14
+ # Enable by default 64-bit precision to get accurate physics.
15
+ # Users can enforce 32-bit precision by setting the following variable to 0.
16
+ use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"
17
+
18
+ # Notify the user if unsupported 64-bit precision was enforced on TPU.
19
+ if is_tpu and use_x64:
20
+ msg = "64-bit precision is not allowed on TPU. Enforcing 32bit precision."
21
+ logging.warning(msg)
22
+ use_x64 = False
23
+
24
+ # Enable 64-bit precision in JAX.
25
+ if use_x64:
26
+ logging.info("Enabling JAX to use 64-bit precision")
15
27
  jax.config.update("jax_enable_x64", True)
16
28
 
17
29
  import jax.numpy as jnp
18
30
  import numpy as np
19
31
 
32
+ # Verify that 64-bit precision is correctly set.
20
33
  if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
21
- logging.warning("Failed to enable 64bit precision in JAX")
34
+ logging.warning("Failed to enable 64-bit precision in JAX")
22
35
 
36
+ # Warn about experimental usage of 32-bit precision.
23
37
  else:
24
38
  logging.warning(
25
- "Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
39
+ "Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
26
40
  )
27
41
 
28
42
 
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev177'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev177')
15
+ __version__ = version = '0.4.3.dev186'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev186')
jaxsim/exceptions.py CHANGED
@@ -17,6 +17,10 @@ def raise_if(
17
17
  format string (fmt), whose fields are filled with the args and kwargs.
18
18
  """
19
19
 
20
+ # Disable host callback if running on TPU.
21
+ if jax.devices()[0].platform == "tpu":
22
+ return
23
+
20
24
  # Check early that the format string is well-formed.
21
25
  try:
22
26
  _ = msg.format(*args, **kwargs)
jaxsim/mujoco/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
1
  from .loaders import RodModelToMjcf, SdfToMjcf, UrdfToMjcf
2
2
  from .model import MujocoModelHelper
3
+ from .utils import mujoco_data_from_jaxsim
3
4
  from .visualizer import MujocoVideoRecorder, MujocoVisualizer
jaxsim/mujoco/loaders.py CHANGED
@@ -646,7 +646,7 @@ class MujocoCamera:
646
646
  def build(cls, **kwargs) -> MujocoCamera:
647
647
 
648
648
  if not all(isinstance(value, str) for value in kwargs.values()):
649
- raise ValueError("Values must be strings")
649
+ raise ValueError(f"Values must be strings: {kwargs}")
650
650
 
651
651
  return cls(**kwargs)
652
652
 
jaxsim/mujoco/model.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import functools
4
4
  import pathlib
5
- from collections.abc import Callable
5
+ from collections.abc import Callable, Sequence
6
6
  from typing import Any
7
7
 
8
8
  import mujoco as mj
@@ -107,7 +107,8 @@ class MujocoModelHelper:
107
107
  size = [float(el) for el in hfield_element["@size"].split(" ")]
108
108
  size[0], size[1] = heightmap_radius_xy
109
109
  size[2] = 1.0
110
- size[3] = max(0, -min(hfield))
110
+ # The following could be zero but Mujoco complains if it's exactly zero.
111
+ size[3] = max(0.000_001, -min(hfield))
111
112
 
112
113
  # Replace the 'size' attribute.
113
114
  hfields_dict[heightmap_name]["@size"] = " ".join(str(el) for el in size)
@@ -315,7 +316,7 @@ class MujocoModelHelper:
315
316
  self.data.qpos[sl] = position
316
317
 
317
318
  def set_joint_positions(
318
- self, joint_names: list[str], positions: npt.NDArray | list[npt.NDArray]
319
+ self, joint_names: Sequence[str], positions: npt.NDArray | list[npt.NDArray]
319
320
  ) -> None:
320
321
  """Set the positions of multiple joints."""
321
322
 
jaxsim/mujoco/utils.py ADDED
@@ -0,0 +1,101 @@
1
+ import mujoco as mj
2
+ import numpy as np
3
+
4
+ from . import MujocoModelHelper
5
+
6
+
7
+ def mujoco_data_from_jaxsim(
8
+ mujoco_model: mj.MjModel,
9
+ jaxsim_model,
10
+ jaxsim_data,
11
+ mujoco_data: mj.MjData | None = None,
12
+ update_removed_joints: bool = True,
13
+ ) -> mj.MjData:
14
+ """
15
+ Create a Mujoco data object from a JaxSim model and data objects.
16
+
17
+ Args:
18
+ mujoco_model: The Mujoco model object corresponding to the JaxSim model.
19
+ jaxsim_model: The JaxSim model object from which the Mujoco model was created.
20
+ jaxsim_data: The JaxSim data object containing the state of the model.
21
+ mujoco_data: An optional Mujoco data object. If None, a new one will be created.
22
+ update_removed_joints:
23
+ If True, the positions of the joints that have been removed during the
24
+ model reduction process will be set to their initial values.
25
+
26
+ Returns:
27
+ The Mujoco data object containing the state of the JaxSim model.
28
+
29
+ Note:
30
+ This method is useful to initialize a Mujoco data object used for visualization
31
+ with the state of a JaxSim model. In particular, this function takes care of
32
+ initializing the positions of the joints that have been removed during the
33
+ model reduction process. After the initial creation of the Mujoco data object,
34
+ it's faster to update the state using an external MujocoModelHelper object.
35
+ """
36
+
37
+ # The package `jaxsim.mujoco` is supposed to be jax-independent.
38
+ # We import all the JaxSim resources privately.
39
+ import jaxsim.api as js
40
+
41
+ if not isinstance(jaxsim_model, js.model.JaxSimModel):
42
+ raise ValueError("The `jaxsim_model` argument must be a JaxSimModel object.")
43
+
44
+ if not isinstance(jaxsim_data, js.data.JaxSimModelData):
45
+ raise ValueError("The `jaxsim_data` argument must be a JaxSimModelData object.")
46
+
47
+ # Create the helper to operate on the Mujoco model and data.
48
+ model_helper = MujocoModelHelper(model=mujoco_model, data=mujoco_data)
49
+
50
+ # If the model is fixed-base, the Mujoco model won't have the joint corresponding
51
+ # to the floating base, and the helper would raise an exception.
52
+ if jaxsim_model.floating_base():
53
+
54
+ # Set the model position.
55
+ model_helper.set_base_position(position=np.array(jaxsim_data.base_position()))
56
+
57
+ # Set the model orientation.
58
+ model_helper.set_base_orientation(
59
+ orientation=np.array(jaxsim_data.base_orientation())
60
+ )
61
+
62
+ # Set the joint positions.
63
+ if jaxsim_model.dofs() > 0:
64
+
65
+ model_helper.set_joint_positions(
66
+ joint_names=list(jaxsim_model.joint_names()),
67
+ positions=np.array(
68
+ jaxsim_data.joint_positions(
69
+ model=jaxsim_model, joint_names=jaxsim_model.joint_names()
70
+ )
71
+ ),
72
+ )
73
+
74
+ # Updating these joints is not necessary after the first time.
75
+ # Users can disable this update after initialization.
76
+ if update_removed_joints:
77
+
78
+ # Create a dictionary with the joints that have been removed for various reasons
79
+ # (like link lumping due to model reduction).
80
+ joints_removed_dict = {
81
+ j.name: j
82
+ for j in jaxsim_model.description._joints_removed
83
+ if j.name not in set(jaxsim_model.joint_names())
84
+ }
85
+
86
+ # Set the positions of the removed joints.
87
+ _ = [
88
+ model_helper.set_joint_position(
89
+ position=joints_removed_dict[joint_name].initial_position,
90
+ joint_name=joint_name,
91
+ )
92
+ # Select all original joint that have been removed from the JaxSim model
93
+ # that are still present in the Mujoco model.
94
+ for joint_name in joints_removed_dict
95
+ if joint_name in model_helper.joint_names()
96
+ ]
97
+
98
+ # Return the mujoco data with updated kinematics.
99
+ mj.mj_forward(mujoco_model, model_helper.data)
100
+
101
+ return model_helper.data
@@ -89,7 +89,7 @@ class MujocoVideoRecorder:
89
89
  if not exist_ok and path.is_file():
90
90
  raise FileExistsError(f"The file '{path}' already exists.")
91
91
 
92
- media.write_video(path=path, images=self.frames, fps=self.fps)
92
+ media.write_video(path=path, images=np.array(self.frames), fps=self.fps)
93
93
 
94
94
  @staticmethod
95
95
  def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev177
3
+ Version: 0.4.3.dev186
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -1,6 +1,6 @@
1
- jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
2
- jaxsim/_version.py,sha256=SzGoIDpeznpZHWyfdxXEtnO3y8zLaXJORHRJRSUcxsU,428
3
- jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
1
+ jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
2
+ jaxsim/_version.py,sha256=K0Qt3IiihQ28Vnsxow16UNJNyDGyw4M94790Je5aXw8,428
3
+ jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
@@ -29,11 +29,12 @@ jaxsim/math/quaternion.py,sha256=_WA7W3iv7px83sWO1V1n0-J78hqAlO4SL1-jofE-UZ4,475
29
29
  jaxsim/math/rotation.py,sha256=k-nwT79zmWrys3NNAB-lGWxat7Kqm_6JnFRoimJ8rBg,2156
30
30
  jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
31
31
  jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
32
- jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
32
+ jaxsim/mujoco/__init__.py,sha256=fZyRWre49pIhOrYdf6yJk_hOax8qWGe8OCmoq-dMVq8,201
33
33
  jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
34
- jaxsim/mujoco/loaders.py,sha256=qT7Le_L7z2prXKA7O9x5rkbbh-_lIrrmLXTjgoAjhZ4,25339
35
- jaxsim/mujoco/model.py,sha256=AQksXemXWACJ3yvefV2G5HLwwBU9ISoJrOD1wlxdY5w,16386
36
- jaxsim/mujoco/visualizer.py,sha256=T1vU-w4NKSmgEkZ0FqVcGmIvYrYO0len2UBSsU4MOZ0,6978
34
+ jaxsim/mujoco/loaders.py,sha256=CkFGydgOku5P_Pz7wdWlM2SCJRs71ePF-vsY9i90-I0,25350
35
+ jaxsim/mujoco/model.py,sha256=5_7rWk_WBkNKDHqeewIFj0t2ZGqJpE6RDXHSbRvw4e4,16493
36
+ jaxsim/mujoco/utils.py,sha256=bGbLMSzcdqbinIwHHJHt8ZN1uup_6DLdB2dWqKiXwO4,3955
37
+ jaxsim/mujoco/visualizer.py,sha256=nD6SNWmn-nxjjjIY9oPAHvL2j8q93DJDjZeepzke_DQ,6988
37
38
  jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
39
  jaxsim/parsers/kinematic_graph.py,sha256=wT2bgaCS8VQJTHy2H9sENkVPDOiMkRikxEF1t_WaahQ,34748
39
40
  jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
@@ -64,8 +65,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
64
65
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
65
66
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
66
67
  jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
67
- jaxsim-0.4.3.dev177.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
68
- jaxsim-0.4.3.dev177.dist-info/METADATA,sha256=BMlT_szB4WbLIZmCucz750x_aRXoe2n7ycwlzI3l-sk,17276
69
- jaxsim-0.4.3.dev177.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
70
- jaxsim-0.4.3.dev177.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
71
- jaxsim-0.4.3.dev177.dist-info/RECORD,,
68
+ jaxsim-0.4.3.dev186.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
+ jaxsim-0.4.3.dev186.dist-info/METADATA,sha256=YYb7FonjyOeyop0Ni-0dB0ijfk215c0-PZo6k9v6JAo,17276
70
+ jaxsim-0.4.3.dev186.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
71
+ jaxsim-0.4.3.dev186.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
+ jaxsim-0.4.3.dev186.dist-info/RECORD,,