jaxsim 0.4.3.dev115__tar.gz → 0.4.3.dev129__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/PKG-INFO +1 -1
  2. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/_version.py +2 -2
  3. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/contact.py +34 -46
  4. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/data.py +13 -15
  5. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/model.py +40 -22
  6. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/ode_data.py +10 -5
  7. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/__init__.py +1 -1
  8. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/contacts/__init__.py +1 -0
  9. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/contacts/common.py +34 -5
  10. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/contacts/relaxed_rigid.py +16 -9
  11. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/contacts/rigid.py +17 -9
  12. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/contacts/soft.py +92 -21
  13. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/terrain/terrain.py +1 -1
  14. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim.egg-info/PKG-INFO +1 -1
  15. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_automatic_differentiation.py +12 -3
  16. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.devcontainer/Dockerfile +0 -0
  17. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.devcontainer/devcontainer.json +0 -0
  18. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.gitattributes +0 -0
  19. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.github/CODEOWNERS +0 -0
  20. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.github/workflows/ci_cd.yml +0 -0
  21. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.github/workflows/read_the_docs.yml +0 -0
  22. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.github/workflows/update_pixi_lockfile.yml +0 -0
  23. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.gitignore +0 -0
  24. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.pre-commit-config.yaml +0 -0
  25. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/.readthedocs.yaml +0 -0
  26. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/CONTRIBUTING.md +0 -0
  27. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/LICENSE +0 -0
  28. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/README.md +0 -0
  29. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/Makefile +0 -0
  30. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/conf.py +0 -0
  31. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/examples.rst +0 -0
  32. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/guide/install.rst +0 -0
  33. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/index.rst +0 -0
  34. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/make.bat +0 -0
  35. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/modules/api.rst +0 -0
  36. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/modules/integrators.rst +0 -0
  37. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/modules/math.rst +0 -0
  38. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/modules/mujoco.rst +0 -0
  39. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/modules/parsers.rst +0 -0
  40. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/modules/rbda.rst +0 -0
  41. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/modules/typing.rst +0 -0
  42. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/docs/modules/utils.rst +0 -0
  43. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/environment.yml +0 -0
  44. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/examples/.gitattributes +0 -0
  45. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/examples/.gitignore +0 -0
  46. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/examples/PD_controller.ipynb +0 -0
  47. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/examples/Parallel_computing.ipynb +0 -0
  48. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/examples/README.md +0 -0
  49. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/examples/assets/cartpole.urdf +0 -0
  50. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/pixi.lock +0 -0
  51. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/pyproject.toml +0 -0
  52. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/setup.cfg +0 -0
  53. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/setup.py +0 -0
  54. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/__init__.py +0 -0
  55. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/__init__.py +0 -0
  56. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/com.py +0 -0
  57. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/common.py +0 -0
  58. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/frame.py +0 -0
  59. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/joint.py +0 -0
  60. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  61. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/link.py +0 -0
  62. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/ode.py +0 -0
  63. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/api/references.py +0 -0
  64. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/exceptions.py +0 -0
  65. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/integrators/__init__.py +0 -0
  66. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/integrators/common.py +0 -0
  67. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/integrators/fixed_step.py +0 -0
  68. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/integrators/variable_step.py +0 -0
  69. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/logging.py +0 -0
  70. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/math/__init__.py +0 -0
  71. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/math/adjoint.py +0 -0
  72. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/math/cross.py +0 -0
  73. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/math/inertia.py +0 -0
  74. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/math/joint_model.py +0 -0
  75. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/math/quaternion.py +0 -0
  76. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/math/rotation.py +0 -0
  77. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/math/skew.py +0 -0
  78. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/math/transform.py +0 -0
  79. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/mujoco/__init__.py +0 -0
  80. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/mujoco/__main__.py +0 -0
  81. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/mujoco/loaders.py +0 -0
  82. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/mujoco/model.py +0 -0
  83. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/mujoco/visualizer.py +0 -0
  84. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/__init__.py +0 -0
  85. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  86. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  87. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  88. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/descriptions/link.py +0 -0
  89. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/descriptions/model.py +0 -0
  90. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  91. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/rod/__init__.py +0 -0
  92. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/rod/parser.py +0 -0
  93. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/parsers/rod/utils.py +0 -0
  94. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/aba.py +0 -0
  95. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/collidable_points.py +0 -0
  96. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/crba.py +0 -0
  97. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  98. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/jacobian.py +0 -0
  99. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/rnea.py +0 -0
  100. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/rbda/utils.py +0 -0
  101. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/terrain/__init__.py +0 -0
  102. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/typing.py +0 -0
  103. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/utils/__init__.py +0 -0
  104. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  105. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/utils/tracing.py +0 -0
  106. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim/utils/wrappers.py +0 -0
  107. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  108. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  109. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim.egg-info/requires.txt +0 -0
  110. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/src/jaxsim.egg-info/top_level.txt +0 -0
  111. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/__init__.py +0 -0
  112. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/conftest.py +0 -0
  113. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_api_com.py +0 -0
  114. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_api_contact.py +0 -0
  115. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_api_data.py +0 -0
  116. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_api_frame.py +0 -0
  117. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_api_joint.py +0 -0
  118. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_api_link.py +0 -0
  119. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_api_model.py +0 -0
  120. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_contact.py +0 -0
  121. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_exceptions.py +0 -0
  122. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_pytree.py +0 -0
  123. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/test_simulations.py +0 -0
  124. {jaxsim-0.4.3.dev115 → jaxsim-0.4.3.dev129}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev115
3
+ Version: 0.4.3.dev129
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev115'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev115')
15
+ __version__ = version = '0.4.3.dev129'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev129')
@@ -9,7 +9,6 @@ import jaxsim.api as js
9
9
  import jaxsim.terrain
10
10
  import jaxsim.typing as jtp
11
11
  from jaxsim.math import Adjoint, Cross, Transform
12
- from jaxsim.rbda.contacts.soft import SoftContactsParams
13
12
 
14
13
  from .common import VelRepr
15
14
 
@@ -156,56 +155,43 @@ def collidable_point_dynamics(
156
155
  Instead, the 6D forces are returned in the active representation.
157
156
  """
158
157
 
159
- # Compute the position and linear velocities (mixed representation) of
160
- # all collidable points belonging to the robot.
161
- W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
162
-
163
158
  # Import privately the contacts classes.
164
- from jaxsim.rbda.contacts.relaxed_rigid import (
159
+ from jaxsim.rbda.contacts import (
165
160
  RelaxedRigidContacts,
166
161
  RelaxedRigidContactsState,
162
+ RigidContacts,
163
+ RigidContactsState,
164
+ SoftContacts,
165
+ SoftContactsState,
167
166
  )
168
- from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
169
- from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
170
167
 
171
168
  # Build the soft contact model.
172
169
  match model.contact_model:
173
170
 
174
171
  case SoftContacts():
175
-
176
172
  assert isinstance(model.contact_model, SoftContacts)
177
173
  assert isinstance(data.state.contact, SoftContactsState)
178
174
 
179
- # Build the contact model.
180
- soft_contacts = SoftContacts(
181
- parameters=data.contacts_params, terrain=model.terrain
182
- )
183
-
184
175
  # Compute the 6D force expressed in the inertial frame and applied to each
185
176
  # collidable point, and the corresponding material deformation rate.
186
177
  # Note that the material deformation rate is always returned in the mixed frame
187
178
  # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
188
- W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
189
- position=W_p_Ci,
190
- velocity=W_ṗ_Ci,
191
- tangential_deformation=data.state.contact.tangential_deformation,
179
+ W_f_Ci, (CW_ṁ,) = model.contact_model.compute_contact_forces(
180
+ model=model, data=data
192
181
  )
182
+
183
+ # Create the dictionary of auxiliary data.
184
+ # This contact model considers the material deformation as additional state
185
+ # of the ODE system. We need to pass its dynamics to the integrator.
193
186
  aux_data = dict(m_dot=CW_ṁ)
194
187
 
195
188
  case RigidContacts():
196
189
  assert isinstance(model.contact_model, RigidContacts)
197
190
  assert isinstance(data.state.contact, RigidContactsState)
198
191
 
199
- # Build the contact model.
200
- rigid_contacts = RigidContacts(
201
- parameters=data.contacts_params, terrain=model.terrain
202
- )
203
-
204
192
  # Compute the 6D force expressed in the inertial frame and applied to each
205
193
  # collidable point.
206
- W_f_Ci, _ = rigid_contacts.compute_contact_forces(
207
- position=W_p_Ci,
208
- velocity=W_ṗ_Ci,
194
+ W_f_Ci, _ = model.contact_model.compute_contact_forces(
209
195
  model=model,
210
196
  data=data,
211
197
  link_forces=link_forces,
@@ -219,16 +205,9 @@ def collidable_point_dynamics(
219
205
  assert isinstance(model.contact_model, RelaxedRigidContacts)
220
206
  assert isinstance(data.state.contact, RelaxedRigidContactsState)
221
207
 
222
- # Build the contact model.
223
- relaxed_rigid_contacts = RelaxedRigidContacts(
224
- parameters=data.contacts_params, terrain=model.terrain
225
- )
226
-
227
208
  # Compute the 6D force expressed in the inertial frame and applied to each
228
209
  # collidable point.
229
- W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
230
- position=W_p_Ci,
231
- velocity=W_ṗ_Ci,
210
+ W_f_Ci, _ = model.contact_model.compute_contact_forces(
232
211
  model=model,
233
212
  data=data,
234
213
  link_forces=link_forces,
@@ -318,7 +297,7 @@ def estimate_good_soft_contacts_parameters(
318
297
  number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
319
298
  damping_ratio: jtp.FloatLike = 1.0,
320
299
  max_penetration: jtp.FloatLike | None = None,
321
- ) -> SoftContactsParams:
300
+ ) -> jaxsim.rbda.contacts.SoftContactsParams:
322
301
  """
323
302
  Estimate good soft contacts parameters for the given model.
324
303
 
@@ -342,14 +321,13 @@ def estimate_good_soft_contacts_parameters(
342
321
  The user is encouraged to fine-tune the parameters based on the
343
322
  specific application.
344
323
  """
345
- from jaxsim.rbda.contacts.soft import SoftContactsParams
346
324
 
347
325
  def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
348
326
  """"""
349
327
 
350
328
  zero_data = js.data.JaxSimModelData.build(
351
329
  model=model,
352
- contacts_params=SoftContactsParams(),
330
+ contacts_params=jaxsim.rbda.contacts.SoftContactsParams(),
353
331
  )
354
332
 
355
333
  W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
@@ -368,16 +346,26 @@ def estimate_good_soft_contacts_parameters(
368
346
 
369
347
  nc = number_of_active_collidable_points_steady_state
370
348
 
371
- sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
372
- model=model,
373
- standard_gravity=standard_gravity,
374
- static_friction_coefficient=static_friction_coefficient,
375
- max_penetration=max_δ,
376
- number_of_active_collidable_points_steady_state=nc,
377
- damping_ratio=damping_ratio,
378
- )
349
+ match model.contact_model:
350
+
351
+ case jaxsim.rbda.contacts.SoftContacts():
352
+ assert isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts)
353
+
354
+ parameters = (
355
+ jaxsim.rbda.contacts.SoftContactsParams.build_default_from_jaxsim_model(
356
+ model=model,
357
+ standard_gravity=standard_gravity,
358
+ static_friction_coefficient=static_friction_coefficient,
359
+ max_penetration=max_δ,
360
+ number_of_active_collidable_points_steady_state=nc,
361
+ damping_ratio=damping_ratio,
362
+ )
363
+ )
364
+
365
+ case _:
366
+ parameters = model.contact_model.parameters
379
367
 
380
- return sc_parameters
368
+ return parameters
381
369
 
382
370
 
383
371
  @jax.jit
@@ -13,7 +13,7 @@ import jaxsim.api as js
13
13
  import jaxsim.math
14
14
  import jaxsim.rbda
15
15
  import jaxsim.typing as jtp
16
- from jaxsim.rbda.contacts.soft import SoftContacts
16
+ from jaxsim.rbda.contacts import SoftContacts
17
17
  from jaxsim.utils import Mutability
18
18
  from jaxsim.utils.tracing import not_tracing
19
19
 
@@ -37,7 +37,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
37
37
 
38
38
  gravity: jtp.Array
39
39
 
40
- contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
40
+ contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
41
41
 
42
42
  time_ns: jtp.Int = dataclasses.field(
43
43
  default_factory=lambda: jnp.array(
@@ -114,8 +114,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
114
114
  base_angular_velocity: jtp.Vector | None = None,
115
115
  joint_velocities: jtp.Vector | None = None,
116
116
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
117
- contact: jaxsim.rbda.ContactsState | None = None,
118
- contacts_params: jaxsim.rbda.ContactsParams | None = None,
117
+ contact: jaxsim.rbda.contacts.ContactsState | None = None,
118
+ contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
119
119
  velocity_representation: VelRepr = VelRepr.Inertial,
120
120
  time: jtp.FloatLike | None = None,
121
121
  ) -> JaxSimModelData:
@@ -185,17 +185,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
185
185
  )
186
186
  )
187
187
 
188
- if isinstance(model.contact_model, SoftContacts):
189
- contacts_params = (
190
- contacts_params
191
- if contacts_params is not None
192
- else js.contact.estimate_good_soft_contacts_parameters(
193
- model=model, standard_gravity=standard_gravity
194
- )
195
- )
196
- else:
197
- contacts_params = model.contact_model.parameters
198
-
199
188
  W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
200
189
  translation=base_position, quaternion=base_quaternion
201
190
  )
@@ -225,6 +214,15 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
225
214
  if not ode_state.valid(model=model):
226
215
  raise ValueError(ode_state)
227
216
 
217
+ if contacts_params is None:
218
+
219
+ if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
220
+ contacts_params = js.contact.estimate_good_soft_contacts_parameters(
221
+ model=model, standard_gravity=standard_gravity
222
+ )
223
+ else:
224
+ contacts_params = model.contact_model.parameters
225
+
228
226
  return JaxSimModelData(
229
227
  time_ns=time_ns,
230
228
  state=ode_state,
@@ -36,7 +36,7 @@ class JaxSimModel(JaxsimDataclass):
36
36
  default=jaxsim.terrain.FlatTerrain(), repr=False
37
37
  )
38
38
 
39
- contact_model: jaxsim.rbda.ContactModel | None = dataclasses.field(
39
+ contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
40
40
  default=None, repr=False
41
41
  )
42
42
 
@@ -89,7 +89,7 @@ class JaxSimModel(JaxsimDataclass):
89
89
  model_name: str | None = None,
90
90
  *,
91
91
  terrain: jaxsim.terrain.Terrain | None = None,
92
- contact_model: jaxsim.rbda.ContactModel | None = None,
92
+ contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
93
93
  is_urdf: bool | None = None,
94
94
  considered_joints: Sequence[str] | None = None,
95
95
  ) -> JaxSimModel:
@@ -150,7 +150,7 @@ class JaxSimModel(JaxsimDataclass):
150
150
  model_name: str | None = None,
151
151
  *,
152
152
  terrain: jaxsim.terrain.Terrain | None = None,
153
- contact_model: jaxsim.rbda.ContactModel | None = None,
153
+ contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
154
154
  ) -> JaxSimModel:
155
155
  """
156
156
  Build a Model object from an intermediate model description.
@@ -169,14 +169,15 @@ class JaxSimModel(JaxsimDataclass):
169
169
  Returns:
170
170
  The built Model object.
171
171
  """
172
- from jaxsim.rbda.contacts.soft import SoftContacts
173
172
 
174
173
  # Set the model name (if not provided, use the one from the model description).
175
174
  model_name = model_name if model_name is not None else model_description.name
176
175
 
177
176
  # Set the terrain (if not provided, use the default flat terrain).
178
177
  terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
179
- contact_model = contact_model or SoftContacts(terrain=terrain)
178
+ contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts(
179
+ terrain=terrain
180
+ )
180
181
 
181
182
  # Build the model.
182
183
  model = JaxSimModel(
@@ -1930,8 +1931,6 @@ def step(
1930
1931
  and the new state of the integrator.
1931
1932
  """
1932
1933
 
1933
- from jaxsim.rbda.contacts.rigid import RigidContacts
1934
-
1935
1934
  # Extract the integrator kwargs.
1936
1935
  # The following logic allows using integrators having kwargs colliding with the
1937
1936
  # kwargs of this step function.
@@ -1992,12 +1991,16 @@ def step(
1992
1991
  # Post process the simulation state, if needed.
1993
1992
  match model.contact_model:
1994
1993
 
1995
- # Rigid contact models use an impact model that produces a discontinuous model velocity.
1996
- # Hence here we need to reset the velocity after each impact to guarantee that
1994
+ # Rigid contact models use an impact model that produces discontinuous model velocities.
1995
+ # Hence, here we need to reset the velocity after each impact to guarantee that
1997
1996
  # the linear velocity of the active collidable points is zero.
1998
- case RigidContacts():
1999
- # Raise runtime error for not supported case in which Rigid contacts and Baumgarte stabilization
2000
- # enabled are used with ForwardEuler integrator.
1997
+ case jaxsim.rbda.contacts.RigidContacts():
1998
+ assert isinstance(
1999
+ data_tf.contacts_params, jaxsim.rbda.contacts.RigidContactsParams
2000
+ )
2001
+
2002
+ # Raise runtime error for not supported case in which Rigid contacts and
2003
+ # Baumgarte stabilization are enabled and used with ForwardEuler integrator.
2001
2004
  jaxsim.exceptions.raise_runtime_error_if(
2002
2005
  condition=jnp.logical_and(
2003
2006
  isinstance(
@@ -2013,23 +2016,38 @@ def step(
2013
2016
  )
2014
2017
 
2015
2018
  with data_tf.switch_velocity_representation(VelRepr.Mixed):
2016
- W_p_C = js.contact.collidable_point_positions(model, data_tf)
2017
- M = js.model.free_floating_mass_matrix(model, data_tf)
2019
+
2018
2020
  J_WC = js.contact.jacobian(model, data_tf)
2021
+ M = js.model.free_floating_mass_matrix(model, data_tf)
2022
+ W_p_C = js.contact.collidable_point_positions(model, data_tf)
2023
+
2024
+ # Compute the height of the terrain below each collidable point.
2019
2025
  px, py, _ = W_p_C.T
2020
2026
  terrain_height = jax.vmap(model.terrain.height)(px, py)
2021
- inactive_collidable_points, _ = RigidContacts.detect_contacts(
2022
- W_p_C=W_p_C,
2023
- terrain_height=terrain_height,
2027
+
2028
+ # Compute the contact state.
2029
+ inactive_collidable_points, _ = (
2030
+ jaxsim.rbda.contacts.RigidContacts.detect_contacts(
2031
+ W_p_C=W_p_C,
2032
+ terrain_height=terrain_height,
2033
+ )
2024
2034
  )
2025
- BW_nu_post_impact = RigidContacts.compute_impact_velocity(
2026
- data=data_tf,
2027
- inactive_collidable_points=inactive_collidable_points,
2028
- M=M,
2029
- J_WC=J_WC,
2035
+
2036
+ # Compute the impact velocity.
2037
+ # It may be discontinuous in case new contacts are made.
2038
+ BW_nu_post_impact = (
2039
+ jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
2040
+ data=data_tf,
2041
+ inactive_collidable_points=inactive_collidable_points,
2042
+ M=M,
2043
+ J_WC=J_WC,
2044
+ )
2030
2045
  )
2046
+
2047
+ # Reset the generalized velocity.
2031
2048
  data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
2032
2049
  data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
2050
+
2033
2051
  # Restore the input velocity representation.
2034
2052
  data_tf = data_tf.replace(
2035
2053
  velocity_representation=data.velocity_representation, validate=False
@@ -5,13 +5,15 @@ import jax_dataclasses
5
5
 
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
- from jaxsim.rbda import ContactsState
9
- from jaxsim.rbda.contacts.relaxed_rigid import (
8
+ from jaxsim.rbda.contacts import (
9
+ ContactsState,
10
10
  RelaxedRigidContacts,
11
11
  RelaxedRigidContactsState,
12
+ RigidContacts,
13
+ RigidContactsState,
14
+ SoftContacts,
15
+ SoftContactsState,
12
16
  )
13
- from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
14
- from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
15
17
  from jaxsim.utils import JaxsimDataclass
16
18
 
17
19
  # =============================================================================
@@ -165,8 +167,11 @@ class ODEState(JaxsimDataclass):
165
167
 
166
168
  # Get the contact model from the `JaxSimModel`.
167
169
  match model.contact_model:
170
+
168
171
  case SoftContacts():
172
+
169
173
  tangential_deformation = kwargs.get("tangential_deformation", None)
174
+
170
175
  contact = SoftContactsState.build_from_jaxsim_model(
171
176
  model=model,
172
177
  **(
@@ -182,7 +187,7 @@ class ODEState(JaxsimDataclass):
182
187
  contact = RelaxedRigidContactsState.build()
183
188
 
184
189
  case _:
185
- raise ValueError("Unable to determine contact state class prefix.")
190
+ raise ValueError("Unsupported contact model.")
186
191
 
187
192
  return ODEState.build(
188
193
  model=model,
@@ -1,6 +1,6 @@
1
+ from . import contacts
1
2
  from .aba import aba
2
3
  from .collidable_points import collidable_points_pos_vel
3
- from .contacts.common import ContactModel, ContactsParams, ContactsState
4
4
  from .crba import crba
5
5
  from .forward_kinematics import forward_kinematics, forward_kinematics_model
6
6
  from .jacobian import (
@@ -1,3 +1,4 @@
1
+ from . import relaxed_rigid, rigid, soft
1
2
  from .common import ContactModel, ContactsParams, ContactsState
2
3
  from .relaxed_rigid import (
3
4
  RelaxedRigidContacts,
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import abc
4
4
  from typing import Any
5
5
 
6
+ import jaxsim.api as js
6
7
  import jaxsim.terrain
7
8
  import jaxsim.typing as jtp
8
9
  from jaxsim.utils import JaxsimDataclass
@@ -90,20 +91,48 @@ class ContactModel(JaxsimDataclass):
90
91
  @abc.abstractmethod
91
92
  def compute_contact_forces(
92
93
  self,
93
- position: jtp.VectorLike,
94
- velocity: jtp.VectorLike,
94
+ model: js.model.JaxSimModel,
95
+ data: js.data.JaxSimModelData,
95
96
  **kwargs,
96
97
  ) -> tuple[jtp.Vector, tuple[Any, ...]]:
97
98
  """
98
99
  Compute the contact forces.
99
100
 
100
101
  Args:
101
- position: The position of the collidable point w.r.t. the world frame.
102
- velocity:
103
- The linear velocity of the collidable point (linear component of the mixed 6D velocity).
102
+ model: The model to consider.
103
+ data: The data of the considered model.
104
104
 
105
105
  Returns:
106
106
  A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame,
107
107
  and as second element a tuple of optional additional information.
108
108
  """
109
+
109
110
  pass
111
+
112
+ def initialize_model_and_data(
113
+ self,
114
+ model: js.model.JaxSimModel,
115
+ data: js.data.JaxSimModelData,
116
+ validate: bool = True,
117
+ ) -> tuple[js.model.JaxSimModel, js.data.JaxSimModelData]:
118
+ """
119
+ Helper function to initialize the active model and data objects.
120
+
121
+ Args:
122
+ model: The robot model considered by the contact model.
123
+ data: The data of the considered robot model.
124
+ validate:
125
+ Whether to validate if the model and data objects have been
126
+ initialized with the current contact model.
127
+
128
+ Returns:
129
+ The initialized model and data objects.
130
+ """
131
+
132
+ with model.editable(validate=validate) as model_out:
133
+ model_out.contact_model = self
134
+
135
+ with data.editable(validate=validate) as data_out:
136
+ data_out.contacts_params = data.contacts_params
137
+
138
+ return model_out, data_out
@@ -169,12 +169,12 @@ class RelaxedRigidContactsState(ContactsState):
169
169
  return cls()
170
170
 
171
171
  @classmethod
172
- def zero(cls: type[Self]) -> Self:
172
+ def zero(cls: type[Self], **kwargs) -> Self:
173
173
  """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
174
174
 
175
175
  return cls.build()
176
176
 
177
- def valid(self, *, model: js.model.JaxSimModel) -> jtp.BoolLike:
177
+ def valid(self, **kwargs) -> jtp.BoolLike:
178
178
  return True
179
179
 
180
180
 
@@ -193,11 +193,9 @@ class RelaxedRigidContacts(ContactModel):
193
193
  @jax.jit
194
194
  def compute_contact_forces(
195
195
  self,
196
- position: jtp.VectorLike,
197
- velocity: jtp.VectorLike,
198
- *,
199
196
  model: js.model.JaxSimModel,
200
197
  data: js.data.JaxSimModelData,
198
+ *,
201
199
  link_forces: jtp.MatrixLike | None = None,
202
200
  joint_force_references: jtp.VectorLike | None = None,
203
201
  ) -> tuple[jtp.Vector, tuple[Any, ...]]:
@@ -205,10 +203,8 @@ class RelaxedRigidContacts(ContactModel):
205
203
  Compute the contact forces.
206
204
 
207
205
  Args:
208
- position: The position of the collidable point.
209
- velocity: The linear velocity of the collidable point.
210
- model: The `JaxSimModel` instance.
211
- data: The `JaxSimModelData` instance.
206
+ model: The model to consider.
207
+ data: The data of the considered model.
212
208
  link_forces:
213
209
  Optional `(n_links, 6)` matrix of external forces acting on the links,
214
210
  expressed in the same representation of data.
@@ -219,6 +215,11 @@ class RelaxedRigidContacts(ContactModel):
219
215
  A tuple containing the contact forces.
220
216
  """
221
217
 
218
+ # Initialize the model and data this contact model is operating on.
219
+ # This will raise an exception if either the contact model or the
220
+ # contact parameters are not compatible.
221
+ model, data = self.initialize_model_and_data(model=model, data=data)
222
+
222
223
  link_forces = (
223
224
  link_forces
224
225
  if link_forces is not None
@@ -247,6 +248,12 @@ class RelaxedRigidContacts(ContactModel):
247
248
 
248
249
  return jnp.dot(h, n̂)
249
250
 
251
+ # Compute the position and linear velocities (mixed representation) of
252
+ # all collidable points belonging to the robot.
253
+ position, velocity = js.contact.collidable_point_kinematics(
254
+ model=model, data=data
255
+ )
256
+
250
257
  # Compute the activation state of the collidable points
251
258
  δ = jax.vmap(_detect_contact)(*position.T)
252
259
 
@@ -92,12 +92,12 @@ class RigidContactsState(ContactsState):
92
92
  return cls()
93
93
 
94
94
  @classmethod
95
- def zero(cls: type[Self]) -> Self:
95
+ def zero(cls: type[Self], **kwargs) -> Self:
96
96
  """Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
97
97
 
98
98
  return cls.build()
99
99
 
100
- def valid(self) -> jtp.BoolLike:
100
+ def valid(self, **kwargs) -> jtp.BoolLike:
101
101
  return True
102
102
 
103
103
 
@@ -219,11 +219,9 @@ class RigidContacts(ContactModel):
219
219
  @jax.jit
220
220
  def compute_contact_forces(
221
221
  self,
222
- position: jtp.VectorLike,
223
- velocity: jtp.VectorLike,
224
- *,
225
222
  model: js.model.JaxSimModel,
226
223
  data: js.data.JaxSimModelData,
224
+ *,
227
225
  link_forces: jtp.MatrixLike | None = None,
228
226
  joint_force_references: jtp.VectorLike | None = None,
229
227
  regularization_term: jtp.FloatLike = 1e-6,
@@ -233,10 +231,8 @@ class RigidContacts(ContactModel):
233
231
  Compute the contact forces.
234
232
 
235
233
  Args:
236
- position: The position of the collidable point.
237
- velocity: The linear velocity of the collidable point.
238
- model: The `JaxSimModel` instance.
239
- data: The `JaxSimModelData` instance.
234
+ model: The model to consider.
235
+ data: The data of the considered model.
240
236
  link_forces:
241
237
  Optional `(n_links, 6)` matrix of external forces acting on the links,
242
238
  expressed in the same representation of data.
@@ -245,11 +241,17 @@ class RigidContacts(ContactModel):
245
241
  regularization_term:
246
242
  The regularization term to add to the diagonal of the Delassus
247
243
  matrix for better numerical conditioning.
244
+ solver_tol: The convergence tolerance to consider in the QP solver.
248
245
 
249
246
  Returns:
250
247
  A tuple containing the contact forces.
251
248
  """
252
249
 
250
+ # Initialize the model and data this contact model is operating on.
251
+ # This will raise an exception if either the contact model or the
252
+ # contact parameters are not compatible.
253
+ model, data = self.initialize_model_and_data(model=model, data=data)
254
+
253
255
  # Import qpax just in this method
254
256
  import qpax
255
257
 
@@ -273,6 +275,12 @@ class RigidContacts(ContactModel):
273
275
  J̇_WC_BW = js.contact.jacobian_derivative(model=model, data=data)
274
276
  BW_ν = data.generalized_velocity()
275
277
 
278
+ # Compute the position and linear velocities (mixed representation) of
279
+ # all collidable points belonging to the robot.
280
+ position, velocity = js.contact.collidable_point_kinematics(
281
+ model=model, data=data
282
+ )
283
+
276
284
  terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1])
277
285
  n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0]
278
286