jaxsim 0.6.2.dev194__tar.gz → 0.6.2.dev225__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.github/workflows/gpu_benchmark.yml +1 -1
  2. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/PKG-INFO +3 -2
  3. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/_version.py +2 -2
  4. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/__init__.py +0 -1
  5. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/com.py +1 -3
  6. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/contact.py +140 -24
  7. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/data.py +39 -12
  8. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/integrators.py +16 -9
  9. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/model.py +19 -35
  10. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/ode.py +28 -6
  11. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/__init__.py +1 -1
  12. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/kinematic_graph.py +1 -1
  13. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/__init__.py +1 -1
  14. jaxsim-0.6.2.dev225/src/jaxsim/rbda/contacts/__init__.py +9 -0
  15. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/contacts/common.py +114 -4
  16. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/contacts/relaxed_rigid.py +51 -167
  17. jaxsim-0.6.2.dev225/src/jaxsim/rbda/contacts/rigid.py +538 -0
  18. jaxsim-0.6.2.dev225/src/jaxsim/rbda/contacts/soft.py +448 -0
  19. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/forward_kinematics.py +0 -29
  20. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/utils.py +2 -2
  21. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/PKG-INFO +3 -2
  22. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/SOURCES.txt +2 -2
  23. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_api_contact.py +32 -0
  24. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_automatic_differentiation.py +53 -0
  25. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_simulations.py +104 -1
  26. jaxsim-0.6.2.dev194/src/jaxsim/api/contact_model.py +0 -101
  27. jaxsim-0.6.2.dev194/src/jaxsim/rbda/contacts/__init__.py +0 -5
  28. jaxsim-0.6.2.dev194/tests/test_contact.py +0 -37
  29. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.devcontainer/Dockerfile +0 -0
  30. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.devcontainer/devcontainer.json +0 -0
  31. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.gitattributes +0 -0
  32. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.github/CODEOWNERS +0 -0
  33. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.github/dependabot.yml +0 -0
  34. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.github/workflows/ci_cd.yml +0 -0
  35. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.github/workflows/pixi.yml +0 -0
  36. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.github/workflows/read_the_docs.yml +0 -0
  37. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.gitignore +0 -0
  38. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.pre-commit-config.yaml +0 -0
  39. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/.readthedocs.yaml +0 -0
  40. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/CONTRIBUTING.md +0 -0
  41. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/LICENSE +0 -0
  42. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/README.md +0 -0
  43. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/Makefile +0 -0
  44. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/conf.py +0 -0
  45. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/examples.rst +0 -0
  46. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/guide/configuration.rst +0 -0
  47. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/guide/install.rst +0 -0
  48. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/index.rst +0 -0
  49. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/make.bat +0 -0
  50. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/modules/api.rst +0 -0
  51. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/modules/math.rst +0 -0
  52. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/modules/mujoco.rst +0 -0
  53. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/modules/parsers.rst +0 -0
  54. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/modules/rbda.rst +0 -0
  55. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/modules/typing.rst +0 -0
  56. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/docs/modules/utils.rst +0 -0
  57. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/environment.yml +0 -0
  58. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/examples/.gitattributes +0 -0
  59. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/examples/.gitignore +0 -0
  60. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/examples/README.md +0 -0
  61. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/examples/assets/build_cartpole_urdf.py +0 -0
  62. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/examples/assets/cartpole.urdf +0 -0
  63. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  64. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  65. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/examples/jaxsim_as_physics_engine_advanced.ipynb +0 -0
  66. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  67. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/pixi.lock +0 -0
  68. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/pyproject.toml +0 -0
  69. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/setup.cfg +0 -0
  70. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/setup.py +0 -0
  71. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/__init__.py +0 -0
  72. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/actuation_model.py +0 -0
  73. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/common.py +0 -0
  74. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/frame.py +0 -0
  75. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/joint.py +0 -0
  76. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  77. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/link.py +0 -0
  78. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/api/references.py +0 -0
  79. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/exceptions.py +0 -0
  80. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/logging.py +0 -0
  81. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/adjoint.py +0 -0
  82. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/cross.py +0 -0
  83. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/inertia.py +0 -0
  84. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/joint_model.py +0 -0
  85. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/quaternion.py +0 -0
  86. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/rotation.py +0 -0
  87. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/skew.py +0 -0
  88. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/transform.py +0 -0
  89. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/math/utils.py +0 -0
  90. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/__init__.py +0 -0
  91. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/__main__.py +0 -0
  92. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/loaders.py +0 -0
  93. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/model.py +0 -0
  94. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/utils.py +0 -0
  95. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/mujoco/visualizer.py +0 -0
  96. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/__init__.py +0 -0
  97. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  98. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  99. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  100. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/link.py +0 -0
  101. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/descriptions/model.py +0 -0
  102. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/rod/__init__.py +0 -0
  103. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/rod/meshes.py +0 -0
  104. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/rod/parser.py +0 -0
  105. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/parsers/rod/utils.py +0 -0
  106. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/aba.py +0 -0
  107. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/collidable_points.py +0 -0
  108. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/crba.py +0 -0
  109. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/jacobian.py +0 -0
  110. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/rbda/rnea.py +0 -0
  111. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/terrain/__init__.py +0 -0
  112. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/terrain/terrain.py +0 -0
  113. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/typing.py +0 -0
  114. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/utils/__init__.py +0 -0
  115. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  116. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/utils/tracing.py +0 -0
  117. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim/utils/wrappers.py +0 -0
  118. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  119. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/requires.txt +0 -0
  120. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/src/jaxsim.egg-info/top_level.txt +0 -0
  121. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/__init__.py +0 -0
  122. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/conftest.py +0 -0
  123. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_api_com.py +0 -0
  124. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_api_data.py +0 -0
  125. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_api_frame.py +0 -0
  126. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_api_joint.py +0 -0
  127. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_api_link.py +0 -0
  128. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_api_model.py +0 -0
  129. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_benchmark.py +0 -0
  130. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_exceptions.py +0 -0
  131. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_meshes.py +0 -0
  132. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_pytree.py +0 -0
  133. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/test_visualizer.py +0 -0
  134. {jaxsim-0.6.2.dev194 → jaxsim-0.6.2.dev225}/tests/utils_idyntree.py +0 -0
@@ -51,7 +51,7 @@ jobs:
51
51
 
52
52
  - name: Run benchmark and store result
53
53
  run: |
54
- pytest tests/test_benchmark.py -k 'not test_rigid_contact_model and not test_soft_contact_model' --gpu-only --batch-size 128 --benchmark-only --benchmark-json output.json
54
+ pytest tests/test_benchmark.py --gpu-only --batch-size 128 --benchmark-only --benchmark-json output.json
55
55
 
56
56
  - name: Compare benchmark results with main branch
57
57
  uses: benchmark-action/github-action-benchmark@v1.20.4
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.6.2.dev194
3
+ Version: 0.6.2.dev225
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -89,6 +89,7 @@ Requires-Dist: mujoco>=3.0.0; extra == "viz"
89
89
  Requires-Dist: scipy>=1.14.0; extra == "viz"
90
90
  Provides-Extra: all
91
91
  Requires-Dist: jaxsim[style,testing,viz]; extra == "all"
92
+ Dynamic: license-file
92
93
 
93
94
  # JaxSim
94
95
 
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.6.2.dev194'
21
- __version_tuple__ = version_tuple = (0, 6, 2, 'dev194')
20
+ __version__ = version = '0.6.2.dev225'
21
+ __version_tuple__ = version_tuple = (0, 6, 2, 'dev225')
@@ -4,7 +4,6 @@ from . import (
4
4
  actuation_model,
5
5
  com,
6
6
  contact,
7
- contact_model,
8
7
  frame,
9
8
  integrators,
10
9
  joint,
@@ -301,9 +301,7 @@ def bias_acceleration(
301
301
  C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
302
302
  C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
303
303
 
304
- L_H_C = L_H_W = jax.vmap( # noqa: F841
305
- lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
306
- )(W_H_L)
304
+ L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841
307
305
 
308
306
  L_v_LC = L_v_LW = jax.vmap( # noqa: F841
309
307
  lambda i: -js.link.velocity(
@@ -11,7 +11,7 @@ import jaxsim.terrain
11
11
  import jaxsim.typing as jtp
12
12
  from jaxsim import logging
13
13
  from jaxsim.math import Adjoint, Cross, Transform
14
- from jaxsim.rbda import contacts
14
+ from jaxsim.rbda.contacts import SoftContacts
15
15
 
16
16
  from .common import VelRepr
17
17
 
@@ -37,14 +37,11 @@ def collidable_point_kinematics(
37
37
  the linear component of the mixed 6D frame velocity.
38
38
  """
39
39
 
40
- # Switch to inertial-fixed since the RBDAs expect velocities in this representation.
41
- with data.switch_velocity_representation(VelRepr.Inertial):
42
-
43
- W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
44
- model=model,
45
- link_transforms=data._link_transforms,
46
- link_velocities=data._link_velocities,
47
- )
40
+ W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
41
+ model=model,
42
+ link_transforms=data._link_transforms,
43
+ link_velocities=data._link_velocities,
44
+ )
48
45
 
49
46
  return W_p_Ci, W_ṗ_Ci
50
47
 
@@ -164,18 +161,23 @@ def estimate_good_soft_contacts_parameters(
164
161
  def estimate_good_contact_parameters(
165
162
  model: js.model.JaxSimModel,
166
163
  *,
164
+ standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
167
165
  static_friction_coefficient: jtp.FloatLike = 0.5,
168
- **kwargs,
166
+ number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
167
+ damping_ratio: jtp.FloatLike = 1.0,
168
+ max_penetration: jtp.FloatLike | None = None,
169
169
  ) -> jaxsim.rbda.contacts.ContactParamsTypes:
170
170
  """
171
171
  Estimate good contact parameters.
172
172
 
173
173
  Args:
174
174
  model: The model to consider.
175
+ standard_gravity: The standard gravity acceleration.
175
176
  static_friction_coefficient: The static friction coefficient.
176
- kwargs:
177
- Additional model-specific parameters passed to the builder method of
178
- the parameters class.
177
+ number_of_active_collidable_points_steady_state:
178
+ The number of active collidable points in steady state.
179
+ damping_ratio: The damping ratio.
180
+ max_penetration: The maximum penetration allowed.
179
181
 
180
182
  Returns:
181
183
  The estimated good contacts parameters.
@@ -190,20 +192,41 @@ def estimate_good_contact_parameters(
190
192
  specific application.
191
193
  """
192
194
 
193
- match model.contact_model:
195
+ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
196
+ """
197
+ Displacement between the CoM and the lowest collidable point using zero
198
+ joint positions.
199
+ """
200
+
201
+ zero_data = js.data.JaxSimModelData.build(
202
+ model=model,
203
+ )
204
+
205
+ W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
194
206
 
195
- case contacts.RelaxedRigidContacts():
196
- assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
207
+ if model.floating_base():
208
+ W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
209
+ return 2 * (W_pz_CoM - W_pz_C.min())
197
210
 
198
- parameters = contacts.RelaxedRigidContactsParams.build(
199
- mu=static_friction_coefficient,
200
- **kwargs,
201
- )
211
+ return 2 * W_pz_CoM
202
212
 
203
- case _:
204
- raise ValueError(f"Invalid contact model: {model.contact_model}")
213
+ max_δ = (
214
+ max_penetration
215
+ if max_penetration is not None
216
+ # Consider as default a 0.5% of the model height.
217
+ else 0.005 * estimate_model_height(model=model)
218
+ )
205
219
 
206
- return parameters
220
+ nc = number_of_active_collidable_points_steady_state
221
+
222
+ return model.contact_model._parameters_class().build_default_from_jaxsim_model(
223
+ model=model,
224
+ standard_gravity=standard_gravity,
225
+ static_friction_coefficient=static_friction_coefficient,
226
+ max_penetration=max_δ,
227
+ number_of_active_collidable_points_steady_state=nc,
228
+ damping_ratio=damping_ratio,
229
+ )
207
230
 
208
231
 
209
232
  @jax.jit
@@ -244,7 +267,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
244
267
 
245
268
  # Build the link-to-point transform from the displacement between the link frame L
246
269
  # and the implicit contact frame C.
247
- L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)
270
+ L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci)
248
271
 
249
272
  # Compose the work-to-link and link-to-point transforms.
250
273
  return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
@@ -504,3 +527,96 @@ def jacobian_derivative(
504
527
  )
505
528
 
506
529
  return O_J̇_WC
530
+
531
+
532
+ @jax.jit
533
+ @js.common.named_scope
534
+ def link_contact_forces(
535
+ model: js.model.JaxSimModel,
536
+ data: js.data.JaxSimModelData,
537
+ *,
538
+ link_forces: jtp.MatrixLike | None = None,
539
+ joint_torques: jtp.VectorLike | None = None,
540
+ ) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
541
+ """
542
+ Compute the 6D contact forces of all links of the model in inertial representation.
543
+
544
+ Args:
545
+ model: The model to consider.
546
+ data: The data of the considered model.
547
+ link_forces:
548
+ The 6D external forces to apply to the links expressed in inertial representation
549
+ joint_torques:
550
+ The joint torques acting on the joints.
551
+
552
+ Returns:
553
+ A `(nL, 6)` array containing the stacked 6D contact forces of the links,
554
+ expressed in inertial representation.
555
+ """
556
+
557
+ # Compute the contact forces for each collidable point with the active contact model.
558
+ W_f_C, aux_dict = model.contact_model.compute_contact_forces(
559
+ model=model,
560
+ data=data,
561
+ **(
562
+ dict(link_forces=link_forces, joint_force_references=joint_torques)
563
+ if not isinstance(model.contact_model, SoftContacts)
564
+ else {}
565
+ ),
566
+ )
567
+
568
+ # Compute the 6D forces applied to the links equivalent to the forces applied
569
+ # to the frames associated to the collidable points.
570
+ W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)
571
+
572
+ return W_f_L, aux_dict
573
+
574
+
575
+ @staticmethod
576
+ def link_forces_from_contact_forces(
577
+ model: js.model.JaxSimModel,
578
+ *,
579
+ contact_forces: jtp.MatrixLike,
580
+ ) -> jtp.Matrix:
581
+ """
582
+ Compute the link forces from the contact forces.
583
+
584
+ Args:
585
+ model: The robot model considered by the contact model.
586
+ contact_forces: The contact forces computed by the contact model.
587
+
588
+ Returns:
589
+ The 6D contact forces applied to the links and expressed in the frame of
590
+ the velocity representation of data.
591
+ """
592
+
593
+ # Get the object storing the contact parameters of the model.
594
+ contact_parameters = model.kin_dyn_parameters.contact_parameters
595
+
596
+ # Extract the indices corresponding to the enabled collidable points.
597
+ indices_of_enabled_collidable_points = (
598
+ contact_parameters.indices_of_enabled_collidable_points
599
+ )
600
+
601
+ # Convert the contact forces to a JAX array.
602
+ W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
603
+
604
+ # Construct the vector defining the parent link index of each collidable point.
605
+ # We use this vector to sum the 6D forces of all collidable points rigidly
606
+ # attached to the same link.
607
+ parent_link_index_of_collidable_points = jnp.array(
608
+ contact_parameters.body, dtype=int
609
+ )[indices_of_enabled_collidable_points]
610
+
611
+ # Create the mask that associate each collidable point to their parent link.
612
+ # We use this mask to sum the collidable points to the right link.
613
+ mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
614
+ model.number_of_links()
615
+ )
616
+
617
+ # Sum the forces of all collidable points rigidly attached to a body.
618
+ # Since the contact forces W_f_C are expressed in the world frame,
619
+ # we don't need any coordinate transformation.
620
+ W_f_L = mask.T @ W_f_C
621
+
622
+ return W_f_L
@@ -5,9 +5,9 @@ import functools
5
5
  from collections.abc import Sequence
6
6
 
7
7
  try:
8
- from typing import override
8
+ from typing import Self, override
9
9
  except ImportError:
10
- from typing_extensions import override
10
+ from typing_extensions import override, Self
11
11
 
12
12
  import jax
13
13
  import jax.numpy as jnp
@@ -22,11 +22,6 @@ import jaxsim.typing as jtp
22
22
  from . import common
23
23
  from .common import VelRepr
24
24
 
25
- try:
26
- from typing import Self
27
- except ImportError:
28
- from typing_extensions import Self
29
-
30
25
 
31
26
  @jax_dataclasses.pytree_dataclass
32
27
  class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
@@ -64,6 +59,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
64
59
  _link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
65
60
  _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
66
61
 
62
+ # Extended state for soft and rigid contact models.
63
+ contact_state: dict[str, jtp.Array] = dataclasses.field(default=None)
64
+
67
65
  @staticmethod
68
66
  def build(
69
67
  model: js.model.JaxSimModel,
@@ -73,6 +71,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
73
71
  base_linear_velocity: jtp.VectorLike | None = None,
74
72
  base_angular_velocity: jtp.VectorLike | None = None,
75
73
  joint_velocities: jtp.VectorLike | None = None,
74
+ contact_state: dict[str, jtp.Array] | None = None,
76
75
  velocity_representation: VelRepr = VelRepr.Mixed,
77
76
  ) -> JaxSimModelData:
78
77
  """
@@ -89,6 +88,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
89
88
  The base angular velocity in the selected representation.
90
89
  joint_velocities: The joint velocities.
91
90
  velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
91
+ contact_state: The optional contact state.
92
92
 
93
93
  Returns:
94
94
  A `JaxSimModelData` initialized with the given state.
@@ -171,6 +171,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
171
171
  )
172
172
  )
173
173
 
174
+ contact_state = contact_state or {}
175
+
176
+ if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
177
+ contact_state.setdefault(
178
+ "tangential_deformation",
179
+ jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
180
+ )
181
+
174
182
  model_data = JaxSimModelData(
175
183
  velocity_representation=velocity_representation,
176
184
  _base_quaternion=base_quaternion,
@@ -183,6 +191,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
183
191
  _joint_transforms=joint_transforms,
184
192
  _link_transforms=link_transforms,
185
193
  _link_velocities=link_velocities_inertial,
194
+ contact_state=contact_state,
186
195
  )
187
196
 
188
197
  if not model_data.valid(model=model):
@@ -347,11 +356,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
347
356
 
348
357
  @js.common.named_scope
349
358
  @jax.jit
350
- def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
359
+ def reset_base_quaternion(
360
+ self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike
361
+ ) -> Self:
351
362
  """
352
363
  Reset the base quaternion.
353
364
 
354
365
  Args:
366
+ model: The JaxSim model to use.
355
367
  base_quaternion: The base orientation as a quaternion.
356
368
 
357
369
  Returns:
@@ -363,15 +375,18 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
363
375
  norm = jaxsim.math.safe_norm(W_Q_B, axis=-1)
364
376
  W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
365
377
 
366
- return self.replace(validate=True, base_quaternion=W_Q_B)
378
+ return self.replace(model=model, base_quaternion=W_Q_B)
367
379
 
368
380
  @js.common.named_scope
369
381
  @jax.jit
370
- def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
382
+ def reset_base_pose(
383
+ self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike
384
+ ) -> Self:
371
385
  """
372
386
  Reset the base pose.
373
387
 
374
388
  Args:
389
+ model: The JaxSim model to use.
375
390
  base_pose: The base pose as an SE(3) matrix.
376
391
 
377
392
  Returns:
@@ -382,6 +397,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
382
397
  W_p_B = base_pose[0:3, 3]
383
398
  W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
384
399
  return self.replace(
400
+ model=model,
385
401
  base_position=W_p_B,
386
402
  base_quaternion=W_Q_B,
387
403
  )
@@ -396,6 +412,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
396
412
  base_linear_velocity: jtp.Vector | None = None,
397
413
  base_angular_velocity: jtp.Vector | None = None,
398
414
  base_position: jtp.Vector | None = None,
415
+ *,
416
+ contact_state: dict[str, jtp.Array] | None = None,
399
417
  validate: bool = False,
400
418
  ) -> Self:
401
419
  """
@@ -415,6 +433,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
415
433
  base_quaternion = self.base_quaternion
416
434
  if base_position is None:
417
435
  base_position = self.base_position
436
+ if contact_state is None:
437
+ contact_state = self.contact_state
438
+
439
+ if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
440
+ contact_state.setdefault(
441
+ "tangential_deformation",
442
+ jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
443
+ )
418
444
 
419
445
  # Normalize the quaternion to avoid numerical issues.
420
446
  base_quaternion_norm = jaxsim.math.safe_norm(
@@ -486,8 +512,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
486
512
 
487
513
  # Adjust the output shapes.
488
514
  if batch_size == 1:
489
- link_transforms = link_transforms.reshape(link_transforms.shape[1:])
490
- link_velocities = link_velocities.reshape(link_velocities.shape[1:])
515
+ link_transforms = link_transforms.reshape(self._link_transforms.shape)
516
+ link_velocities = link_velocities.reshape(self._link_velocities.shape)
517
+ joint_transforms = joint_transforms.reshape(self._joint_transforms.shape)
491
518
 
492
519
  return super().replace(
493
520
  _joint_positions=joint_positions,
@@ -22,7 +22,7 @@ def semi_implicit_euler_integration(
22
22
  with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
23
23
 
24
24
  # Compute the system acceleration
25
- W_v̇_WB, s̈ = js.ode.system_acceleration(
25
+ W_v̇_WB, s̈, contact_state_derivative = js.ode.system_acceleration(
26
26
  model=model,
27
27
  data=data,
28
28
  link_forces=link_forces,
@@ -64,6 +64,12 @@ def semi_implicit_euler_integration(
64
64
 
65
65
  s = data.joint_positions + dt * ṡ
66
66
 
67
+ integrated_contact_state = jax.tree.map(
68
+ lambda x, x_dot: x + dt * x_dot,
69
+ data.contact_state,
70
+ contact_state_derivative,
71
+ )
72
+
67
73
  # TODO: Avoid double replace, e.g. by computing cached value here
68
74
  data = dataclasses.replace(
69
75
  data,
@@ -73,6 +79,7 @@ def semi_implicit_euler_integration(
73
79
  _joint_velocities=ṡ,
74
80
  _base_linear_velocity=W_v_B[0:3],
75
81
  _base_angular_velocity=W_ω_WB,
82
+ contact_state=integrated_contact_state,
76
83
  )
77
84
 
78
85
  # Update the cached computations.
@@ -116,6 +123,7 @@ def rk4_integration(
116
123
  base_linear_velocity=data._base_linear_velocity,
117
124
  base_angular_velocity=data._base_angular_velocity,
118
125
  joint_velocities=data._joint_velocities,
126
+ contact_state=data.contact_state,
119
127
  )
120
128
 
121
129
  euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt
@@ -136,14 +144,13 @@ def rk4_integration(
136
144
 
137
145
  data_tf = dataclasses.replace(
138
146
  data,
139
- **{
140
- "_base_position": x_tf["base_position"],
141
- "_base_quaternion": x_tf["base_quaternion"],
142
- "_joint_positions": x_tf["joint_positions"],
143
- "_base_linear_velocity": x_tf["base_linear_velocity"],
144
- "_base_angular_velocity": x_tf["base_angular_velocity"],
145
- "_joint_velocities": x_tf["joint_velocities"],
146
- },
147
+ _base_position=x_tf["base_position"],
148
+ _base_quaternion=x_tf["base_quaternion"],
149
+ _joint_positions=x_tf["joint_positions"],
150
+ _base_linear_velocity=x_tf["base_linear_velocity"],
151
+ _base_angular_velocity=x_tf["base_angular_velocity"],
152
+ _joint_velocities=x_tf["joint_velocities"],
153
+ contact_state=x_tf["contact_state"],
147
154
  )
148
155
 
149
156
  return data_tf.replace(model=model)
@@ -47,13 +47,13 @@ class JaxSimModel(JaxsimDataclass):
47
47
  default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
48
48
  )
49
49
 
50
- gravity: Static[float] = jaxsim.math.STANDARD_GRAVITY
50
+ gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY
51
51
 
52
52
  contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(
53
53
  default=None, repr=False
54
54
  )
55
55
 
56
- contacts_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field(
56
+ contact_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field(
57
57
  default=None, repr=False
58
58
  )
59
59
 
@@ -177,9 +177,9 @@ class JaxSimModel(JaxsimDataclass):
177
177
  time_step=time_step,
178
178
  terrain=terrain,
179
179
  contact_model=contact_model,
180
- contacts_params=contact_params,
180
+ contact_params=contact_params,
181
181
  integrator=integrator,
182
- gravity=gravity,
182
+ gravity=-gravity,
183
183
  )
184
184
 
185
185
  # Store the origin of the model, in case downstream logic needs it.
@@ -197,7 +197,7 @@ class JaxSimModel(JaxsimDataclass):
197
197
  time_step: jtp.FloatLike | None = None,
198
198
  terrain: jaxsim.terrain.Terrain | None = None,
199
199
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
200
- contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
200
+ contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
201
201
  integrator: IntegratorType | None = None,
202
202
  gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
203
203
  ) -> JaxSimModel:
@@ -217,8 +217,8 @@ class JaxSimModel(JaxsimDataclass):
217
217
  The optional name of the model overriding the physics model name.
218
218
  contact_model:
219
219
  The contact model to consider.
220
- If not specified, a soft contacts model is used.
221
- contacts_params: The parameters of the soft contacts.
220
+ If not specified, a relaxed-constraints rigid contacts model is used.
221
+ contact_params: The parameters of the contact model.
222
222
  integrator: The integrator to use for the simulation.
223
223
  gravity: The gravity constant.
224
224
 
@@ -252,8 +252,8 @@ class JaxSimModel(JaxsimDataclass):
252
252
  else jaxsim.rbda.contacts.RelaxedRigidContacts.build()
253
253
  )
254
254
 
255
- if contacts_params is None:
256
- contacts_params = contact_model._parameters_class()
255
+ if contact_params is None:
256
+ contact_params = contact_model._parameters_class()
257
257
 
258
258
  # Consider the default integrator if not specified.
259
259
  integrator = (
@@ -271,7 +271,7 @@ class JaxSimModel(JaxsimDataclass):
271
271
  time_step=time_step,
272
272
  terrain=terrain,
273
273
  contact_model=contact_model,
274
- contacts_params=contacts_params,
274
+ contact_params=contact_params,
275
275
  integrator=integrator,
276
276
  gravity=gravity,
277
277
  # The following is wrapped as hashless since it's a static argument, and we
@@ -473,7 +473,7 @@ def reduce(
473
473
  time_step=model.time_step,
474
474
  terrain=model.terrain,
475
475
  contact_model=model.contact_model,
476
- contacts_params=model.contacts_params,
476
+ contact_params=model.contact_params,
477
477
  gravity=model.gravity,
478
478
  integrator=model.integrator,
479
479
  )
@@ -2090,29 +2090,6 @@ def step(
2090
2090
  model, data, joint_force_references=τ_references
2091
2091
  )
2092
2092
 
2093
- # ======================
2094
- # Compute contact forces
2095
- # ======================
2096
-
2097
- W_f_L_terrain = jnp.zeros_like(W_f_L_external)
2098
-
2099
- if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
2100
-
2101
- # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
2102
- # with the terrain.
2103
- W_f_L_terrain = js.contact_model.link_contact_forces(
2104
- model=model,
2105
- data=data,
2106
- link_forces=W_f_L_external,
2107
- joint_torques=τ_total,
2108
- )
2109
-
2110
- # ==============================
2111
- # Compute the total link forces
2112
- # ==============================
2113
-
2114
- W_f_L_total = W_f_L_external + W_f_L_terrain
2115
-
2116
2093
  # =============================
2117
2094
  # Advance the simulation state
2118
2095
  # =============================
@@ -2122,7 +2099,14 @@ def step(
2122
2099
  integrator_fn = _INTEGRATORS_MAP[model.integrator]
2123
2100
 
2124
2101
  data_tf = integrator_fn(
2125
- model=model, data=data, link_forces=W_f_L_total, joint_torques=τ_total
2102
+ model=model,
2103
+ data=data,
2104
+ link_forces=W_f_L_external,
2105
+ joint_torques=τ_total,
2106
+ )
2107
+
2108
+ data_tf = model.contact_model.update_velocity_after_impact(
2109
+ model=model, data=data_tf
2126
2110
  )
2127
2111
 
2128
2112
  return data_tf
@@ -46,12 +46,36 @@ def system_acceleration(
46
46
  else jnp.zeros((model.number_of_links(), 6))
47
47
  ).astype(float)
48
48
 
49
+ # ======================
50
+ # Compute contact forces
51
+ # ======================
52
+
53
+ W_f_L_terrain = jnp.zeros_like(f_L)
54
+ contact_state_derivative = {}
55
+
56
+ if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
57
+
58
+ # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
59
+ # with the terrain.
60
+ W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces(
61
+ model=model,
62
+ data=data,
63
+ link_forces=f_L,
64
+ joint_torques=joint_torques,
65
+ )
66
+
67
+ W_f_L_total = f_L + W_f_L_terrain
68
+
69
+ # Update the contact state data. This is necessary only for the contact models
70
+ # that require propagation and integration of contact state.
71
+ contact_state = model.contact_model.update_contact_state(contact_state_derivative)
72
+
49
73
  # Store the link forces in a references object.
50
74
  references = js.references.JaxSimModelReferences.build(
51
75
  model=model,
52
76
  data=data,
53
77
  velocity_representation=data.velocity_representation,
54
- link_forces=f_L,
78
+ link_forces=W_f_L_total,
55
79
  )
56
80
 
57
81
  # Compute forward dynamics.
@@ -68,13 +92,12 @@ def system_acceleration(
68
92
  link_forces=references.link_forces(model=model, data=data),
69
93
  )
70
94
 
71
- return v̇_WB, s̈
95
+ return v̇_WB, s̈, contact_state
72
96
 
73
97
 
74
98
  @jax.jit
75
99
  @js.common.named_scope
76
100
  def system_position_dynamics(
77
- model: js.model.JaxSimModel,
78
101
  data: js.data.JaxSimModelData,
79
102
  baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
80
103
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
@@ -82,7 +105,6 @@ def system_position_dynamics(
82
105
  Compute the dynamics of the system position.
83
106
 
84
107
  Args:
85
- model: The model to consider.
86
108
  data: The data of the considered model.
87
109
  baumgarte_quaternion_regularization:
88
110
  The Baumgarte regularization coefficient for adjusting the quaternion norm.
@@ -144,7 +166,7 @@ def system_dynamics(
144
166
  """
145
167
 
146
168
  with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
147
- W_v̇_WB, s̈ = system_acceleration(
169
+ W_v̇_WB, s̈, contact_state_derivative = system_acceleration(
148
170
  model=model,
149
171
  data=data,
150
172
  joint_torques=joint_torques,
@@ -152,7 +174,6 @@ def system_dynamics(
152
174
  )
153
175
 
154
176
  W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
155
- model=model,
156
177
  data=data,
157
178
  baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
158
179
  )
@@ -164,4 +185,5 @@ def system_dynamics(
164
185
  base_linear_velocity=W_v̇_WB[0:3],
165
186
  base_angular_velocity=W_v̇_WB[3:6],
166
187
  joint_velocities=s̈,
188
+ contact_state=contact_state_derivative,
167
189
  )