jaxsim 0.5.1.dev103__tar.gz → 0.5.1.dev126__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/PKG-INFO +102 -38
  2. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/README.md +100 -37
  3. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/environment.yml +2 -0
  4. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/examples/README.md +2 -0
  5. jaxsim-0.5.1.dev126/examples/jaxsim_as_physics_engine.ipynb +274 -0
  6. jaxsim-0.5.1.dev103/examples/jaxsim_as_physics_engine.ipynb → jaxsim-0.5.1.dev126/examples/jaxsim_as_physics_engine_advanced.ipynb +3 -3
  7. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/pyproject.toml +2 -1
  8. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/_version.py +2 -2
  9. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/kin_dyn_parameters.py +8 -8
  10. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/model.py +2 -2
  11. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim.egg-info/PKG-INFO +102 -38
  12. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim.egg-info/SOURCES.txt +1 -0
  13. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim.egg-info/requires.txt +1 -0
  14. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.devcontainer/Dockerfile +0 -0
  15. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.devcontainer/devcontainer.json +0 -0
  16. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.gitattributes +0 -0
  17. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.github/CODEOWNERS +0 -0
  18. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.github/dependabot.yml +0 -0
  19. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.github/workflows/ci_cd.yml +0 -0
  20. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.github/workflows/pixi.yml +0 -0
  21. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.github/workflows/read_the_docs.yml +0 -0
  22. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.gitignore +0 -0
  23. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.pre-commit-config.yaml +0 -0
  24. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/.readthedocs.yaml +0 -0
  25. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/CONTRIBUTING.md +0 -0
  26. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/LICENSE +0 -0
  27. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/Makefile +0 -0
  28. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/conf.py +0 -0
  29. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/examples.rst +0 -0
  30. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/guide/configuration.rst +0 -0
  31. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/guide/install.rst +0 -0
  32. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/index.rst +0 -0
  33. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/make.bat +0 -0
  34. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/modules/api.rst +0 -0
  35. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/modules/integrators.rst +0 -0
  36. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/modules/math.rst +0 -0
  37. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/modules/mujoco.rst +0 -0
  38. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/modules/parsers.rst +0 -0
  39. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/modules/rbda.rst +0 -0
  40. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/modules/typing.rst +0 -0
  41. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/docs/modules/utils.rst +0 -0
  42. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/examples/.gitattributes +0 -0
  43. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/examples/.gitignore +0 -0
  44. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/examples/assets/build_cartpole_urdf.py +0 -0
  45. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/examples/assets/cartpole.urdf +0 -0
  46. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  47. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  48. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/pixi.lock +0 -0
  49. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/setup.cfg +0 -0
  50. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/setup.py +0 -0
  51. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/__init__.py +0 -0
  52. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/__init__.py +0 -0
  53. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/com.py +0 -0
  54. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/common.py +0 -0
  55. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/contact.py +0 -0
  56. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/data.py +0 -0
  57. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/frame.py +0 -0
  58. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/joint.py +0 -0
  59. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/link.py +0 -0
  60. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/ode.py +0 -0
  61. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/ode_data.py +0 -0
  62. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/api/references.py +0 -0
  63. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/exceptions.py +0 -0
  64. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/integrators/__init__.py +0 -0
  65. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/integrators/common.py +0 -0
  66. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/integrators/fixed_step.py +0 -0
  67. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/integrators/variable_step.py +0 -0
  68. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/logging.py +0 -0
  69. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/__init__.py +0 -0
  70. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/adjoint.py +0 -0
  71. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/cross.py +0 -0
  72. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/inertia.py +0 -0
  73. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/joint_model.py +0 -0
  74. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/quaternion.py +0 -0
  75. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/rotation.py +0 -0
  76. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/skew.py +0 -0
  77. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/transform.py +0 -0
  78. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/math/utils.py +0 -0
  79. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/mujoco/__init__.py +0 -0
  80. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/mujoco/__main__.py +0 -0
  81. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/mujoco/loaders.py +0 -0
  82. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/mujoco/model.py +0 -0
  83. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/mujoco/utils.py +0 -0
  84. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/mujoco/visualizer.py +0 -0
  85. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/__init__.py +0 -0
  86. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  87. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  88. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  89. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/descriptions/link.py +0 -0
  90. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/descriptions/model.py +0 -0
  91. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  92. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/rod/__init__.py +0 -0
  93. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/rod/meshes.py +0 -0
  94. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/rod/parser.py +0 -0
  95. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/parsers/rod/utils.py +0 -0
  96. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/__init__.py +0 -0
  97. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/aba.py +0 -0
  98. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/collidable_points.py +0 -0
  99. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  100. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/contacts/common.py +0 -0
  101. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
  102. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  103. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/contacts/soft.py +0 -0
  104. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/contacts/visco_elastic.py +0 -0
  105. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/crba.py +0 -0
  106. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  107. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/jacobian.py +0 -0
  108. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/rnea.py +0 -0
  109. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/rbda/utils.py +0 -0
  110. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/terrain/__init__.py +0 -0
  111. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/terrain/terrain.py +0 -0
  112. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/typing.py +0 -0
  113. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/utils/__init__.py +0 -0
  114. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  115. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/utils/tracing.py +0 -0
  116. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim/utils/wrappers.py +0 -0
  117. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  118. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/src/jaxsim.egg-info/top_level.txt +0 -0
  119. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/__init__.py +0 -0
  120. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/conftest.py +0 -0
  121. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_api_com.py +0 -0
  122. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_api_contact.py +0 -0
  123. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_api_data.py +0 -0
  124. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_api_frame.py +0 -0
  125. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_api_joint.py +0 -0
  126. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_api_link.py +0 -0
  127. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_api_model.py +0 -0
  128. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_automatic_differentiation.py +0 -0
  129. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_benchmark.py +0 -0
  130. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_contact.py +0 -0
  131. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_exceptions.py +0 -0
  132. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_meshes.py +0 -0
  133. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_pytree.py +0 -0
  134. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/test_simulations.py +0 -0
  135. {jaxsim-0.5.1.dev103 → jaxsim-0.5.1.dev126}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev103
3
+ Version: 0.5.1.dev126
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -81,6 +81,7 @@ Requires-Dist: pytest>=6.0; extra == "testing"
81
81
  Requires-Dist: pytest-benchmark; extra == "testing"
82
82
  Requires-Dist: pytest-icdiff; extra == "testing"
83
83
  Requires-Dist: robot-descriptions; extra == "testing"
84
+ Requires-Dist: icub-models; extra == "testing"
84
85
  Provides-Extra: viz
85
86
  Requires-Dist: lxml; extra == "viz"
86
87
  Requires-Dist: mediapy; extra == "viz"
@@ -91,9 +92,7 @@ Requires-Dist: jaxsim[style,testing,viz]; extra == "all"
91
92
 
92
93
  # JaxSim
93
94
 
94
- JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX.
95
-
96
- Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
95
+ **JaxSim** is a **differentiable physics engine** and **multibody dynamics library** built with JAX, tailored for control and robotic learning applications.
97
96
 
98
97
  <div align="center">
99
98
  <br/>
@@ -108,41 +107,105 @@ Its design facilitates research and accelerates prototyping in the intersection
108
107
  </div>
109
108
 
110
109
  ## Features
110
+ - Reduced-coordinate physics engine for **fixed-base** and **floating-base** robots.
111
+ - Multibody dynamics library for model-based control algorithms.
112
+ - Fully Python-based, leveraging [jax][jax] following a functional programming paradigm.
113
+ - Seamless execution on CPUs, GPUs, and TPUs.
114
+ - Supports JIT compilation and automatic vectorization for high performance.
115
+ - Compatible with SDF models and URDF (via [sdformat][sdformat] conversion).
116
+
117
+ ## Usage
118
+
119
+ ### Using JaxSim as simulator
120
+
121
+
122
+ ```python
123
+ import jax.numpy as jnp
124
+ import jaxsim.api as js
125
+ import icub_models
126
+ import pathlib
127
+
128
+ # Load the iCub model
129
+ model_path = icub_models.get_model_file("iCubGazeboV2_5")
130
+ joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
131
+ 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
132
+ 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
133
+ 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
134
+ 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
135
+ 'r_ankle_roll')
136
+
137
+ # Build and reduce the model
138
+ model_description = pathlib.Path(model_path)
139
+ full_model = js.model.JaxSimModel.build_from_model_description(
140
+ model_description=model_description, time_step=0.0001, is_urdf=True
141
+ )
142
+ model = js.model.reduce(model=full_model, considered_joints=joints)
143
+
144
+ ndof = model.dofs()
145
+ # Initialize data and simulation
146
+ data = js.data.JaxSimModelData.zero(model=model).reset_base_position(
147
+ base_position=jnp.array([0.0, 0.0, 1.0])
148
+ )
149
+ T = jnp.arange(start=0, stop=1.0, step=model.time_step)
150
+ tau = jnp.zeros(ndof)
151
+
152
+ # Simulate
153
+ for t in T:
154
+ data, _ = js.model.step(model=model, data=data, link_forces=None, joint_force_references=tau)
111
155
 
112
- - Physics engine in reduced coordinates supporting fixed-base and floating-base robots.
113
- - Multibody dynamics library providing all the necessary components for developing model-based control algorithms.
114
- - Completely developed in Python with [`google/jax`][jax] following a functional programming paradigm.
115
- - Transparent support for running on CPUs, GPUs, and TPUs.
116
- - Full support for JIT compilation for increased performance.
117
- - Full support for automatic vectorization for massive parallelization of open-loop and closed-loop architectures.
118
- - Support for SDF models and, upon conversion with [sdformat][sdformat], URDF models.
119
- - Visualization based on the [passive viewer][passive_viewer_mujoco] of Mujoco.
120
-
121
- ### JaxSim as a simulator
122
-
123
- - Wide range of fixed-step explicit Runge-Kutta integrators.
124
- - Support for variable-step integrators implemented as embedded Runge-Kutta schemes.
125
- - Improved stability by optionally integrating the base orientation on the $\text{SO}(3)$ manifold.
126
- - Soft contacts model supporting full friction cone and sticking-slipping transition.
127
- - Collision detection between points rigidly attached to bodies and uneven ground surfaces.
128
-
129
- ### JaxSim as a multibody dynamics library
156
+ ```
130
157
 
