jaxsim 0.3.1.dev17__tar.gz → 0.3.1.dev40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.pre-commit-config.yaml +5 -0
  2. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/PKG-INFO +1 -1
  3. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/modules/rbda.rst +6 -7
  4. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/contact.py +30 -22
  6. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/data.py +27 -17
  7. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/model.py +19 -2
  8. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/ode.py +1 -1
  9. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/ode_data.py +32 -152
  10. jaxsim-0.3.1.dev40/src/jaxsim/exceptions.py +63 -0
  11. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/rbda/__init__.py +1 -1
  12. jaxsim-0.3.1.dev40/src/jaxsim/rbda/contacts/common.py +101 -0
  13. jaxsim-0.3.1.dev17/src/jaxsim/rbda/soft_contacts.py → jaxsim-0.3.1.dev40/src/jaxsim/rbda/contacts/soft.py +149 -9
  14. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim.egg-info/PKG-INFO +1 -1
  15. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim.egg-info/SOURCES.txt +5 -1
  16. jaxsim-0.3.1.dev40/tests/__init__.py +0 -0
  17. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_automatic_differentiation.py +7 -7
  18. jaxsim-0.3.1.dev40/tests/test_exceptions.py +88 -0
  19. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_simulations.py +2 -1
  20. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.devcontainer/Dockerfile +0 -0
  21. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.devcontainer/devcontainer.json +0 -0
  22. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.gitattributes +0 -0
  23. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.github/CODEOWNERS +0 -0
  24. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.github/workflows/ci_cd.yml +0 -0
  25. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.github/workflows/read_the_docs.yml +0 -0
  26. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.github/workflows/style.yml +0 -0
  27. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.gitignore +0 -0
  28. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/.readthedocs.yaml +0 -0
  29. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/CONTRIBUTING.md +0 -0
  30. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/LICENSE +0 -0
  31. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/README.md +0 -0
  32. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/Makefile +0 -0
  33. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/conf.py +0 -0
  34. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/guide/install.rst +0 -0
  35. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/index.rst +0 -0
  36. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/make.bat +0 -0
  37. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/modules/api.rst +0 -0
  38. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/modules/index.rst +0 -0
  39. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/modules/integrators.rst +0 -0
  40. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/modules/math.rst +0 -0
  41. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/modules/mujoco.rst +0 -0
  42. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/modules/parsers.rst +0 -0
  43. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/modules/typing.rst +0 -0
  44. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/docs/modules/utils.rst +0 -0
  45. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/environment.yml +0 -0
  46. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/examples/.gitattributes +0 -0
  47. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/examples/.gitignore +0 -0
  48. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/examples/PD_controller.ipynb +0 -0
  49. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/examples/Parallel_computing.ipynb +0 -0
  50. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/examples/README.md +0 -0
  51. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/examples/assets/cartpole.urdf +0 -0
  52. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/pixi.lock +0 -0
  53. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/pyproject.toml +0 -0
  54. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/setup.cfg +0 -0
  55. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/setup.py +0 -0
  56. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/__init__.py +0 -0
  57. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/__init__.py +0 -0
  58. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/com.py +0 -0
  59. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/common.py +0 -0
  60. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/frame.py +0 -0
  61. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/joint.py +0 -0
  62. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  63. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/link.py +0 -0
  64. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/api/references.py +0 -0
  65. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/integrators/__init__.py +0 -0
  66. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/integrators/common.py +0 -0
  67. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/integrators/fixed_step.py +0 -0
  68. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/integrators/variable_step.py +0 -0
  69. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/logging.py +0 -0
  70. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/math/__init__.py +0 -0
  71. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/math/adjoint.py +0 -0
  72. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/math/cross.py +0 -0
  73. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/math/inertia.py +0 -0
  74. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/math/joint_model.py +0 -0
  75. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/math/quaternion.py +0 -0
  76. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/math/rotation.py +0 -0
  77. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/math/skew.py +0 -0
  78. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/math/transform.py +0 -0
  79. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/mujoco/__init__.py +0 -0
  80. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/mujoco/__main__.py +0 -0
  81. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/mujoco/loaders.py +0 -0
  82. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/mujoco/model.py +0 -0
  83. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/mujoco/visualizer.py +0 -0
  84. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/__init__.py +0 -0
  85. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  86. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  87. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  88. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/descriptions/link.py +0 -0
  89. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/descriptions/model.py +0 -0
  90. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  91. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/rod/__init__.py +0 -0
  92. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/rod/parser.py +0 -0
  93. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/parsers/rod/utils.py +0 -0
  94. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/rbda/aba.py +0 -0
  95. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/rbda/collidable_points.py +0 -0
  96. {jaxsim-0.3.1.dev17/tests → jaxsim-0.3.1.dev40/src/jaxsim/rbda/contacts}/__init__.py +0 -0
  97. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/rbda/crba.py +0 -0
  98. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  99. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/rbda/jacobian.py +0 -0
  100. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/rbda/rnea.py +0 -0
  101. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/rbda/utils.py +0 -0
  102. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/terrain/__init__.py +0 -0
  103. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/terrain/terrain.py +0 -0
  104. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/typing.py +0 -0
  105. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/utils/__init__.py +0 -0
  106. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  107. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/utils/tracing.py +0 -0
  108. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim/utils/wrappers.py +0 -0
  109. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  110. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim.egg-info/not-zip-safe +0 -0
  111. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim.egg-info/requires.txt +0 -0
  112. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/src/jaxsim.egg-info/top_level.txt +0 -0
  113. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/conftest.py +0 -0
  114. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_api_com.py +0 -0
  115. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_api_data.py +0 -0
  116. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_api_frame.py +0 -0
  117. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_api_joint.py +0 -0
  118. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_api_link.py +0 -0
  119. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_api_model.py +0 -0
  120. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_contact.py +0 -0
  121. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/test_pytree.py +0 -0
  122. {jaxsim-0.3.1.dev17 → jaxsim-0.3.1.dev40}/tests/utils_idyntree.py +0 -0
@@ -24,3 +24,8 @@ repos:
24
24
  rev: v0.3.2
25
25
  hooks:
26
26
  - id: ruff
27
+
28
+ - repo: https://github.com/kynan/nbstripout
29
+ rev: 0.7.1
30
+ hooks:
31
+ - id: nbstripout
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.3.1.dev17
3
+ Version: 0.3.1.dev40
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -13,7 +13,6 @@ This module provides a set of algorithms for rigid body dynamics.
13
13
  crba
14
14
  forward_kinematics
15
15
  jacobian
16
- soft_contacts
17
16
  utils
18
17
 
19
18
  Articulated Body Algorithm
@@ -28,6 +27,12 @@ Collision Detection
28
27
  .. automodule:: jaxsim.rbda.collidable_points
29
28
  :members:
30
29
 
30
+ Contact Models
31
+ ~~~~~~~~~~~~~~
32
+
33
+ .. automodule:: jaxsim.rbda.contacts.soft
34
+ :members:
35
+
31
36
  Composite Rigid Body Algorithm
32
37
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33
38
 
@@ -46,12 +51,6 @@ Jacobians
46
51
  .. automodule:: jaxsim.rbda.jacobian
47
52
  :members:
48
53
 
49
- Soft Contacts
50
- ~~~~~~~~~~~~~
51
-
52
- .. automodule:: jaxsim.rbda.soft_contacts
53
- :members:
54
-
55
54
  Utilities
56
55
  ~~~~~~~~~
57
56
 
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.1.dev17'
16
- __version_tuple__ = version_tuple = (0, 3, 1, 'dev17')
15
+ __version__ = version = '0.3.1.dev40'
16
+ __version_tuple__ = version_tuple = (0, 3, 1, 'dev40')
@@ -1,11 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import functools
2
4
 
3
5
  import jax
4
6
  import jax.numpy as jnp
5
7
 
6
8
  import jaxsim.api as js
7
- import jaxsim.rbda
9
+ import jaxsim.terrain
8
10
  import jaxsim.typing as jtp
11
+ from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsParams
9
12
 
10
13
  from .common import VelRepr
11
14
 
@@ -135,17 +138,23 @@ def collidable_point_dynamics(
135
138
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
136
139
 
137
140
  # Build the soft contact model.
138
- soft_contacts = jaxsim.rbda.SoftContacts(
139
- parameters=data.soft_contacts_params, terrain=model.terrain
140
- )
141
+ match model.contact_model:
142
+ case s if isinstance(s, SoftContacts):
143
+ # Build the contact model.
144
+ soft_contacts = SoftContacts(
145
+ parameters=data.contacts_params, terrain=model.terrain
146
+ )
147
+
148
+ # Compute the 6D force expressed in the inertial frame and applied to each
149
+ # collidable point, and the corresponding material deformation rate.
150
+ # Note that the material deformation rate is always returned in the mixed frame
151
+ # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
152
+ W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
153
+ W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
154
+ )
141
155
 
142
- # Compute the 6D force expressed in the inertial frame and applied to each
143
- # collidable point, and the corresponding material deformation rate.
144
- # Note that the material deformation rate is always returned in the mixed frame
145
- # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
146
- W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)(
147
- W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation
148
- )
156
+ case _:
157
+ raise ValueError("Invalid contact model {}".format(model.contact_model))
149
158
 
150
159
  # Convert the 6D forces to the active representation.
151
160
  f_Ci = jax.vmap(
@@ -213,7 +222,7 @@ def estimate_good_soft_contacts_parameters(
213
222
  number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
214
223
  damping_ratio: jtp.FloatLike = 1.0,
215
224
  max_penetration: jtp.FloatLike | None = None,
216
- ) -> jaxsim.rbda.soft_contacts.SoftContactsParams:
225
+ ) -> SoftContactsParams:
217
226
  """
218
227
  Estimate good soft contacts parameters for the given model.
219
228
 
@@ -237,13 +246,14 @@ def estimate_good_soft_contacts_parameters(
237
246
  The user is encouraged to fine-tune the parameters based on the
238
247
  specific application.
239
248
  """
249
+ from jaxsim.rbda.contacts.soft import SoftContactsParams
240
250
 
241
251
  def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
242
252
  """"""
243
253
 
244
254
  zero_data = js.data.JaxSimModelData.build(
245
255
  model=model,
246
- soft_contacts_params=jaxsim.rbda.soft_contacts.SoftContactsParams(),
256
+ contacts_params=SoftContactsParams(),
247
257
  )
248
258
 
249
259
  W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
@@ -262,15 +272,13 @@ def estimate_good_soft_contacts_parameters(
262
272
 
263
273
  nc = number_of_active_collidable_points_steady_state
264
274
 
265
- sc_parameters = (
266
- jaxsim.rbda.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model(
267
- model=model,
268
- standard_gravity=standard_gravity,
269
- static_friction_coefficient=static_friction_coefficient,
270
- max_penetration=max_δ,
271
- number_of_active_collidable_points_steady_state=nc,
272
- damping_ratio=damping_ratio,
273
- )
275
+ sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
276
+ model=model,
277
+ standard_gravity=standard_gravity,
278
+ static_friction_coefficient=static_friction_coefficient,
279
+ max_penetration=max_δ,
280
+ number_of_active_collidable_points_steady_state=nc,
281
+ damping_ratio=damping_ratio,
274
282
  )
275
283
 
276
284
  return sc_parameters
@@ -14,6 +14,7 @@ import jaxsim.api as js
14
14
  import jaxsim.rbda
15
15
  import jaxsim.typing as jtp
16
16
  from jaxsim.math import Quaternion
17
+ from jaxsim.rbda.contacts.soft import SoftContacts
17
18
  from jaxsim.utils import Mutability
18
19
  from jaxsim.utils.tracing import not_tracing
19
20
 
@@ -37,7 +38,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
37
38
 
38
39
  gravity: jtp.Array
39
40
 
40
- soft_contacts_params: jaxsim.rbda.SoftContactsParams = dataclasses.field(repr=False)
41
+ contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
41
42
 
42
43
  time_ns: jtp.Int = dataclasses.field(
43
44
  default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
@@ -51,8 +52,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
51
52
  (
52
53
  hash(self.state),
53
54
  HashedNumpyArray.hash_of_array(self.gravity),
54
- hash(self.soft_contacts_params),
55
55
  HashedNumpyArray.hash_of_array(self.time_ns),
56
+ hash(self.contacts_params),
56
57
  )
57
58
  )
58
59
 
@@ -112,8 +113,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
112
113
  base_angular_velocity: jtp.Vector | None = None,
113
114
  joint_velocities: jtp.Vector | None = None,
114
115
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
115
- soft_contacts_state: js.ode_data.SoftContactsState | None = None,
116
- soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
116
+ contact: jaxsim.rbda.ContactsState | None = None,
117
+ contacts_params: jaxsim.rbda.ContactsParams | None = None,
117
118
  velocity_representation: VelRepr = VelRepr.Inertial,
118
119
  time: jtp.FloatLike | None = None,
119
120
  ) -> JaxSimModelData:
@@ -131,8 +132,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
131
132
  The base angular velocity in the selected representation.
132
133
  joint_velocities: The joint velocities.
133
134
  standard_gravity: The standard gravity constant.
134
- soft_contacts_state: The state of the soft contacts.
135
- soft_contacts_params: The parameters of the soft contacts.
135
+ contact: The state of the soft contacts.
136
+ contacts_params: The parameters of the soft contacts.
136
137
  velocity_representation: The velocity representation to use.
137
138
  time: The time at which the state is created.
138
139
 
@@ -178,13 +179,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
178
179
  else jnp.array(0, dtype=jnp.uint64)
179
180
  )
180
181
 
181
- soft_contacts_params = (
182
- soft_contacts_params
183
- if soft_contacts_params is not None
184
- else js.contact.estimate_good_soft_contacts_parameters(
185
- model=model, standard_gravity=standard_gravity
182
+ if isinstance(model.contact_model, SoftContacts):
183
+ contacts_params = (
184
+ contacts_params
185
+ if contacts_params is not None
186
+ else js.contact.estimate_good_soft_contacts_parameters(
187
+ model=model, standard_gravity=standard_gravity
188
+ )
186
189
  )
187
- )
190
+ else:
191
+ contacts_params = model.contact_model.parameters
188
192
 
189
193
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
190
194
  translation=base_position,
@@ -209,8 +213,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
209
213
  base_angular_velocity=v_WB[3:6].astype(float),
210
214
  joint_velocities=joint_velocities.astype(float),
211
215
  tangential_deformation=(
212
- soft_contacts_state.tangential_deformation
213
- if soft_contacts_state is not None
216
+ contact.tangential_deformation
217
+ if contact is not None and isinstance(model.contact_model, SoftContacts)
214
218
  else None
215
219
  ),
216
220
  )
@@ -222,7 +226,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
222
226
  time_ns=time_ns,
223
227
  state=ode_state,
224
228
  gravity=gravity.astype(float),
225
- soft_contacts_params=soft_contacts_params,
229
+ contacts_params=contacts_params,
226
230
  velocity_representation=velocity_representation,
227
231
  )
228
232
 
@@ -652,7 +656,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
652
656
 
653
657
  return self.reset_base_velocity(
654
658
  base_velocity=jnp.hstack(
655
- [linear_velocity.squeeze(), self.base_velocity()[3:6]]
659
+ [
660
+ linear_velocity.squeeze(),
661
+ self.base_velocity()[3:6],
662
+ ]
656
663
  ),
657
664
  velocity_representation=velocity_representation,
658
665
  )
@@ -680,7 +687,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
680
687
 
681
688
  return self.reset_base_velocity(
682
689
  base_velocity=jnp.hstack(
683
- [self.base_velocity()[0:3], angular_velocity.squeeze()]
690
+ [
691
+ self.base_velocity()[0:3],
692
+ angular_velocity.squeeze(),
693
+ ]
684
694
  ),
685
695
  velocity_representation=velocity_representation,
686
696
  )
@@ -34,6 +34,10 @@ class JaxSimModel(JaxsimDataclass):
34
34
  default=jaxsim.terrain.FlatTerrain(), repr=False
35
35
  )
36
36
 
37
+ contact_model: jaxsim.rbda.ContactModel | None = dataclasses.field(
38
+ default=None, repr=False
39
+ )
40
+
37
41
  kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
38
42
  dataclasses.field(default=None, repr=False)
39
43
  )
@@ -69,6 +73,7 @@ class JaxSimModel(JaxsimDataclass):
69
73
  (
70
74
  hash(self.model_name),
71
75
  hash(self.kin_dyn_parameters),
76
+ hash(self.contact_model),
72
77
  )
73
78
  )
74
79
 
@@ -82,6 +87,7 @@ class JaxSimModel(JaxsimDataclass):
82
87
  model_name: str | None = None,
83
88
  *,
84
89
  terrain: jaxsim.terrain.Terrain | None = None,
90
+ contact_model: jaxsim.rbda.ContactModel | None = None,
85
91
  is_urdf: bool | None = None,
86
92
  considered_joints: Sequence[str] | None = None,
87
93
  ) -> JaxSimModel:
@@ -127,6 +133,7 @@ class JaxSimModel(JaxsimDataclass):
127
133
  model_description=intermediate_description,
128
134
  model_name=model_name,
129
135
  terrain=terrain,
136
+ contact_model=contact_model,
130
137
  )
131
138
 
132
139
  # Store the origin of the model, in case downstream logic needs it
@@ -141,6 +148,7 @@ class JaxSimModel(JaxsimDataclass):
141
148
  model_name: str | None = None,
142
149
  *,
143
150
  terrain: jaxsim.terrain.Terrain | None = None,
151
+ contact_model: jaxsim.rbda.ContactModel | None = None,
144
152
  ) -> JaxSimModel:
145
153
  """
146
154
  Build a Model object from an intermediate model description.
@@ -153,22 +161,30 @@ class JaxSimModel(JaxsimDataclass):
153
161
  The optional name of the model overriding the physics model name.
154
162
  terrain:
155
163
  The optional terrain to consider.
164
+ contact_model:
165
+ The optional contact model to consider. If None, the soft contact model is used.
156
166
 
157
167
  Returns:
158
168
  The built Model object.
159
169
  """
170
+ from jaxsim.rbda.contacts.soft import SoftContacts
160
171
 
161
172
  # Set the model name (if not provided, use the one from the model description)
162
173
  model_name = model_name if model_name is not None else model_description.name
163
174
 
164
- # Build the model.
175
+ # Set the terrain (if not provided, use the default flat terrain)
176
+ terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
177
+ contact_model = contact_model or SoftContacts(terrain=terrain)
178
+
179
+ # Build the model
165
180
  model = JaxSimModel(
166
181
  model_name=model_name,
167
182
  _description=wrappers.HashlessObject(obj=model_description),
168
183
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
169
184
  model_description=model_description
170
185
  ),
171
- terrain=terrain or JaxSimModel.__dataclass_fields__["terrain"].default,
186
+ terrain=terrain,
187
+ contact_model=contact_model,
172
188
  )
173
189
 
174
190
  return model
@@ -350,6 +366,7 @@ def reduce(
350
366
  model_description=reduced_intermediate_description,
351
367
  model_name=model.name(),
352
368
  terrain=model.terrain,
369
+ contact_model=model.contact_model,
353
370
  )
354
371
 
355
372
  # Store the origin of the model, in case downstream logic needs it
@@ -132,7 +132,7 @@ def system_velocity_dynamics(
132
132
  W_f_Ci = None
133
133
 
134
134
  # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
135
- ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
135
+ ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)
136
136
 
137
137
  if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
138
138
  # Compute the 6D forces applied to each collidable point and the
@@ -5,6 +5,8 @@ import jax_dataclasses
5
5
 
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
+ from jaxsim.rbda import ContactsState
9
+ from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
8
10
  from jaxsim.utils import JaxsimDataclass
9
11
 
10
12
  # =============================================================================
@@ -116,11 +118,11 @@ class ODEState(JaxsimDataclass):
116
118
 
117
119
  Attributes:
118
120
  physics_model: The state of the physics model.
119
- soft_contacts: The state of the soft-contacts model.
121
+ contact: The state of the contacts model.
120
122
  """
121
123
 
122
124
  physics_model: PhysicsModelState
123
- soft_contacts: SoftContactsState
125
+ contact: ContactsState
124
126
 
125
127
  @staticmethod
126
128
  def build_from_jaxsim_model(
@@ -158,6 +160,20 @@ class ODEState(JaxsimDataclass):
158
160
  `JaxSimModel` and initialized to zero.
159
161
  """
160
162
 
163
+ # Get the contact model from the `JaxSimModel`
164
+ match model.contact_model:
165
+ case SoftContacts():
166
+ contact = SoftContactsState.build_from_jaxsim_model(
167
+ model=model,
168
+ **(
169
+ dict(tangential_deformation=tangential_deformation)
170
+ if tangential_deformation is not None
171
+ else dict()
172
+ ),
173
+ )
174
+ case _:
175
+ raise ValueError("Unable to determine contact state class prefix.")
176
+
161
177
  return ODEState.build(
162
178
  model=model,
163
179
  physics_model_state=PhysicsModelState.build_from_jaxsim_model(
@@ -169,24 +185,21 @@ class ODEState(JaxsimDataclass):
169
185
  base_linear_velocity=base_linear_velocity,
170
186
  base_angular_velocity=base_angular_velocity,
171
187
  ),
172
- soft_contacts_state=SoftContactsState.build_from_jaxsim_model(
173
- model=model,
174
- tangential_deformation=tangential_deformation,
175
- ),
188
+ contact=contact,
176
189
  )
177
190
 
178
191
  @staticmethod
179
192
  def build(
180
193
  physics_model_state: PhysicsModelState | None = None,
181
- soft_contacts_state: SoftContactsState | None = None,
194
+ contact: ContactsState | None = None,
182
195
  model: js.model.JaxSimModel | None = None,
183
196
  ) -> ODEState:
184
197
  """
185
- Build an `ODEState` from a `PhysicsModelState` and a `SoftContactsState`.
198
+ Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`.
186
199
 
187
200
  Args:
188
201
  physics_model_state: The state of the physics model.
189
- soft_contacts_state: The state of the soft-contacts model.
202
+ contact: The state of the contacts model.
190
203
  model: The `JaxSimModel` associated with the ODE state.
191
204
 
192
205
  Returns:
@@ -199,15 +212,16 @@ class ODEState(JaxsimDataclass):
199
212
  else PhysicsModelState.zero(model=model)
200
213
  )
201
214
 
202
- soft_contacts_state = (
203
- soft_contacts_state
204
- if soft_contacts_state is not None
205
- else SoftContactsState.zero(model=model)
206
- )
215
+ # Get the contact model from the `JaxSimModel`
216
+ match contact:
217
+ case SoftContactsState():
218
+ pass
219
+ case None:
220
+ contact = SoftContactsState.zero(model=model)
221
+ case _:
222
+ raise ValueError("Unable to determine contact state class prefix.")
207
223
 
208
- return ODEState(
209
- physics_model=physics_model_state, soft_contacts=soft_contacts_state
210
- )
224
+ return ODEState(physics_model=physics_model_state, contact=contact)
211
225
 
212
226
  @staticmethod
213
227
  def zero(model: js.model.JaxSimModel) -> ODEState:
@@ -236,9 +250,7 @@ class ODEState(JaxsimDataclass):
236
250
  `True` if the ODE state is valid for the given model, `False` otherwise.
237
251
  """
238
252
 
239
- return self.physics_model.valid(model=model) and self.soft_contacts.valid(
240
- model=model
241
- )
253
+ return self.physics_model.valid(model=model) and self.contact.valid(model=model)
242
254
 
243
255
 
244
256
  # ==================================================
@@ -595,135 +607,3 @@ class PhysicsModelInput(JaxsimDataclass):
595
607
  return False
596
608
 
597
609
  return True
598
-
599
-
600
- # ===========================================
601
- # Define the state of the soft-contacts model
602
- # ===========================================
603
-
604
-
605
- @jax_dataclasses.pytree_dataclass
606
- class SoftContactsState(JaxsimDataclass):
607
- """
608
- Class storing the state of the soft contacts model.
609
-
610
- Attributes:
611
- tangential_deformation:
612
- The matrix of 3D tangential material deformations corresponding to
613
- each collidable point.
614
- """
615
-
616
- tangential_deformation: jtp.Matrix
617
-
618
- def __hash__(self) -> int:
619
-
620
- from jaxsim.utils.wrappers import HashedNumpyArray
621
-
622
- return HashedNumpyArray.hash_of_array(self.tangential_deformation)
623
-
624
- def __eq__(self, other: SoftContactsState) -> bool:
625
-
626
- if not isinstance(other, SoftContactsState):
627
- return False
628
-
629
- return hash(self) == hash(other)
630
-
631
- @staticmethod
632
- def build_from_jaxsim_model(
633
- model: js.model.JaxSimModel | None = None,
634
- tangential_deformation: jtp.Matrix | None = None,
635
- ) -> SoftContactsState:
636
- """
637
- Build a `SoftContactsState` from a `JaxSimModel`.
638
-
639
- Args:
640
- model: The `JaxSimModel` associated with the soft contacts state.
641
- tangential_deformation: The matrix of 3D tangential material deformations.
642
-
643
- Returns:
644
- The `SoftContactsState` built from the `JaxSimModel`.
645
-
646
- Note:
647
- If any of the state components are not provided, they are built from the
648
- `JaxSimModel` and initialized to zero.
649
- """
650
-
651
- return SoftContactsState.build(
652
- tangential_deformation=tangential_deformation,
653
- number_of_collidable_points=len(
654
- model.kin_dyn_parameters.contact_parameters.body
655
- ),
656
- )
657
-
658
- @staticmethod
659
- def build(
660
- tangential_deformation: jtp.Matrix | None = None,
661
- number_of_collidable_points: int | None = None,
662
- ) -> SoftContactsState:
663
- """
664
- Create a `SoftContactsState`.
665
-
666
- Args:
667
- tangential_deformation:
668
- The matrix of 3D tangential material deformations corresponding to
669
- each collidable point.
670
- number_of_collidable_points: The number of collidable points.
671
-
672
- Returns:
673
- A `SoftContactsState` instance.
674
- """
675
-
676
- tangential_deformation = (
677
- tangential_deformation
678
- if tangential_deformation is not None
679
- else jnp.zeros(shape=(number_of_collidable_points, 3))
680
- )
681
-
682
- if tangential_deformation.shape[1] != 3:
683
- raise RuntimeError("The tangential deformation matrix must have 3 columns.")
684
-
685
- if (
686
- number_of_collidable_points is not None
687
- and tangential_deformation.shape[0] != number_of_collidable_points
688
- ):
689
- msg = "The number of collidable points must match the number of rows "
690
- msg += "in the tangential deformation matrix."
691
- raise RuntimeError(msg)
692
-
693
- return SoftContactsState(
694
- tangential_deformation=jnp.array(tangential_deformation).astype(float)
695
- )
696
-
697
- @staticmethod
698
- def zero(model: js.model.JaxSimModel) -> SoftContactsState:
699
- """
700
- Build a zero `SoftContactsState` from a `JaxSimModel`.
701
-
702
- Args:
703
- model: The `JaxSimModel` associated with the soft contacts state.
704
-
705
- Returns:
706
- A zero `SoftContactsState` instance.
707
- """
708
-
709
- return SoftContactsState.build_from_jaxsim_model(model=model)
710
-
711
- def valid(self, model: js.model.JaxSimModel) -> bool:
712
- """
713
- Check if the `SoftContactsState` is valid for a given `JaxSimModel`.
714
-
715
- Args:
716
- model: The `JaxSimModel` to validate the `SoftContactsState` against.
717
-
718
- Returns:
719
- `True` if the soft contacts state is valid for the given `JaxSimModel`,
720
- `False` otherwise.
721
- """
722
-
723
- shape = self.tangential_deformation.shape
724
- expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3)
725
-
726
- if shape != expected:
727
- return False
728
-
729
- return True
@@ -0,0 +1,63 @@
1
+ import jax
2
+
3
+
4
+ def raise_if(
5
+ condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs
6
+ ) -> None:
7
+ """
8
+ Raise a host-side exception if a condition is met. Useful in jit-compiled functions.
9
+
10
+ Args:
11
+ condition:
12
+ The boolean condition of the evaluated expression that triggers
13
+ the exception during runtime.
14
+ exception: The type of exception to raise.
15
+ msg:
16
+ The message to display when the exception is raised. The message can be a
17
+ format string (fmt), whose fields are filled with the args and kwargs.
18
+ """
19
+
20
+ # Check early that the format string is well-formed.
21
+ try:
22
+ _ = msg.format(*args, **kwargs)
23
+ except Exception as e:
24
+ msg = "Error in formatting exception message with args={} and kwargs={}"
25
+ raise ValueError(msg.format(args, kwargs)) from e
26
+
27
+ def _raise_exception(condition: bool, *args, **kwargs) -> None:
28
+ """The function called by the JAX callback."""
29
+
30
+ if condition:
31
+ raise exception(msg.format(*args, **kwargs))
32
+
33
+ def _callback(args, kwargs) -> None:
34
+ """The function that calls the JAX callback, executed only when needed."""
35
+
36
+ jax.debug.callback(_raise_exception, condition, *args, **kwargs)
37
+
38
+ # Since running a callable on the host is expensive, we prevent its execution
39
+ # if the condition is False with a low-level conditional expression.
40
+ def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
41
+ return jax.lax.cond(
42
+ condition,
43
+ _callback,
44
+ lambda args, kwargs: None,
45
+ args,
46
+ kwargs,
47
+ )
48
+
49
+ return _run_callback_only_if_condition_is_true(*args, **kwargs)
50
+
51
+
52
+ def raise_runtime_error_if(
53
+ condition: bool | jax.Array, msg: str, *args, **kwargs
54
+ ) -> None:
55
+
56
+ return raise_if(condition, RuntimeError, msg, *args, **kwargs)
57
+
58
+
59
+ def raise_value_error_if(
60
+ condition: bool | jax.Array, msg: str, *args, **kwargs
61
+ ) -> None:
62
+
63
+ return raise_if(condition, ValueError, msg, *args, **kwargs)
@@ -1,5 +1,6 @@
1
1
  from .aba import aba
2
2
  from .collidable_points import collidable_points_pos_vel
3
+ from .contacts.common import ContactModel, ContactsParams, ContactsState
3
4
  from .crba import crba
4
5
  from .forward_kinematics import forward_kinematics, forward_kinematics_model
5
6
  from .jacobian import (
@@ -8,4 +9,3 @@ from .jacobian import (
8
9
  jacobian_full_doubly_left,
9
10
  )
10
11
  from .rnea import rnea
11
- from .soft_contacts import SoftContacts, SoftContactsParams