jaxsim 0.7.1.dev40__tar.gz → 0.7.1.dev46__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.github/workflows/gpu_benchmark.yml +5 -3
  2. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/PKG-INFO +1 -1
  3. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/_version.py +2 -2
  4. jaxsim-0.7.1.dev46/src/jaxsim/math/utils.py +58 -0
  5. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim.egg-info/PKG-INFO +1 -1
  6. jaxsim-0.7.1.dev40/src/jaxsim/math/utils.py +0 -32
  7. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.devcontainer/Dockerfile +0 -0
  8. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.devcontainer/devcontainer.json +0 -0
  9. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.gitattributes +0 -0
  10. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.github/CODEOWNERS +0 -0
  11. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.github/dependabot.yml +0 -0
  12. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.github/release.yml +0 -0
  13. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.github/workflows/ci_cd.yml +0 -0
  14. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.github/workflows/pixi.yml +0 -0
  15. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.github/workflows/read_the_docs.yml +0 -0
  16. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.gitignore +0 -0
  17. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.pre-commit-config.yaml +0 -0
  18. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/.readthedocs.yaml +0 -0
  19. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/CONTRIBUTING.md +0 -0
  20. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/LICENSE +0 -0
  21. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/README.md +0 -0
  22. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/Makefile +0 -0
  23. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/conf.py +0 -0
  24. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/examples.rst +0 -0
  25. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/guide/configuration.rst +0 -0
  26. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/guide/install.rst +0 -0
  27. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/index.rst +0 -0
  28. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/make.bat +0 -0
  29. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/modules/api.rst +0 -0
  30. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/modules/math.rst +0 -0
  31. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/modules/mujoco.rst +0 -0
  32. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/modules/parsers.rst +0 -0
  33. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/modules/rbda.rst +0 -0
  34. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/modules/typing.rst +0 -0
  35. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/docs/modules/utils.rst +0 -0
  36. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/environment.yml +0 -0
  37. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/examples/.gitattributes +0 -0
  38. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/examples/.gitignore +0 -0
  39. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/examples/README.md +0 -0
  40. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/examples/assets/build_cartpole_urdf.py +0 -0
  41. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/examples/assets/cartpole.urdf +0 -0
  42. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  43. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  44. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/examples/jaxsim_as_physics_engine_advanced.ipynb +0 -0
  45. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  46. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/pixi.lock +0 -0
  47. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/pyproject.toml +0 -0
  48. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/setup.cfg +0 -0
  49. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/__init__.py +0 -0
  50. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/__init__.py +0 -0
  51. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/actuation_model.py +0 -0
  52. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/com.py +0 -0
  53. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/common.py +0 -0
  54. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/contact.py +0 -0
  55. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/data.py +0 -0
  56. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/frame.py +0 -0
  57. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/integrators.py +0 -0
  58. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/joint.py +0 -0
  59. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  60. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/link.py +0 -0
  61. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/model.py +0 -0
  62. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/ode.py +0 -0
  63. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/api/references.py +0 -0
  64. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/exceptions.py +0 -0
  65. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/logging.py +0 -0
  66. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/math/__init__.py +0 -0
  67. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/math/adjoint.py +0 -0
  68. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/math/cross.py +0 -0
  69. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/math/inertia.py +0 -0
  70. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/math/joint_model.py +0 -0
  71. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/math/quaternion.py +0 -0
  72. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/math/rotation.py +0 -0
  73. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/math/skew.py +0 -0
  74. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/math/transform.py +0 -0
  75. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/mujoco/__init__.py +0 -0
  76. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/mujoco/__main__.py +0 -0
  77. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/mujoco/loaders.py +0 -0
  78. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/mujoco/model.py +0 -0
  79. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/mujoco/utils.py +0 -0
  80. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/mujoco/visualizer.py +0 -0
  81. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/__init__.py +0 -0
  82. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  83. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  84. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  85. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/descriptions/link.py +0 -0
  86. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/descriptions/model.py +0 -0
  87. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  88. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/rod/__init__.py +0 -0
  89. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/rod/meshes.py +0 -0
  90. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/rod/parser.py +0 -0
  91. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/parsers/rod/utils.py +0 -0
  92. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/__init__.py +0 -0
  93. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/aba.py +0 -0
  94. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/actuation/__init__.py +0 -0
  95. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/actuation/common.py +0 -0
  96. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/collidable_points.py +0 -0
  97. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  98. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/contacts/common.py +0 -0
  99. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
  100. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  101. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/contacts/soft.py +0 -0
  102. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/crba.py +0 -0
  103. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  104. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/jacobian.py +0 -0
  105. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/rnea.py +0 -0
  106. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/rbda/utils.py +0 -0
  107. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/terrain/__init__.py +0 -0
  108. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/terrain/terrain.py +0 -0
  109. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/typing.py +0 -0
  110. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/utils/__init__.py +0 -0
  111. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  112. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/utils/tracing.py +0 -0
  113. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim/utils/wrappers.py +0 -0
  114. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  115. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  116. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim.egg-info/requires.txt +0 -0
  117. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/src/jaxsim.egg-info/top_level.txt +0 -0
  118. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/__init__.py +0 -0
  119. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/conftest.py +0 -0
  120. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_actuation.py +0 -0
  121. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_api_com.py +0 -0
  122. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_api_contact.py +0 -0
  123. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_api_data.py +0 -0
  124. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_api_frame.py +0 -0
  125. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_api_joint.py +0 -0
  126. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_api_link.py +0 -0
  127. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_api_model.py +0 -0
  128. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_api_model_hw_parametrization.py +0 -0
  129. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_automatic_differentiation.py +0 -0
  130. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_benchmark.py +0 -0
  131. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_exceptions.py +0 -0
  132. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_meshes.py +0 -0
  133. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_pytree.py +0 -0
  134. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_simulations.py +0 -0
  135. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/test_visualizer.py +0 -0
  136. {jaxsim-0.7.1.dev40 → jaxsim-0.7.1.dev46}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,9 @@
