jaxsim 0.4.3.dev312__tar.gz → 0.4.3.dev350__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/PKG-INFO +1 -1
  2. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/_version.py +2 -2
  3. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/contact.py +65 -28
  4. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/joint.py +8 -9
  5. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/kin_dyn_parameters.py +9 -4
  6. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/link.py +3 -4
  7. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/model.py +21 -22
  8. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/references.py +1 -1
  9. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/integrators/common.py +2 -2
  10. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/integrators/variable_step.py +6 -12
  11. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/loaders.py +9 -138
  12. jaxsim-0.4.3.dev350/src/jaxsim/mujoco/utils.py +223 -0
  13. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/joint.py +1 -26
  14. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/kinematic_graph.py +3 -3
  15. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/rod/parser.py +3 -6
  16. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/rod/utils.py +1 -1
  17. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/collidable_points.py +18 -5
  18. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/common.py +11 -9
  19. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/relaxed_rigid.py +14 -5
  20. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/rigid.py +9 -6
  21. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/soft.py +17 -4
  22. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/jacobian.py +2 -2
  23. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/utils.py +1 -1
  24. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/terrain/terrain.py +9 -1
  25. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/utils/tracing.py +3 -9
  26. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/utils/wrappers.py +1 -1
  27. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/PKG-INFO +1 -1
  28. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_api_contact.py +30 -10
  29. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_api_data.py +5 -3
  30. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_api_joint.py +1 -1
  31. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_api_link.py +1 -1
  32. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_api_model.py +8 -6
  33. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_exceptions.py +10 -12
  34. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_pytree.py +6 -7
  35. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_simulations.py +34 -0
  36. jaxsim-0.4.3.dev312/src/jaxsim/mujoco/utils.py +0 -101
  37. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.devcontainer/Dockerfile +0 -0
  38. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.devcontainer/devcontainer.json +0 -0
  39. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.gitattributes +0 -0
  40. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.github/CODEOWNERS +0 -0
  41. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.github/dependabot.yml +0 -0
  42. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.github/workflows/ci_cd.yml +0 -0
  43. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.github/workflows/pixi.yml +0 -0
  44. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.github/workflows/read_the_docs.yml +0 -0
  45. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.gitignore +0 -0
  46. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.pre-commit-config.yaml +0 -0
  47. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/.readthedocs.yaml +0 -0
  48. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/CONTRIBUTING.md +0 -0
  49. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/LICENSE +0 -0
  50. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/README.md +0 -0
  51. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/Makefile +0 -0
  52. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/conf.py +0 -0
  53. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/examples.rst +0 -0
  54. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/guide/install.rst +0 -0
  55. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/index.rst +0 -0
  56. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/make.bat +0 -0
  57. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/modules/api.rst +0 -0
  58. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/modules/integrators.rst +0 -0
  59. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/modules/math.rst +0 -0
  60. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/modules/mujoco.rst +0 -0
  61. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/modules/parsers.rst +0 -0
  62. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/modules/rbda.rst +0 -0
  63. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/modules/typing.rst +0 -0
  64. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/docs/modules/utils.rst +0 -0
  65. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/environment.yml +0 -0
  66. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/examples/.gitattributes +0 -0
  67. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/examples/.gitignore +0 -0
  68. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/examples/README.md +0 -0
  69. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/examples/assets/build_cartpole_urdf.py +0 -0
  70. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/examples/assets/cartpole.urdf +0 -0
  71. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  72. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  73. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  74. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/pixi.lock +0 -0
  75. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/pyproject.toml +0 -0
  76. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/setup.cfg +0 -0
  77. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/setup.py +0 -0
  78. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/__init__.py +0 -0
  79. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/__init__.py +0 -0
  80. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/com.py +0 -0
  81. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/common.py +0 -0
  82. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/data.py +0 -0
  83. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/frame.py +0 -0
  84. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/ode.py +0 -0
  85. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/api/ode_data.py +0 -0
  86. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/exceptions.py +0 -0
  87. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/integrators/__init__.py +0 -0
  88. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/integrators/fixed_step.py +0 -0
  89. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/logging.py +0 -0
  90. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/math/__init__.py +0 -0
  91. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/math/adjoint.py +0 -0
  92. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/math/cross.py +0 -0
  93. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/math/inertia.py +0 -0
  94. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/math/joint_model.py +0 -0
  95. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/math/quaternion.py +0 -0
  96. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/math/rotation.py +0 -0
  97. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/math/skew.py +0 -0
  98. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/math/transform.py +0 -0
  99. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/__init__.py +0 -0
  100. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/__main__.py +0 -0
  101. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/model.py +0 -0
  102. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/visualizer.py +0 -0
  103. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/__init__.py +0 -0
  104. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  105. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  106. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/link.py +0 -0
  107. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/model.py +0 -0
  108. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/rod/__init__.py +0 -0
  109. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/__init__.py +0 -0
  110. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/aba.py +0 -0
  111. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  112. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/visco_elastic.py +0 -0
  113. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/crba.py +0 -0
  114. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  115. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/rnea.py +0 -0
  116. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/terrain/__init__.py +0 -0
  117. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/typing.py +0 -0
  118. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/utils/__init__.py +0 -0
  119. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  120. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  121. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  122. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/requires.txt +0 -0
  123. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/top_level.txt +0 -0
  124. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/__init__.py +0 -0
  125. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/conftest.py +0 -0
  126. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_api_com.py +0 -0
  127. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_api_frame.py +0 -0
  128. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_automatic_differentiation.py +0 -0
  129. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/test_contact.py +0 -0
  130. {jaxsim-0.4.3.dev312 → jaxsim-0.4.3.dev350}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev312
