jaxsim 0.4.3.dev161__tar.gz → 0.4.3.dev177__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/PKG-INFO +1 -1
  2. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/_version.py +2 -2
  3. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/contact.py +67 -16
  4. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/data.py +5 -26
  5. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/kin_dyn_parameters.py +14 -1
  6. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/model.py +38 -29
  7. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/ode.py +11 -5
  8. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/contacts/__init__.py +2 -1
  9. jaxsim-0.4.3.dev177/src/jaxsim/rbda/contacts/visco_elastic.py +1055 -0
  10. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim.egg-info/PKG-INFO +1 -1
  11. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim.egg-info/SOURCES.txt +1 -0
  12. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_simulations.py +4 -10
  13. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.devcontainer/Dockerfile +0 -0
  14. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.devcontainer/devcontainer.json +0 -0
  15. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.gitattributes +0 -0
  16. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.github/CODEOWNERS +0 -0
  17. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.github/workflows/ci_cd.yml +0 -0
  18. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.github/workflows/read_the_docs.yml +0 -0
  19. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.github/workflows/update_pixi_lockfile.yml +0 -0
  20. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.gitignore +0 -0
  21. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.pre-commit-config.yaml +0 -0
  22. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/.readthedocs.yaml +0 -0
  23. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/CONTRIBUTING.md +0 -0
  24. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/LICENSE +0 -0
  25. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/README.md +0 -0
  26. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/Makefile +0 -0
  27. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/conf.py +0 -0
  28. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/examples.rst +0 -0
  29. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/guide/install.rst +0 -0
  30. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/index.rst +0 -0
  31. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/make.bat +0 -0
  32. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/modules/api.rst +0 -0
  33. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/modules/integrators.rst +0 -0
  34. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/modules/math.rst +0 -0
  35. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/modules/mujoco.rst +0 -0
  36. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/modules/parsers.rst +0 -0
  37. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/modules/rbda.rst +0 -0
  38. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/modules/typing.rst +0 -0
  39. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/docs/modules/utils.rst +0 -0
  40. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/environment.yml +0 -0
  41. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/examples/.gitattributes +0 -0
  42. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/examples/.gitignore +0 -0
  43. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/examples/PD_controller.ipynb +0 -0
  44. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/examples/Parallel_computing.ipynb +0 -0
  45. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/examples/README.md +0 -0
  46. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/examples/assets/cartpole.urdf +0 -0
  47. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/pixi.lock +0 -0
  48. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/pyproject.toml +0 -0
  49. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/setup.cfg +0 -0
  50. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/setup.py +0 -0
  51. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/__init__.py +0 -0
  52. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/__init__.py +0 -0
  53. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/com.py +0 -0
  54. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/common.py +0 -0
  55. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/frame.py +0 -0
  56. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/joint.py +0 -0
  57. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/link.py +0 -0
  58. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/ode_data.py +0 -0
  59. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/api/references.py +0 -0
  60. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/exceptions.py +0 -0
  61. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/integrators/__init__.py +0 -0
  62. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/integrators/common.py +0 -0
  63. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/integrators/fixed_step.py +0 -0
  64. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/integrators/variable_step.py +0 -0
  65. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/logging.py +0 -0
  66. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/math/__init__.py +0 -0
  67. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/math/adjoint.py +0 -0
  68. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/math/cross.py +0 -0
  69. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/math/inertia.py +0 -0
  70. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/math/joint_model.py +0 -0
  71. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/math/quaternion.py +0 -0
  72. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/math/rotation.py +0 -0
  73. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/math/skew.py +0 -0
  74. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/math/transform.py +0 -0
  75. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/mujoco/__init__.py +0 -0
  76. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/mujoco/__main__.py +0 -0
  77. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/mujoco/loaders.py +0 -0
  78. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/mujoco/model.py +0 -0
  79. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/mujoco/visualizer.py +0 -0
  80. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/__init__.py +0 -0
  81. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  82. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  83. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  84. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/descriptions/link.py +0 -0
  85. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/descriptions/model.py +0 -0
  86. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  87. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/rod/__init__.py +0 -0
  88. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/rod/parser.py +0 -0
  89. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/parsers/rod/utils.py +0 -0
  90. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/__init__.py +0 -0
  91. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/aba.py +0 -0
  92. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/collidable_points.py +0 -0
  93. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/contacts/common.py +0 -0
  94. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
  95. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  96. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/contacts/soft.py +0 -0
  97. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/crba.py +0 -0
  98. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  99. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/jacobian.py +0 -0
  100. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/rnea.py +0 -0
  101. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/rbda/utils.py +0 -0
  102. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/terrain/__init__.py +0 -0
  103. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/terrain/terrain.py +0 -0
  104. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/typing.py +0 -0
  105. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/utils/__init__.py +0 -0
  106. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  107. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/utils/tracing.py +0 -0
  108. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim/utils/wrappers.py +0 -0
  109. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  110. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim.egg-info/requires.txt +0 -0
  111. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/src/jaxsim.egg-info/top_level.txt +0 -0
  112. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/__init__.py +0 -0
  113. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/conftest.py +0 -0
  114. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_api_com.py +0 -0
  115. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_api_contact.py +0 -0
  116. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_api_data.py +0 -0
  117. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_api_frame.py +0 -0
  118. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_api_joint.py +0 -0
  119. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_api_link.py +0 -0
  120. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_api_model.py +0 -0
  121. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_automatic_differentiation.py +0 -0
  122. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_contact.py +0 -0
  123. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_exceptions.py +0 -0
  124. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/test_pytree.py +0 -0
  125. {jaxsim-0.4.3.dev161 → jaxsim-0.4.3.dev177}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev161
3
+ Version: 0.4.3.dev177
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev161'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev161')
15
+ __version__ = version = '0.4.3.dev177'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev177')
@@ -6,9 +6,12 @@ import jax
6
6
  import jax.numpy as jnp
7
7
 
8
8
  import jaxsim.api as js
9
+ import jaxsim.exceptions
9
10
  import jaxsim.terrain
10
11
  import jaxsim.typing as jtp
12
+ from jaxsim import logging
11
13
  from jaxsim.math import Adjoint, Cross, Transform
14
+ from jaxsim.rbda import contacts
12
15
 
13
16
  from .common import VelRepr
14
17
 
@@ -156,14 +159,11 @@ def collidable_point_dynamics(
156
159
  Instead, the 6D forces are returned in the active representation.
157
160
  """
158
161
 
159
- # Import privately the contacts classes.
160
- from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts
161
-
162
162
  # Build the soft contact model.
163
163
  match model.contact_model:
164
164
 
165
- case SoftContacts():
166
- assert isinstance(model.contact_model, SoftContacts)
165
+ case contacts.SoftContacts():
166
+ assert isinstance(model.contact_model, contacts.SoftContacts)
167
167
 
168
168
  # Compute the 6D force expressed in the inertial frame and applied to each
169
169
  # collidable point, and the corresponding material deformation rate.
@@ -178,8 +178,8 @@ def collidable_point_dynamics(
178
178
  # of the ODE system. We need to pass its dynamics to the integrator.
179
179
  aux_data = dict(m_dot=CW_ṁ)
180
180
 
181
- case RigidContacts():
182
- assert isinstance(model.contact_model, RigidContacts)
181
+ case contacts.RigidContacts():
182
+ assert isinstance(model.contact_model, contacts.RigidContacts)
183
183
 
184
184
  # Compute the 6D force expressed in the inertial frame and applied to each
185
185
  # collidable point.
@@ -192,8 +192,8 @@ def collidable_point_dynamics(
192
192
 
193
193
  aux_data = dict()
194
194
 
195
- case RelaxedRigidContacts():
196
- assert isinstance(model.contact_model, RelaxedRigidContacts)
195
+ case contacts.RelaxedRigidContacts():
196
+ assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
197
197
 
198
198
  # Compute the 6D force expressed in the inertial frame and applied to each
199
199
  # collidable point.
@@ -206,6 +206,31 @@ def collidable_point_dynamics(
206
206
 
207
207
  aux_data = dict()
208
208
 
209
+ case contacts.ViscoElasticContacts():
210
+ assert isinstance(model.contact_model, contacts.ViscoElasticContacts)
211
+
212
+ # It is not yet clear how to pass the time step to this stage.
213
+ # A possibility is to restrict the integrator to only forward Euler
214
+ # and store the Δt inside the model.
215
+ module = jaxsim.rbda.contacts.visco_elastic.step.__module__
216
+ name = jaxsim.rbda.contacts.visco_elastic.step.__name__
217
+ msg = "You need to use the custom '{}.{}' function with this contact model."
218
+ jaxsim.exceptions.raise_runtime_error_if(
219
+ condition=True, msg=msg.format(module, name)
220
+ )
221
+
222
+ # Compute the 6D force expressed in the inertial frame and applied to each
223
+ # collidable point.
224
+ W_f_Ci, (W_f̿_Ci, m_tf) = model.contact_model.compute_contact_forces(
225
+ model=model,
226
+ data=data,
227
+ dt=None, # TODO
228
+ link_forces=link_forces,
229
+ joint_force_references=joint_force_references,
230
+ )
231
+
232
+ aux_data = dict(W_f_avg2_C=W_f̿_Ci, m_tf=m_tf)
233
+
209
234
  case _:
210
235
  raise ValueError(f"Invalid contact model {model.contact_model}")
211
236
 
@@ -278,7 +303,6 @@ def in_contact(
278
303
  return links_in_contact
279
304
 
280
305
 
281
- @jax.jit
282
306
  def estimate_good_soft_contacts_parameters(
283
307
  model: js.model.JaxSimModel,
284
308
  *,
@@ -287,9 +311,15 @@ def estimate_good_soft_contacts_parameters(
287
311
  number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
288
312
  damping_ratio: jtp.FloatLike = 1.0,
289
313
  max_penetration: jtp.FloatLike | None = None,
290
- ) -> jaxsim.rbda.contacts.SoftContactsParams:
314
+ **kwargs,
315
+ ) -> (
316
+ jaxsim.rbda.contacts.RelaxedRigidContactsParams
317
+ | jaxsim.rbda.contacts.RigidContactsParams
318
+ | jaxsim.rbda.contacts.SoftContactsParams
319
+ | jaxsim.rbda.contacts.ViscoElasticContactsParams
320
+ ):
291
321
  """
292
- Estimate good soft contacts parameters for the given model.
322
+ Estimate good parameters for soft-like contact models.
293
323
 
294
324
  Args:
295
325
  model: The model to consider.
@@ -313,7 +343,10 @@ def estimate_good_soft_contacts_parameters(
313
343
  """
314
344
 
315
345
  def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
316
- """"""
346
+ """
347
+ Displacement between the CoM and the lowest collidable point using zero
348
+ joint positions.
349
+ """
317
350
 
318
351
  zero_data = js.data.JaxSimModelData.build(
319
352
  model=model,
@@ -338,21 +371,39 @@ def estimate_good_soft_contacts_parameters(
338
371
 
339
372
  match model.contact_model:
340
373
 
341
- case jaxsim.rbda.contacts.SoftContacts():
342
- assert isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts)
374
+ case contacts.SoftContacts():
375
+ assert isinstance(model.contact_model, contacts.SoftContacts)
376
+
377
+ parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
378
+ model=model,
379
+ standard_gravity=standard_gravity,
380
+ static_friction_coefficient=static_friction_coefficient,
381
+ max_penetration=max_δ,
382
+ number_of_active_collidable_points_steady_state=nc,
383
+ damping_ratio=damping_ratio,
384
+ p=model.contact_model.parameters.p,
385
+ q=model.contact_model.parameters.q,
386
+ )
387
+
388
+ case contacts.ViscoElasticContacts():
389
+ assert isinstance(model.contact_model, contacts.ViscoElasticContacts)
343
390
 
344
391
  parameters = (
345
- jaxsim.rbda.contacts.SoftContactsParams.build_default_from_jaxsim_model(
392
+ contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model(
346
393
  model=model,
347
394
  standard_gravity=standard_gravity,
348
395
  static_friction_coefficient=static_friction_coefficient,
349
396
  max_penetration=max_δ,
350
397
  number_of_active_collidable_points_steady_state=nc,
351
398
  damping_ratio=damping_ratio,
399
+ p=model.contact_model.parameters.p,
400
+ q=model.contact_model.parameters.q,
401
+ **kwargs,
352
402
  )
353
403
  )
354
404
 
355
405
  case _:
406
+ logging.warning("The active contact model is not soft-like, no-op.")
356
407
  parameters = model.contact_model.parameters
357
408
 
358
409
  return parameters
@@ -38,12 +38,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
38
38
 
39
39
  contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
40
40
 
41
- time_ns: jtp.Int = dataclasses.field(
42
- default_factory=lambda: jnp.array(
43
- 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
44
- ),
45
- )
46
-
47
41
  def __hash__(self) -> int:
48
42
 
49
43
  from jaxsim.utils.wrappers import HashedNumpyArray
@@ -52,7 +46,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
52
46
  (
53
47
  hash(self.state),
54
48
  HashedNumpyArray.hash_of_array(self.gravity),
55
- HashedNumpyArray.hash_of_array(self.time_ns),
56
49
  hash(self.contacts_params),
57
50
  )
58
51
  )
@@ -115,7 +108,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
115
108
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
116
109
  contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
117
110
  velocity_representation: VelRepr = VelRepr.Inertial,
118
- time: jtp.FloatLike | None = None,
119
111
  extended_ode_state: dict[str, jtp.PyTree] | None = None,
120
112
  ) -> JaxSimModelData:
121
113
  """
@@ -134,7 +126,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
134
126
  standard_gravity: The standard gravity constant.
135
127
  contacts_params: The parameters of the soft contacts.
136
128
  velocity_representation: The velocity representation to use.
137
- time: The time at which the state is created.
138
129
  extended_ode_state:
139
130
  Additional user-defined state variables that are not part of the
140
131
  standard `ODEState` object. Useful to extend the system dynamics
@@ -196,11 +187,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
196
187
  ).squeeze()
197
188
  )
198
189
 
199
- time_ns = jnp.array(
200
- time * 1e9 if time is not None else 0.0,
201
- dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
202
- )
203
-
204
190
  W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
205
191
  translation=base_position, quaternion=base_quaternion
206
192
  )
@@ -233,7 +219,11 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
233
219
 
234
220
  if contacts_params is None:
235
221
 
236
- if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
222
+ if isinstance(
223
+ model.contact_model,
224
+ jaxsim.rbda.contacts.SoftContacts
225
+ | jaxsim.rbda.contacts.ViscoElasticContacts,
226
+ ):
237
227
  contacts_params = js.contact.estimate_good_soft_contacts_parameters(
238
228
  model=model, standard_gravity=standard_gravity
239
229
  )
@@ -242,7 +232,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
242
232
  contacts_params = model.contact_model.parameters
243
233
 
244
234
  return JaxSimModelData(
245
- time_ns=time_ns,
246
235
  state=ode_state,
247
236
  gravity=gravity,
248
237
  contacts_params=contacts_params,
@@ -253,16 +242,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
253
242
  # Extract quantities
254
243
  # ==================
255
244
 
256
- def time(self) -> jtp.Float:
257
- """
258
- Get the simulated time.
259
-
260
- Returns:
261
- The simulated time in seconds.
262
- """
263
-
264
- return self.time_ns.astype(float) / 1e9
265
-
266
245
  def standard_gravity(self) -> jtp.Float:
267
246
  """
268
247
  Get the standard gravity constant.
@@ -5,6 +5,8 @@ import dataclasses
5
5
  import jax.lax
6
6
  import jax.numpy as jnp
7
7
  import jax_dataclasses
8
+ import numpy as np
9
+ import numpy.typing as npt
8
10
  from jax_dataclasses import Static
9
11
 
10
12
  import jaxsim.typing as jtp
@@ -753,6 +755,13 @@ class ContactParameters(JaxsimDataclass):
753
755
 
754
756
  point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([]))
755
757
 
758
+ enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple)
759
+
760
+ @property
761
+ def indices_of_enabled_collidable_points(self) -> npt.NDArray:
762
+
763
+ return np.where(np.array(self.enabled))[0]
764
+
756
765
  @staticmethod
757
766
  def build_from(model_description: ModelDescription) -> ContactParameters:
758
767
  """
@@ -785,7 +794,11 @@ class ContactParameters(JaxsimDataclass):
785
794
  )
786
795
 
787
796
  # Build the ContactParameters object.
788
- cp = ContactParameters(point=points, body=link_index_of_points)
797
+ cp = ContactParameters(
798
+ point=points,
799
+ body=link_index_of_points,
800
+ enabled=tuple(True for _ in link_index_of_points),
801
+ )
789
802
 
790
803
  assert cp.point.shape[1] == 3, cp.point.shape[1]
791
804
  assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]
@@ -32,6 +32,10 @@ class JaxSimModel(JaxsimDataclass):
32
32
 
33
33
  model_name: Static[str]
34
34
 
35
+ time_step: jaxsim.integrators.TimeStep = dataclasses.field(
36
+ default_factory=lambda: jnp.array(0.001, dtype=float),
37
+ )
38
+
35
39
  terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
36
40
  default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
37
41
  )
@@ -64,6 +68,9 @@ class JaxSimModel(JaxsimDataclass):
64
68
  if self.model_name != other.model_name:
65
69
  return False
66
70
 
71
+ if self.time_step != other.time_step:
72
+ return False
73
+
67
74
  if self.kin_dyn_parameters != other.kin_dyn_parameters:
68
75
  return False
69
76
 
@@ -74,6 +81,7 @@ class JaxSimModel(JaxsimDataclass):
74
81
  return hash(
75
82
  (
76
83
  hash(self.model_name),
84
+ hash(float(self.time_step)),
77
85
  hash(self.kin_dyn_parameters),
78
86
  hash(self.contact_model),
79
87
  )
@@ -88,6 +96,7 @@ class JaxSimModel(JaxsimDataclass):
88
96
  model_description: str | pathlib.Path | rod.Model,
89
97
  model_name: str | None = None,
90
98
  *,
99
+ time_step: jtp.FloatLike | None = None,
91
100
  terrain: jaxsim.terrain.Terrain | None = None,
92
101
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
93
102
  is_urdf: bool | None = None,
@@ -102,6 +111,9 @@ class JaxSimModel(JaxsimDataclass):
102
111
  its content, or a pre-parsed/pre-built rod model.
103
112
  model_name:
104
113
  The name of the model. If not specified, it is read from the description.
114
+ time_step:
115
+ The default time step to consider for the simulation. It can be
116
+ manually overridden in the function that steps the simulation.
105
117
  terrain: The terrain to consider (the default is a flat infinite plane).
106
118
  contact_model:
107
119
  The contact model to consider.
@@ -135,6 +147,7 @@ class JaxSimModel(JaxsimDataclass):
135
147
  model = JaxSimModel.build(
136
148
  model_description=intermediate_description,
137
149
  model_name=model_name,
150
+ time_step=time_step,
138
151
  terrain=terrain,
139
152
  contact_model=contact_model,
140
153
  )
@@ -150,6 +163,7 @@ class JaxSimModel(JaxsimDataclass):
150
163
  model_description: ModelDescription,
151
164
  model_name: str | None = None,
152
165
  *,
166
+ time_step: jtp.FloatLike | None = None,
153
167
  terrain: jaxsim.terrain.Terrain | None = None,
154
168
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
155
169
  ) -> JaxSimModel:
@@ -162,6 +176,9 @@ class JaxSimModel(JaxsimDataclass):
162
176
  of the model.
163
177
  model_name:
164
178
  The name of the model. If not specified, it is read from the description.
179
+ time_step:
180
+ The default time step to consider for the simulation. It can be
181
+ manually overridden in the function that steps the simulation.
165
182
  terrain: The terrain to consider (the default is a flat infinite plane).
166
183
  contact_model:
167
184
  The contact model to consider.
@@ -179,6 +196,11 @@ class JaxSimModel(JaxsimDataclass):
179
196
  terrain or JaxSimModel.__dataclass_fields__["terrain"].default_factory()
180
197
  )
181
198
 
199
+ # Consider the default time step if not specified.
200
+ time_step = (
201
+ time_step or JaxSimModel.__dataclass_fields__["time_step"].default_factory()
202
+ )
203
+
182
204
  # Create the default contact model.
183
205
  # It will be populated with an initial estimation of good parameters.
184
206
  # While these might not be the best, they are a good starting point.
@@ -192,6 +214,7 @@ class JaxSimModel(JaxsimDataclass):
192
214
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
193
215
  model_description=model_description
194
216
  ),
217
+ time_step=time_step,
195
218
  terrain=terrain,
196
219
  contact_model=contact_model,
197
220
  # The following is wrapped as hashless since it's a static argument, and we
@@ -1915,8 +1938,9 @@ def step(
1915
1938
  model: JaxSimModel,
1916
1939
  data: js.data.JaxSimModelData,
1917
1940
  *,
1918
- dt: jtp.FloatLike,
1919
1941
  integrator: jaxsim.integrators.Integrator,
1942
+ t0: jtp.FloatLike = 0.0,
1943
+ dt: jtp.FloatLike | None = None,
1920
1944
  integrator_state: dict[str, Any] | None = None,
1921
1945
  link_forces: jtp.MatrixLike | None = None,
1922
1946
  joint_force_references: jtp.VectorLike | None = None,
@@ -1928,9 +1952,10 @@ def step(
1928
1952
  Args:
1929
1953
  model: The model to consider.
1930
1954
  data: The data of the considered model.
1931
- dt: The time step to consider.
1932
1955
  integrator: The integrator to use.
1933
1956
  integrator_state: The state of the integrator.
1957
+ t0: The initial time to consider. Only relevant for time-dependent dynamics.
1958
+ dt: The time step to consider. If not specified, it is read from the model.
1934
1959
  link_forces:
1935
1960
  The 6D forces to apply to the links expressed in the frame corresponding to
1936
1961
  the velocity representation of `data`.
@@ -1951,17 +1976,20 @@ def step(
1951
1976
 
1952
1977
  integrator_state = integrator_state if integrator_state is not None else dict()
1953
1978
 
1954
- # Extract the initial resources.
1955
- t0_ns = data.time_ns
1979
+ # Initialize the time-related variables.
1956
1980
  state_t0 = data.state
1957
- integrator_state_x0 = integrator_state
1981
+ t0 = jnp.array(t0, dtype=float)
1982
+ dt = jnp.array(dt if dt is not None else model.time_step).astype(float)
1983
+
1984
+ # Rename the integrator state.
1985
+ integrator_state_t0 = integrator_state
1958
1986
 
1959
1987
  # Step the dynamics forward.
1960
1988
  state_tf, integrator_state_tf = integrator.step(
1961
1989
  x0=state_t0,
1962
- t0=jnp.array(t0_ns / 1e9).astype(float),
1990
+ t0=t0,
1963
1991
  dt=dt,
1964
- params=integrator_state_x0,
1992
+ params=integrator_state_t0,
1965
1993
  # Always inject the current (model, data) pair into the system dynamics
1966
1994
  # considered by the integrator, and include the input variables represented
1967
1995
  # by the pair (joint_force_references, link_forces).
@@ -1980,24 +2008,8 @@ def step(
1980
2008
  ),
1981
2009
  )
1982
2010
 
1983
- tf_ns = t0_ns + jnp.array(dt * 1e9, dtype=t0_ns.dtype)
1984
- tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))
1985
-
1986
- jax.lax.cond(
1987
- pred=tf_ns < t0_ns,
1988
- true_fun=lambda: jax.debug.print(
1989
- "The simulation time overflowed, resetting simulation time to 0."
1990
- ),
1991
- false_fun=lambda: None,
1992
- )
1993
-
1994
- data_tf = (
1995
- # Store the new state of the model and the new time.
1996
- data.replace(
1997
- state=state_tf,
1998
- time_ns=tf_ns,
1999
- )
2000
- )
2011
+ # Store the new state of the model.
2012
+ data_tf = data.replace(state=state_tf)
2001
2013
 
2002
2014
  # Post process the simulation state, if needed.
2003
2015
  match model.contact_model:
@@ -2064,7 +2076,4 @@ def step(
2064
2076
  velocity_representation=data.velocity_representation, validate=False
2065
2077
  )
2066
2078
 
2067
- return (
2068
- data_tf,
2069
- integrator_state_tf,
2070
- )
2079
+ return data_tf, integrator_state_tf
@@ -8,6 +8,7 @@ import jaxsim.rbda
8
8
  import jaxsim.typing as jtp
9
9
  from jaxsim.integrators import Time
10
10
  from jaxsim.math import Quaternion
11
+ from jaxsim.rbda import contacts
11
12
 
12
13
  from .common import VelRepr
13
14
  from .ode_data import ODEState
@@ -62,7 +63,6 @@ def wrap_system_dynamics_for_integration(
62
63
  # Update the state and time stored inside data.
63
64
  with data_f.editable(validate=True) as data_rw:
64
65
  data_rw.state = x
65
- data_rw.time_ns = jnp.array(t * 1e9).astype(data_rw.time_ns.dtype)
66
66
 
67
67
  # Evaluate the system dynamics, allowing to override the kwargs originally
68
68
  # passed when the closure was created.
@@ -371,8 +371,6 @@ def system_dynamics(
371
371
  by the system dynamics evaluation.
372
372
  """
373
373
 
374
- from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts
375
-
376
374
  # Compute the accelerations and the material deformation rate.
377
375
  W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
378
376
  model=model,
@@ -387,10 +385,18 @@ def system_dynamics(
387
385
 
388
386
  match model.contact_model:
389
387
 
390
- case SoftContacts():
388
+ case contacts.SoftContacts():
391
389
  extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]
392
390
 
393
- case RigidContacts() | RelaxedRigidContacts():
391
+ case contacts.ViscoElasticContacts():
392
+
393
+ extended_ode_state["contacts_state"] = {
394
+ "tangential_deformation": jnp.zeros_like(
395
+ data.state.extended["tangential_deformation"]
396
+ )
397
+ }
398
+
399
+ case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
394
400
  pass
395
401
 
396
402
  case _:
@@ -1,5 +1,6 @@
1
- from . import relaxed_rigid, rigid, soft
1
+ from . import relaxed_rigid, rigid, soft, visco_elastic
2
2
  from .common import ContactModel, ContactsParams
3
3
  from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
4
4
  from .rigid import RigidContacts, RigidContactsParams
5
5
  from .soft import SoftContacts, SoftContactsParams
6
+ from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams