jaxsim 0.5.1.dev91__py3-none-any.whl → 0.5.1.dev95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.5.1.dev91'
16
- __version_tuple__ = version_tuple = (0, 5, 1, 'dev91')
15
+ __version__ = version = '0.5.1.dev95'
16
+ __version_tuple__ = version_tuple = (0, 5, 1, 'dev95')
jaxsim/api/data.py CHANGED
@@ -382,9 +382,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
382
382
  # we introduce a Baumgarte stabilization to let the quaternion converge to
383
383
  # a unit quaternion. In this case, it is not guaranteed that the quaternion
384
384
  # stored in the state is a unit quaternion.
385
- W_Q_B = jnp.where(
386
- jnp.allclose(W_Q_B.dot(W_Q_B), 1.0), W_Q_B, W_Q_B / jnp.linalg.norm(W_Q_B)
387
- )
385
+ norm = jaxsim.math.safe_norm(W_Q_B)
386
+ W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
388
387
 
389
388
  return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
390
389
  float
@@ -611,11 +610,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
611
610
 
612
611
  W_Q_B = jnp.array(base_quaternion, dtype=float)
613
612
 
614
- W_Q_B = jax.lax.select(
615
- pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
616
- on_true=W_Q_B,
617
- on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
618
- )
613
+ norm = jaxsim.math.safe_norm(W_Q_B)
614
+ W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
619
615
 
620
616
  return self.replace(
621
617
  validate=True,
jaxsim/math/__init__.py CHANGED
@@ -8,5 +8,6 @@ from .quaternion import Quaternion
8
8
  from .rotation import Rotation
9
9
  from .skew import Skew
10
10
  from .transform import Transform
11
+ from .utils import safe_norm
11
12
 
12
13
  from .joint_model import JointModel, supported_joint_motion # isort:skip
jaxsim/math/quaternion.py CHANGED
@@ -4,6 +4,8 @@ import jaxlie
4
4
 
5
5
  import jaxsim.typing as jtp
6
6
 
7
+ from .utils import safe_norm
8
+
7
9
 
8
10
  class Quaternion:
9
11
  @staticmethod
@@ -111,18 +113,13 @@ class Quaternion:
111
113
  operand=quaternion,
112
114
  )
113
115
 
114
- norm_ω = jax.lax.cond(
115
- pred=ω.dot(ω) < (1e-6) ** 2,
116
- true_fun=lambda _: 1e-6,
117
- false_fun=lambda _: jnp.linalg.norm(ω),
118
- operand=None,
119
- )
116
+ norm_ω = safe_norm(ω)
120
117
 
121
118
  qd = 0.5 * (
122
119
  Q
123
120
  @ jnp.hstack(
124
121
  [
125
- K * norm_ω * (1 - jnp.linalg.norm(quaternion)),
122
+ K * norm_ω * (1 - safe_norm(quaternion)),
126
123
  ω,
127
124
  ]
128
125
  )
jaxsim/math/rotation.py CHANGED
@@ -4,6 +4,7 @@ import jaxlie
4
4
  import jaxsim.typing as jtp
5
5
 
6
6
  from .skew import Skew
7
+ from .utils import safe_norm
7
8
 
8
9
 
9
10
  class Rotation:
@@ -67,7 +68,7 @@ class Rotation:
67
68
  def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:
68
69
 
69
70
  v = axis
70
- theta = jnp.linalg.norm(v)
71
+ theta = safe_norm(v)
71
72
 
72
73
  s = jnp.sin(theta)
73
74
  c = jnp.cos(theta)
@@ -81,19 +82,9 @@ class Rotation:
81
82
 
82
83
  return R.transpose()
83
84
 
84
- # Use the double-where trick to prevent JAX problems when the
85
- # jax.jit and jax.grad transforms are applied.
86
85
  return jnp.where(
87
- jnp.linalg.norm(vector) > 0,
88
- theta_is_not_zero(
89
- axis=jnp.where(
90
- jnp.linalg.norm(vector) > 0,
91
- vector,
92
- # The following line is a workaround to prevent division by 0.
93
- # Considering the outer where, this branch is never executed.
94
- jnp.ones(3),
95
- )
96
- ),
86
+ jnp.allclose(vector, 0.0),
97
87
  # Return an identity rotation matrix when the input vector is zero.
98
88
  jnp.eye(3),
89
+ theta_is_not_zero(axis=vector),
99
90
  )
jaxsim/math/utils.py ADDED
@@ -0,0 +1,31 @@
1
+ import jax.numpy as jnp
2
+
3
+ import jaxsim.typing as jtp
4
+
5
+
6
+ def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
7
+ """
8
+ Provides a calculation for an array norm so that it is safe
9
+ to compute the gradient and handle NaNs.
10
+
11
+ Args:
12
+ array: The array for which to compute the norm.
13
+ axis: The axis for which to compute the norm.
14
+
15
+ Returns:
16
+ The norm of the array with handling for zero arrays to avoid NaNs.
17
+ """
18
+
19
+ # Check if the entire array is composed of zeros.
20
+ is_zero = jnp.allclose(array, 0.0)
21
+
22
+ # Replace zeros with an array of ones temporarily to avoid division by zero.
23
+ # This ensures the computation of norm does not produce NaNs or Infs.
24
+ array = jnp.where(is_zero, jnp.ones_like(array), array)
25
+
26
+ # Compute the norm of the array along the specified axis.
27
+ norm = jnp.linalg.norm(array, axis=axis)
28
+
29
+ # Use `jnp.where` to set the norm to 0.0 where the input array was all zeros.
30
+ # This usage supports potential batch processing for future scalability.
31
+ return jnp.where(is_zero, 0.0, norm)
@@ -309,19 +309,16 @@ class SoftContacts(common.ContactModel):
309
309
 
310
310
  # Compute the direction of the tangential force.
311
311
  # To prevent dividing by zero, we use a switch statement.
312
- # The ε, instead, is needed to make AD happy.
313
- f_tangential_direction = jnp.where(
314
- f_tangential.dot(f_tangential) != 0,
315
- f_tangential / jnp.linalg.norm(f_tangential + ε),
316
- jnp.zeros(3),
312
+ norm = jaxsim.math.safe_norm(f_tangential)
313
+ f_tangential_direction = f_tangential / (
314
+ norm + jnp.finfo(float).eps * (norm == 0)
317
315
  )
318
316
 
319
317
  # Project the tangential force to the friction cone if slipping.
320
318
  f_tangential = jnp.where(
321
319
  sticking,
322
320
  f_tangential,
323
- jnp.minimum(μ * force_normal_mag, jnp.linalg.norm(f_tangential + ε))
324
- * f_tangential_direction,
321
+ jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
325
322
  )
326
323
 
327
324
  # Set the tangential force to zero if there is no contact.
jaxsim/terrain/terrain.py CHANGED
@@ -7,6 +7,7 @@ import jax.numpy as jnp
7
7
  import jax_dataclasses
8
8
  import numpy as np
9
9
 
10
+ import jaxsim.math
10
11
  import jaxsim.typing as jtp
11
12
  from jaxsim import exceptions
12
13
 
@@ -41,7 +42,7 @@ class Terrain(abc.ABC):
41
42
  [(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
42
43
  )
43
44
 
44
- return n / jnp.linalg.norm(n)
45
+ return n / jaxsim.math.safe_norm(n)
45
46
 
46
47
 
47
48
  @jax_dataclasses.pytree_dataclass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev91
3
+ Version: 0.5.1.dev95
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
2
- jaxsim/_version.py,sha256=HHdXV3EXu0rha3QfUW2g4pSsGNWqjfhD2e1Qwed6NGk,426
2
+ jaxsim/_version.py,sha256=uLmioR8f_16hLO76GPSpzQyclhE2S-4nH8pntMObYGY,426
3
3
  jaxsim/exceptions.py,sha256=Sq3qtqeiy-CK76new_W2KKQ-4MAzyOUK5j5pBLr4RPQ,2250
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
@@ -7,7 +7,7 @@ jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=5fYNRUhKE5VGGdW88zY8mqqEy5VTWyaHu5k6MgW4Jt4,13826
8
8
  jaxsim/api/common.py,sha256=SvEOGxCKOxKLLVHaNp1sFkBX0sku3-wH0-HUlYVWCDk,7090
9
9
  jaxsim/api/contact.py,sha256=vfW-HEvQcAUHl7dOOwI-ndRxgMeAtkKT7tTaMDFlh7k,25421
10
- jaxsim/api/data.py,sha256=hz8g-P0o7XoDYyUFPx6yA8QlgHmjfFf2_OdiwcRV6W8,30292
10
+ jaxsim/api/data.py,sha256=AFp1sDNRIkwpBom6ZlW6L7vJBtf4D9woVRJ8bGICr3s,30189
11
11
  jaxsim/api/frame.py,sha256=d6pa6vywGDqfaJU76F_-yjLJs6R3mrjZ6B-KXPu6f3Q,14595
12
12
  jaxsim/api/joint.py,sha256=AnqlNWmBOay-gsoo0y4AbfFQ2OCJm-8T1E0IMhZeLoY,7457
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=wnto0nzzEJ_M8tH2PUdldEyxQwQdsStYUoQFu696uuw,29897
@@ -20,15 +20,16 @@ jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-K
20
20
  jaxsim/integrators/common.py,sha256=fnDqVIIXMYe2aiT_qnEhJSAeFuYRGhmElVCl7zPTrN8,18229
21
21
  jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
22
22
  jaxsim/integrators/variable_step.py,sha256=HuUKudeFj0W7dvVATVNZK3uk1Nh_qKlGO_CDqXJFV14,22166
23
- jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
23
+ jaxsim/math/__init__.py,sha256=2T1WUU_chNBCvyvkKSdiesPlckbo-gXVbCZEGoF-W0I,381
24
24
  jaxsim/math/adjoint.py,sha256=V7r5VrTCKPLEL5gavNSx9U7xSsrb11a5e4gWqJ2MuRo,4375
25
25
  jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
26
26
  jaxsim/math/inertia.py,sha256=01hz6wMFreN2jBA0rVoBS1YMVh77KvwuzXSOpI3pxNk,1614
27
27
  jaxsim/math/joint_model.py,sha256=EzAveaG5B6ZnCFNUzN30KEQUVesd83lfWXJarYR-kUw,9989
28
- jaxsim/math/quaternion.py,sha256=_WA7W3iv7px83sWO1V1n0-J78hqAlO4SL1-jofE-UZ4,4754
29
- jaxsim/math/rotation.py,sha256=P34sx28Rh1MhNwxUxqXjxP-ZN1_5tvoflMoAIpy2LbE,2586
28
+ jaxsim/math/quaternion.py,sha256=vrPkdSUPfv1RZOULY9uG_bxmqgOARdvMxKtb6QixHuY,4609
29
+ jaxsim/math/rotation.py,sha256=W6vaIWmoBBuPgNJKey1vFpl2a1IxXTToCTaPfs-kd9I,2155
30
30
  jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
31
31
  jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
32
+ jaxsim/math/utils.py,sha256=C5kP11KWsIacWtzouaI5tNH8BHjZ-ZgZ67U9wzjz7jw,1070
32
33
  jaxsim/mujoco/__init__.py,sha256=fZyRWre49pIhOrYdf6yJk_hOax8qWGe8OCmoq-dMVq8,201
33
34
  jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
34
35
  jaxsim/mujoco/loaders.py,sha256=_CZekIqZNe8oFeH7zSv4gGZAZENRISwMd8dt640zjRI,20860
@@ -58,16 +59,16 @@ jaxsim/rbda/contacts/__init__.py,sha256=L5MM-2pv76YPGzxExdz2EErgGBATuAjYnNHlq5QO
58
59
  jaxsim/rbda/contacts/common.py,sha256=BjwZMCkzd1ZOdZW7_Zt09Cl5j2JUHXM5Q8ao_qS6e64,10406
59
60
  jaxsim/rbda/contacts/relaxed_rigid.py,sha256=PgwKfProN5sLXJsSov3nIidHHMVpJqIp7eIv6_bPGjs,20345
60
61
  jaxsim/rbda/contacts/rigid.py,sha256=X-PE6PmZqlKoZTY6JhYBSW-vom-rq2uBKmBUNQeQHCg,15991
61
- jaxsim/rbda/contacts/soft.py,sha256=sIWT4NUJmoVR5T1Fo0ExdPfzL_gPfiPiB-9CFuotE_s,15567
62
+ jaxsim/rbda/contacts/soft.py,sha256=MhqHThn3XVyy8lRg6QVYzcJF4vVRKwnMvOMEyyF96OE,15443
62
63
  jaxsim/rbda/contacts/visco_elastic.py,sha256=QhyJHjDowyBTAhoSdZcCIkOqzp__gMXhLON-qYyMgQc,39886
63
64
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
64
- jaxsim/terrain/terrain.py,sha256=_G1QS3zWycj089R8fTP5s2VjcZpEdJxREjXZJ-oXIvc,5248
65
+ jaxsim/terrain/terrain.py,sha256=bv-YAwG06EydHnB6bcNtl7xIyB3LSl0vVXSVLFC4JpQ,5273
65
66
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
66
67
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
67
68
  jaxsim/utils/tracing.py,sha256=eEY28MZW0Lm_jJNt1NkFqZz0ek01tvhR46OXZYCo7tc,532
68
69
  jaxsim/utils/wrappers.py,sha256=ZY7olSORzZRvSzkdeNLj8yjwUIAt9L0Douwl7wItjpk,4008
69
- jaxsim-0.5.1.dev91.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
70
- jaxsim-0.5.1.dev91.dist-info/METADATA,sha256=PPDDHpeFVoVidQMnmZcNT_Spo8ryctjuFISZWYM_ZgI,17937
71
- jaxsim-0.5.1.dev91.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
72
- jaxsim-0.5.1.dev91.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
73
- jaxsim-0.5.1.dev91.dist-info/RECORD,,
70
+ jaxsim-0.5.1.dev95.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
71
+ jaxsim-0.5.1.dev95.dist-info/METADATA,sha256=xOMthPKBfJqlJrzyVADzPjB5b-h8MKI3K8CGRdcKZy4,17937
72
+ jaxsim-0.5.1.dev95.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
73
+ jaxsim-0.5.1.dev95.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
74
+ jaxsim-0.5.1.dev95.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.7.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5