3
+ Version: 0.4.3.dev350
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev312'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev312')
15
+ __version__ = version = '0.4.3.dev350'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev350')
@@ -138,7 +138,7 @@ def collidable_point_dynamics(
138
138
  **kwargs,
139
139
  ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
140
140
  r"""
141
- Compute the 6D force applied to each collidable point.
141
+ Compute the 6D force applied to each enabled collidable point.
142
142
 
143
143
  Args:
144
144
  model: The model to consider.
@@ -151,7 +151,7 @@ def collidable_point_dynamics(
151
151
  kwargs: Additional keyword arguments to pass to the active contact model.
152
152
 
153
153
  Returns:
154
- The 6D force applied to each collidable point and additional data based
154
+ The 6D force applied to each eneabled collidable point and additional data based
155
155
  on the contact model configured:
156
156
  - Soft: the material deformation rate.
157
157
  - Rigid: no additional data.
@@ -199,15 +199,19 @@ def collidable_point_dynamics(
199
199
  )
200
200
 
201
201
  # Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
202
- # associated to each collidable point.
202
+ # associated to the enabled collidable point.
203
203
  # In inertial-fixed representation, the computation of these transforms
204
204
  # is not necessary and the conversion below becomes a no-op.
205
+
206
+ # Get the indices of the enabled collidable points.
207
+ indices_of_enabled_collidable_points = (
208
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
209
+ )
210
+
205
211
  W_H_C = (
206
212
  js.contact.transforms(model=model, data=data)
207
213
  if data.velocity_representation is not VelRepr.Inertial
208
- else jnp.zeros(
209
- shape=(len(model.kin_dyn_parameters.contact_parameters.body), 4, 4)
210
- )
214
+ else jnp.zeros(shape=(len(indices_of_enabled_collidable_points), 4, 4))
211
215
  )
212
216
 
213
217
  # Convert the 6D forces to the active representation.
@@ -246,6 +250,15 @@ def in_contact(
246
250
  if link_names is not None and set(link_names).difference(model.link_names()):
247
251
  raise ValueError("One or more link names are not part of the model")
248
252
 
253
+ # Get the indices of the enabled collidable points.
254
+ indices_of_enabled_collidable_points = (
255
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
256
+ )
257
+
258
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
259
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
260
+ )[indices_of_enabled_collidable_points]
261
+
249
262
  W_p_Ci = collidable_point_positions(model=model, data=data)
250
263
 
251
264
  terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
@@ -262,7 +275,7 @@ def in_contact(
262
275
 
263
276
  links_in_contact = jax.vmap(
264
277
  lambda link_index: jnp.where(
265
- jnp.array(model.kin_dyn_parameters.contact_parameters.body) == link_index,
278
+ parent_link_idx_of_enabled_collidable_points == link_index,
266
279
  below_terrain,
267
280
  jnp.zeros_like(below_terrain, dtype=bool),
268
281
  ).any()
@@ -426,14 +439,14 @@ def estimate_good_contact_parameters(
426
439
  @jax.jit
427
440
  def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
428
441
  r"""
429
- Return the pose of the collidable points.
442
+ Return the pose of the enabled collidable points.
430
443
 
431
444
  Args:
432
445
  model: The model to consider.
433
446
  data: The data of the considered model.
434
447
 
435
448
  Returns:
436
- The stacked SE(3) matrices of all collidable points.
449
+ The stacked SE(3) matrices of all enabled collidable points.
437
450
 
438
451
  Note:
439
452
  Each collidable point is implicitly associated with a frame
@@ -442,16 +455,27 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
442
455
  rigidly attached to.
443
456
  """
444
457
 
458
+ # Get the indices of the enabled collidable points.
459
+ indices_of_enabled_collidable_points = (
460
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
461
+ )
462
+
463
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
464
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
465
+ )[indices_of_enabled_collidable_points]
466
+
445
467
  # Get the transforms of the parent link of all collidable points.
446
468
  W_H_L = js.model.forward_kinematics(model=model, data=data)[
447
- jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int)
469
+ parent_link_idx_of_enabled_collidable_points
470
+ ]
471
+
472
+ L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
473
+ indices_of_enabled_collidable_points
448
474
  ]
449
475
 
450
476
  # Build the link-to-point transform from the displacement between the link frame L
451
477
  # and the implicit contact frame C.
452
- L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(
453
- model.kin_dyn_parameters.contact_parameters.point
454
- )
478
+ L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)
455
479
 
456
480
  # Compose the work-to-link and link-to-point transforms.
457
481
  return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
@@ -465,7 +489,7 @@ def jacobian(
465
489
  output_vel_repr: VelRepr | None = None,
466
490
  ) -> jtp.Array:
467
491
  r"""
468
- Return the free-floating Jacobian of the collidable points.
492
+ Return the free-floating Jacobian of the enabled collidable points.
469
493
 
470
494
  Args:
471
495
  model: The model to consider.
@@ -475,7 +499,7 @@ def jacobian(
475
499
 
476
500
  Returns:
477
501
  The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the
478
- collidable points.
502
+ enabled collidable points.
479
503
 
480
504
  Note:
481
505
  Each collidable point is implicitly associated with a frame
@@ -488,6 +512,15 @@ def jacobian(
488
512
  output_vel_repr if output_vel_repr is not None else data.velocity_representation
489
513
  )
490
514
 
515
+ # Get the indices of the enabled collidable points.
516
+ indices_of_enabled_collidable_points = (
517
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
518
+ )
519
+
520
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
521
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
522
+ )[indices_of_enabled_collidable_points]
523
+
491
524
  # Compute the Jacobians of all links.
492
525
  W_J_WL = js.model.generalized_free_floating_jacobian(
493
526
  model=model, data=data, output_vel_repr=VelRepr.Inertial
@@ -496,9 +529,7 @@ def jacobian(
496
529
  # Compute the contact Jacobian.
497
530
  # In inertial-fixed output representation, the Jacobian of the parent link is also
498
531
  # the Jacobian of the frame C implicitly associated with the collidable point.
499
- W_J_WC = W_J_WL[
500
- jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int)
501
- ]
532
+ W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points]
502
533
 
503
534
  # Adjust the output representation.
504
535
  match output_vel_repr:
@@ -550,7 +581,7 @@ def jacobian_derivative(
550
581
  output_vel_repr: VelRepr | None = None,
551
582
  ) -> jtp.Matrix:
552
583
  r"""
553
- Compute the derivative of the free-floating jacobian of the contact points.
584
+ Compute the derivative of the free-floating jacobian of the enabled collidable points.
554
585
 
555
586
  Args:
556
587
  model: The model to consider.
@@ -559,7 +590,7 @@ def jacobian_derivative(
559
590
  The output velocity representation of the free-floating jacobian derivative.
560
591
 
561
592
  Returns:
562
- The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the contact points.
593
+ The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the enabled collidable points.
563
594
 
564
595
  Note:
565
596
  The input representation of the free-floating jacobian derivative is the active
@@ -570,10 +601,18 @@ def jacobian_derivative(
570
601
  output_vel_repr if output_vel_repr is not None else data.velocity_representation
571
602
  )
572
603
 
604
+ indices_of_enabled_collidable_points = (
605
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
606
+ )
607
+
573
608
  # Get the index of the parent link and the position of the collidable point.
574
- parent_link_idxs = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
575
- L_p_Ci = jnp.array(model.kin_dyn_parameters.contact_parameters.point)
576
- contact_idxs = jnp.arange(L_p_Ci.shape[0])
609
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
610
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
611
+ )[indices_of_enabled_collidable_points]
612
+
613
+ L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
614
+ indices_of_enabled_collidable_points
615
+ ]
577
616
 
578
617
  # Get the transforms of all the parent links.
579
618
  W_H_Li = js.model.forward_kinematics(model=model, data=data)
@@ -646,7 +685,7 @@ def jacobian_derivative(
646
685
  output_vel_repr=VelRepr.Inertial,
647
686
  )
648
687
 
649
- # Get the Jacobian of the collidable points in the mixed representation.
688
+ # Get the Jacobian of the enabled collidable points in the mixed representation.
650
689
  with data.switch_velocity_representation(VelRepr.Mixed):
651
690
  CW_J_WC_BW = jacobian(
652
691
  model=model,
@@ -656,13 +695,11 @@ def jacobian_derivative(
656
695
 
657
696
  def compute_O_J̇_WC_I(
658
697
  L_p_C: jtp.Vector,
659
- contact_idx: jtp.Int,
698
+ parent_link_idx: jtp.Int,
660
699
  CW_J_WC_BW: jtp.Matrix,
661
700
  W_H_L: jtp.Matrix,
662
701
  ) -> jtp.Matrix:
663
702
 
664
- parent_link_idx = parent_link_idxs[contact_idx]
665
-
666
703
  match output_vel_repr:
667
704
  case VelRepr.Inertial:
668
705
  O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841
@@ -703,7 +740,7 @@ def jacobian_derivative(
703
740
  return O_J̇_WC_I
704
741
 
705
742
  O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, 0, None))(
706
- L_p_Ci, contact_idxs, CW_J_WC_BW, W_H_Li
743
+ L_p_Ci, parent_link_idx_of_enabled_collidable_points, CW_J_WC_BW, W_H_Li
707
744
  )
708
745
 
709
746
  return O_J̇_WC
@@ -53,9 +53,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
53
53
  """
54
54
 
55
55
  exceptions.raise_value_error_if(
56
- condition=jnp.array(
57
- [joint_index < 0, joint_index >= model.number_of_joints()]
58
- ).any(),
56
+ condition=joint_index < 0,
59
57
  msg="Invalid joint index '{idx}'",
60
58
  idx=joint_index,
61
59
  )
@@ -123,10 +121,7 @@ def position_limit(
123
121
  """
124
122
 
125
123
  if model.number_of_joints() == 0:
126
- s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min
127
- s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max
128
-
129
- return jnp.atleast_1d(s_min).astype(float), jnp.atleast_1d(s_max).astype(float)
124
+ return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
130
125
 
131
126
  exceptions.raise_value_error_if(
132
127
  condition=jnp.array(
@@ -136,8 +131,12 @@ def position_limit(
136
131
  idx=joint_index,
137
132
  )
138
133
 
139
- s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
140
- s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
134
+ s_min = jnp.atleast_1d(
135
+ model.kin_dyn_parameters.joint_parameters.position_limits_min
136
+ )[joint_index]
137
+ s_max = jnp.atleast_1d(
138
+ model.kin_dyn_parameters.joint_parameters.position_limits_max
139
+ )[joint_index]
141
140
 
142
141
  return s_min.astype(float), s_max.astype(float)
143
142
 
@@ -438,7 +438,9 @@ class KynDynParameters(JaxsimDataclass):
438
438
  # Helpers to update parameters
439
439
  # ============================
440
440
 
441
- def set_link_mass(self, link_index: int, mass: jtp.FloatLike) -> KynDynParameters:
441
+ def set_link_mass(
442
+ self, link_index: jtp.IntLike, mass: jtp.FloatLike
443
+ ) -> KynDynParameters:
442
444
  """
443
445
  Set the mass of a link.
444
446
 
@@ -457,7 +459,7 @@ class KynDynParameters(JaxsimDataclass):
457
459
  return self.replace(link_parameters=link_parameters)
458
460
 
459
461
  def set_link_inertia(
460
- self, link_index: int, inertia: jtp.MatrixLike
462
+ self, link_index: jtp.IntLike, inertia: jtp.MatrixLike
461
463
  ) -> KynDynParameters:
462
464
  r"""
463
465
  Set the inertia tensor of a link.
@@ -593,10 +595,10 @@ class LinkParameters(JaxsimDataclass):
593
595
  """
594
596
 
595
597
  # Extract the link parameters from the 6D spatial inertia.
596
- m, L_p_CoM, I = Inertia.to_params(M=M)
598
+ m, L_p_CoM, I_CoM = Inertia.to_params(M=M)
597
599
 
598
600
  # Extract only the necessary elements of the inertia tensor.
599
- inertia_elements = I[jnp.triu_indices(3)]
601
+ inertia_elements = I_CoM[jnp.triu_indices(3)]
600
602
 
601
603
  return LinkParameters(
602
604
  index=jnp.array(index).squeeze().astype(int),
@@ -743,6 +745,9 @@ class ContactParameters(JaxsimDataclass):
743
745
  point:
744
746
  The translations between the link frame and the collidable point, expressed
745
747
  in the coordinates of the parent link frame.
748
+ enabled:
749
+ A tuple of booleans representing, for each collidable point, whether it is
750
+ enabled or not in contact models.
746
751
 
747
752
  Note:
748
753
  Contrarily to LinkParameters and JointParameters, this class is not meant
@@ -4,6 +4,7 @@ from collections.abc import Sequence
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import jax.scipy.linalg
7
+ import numpy as np
7
8
 
8
9
  import jaxsim.api as js
9
10
  import jaxsim.rbda
@@ -54,9 +55,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
54
55
  """
55
56
 
56
57
  exceptions.raise_value_error_if(
57
- condition=jnp.array(
58
- [link_index < 0, link_index >= model.number_of_links()]
59
- ).any(),
58
+ condition=link_index < 0,
60
59
  msg="Invalid link index '{idx}'",
61
60
  idx=link_index,
62
61
  )
@@ -98,7 +97,7 @@ def idxs_to_names(
98
97
  The names of the links.
99
98
  """
100
99
 
101
- return tuple(idx_to_name(model=model, link_index=idx) for idx in link_indices)
100
+ return tuple(np.array(model.kin_dyn_parameters.link_names)[list(link_indices)])
102
101
 
103
102
 
104
103
  # =========
@@ -304,7 +304,7 @@ class JaxSimModel(JaxsimDataclass):
304
304
 
305
305
  return self.model_name
306
306
 
307
- def number_of_links(self) -> jtp.Int:
307
+ def number_of_links(self) -> int:
308
308
  """
309
309
  Return the number of links in the model.
310
310
 
@@ -317,7 +317,7 @@ class JaxSimModel(JaxsimDataclass):
317
317
 
318
318
  return self.kin_dyn_parameters.number_of_links()
319
319
 
320
- def number_of_joints(self) -> jtp.Int:
320
+ def number_of_joints(self) -> int:
321
321
  """
322
322
  Return the number of joints in the model.
323
323
 
@@ -419,7 +419,7 @@ class JaxSimModel(JaxsimDataclass):
419
419
  def reduce(
420
420
  model: JaxSimModel,
421
421
  considered_joints: tuple[str, ...],
422
- locked_joint_positions: dict[str, jtp.Float] | None = None,
422
+ locked_joint_positions: dict[str, jtp.FloatLike] | None = None,
423
423
  ) -> JaxSimModel:
424
424
  """
425
425
  Reduce the model by lumping together the links connected by removed joints.
@@ -1038,12 +1038,7 @@ def forward_dynamics_aba(
1038
1038
  C_v̇_WB = to_active(
1039
1039
  W_v̇_WB=W_v̇_WB,
1040
1040
  W_H_C=W_H_C,
1041
- W_v_WB=jnp.hstack(
1042
- [
1043
- data.state.physics_model.base_linear_velocity,
1044
- data.state.physics_model.base_angular_velocity,
1045
- ]
1046
- ),
1041
+ W_v_WB=W_v_WB,
1047
1042
  W_v_WC=W_v_WC,
1048
1043
  )
1049
1044
 
@@ -2274,20 +2269,23 @@ def step(
2274
2269
  # Raise runtime error for not supported case in which Rigid contacts and
2275
2270
  # Baumgarte stabilization are enabled and used with ForwardEuler integrator.
2276
2271
  jaxsim.exceptions.raise_runtime_error_if(
2277
- condition=jnp.logical_and(
2278
- isinstance(
2279
- integrator,
2280
- jaxsim.integrators.fixed_step.ForwardEuler
2281
- | jaxsim.integrators.fixed_step.ForwardEulerSO3,
2282
- ),
2283
- jnp.array(
2284
- [data_tf.contacts_params.K, data_tf.contacts_params.D]
2285
- ).any(),
2286
- ),
2272
+ condition=isinstance(
2273
+ integrator,
2274
+ jaxsim.integrators.fixed_step.ForwardEuler
2275
+ | jaxsim.integrators.fixed_step.ForwardEulerSO3,
2276
+ )
2277
+ & ((data_tf.contacts_params.K > 0) | (data_tf.contacts_params.D > 0)),
2287
2278
  msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
2288
2279
  )
2289
2280
 
2290
- W_p_C = js.contact.collidable_point_positions(model, data_tf)
2281
+ # Extract the indices corresponding to the enabled collidable points.
2282
+ indices_of_enabled_collidable_points = (
2283
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
2284
+ )
2285
+
2286
+ W_p_C = js.contact.collidable_point_positions(model, data_tf)[
2287
+ indices_of_enabled_collidable_points
2288
+ ]
2291
2289
 
2292
2290
  # Compute the penetration depth of the collidable points.
2293
2291
  δ, *_ = jax.vmap(
@@ -2296,8 +2294,9 @@ def step(
2296
2294
  )(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
2297
2295
 
2298
2296
  with data_tf.switch_velocity_representation(VelRepr.Mixed):
2299
-
2300
- J_WC = js.contact.jacobian(model, data_tf)
2297
+ J_WC = js.contact.jacobian(model, data_tf)[
2298
+ indices_of_enabled_collidable_points
2299
+ ]
2301
2300
  M = js.model.free_floating_mass_matrix(model, data_tf)
2302
2301
 
2303
2302
  # Compute the impact velocity.
@@ -503,7 +503,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
503
503
  ]
504
504
 
505
505
  exceptions.raise_value_error_if(
506
- condition=jnp.logical_not(data.valid(model=model)),
506
+ condition=~data.valid(model=model),
507
507
  msg="The provided data is not valid for the model",
508
508
  )
509
509
  W_H_Fi = jax.vmap(
@@ -319,7 +319,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
319
319
  f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
320
320
 
321
321
  # Initialize the carry of the for loop with the stacked kᵢ vectors.
322
- carry0 = jax.tree_map(
322
+ carry0 = jax.tree.map(
323
323
  lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
324
324
  )
325
325
 
@@ -507,7 +507,7 @@ class ExplicitRungeKuttaSO3Mixin:
507
507
 
508
508
  # We assume that the initial quaternion is already unary.
509
509
  exceptions.raise_runtime_error_if(
510
- condition=jnp.logical_not(jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0)),
510
+ condition=~jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0),
511
511
  msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
512
512
  )
513
513
 
@@ -152,7 +152,7 @@ def compute_pytree_scale(
152
152
  """
153
153
 
154
154
  # Consider a zero second pytree, if not given.
155
- x2 = jax.tree.map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2
155
+ x2 = jax.tree.map(jnp.zeros_like, x1) if x2 is None else x2
156
156
 
157
157
  # Compute the scaling factors of the initial state and its derivative.
158
158
  compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
@@ -199,9 +199,7 @@ def local_error_estimation(
199
199
 
200
200
  # Consider a zero estimated final state, if not given.
201
201
  xf_estimate = (
202
- jax.tree.map(lambda l: jnp.zeros_like(l), xf)
203
- if xf_estimate is None
204
- else xf_estimate
202
+ jax.tree.map(jnp.zeros_like, xf) if xf_estimate is None else xf_estimate
205
203
  )
206
204
 
207
205
  # Estimate the error.
@@ -483,14 +481,10 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
483
481
  metadata_next,
484
482
  discarded_steps,
485
483
  ) = jax.lax.cond(
486
- pred=jnp.array(
487
- [
488
- discarded_steps >= self.max_step_rejections,
489
- local_error <= 1.0,
490
- Δt_next < self.dt_min,
491
- integrator_init,
492
- ]
493
- ).any(),
484
+ pred=discarded_steps
485
+ >= self.max_step_rejections | local_error
486
+ <= 1.0 | Δt_next
487
+ < self.dt_min | integrator_init,
494
488
  true_fun=accept_step,
495
489
  false_fun=reject_step,
496
490
  )