jaxsim 0.4.3.dev177__py3-none-any.whl → 0.4.3.dev181__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/__init__.py CHANGED
@@ -8,21 +8,35 @@ def _jnp_options() -> None:
8
8
 
9
9
  import jax
10
10
 
11
- # Enable by default 64bit precision in JAX.
12
- if os.environ.get("JAX_ENABLE_X64", "1") != "0":
13
-
14
- logging.info("Enabling JAX to use 64bit precision")
11
+ # Check if running on TPU
12
+ is_tpu = jax.devices()[0].platform == "tpu"
13
+
14
+ # Enable by default 64-bit precision to get accurate physics.
15
+ # Users can enforce 32-bit precision by setting the following variable to 0.
16
+ use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"
17
+
18
+ # Notify the user if unsupported 64-bit precision was enforced on TPU.
19
+ if is_tpu and use_x64:
20
+ msg = "64-bit precision is not allowed on TPU. Enforcing 32bit precision."
21
+ logging.warning(msg)
22
+ use_x64 = False
23
+
24
+ # Enable 64-bit precision in JAX.
25
+ if use_x64:
26
+ logging.info("Enabling JAX to use 64-bit precision")
15
27
  jax.config.update("jax_enable_x64", True)
16
28
 
17
29
  import jax.numpy as jnp
18
30
  import numpy as np
19
31
 
32
+ # Verify that 64-bit precision is correctly set.
20
33
  if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
21
- logging.warning("Failed to enable 64bit precision in JAX")
34
+ logging.warning("Failed to enable 64-bit precision in JAX")
22
35
 
36
+ # Warn about experimental usage of 32-bit precision.
23
37
  else:
24
38
  logging.warning(
25
- "Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
39
+ "Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
26
40
  )
27
41
 
28
42
 
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev177'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev177')
15
+ __version__ = version = '0.4.3.dev181'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev181')
jaxsim/exceptions.py CHANGED
@@ -17,6 +17,10 @@ def raise_if(
17
17
  format string (fmt), whose fields are filled with the args and kwargs.
18
18
  """
19
19
 
20
+ # Disable host callback if running on TPU.
21
+ if jax.devices()[0].platform == "tpu":
22
+ return
23
+
20
24
  # Check early that the format string is well-formed.
21
25
  try:
22
26
  _ = msg.format(*args, **kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev177
3
+ Version: 0.4.3.dev181
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -1,6 +1,6 @@
1
- jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
2
- jaxsim/_version.py,sha256=SzGoIDpeznpZHWyfdxXEtnO3y8zLaXJORHRJRSUcxsU,428
3
- jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
1
+ jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
2
+ jaxsim/_version.py,sha256=2YMpT461ObsI3rceAHrUVR-OPYPsQ43SHkzRH20mlDY,428
3
+ jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
@@ -64,8 +64,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
64
64
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
65
65
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
66
66
  jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
67
- jaxsim-0.4.3.dev177.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
68
- jaxsim-0.4.3.dev177.dist-info/METADATA,sha256=BMlT_szB4WbLIZmCucz750x_aRXoe2n7ycwlzI3l-sk,17276
69
- jaxsim-0.4.3.dev177.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
70
- jaxsim-0.4.3.dev177.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
71
- jaxsim-0.4.3.dev177.dist-info/RECORD,,
67
+ jaxsim-0.4.3.dev181.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
68
+ jaxsim-0.4.3.dev181.dist-info/METADATA,sha256=w-b474j6ugFST4V5QJ2h9DTtrb2xRH0cdCRAMyBG8wg,17276
69
+ jaxsim-0.4.3.dev181.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
70
+ jaxsim-0.4.3.dev181.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
71
+ jaxsim-0.4.3.dev181.dist-info/RECORD,,