131
- - Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
132
- - Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion.
133
- - Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation].
134
- - Exposes all the necessary quantities to develop controllers in centroidal coordinates.
135
- - Supports running open-loop and full closed-loop control architectures on hardware accelerators.
158
+ ### Using JaxSim as a multibody dynamics library
159
+ ``` python
160
+ import jax.numpy as jnp
161
+ import jaxsim.api as js
162
+ import icub_models
163
+ import pathlib
164
+
165
+ # Load the iCub model
166
+ model_path = icub_models.get_model_file("iCubGazeboV2_5")
167
+ joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
168
+ 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
169
+ 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
170
+ 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
171
+ 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
172
+ 'r_ankle_roll')
173
+
174
+ # Build and reduce the model
175
+ model_description = pathlib.Path(model_path)
176
+ full_model = js.model.JaxSimModel.build_from_model_description(
177
+ model_description=model_description, time_step=0.0001, is_urdf=True
178
+ )
179
+ model = js.model.reduce(model=full_model, considered_joints=joints)
180
+
181
+ # Initialize model data
182
+ data = js.data.JaxSimModelData.build(
183
+ model=model,
184
+ base_position=jnp.array([0.0, 0.0, 1.0],
185
+ )
186
+
187
+ # Frame and dynamics computations
188
+ frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot")
189
+ W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index) # Frame transformation
190
+ W_J_F = js.frame.jacobian(model=model, data=data, frame_index=frame_index) # Frame Jacobian
191
+
192
+ # Dynamics properties
193
+ M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix
194
+ h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces
195
+ g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces
196
+ C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix
197
+
198
+ # Print dynamics results
199
+ print(f"M: shape={M.shape}, h: shape={h.shape}, g: shape={g.shape}, C: shape={C.shape}")
136
200
 
137
- ### JaxSim for robot learning
201
+ ```
202
+ ### Additional features
138
203
 
