jaxsim 0.3.1.dev21__tar.gz → 0.3.1.dev46__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.pre-commit-config.yaml +5 -0
  2. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/PKG-INFO +1 -1
  3. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/modules/rbda.rst +6 -7
  4. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/contact.py +30 -22
  6. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/data.py +27 -17
  7. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/frame.py +53 -18
  8. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/joint.py +28 -7
  9. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/link.py +73 -10
  10. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/model.py +19 -2
  11. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/ode.py +1 -1
  12. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/ode_data.py +32 -152
  13. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/rbda/__init__.py +1 -1
  14. jaxsim-0.3.1.dev46/src/jaxsim/rbda/contacts/common.py +101 -0
  15. jaxsim-0.3.1.dev21/src/jaxsim/rbda/soft_contacts.py → jaxsim-0.3.1.dev46/src/jaxsim/rbda/contacts/soft.py +149 -9
  16. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim.egg-info/PKG-INFO +1 -1
  17. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim.egg-info/SOURCES.txt +3 -1
  18. jaxsim-0.3.1.dev46/tests/__init__.py +0 -0
  19. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_api_frame.py +25 -3
  20. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_api_joint.py +10 -0
  21. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_api_link.py +10 -0
  22. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_automatic_differentiation.py +7 -7
  23. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_simulations.py +2 -1
  24. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.devcontainer/Dockerfile +0 -0
  25. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.devcontainer/devcontainer.json +0 -0
  26. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.gitattributes +0 -0
  27. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.github/CODEOWNERS +0 -0
  28. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.github/workflows/ci_cd.yml +0 -0
  29. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.github/workflows/read_the_docs.yml +0 -0
  30. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.github/workflows/style.yml +0 -0
  31. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.gitignore +0 -0
  32. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/.readthedocs.yaml +0 -0
  33. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/CONTRIBUTING.md +0 -0
  34. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/LICENSE +0 -0
  35. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/README.md +0 -0
  36. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/Makefile +0 -0
  37. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/conf.py +0 -0
  38. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/guide/install.rst +0 -0
  39. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/index.rst +0 -0
  40. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/make.bat +0 -0
  41. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/modules/api.rst +0 -0
  42. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/modules/index.rst +0 -0
  43. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/modules/integrators.rst +0 -0
  44. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/modules/math.rst +0 -0
  45. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/modules/mujoco.rst +0 -0
  46. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/modules/parsers.rst +0 -0
  47. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/modules/typing.rst +0 -0
  48. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/docs/modules/utils.rst +0 -0
  49. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/environment.yml +0 -0
  50. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/examples/.gitattributes +0 -0
  51. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/examples/.gitignore +0 -0
  52. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/examples/PD_controller.ipynb +0 -0
  53. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/examples/Parallel_computing.ipynb +0 -0
  54. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/examples/README.md +0 -0
  55. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/examples/assets/cartpole.urdf +0 -0
  56. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/pixi.lock +0 -0
  57. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/pyproject.toml +0 -0
  58. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/setup.cfg +0 -0
  59. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/setup.py +0 -0
  60. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/__init__.py +0 -0
  61. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/__init__.py +0 -0
  62. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/com.py +0 -0
  63. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/common.py +0 -0
  64. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  65. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/api/references.py +0 -0
  66. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/exceptions.py +0 -0
  67. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/integrators/__init__.py +0 -0
  68. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/integrators/common.py +0 -0
  69. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/integrators/fixed_step.py +0 -0
  70. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/integrators/variable_step.py +0 -0
  71. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/logging.py +0 -0
  72. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/math/__init__.py +0 -0
  73. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/math/adjoint.py +0 -0
  74. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/math/cross.py +0 -0
  75. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/math/inertia.py +0 -0
  76. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/math/joint_model.py +0 -0
  77. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/math/quaternion.py +0 -0
  78. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/math/rotation.py +0 -0
  79. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/math/skew.py +0 -0
  80. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/math/transform.py +0 -0
  81. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/mujoco/__init__.py +0 -0
  82. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/mujoco/__main__.py +0 -0
  83. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/mujoco/loaders.py +0 -0
  84. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/mujoco/model.py +0 -0
  85. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/mujoco/visualizer.py +0 -0
  86. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/__init__.py +0 -0
  87. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  88. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  89. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  90. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/descriptions/link.py +0 -0
  91. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/descriptions/model.py +0 -0
  92. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  93. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/rod/__init__.py +0 -0
  94. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/rod/parser.py +0 -0
  95. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/parsers/rod/utils.py +0 -0
  96. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/rbda/aba.py +0 -0
  97. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/rbda/collidable_points.py +0 -0
  98. {jaxsim-0.3.1.dev21/tests → jaxsim-0.3.1.dev46/src/jaxsim/rbda/contacts}/__init__.py +0 -0
  99. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/rbda/crba.py +0 -0
  100. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  101. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/rbda/jacobian.py +0 -0
  102. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/rbda/rnea.py +0 -0
  103. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/rbda/utils.py +0 -0
  104. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/terrain/__init__.py +0 -0
  105. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/terrain/terrain.py +0 -0
  106. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/typing.py +0 -0
  107. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/utils/__init__.py +0 -0
  108. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  109. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/utils/tracing.py +0 -0
  110. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim/utils/wrappers.py +0 -0
  111. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  112. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim.egg-info/not-zip-safe +0 -0
  113. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim.egg-info/requires.txt +0 -0
  114. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/src/jaxsim.egg-info/top_level.txt +0 -0
  115. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/conftest.py +0 -0
  116. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_api_com.py +0 -0
  117. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_api_data.py +0 -0
  118. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_api_model.py +0 -0
  119. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_contact.py +0 -0
  120. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_exceptions.py +0 -0
  121. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/test_pytree.py +0 -0
  122. {jaxsim-0.3.1.dev21 → jaxsim-0.3.1.dev46}/tests/utils_idyntree.py +0 -0
@@ -24,3 +24,8 @@ repos:
24
24
  rev: v0.3.2
25
25
  hooks:
26
26
  - id: ruff
27
+
28
+ - repo: https://github.com/kynan/nbstripout
29
+ rev: 0.7.1
30
+ hooks:
31
+ - id: nbstripout
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.3.1.dev21
3
+ Version: 0.3.1.dev46
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -13,7 +13,6 @@ This module provides a set of algorithms for rigid body dynamics.
13
13
  crba
14
14
  forward_kinematics
15
15
  jacobian
16
- soft_contacts
17
16
  utils
18
17
 
19
18
  Articulated Body Algorithm
@@ -28,6 +27,12 @@ Collision Detection
28
27
  .. automodule:: jaxsim.rbda.collidable_points
29
28
  :members:
30
29
 
30
+ Contact Models
31
+ ~~~~~~~~~~~~~~
32
+
33
+ .. automodule:: jaxsim.rbda.contacts.soft
34
+ :members:
35
+
31
36
  Composite Rigid Body Algorithm
32
37
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33
38
 
@@ -46,12 +51,6 @@ Jacobians
46
51
  .. automodule:: jaxsim.rbda.jacobian
47
52
  :members:
48
53
 
49
- Soft Contacts
50
- ~~~~~~~~~~~~~
51
-
52
- .. automodule:: jaxsim.rbda.soft_contacts
53
- :members:
54
-
55
54
  Utilities
56
55
  ~~~~~~~~~
57
56
 
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.1.dev21'
16
- __version_tuple__ = version_tuple = (0, 3, 1, 'dev21')
15
+ __version__ = version = '0.3.1.dev46'
16
+ __version_tuple__ = version_tuple = (0, 3, 1, 'dev46')
@@ -1,11 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import functools
2
4
 
3
5
  import jax
4
6
  import jax.numpy as jnp
5
7
 
6
8
  import jaxsim.api as js
7
- import jaxsim.rbda
9
+ import jaxsim.terrain
8
10
  import jaxsim.typing as jtp
11
+ from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsParams
9
12
 
10
13
  from .common import VelRepr
11
14
 
@@ -135,17 +138,23 @@ def collidable_point_dynamics(
135
138
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
136
139
 
137
140
  # Build the soft contact model.
138
- soft_contacts = jaxsim.rbda.SoftContacts(
139
- parameters=data.soft_contacts_params, terrain=model.terrain
140
- )
141
+ match model.contact_model:
142
+ case s if isinstance(s, SoftContacts):
143
+ # Build the contact model.
144
+ soft_contacts = SoftContacts(
145
+ parameters=data.contacts_params, terrain=model.terrain
146
+ )
147
+
148
+ # Compute the 6D force expressed in the inertial frame and applied to each
149
+ # collidable point, and the corresponding material deformation rate.
150
+ # Note that the material deformation rate is always returned in the mixed frame
151
+ # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
152
+ W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
153
+ W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
154
+ )
141
155
 
142
- # Compute the 6D force expressed in the inertial frame and applied to each
143
- # collidable point, and the corresponding material deformation rate.
144
- # Note that the material deformation rate is always returned in the mixed frame
145
- # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
146
- W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)(
147
- W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation
148
- )
156
+ case _:
157
+ raise ValueError("Invalid contact model {}".format(model.contact_model))
149
158
 
150
159
  # Convert the 6D forces to the active representation.
151
160
  f_Ci = jax.vmap(
@@ -213,7 +222,7 @@ def estimate_good_soft_contacts_parameters(
213
222
  number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
214
223
  damping_ratio: jtp.FloatLike = 1.0,
215
224
  max_penetration: jtp.FloatLike | None = None,
216
- ) -> jaxsim.rbda.soft_contacts.SoftContactsParams:
225
+ ) -> SoftContactsParams:
217
226
  """
218
227
  Estimate good soft contacts parameters for the given model.
219
228
 
@@ -237,13 +246,14 @@ def estimate_good_soft_contacts_parameters(
237
246
  The user is encouraged to fine-tune the parameters based on the
238
247
  specific application.
239
248
  """
249
+ from jaxsim.rbda.contacts.soft import SoftContactsParams
240
250
 
241
251
  def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
242
252
  """"""
243
253
 
244
254
  zero_data = js.data.JaxSimModelData.build(
245
255
  model=model,
246
- soft_contacts_params=jaxsim.rbda.soft_contacts.SoftContactsParams(),
256
+ contacts_params=SoftContactsParams(),
247
257
  )
248
258
 
249
259
  W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
@@ -262,15 +272,13 @@ def estimate_good_soft_contacts_parameters(
262
272
 
263
273
  nc = number_of_active_collidable_points_steady_state
264
274
 
265
- sc_parameters = (
266
- jaxsim.rbda.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model(
267
- model=model,
268
- standard_gravity=standard_gravity,
269
- static_friction_coefficient=static_friction_coefficient,
270
- max_penetration=max_δ,
271
- number_of_active_collidable_points_steady_state=nc,
272
- damping_ratio=damping_ratio,
273
- )
275
+ sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
276
+ model=model,
277
+ standard_gravity=standard_gravity,
278
+ static_friction_coefficient=static_friction_coefficient,
279
+ max_penetration=max_δ,
280
+ number_of_active_collidable_points_steady_state=nc,
281
+ damping_ratio=damping_ratio,
274
282
  )
275
283
 
276
284
  return sc_parameters
@@ -14,6 +14,7 @@ import jaxsim.api as js
14
14
  import jaxsim.rbda
15
15
  import jaxsim.typing as jtp
16
16
  from jaxsim.math import Quaternion
17
+ from jaxsim.rbda.contacts.soft import SoftContacts
17
18
  from jaxsim.utils import Mutability
18
19
  from jaxsim.utils.tracing import not_tracing
19
20
 
@@ -37,7 +38,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
37
38
 
38
39
  gravity: jtp.Array
39
40
 
40
- soft_contacts_params: jaxsim.rbda.SoftContactsParams = dataclasses.field(repr=False)
41
+ contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
41
42
 
42
43
  time_ns: jtp.Int = dataclasses.field(
43
44
  default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
@@ -51,8 +52,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
51
52
  (
52
53
  hash(self.state),
53
54
  HashedNumpyArray.hash_of_array(self.gravity),
54
- hash(self.soft_contacts_params),
55
55
  HashedNumpyArray.hash_of_array(self.time_ns),
56
+ hash(self.contacts_params),
56
57
  )
57
58
  )
58
59
 
@@ -112,8 +113,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
112
113
  base_angular_velocity: jtp.Vector | None = None,
113
114
  joint_velocities: jtp.Vector | None = None,
114
115
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
115
- soft_contacts_state: js.ode_data.SoftContactsState | None = None,
116
- soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
116
+ contact: jaxsim.rbda.ContactsState | None = None,
117
+ contacts_params: jaxsim.rbda.ContactsParams | None = None,
117
118
  velocity_representation: VelRepr = VelRepr.Inertial,
118
119
  time: jtp.FloatLike | None = None,
119
120
  ) -> JaxSimModelData:
@@ -131,8 +132,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
131
132
  The base angular velocity in the selected representation.
132
133
  joint_velocities: The joint velocities.
133
134
  standard_gravity: The standard gravity constant.
134
- soft_contacts_state: The state of the soft contacts.
135
- soft_contacts_params: The parameters of the soft contacts.
135
+ contact: The state of the soft contacts.
136
+ contacts_params: The parameters of the soft contacts.
136
137
  velocity_representation: The velocity representation to use.
137
138
  time: The time at which the state is created.
138
139
 
@@ -178,13 +179,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
178
179
  else jnp.array(0, dtype=jnp.uint64)
179
180
  )
180
181
 
181
- soft_contacts_params = (
182
- soft_contacts_params
183
- if soft_contacts_params is not None
184
- else js.contact.estimate_good_soft_contacts_parameters(
185
- model=model, standard_gravity=standard_gravity
182
+ if isinstance(model.contact_model, SoftContacts):
183
+ contacts_params = (
184
+ contacts_params
185
+ if contacts_params is not None
186
+ else js.contact.estimate_good_soft_contacts_parameters(
187
+ model=model, standard_gravity=standard_gravity
188
+ )
186
189
  )
187
- )
190
+ else:
191
+ contacts_params = model.contact_model.parameters
188
192
 
189
193
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
190
194
  translation=base_position,
@@ -209,8 +213,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
209
213
  base_angular_velocity=v_WB[3:6].astype(float),
210
214
  joint_velocities=joint_velocities.astype(float),
211
215
  tangential_deformation=(
212
- soft_contacts_state.tangential_deformation
213
- if soft_contacts_state is not None
216
+ contact.tangential_deformation
217
+ if contact is not None and isinstance(model.contact_model, SoftContacts)
214
218
  else None
215
219
  ),
216
220
  )
@@ -222,7 +226,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
222
226
  time_ns=time_ns,
223
227
  state=ode_state,
224
228
  gravity=gravity.astype(float),
225
- soft_contacts_params=soft_contacts_params,
229
+ contacts_params=contacts_params,
226
230
  velocity_representation=velocity_representation,
227
231
  )
228
232
 
@@ -652,7 +656,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
652
656
 
653
657
  return self.reset_base_velocity(
654
658
  base_velocity=jnp.hstack(
655
- [linear_velocity.squeeze(), self.base_velocity()[3:6]]
659
+ [
660
+ linear_velocity.squeeze(),
661
+ self.base_velocity()[3:6],
662
+ ]
656
663
  ),
657
664
  velocity_representation=velocity_representation,
658
665
  )
@@ -680,7 +687,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
680
687
 
681
688
  return self.reset_base_velocity(
682
689
  base_velocity=jnp.hstack(
683
- [self.base_velocity()[0:3], angular_velocity.squeeze()]
690
+ [
691
+ self.base_velocity()[0:3],
692
+ angular_velocity.squeeze(),
693
+ ]
684
694
  ),
685
695
  velocity_representation=velocity_representation,
686
696
  )
@@ -4,11 +4,11 @@ from typing import Sequence
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import jaxlie
7
- import numpy as np
8
7
 
9
8
  import jaxsim.api as js
10
9
  import jaxsim.math
11
10
  import jaxsim.typing as jtp
11
+ from jaxsim import exceptions
12
12
 
13
13
  from .common import VelRepr
14
14
 
@@ -17,22 +17,32 @@ from .common import VelRepr
17
17
  # =======================
18
18
 
19
19
 
20
+ @jax.jit
20
21
  def idx_of_parent_link(
21
- model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike
22
+ model: js.model.JaxSimModel, *, frame_index: jtp.IntLike
22
23
  ) -> jtp.Int:
23
24
  """
24
25
  Get the index of the link to which the frame is rigidly attached.
25
26
 
26
27
  Args:
27
28
  model: The model to consider.
28
- frame_idx: The index of the frame.
29
+ frame_index: The index of the frame.
29
30
 
30
31
  Returns:
31
32
  The index of the frame's parent link.
32
33
  """
33
34
 
35
+ n_l = model.number_of_links()
36
+ n_f = len(model.frame_names())
37
+
38
+ exceptions.raise_value_error_if(
39
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
40
+ msg="Invalid frame index '{idx}'",
41
+ idx=frame_index,
42
+ )
43
+
34
44
  return model.kin_dyn_parameters.frame_parameters.body[
35
- frame_idx - model.number_of_links()
45
+ frame_index - model.number_of_links()
36
46
  ]
37
47
 
38
48
 
@@ -49,19 +59,17 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
49
59
  The index of the frame.
50
60
  """
51
61
 
52
- if frame_name in model.kin_dyn_parameters.frame_parameters.name:
53
- return (
54
- jnp.array(
55
- np.argwhere(
56
- np.array(model.kin_dyn_parameters.frame_parameters.name)
57
- == frame_name
58
- )
59
- )
60
- .squeeze()
61
- .astype(int)
62
- ) + model.number_of_links()
62
+ if frame_name not in model.kin_dyn_parameters.frame_parameters.name:
63
+ raise ValueError(f"Frame '{frame_name}' not found in the model.")
63
64
 
64
- return jnp.array(-1).astype(int)
65
+ return (
66
+ jnp.array(
67
+ model.number_of_links()
68
+ + model.kin_dyn_parameters.frame_parameters.name.index(frame_name)
69
+ )
70
+ .astype(int)
71
+ .squeeze()
72
+ )
65
73
 
66
74
 
67
75
  def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
@@ -76,6 +84,15 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
76
84
  The name of the frame.
77
85
  """
78
86
 
87
+ n_l = model.number_of_links()
88
+ n_f = len(model.frame_names())
89
+
90
+ exceptions.raise_value_error_if(
91
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
92
+ msg="Invalid frame index '{idx}'",
93
+ idx=frame_index,
94
+ )
95
+
79
96
  return model.kin_dyn_parameters.frame_parameters.name[
80
97
  frame_index - model.number_of_links()
81
98
  ]
@@ -142,8 +159,17 @@ def transform(
142
159
  The 4x4 matrix representing the transform.
143
160
  """
144
161
 
162
+ n_l = model.number_of_links()
163
+ n_f = len(model.frame_names())
164
+
165
+ exceptions.raise_value_error_if(
166
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
167
+ msg="Invalid frame index '{idx}'",
168
+ idx=frame_index,
169
+ )
170
+
145
171
  # Compute the necessary transforms.
146
- L = idx_of_parent_link(model=model, frame_idx=frame_index)
172
+ L = idx_of_parent_link(model=model, frame_index=frame_index)
147
173
  W_H_L = js.link.transform(model=model, data=data, link_index=L)
148
174
 
149
175
  # Get the static frame pose wrt the parent link.
@@ -181,12 +207,21 @@ def jacobian(
181
207
  velocity representation.
182
208
  """
183
209
 
210
+ n_l = model.number_of_links()
211
+ n_f = len(model.frame_names())
212
+
213
+ exceptions.raise_value_error_if(
214
+ condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
215
+ msg="Invalid frame index '{idx}'",
216
+ idx=frame_index,
217
+ )
218
+
184
219
  output_vel_repr = (
185
220
  output_vel_repr if output_vel_repr is not None else data.velocity_representation
186
221
  )
187
222
 
188
223
  # Get the index of the parent link.
189
- L = idx_of_parent_link(model=model, frame_idx=frame_index)
224
+ L = idx_of_parent_link(model=model, frame_index=frame_index)
190
225
 
191
226
  # Compute the Jacobian of the parent link using body-fixed output representation.
192
227
  L_J_WL = js.link.jacobian(
@@ -6,6 +6,7 @@ import jax.numpy as jnp
6
6
 
7
7
  import jaxsim.api as js
8
8
  import jaxsim.typing as jtp
9
+ from jaxsim import exceptions
9
10
 
10
11
  # =======================
11
12
  # Index-related functions
@@ -25,14 +26,18 @@ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
25
26
  The index of the joint.
26
27
  """
27
28
 
28
- if joint_name in model.kin_dyn_parameters.joint_model.joint_names:
29
- # Note: the index of the joint for RBDAs starts from 1, but
30
- # the index for accessing the right element starts from 0.
31
- # Therefore, there is a -1.
32
- return jnp.array(
29
+ if joint_name not in model.joint_names():
30
+ raise ValueError(f"Joint '{joint_name}' not found in the model.")
31
+
32
+ # Note: the index of the joint for RBDAs starts from 1, but the index for
33
+ # accessing the right element starts from 0. Therefore, there is a -1.
34
+ return (
35
+ jnp.array(
33
36
  model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
34
- ).squeeze()
35
- return jnp.array(-1).astype(int)
37
+ )
38
+ .astype(int)
39
+ .squeeze()
40
+ )
36
41
 
37
42
 
38
43
  def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
@@ -47,6 +52,14 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
47
52
  The name of the joint.
48
53
  """
49
54
 
55
+ exceptions.raise_value_error_if(
56
+ condition=jnp.array(
57
+ [joint_index < 0, joint_index >= model.number_of_joints()]
58
+ ).any(),
59
+ msg="Invalid joint index '{idx}'",
60
+ idx=joint_index,
61
+ )
62
+
50
63
  return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
51
64
 
52
65
 
@@ -112,6 +125,14 @@ def position_limit(
112
125
  if model.number_of_joints() <= 1:
113
126
  return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
114
127
 
128
+ exceptions.raise_value_error_if(
129
+ condition=jnp.array(
130
+ [joint_index < 0, joint_index >= model.number_of_joints()]
131
+ ).any(),
132
+ msg="Invalid joint index '{idx}'",
133
+ idx=joint_index,
134
+ )
135
+
115
136
  s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
116
137
  s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
117
138
 
@@ -5,11 +5,11 @@ import jax
5
5
  import jax.numpy as jnp
6
6
  import jax.scipy.linalg
7
7
  import jaxlie
8
- import numpy as np
9
8
 
10
9
  import jaxsim.api as js
11
10
  import jaxsim.rbda
12
11
  import jaxsim.typing as jtp
12
+ from jaxsim import exceptions
13
13
 
14
14
  from .common import VelRepr
15
15
 
@@ -31,15 +31,14 @@ def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:
31
31
  The index of the link.
32
32
  """
33
33
 
34
- if link_name in model.kin_dyn_parameters.link_names:
35
- return (
36
- jnp.array(
37
- np.argwhere(np.array(model.kin_dyn_parameters.link_names) == link_name)
38
- )
39
- .squeeze()
40
- .astype(int)
41
- )
42
- return jnp.array(-1).astype(int)
34
+ if link_name not in model.link_names():
35
+ raise ValueError(f"Link '{link_name}' not found in the model.")
36
+
37
+ return (
38
+ jnp.array(model.kin_dyn_parameters.link_names.index(link_name))
39
+ .astype(int)
40
+ .squeeze()
41
+ )
43
42
 
44
43
 
45
44
  def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
@@ -54,6 +53,14 @@ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
54
53
  The name of the link.
55
54
  """
56
55
 
56
+ exceptions.raise_value_error_if(
57
+ condition=jnp.array(
58
+ [link_index < 0, link_index >= model.number_of_links()]
59
+ ).any(),
60
+ msg="Invalid link index '{idx}'",
61
+ idx=link_index,
62
+ )
63
+
57
64
  return model.kin_dyn_parameters.link_names[link_index]
58
65
 
59
66
 
@@ -112,6 +119,14 @@ def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
112
119
  The mass of the link.
113
120
  """
114
121
 
122
+ exceptions.raise_value_error_if(
123
+ condition=jnp.array(
124
+ [link_index < 0, link_index >= model.number_of_links()]
125
+ ).any(),
126
+ msg="Invalid link index '{idx}'",
127
+ idx=link_index,
128
+ )
129
+
115
130
  return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float)
116
131
 
117
132
 
@@ -131,6 +146,14 @@ def spatial_inertia(
131
146
  the link frame (body-fixed representation).
132
147
  """
133
148
 
149
+ exceptions.raise_value_error_if(
150
+ condition=jnp.array(
151
+ [link_index < 0, link_index >= model.number_of_links()]
152
+ ).any(),
153
+ msg="Invalid link index '{idx}'",
154
+ idx=link_index,
155
+ )
156
+
134
157
  link_parameters = jax.tree_util.tree_map(
135
158
  lambda l: l[link_index], model.kin_dyn_parameters.link_parameters
136
159
  )
@@ -157,6 +180,14 @@ def transform(
157
180
  The 4x4 matrix representing the transform.
158
181
  """
159
182
 
183
+ exceptions.raise_value_error_if(
184
+ condition=jnp.array(
185
+ [link_index < 0, link_index >= model.number_of_links()]
186
+ ).any(),
187
+ msg="Invalid link index '{idx}'",
188
+ idx=link_index,
189
+ )
190
+
160
191
  return js.model.forward_kinematics(model=model, data=data)[link_index]
161
192
 
162
193
 
@@ -230,6 +261,14 @@ def jacobian(
230
261
  velocity representation.
231
262
  """
232
263
 
264
+ exceptions.raise_value_error_if(
265
+ condition=jnp.array(
266
+ [link_index < 0, link_index >= model.number_of_links()]
267
+ ).any(),
268
+ msg="Invalid link index '{idx}'",
269
+ idx=link_index,
270
+ )
271
+
233
272
  output_vel_repr = (
234
273
  output_vel_repr if output_vel_repr is not None else data.velocity_representation
235
274
  )
@@ -318,6 +357,14 @@ def velocity(
318
357
  The 6D velocity of the link in the specified velocity representation.
319
358
  """
320
359
 
360
+ exceptions.raise_value_error_if(
361
+ condition=jnp.array(
362
+ [link_index < 0, link_index >= model.number_of_links()]
363
+ ).any(),
364
+ msg="Invalid link index '{idx}'",
365
+ idx=link_index,
366
+ )
367
+
321
368
  output_vel_repr = (
322
369
  output_vel_repr if output_vel_repr is not None else data.velocity_representation
323
370
  )
@@ -364,6 +411,14 @@ def jacobian_derivative(
364
411
  velocity representation.
365
412
  """
366
413
 
414
+ exceptions.raise_value_error_if(
415
+ condition=jnp.array(
416
+ [link_index < 0, link_index >= model.number_of_links()]
417
+ ).any(),
418
+ msg="Invalid link index '{idx}'",
419
+ idx=link_index,
420
+ )
421
+
367
422
  output_vel_repr = (
368
423
  output_vel_repr if output_vel_repr is not None else data.velocity_representation
369
424
  )
@@ -538,6 +593,14 @@ def bias_acceleration(
538
593
  The 6D bias acceleration of the link.
539
594
  """
540
595
 
596
+ exceptions.raise_value_error_if(
597
+ condition=jnp.array(
598
+ [link_index < 0, link_index >= model.number_of_links()]
599
+ ).any(),
600
+ msg="Invalid link index '{idx}'",
601
+ idx=link_index,
602
+ )
603
+
541
604
  # Compute the bias acceleration of all links in the active representation.
542
605
  O_v̇_WL = js.model.link_bias_accelerations(model=model, data=data)[link_index]
543
606
  return O_v̇_WL