1
1
  name: GPU Benchmarks
2
2
 
3
3
  on:
4
+ push:
5
+ branches:
6
+ - main
4
7
  pull_request:
5
8
  types: [opened, reopened, synchronize]
6
9
  workflow_dispatch:
@@ -46,7 +49,7 @@ jobs:
46
49
  uses: actions/cache/restore@v4
47
50
  with:
48
51
  path: ./cache
49
- key: ${{ steps.get-main-branch-sha.outputs.sha }}-${{ runner.os }}-benchmark
52
+ key: ${{ runner.os }}-benchmark
50
53
 
51
54
  - name: Ensure version file is written
52
55
  run: |
@@ -60,7 +63,6 @@ jobs:
60
63
 
61
64
  - name: Compare benchmark results with main branch
62
65
  uses: benchmark-action/github-action-benchmark@v1.20.4
63
- if: steps.cache.outputs.cache-hit == 'true'
64
66
  with:
65
67
  tool: 'pytest'
66
68
  output-file-path: output.json
@@ -106,4 +108,4 @@ jobs:
106
108
  if: ${{ github.ref_name == 'main' }}
107
109
  with:
108
110
  path: ./cache
109
- key: ${{ steps.get-main-branch-sha.outputs.sha }}-${{ runner.os }}-benchmark
111
+ key: ${{ runner.os }}-benchmark
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.7.1.dev40
3
+ Version: 0.7.1.dev46
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.1.dev40'
21
- __version_tuple__ = version_tuple = (0, 7, 1, 'dev40')
20
+ __version__ = version = '0.7.1.dev46'
21
+ __version_tuple__ = version_tuple = (0, 7, 1, 'dev46')
@@ -0,0 +1,58 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ import jaxsim.typing as jtp
5
+
6
+
7
+ def _make_safe_norm(axis, keepdims):
8
+ @jax.custom_jvp
9
+ def _safe_norm(array: jtp.ArrayLike) -> jtp.Array:
10
+ """
11
+ Compute an array norm handling NaNs and making sure that
12
+ it is safe to get the gradient.
13
+
14
+ Args:
15
+ array: The array for which to compute the norm.
16
+
17
+ Returns:
18
+ The norm of the array with handling for zero arrays to avoid NaNs.
19
+ """
20
+ # Compute the norm of the array along the specified axis.
21
+ return jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
22
+
23
+ @_safe_norm.defjvp
24
+ def _safe_norm_jvp(primals, tangents):
25
+ (x,), (x_dot,) = primals, tangents
26
+
27
+ # Check if the entire array is composed of zeros.
28
+ is_zero = jnp.all(x == 0.0)
29
+
30
+ # Replace zeros with an array of ones temporarily to avoid division by zero.
31
+ # This ensures the computation of norm does not produce NaNs or Infs.
32
+ array = jnp.where(is_zero, jnp.ones_like(x), x)
33
+
34
+ # Compute the norm of the array along the specified axis.
35
+ norm = jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
36
+
37
+ dot = jnp.sum(array * x_dot, axis=axis, keepdims=keepdims)
38
+ tangent = jnp.where(is_zero, 0.0, dot / norm)
39
+
40
+ return jnp.where(is_zero, 0.0, norm), tangent
41
+
42
+ return _safe_norm
43
+
44
+
45
+ def safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False) -> jtp.Array:
46
+ """
47
+ Compute an array norm handling NaNs and making sure that
48
+ it is safe to get the gradient.
49
+
50
+ Args:
51
+ array: The array for which to compute the norm.
52
+ axis: The axis for which to compute the norm.
53
+ keepdims: Whether to keep the dimensions of the input
54
+
55
+ Returns:
56
+ The norm of the array with handling for zero arrays to avoid NaNs.
57
+ """
58
+ return _make_safe_norm(axis, keepdims)(array)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.7.1.dev40
3
+ Version: 0.7.1.dev46
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -1,32 +0,0 @@
1
- import jax.numpy as jnp
2
-
3
- import jaxsim.typing as jtp
4
-
5
-
6
- def safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False) -> jtp.Array:
7
- """
8
- Compute an array norm handling NaNs and making sure that
9
- it is safe to get the gradient.
10
-
11
- Args:
12
- array: The array for which to compute the norm.
13
- axis: The axis for which to compute the norm.
14
- keepdims: Whether to keep the dimensions of the input
15
-
16
- Returns:
17
- The norm of the array with handling for zero arrays to avoid NaNs.
18
- """
19
-
20
- # Check if the entire array is composed of zeros.
21
- is_zero = jnp.allclose(array, 0.0)
22
-
23
- # Replace zeros with an array of ones temporarily to avoid division by zero.
24
- # This ensures the computation of norm does not produce NaNs or Infs.
25
- array = jnp.where(is_zero, jnp.ones_like(array), array)
26
-
27
- # Compute the norm of the array along the specified axis.
28
- norm = jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
29
-
30
- # Use `jnp.where` to set the norm to 0.0 where the input array was all zeros.
31
- # This usage supports potential batch processing for future scalability.
32
- return jnp.where(is_zero, 0.0, norm)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes