jaxsim 0.4.3.dev77__tar.gz → 0.4.3.dev88__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/PKG-INFO +3 -3
  2. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/environment.yml +2 -2
  3. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/pyproject.toml +2 -2
  4. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/contact.py +5 -7
  6. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/kin_dyn_parameters.py +3 -7
  7. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/link.py +1 -1
  8. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/model.py +2 -8
  9. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/integrators/common.py +8 -12
  10. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/integrators/variable_step.py +13 -15
  11. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/contacts/relaxed_rigid.py +1 -1
  12. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/utils/jaxsim_dataclass.py +1 -1
  13. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim.egg-info/PKG-INFO +3 -3
  14. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim.egg-info/requires.txt +2 -2
  15. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_api_link.py +10 -30
  16. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_automatic_differentiation.py +1 -1
  17. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.devcontainer/Dockerfile +0 -0
  18. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.devcontainer/devcontainer.json +0 -0
  19. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.gitattributes +0 -0
  20. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.github/CODEOWNERS +0 -0
  21. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.github/workflows/ci_cd.yml +0 -0
  22. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.github/workflows/read_the_docs.yml +0 -0
  23. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.github/workflows/update_pixi_lockfile.yml +0 -0
  24. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.gitignore +0 -0
  25. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.pre-commit-config.yaml +0 -0
  26. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/.readthedocs.yaml +0 -0
  27. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/CONTRIBUTING.md +0 -0
  28. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/LICENSE +0 -0
  29. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/README.md +0 -0
  30. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/Makefile +0 -0
  31. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/conf.py +0 -0
  32. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/examples.rst +0 -0
  33. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/guide/install.rst +0 -0
  34. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/index.rst +0 -0
  35. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/make.bat +0 -0
  36. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/modules/api.rst +0 -0
  37. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/modules/integrators.rst +0 -0
  38. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/modules/math.rst +0 -0
  39. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/modules/mujoco.rst +0 -0
  40. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/modules/parsers.rst +0 -0
  41. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/modules/rbda.rst +0 -0
  42. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/modules/typing.rst +0 -0
  43. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/docs/modules/utils.rst +0 -0
  44. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/examples/.gitattributes +0 -0
  45. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/examples/.gitignore +0 -0
  46. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/examples/PD_controller.ipynb +0 -0
  47. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/examples/Parallel_computing.ipynb +0 -0
  48. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/examples/README.md +0 -0
  49. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/examples/assets/cartpole.urdf +0 -0
  50. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/pixi.lock +0 -0
  51. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/setup.cfg +0 -0
  52. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/setup.py +0 -0
  53. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/__init__.py +0 -0
  54. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/__init__.py +0 -0
  55. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/com.py +0 -0
  56. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/common.py +0 -0
  57. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/data.py +0 -0
  58. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/frame.py +0 -0
  59. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/joint.py +0 -0
  60. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/ode.py +0 -0
  61. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/ode_data.py +0 -0
  62. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/api/references.py +0 -0
  63. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/exceptions.py +0 -0
  64. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/integrators/__init__.py +0 -0
  65. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/integrators/fixed_step.py +0 -0
  66. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/logging.py +0 -0
  67. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/math/__init__.py +0 -0
  68. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/math/adjoint.py +0 -0
  69. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/math/cross.py +0 -0
  70. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/math/inertia.py +0 -0
  71. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/math/joint_model.py +0 -0
  72. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/math/quaternion.py +0 -0
  73. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/math/rotation.py +0 -0
  74. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/math/skew.py +0 -0
  75. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/math/transform.py +0 -0
  76. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/mujoco/__init__.py +0 -0
  77. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/mujoco/__main__.py +0 -0
  78. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/mujoco/loaders.py +0 -0
  79. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/mujoco/model.py +0 -0
  80. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/mujoco/visualizer.py +0 -0
  81. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/__init__.py +0 -0
  82. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  83. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  84. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  85. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/descriptions/link.py +0 -0
  86. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/descriptions/model.py +0 -0
  87. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  88. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/rod/__init__.py +0 -0
  89. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/rod/parser.py +0 -0
  90. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/parsers/rod/utils.py +0 -0
  91. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/__init__.py +0 -0
  92. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/aba.py +0 -0
  93. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/collidable_points.py +0 -0
  94. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  95. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/contacts/common.py +0 -0
  96. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  97. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/contacts/soft.py +0 -0
  98. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/crba.py +0 -0
  99. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  100. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/jacobian.py +0 -0
  101. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/rnea.py +0 -0
  102. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/rbda/utils.py +0 -0
  103. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/terrain/__init__.py +0 -0
  104. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/terrain/terrain.py +0 -0
  105. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/typing.py +0 -0
  106. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/utils/__init__.py +0 -0
  107. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/utils/tracing.py +0 -0
  108. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim/utils/wrappers.py +0 -0
  109. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  110. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  111. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/src/jaxsim.egg-info/top_level.txt +0 -0
  112. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/__init__.py +0 -0
  113. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/conftest.py +0 -0
  114. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_api_com.py +0 -0
  115. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_api_contact.py +0 -0
  116. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_api_data.py +0 -0
  117. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_api_frame.py +0 -0
  118. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_api_joint.py +0 -0
  119. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_api_model.py +0 -0
  120. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_contact.py +0 -0
  121. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_exceptions.py +0 -0
  122. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_pytree.py +0 -0
  123. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/test_simulations.py +0 -0
  124. {jaxsim-0.4.3.dev77 → jaxsim-0.4.3.dev88}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev77
3
+ Version: 0.4.3.dev88
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -60,9 +60,9 @@ Requires-Python: >=3.10
60
60
  Description-Content-Type: text/markdown
61
61
  License-File: LICENSE
62
62
  Requires-Dist: coloredlogs
63
- Requires-Dist: jax>=0.4.13
63
+ Requires-Dist: jax>=0.4.26
64
64
  Requires-Dist: jaxopt>=0.8.0
65
- Requires-Dist: jaxlib>=0.4.13
65
+ Requires-Dist: jaxlib>=0.4.26
66
66
  Requires-Dist: jaxlie>=1.3.0
67
67
  Requires-Dist: jax_dataclasses>=1.4.0
68
68
  Requires-Dist: pptree
@@ -7,9 +7,9 @@ dependencies:
7
7
  # ===========================
8
8
  - python >= 3.12.0
9
9
  - coloredlogs
10
- - jax >= 0.4.13
10
+ - jax >= 0.4.26
11
11
  - jaxopt >= 0.8.0
12
- - jaxlib >= 0.4.13
12
+ - jaxlib >= 0.4.26
13
13
  - jaxlie >= 1.3.0
14
14
  - jax-dataclasses >= 1.4.0
15
15
  - pptree
@@ -44,9 +44,9 @@ classifiers = [
44
44
  ]
45
45
  dependencies = [
46
46
  "coloredlogs",
47
- "jax >= 0.4.13",
47
+ "jax >= 0.4.26",
48
48
  "jaxopt >= 0.8.0",
49
- "jaxlib >= 0.4.13",
49
+ "jaxlib >= 0.4.26",
50
50
  "jaxlie >= 1.3.0",
51
51
  "jax_dataclasses >= 1.4.0",
52
52
  "pptree",
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev77'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev77')
15
+ __version__ = version = '0.4.3.dev88'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev88')
@@ -372,11 +372,9 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
372
372
  """
373
373
 
374
374
  # Get the transforms of the parent link of all collidable points.
375
- W_H_L = jax.vmap(
376
- lambda parent_link_idx: js.link.transform(
377
- model=model, data=data, link_index=parent_link_idx
378
- )
379
- )(jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int))
375
+ W_H_L = js.model.forward_kinematics(model=model, data=data)[
376
+ jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int)
377
+ ]
380
378
 
381
379
  # Build the link-to-point transform from the displacement between the link frame L
382
380
  # and the implicit contact frame C.
@@ -427,9 +425,9 @@ def jacobian(
427
425
  # Compute the contact Jacobian.
428
426
  # In inertial-fixed output representation, the Jacobian of the parent link is also
429
427
  # the Jacobian of the frame C implicitly associated with the collidable point.
430
- W_J_WC = jax.vmap(lambda parent_link_idx: W_J_WL[parent_link_idx])(
428
+ W_J_WC = W_J_WL[
431
429
  jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int)
432
- )
430
+ ]
433
431
 
434
432
  # Adjust the output representation.
435
433
  match output_vel_repr:
@@ -98,9 +98,7 @@ class KynDynParameters(JaxsimDataclass):
98
98
  ]
99
99
 
100
100
  # Create a vectorized object of link parameters.
101
- link_parameters = jax.tree_util.tree_map(
102
- lambda *l: jnp.stack(l), *link_parameters_list
103
- )
101
+ link_parameters = jax.tree.map(lambda *l: jnp.stack(l), *link_parameters_list)
104
102
 
105
103
  # =================
106
104
  # Joints properties
@@ -114,7 +112,7 @@ class KynDynParameters(JaxsimDataclass):
114
112
 
115
113
  # Create a vectorized object of joint parameters.
116
114
  joint_parameters = (
117
- jax.tree_util.tree_map(lambda *l: jnp.stack(l), *joint_parameters_list)
115
+ jax.tree.map(lambda *l: jnp.stack(l), *joint_parameters_list)
118
116
  if len(ordered_joints) > 0
119
117
  else JointParameters(
120
118
  index=jnp.array([], dtype=int),
@@ -424,9 +422,7 @@ class KynDynParameters(JaxsimDataclass):
424
422
  # Note that here we include also the index 0 since suc_H_child[0] stores the
425
423
  # optional pose of the base link w.r.t. the root frame of the model.
426
424
  # This is supported by SDF when the base link <pose> element is defined.
427
- suc_H_i = jax.vmap(lambda i: self.joint_model.successor_H_child(joint_index=i))(
428
- jnp.arange(0, 1 + self.number_of_joints())
429
- )
425
+ suc_H_i = self.joint_model.suc_H_i[jnp.arange(0, 1 + self.number_of_joints())]
430
426
 
431
427
  # Compute the overall transforms from the parent to the child of each joint by
432
428
  # composing all the components of our joint model.
@@ -154,7 +154,7 @@ def spatial_inertia(
154
154
  idx=link_index,
155
155
  )
156
156
 
157
- link_parameters = jax.tree_util.tree_map(
157
+ link_parameters = jax.tree.map(
158
158
  lambda l: l[link_index], model.kin_dyn_parameters.link_parameters
159
159
  )
160
160
 
@@ -395,13 +395,7 @@ def total_mass(model: JaxSimModel) -> jtp.Float:
395
395
  The total mass of the model.
396
396
  """
397
397
 
398
- return (
399
- jax.vmap(lambda idx: js.link.mass(model=model, link_index=idx))(
400
- jnp.arange(model.number_of_links())
401
- )
402
- .sum()
403
- .astype(float)
404
- )
398
+ return model.kin_dyn_parameters.link_parameters.mass.sum().astype(float)
405
399
 
406
400
 
407
401
  @jax.jit
@@ -974,7 +968,7 @@ def free_floating_coriolis_matrix(
974
968
  lambda link_index: js.link.jacobian_derivative(
975
969
  model=model, data=data, link_index=link_index
976
970
  )
977
- )(js.link.names_to_idxs(model=model, link_names=model.link_names()))
971
+ )(jnp.arange(model.number_of_links()))
978
972
 
979
973
  L_M_L = link_spatial_inertia_matrices(model=model)
980
974
 
@@ -173,9 +173,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
173
173
 
174
174
  # Make sure that all leafs of the dictionary are JAX arrays.
175
175
  # Also, since these are dummy parameters, set them all to zero.
176
- params_after_init = jax.tree_util.tree_map(
177
- lambda l: jnp.zeros_like(l), integrator.params
178
- )
176
+ params_after_init = jax.tree.map(lambda l: jnp.zeros_like(l), integrator.params)
179
177
 
180
178
  # Mark the next step as first step after initialization.
181
179
  params_after_init = params_after_init | {
@@ -290,7 +288,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
290
288
  z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
291
289
 
292
290
  # The next state is the batch element located at the configured index of solution.
293
- next_state = jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
291
+ next_state = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
294
292
 
295
293
  return next_state, aux_dict
296
294
 
@@ -327,7 +325,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
327
325
  """
328
326
 
329
327
  op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
330
- return jax.tree_util.tree_map(op, x0, k)
328
+ return jax.tree.map(op, x0, k)
331
329
 
332
330
  @classmethod
333
331
  def post_process_state(
@@ -374,7 +372,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
374
372
  f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
375
373
 
376
374
  # Initialize the carry of the for loop with the stacked kᵢ vectors.
377
- carry0 = jax.tree_util.tree_map(
375
+ carry0 = jax.tree.map(
378
376
  lambda l: jnp.repeat(jnp.zeros_like(l)[jnp.newaxis, ...], c.size, axis=0),
379
377
  x0,
380
378
  )
@@ -398,7 +396,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
398
396
 
399
397
  # Compute ∑ⱼ aᵢⱼ kⱼ.
400
398
  op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
401
- sum_ak = jax.tree_util.tree_map(op_sum_ak, K)
399
+ sum_ak = jax.tree.map(op_sum_ak, K)
402
400
 
403
401
  # Compute the next state for the kᵢ evaluation.
404
402
  # Note that this is not a Δt integration since aᵢⱼ could be fractional.
@@ -419,7 +417,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
419
417
 
420
418
  # Store the kᵢ derivative in K.
421
419
  op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
422
- K = jax.tree_util.tree_map(op, K, ki)
420
+ K = jax.tree.map(op, K, ki)
423
421
 
424
422
  carry = K
425
423
  return carry, aux_dict
@@ -433,14 +431,12 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
433
431
 
434
432
  # Update the FSAL property for the next iteration.
435
433
  if self.has_fsal:
436
- self.params["dxdt0"] = jax.tree_util.tree_map(
437
- lambda l: l[self.index_of_fsal], K
438
- )
434
+ self.params["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
439
435
 
440
436
  # Compute the output state.
441
437
  # Note that z contains as many new states as the rows of `b.T`.
442
438
  op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k)
443
- z = jax.tree_util.tree_map(op, x0, K)
439
+ z = jax.tree.map(op, x0, K)
444
440
 
445
441
  # Transform the final state of the integration.
446
442
  # This allows to inject custom logic, if needed.
@@ -87,13 +87,13 @@ def estimate_step_size(
87
87
 
88
88
  # Compute the scaling factors of the initial state and its derivative.
89
89
  compute_scale = lambda x: atol + jnp.abs(x) * rtol
90
- scale0 = jax.tree_util.tree_map(compute_scale, x0)
91
- scale1 = jax.tree_util.tree_map(compute_scale, ẋ0)
90
+ scale0 = jax.tree.map(compute_scale, x0)
91
+ scale1 = jax.tree.map(compute_scale, ẋ0)
92
92
 
93
93
  # Scale the initial state and its derivative.
94
94
  scale_pytree = lambda x, scale: jnp.abs(x) / scale
95
- x0_scaled = jax.tree_util.tree_map(scale_pytree, x0, scale0)
96
- ẋ0_scaled = jax.tree_util.tree_map(scale_pytree, ẋ0, scale1)
95
+ x0_scaled = jax.tree.map(scale_pytree, x0, scale0)
96
+ ẋ0_scaled = jax.tree.map(scale_pytree, ẋ0, scale1)
97
97
 
98
98
  # Get the maximum of the scaled pytrees.
99
99
  d0 = jnp.linalg.norm(flatten(x0_scaled), ord=jnp.inf)
@@ -103,16 +103,16 @@ def estimate_step_size(
103
103
  h0 = jnp.where(jnp.minimum(d0, d1) <= 1e-5, 1e-6, 0.01 * d0 / d1)
104
104
 
105
105
  # Compute the next state (explicit Euler step) and its derivative.
106
- x1 = jax.tree_util.tree_map(lambda x0, ẋ0: x0 + h0 * ẋ0, x0, ẋ0)
106
+ x1 = jax.tree.map(lambda x0, ẋ0: x0 + h0 * ẋ0, x0, ẋ0)
107
107
  ẋ1 = f(x1, t0 + h0)[0]
108
108
 
109
109
  # Compute the scaling factor of the state derivatives.
110
110
  compute_scale_2 = lambda ẋ0, ẋ1: atol + jnp.maximum(jnp.abs(ẋ0), jnp.abs(ẋ1)) * rtol
111
- scale2 = jax.tree_util.tree_map(compute_scale_2, ẋ0, ẋ1)
111
+ scale2 = jax.tree.map(compute_scale_2, ẋ0, ẋ1)
112
112
 
113
113
  # Scale the difference of the state derivatives.
114
114
  scale_ẋ_difference = lambda ẋ0, ẋ1, scale: jnp.abs((ẋ0 - ẋ1) / scale)
115
- ẋ_difference_scaled = jax.tree_util.tree_map(scale_ẋ_difference, ẋ0, ẋ1, scale2)
115
+ ẋ_difference_scaled = jax.tree.map(scale_ẋ_difference, ẋ0, ẋ1, scale2)
116
116
 
117
117
  # Get the maximum of the scaled derivatives difference.
118
118
  d2 = jnp.linalg.norm(flatten(ẋ_difference_scaled), ord=jnp.inf) / h0
@@ -151,11 +151,11 @@ def compute_pytree_scale(
151
151
  """
152
152
 
153
153
  # Consider a zero second pytree, if not given.
154
- x2 = jax.tree_util.tree_map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2
154
+ x2 = jax.tree.map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2
155
155
 
156
156
  # Compute the scaling factors of the initial state and its derivative.
157
157
  compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
158
- scale = jax.tree_util.tree_map(compute_scale, x1, x2)
158
+ scale = jax.tree.map(compute_scale, x1, x2)
159
159
 
160
160
  return scale
161
161
 
@@ -198,14 +198,14 @@ def local_error_estimation(
198
198
 
199
199
  # Consider a zero estimated final state, if not given.
200
200
  xf_estimate = (
201
- jax.tree_util.tree_map(lambda l: jnp.zeros_like(l), xf)
201
+ jax.tree.map(lambda l: jnp.zeros_like(l), xf)
202
202
  if xf_estimate is None
203
203
  else xf_estimate
204
204
  )
205
205
 
206
206
  # Estimate the error.
207
207
  estimate_error = lambda l, l̂, sc: jnp.abs(l - l̂) / sc
208
- error_estimate = jax.tree_util.tree_map(estimate_error, xf, xf_estimate, scale)
208
+ error_estimate = jax.tree.map(estimate_error, xf, xf_estimate, scale)
209
209
 
210
210
  # Return the highest element of the error estimate.
211
211
  return jnp.linalg.norm(flatten(error_estimate), ord=norm_ord)
@@ -359,10 +359,8 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
359
359
  params_next = integrator.params
360
360
 
361
361
  # Extract the high-order solution xf and the low-order estimate x̂f.
362
- xf = jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
363
- x̂f = jax.tree_util.tree_map(
364
- lambda l: l[self.row_index_of_solution_estimate], z
365
- )
362
+ xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
363
+ x̂f = jax.tree.map(lambda l: l[self.row_index_of_solution_estimate], z)
366
364
 
367
365
  # Calculate the local integration error.
368
366
  local_error = local_error_estimation(
@@ -230,7 +230,7 @@ class RelaxedRigidContacts(ContactModel):
230
230
  )
231
231
 
232
232
  def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
233
- x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))
233
+ x, y, z = jax.tree.map(jnp.squeeze, (x, y, z))
234
234
 
235
235
  n̂ = self.terrain.normal(x=x, y=y).squeeze()
236
236
  h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
@@ -298,7 +298,7 @@ class JaxsimDataclass(abc.ABC):
298
298
  """
299
299
 
300
300
  # Make a copy calling tree_map.
301
- obj = jax.tree_util.tree_map(lambda leaf: leaf, self)
301
+ obj = jax.tree.map(lambda leaf: leaf, self)
302
302
 
303
303
  # Make sure that the copied object and all the copied leaves have the same
304
304
  # mutability of the original object.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev77
3
+ Version: 0.4.3.dev88
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -60,9 +60,9 @@ Requires-Python: >=3.10
60
60
  Description-Content-Type: text/markdown
61
61
  License-File: LICENSE
62
62
  Requires-Dist: coloredlogs
63
- Requires-Dist: jax>=0.4.13
63
+ Requires-Dist: jax>=0.4.26
64
64
  Requires-Dist: jaxopt>=0.8.0
65
- Requires-Dist: jaxlib>=0.4.13
65
+ Requires-Dist: jaxlib>=0.4.26
66
66
  Requires-Dist: jaxlie>=1.3.0
67
67
  Requires-Dist: jax_dataclasses>=1.4.0
68
68
  Requires-Dist: pptree
@@ -1,7 +1,7 @@
1
1
  coloredlogs
2
- jax>=0.4.13
2
+ jax>=0.4.26
3
3
  jaxopt>=0.8.0
4
- jaxlib>=0.4.13
4
+ jaxlib>=0.4.26
5
5
  jaxlie>=1.3.0
6
6
  jax_dataclasses>=1.4.0
7
7
  pptree
@@ -74,7 +74,7 @@ def test_link_inertial_properties(
74
74
 
75
75
  for link_name, link_idx in zip(
76
76
  model.link_names(),
77
- js.link.names_to_idxs(model=model, link_names=model.link_names()),
77
+ jnp.arange(model.number_of_links()),
78
78
  strict=True,
79
79
  ):
80
80
  if link_name == model.base_link():
@@ -164,7 +164,7 @@ def test_link_jacobians(
164
164
 
165
165
  for link_name, link_idx in zip(
166
166
  model.link_names(),
167
- js.link.names_to_idxs(model=model, link_names=model.link_names()),
167
+ jnp.arange(model.number_of_links()),
168
168
  strict=True,
169
169
  ):
170
170
  v_WL_idt = kin_dyn.frame_velocity(frame_name=link_name)
@@ -185,7 +185,7 @@ def test_link_jacobians(
185
185
 
186
186
  for link_name, link_idx in zip(
187
187
  model.link_names(),
188
- js.link.names_to_idxs(model=model, link_names=model.link_names()),
188
+ jnp.arange(model.number_of_links()),
189
189
  strict=True,
190
190
  ):
191
191
  v_WL_idt = kin_dyn_other_repr.frame_velocity(frame_name=link_name)
@@ -220,7 +220,7 @@ def test_link_bias_acceleration(
220
220
 
221
221
  for name, index in zip(
222
222
  model.link_names(),
223
- js.link.names_to_idxs(model=model, link_names=model.link_names()),
223
+ jnp.arange(model.number_of_links()),
224
224
  strict=True,
225
225
  ):
226
226
  Jν_idt = kin_dyn.frame_bias_acc(frame_name=name)
@@ -240,11 +240,7 @@ def test_link_bias_acceleration(
240
240
 
241
241
  W_H_L = js.model.forward_kinematics(model=model, data=data)
242
242
 
243
- W_a_bias_WL = jax.vmap(
244
- lambda index: js.link.bias_acceleration(
245
- model=model, data=data, link_index=index
246
- )
247
- )(jnp.arange(model.number_of_links()))
243
+ W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
248
244
 
249
245
  with data.switch_velocity_representation(VelRepr.Body):
250
246
 
@@ -252,11 +248,7 @@ def test_link_bias_acceleration(
252
248
  lambda W_H_L: jaxsim.math.Adjoint.from_transform(transform=W_H_L)
253
249
  )(W_H_L)
254
250
 
255
- L_a_bias_WL = jax.vmap(
256
- lambda index: js.link.bias_acceleration(
257
- model=model, data=data, link_index=index
258
- )
259
- )(jnp.arange(model.number_of_links()))
251
+ L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
260
252
 
261
253
  W_a_bias_WL_converted = jax.vmap(
262
254
  lambda W_X_L, L_a_bias_WL: W_X_L @ L_a_bias_WL
@@ -269,11 +261,7 @@ def test_link_bias_acceleration(
269
261
 
270
262
  W_H_L = js.model.forward_kinematics(model=model, data=data)
271
263
 
272
- L_a_bias_WL = jax.vmap(
273
- lambda index: js.link.bias_acceleration(
274
- model=model, data=data, link_index=index
275
- )
276
- )(jnp.arange(model.number_of_links()))
264
+ L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
277
265
 
278
266
  with data.switch_velocity_representation(VelRepr.Inertial):
279
267
 
@@ -283,11 +271,7 @@ def test_link_bias_acceleration(
283
271
  )
284
272
  )(W_H_L)
285
273
 
286
- W_a_bias_WL = jax.vmap(
287
- lambda index: js.link.bias_acceleration(
288
- model=model, data=data, link_index=index
289
- )
290
- )(jnp.arange(model.number_of_links()))
274
+ W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
291
275
 
292
276
  L_a_bias_WL_converted = jax.vmap(
293
277
  lambda L_X_W, W_a_bias_WL: L_X_W @ W_a_bias_WL
@@ -323,14 +307,10 @@ def test_link_jacobian_derivative(
323
307
  lambda link_index: js.link.jacobian_derivative(
324
308
  model=model, data=data, link_index=link_index
325
309
  )
326
- )(js.link.names_to_idxs(model=model, link_names=model.link_names()))
310
+ )(jnp.arange(model.number_of_links()))
327
311
 
328
312
  # Compute the product J̇ν.
329
- O_a_bias_WL = jax.vmap(
330
- lambda link_index: js.link.bias_acceleration(
331
- model=model, data=data, link_index=link_index
332
- )
333
- )(js.link.names_to_idxs(model=model, link_names=model.link_names()))
313
+ O_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
334
314
 
335
315
  # Compare the two computations.
336
316
  assert jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν) == pytest.approx(
@@ -263,7 +263,7 @@ def test_ad_jacobian(
263
263
  # ====
264
264
 
265
265
  # Get the link indices.
266
- link_indices = js.link.names_to_idxs(model=model, link_names=model.link_names())
266
+ link_indices = jnp.arange(model.number_of_links())
267
267
 
268
268
  # Get a closure exposing only the parameters to be differentiated.
269
269
  # We differentiate the jacobian of the last link, likely among those
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes