jaxsim 0.6.2.dev182__tar.gz → 0.6.2.dev225__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.github/workflows/gpu_benchmark.yml +1 -1
  2. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/PKG-INFO +3 -2
  3. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/examples/jaxsim_as_physics_engine_advanced.ipynb +89 -6
  4. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/__init__.py +0 -1
  6. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/com.py +1 -3
  7. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/common.py +26 -38
  8. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/contact.py +140 -24
  9. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/data.py +96 -33
  10. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/integrators.py +18 -11
  11. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/model.py +25 -43
  12. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/ode.py +28 -6
  13. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/references.py +9 -16
  14. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/__init__.py +1 -1
  15. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/adjoint.py +2 -2
  16. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/transform.py +2 -2
  17. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/utils.py +3 -2
  18. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/visualizer.py +1 -1
  19. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/kinematic_graph.py +1 -1
  20. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/__init__.py +1 -1
  21. jaxsim-0.6.2.dev225/src/jaxsim/rbda/contacts/__init__.py +9 -0
  22. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/contacts/common.py +114 -4
  23. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/contacts/relaxed_rigid.py +57 -177
  24. jaxsim-0.6.2.dev225/src/jaxsim/rbda/contacts/rigid.py +538 -0
  25. jaxsim-0.6.2.dev225/src/jaxsim/rbda/contacts/soft.py +448 -0
  26. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/forward_kinematics.py +0 -29
  27. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/utils.py +2 -2
  28. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/terrain/terrain.py +1 -1
  29. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/PKG-INFO +3 -2
  30. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/SOURCES.txt +2 -2
  31. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_api_contact.py +32 -0
  32. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_automatic_differentiation.py +58 -3
  33. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_simulations.py +104 -1
  34. jaxsim-0.6.2.dev182/src/jaxsim/api/contact_model.py +0 -101
  35. jaxsim-0.6.2.dev182/src/jaxsim/rbda/contacts/__init__.py +0 -5
  36. jaxsim-0.6.2.dev182/tests/test_contact.py +0 -37
  37. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.devcontainer/Dockerfile +0 -0
  38. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.devcontainer/devcontainer.json +0 -0
  39. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.gitattributes +0 -0
  40. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.github/CODEOWNERS +0 -0
  41. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.github/dependabot.yml +0 -0
  42. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.github/workflows/ci_cd.yml +0 -0
  43. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.github/workflows/pixi.yml +0 -0
  44. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.github/workflows/read_the_docs.yml +0 -0
  45. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.gitignore +0 -0
  46. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.pre-commit-config.yaml +0 -0
  47. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/.readthedocs.yaml +0 -0
  48. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/CONTRIBUTING.md +0 -0
  49. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/LICENSE +0 -0
  50. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/README.md +0 -0
  51. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/Makefile +0 -0
  52. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/conf.py +0 -0
  53. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/examples.rst +0 -0
  54. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/guide/configuration.rst +0 -0
  55. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/guide/install.rst +0 -0
  56. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/index.rst +0 -0
  57. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/make.bat +0 -0
  58. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/modules/api.rst +0 -0
  59. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/modules/math.rst +0 -0
  60. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/modules/mujoco.rst +0 -0
  61. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/modules/parsers.rst +0 -0
  62. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/modules/rbda.rst +0 -0
  63. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/modules/typing.rst +0 -0
  64. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/docs/modules/utils.rst +0 -0
  65. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/environment.yml +0 -0
  66. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/examples/.gitattributes +0 -0
  67. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/examples/.gitignore +0 -0
  68. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/examples/README.md +0 -0
  69. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/examples/assets/build_cartpole_urdf.py +0 -0
  70. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/examples/assets/cartpole.urdf +0 -0
  71. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  72. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  73. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  74. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/pixi.lock +0 -0
  75. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/pyproject.toml +0 -0
  76. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/setup.cfg +0 -0
  77. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/setup.py +0 -0
  78. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/__init__.py +0 -0
  79. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/actuation_model.py +0 -0
  80. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/frame.py +0 -0
  81. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/joint.py +0 -0
  82. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  83. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/api/link.py +0 -0
  84. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/exceptions.py +0 -0
  85. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/logging.py +0 -0
  86. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/cross.py +0 -0
  87. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/inertia.py +0 -0
  88. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/joint_model.py +0 -0
  89. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/quaternion.py +0 -0
  90. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/rotation.py +0 -0
  91. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/math/skew.py +0 -0
  92. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/__init__.py +0 -0
  93. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/__main__.py +0 -0
  94. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/loaders.py +0 -0
  95. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/model.py +0 -0
  96. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/utils.py +0 -0
  97. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/__init__.py +0 -0
  98. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  99. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  100. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  101. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/link.py +0 -0
  102. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/model.py +0 -0
  103. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/rod/__init__.py +0 -0
  104. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/rod/meshes.py +0 -0
  105. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/rod/parser.py +0 -0
  106. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/rod/utils.py +0 -0
  107. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/aba.py +0 -0
  108. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/collidable_points.py +0 -0
  109. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/crba.py +0 -0
  110. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/jacobian.py +0 -0
  111. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/rnea.py +0 -0
  112. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/terrain/__init__.py +0 -0
  113. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/typing.py +0 -0
  114. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/utils/__init__.py +0 -0
  115. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  116. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/utils/tracing.py +0 -0
  117. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim/utils/wrappers.py +0 -0
  118. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  119. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/requires.txt +0 -0
  120. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/top_level.txt +0 -0
  121. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/__init__.py +0 -0
  122. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/conftest.py +0 -0
  123. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_api_com.py +0 -0
  124. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_api_data.py +0 -0
  125. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_api_frame.py +0 -0
  126. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_api_joint.py +0 -0
  127. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_api_link.py +0 -0
  128. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_api_model.py +0 -0
  129. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_benchmark.py +0 -0
  130. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_exceptions.py +0 -0
  131. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_meshes.py +0 -0
  132. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_pytree.py +0 -0
  133. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/test_visualizer.py +0 -0
  134. {jaxsim-0.6.2.dev182 → jaxsim-0.6.2.dev225}/tests/utils_idyntree.py +0 -0
@@ -51,7 +51,7 @@ jobs:
51
51
 
52
52
  - name: Run benchmark and store result
53
53
  run: |
54
- pytest tests/test_benchmark.py -k 'not test_rigid_contact_model and not test_soft_contact_model' --gpu-only --batch-size 128 --benchmark-only --benchmark-json output.json
54
+ pytest tests/test_benchmark.py --gpu-only --batch-size 128 --benchmark-only --benchmark-json output.json
55
55
 
56
56
  - name: Compare benchmark results with main branch
57
57
  uses: benchmark-action/github-action-benchmark@v1.20.4
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.6.2.dev182
3
+ Version: 0.6.2.dev225
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -89,6 +89,7 @@ Requires-Dist: mujoco>=3.0.0; extra == "viz"
89
89
  Requires-Dist: scipy>=1.14.0; extra == "viz"
90
90
  Provides-Extra: all
91
91
  Requires-Dist: jaxsim[style,testing,viz]; extra == "all"
92
+ Dynamic: license-file
92
93
 
93
94
  # JaxSim
94
95
 
@@ -60,6 +60,8 @@
60
60
  "\n",
61
61
  "import os\n",
62
62
  "\n",
63
+ "os.environ[\"MUJOCO_GL\"] = \"osmesa\"\n",
64
+ "\n",
63
65
  "import jax\n",
64
66
  "\n",
65
67
  "import jax.numpy as jnp\n",
@@ -178,7 +180,6 @@
178
180
  },
179
181
  "outputs": [],
180
182
  "source": [
181
- "\n",
182
183
  "# Initialize the simulated time.\n",
183
184
  "T = jnp.arange(start=0, stop=1.0, step=model.time_step)"
184
185
  ]
@@ -212,7 +213,9 @@
212
213
  "key = jax.random.PRNGKey(seed=0)\n",
213
214
  "\n",
214
215
  "# Split subkeys for sampling random initial data.\n",
215
- "batch_size = 10\n",
216
+ "batch_size = 16\n",
217
+ "row_length = int(jnp.sqrt(batch_size))\n",
218
+ "row_dist = 0.3 * row_length\n",
216
219
  "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n",
217
220
  "\n",
218
221
  "# Create the batched data by sampling the height from [0.5, 0.6] meters.\n",
@@ -220,12 +223,24 @@
220
223
  " lambda key: js.data.random_model_data(\n",
221
224
  " model=model,\n",
222
225
  " key=key,\n",
223
- " base_pos_bounds=([0, 0, 0.3], [0, 0, 0.6]),\n",
226
+ " base_pos_bounds=([0, 0, 0.3], [0, 0, 1.2]),\n",
224
227
  " base_vel_lin_bounds=(0, 0),\n",
225
228
  " base_vel_ang_bounds=(0, 0),\n",
226
229
  " )\n",
227
230
  ")(jnp.vstack(subkeys))\n",
228
231
  "\n",
232
+ "x, y = jnp.meshgrid(\n",
233
+ " jnp.linspace(-row_dist, row_dist, num=row_length),\n",
234
+ " jnp.linspace(-row_dist, row_dist, num=row_length),\n",
235
+ ")\n",
236
+ "xy_coordinate = jnp.stack([x.flatten(), y.flatten()], axis=-1)\n",
237
+ "\n",
238
+ "# Reset the x and y position to a grid.\n",
239
+ "data_batch_t0 = data_batch_t0.replace(\n",
240
+ " model=model,\n",
241
+ " base_position=data_batch_t0.base_position.at[:, :2].set(xy_coordinate),\n",
242
+ ")\n",
243
+ "\n",
229
244
  "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position[0:10])"
230
245
  ]
231
246
  },
@@ -339,7 +354,7 @@
339
354
  "import matplotlib.pyplot as plt\n",
340
355
  "\n",
341
356
  "\n",
342
- "plt.plot(T, data_trajectory.base_position[:, 0:5, 2])\n",
357
+ "plt.plot(T, data_trajectory.base_position[:, :, 2])\n",
343
358
  "plt.grid(True)\n",
344
359
  "plt.xlabel(\"Time [s]\")\n",
345
360
  "plt.ylabel(\"Height [m]\")\n",
@@ -347,6 +362,74 @@
347
362
  "plt.show()"
348
363
  ]
349
364
  },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "import jaxsim.mujoco\n",
372
+ "\n",
373
+ "mjcf_string, assets = jaxsim.mujoco.ModelToMjcf.convert(\n",
374
+ " model.built_from,\n",
375
+ " cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\n",
376
+ " camera_name=\"sphere_cam\",\n",
377
+ " lookat=[0, 0, 0.3],\n",
378
+ " distance=4,\n",
379
+ " azimuth=150,\n",
380
+ " elevation=-10,\n",
381
+ " ),\n",
382
+ ")\n",
383
+ "\n",
384
+ "# Create a helper for each parallel instance.\n",
385
+ "mj_model_helpers = [\n",
386
+ " jaxsim.mujoco.MujocoModelHelper.build_from_xml(\n",
387
+ " mjcf_description=mjcf_string, assets=assets\n",
388
+ " )\n",
389
+ " for _ in range(batch_size)\n",
390
+ "]\n",
391
+ "\n",
392
+ "# Create the video recorder.\n",
393
+ "recorder = jaxsim.mujoco.MujocoVideoRecorder(\n",
394
+ " model=mj_model_helpers[0].model,\n",
395
+ " data=[helper.data for helper in mj_model_helpers],\n",
396
+ " fps=int(1 / model.time_step),\n",
397
+ " width=320 * 2,\n",
398
+ " height=240 * 2,\n",
399
+ ")\n",
400
+ "\n",
401
+ "for data_t in data_trajectory_list:\n",
402
+ "\n",
403
+ " for helper, base_position, base_quaternion, joint_position in zip(\n",
404
+ " mj_model_helpers,\n",
405
+ " data_t.base_position,\n",
406
+ " data_t.base_orientation,\n",
407
+ " data_t.joint_positions,\n",
408
+ " strict=True,\n",
409
+ " ):\n",
410
+ " helper.set_base_position(position=base_position)\n",
411
+ " helper.set_base_orientation(orientation=base_quaternion)\n",
412
+ "\n",
413
+ " if model.dofs() > 0:\n",
414
+ " helper.set_joint_positions(\n",
415
+ " positions=joint_position, joint_names=model.joint_names()\n",
416
+ " )\n",
417
+ "\n",
418
+ " # Record a new video frame.\n",
419
+ " recorder.record_frame(camera_name=\"sphere_cam\")"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": null,
425
+ "metadata": {},
426
+ "outputs": [],
427
+ "source": [
428
+ "import mediapy as media\n",
429
+ "\n",
430
+ "media.show_video(recorder.frames, fps=recorder.fps)"
431
+ ]
432
+ },
350
433
  {
351
434
  "cell_type": "markdown",
352
435
  "metadata": {},
@@ -362,7 +445,7 @@
362
445
  "toc_visible": true
363
446
  },
364
447
  "kernelspec": {
365
- "display_name": "comodo_jaxsim",
448
+ "display_name": "jaxpypi",
366
449
  "language": "python",
367
450
  "name": "python3"
368
451
  },
@@ -376,7 +459,7 @@
376
459
  "name": "python",
377
460
  "nbconvert_exporter": "python",
378
461
  "pygments_lexer": "ipython3",
379
- "version": "3.12.8"
462
+ "version": "3.13.1"
380
463
  }
381
464
  },
382
465
  "nbformat": 4,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.6.2.dev182'
21
- __version_tuple__ = version_tuple = (0, 6, 2, 'dev182')
20
+ __version__ = version = '0.6.2.dev225'
21
+ __version_tuple__ = version_tuple = (0, 6, 2, 'dev225')
@@ -4,7 +4,6 @@ from . import (
4
4
  actuation_model,
5
5
  com,
6
6
  contact,
7
- contact_model,
8
7
  frame,
9
8
  integrators,
10
9
  joint,
@@ -301,9 +301,7 @@ def bias_acceleration(
301
301
  C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
302
302
  C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
303
303
 
304
- L_H_C = L_H_W = jax.vmap( # noqa: F841
305
- lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
306
- )(W_H_L)
304
+ L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841
307
305
 
308
306
  L_v_LC = L_v_LW = jax.vmap( # noqa: F841
309
307
  lambda i: -js.link.velocity(
@@ -121,14 +121,8 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
121
121
  The 6D quantity in the other representation.
122
122
  """
123
123
 
124
- W_array = array.squeeze()
125
- W_H_O = transform.squeeze()
126
-
127
- if W_array.size != 6:
128
- raise ValueError(W_array.size, 6)
129
-
130
- if W_H_O.shape != (4, 4):
131
- raise ValueError(W_H_O.shape, (4, 4))
124
+ W_array = array
125
+ W_H_O = transform
132
126
 
133
127
  match other_representation:
134
128
 
@@ -139,25 +133,24 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
139
133
 
140
134
  if not is_force:
141
135
  O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
142
- O_array = O_Xv_W @ W_array
136
+ O_array = jnp.einsum("...ij,...j->...i", O_Xv_W, W_array)
143
137
 
144
138
  else:
145
- O_Xf_W = Adjoint.from_transform(transform=W_H_O).T
146
- O_array = O_Xf_W @ W_array
139
+ O_Xf_W = Adjoint.from_transform(transform=W_H_O).swapaxes(-1, -2)
140
+ O_array = jnp.einsum("...ij,...j->...i", O_Xf_W, W_array)
147
141
 
148
142
  return O_array
149
143
 
150
144
  case VelRepr.Mixed:
151
- W_p_O = W_H_O[0:3, 3]
152
- W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
145
+ W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
153
146
 
154
147
  if not is_force:
155
148
  OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
156
- OW_array = OW_Xv_W @ W_array
149
+ OW_array = jnp.einsum("...ij,...j->...i", OW_Xv_W, W_array)
157
150
 
158
151
  else:
159
- OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T
160
- OW_array = OW_Xf_W @ W_array
152
+ OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).swapaxes(-1, -2)
153
+ OW_array = jnp.einsum("...ij,...j->...i", OW_Xf_W, W_array)
161
154
 
162
155
  return OW_array
163
156
 
@@ -188,45 +181,40 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
188
181
  The 6D quantity in the inertial-fixed representation.
189
182
  """
190
183
 
191
- W_array = array.squeeze()
192
- W_H_O = transform.squeeze()
193
-
194
- if W_array.size != 6:
195
- raise ValueError(W_array.size, 6)
196
-
197
- if W_H_O.shape != (4, 4):
198
- raise ValueError(W_H_O.shape, (4, 4))
184
+ O_array = array
185
+ W_H_O = transform
199
186
 
200
187
  match other_representation:
201
188
  case VelRepr.Inertial:
202
- W_array = array
203
- return W_array
189
+ return O_array
204
190
 
205
191
  case VelRepr.Body:
206
- O_array = array
207
192
 
208
193
  if not is_force:
209
- W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O)
210
- W_array = W_Xv_O @ O_array
194
+ W_Xv_O = Adjoint.from_transform(W_H_O)
195
+ W_array = jnp.einsum("...ij,...j->...i", W_Xv_O, O_array)
211
196
 
212
197
  else:
213
- W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T
214
- W_array = W_Xf_O @ O_array
198
+ W_Xf_O = Adjoint.from_transform(
199
+ transform=W_H_O, inverse=True
200
+ ).swapaxes(-1, -2)
201
+ W_array = jnp.einsum("...ij,...j->...i", W_Xf_O, O_array)
215
202
 
216
203
  return W_array
217
204
 
218
205
  case VelRepr.Mixed:
219
- BW_array = array
220
- W_p_O = W_H_O[0:3, 3]
221
- W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
206
+
207
+ W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
222
208
 
223
209
  if not is_force:
224
- W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW)
225
- W_array = W_Xv_BW @ BW_array
210
+ W_Xv_BW = Adjoint.from_transform(W_H_OW)
211
+ W_array = jnp.einsum("...ij,...j->...i", W_Xv_BW, O_array)
226
212
 
227
213
  else:
228
- W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T
229
- W_array = W_Xf_BW @ BW_array
214
+ W_Xf_BW = Adjoint.from_transform(
215
+ transform=W_H_OW, inverse=True
216
+ ).swapaxes(-1, -2)
217
+ W_array = jnp.einsum("...ij,...j->...i", W_Xf_BW, O_array)
230
218
 
231
219
  return W_array
232
220
 
@@ -11,7 +11,7 @@ import jaxsim.terrain
11
11
  import jaxsim.typing as jtp
12
12
  from jaxsim import logging
13
13
  from jaxsim.math import Adjoint, Cross, Transform
14
- from jaxsim.rbda import contacts
14
+ from jaxsim.rbda.contacts import SoftContacts
15
15
 
16
16
  from .common import VelRepr
17
17
 
@@ -37,14 +37,11 @@ def collidable_point_kinematics(
37
37
  the linear component of the mixed 6D frame velocity.
38
38
  """
39
39
 
40
- # Switch to inertial-fixed since the RBDAs expect velocities in this representation.
41
- with data.switch_velocity_representation(VelRepr.Inertial):
42
-
43
- W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
44
- model=model,
45
- link_transforms=data._link_transforms,
46
- link_velocities=data._link_velocities,
47
- )
40
+ W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
41
+ model=model,
42
+ link_transforms=data._link_transforms,
43
+ link_velocities=data._link_velocities,
44
+ )
48
45
 
49
46
  return W_p_Ci, W_ṗ_Ci
50
47
 
@@ -164,18 +161,23 @@ def estimate_good_soft_contacts_parameters(
164
161
  def estimate_good_contact_parameters(
165
162
  model: js.model.JaxSimModel,
166
163
  *,
164
+ standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
167
165
  static_friction_coefficient: jtp.FloatLike = 0.5,
168
- **kwargs,
166
+ number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
167
+ damping_ratio: jtp.FloatLike = 1.0,
168
+ max_penetration: jtp.FloatLike | None = None,
169
169
  ) -> jaxsim.rbda.contacts.ContactParamsTypes:
170
170
  """
171
171
  Estimate good contact parameters.
172
172
 
173
173
  Args:
174
174
  model: The model to consider.
175
+ standard_gravity: The standard gravity acceleration.
175
176
  static_friction_coefficient: The static friction coefficient.
176
- kwargs:
177
- Additional model-specific parameters passed to the builder method of
178
- the parameters class.
177
+ number_of_active_collidable_points_steady_state:
178
+ The number of active collidable points in steady state.
179
+ damping_ratio: The damping ratio.
180
+ max_penetration: The maximum penetration allowed.
179
181
 
180
182
  Returns:
181
183
  The estimated good contacts parameters.
@@ -190,20 +192,41 @@ def estimate_good_contact_parameters(
190
192
  specific application.
191
193
  """
192
194
 
193
- match model.contact_model:
195
+ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
196
+ """
197
+ Displacement between the CoM and the lowest collidable point using zero
198
+ joint positions.
199
+ """
200
+
201
+ zero_data = js.data.JaxSimModelData.build(
202
+ model=model,
203
+ )
204
+
205
+ W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
194
206
 
195
- case contacts.RelaxedRigidContacts():
196
- assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
207
+ if model.floating_base():
208
+ W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
209
+ return 2 * (W_pz_CoM - W_pz_C.min())
197
210
 
198
- parameters = contacts.RelaxedRigidContactsParams.build(
199
- mu=static_friction_coefficient,
200
- **kwargs,
201
- )
211
+ return 2 * W_pz_CoM
202
212
 
203
- case _:
204
- raise ValueError(f"Invalid contact model: {model.contact_model}")
213
+ max_δ = (
214
+ max_penetration
215
+ if max_penetration is not None
216
+ # Consider as default a 0.5% of the model height.
217
+ else 0.005 * estimate_model_height(model=model)
218
+ )
205
219
 
206
- return parameters
220
+ nc = number_of_active_collidable_points_steady_state
221
+
222
+ return model.contact_model._parameters_class().build_default_from_jaxsim_model(
223
+ model=model,
224
+ standard_gravity=standard_gravity,
225
+ static_friction_coefficient=static_friction_coefficient,
226
+ max_penetration=max_δ,
227
+ number_of_active_collidable_points_steady_state=nc,
228
+ damping_ratio=damping_ratio,
229
+ )
207
230
 
208
231
 
209
232
  @jax.jit
@@ -244,7 +267,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
244
267
 
245
268
  # Build the link-to-point transform from the displacement between the link frame L
246
269
  # and the implicit contact frame C.
247
- L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)
270
+ L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci)
248
271
 
249
272
  # Compose the work-to-link and link-to-point transforms.
250
273
  return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
@@ -504,3 +527,96 @@ def jacobian_derivative(
504
527
  )
505
528
 
506
529
  return O_J̇_WC
530
+
531
+
532
+ @jax.jit
533
+ @js.common.named_scope
534
+ def link_contact_forces(
535
+ model: js.model.JaxSimModel,
536
+ data: js.data.JaxSimModelData,
537
+ *,
538
+ link_forces: jtp.MatrixLike | None = None,
539
+ joint_torques: jtp.VectorLike | None = None,
540
+ ) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
541
+ """
542
+ Compute the 6D contact forces of all links of the model in inertial representation.
543
+
544
+ Args:
545
+ model: The model to consider.
546
+ data: The data of the considered model.
547
+ link_forces:
548
+ The 6D external forces to apply to the links expressed in inertial representation
549
+ joint_torques:
550
+ The joint torques acting on the joints.
551
+
552
+ Returns:
553
+ A `(nL, 6)` array containing the stacked 6D contact forces of the links,
554
+ expressed in inertial representation.
555
+ """
556
+
557
+ # Compute the contact forces for each collidable point with the active contact model.
558
+ W_f_C, aux_dict = model.contact_model.compute_contact_forces(
559
+ model=model,
560
+ data=data,
561
+ **(
562
+ dict(link_forces=link_forces, joint_force_references=joint_torques)
563
+ if not isinstance(model.contact_model, SoftContacts)
564
+ else {}
565
+ ),
566
+ )
567
+
568
+ # Compute the 6D forces applied to the links equivalent to the forces applied
569
+ # to the frames associated to the collidable points.
570
+ W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)
571
+
572
+ return W_f_L, aux_dict
573
+
574
+
575
+ @staticmethod
576
+ def link_forces_from_contact_forces(
577
+ model: js.model.JaxSimModel,
578
+ *,
579
+ contact_forces: jtp.MatrixLike,
580
+ ) -> jtp.Matrix:
581
+ """
582
+ Compute the link forces from the contact forces.
583
+
584
+ Args:
585
+ model: The robot model considered by the contact model.
586
+ contact_forces: The contact forces computed by the contact model.
587
+
588
+ Returns:
589
+ The 6D contact forces applied to the links and expressed in the frame of
590
+ the velocity representation of data.
591
+ """
592
+
593
+ # Get the object storing the contact parameters of the model.
594
+ contact_parameters = model.kin_dyn_parameters.contact_parameters
595
+
596
+ # Extract the indices corresponding to the enabled collidable points.
597
+ indices_of_enabled_collidable_points = (
598
+ contact_parameters.indices_of_enabled_collidable_points
599
+ )
600
+
601
+ # Convert the contact forces to a JAX array.
602
+ W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
603
+
604
+ # Construct the vector defining the parent link index of each collidable point.
605
+ # We use this vector to sum the 6D forces of all collidable points rigidly
606
+ # attached to the same link.
607
+ parent_link_index_of_collidable_points = jnp.array(
608
+ contact_parameters.body, dtype=int
609
+ )[indices_of_enabled_collidable_points]
610
+
611
+ # Create the mask that associate each collidable point to their parent link.
612
+ # We use this mask to sum the collidable points to the right link.
613
+ mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
614
+ model.number_of_links()
615
+ )
616
+
617
+ # Sum the forces of all collidable points rigidly attached to a body.
618
+ # Since the contact forces W_f_C are expressed in the world frame,
619
+ # we don't need any coordinate transformation.
620
+ W_f_L = mask.T @ W_f_C
621
+
622
+ return W_f_L