jaxsim 0.4.3.dev245__tar.gz → 0.4.3.dev271__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/PKG-INFO +1 -1
  2. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/jaxsim_as_physics_engine.ipynb +27 -35
  3. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/jaxsim_for_robot_controllers.ipynb +7 -27
  4. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/model.py +92 -23
  6. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/ode.py +26 -22
  7. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/integrators/common.py +27 -76
  8. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/integrators/variable_step.py +96 -61
  9. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/PKG-INFO +1 -1
  10. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_automatic_differentiation.py +0 -20
  11. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_simulations.py +7 -57
  12. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.devcontainer/Dockerfile +0 -0
  13. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.devcontainer/devcontainer.json +0 -0
  14. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.gitattributes +0 -0
  15. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/CODEOWNERS +0 -0
  16. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/dependabot.yml +0 -0
  17. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/workflows/ci_cd.yml +0 -0
  18. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/workflows/read_the_docs.yml +0 -0
  19. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/workflows/update_pixi_lockfile.yml +0 -0
  20. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.gitignore +0 -0
  21. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.pre-commit-config.yaml +0 -0
  22. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.readthedocs.yaml +0 -0
  23. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/CONTRIBUTING.md +0 -0
  24. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/LICENSE +0 -0
  25. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/README.md +0 -0
  26. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/Makefile +0 -0
  27. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/conf.py +0 -0
  28. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/examples.rst +0 -0
  29. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/guide/install.rst +0 -0
  30. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/index.rst +0 -0
  31. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/make.bat +0 -0
  32. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/api.rst +0 -0
  33. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/integrators.rst +0 -0
  34. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/math.rst +0 -0
  35. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/mujoco.rst +0 -0
  36. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/parsers.rst +0 -0
  37. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/rbda.rst +0 -0
  38. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/typing.rst +0 -0
  39. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/utils.rst +0 -0
  40. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/environment.yml +0 -0
  41. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/.gitattributes +0 -0
  42. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/.gitignore +0 -0
  43. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/README.md +0 -0
  44. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/assets/build_cartpole_urdf.py +0 -0
  45. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/assets/cartpole.urdf +0 -0
  46. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  47. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/pixi.lock +0 -0
  48. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/pyproject.toml +0 -0
  49. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/setup.cfg +0 -0
  50. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/setup.py +0 -0
  51. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/__init__.py +0 -0
  52. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/__init__.py +0 -0
  53. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/com.py +0 -0
  54. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/common.py +0 -0
  55. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/contact.py +0 -0
  56. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/data.py +0 -0
  57. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/frame.py +0 -0
  58. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/joint.py +0 -0
  59. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  60. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/link.py +0 -0
  61. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/ode_data.py +0 -0
  62. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/references.py +0 -0
  63. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/exceptions.py +0 -0
  64. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/integrators/__init__.py +0 -0
  65. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/integrators/fixed_step.py +0 -0
  66. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/logging.py +0 -0
  67. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/__init__.py +0 -0
  68. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/adjoint.py +0 -0
  69. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/cross.py +0 -0
  70. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/inertia.py +0 -0
  71. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/joint_model.py +0 -0
  72. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/quaternion.py +0 -0
  73. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/rotation.py +0 -0
  74. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/skew.py +0 -0
  75. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/transform.py +0 -0
  76. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/__init__.py +0 -0
  77. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/__main__.py +0 -0
  78. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/loaders.py +0 -0
  79. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/model.py +0 -0
  80. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/utils.py +0 -0
  81. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/visualizer.py +0 -0
  82. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/__init__.py +0 -0
  83. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  84. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  85. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  86. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/link.py +0 -0
  87. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/model.py +0 -0
  88. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  89. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/rod/__init__.py +0 -0
  90. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/rod/parser.py +0 -0
  91. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/rod/utils.py +0 -0
  92. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/__init__.py +0 -0
  93. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/aba.py +0 -0
  94. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/collidable_points.py +0 -0
  95. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  96. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/common.py +0 -0
  97. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
  98. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  99. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/soft.py +0 -0
  100. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/visco_elastic.py +0 -0
  101. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/crba.py +0 -0
  102. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  103. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/jacobian.py +0 -0
  104. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/rnea.py +0 -0
  105. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/utils.py +0 -0
  106. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/terrain/__init__.py +0 -0
  107. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/terrain/terrain.py +0 -0
  108. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/typing.py +0 -0
  109. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/utils/__init__.py +0 -0
  110. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  111. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/utils/tracing.py +0 -0
  112. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/utils/wrappers.py +0 -0
  113. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  114. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  115. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/requires.txt +0 -0
  116. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/top_level.txt +0 -0
  117. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/__init__.py +0 -0
  118. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/conftest.py +0 -0
  119. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_com.py +0 -0
  120. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_contact.py +0 -0
  121. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_data.py +0 -0
  122. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_frame.py +0 -0
  123. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_joint.py +0 -0
  124. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_link.py +0 -0
  125. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_model.py +0 -0
  126. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_contact.py +0 -0
  127. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_exceptions.py +0 -0
  128. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_pytree.py +0 -0
  129. {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev245
3
+ Version: 0.4.3.dev271
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -63,8 +63,9 @@
63
63
  "import jax\n",
64
64
  "import jax.numpy as jnp\n",
65
65
  "import jaxsim.api as js\n",
66
+ "import jaxsim\n",
66
67
  "import rod\n",
67
- "from jaxsim import integrators, logging\n",
68
+ "from jaxsim import logging\n",
68
69
  "from rod.builder.primitives import SphereBuilder\n",
69
70
  "\n",
70
71
  "logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
@@ -142,8 +143,11 @@
142
143
  "\n",
143
144
  "- `model`: an object that defines the dynamics of the system.\n",
144
145
  "- `data`: an object that contains the state of the system.\n",
145
- "- `integrator`: an object that defines the integration method.\n",
146
- "- `integrator_state`: an object that contains the state of the integrator."
146
+ "- `integrator` *(Optional)*: an object that defines the integration method.\n",
147
+ "- `integrator_state` *(Optional)*: an object that contains the state of the integrator.\n",
148
+ "\n",
149
+ "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
150
+ "In this example, we will explicity pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model."
147
151
  ]
148
152
  },
149
153
  {
@@ -157,7 +161,9 @@
157
161
  "# Create the JaxSim model.\n",
158
162
  "# This is shared among all the parallel instances.\n",
159
163
  "model = js.model.JaxSimModel.build_from_model_description(\n",
160
- " model_description=model_sdf_string, time_step=0.001\n",
164
+ " model_description=model_sdf_string,\n",
165
+ " time_step=0.001,\n",
166
+ " integrator=jaxsim.integrators.fixed_step.Heun2,\n",
161
167
  ")\n",
162
168
  "\n",
163
169
  "# Create the data of a single model.\n",
@@ -240,21 +246,7 @@
240
246
  },
241
247
  "outputs": [],
242
248
  "source": [
243
- "# Create the integrator.\n",
244
- "integrator = integrators.fixed_step.Heun2SO3.build(\n",
245
- " dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
246
- " model=model,\n",
247
- " data=data_single,\n",
248
- " system_dynamics=js.ode.system_dynamics,\n",
249
- " ),\n",
250
- ")\n",
251
- "\n",
252
- "# Initialize the integrator.\n",
253
- "integrator_state = integrator.init(\n",
254
- " x0=data_single.state,\n",
255
- " t0=0.0,\n",
256
- " dt=model.time_step,\n",
257
- ")\n",
249
+ "print(f\"Using integrator: {model.integrator}\")\n",
258
250
  "\n",
259
251
  "# Initialize the simulated time.\n",
260
252
  "T = jnp.arange(start=0, stop=1.0, step=model.time_step)"
@@ -324,46 +316,42 @@
324
316
  "def step_single(\n",
325
317
  " model: js.model.JaxSimModel,\n",
326
318
  " data: js.data.JaxSimModelData,\n",
327
- " integrator_state: dict[str, Any],\n",
328
319
  ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
329
320
  "\n",
330
321
  " # Close step over static arguments.\n",
331
322
  " return js.model.step(\n",
332
323
  " model=model,\n",
333
324
  " data=data,\n",
334
- " integrator=integrator,\n",
335
- " integrator_state=integrator_state,\n",
336
325
  " link_forces=None,\n",
337
326
  " joint_force_references=None,\n",
338
327
  " )\n",
339
328
  "\n",
340
329
  "\n",
341
330
  "@jax.jit\n",
342
- "@functools.partial(jax.vmap, in_axes=(None, 0, None))\n",
331
+ "@functools.partial(jax.vmap, in_axes=(None, 0))\n",
343
332
  "def step_parallel(\n",
344
333
  " model: js.model.JaxSimModel,\n",
345
334
  " data: js.data.JaxSimModelData,\n",
346
- " integrator_state: dict[str, Any],\n",
347
335
  ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
348
336
  "\n",
349
337
  " return step_single(\n",
350
- " model=model, data=data, integrator_state=integrator_state\n",
338
+ " model=model, data=data\n",
351
339
  " )\n",
352
340
  "\n",
353
341
  "\n",
354
342
  "# The first run will be slow since JAX needs to JIT-compile the functions.\n",
355
- "_ = step_single(model, data_single, integrator_state)\n",
356
- "_ = step_parallel(model, data_batch_t0, integrator_state)\n",
343
+ "_ = step_single(model, data_single)\n",
344
+ "_ = step_parallel(model, data_batch_t0)\n",
357
345
  "\n",
358
346
  "# Benchmark the execution of a single step.\n",
359
347
  "print(\"\\nSingle simulation step:\")\n",
360
- "%timeit step_single(model, data_single, integrator_state)\n",
348
+ "%timeit step_single(model, data_single)\n",
361
349
  "\n",
362
350
  "# On hardware accelerators, there's a range of batch_size values where\n",
363
351
  "# increasing the number of parallel instances doesn't affect computation time.\n",
364
352
  "# This range depends on the GPU/TPU specifications.\n",
365
353
  "print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n",
366
- "%timeit step_parallel(model, data_batch_t0, integrator_state)"
354
+ "%timeit step_parallel(model, data_batch_t0)"
367
355
  ]
368
356
  },
369
357
  {
@@ -381,7 +369,7 @@
381
369
  "\n",
382
370
  "for _ in T:\n",
383
371
  "\n",
384
- " data, integrator_state = step_parallel(model, data, integrator_state)\n",
372
+ " data, _ = step_parallel(model, data)\n",
385
373
  " data_trajectory_list.append(data)"
386
374
  ]
387
375
  },
@@ -404,9 +392,7 @@
404
392
  "source": [
405
393
  "# Convert a list of PyTrees to a batched PyTree.\n",
406
394
  "# This operation is called 'tree transpose' in JAX.\n",
407
- "data_trajectory = jax.tree.map(\n",
408
- " lambda *leafs: jnp.stack(leafs), *data_trajectory_list\n",
409
- ")\n",
395
+ "data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n",
410
396
  "\n",
411
397
  "print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")"
412
398
  ]
@@ -448,6 +434,11 @@
448
434
  "\n",
449
435
  "Have fun!"
450
436
  ]
437
+ },
438
+ {
439
+ "cell_type": "markdown",
440
+ "metadata": {},
441
+ "source": []
451
442
  }
452
443
  ],
453
444
  "metadata": {
@@ -459,7 +450,8 @@
459
450
  "toc_visible": true
460
451
  },
461
452
  "kernelspec": {
462
- "display_name": "Python 3",
453
+ "display_name": "jaxsim",
454
+ "language": "python",
463
455
  "name": "python3"
464
456
  },
465
457
  "language_info": {
@@ -472,7 +464,7 @@
472
464
  "name": "python",
473
465
  "nbconvert_exporter": "python",
474
466
  "pygments_lexer": "ipython3",
475
- "version": "3.11.8"
467
+ "version": "3.12.7"
476
468
  }
477
469
  },
478
470
  "nbformat": 4,
@@ -123,7 +123,6 @@
123
123
  "# @title Create the model and its data\n",
124
124
  "\n",
125
125
  "import jaxsim.api as js\n",
126
- "from jaxsim import integrators\n",
127
126
  "\n",
128
127
  "# Create the model from the model description.\n",
129
128
  "model = js.model.JaxSimModel.build_from_model_description(\n",
@@ -143,23 +142,7 @@
143
142
  },
144
143
  "outputs": [],
145
144
  "source": [
146
- "# @title Select the integrator\n",
147
- "\n",
148
- "# Create the integrator.\n",
149
- "integrator = integrators.fixed_step.RungeKutta4.build(\n",
150
- " dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
151
- " model=model,\n",
152
- " data=data_zero,\n",
153
- " system_dynamics=js.ode.system_dynamics,\n",
154
- " ),\n",
155
- ")\n",
156
- "\n",
157
- "# Initialize the integrator.\n",
158
- "integrator_state = integrator.init(\n",
159
- " x0=data_zero.state,\n",
160
- " t0=0.0,\n",
161
- " dt=model.time_step,\n",
162
- ")\n",
145
+ "# @title Define simulation parameters\n",
163
146
  "\n",
164
147
  "# Initialize the simulated time.\n",
165
148
  "T = jnp.arange(start=0, stop=5.0, step=model.time_step)"
@@ -255,11 +238,9 @@
255
238
  "for _ in T:\n",
256
239
  "\n",
257
240
  " # Step the JaxSim simulation.\n",
258
- " data, integrator_state = js.model.step(\n",
241
+ " data, _ = js.model.step(\n",
259
242
  " model=model,\n",
260
243
  " data=data,\n",
261
- " integrator=integrator,\n",
262
- " integrator_state=integrator_state,\n",
263
244
  " joint_force_references=None,\n",
264
245
  " link_forces=None,\n",
265
246
  " )\n",
@@ -359,7 +340,7 @@
359
340
  " ṡ = data.joint_velocities()\n",
360
341
  "\n",
361
342
  " # Compute the actuated joint torques.\n",
362
- " s_star = - kp * (s - s_des) - kd * (ṡ - s_dot_des)\n",
343
+ " s_star = -kp * (s - s_des) - kd * (ṡ - s_dot_des)\n",
363
344
  " τ = Mss @ s_star + hs\n",
364
345
  "\n",
365
346
  " return τ"
@@ -407,11 +388,9 @@
407
388
  " )\n",
408
389
  "\n",
409
390
  " # Step the JaxSim simulation.\n",
410
- " data, integrator_state = js.model.step(\n",
391
+ " data, _ = js.model.step(\n",
411
392
  " model=model,\n",
412
393
  " data=data,\n",
413
- " integrator=integrator,\n",
414
- " integrator_state=integrator_state,\n",
415
394
  " joint_force_references=τ,\n",
416
395
  " )\n",
417
396
  "\n",
@@ -461,7 +440,8 @@
461
440
  "toc_visible": true
462
441
  },
463
442
  "kernelspec": {
464
- "display_name": "Python 3",
443
+ "display_name": "jaxsim",
444
+ "language": "python",
465
445
  "name": "python3"
466
446
  },
467
447
  "language_info": {
@@ -474,7 +454,7 @@
474
454
  "name": "python",
475
455
  "nbconvert_exporter": "python",
476
456
  "pygments_lexer": "ipython3",
477
- "version": "3.11.8"
457
+ "version": "3.12.7"
478
458
  }
479
459
  },
480
460
  "nbformat": 4,
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev245'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev245')
15
+ __version__ = version = '0.4.3.dev271'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev271')
@@ -54,6 +54,10 @@ class JaxSimModel(JaxsimDataclass):
54
54
  default=None, repr=False
55
55
  )
56
56
 
57
+ integrator: Static[jaxsim.integrators.Integrator | None] = dataclasses.field(
58
+ default=None, repr=False
59
+ )
60
+
57
61
  _description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
58
62
  dataclasses.field(default=None, repr=False)
59
63
  )
@@ -93,12 +97,16 @@ class JaxSimModel(JaxsimDataclass):
93
97
  # Initialization and state
94
98
  # ========================
95
99
 
96
- @staticmethod
100
+ @classmethod
97
101
  def build_from_model_description(
102
+ cls,
98
103
  model_description: str | pathlib.Path | rod.Model,
99
- model_name: str | None = None,
100
104
  *,
105
+ model_name: str | None = None,
101
106
  time_step: jtp.FloatLike | None = None,
107
+ integrator: (
108
+ jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
109
+ ) = None,
102
110
  terrain: jaxsim.terrain.Terrain | None = None,
103
111
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
104
112
  is_urdf: bool | None = None,
@@ -120,6 +128,10 @@ class JaxSimModel(JaxsimDataclass):
120
128
  contact_model:
121
129
  The contact model to consider.
122
130
  If not specified, a soft contacts model is used.
131
+ integrator:
132
+ The integrator to use. If not specified, a default one is used.
133
+ This argument can either be a pre-built integrator instance or one
134
+ of the integrator classes defined in JaxSim.
123
135
  is_urdf:
124
136
  The optional flag to force the model description to be parsed as a URDF.
125
137
  This is usually automatically inferred.
@@ -146,10 +158,11 @@ class JaxSimModel(JaxsimDataclass):
146
158
  )
147
159
 
148
160
  # Build the model.
149
- model = JaxSimModel.build(
161
+ model = cls.build(
150
162
  model_description=intermediate_description,
151
163
  model_name=model_name,
152
164
  time_step=time_step,
165
+ integrator=integrator,
153
166
  terrain=terrain,
154
167
  contact_model=contact_model,
155
168
  )
@@ -160,12 +173,16 @@ class JaxSimModel(JaxsimDataclass):
160
173
 
161
174
  return model
162
175
 
163
- @staticmethod
176
+ @classmethod
164
177
  def build(
178
+ cls,
165
179
  model_description: ModelDescription,
166
- model_name: str | None = None,
167
180
  *,
181
+ model_name: str | None = None,
168
182
  time_step: jtp.FloatLike | None = None,
183
+ integrator: (
184
+ jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
185
+ ) = None,
169
186
  terrain: jaxsim.terrain.Terrain | None = None,
170
187
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
171
188
  ) -> JaxSimModel:
@@ -182,6 +199,11 @@ class JaxSimModel(JaxsimDataclass):
182
199
  The default time step to consider for the simulation. It can be
183
200
  manually overridden in the function that steps the simulation.
184
201
  terrain: The terrain to consider (the default is a flat infinite plane).
202
+ The optional name of the model overriding the physics model name.
203
+ integrator:
204
+ The integrator to use. If not specified, a default one is used.
205
+ This argument can either be a pre-built integrator instance or one
206
+ of the integrator classes defined in JaxSim.
185
207
  contact_model:
186
208
  The contact model to consider.
187
209
  If not specified, a soft contacts model is used.
@@ -195,23 +217,62 @@ class JaxSimModel(JaxsimDataclass):
195
217
 
196
218
  # Consider the default terrain (a flat infinite plane) if not specified.
197
219
  terrain = (
198
- terrain or JaxSimModel.__dataclass_fields__["terrain"].default_factory()
220
+ terrain
221
+ if terrain is not None
222
+ else JaxSimModel.__dataclass_fields__["terrain"].default_factory()
199
223
  )
200
224
 
201
225
  # Consider the default time step if not specified.
202
226
  time_step = (
203
- time_step or JaxSimModel.__dataclass_fields__["time_step"].default_factory()
227
+ time_step
228
+ if time_step is not None
229
+ else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
204
230
  )
205
231
 
206
232
  # Create the default contact model.
207
233
  # It will be populated with an initial estimation of good parameters.
208
234
  # While these might not be the best, they are a good starting point.
209
- contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts.build(
210
- terrain=terrain, parameters=None
235
+ contact_model = (
236
+ contact_model
237
+ if contact_model is not None
238
+ else jaxsim.rbda.contacts.SoftContacts.build(
239
+ terrain=terrain, parameters=None
240
+ )
211
241
  )
212
242
 
243
+ # Build the integrator if not provided.
244
+ match integrator:
245
+
246
+ # If None, build a default integrator.
247
+ case None:
248
+
249
+ integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
250
+ dynamics=js.ode.wrap_system_dynamics_for_integration(
251
+ system_dynamics=js.ode.system_dynamics
252
+ )
253
+ )
254
+
255
+ # If it's a pre-built integrator (also a custom one from the user)
256
+ # just use it as is.
257
+ case _ if isinstance(integrator, jaxsim.integrators.Integrator):
258
+ pass
259
+
260
+ # If an integrator class is passed, assume that it is a JaxSim integrator
261
+ # and build it with the default system dynamics.
262
+ case _ if issubclass(integrator, jaxsim.integrators.Integrator):
263
+
264
+ integrator_cls = integrator
265
+ integrator = integrator_cls.build(
266
+ dynamics=js.ode.wrap_system_dynamics_for_integration(
267
+ system_dynamics=js.ode.system_dynamics
268
+ )
269
+ )
270
+
271
+ case _:
272
+ raise ValueError(f"Invalid integrator: {integrator}")
273
+
213
274
  # Build the model.
214
- model = JaxSimModel(
275
+ model = cls(
215
276
  model_name=model_name,
216
277
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
217
278
  model_description=model_description
@@ -219,6 +280,7 @@ class JaxSimModel(JaxsimDataclass):
219
280
  time_step=time_step,
220
281
  terrain=terrain,
221
282
  contact_model=contact_model,
283
+ integrator=integrator,
222
284
  # The following is wrapped as hashless since it's a static argument, and we
223
285
  # don't want to trigger recompilation if it changes. All relevant parameters
224
286
  # needed to compute kinematics and dynamics quantities are stored in the
@@ -404,6 +466,7 @@ def reduce(
404
466
  reduced_model = JaxSimModel.build(
405
467
  model_description=reduced_intermediate_description,
406
468
  model_name=model.name(),
469
+ time_step=model.time_step,
407
470
  terrain=model.terrain,
408
471
  contact_model=model.contact_model,
409
472
  )
@@ -1912,10 +1975,10 @@ def step(
1912
1975
  model: JaxSimModel,
1913
1976
  data: js.data.JaxSimModelData,
1914
1977
  *,
1915
- integrator: jaxsim.integrators.Integrator,
1916
1978
  t0: jtp.FloatLike = 0.0,
1917
1979
  dt: jtp.FloatLike | None = None,
1918
- integrator_state: dict[str, Any] | None = None,
1980
+ integrator: jaxsim.integrators.Integrator | None = None,
1981
+ integrator_metadata: dict[str, Any] | None = None,
1919
1982
  link_forces: jtp.MatrixLike | None = None,
1920
1983
  joint_force_references: jtp.VectorLike | None = None,
1921
1984
  **kwargs,
@@ -1927,7 +1990,7 @@ def step(
1927
1990
  model: The model to consider.
1928
1991
  data: The data of the considered model.
1929
1992
  integrator: The integrator to use.
1930
- integrator_state: The state of the integrator.
1993
+ integrator_metadata: The metadata of the integrator, if needed.
1931
1994
  t0: The initial time to consider. Only relevant for time-dependent dynamics.
1932
1995
  dt: The time step to consider. If not specified, it is read from the model.
1933
1996
  link_forces:
@@ -1937,8 +2000,9 @@ def step(
1937
2000
  kwargs: Additional kwargs to pass to the integrator.
1938
2001
 
1939
2002
  Returns:
1940
- A tuple containing the new data of the model
1941
- and the new state of the integrator.
2003
+ A tuple containing the new data of the model and a dictionary of auxiliary
2004
+ data computed during the step. If the integrator has metadata, the dictionary
2005
+ will contain the new metadata stored in the `integrator_metadata` key.
1942
2006
 
1943
2007
  Note:
1944
2008
  In order to reduce the occurrences of frame conversions performed internally,
@@ -1953,8 +2017,9 @@ def step(
1953
2017
  integrator_kwargs = kwargs.pop("integrator_kwargs", {})
1954
2018
  integrator_kwargs = kwargs | integrator_kwargs
1955
2019
 
1956
- # Initialize the integrator state.
1957
- integrator_state_t0 = integrator_state if integrator_state is not None else dict()
2020
+ # Extract the integrator and the optional metadata.
2021
+ integrator_metadata_t0 = integrator_metadata
2022
+ integrator = integrator if integrator is not None else model.integrator
1958
2023
 
1959
2024
  # Initialize the time-related variables.
1960
2025
  state_t0 = data.state
@@ -2010,11 +2075,11 @@ def step(
2010
2075
  τ_references = references.joint_force_references(model=model)
2011
2076
 
2012
2077
  # Step the dynamics forward.
2013
- state_tf, integrator_state_tf = integrator.step(
2078
+ state_tf, integrator_metadata_tf = integrator.step(
2014
2079
  x0=state_t0,
2015
2080
  t0=t0,
2016
2081
  dt=dt,
2017
- params=integrator_state_t0,
2082
+ metadata=integrator_metadata_t0,
2018
2083
  # Always inject the current (model, data) pair into the system dynamics
2019
2084
  # considered by the integrator, and include the input variables represented
2020
2085
  # by the pair (f_L, τ_references).
@@ -2091,13 +2156,17 @@ def step(
2091
2156
  )
2092
2157
  )
2093
2158
 
2094
- # Reset the generalized velocity.
2095
- data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
2096
- data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
2159
+ # Reset the generalized velocity.
2160
+ data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
2161
+ data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
2097
2162
 
2098
2163
  # Restore the input velocity representation.
2099
2164
  data_tf = data_tf.replace(
2100
2165
  velocity_representation=data.velocity_representation, validate=False
2101
2166
  )
2102
2167
 
2103
- return data_tf, integrator_state_tf
2168
+ return data_tf, {} | (
2169
+ dict(integrator_metadata=integrator_metadata_tf)
2170
+ if integrator_metadata is not None
2171
+ else {}
2172
+ )
@@ -24,41 +24,45 @@ class SystemDynamicsFromModelAndData(Protocol):
24
24
 
25
25
 
26
26
  def wrap_system_dynamics_for_integration(
27
- model: js.model.JaxSimModel,
28
- data: js.data.JaxSimModelData,
29
27
  *,
30
28
  system_dynamics: SystemDynamicsFromModelAndData,
31
- **kwargs,
29
+ **kwargs: dict[str, Any],
32
30
  ) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
33
31
  """
34
- Wrap generic system dynamics operating on `JaxSimModel` and `JaxSimModelData`
35
- for integration with `jaxsim.integrators`.
32
+ Wrap the system dynamics considered by JaxSim integrators in a generic
33
+ `f(x, t, **u, **parameters)` function.
36
34
 
37
35
  Args:
38
- model: The model to consider.
39
- data: The data of the considered model.
40
36
  system_dynamics: The system dynamics to wrap.
41
37
  **kwargs: Additional kwargs to close over the system dynamics.
42
38
 
43
39
  Returns:
44
- The system dynamics closed over the model, the data, and the additional kwargs.
40
+ The system dynamics closed over the additional kwargs to be used by
41
+ JaxSim integrators.
45
42
  """
46
43
 
47
- # We allow to close `system_dynamics` over additional kwargs.
48
- kwargs_closed = kwargs.copy()
49
-
50
- # Create a local copy of model and data.
51
- # The wrapped dynamics will hold a reference of this object.
52
- model_closed = model.copy()
53
- data_closed = data.copy().replace(
54
- state=js.ode_data.ODEState.zero(model=model_closed, data=data)
55
- )
56
-
44
+ # Close `system_dynamics` over additional kwargs.
45
+ # Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
46
+ # following logic to allow the caller to close over arguments having the same name
47
+ # of the ones used in the `wrap_system_dynamics_for_integration` function.
48
+ kwargs = kwargs.copy() if kwargs is not None else {}
49
+ colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
50
+ system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
51
+
52
+ # Remove `model` and `data` for backward compatibility.
53
+ # It's no longer necessary to close over them at this stage, as this is always
54
+ # done in `jaxsim.api.model.step`.
55
+ # We can remove the following lines in a few releases.
56
+ _ = system_dynamics_kwargs.pop("data", None)
57
+ _ = system_dynamics_kwargs.pop("model", None)
58
+
59
+ # Create the function with the signature expected by our generic integrators.
60
+ # Note that our system dynamics is time independent.
57
61
  def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
58
62
 
59
- # Allow caller to override the closed data and model objects.
60
- data_f = kwargs_f.pop("data", data_closed)
61
- model_f = kwargs_f.pop("model", model_closed)
63
+ # Get the data and model objects from the kwargs.
64
+ data_f = kwargs_f.pop("data")
65
+ model_f = kwargs_f.pop("model")
62
66
 
63
67
  # Update the state and time stored inside data.
64
68
  with data_f.editable(validate=True) as data_rw:
@@ -69,7 +73,7 @@ def wrap_system_dynamics_for_integration(
69
73
  return system_dynamics(
70
74
  model=model_f,
71
75
  data=data_rw,
72
- **(kwargs_closed | kwargs_f),
76
+ **(system_dynamics_kwargs | kwargs_f),
73
77
  )
74
78
 
75
79
  f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]