jaxsim 0.7.1.dev38__py3-none-any.whl → 0.7.1.dev43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.1.dev38'
21
- __version_tuple__ = version_tuple = (0, 7, 1, 'dev38')
20
+ __version__ = version = '0.7.1.dev43'
21
+ __version_tuple__ = version_tuple = (0, 7, 1, 'dev43')
jaxsim/math/utils.py CHANGED
@@ -1,8 +1,47 @@
1
+ import jax
1
2
  import jax.numpy as jnp
2
3
 
3
4
  import jaxsim.typing as jtp
4
5
 
5
6
 
7
+ def _make_safe_norm(axis, keepdims):
8
+ @jax.custom_jvp
9
+ def _safe_norm(array: jtp.ArrayLike) -> jtp.Array:
10
+ """
11
+ Compute an array norm handling NaNs and making sure that
12
+ it is safe to get the gradient.
13
+
14
+ Args:
15
+ array: The array for which to compute the norm.
16
+
17
+ Returns:
18
+ The norm of the array with handling for zero arrays to avoid NaNs.
19
+ """
20
+ # Compute the norm of the array along the specified axis.
21
+ return jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
22
+
23
+ @_safe_norm.defjvp
24
+ def _safe_norm_jvp(primals, tangents):
25
+ (x,), (x_dot,) = primals, tangents
26
+
27
+ # Check if the entire array is composed of zeros.
28
+ is_zero = jnp.all(x == 0.0)
29
+
30
+ # Replace zeros with an array of ones temporarily to avoid division by zero.
31
+ # This ensures the computation of norm does not produce NaNs or Infs.
32
+ array = jnp.where(is_zero, jnp.ones_like(x), x)
33
+
34
+ # Compute the norm of the array along the specified axis.
35
+ norm = jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
36
+
37
+ dot = jnp.sum(array * x_dot, axis=axis, keepdims=keepdims)
38
+ tangent = jnp.where(is_zero, 0.0, dot / norm)
39
+
40
+ return jnp.where(is_zero, 0.0, norm), tangent
41
+
42
+ return _safe_norm
43
+
44
+
6
45
  def safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False) -> jtp.Array:
7
46
  """
8
47
  Compute an array norm handling NaNs and making sure that
@@ -16,17 +55,4 @@ def safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False) -> jtp
16
55
  Returns:
17
56
  The norm of the array with handling for zero arrays to avoid NaNs.
18
57
  """
19
-
20
- # Check if the entire array is composed of zeros.
21
- is_zero = jnp.allclose(array, 0.0)
22
-
23
- # Replace zeros with an array of ones temporarily to avoid division by zero.
24
- # This ensures the computation of norm does not produce NaNs or Infs.
25
- array = jnp.where(is_zero, jnp.ones_like(array), array)
26
-
27
- # Compute the norm of the array along the specified axis.
28
- norm = jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
29
-
30
- # Use `jnp.where` to set the norm to 0.0 where the input array was all zeros.
31
- # This usage supports potential batch processing for future scalability.
32
- return jnp.where(is_zero, 0.0, norm)
58
+ return _make_safe_norm(axis, keepdims)(array)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.7.1.dev38
3
+ Version: 0.7.1.dev43
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=EKeysKN-7UswwJLCl7n6qIBKQIVUtYsCMYu_tCoFn7g,3628
2
- jaxsim/_version.py,sha256=jJU4vdmIgtkYFUJO3x0tHykHUkxDr69boyVHg05l7rU,526
2
+ jaxsim/_version.py,sha256=voU-Mq2GmQtb3vS3AKLaBQRwskNrzOE3uH9-98vPJIo,526
3
3
  jaxsim/exceptions.py,sha256=MQ3LRMfVMX2-g3qYj7mUVNV9OLlIA48TANJegbcQyXI,2641
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=7msl8t5Jt09RNYfKdPJtpjLfWurldcycDappb045Eso,761
@@ -26,7 +26,7 @@ jaxsim/math/quaternion.py,sha256=MSaZywzJDxs2te1ZELeIcupKSFIA9q_pdXy7fDAEqM4,453
26
26
  jaxsim/math/rotation.py,sha256=TEUtT3X2tFieNxdlccup1pfaTgCTtfX-hTNotd8-nNk,1892
27
27
  jaxsim/math/skew.py,sha256=z_9YN-NDHL3n4KXWNbzTSMkFDZ0SDpz4RUcwwYFOaao,1402
28
28
  jaxsim/math/transform.py,sha256=d0_m_obmUOmnI8Bte0ktvibR9Hv9M9qpg8tVuLON2g0,3192
29
- jaxsim/math/utils.py,sha256=IiH01iN54BtLnULC04pDfYe8Av99p3FGdYp2jJInm30,1166
29
+ jaxsim/math/utils.py,sha256=JgJrBPeuCvi0969VqoNsyk3CflQiLzopngKDjl6RfiE,1898
30
30
  jaxsim/mujoco/__init__.py,sha256=1kAWzYOS7nP29S5FGyWPqiAnPf4yPSoaPW-WBGBjVV0,214
31
31
  jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
32
32
  jaxsim/mujoco/loaders.py,sha256=OCk1T11iIm3qZUibNpo_bxxLgaGSkCpLt7ae_ND0ExA,23272
@@ -65,8 +65,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
65
65
  jaxsim/utils/jaxsim_dataclass.py,sha256=XzmZeIibcaOzaxpprsGSxH3UrM66PAO456rFV91sNXg,11453
66
66
  jaxsim/utils/tracing.py,sha256=Btwxdfhb7fJLk3r5PlQkGYj60Y2KbFT1gANGIA697FU,530
67
67
  jaxsim/utils/wrappers.py,sha256=3IMwydqFgmSPqeuUQ3PRmdhDc1IoT6XC23jPC_LjWXs,4175
68
- jaxsim-0.7.1.dev38.dist-info/licenses/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
- jaxsim-0.7.1.dev38.dist-info/METADATA,sha256=zzMOWYv37A9BbHQiYLQ6iTRNb1hsLaesMyLPC4GBAMM,17851
70
- jaxsim-0.7.1.dev38.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
71
- jaxsim-0.7.1.dev38.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
- jaxsim-0.7.1.dev38.dist-info/RECORD,,
68
+ jaxsim-0.7.1.dev43.dist-info/licenses/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
+ jaxsim-0.7.1.dev43.dist-info/METADATA,sha256=fWu7t97w6jHa9j_ICI4Fs90WwHQmwSGo5DzkrQbYoQ4,17851
70
+ jaxsim-0.7.1.dev43.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
71
+ jaxsim-0.7.1.dev43.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
+ jaxsim-0.7.1.dev43.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5