jaxsim 0.5.1.dev95__tar.gz → 0.5.1.dev103__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.pre-commit-config.yaml +3 -3
  2. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/PKG-INFO +1 -1
  3. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/__init__.py +11 -3
  4. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/exceptions.py +3 -2
  6. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/rod/utils.py +3 -3
  7. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim.egg-info/PKG-INFO +1 -1
  8. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.devcontainer/Dockerfile +0 -0
  9. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.devcontainer/devcontainer.json +0 -0
  10. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.gitattributes +0 -0
  11. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.github/CODEOWNERS +0 -0
  12. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.github/dependabot.yml +0 -0
  13. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.github/workflows/ci_cd.yml +0 -0
  14. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.github/workflows/pixi.yml +0 -0
  15. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.github/workflows/read_the_docs.yml +0 -0
  16. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.gitignore +0 -0
  17. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/.readthedocs.yaml +0 -0
  18. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/CONTRIBUTING.md +0 -0
  19. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/LICENSE +0 -0
  20. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/README.md +0 -0
  21. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/Makefile +0 -0
  22. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/conf.py +0 -0
  23. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/examples.rst +0 -0
  24. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/guide/configuration.rst +0 -0
  25. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/guide/install.rst +0 -0
  26. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/index.rst +0 -0
  27. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/make.bat +0 -0
  28. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/modules/api.rst +0 -0
  29. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/modules/integrators.rst +0 -0
  30. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/modules/math.rst +0 -0
  31. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/modules/mujoco.rst +0 -0
  32. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/modules/parsers.rst +0 -0
  33. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/modules/rbda.rst +0 -0
  34. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/modules/typing.rst +0 -0
  35. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/docs/modules/utils.rst +0 -0
  36. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/environment.yml +0 -0
  37. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/examples/.gitattributes +0 -0
  38. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/examples/.gitignore +0 -0
  39. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/examples/README.md +0 -0
  40. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/examples/assets/build_cartpole_urdf.py +0 -0
  41. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/examples/assets/cartpole.urdf +0 -0
  42. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  43. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  44. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  45. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/pixi.lock +0 -0
  46. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/pyproject.toml +0 -0
  47. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/setup.cfg +0 -0
  48. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/setup.py +0 -0
  49. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/__init__.py +0 -0
  50. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/com.py +0 -0
  51. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/common.py +0 -0
  52. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/contact.py +0 -0
  53. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/data.py +0 -0
  54. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/frame.py +0 -0
  55. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/joint.py +0 -0
  56. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  57. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/link.py +0 -0
  58. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/model.py +0 -0
  59. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/ode.py +0 -0
  60. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/ode_data.py +0 -0
  61. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/api/references.py +0 -0
  62. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/integrators/__init__.py +0 -0
  63. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/integrators/common.py +0 -0
  64. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/integrators/fixed_step.py +0 -0
  65. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/integrators/variable_step.py +0 -0
  66. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/logging.py +0 -0
  67. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/__init__.py +0 -0
  68. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/adjoint.py +0 -0
  69. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/cross.py +0 -0
  70. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/inertia.py +0 -0
  71. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/joint_model.py +0 -0
  72. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/quaternion.py +0 -0
  73. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/rotation.py +0 -0
  74. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/skew.py +0 -0
  75. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/transform.py +0 -0
  76. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/math/utils.py +0 -0
  77. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/mujoco/__init__.py +0 -0
  78. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/mujoco/__main__.py +0 -0
  79. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/mujoco/loaders.py +0 -0
  80. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/mujoco/model.py +0 -0
  81. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/mujoco/utils.py +0 -0
  82. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/mujoco/visualizer.py +0 -0
  83. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/__init__.py +0 -0
  84. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  85. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  86. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  87. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/descriptions/link.py +0 -0
  88. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/descriptions/model.py +0 -0
  89. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  90. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/rod/__init__.py +0 -0
  91. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/rod/meshes.py +0 -0
  92. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/parsers/rod/parser.py +0 -0
  93. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/__init__.py +0 -0
  94. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/aba.py +0 -0
  95. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/collidable_points.py +0 -0
  96. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  97. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/contacts/common.py +0 -0
  98. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
  99. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  100. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/contacts/soft.py +0 -0
  101. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/contacts/visco_elastic.py +0 -0
  102. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/crba.py +0 -0
  103. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  104. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/jacobian.py +0 -0
  105. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/rnea.py +0 -0
  106. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/rbda/utils.py +0 -0
  107. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/terrain/__init__.py +0 -0
  108. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/terrain/terrain.py +0 -0
  109. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/typing.py +1 -1
  110. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/utils/__init__.py +0 -0
  111. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  112. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/utils/tracing.py +0 -0
  113. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim/utils/wrappers.py +0 -0
  114. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  115. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  116. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim.egg-info/requires.txt +0 -0
  117. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/src/jaxsim.egg-info/top_level.txt +0 -0
  118. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/__init__.py +0 -0
  119. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/conftest.py +0 -0
  120. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_api_com.py +0 -0
  121. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_api_contact.py +0 -0
  122. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_api_data.py +0 -0
  123. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_api_frame.py +0 -0
  124. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_api_joint.py +0 -0
  125. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_api_link.py +0 -0
  126. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_api_model.py +0 -0
  127. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_automatic_differentiation.py +0 -0
  128. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_benchmark.py +0 -0
  129. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_contact.py +0 -0
  130. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_exceptions.py +0 -0
  131. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_meshes.py +0 -0
  132. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_pytree.py +0 -0
  133. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/test_simulations.py +0 -0
  134. {jaxsim-0.5.1.dev95 → jaxsim-0.5.1.dev103}/tests/utils_idyntree.py +0 -0
@@ -20,7 +20,7 @@ repos:
20
20
  args: ["--maxkb=2000"]
21
21
 
22
22
  - repo: https://github.com/psf/black-pre-commit-mirror
23
- rev: 24.8.0
23
+ rev: 24.10.0
24
24
  hooks:
25
25
  - id: black
26
26
  args: ["--check", "--diff"]
@@ -44,11 +44,11 @@ repos:
44
44
  - id: codespell
45
45
 
46
46
  - repo: https://github.com/astral-sh/ruff-pre-commit
47
- rev: v0.6.9
47
+ rev: v0.8.6
48
48
  hooks:
49
49
  - id: ruff
50
50
 
51
51
  - repo: https://github.com/kynan/nbstripout
52
- rev: 0.7.1
52
+ rev: 0.8.1
53
53
  hooks:
54
54
  - id: nbstripout
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev95
3
+ Version: 0.5.1.dev103
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -8,19 +8,27 @@ def _jnp_options() -> None:
8
8
 
9
9
  import jax
10
10
 
11
- # Check if running on TPU
11
+ # Check if running on TPU.
12
12
  is_tpu = jax.devices()[0].platform == "tpu"
13
13
 
14
+ # Check if running on Metal.
15
+ is_metal = jax.devices()[0].platform == "METAL"
16
+
14
17
  # Enable by default 64-bit precision to get accurate physics.
15
18
  # Users can enforce 32-bit precision by setting the following variable to 0.
16
19
  use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"
17
20
 
18
21
  # Notify the user if unsupported 64-bit precision was enforced on TPU.
19
- if is_tpu and use_x64:
20
- msg = "64-bit precision is not allowed on TPU. Enforcing 32bit precision."
22
+ if (is_tpu or is_metal) and use_x64:
23
+ msg = f"64-bit precision is not allowed on {jax.devices()[0].platform.upper}. Enforcing 32bit precision."
21
24
  logging.warning(msg)
22
25
  use_x64 = False
23
26
 
27
+ if is_metal:
28
+ logging.warning(
29
+ "JAX Metal backend is experimental. Some functionalities may not be available."
30
+ )
31
+
24
32
  # Enable 64-bit precision in JAX.
25
33
  if use_x64:
26
34
  logging.info("Enabling JAX to use 64-bit precision")
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.5.1.dev95'
16
- __version_tuple__ = version_tuple = (0, 5, 1, 'dev95')
15
+ __version__ = version = '0.5.1.dev103'
16
+ __version_tuple__ = version_tuple = (0, 5, 1, 'dev103')
@@ -19,8 +19,9 @@ def raise_if(
19
19
  format string (fmt), whose fields are filled with the args and kwargs.
20
20
  """
21
21
 
22
- # Disable host callback if running on TPU.
23
- if jax.devices()[0].platform == "tpu" or os.environ.get(
22
+ # Disable host callback if running on unsupported hardware or if the user
23
+ # explicitly disabled it.
24
+ if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get(
24
25
  "JAXSIM_DISABLE_EXCEPTIONS", 0
25
26
  ):
26
27
  return
@@ -225,15 +225,15 @@ def create_mesh_collision(
225
225
  ) -> descriptions.MeshCollision:
226
226
 
227
227
  file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
228
- _file_type = file.suffix.replace(".", "")
229
- mesh = trimesh.load_mesh(file, file_type=_file_type)
228
+ file_type = file.suffix.replace(".", "")
229
+ mesh = trimesh.load_mesh(file, file_type=file_type)
230
230
 
231
231
  if mesh.is_empty:
232
232
  raise RuntimeError(f"Failed to process '{file}' with trimesh")
233
233
 
234
234
  mesh.apply_scale(collision.geometry.mesh.scale)
235
235
  logging.info(
236
- msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{_file_type}'"
236
+ msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{file_type}'"
237
237
  )
238
238
 
239
239
  if method is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev95
3
+ Version: 0.5.1.dev103
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -20,9 +20,9 @@ PyTree: object = (
20
20
  dict[Hashable, TypeVar("PyTree")]
21
21
  | list[TypeVar("PyTree")]
22
22
  | tuple[TypeVar("PyTree")]
23
- | None
24
23
  | jax.Array
25
24
  | Any
25
+ | None
26
26
  )
27
27
 
28
28
  # =======================