139
- - Being developed with JAX, all the RBDAs support automatic differentiation both in forward and reverse modes.
204
+ - Full support for automatic differentiation of RBDAs (forward and reverse modes) with JAX.
140
205
  - Support for automatically differentiating against kinematics and dynamics parameters.
141
206
  - All fixed-step integrators are forward and reverse differentiable.
142
207
  - All variable-step integrators are forward differentiable.
143
- - Ideal for sampling synthetic data for reinforcement learning (RL).
144
- - Ideal for designing physics-informed neural networks (PINNs) with loss functions requiring model-based quantities.
145
- - Ideal for combining model-based control with learning-based components.
208
+ - Check the example folder for additional usecase !
146
209
 
147
210
  [jax]: https://github.com/google/jax/
148
211
  [sdformat]: https://github.com/gazebosim/sdformat
@@ -156,12 +219,6 @@ Its design facilitates research and accelerates prototyping in the intersection
156
219
  > JaxSim currently focuses on locomotion applications.
157
220
  > Only contacts between bodies and smooth ground surfaces are supported.
158
221
 
159
- ## Documentation
160
-
161
- The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs].
162
-
163
- [readthedocs]: https://jaxsim.readthedocs.io/
164
-
165
222
  ## Installation
166
223
 
167
224
  <details>
@@ -233,6 +290,13 @@ pip install --no-deps -e .
233
290
  [venv]: https://docs.python.org/3/tutorial/venv.html
234
291
  [jax_gpu]: https://github.com/google/jax/#installation
235
292
 
293
+ ## Documentation
294
+
295
+ The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs].
296
+
297
+ [readthedocs]: https://jaxsim.readthedocs.io/
298
+
299
+
236
300
  ## Overview
237
301
 
238
302
  <details>
@@ -1,8 +1,6 @@
1
1
  # JaxSim
2
2
 
3
- JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX.
4
-
5
- Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
3
+ **JaxSim** is a **differentiable physics engine** and **multibody dynamics library** built with JAX, tailored for control and robotic learning applications.
6
4
 
7
5
  <div align="center">
8
6
  <br/>
@@ -17,41 +15,105 @@ Its design facilitates research and accelerates prototyping in the intersection
17
15
  </div>
18
16
 
19
17
  ## Features
18
+ - Reduced-coordinate physics engine for **fixed-base** and **floating-base** robots.
19
+ - Multibody dynamics library for model-based control algorithms.
20
+ - Fully Python-based, leveraging [jax][jax] following a functional programming paradigm.
21
+ - Seamless execution on CPUs, GPUs, and TPUs.
22
+ - Supports JIT compilation and automatic vectorization for high performance.
23
+ - Compatible with SDF models and URDF (via [sdformat][sdformat] conversion).
24
+
25
+ ## Usage
26
+
27
+ ### Using JaxSim as simulator
28
+
29
+
30
+ ```python
31
+ import jax.numpy as jnp
32
+ import jaxsim.api as js
33
+ import icub_models
34
+ import pathlib
35
+
36
+ # Load the iCub model
37
+ model_path = icub_models.get_model_file("iCubGazeboV2_5")
38
+ joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
39
+ 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
40
+ 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
41
+ 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
42
+ 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
43
+ 'r_ankle_roll')
44
+
45
+ # Build and reduce the model
46
+ model_description = pathlib.Path(model_path)
47
+ full_model = js.model.JaxSimModel.build_from_model_description(
48
+ model_description=model_description, time_step=0.0001, is_urdf=True
49
+ )
50
+ model = js.model.reduce(model=full_model, considered_joints=joints)
51
+
52
+ ndof = model.dofs()
53
+ # Initialize data and simulation
54
+ data = js.data.JaxSimModelData.zero(model=model).reset_base_position(
55
+ base_position=jnp.array([0.0, 0.0, 1.0])
56
+ )
57
+ T = jnp.arange(start=0, stop=1.0, step=model.time_step)
58
+ tau = jnp.zeros(ndof)
59
+
60
+ # Simulate
61
+ for t in T:
62
+ data, _ = js.model.step(model=model, data=data, link_forces=None, joint_force_references=tau)
20
63
 
21
- - Physics engine in reduced coordinates supporting fixed-base and floating-base robots.
22
- - Multibody dynamics library providing all the necessary components for developing model-based control algorithms.
23
- - Completely developed in Python with [`google/jax`][jax] following a functional programming paradigm.
24
- - Transparent support for running on CPUs, GPUs, and TPUs.
25
- - Full support for JIT compilation for increased performance.
26
- - Full support for automatic vectorization for massive parallelization of open-loop and closed-loop architectures.
27
- - Support for SDF models and, upon conversion with [sdformat][sdformat], URDF models.
28
- - Visualization based on the [passive viewer][passive_viewer_mujoco] of Mujoco.
29
-
30
- ### JaxSim as a simulator
31
-
32
- - Wide range of fixed-step explicit Runge-Kutta integrators.
33
- - Support for variable-step integrators implemented as embedded Runge-Kutta schemes.
34
- - Improved stability by optionally integrating the base orientation on the $\text{SO}(3)$ manifold.
35
- - Soft contacts model supporting full friction cone and sticking-slipping transition.
36
- - Collision detection between points rigidly attached to bodies and uneven ground surfaces.
37
-
38
- ### JaxSim as a multibody dynamics library
64
+ ```
39
65
 
40
- - Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
41
- - Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion.
42
- - Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation].
43
- - Exposes all the necessary quantities to develop controllers in centroidal coordinates.
44
- - Supports running open-loop and full closed-loop control architectures on hardware accelerators.
66
+ ### Using JaxSim as a multibody dynamics library
67
+ ``` python
68
+ import jax.numpy as jnp
69
+ import jaxsim.api as js
70
+ import icub_models
71
+ import pathlib
72
+
73
+ # Load the iCub model
74
+ model_path = icub_models.get_model_file("iCubGazeboV2_5")
75
+ joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
76
+ 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
77
+ 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
78
+ 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
79
+ 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
80
+ 'r_ankle_roll')
81
+
82
+ # Build and reduce the model
83
+ model_description = pathlib.Path(model_path)
84
+ full_model = js.model.JaxSimModel.build_from_model_description(
85
+ model_description=model_description, time_step=0.0001, is_urdf=True
86
+ )
87
+ model = js.model.reduce(model=full_model, considered_joints=joints)
88
+
89
+ # Initialize model data
90
+ data = js.data.JaxSimModelData.build(
91
+ model=model,
92
+ base_position=jnp.array([0.0, 0.0, 1.0],
93
+ )
94
+
95
+ # Frame and dynamics computations
96
+ frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot")
97
+ W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index) # Frame transformation
98
+ W_J_F = js.frame.jacobian(model=model, data=data, frame_index=frame_index) # Frame Jacobian
99
+
100
+ # Dynamics properties
101
+ M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix
102
+ h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces
103
+ g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces
104
+ C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix
105
+
106
+ # Print dynamics results
107
+ print(f"M: shape={M.shape}, h: shape={h.shape}, g: shape={g.shape}, C: shape={C.shape}")
45
108
 
46
- ### JaxSim for robot learning
109
+ ```
110
+ ### Additional features
47
111
 
48
- - Being developed with JAX, all the RBDAs support automatic differentiation both in forward and reverse modes.
112
+ - Full support for automatic differentiation of RBDAs (forward and reverse modes) with JAX.
49
113
  - Support for automatically differentiating against kinematics and dynamics parameters.
50
114
  - All fixed-step integrators are forward and reverse differentiable.
51
115
  - All variable-step integrators are forward differentiable.
52
- - Ideal for sampling synthetic data for reinforcement learning (RL).
53
- - Ideal for designing physics-informed neural networks (PINNs) with loss functions requiring model-based quantities.
54
- - Ideal for combining model-based control with learning-based components.
116
+ - Check the example folder for additional usecase !
55
117
 
56
118
  [jax]: https://github.com/google/jax/
57
119
  [sdformat]: https://github.com/gazebosim/sdformat
@@ -65,12 +127,6 @@ Its design facilitates research and accelerates prototyping in the intersection
65
127
  > JaxSim currently focuses on locomotion applications.
66
128
  > Only contacts between bodies and smooth ground surfaces are supported.
67
129
 
68
- ## Documentation
69
-
70
- The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs].
71
-
72
- [readthedocs]: https://jaxsim.readthedocs.io/
73
-
74
130
  ## Installation
75
131
 
76
132
  <details>
@@ -142,6 +198,13 @@ pip install --no-deps -e .
142
198
  [venv]: https://docs.python.org/3/tutorial/venv.html
143
199
  [jax_gpu]: https://github.com/google/jax/#installation
144
200
 
201
+ ## Documentation
202
+
203
+ The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs].
204
+
205
+ [readthedocs]: https://jaxsim.readthedocs.io/
206
+
207
+
145
208
  ## Overview
146
209
 
147
210
  <details>
@@ -29,6 +29,7 @@ dependencies:
29
29
  - pytest
30
30
  - pytest-icdiff
31
31
  - robot_descriptions
32
+ - icub-models
32
33
  # [viz]
33
34
  - lxml
34
35
  - mediapy
@@ -55,6 +56,7 @@ dependencies:
55
56
  - sphinx-multiversion
56
57
  - sphinx_rtd_theme
57
58
  - sphinx-toolbox
59
+ - icub-models
58
60
  # ========================================
59
61
  # Other dependencies for GitHub Codespaces
60
62
  # ========================================
@@ -8,11 +8,13 @@ This folder contains Jupyter notebooks that demonstrate the practical usage of J
8
8
  | :--- | :---: | :--- |
9
9
  | [`jaxsim_as_multibody_dynamics_library`](./jaxsim_as_multibody_dynamics_library.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_multibody_dynamics] | An example demonstrating how to use JaxSim as a multibody dynamics library. |
10
10
  | [`jaxsim_as_physics_engine.ipynb`](./jaxsim_as_physics_engine.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_physics_engine] | An example demonstrating how to simulate vectorized models in parallel. |
11
+ | [`jaxsim_as_physics_engine_advanced.ipynb`](./jaxsim_as_physics_engine_advanced.ipynb) | [![Open In Colab][colab_badge]][jaxsim_as_physics_engine_advanced] | An example showcasing advanced JaxSim usage, such as customizing the integrator, contact model, and more. |
11
12
  | [`jaxsim_for_robot_controllers.ipynb`](./jaxsim_for_robot_controllers.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_closed_loop] | A basic example showing how to simulate a PD controller with gravity compensation for a 2-DOF cart-pole. |
12
13
 
13
14
  [colab_badge]: https://colab.research.google.com/assets/colab-badge.svg
14
15
  [ipynb_jaxsim_closed_loop]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb
15
16
  [ipynb_jaxsim_as_physics_engine]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb
17
+ [jaxsim_as_physics_engine_advanced]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb
16
18
  [ipynb_jaxsim_as_multibody_dynamics]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb
17
19
 
18
20
  ## How to run the examples
@@ -0,0 +1,274 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "H-WgcgGQaTG7"
7
+ },
8
+ "source": [
9
+ "# JaxSim as a hardware-accelerated parallel physics engine\n",
10
+ "\n",
11
+ "This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.\n",
12
+ "\n",
13
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb\">\n",
14
+ " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
15
+ "</a>"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {
21
+ "id": "SgOSnrSscEkt"
22
+ },
23
+ "source": [
24
+ "## Prepare the environment"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {
31
+ "id": "fdqvAqMDaTG9"
32
+ },
33
+ "outputs": [],
34
+ "source": [
35
+ "# @title Imports and setup\n",
36
+ "import sys\n",
37
+ "from IPython.display import clear_output\n",
38
+ "\n",
39
+ "IS_COLAB = \"google.colab\" in sys.modules\n",
40
+ "\n",
41
+ "# Install JAX and Gazebo\n",
42
+ "if IS_COLAB:\n",
43
+ " !{sys.executable} -m pip install --pre -qU jaxsim\n",
44
+ " !apt install -qq lsb-release wget gnupg\n",
45
+ " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
46
+ " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
47
+ " !apt -qq update\n",
48
+ " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
49
+ "\n",
50
+ " clear_output()\n",
51
+ "\n",
52
+ "# Set environment variable to avoid GPU out of memory errors\n",
53
+ "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
54
+ "\n",
55
+ "# ================\n",
56
+ "# Notebook imports\n",
57
+ "# ================\n",
58
+ "\n",
59
+ "import jax\n",
60
+ "import jax.numpy as jnp\n",
61
+ "import jaxsim.api as js\n",
62
+ "from jaxsim import logging\n",
63
+ "import pathlib\n",
64
+ "import urllib.request\n",
65
+ "\n",
66
+ "logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
67
+ "print(f\"Running on {jax.devices()}\")"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "markdown",
72
+ "metadata": {
73
+ "id": "NqjuZKvOaTG_"
74
+ },
75
+ "source": [
76
+ "## Prepare the simulation\n",
77
+ "\n",
78
+ "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the [ergoCub][ergocub] model urdf.\n",
79
+ "\n",
80
+ "[sdformat]: http://sdformat.org/\n",
81
+ "[urdf]: http://wiki.ros.org/urdf/\n",
82
+ "[ergocub]: https://ergocub.eu/\n",
83
+ "[rod]: https://github.com/ami-iit/rod\n",
84
+ "\n",
85
+ "### Create the model and its data\n",
86
+ " To define a simulation we need two main objects:\n",
87
+ "\n",
88
+ "- `model`: an object that defines the dynamics of the system.\n",
89
+ "- `data`: an object that contains the state of the system.\n",
90
+ "\n",
91
+ "\n",
92
+ "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
93
+ "To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the `model` object and how to change the contact model."
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "markdown",
98
+ "metadata": {},
99
+ "source": [
100
+ "### Create the model "
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {
107
+ "id": "etQ577cFaTHA"
108
+ },
109
+ "outputs": [],
110
+ "source": [
111
+ "# Create the JaxSim model.\n",
112
+ "url = \"https://raw.githubusercontent.com/icub-tech-iit/ergocub-software/refs/heads/master/urdf/ergoCub/robots/ergoCubSN001/model.urdf\"\n",
113
+ "\n",
114
+ "# Retrieve the file\n",
115
+ "model_path, _ = urllib.request.urlretrieve(url)\n",
116
+ "\n",
117
+ "model_description_path = pathlib.Path(model_path)\n",
118
+ "full_model = js.model.JaxSimModel.build_from_model_description(\n",
119
+ " model_description=model_description_path,\n",
120
+ " time_step=0.0001,\n",
121
+ " is_urdf=True\n",
122
+ ")\n",
123
+ "\n",
124
+ "joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',\n",
125
+ " 'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',\n",
126
+ " 'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n",
127
+ " 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))\n",
128
+ "\n",
129
+ "model = js.model.reduce(\n",
130
+ " model=full_model,\n",
131
+ " considered_joints=joints_list\n",
132
+ ")\n"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "markdown",
137
+ "metadata": {},
138
+ "source": [
139
+ "### Create the data object \n",
140
+ "\n",
141
+ "The data object is never changed by reference. Anytime you call a method aimed at modifying data, like `reset_base_position`, a new data object will be returned with the updated attributes while the original data will not be changed."
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "# Create the data of a single model.\n",
151
+ "data_zero = js.data.JaxSimModelData.zero(model=model)\n",
152
+ "base_position = jnp.array([0.0, 0.0, 1.0])\n",
153
+ "data = data_zero.reset_base_position(base_position=base_position) # Note that the reset position returns the updated data object"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "markdown",
158
+ "metadata": {},
159
+ "source": [
160
+ "### Simulation"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "# Create a random JAX key.\n",
170
+ "\n",
171
+ "key = jax.random.PRNGKey(seed=0)\n",
172
+ "\n",
173
+ "# Initialize the simulated time.\n",
174
+ "T = jnp.arange(start=0, stop=0.3, step=model.time_step)\n",
175
+ "\n",
176
+ "# Simulate\n",
177
+ "for _t in T:\n",
178
+ " data, _ = js.model.step(\n",
179
+ " model=model,\n",
180
+ " data=data,\n",
181
+ " link_forces=None,\n",
182
+ " joint_force_references=None,\n",
183
+ " )"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "metadata": {},
189
+ "source": [
190
+ "### Vectorized simulation \n",
191
+ "\n",
192
+ "We will now vectorize the simulation on batched data using `jax.vmap`"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "# first we have to vmap the function\n",
202
+ "\n",
203
+ "import functools\n",
204
+ "from typing import Any\n",
205
+ "\n",
206
+ "\n",
207
+ "@jax.jit\n",
208
+ "def step_single(\n",
209
+ " model: js.model.JaxSimModel,\n",
210
+ " data: js.data.JaxSimModelData,\n",
211
+ ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
212
+ "\n",
213
+ " # Close step over static arguments.\n",
214
+ " return js.model.step(\n",
215
+ " model=model,\n",
216
+ " data=data,\n",
217
+ " link_forces=None,\n",
218
+ " joint_force_references=None,\n",
219
+ " )\n",
220
+ "\n",
221
+ "\n",
222
+ "@jax.jit\n",
223
+ "@functools.partial(jax.vmap, in_axes=(None, 0))\n",
224
+ "def step_parallel(\n",
225
+ " model: js.model.JaxSimModel,\n",
226
+ " data: js.data.JaxSimModelData,\n",
227
+ ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
228
+ "\n",
229
+ " return step_single(\n",
230
+ " model=model, data=data\n",
231
+ " )\n",
232
+ "\n",
233
+ "\n",
234
+ "# Then we have to create the vector of initial state\n",
235
+ "batch_size = 5\n",
236
+ "data_batch_t0 = jax.vmap(\n",
237
+ " lambda pos: data_zero.reset_base_position(base_position=pos)\n",
238
+ ")(jnp.tile(jnp.array([0.0, 0.0, 1.0]), (batch_size, 1)))\n",
239
+ "\n",
240
+ "data = data_batch_t0\n",
241
+ "for _t in T:\n",
242
+ " data, _ = step_parallel(model, data)"
243
+ ]
244
+ }
245
+ ],
246
+ "metadata": {
247
+ "accelerator": "GPU",
248
+ "colab": {
249
+ "gpuClass": "premium",
250
+ "private_outputs": true,
251
+ "provenance": [],
252
+ "toc_visible": true
253
+ },
254
+ "kernelspec": {
255
+ "display_name": "jaxsim",
256
+ "language": "python",
257
+ "name": "python3"
258
+ },
259
+ "language_info": {
260
+ "codemirror_mode": {
261
+ "name": "ipython",
262
+ "version": 3
263
+ },
264
+ "file_extension": ".py",
265
+ "mimetype": "text/x-python",
266
+ "name": "python",
267
+ "nbconvert_exporter": "python",
268
+ "pygments_lexer": "ipython3",
269
+ "version": "3.13.0"
270
+ }
271
+ },
272
+ "nbformat": 4,
273
+ "nbformat_minor": 0
274
+ }