jaxsim 0.5.1.dev91__tar.gz → 0.5.1.dev95__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/PKG-INFO +1 -1
  2. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/_version.py +2 -2
  3. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/data.py +4 -8
  4. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/math/__init__.py +1 -0
  5. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/math/quaternion.py +4 -7
  6. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/math/rotation.py +4 -13
  7. jaxsim-0.5.1.dev95/src/jaxsim/math/utils.py +31 -0
  8. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/contacts/soft.py +4 -7
  9. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/terrain/terrain.py +2 -1
  10. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim.egg-info/PKG-INFO +1 -1
  11. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim.egg-info/SOURCES.txt +1 -0
  12. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_automatic_differentiation.py +43 -0
  13. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.devcontainer/Dockerfile +0 -0
  14. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.devcontainer/devcontainer.json +0 -0
  15. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.gitattributes +0 -0
  16. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.github/CODEOWNERS +0 -0
  17. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.github/dependabot.yml +0 -0
  18. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.github/workflows/ci_cd.yml +0 -0
  19. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.github/workflows/pixi.yml +0 -0
  20. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.github/workflows/read_the_docs.yml +0 -0
  21. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.gitignore +0 -0
  22. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.pre-commit-config.yaml +0 -0
  23. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/.readthedocs.yaml +0 -0
  24. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/CONTRIBUTING.md +0 -0
  25. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/LICENSE +0 -0
  26. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/README.md +0 -0
  27. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/Makefile +0 -0
  28. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/conf.py +0 -0
  29. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/examples.rst +0 -0
  30. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/guide/configuration.rst +0 -0
  31. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/guide/install.rst +0 -0
  32. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/index.rst +0 -0
  33. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/make.bat +0 -0
  34. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/modules/api.rst +0 -0
  35. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/modules/integrators.rst +0 -0
  36. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/modules/math.rst +0 -0
  37. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/modules/mujoco.rst +0 -0
  38. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/modules/parsers.rst +0 -0
  39. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/modules/rbda.rst +0 -0
  40. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/modules/typing.rst +0 -0
  41. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/docs/modules/utils.rst +0 -0
  42. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/environment.yml +0 -0
  43. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/examples/.gitattributes +0 -0
  44. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/examples/.gitignore +0 -0
  45. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/examples/README.md +0 -0
  46. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/examples/assets/build_cartpole_urdf.py +0 -0
  47. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/examples/assets/cartpole.urdf +0 -0
  48. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  49. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  50. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  51. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/pixi.lock +0 -0
  52. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/pyproject.toml +0 -0
  53. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/setup.cfg +0 -0
  54. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/setup.py +0 -0
  55. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/__init__.py +0 -0
  56. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/__init__.py +0 -0
  57. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/com.py +0 -0
  58. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/common.py +0 -0
  59. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/contact.py +0 -0
  60. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/frame.py +0 -0
  61. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/joint.py +0 -0
  62. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  63. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/link.py +0 -0
  64. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/model.py +0 -0
  65. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/ode.py +0 -0
  66. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/ode_data.py +0 -0
  67. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/api/references.py +0 -0
  68. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/exceptions.py +0 -0
  69. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/integrators/__init__.py +0 -0
  70. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/integrators/common.py +0 -0
  71. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/integrators/fixed_step.py +0 -0
  72. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/integrators/variable_step.py +0 -0
  73. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/logging.py +0 -0
  74. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/math/adjoint.py +0 -0
  75. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/math/cross.py +0 -0
  76. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/math/inertia.py +0 -0
  77. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/math/joint_model.py +0 -0
  78. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/math/skew.py +0 -0
  79. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/math/transform.py +0 -0
  80. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/mujoco/__init__.py +0 -0
  81. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/mujoco/__main__.py +0 -0
  82. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/mujoco/loaders.py +0 -0
  83. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/mujoco/model.py +0 -0
  84. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/mujoco/utils.py +0 -0
  85. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/mujoco/visualizer.py +0 -0
  86. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/__init__.py +0 -0
  87. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  88. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  89. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  90. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/descriptions/link.py +0 -0
  91. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/descriptions/model.py +0 -0
  92. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  93. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/rod/__init__.py +0 -0
  94. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/rod/meshes.py +0 -0
  95. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/rod/parser.py +0 -0
  96. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/parsers/rod/utils.py +0 -0
  97. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/__init__.py +0 -0
  98. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/aba.py +0 -0
  99. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/collidable_points.py +0 -0
  100. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  101. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/contacts/common.py +0 -0
  102. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
  103. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  104. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/contacts/visco_elastic.py +0 -0
  105. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/crba.py +0 -0
  106. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  107. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/jacobian.py +0 -0
  108. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/rnea.py +0 -0
  109. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/rbda/utils.py +0 -0
  110. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/terrain/__init__.py +0 -0
  111. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/typing.py +0 -0
  112. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/utils/__init__.py +0 -0
  113. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  114. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/utils/tracing.py +0 -0
  115. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim/utils/wrappers.py +0 -0
  116. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  117. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim.egg-info/requires.txt +0 -0
  118. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/src/jaxsim.egg-info/top_level.txt +0 -0
  119. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/__init__.py +0 -0
  120. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/conftest.py +0 -0
  121. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_api_com.py +0 -0
  122. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_api_contact.py +0 -0
  123. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_api_data.py +0 -0
  124. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_api_frame.py +0 -0
  125. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_api_joint.py +0 -0
  126. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_api_link.py +0 -0
  127. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_api_model.py +0 -0
  128. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_benchmark.py +0 -0
  129. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_contact.py +0 -0
  130. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_exceptions.py +0 -0
  131. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_meshes.py +0 -0
  132. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_pytree.py +0 -0
  133. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/test_simulations.py +0 -0
  134. {jaxsim-0.5.1.dev91 → jaxsim-0.5.1.dev95}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev91
3
+ Version: 0.5.1.dev95
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.5.1.dev91'
16
- __version_tuple__ = version_tuple = (0, 5, 1, 'dev91')
15
+ __version__ = version = '0.5.1.dev95'
16
+ __version_tuple__ = version_tuple = (0, 5, 1, 'dev95')
@@ -382,9 +382,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
382
382
  # we introduce a Baumgarte stabilization to let the quaternion converge to
383
383
  # a unit quaternion. In this case, it is not guaranteed that the quaternion
384
384
  # stored in the state is a unit quaternion.
385
- W_Q_B = jnp.where(
386
- jnp.allclose(W_Q_B.dot(W_Q_B), 1.0), W_Q_B, W_Q_B / jnp.linalg.norm(W_Q_B)
387
- )
385
+ norm = jaxsim.math.safe_norm(W_Q_B)
386
+ W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
388
387
 
389
388
  return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
390
389
  float
@@ -611,11 +610,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
611
610
 
612
611
  W_Q_B = jnp.array(base_quaternion, dtype=float)
613
612
 
614
- W_Q_B = jax.lax.select(
615
- pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
616
- on_true=W_Q_B,
617
- on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
618
- )
613
+ norm = jaxsim.math.safe_norm(W_Q_B)
614
+ W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
619
615
 
620
616
  return self.replace(
621
617
  validate=True,
@@ -8,5 +8,6 @@ from .quaternion import Quaternion
8
8
  from .rotation import Rotation
9
9
  from .skew import Skew
10
10
  from .transform import Transform
11
+ from .utils import safe_norm
11
12
 
12
13
  from .joint_model import JointModel, supported_joint_motion # isort:skip
@@ -4,6 +4,8 @@ import jaxlie
4
4
 
5
5
  import jaxsim.typing as jtp
6
6
 
7
+ from .utils import safe_norm
8
+
7
9
 
8
10
  class Quaternion:
9
11
  @staticmethod
@@ -111,18 +113,13 @@ class Quaternion:
111
113
  operand=quaternion,
112
114
  )
113
115
 
114
- norm_ω = jax.lax.cond(
115
- pred=ω.dot(ω) < (1e-6) ** 2,
116
- true_fun=lambda _: 1e-6,
117
- false_fun=lambda _: jnp.linalg.norm(ω),
118
- operand=None,
119
- )
116
+ norm_ω = safe_norm(ω)
120
117
 
121
118
  qd = 0.5 * (
122
119
  Q
123
120
  @ jnp.hstack(
124
121
  [
125
- K * norm_ω * (1 - jnp.linalg.norm(quaternion)),
122
+ K * norm_ω * (1 - safe_norm(quaternion)),
126
123
  ω,
127
124
  ]
128
125
  )
@@ -4,6 +4,7 @@ import jaxlie
4
4
  import jaxsim.typing as jtp
5
5
 
6
6
  from .skew import Skew
7
+ from .utils import safe_norm
7
8
 
8
9
 
9
10
  class Rotation:
@@ -67,7 +68,7 @@ class Rotation:
67
68
  def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:
68
69
 
69
70
  v = axis
70
- theta = jnp.linalg.norm(v)
71
+ theta = safe_norm(v)
71
72
 
72
73
  s = jnp.sin(theta)
73
74
  c = jnp.cos(theta)
@@ -81,19 +82,9 @@ class Rotation:
81
82
 
82
83
  return R.transpose()
83
84
 
84
- # Use the double-where trick to prevent JAX problems when the
85
- # jax.jit and jax.grad transforms are applied.
86
85
  return jnp.where(
87
- jnp.linalg.norm(vector) > 0,
88
- theta_is_not_zero(
89
- axis=jnp.where(
90
- jnp.linalg.norm(vector) > 0,
91
- vector,
92
- # The following line is a workaround to prevent division by 0.
93
- # Considering the outer where, this branch is never executed.
94
- jnp.ones(3),
95
- )
96
- ),
86
+ jnp.allclose(vector, 0.0),
97
87
  # Return an identity rotation matrix when the input vector is zero.
98
88
  jnp.eye(3),
89
+ theta_is_not_zero(axis=vector),
99
90
  )
@@ -0,0 +1,31 @@
1
+ import jax.numpy as jnp
2
+
3
+ import jaxsim.typing as jtp
4
+
5
+
6
+ def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
7
+ """
8
+ Provides a calculation for an array norm so that it is safe
9
+ to compute the gradient and handle NaNs.
10
+
11
+ Args:
12
+ array: The array for which to compute the norm.
13
+ axis: The axis for which to compute the norm.
14
+
15
+ Returns:
16
+ The norm of the array with handling for zero arrays to avoid NaNs.
17
+ """
18
+
19
+ # Check if the entire array is composed of zeros.
20
+ is_zero = jnp.allclose(array, 0.0)
21
+
22
+ # Replace zeros with an array of ones temporarily to avoid division by zero.
23
+ # This ensures the computation of norm does not produce NaNs or Infs.
24
+ array = jnp.where(is_zero, jnp.ones_like(array), array)
25
+
26
+ # Compute the norm of the array along the specified axis.
27
+ norm = jnp.linalg.norm(array, axis=axis)
28
+
29
+ # Use `jnp.where` to set the norm to 0.0 where the input array was all zeros.
30
+ # This usage supports potential batch processing for future scalability.
31
+ return jnp.where(is_zero, 0.0, norm)
@@ -309,19 +309,16 @@ class SoftContacts(common.ContactModel):
309
309
 
310
310
  # Compute the direction of the tangential force.
311
311
  # To prevent dividing by zero, we use a switch statement.
312
- # The ε, instead, is needed to make AD happy.
313
- f_tangential_direction = jnp.where(
314
- f_tangential.dot(f_tangential) != 0,
315
- f_tangential / jnp.linalg.norm(f_tangential + ε),
316
- jnp.zeros(3),
312
+ norm = jaxsim.math.safe_norm(f_tangential)
313
+ f_tangential_direction = f_tangential / (
314
+ norm + jnp.finfo(float).eps * (norm == 0)
317
315
  )
318
316
 
319
317
  # Project the tangential force to the friction cone if slipping.
320
318
  f_tangential = jnp.where(
321
319
  sticking,
322
320
  f_tangential,
323
- jnp.minimum(μ * force_normal_mag, jnp.linalg.norm(f_tangential + ε))
324
- * f_tangential_direction,
321
+ jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
325
322
  )
326
323
 
327
324
  # Set the tangential force to zero if there is no contact.
@@ -7,6 +7,7 @@ import jax.numpy as jnp
7
7
  import jax_dataclasses
8
8
  import numpy as np
9
9
 
10
+ import jaxsim.math
10
11
  import jaxsim.typing as jtp
11
12
  from jaxsim import exceptions
12
13
 
@@ -41,7 +42,7 @@ class Terrain(abc.ABC):
41
42
  [(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
42
43
  )
43
44
 
44
- return n / jnp.linalg.norm(n)
45
+ return n / jaxsim.math.safe_norm(n)
45
46
 
46
47
 
47
48
  @jax_dataclasses.pytree_dataclass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev91
3
+ Version: 0.5.1.dev95
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -75,6 +75,7 @@ src/jaxsim/math/quaternion.py
75
75
  src/jaxsim/math/rotation.py
76
76
  src/jaxsim/math/skew.py
77
77
  src/jaxsim/math/transform.py
78
+ src/jaxsim/math/utils.py
78
79
  src/jaxsim/mujoco/__init__.py
79
80
  src/jaxsim/mujoco/__main__.py
80
81
  src/jaxsim/mujoco/loaders.py
@@ -2,6 +2,7 @@ import os
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
+ import numpy as np
5
6
  from jax.test_util import check_grads
6
7
 
7
8
  import jaxsim.api as js
@@ -413,3 +414,45 @@ def test_ad_integration(
413
414
  modes=["rev", "fwd"],
414
415
  eps=ε,
415
416
  )
417
+
418
+
419
+ def test_ad_safe_norm(
420
+ prng_key: jax.Array,
421
+ ):
422
+
423
+ _, subkey = jax.random.split(prng_key, num=2)
424
+ array = jax.random.uniform(subkey, shape=(4,), minval=-5, maxval=5)
425
+
426
+ # ====
427
+ # Test
428
+ # ====
429
+
430
+ # Test that the safe_norm function is compatible with batching.
431
+ array = jnp.stack([array, array])
432
+ assert jaxsim.math.safe_norm(array, axis=1).shape == (2,)
433
+
434
+ # Test that the safe_norm function is correctly computing the norm.
435
+ assert np.allclose(jaxsim.math.safe_norm(array), np.linalg.norm(array))
436
+
437
+ # Function exposing only the parameters to be differentiated.
438
+ def safe_norm(array: jtp.Array) -> jtp.Array:
439
+
440
+ return jaxsim.math.safe_norm(array)
441
+
442
+ # Check derivatives against finite differences.
443
+ check_grads(
444
+ f=safe_norm,
445
+ args=(array,),
446
+ order=AD_ORDER,
447
+ modes=["rev", "fwd"],
448
+ eps=ε,
449
+ )
450
+
451
+ # Check derivatives against finite differences when the array is zero.
452
+ check_grads(
453
+ f=safe_norm,
454
+ args=(jnp.zeros_like(array),),
455
+ order=AD_ORDER,
456
+ modes=["rev", "fwd"],
457
+ eps=ε,
458
+ )
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes