jaxsim 0.4.3.dev229__tar.gz → 0.4.3.dev242__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/PKG-INFO +1 -1
  2. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/examples/jaxsim_as_multibody_dynamics_library.ipynb +2 -2
  3. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/_version.py +2 -2
  4. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/contact.py +48 -77
  5. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/model.py +87 -59
  6. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/ode.py +25 -34
  7. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/contacts/common.py +137 -3
  8. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/contacts/relaxed_rigid.py +48 -15
  9. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/contacts/rigid.py +26 -9
  10. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/contacts/soft.py +9 -5
  11. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/contacts/visco_elastic.py +94 -52
  12. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim.egg-info/PKG-INFO +1 -1
  13. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.devcontainer/Dockerfile +0 -0
  14. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.devcontainer/devcontainer.json +0 -0
  15. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.gitattributes +0 -0
  16. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.github/CODEOWNERS +0 -0
  17. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.github/dependabot.yml +0 -0
  18. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.github/workflows/ci_cd.yml +0 -0
  19. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.github/workflows/read_the_docs.yml +0 -0
  20. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.github/workflows/update_pixi_lockfile.yml +0 -0
  21. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.gitignore +0 -0
  22. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.pre-commit-config.yaml +0 -0
  23. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/.readthedocs.yaml +0 -0
  24. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/CONTRIBUTING.md +0 -0
  25. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/LICENSE +0 -0
  26. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/README.md +0 -0
  27. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/Makefile +0 -0
  28. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/conf.py +0 -0
  29. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/examples.rst +0 -0
  30. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/guide/install.rst +0 -0
  31. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/index.rst +0 -0
  32. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/make.bat +0 -0
  33. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/modules/api.rst +0 -0
  34. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/modules/integrators.rst +0 -0
  35. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/modules/math.rst +0 -0
  36. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/modules/mujoco.rst +0 -0
  37. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/modules/parsers.rst +0 -0
  38. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/modules/rbda.rst +0 -0
  39. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/modules/typing.rst +0 -0
  40. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/docs/modules/utils.rst +0 -0
  41. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/environment.yml +0 -0
  42. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/examples/.gitattributes +0 -0
  43. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/examples/.gitignore +0 -0
  44. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/examples/README.md +0 -0
  45. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/examples/assets/build_cartpole_urdf.py +0 -0
  46. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/examples/assets/cartpole.urdf +0 -0
  47. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  48. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  49. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/pixi.lock +0 -0
  50. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/pyproject.toml +0 -0
  51. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/setup.cfg +0 -0
  52. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/setup.py +0 -0
  53. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/__init__.py +0 -0
  54. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/__init__.py +0 -0
  55. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/com.py +0 -0
  56. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/common.py +0 -0
  57. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/data.py +0 -0
  58. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/frame.py +0 -0
  59. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/joint.py +0 -0
  60. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  61. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/link.py +0 -0
  62. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/ode_data.py +0 -0
  63. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/api/references.py +0 -0
  64. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/exceptions.py +0 -0
  65. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/integrators/__init__.py +0 -0
  66. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/integrators/common.py +0 -0
  67. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/integrators/fixed_step.py +0 -0
  68. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/integrators/variable_step.py +0 -0
  69. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/logging.py +0 -0
  70. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/math/__init__.py +0 -0
  71. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/math/adjoint.py +0 -0
  72. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/math/cross.py +0 -0
  73. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/math/inertia.py +0 -0
  74. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/math/joint_model.py +0 -0
  75. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/math/quaternion.py +0 -0
  76. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/math/rotation.py +0 -0
  77. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/math/skew.py +0 -0
  78. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/math/transform.py +0 -0
  79. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/mujoco/__init__.py +0 -0
  80. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/mujoco/__main__.py +0 -0
  81. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/mujoco/loaders.py +0 -0
  82. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/mujoco/model.py +0 -0
  83. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/mujoco/utils.py +0 -0
  84. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/mujoco/visualizer.py +0 -0
  85. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/__init__.py +0 -0
  86. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  87. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  88. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  89. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/descriptions/link.py +0 -0
  90. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/descriptions/model.py +0 -0
  91. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  92. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/rod/__init__.py +0 -0
  93. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/rod/parser.py +0 -0
  94. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/parsers/rod/utils.py +0 -0
  95. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/__init__.py +0 -0
  96. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/aba.py +0 -0
  97. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/collidable_points.py +0 -0
  98. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  99. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/crba.py +0 -0
  100. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  101. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/jacobian.py +0 -0
  102. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/rnea.py +0 -0
  103. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/rbda/utils.py +0 -0
  104. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/terrain/__init__.py +0 -0
  105. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/terrain/terrain.py +0 -0
  106. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/typing.py +0 -0
  107. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/utils/__init__.py +0 -0
  108. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  109. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/utils/tracing.py +0 -0
  110. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim/utils/wrappers.py +0 -0
  111. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  112. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  113. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim.egg-info/requires.txt +0 -0
  114. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/src/jaxsim.egg-info/top_level.txt +0 -0
  115. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/__init__.py +0 -0
  116. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/conftest.py +0 -0
  117. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_api_com.py +0 -0
  118. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_api_contact.py +0 -0
  119. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_api_data.py +0 -0
  120. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_api_frame.py +0 -0
  121. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_api_joint.py +0 -0
  122. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_api_link.py +0 -0
  123. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_api_model.py +0 -0
  124. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_automatic_differentiation.py +0 -0
  125. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_contact.py +0 -0
  126. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_exceptions.py +0 -0
  127. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_pytree.py +0 -0
  128. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/test_simulations.py +0 -0
  129. {jaxsim-0.4.3.dev229 → jaxsim-0.4.3.dev242}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev229
3
+ Version: 0.4.3.dev242
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -728,7 +728,7 @@
728
728
  "Key methodologies in this area may involve the Delassus matrix:\n",
729
729
  "\n",
730
730
  "$$\n",
731
- "\\Psi(\\mathbf{q}) = J_{W,C}^T(\\mathbf{q}) \\, M(\\mathbf{q})^{-1} \\, J_{W,C}^T(\\mathbf{q})\n",
731
+ "\\Psi(\\mathbf{q}) = J_{W,C}(\\mathbf{q}) \\, M(\\mathbf{q})^{-1} \\, J_{W,C}^T(\\mathbf{q})\n",
732
732
  "$$\n",
733
733
  "\n",
734
734
  "or the linear acceleration of a contact point:\n",
@@ -780,7 +780,7 @@
780
780
  " J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]\n",
781
781
  "\n",
782
782
  "# Compute the Delassus matrix.\n",
783
- "Ψ = jnp.vstack(Jl_WC) @ jnp.linalg.lstsq(BW_M, jnp.vstack(J̇l_WC).T)[0]\n",
783
+ "Ψ = jnp.vstack(Jl_WC) @ jnp.linalg.lstsq(BW_M, jnp.vstack(Jl_WC).T)[0]\n",
784
784
  "print(f\"Ψ: shape={Ψ.shape}\")\n",
785
785
  "\n",
786
786
  "# Compute the transforms of the mixed frames implicitly associated\n",
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev229'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev229')
15
+ __version__ = version = '0.4.3.dev242'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev242')
@@ -98,6 +98,7 @@ def collidable_point_forces(
98
98
  data: js.data.JaxSimModelData,
99
99
  link_forces: jtp.MatrixLike | None = None,
100
100
  joint_force_references: jtp.VectorLike | None = None,
101
+ **kwargs,
101
102
  ) -> jtp.Matrix:
102
103
  """
103
104
  Compute the 6D forces applied to each collidable point.
@@ -110,6 +111,7 @@ def collidable_point_forces(
110
111
  representation of data.
111
112
  joint_force_references:
112
113
  The joint force references to apply to the joints.
114
+ kwargs: Additional keyword arguments to pass to the active contact model.
113
115
 
114
116
  Returns:
115
117
  The 6D forces applied to each collidable point expressed in the frame
@@ -121,6 +123,7 @@ def collidable_point_forces(
121
123
  data=data,
122
124
  link_forces=link_forces,
123
125
  joint_force_references=joint_force_references,
126
+ **kwargs,
124
127
  )
125
128
 
126
129
  return f_Ci
@@ -132,7 +135,8 @@ def collidable_point_dynamics(
132
135
  data: js.data.JaxSimModelData,
133
136
  link_forces: jtp.MatrixLike | None = None,
134
137
  joint_force_references: jtp.VectorLike | None = None,
135
- ) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
138
+ **kwargs,
139
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
136
140
  r"""
137
141
  Compute the 6D force applied to each collidable point.
138
142
 
@@ -144,6 +148,7 @@ def collidable_point_dynamics(
144
148
  representation of data.
145
149
  joint_force_references:
146
150
  The joint force references to apply to the joints.
151
+ kwargs: Additional keyword arguments to pass to the active contact model.
147
152
 
148
153
  Returns:
149
154
  The 6D force applied to each collidable point and additional data based
@@ -158,86 +163,46 @@ def collidable_point_dynamics(
158
163
  Instead, the 6D forces are returned in the active representation.
159
164
  """
160
165
 
161
- # Build the soft contact model.
166
+ # Build the common kw arguments to pass to the computation of the contact forces.
167
+ common_kwargs = dict(
168
+ link_forces=link_forces,
169
+ joint_force_references=joint_force_references,
170
+ )
171
+
172
+ # Build the additional kwargs to pass to the computation of the contact forces.
162
173
  match model.contact_model:
163
174
 
164
175
  case contacts.SoftContacts():
165
- assert isinstance(model.contact_model, contacts.SoftContacts)
166
176
 
167
- # Compute the 6D force expressed in the inertial frame and applied to each
168
- # collidable point, and the corresponding material deformation rate.
169
- # Note that the material deformation rate is always returned in the mixed frame
170
- # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
171
- W_f_Ci, (CW_ṁ,) = model.contact_model.compute_contact_forces(
172
- model=model, data=data
173
- )
174
-
175
- # Create the dictionary of auxiliary data.
176
- # This contact model considers the material deformation as additional state
177
- # of the ODE system. We need to pass its dynamics to the integrator.
178
- aux_data = dict(m_dot=CW_ṁ)
177
+ kwargs_contact_model = {}
179
178
 
180
179
  case contacts.RigidContacts():
181
- assert isinstance(model.contact_model, contacts.RigidContacts)
182
180
 
183
- # Compute the 6D force expressed in the inertial frame and applied to each
184
- # collidable point.
185
- W_f_Ci, _ = model.contact_model.compute_contact_forces(
186
- model=model,
187
- data=data,
188
- link_forces=link_forces,
189
- joint_force_references=joint_force_references,
190
- )
191
-
192
- aux_data = dict()
181
+ kwargs_contact_model = common_kwargs | kwargs
193
182
 
194
183
  case contacts.RelaxedRigidContacts():
195
- assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
196
184
 
197
- # Compute the 6D force expressed in the inertial frame and applied to each
198
- # collidable point.
199
- W_f_Ci, _ = model.contact_model.compute_contact_forces(
200
- model=model,
201
- data=data,
202
- link_forces=link_forces,
203
- joint_force_references=joint_force_references,
204
- )
205
-
206
- aux_data = dict()
185
+ kwargs_contact_model = common_kwargs | kwargs
207
186
 
208
187
  case contacts.ViscoElasticContacts():
209
- assert isinstance(model.contact_model, contacts.ViscoElasticContacts)
210
188
 
211
- # It is not yet clear how to pass the time step to this stage.
212
- # A possibility is to restrict the integrator to only forward Euler
213
- # and store the Δt inside the model.
214
- module = jaxsim.rbda.contacts.visco_elastic.step.__module__
215
- name = jaxsim.rbda.contacts.visco_elastic.step.__name__
216
- msg = "You need to use the custom '{}.{}' function with this contact model."
217
- jaxsim.exceptions.raise_runtime_error_if(
218
- condition=True, msg=msg.format(module, name)
219
- )
220
-
221
- # Compute the 6D force expressed in the inertial frame and applied to each
222
- # collidable point.
223
- W_f_Ci, (W_f̿_Ci, m_tf) = model.contact_model.compute_contact_forces(
224
- model=model,
225
- data=data,
226
- dt=None, # TODO
227
- link_forces=link_forces,
228
- joint_force_references=joint_force_references,
229
- )
230
-
231
- aux_data = dict(W_f_avg2_C=W_f̿_Ci, m_tf=m_tf)
189
+ kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs
232
190
 
233
191
  case _:
234
- raise ValueError(f"Invalid contact model {model.contact_model}")
192
+ raise ValueError(f"Invalid contact model: {model.contact_model}")
193
+
194
+ # Compute the contact forces with the active contact model.
195
+ W_f_C, aux_data = model.contact_model.compute_contact_forces(
196
+ model=model,
197
+ data=data,
198
+ **kwargs_contact_model,
199
+ )
235
200
 
236
201
  # Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
237
202
  # associated to each collidable point.
238
203
  # In inertial-fixed representation, the computation of these transforms
239
204
  # is not necessary and the conversion below becomes a no-op.
240
- W_H_Ci = (
205
+ W_H_C = (
241
206
  js.contact.transforms(model=model, data=data)
242
207
  if data.velocity_representation is not VelRepr.Inertial
243
208
  else jnp.zeros(
@@ -253,7 +218,7 @@ def collidable_point_dynamics(
253
218
  transform=W_H_C,
254
219
  is_force=True,
255
220
  )
256
- )(W_f_Ci, W_H_Ci)
221
+ )(W_f_C, W_H_C)
257
222
 
258
223
  return f_Ci, aux_data
259
224
 
@@ -392,11 +357,13 @@ def estimate_good_contact_parameters(
392
357
  max_penetration=max_δ,
393
358
  number_of_active_collidable_points_steady_state=nc,
394
359
  damping_ratio=damping_ratio,
395
- **dict(
396
- p=model.contact_model.parameters.p,
397
- q=model.contact_model.parameters.q,
398
- )
399
- | kwargs,
360
+ **(
361
+ dict(
362
+ p=model.contact_model.parameters.p,
363
+ q=model.contact_model.parameters.q,
364
+ )
365
+ | kwargs
366
+ ),
400
367
  )
401
368
 
402
369
  case contacts.ViscoElasticContacts():
@@ -410,11 +377,13 @@ def estimate_good_contact_parameters(
410
377
  max_penetration=max_δ,
411
378
  number_of_active_collidable_points_steady_state=nc,
412
379
  damping_ratio=damping_ratio,
413
- **dict(
414
- p=model.contact_model.parameters.p,
415
- q=model.contact_model.parameters.q,
416
- )
417
- | kwargs,
380
+ **(
381
+ dict(
382
+ p=model.contact_model.parameters.p,
383
+ q=model.contact_model.parameters.q,
384
+ )
385
+ | kwargs
386
+ ),
418
387
  )
419
388
  )
420
389
 
@@ -427,11 +396,13 @@ def estimate_good_contact_parameters(
427
396
 
428
397
  parameters = contacts.RigidContactsParams.build(
429
398
  mu=static_friction_coefficient,
430
- **dict(
431
- K=K,
432
- D=2 * jnp.sqrt(K),
433
- )
434
- | kwargs,
399
+ **(
400
+ dict(
401
+ K=K,
402
+ D=2 * jnp.sqrt(K),
403
+ )
404
+ | kwargs
405
+ ),
435
406
  )
436
407
 
437
408
  case contacts.RelaxedRigidContacts():
@@ -1770,8 +1770,10 @@ def link_bias_accelerations(
1770
1770
  def link_contact_forces(
1771
1771
  model: js.model.JaxSimModel,
1772
1772
  data: js.data.JaxSimModelData,
1773
+ *,
1773
1774
  link_forces: jtp.MatrixLike | None = None,
1774
1775
  joint_force_references: jtp.VectorLike | None = None,
1776
+ **kwargs,
1775
1777
  ) -> jtp.Matrix:
1776
1778
  """
1777
1779
  Compute the 6D contact forces of all links of the model.
@@ -1784,6 +1786,7 @@ def link_contact_forces(
1784
1786
  representation of data.
1785
1787
  joint_force_references:
1786
1788
  The joint force references to apply to the joints.
1789
+ kwargs: Additional keyword arguments to pass to the active contact model..
1787
1790
 
1788
1791
  Returns:
1789
1792
  A `(nL, 6)` array containing the stacked 6D contact forces of the links,
@@ -1820,47 +1823,16 @@ def link_contact_forces(
1820
1823
  joint_force_references=joint_force_references,
1821
1824
  )
1822
1825
 
1823
- # Compute the 6D forces applied to each collidable point expressed in the
1824
- # inertial frame.
1825
- with (
1826
- data.switch_velocity_representation(VelRepr.Inertial),
1827
- input_references.switch_velocity_representation(VelRepr.Inertial),
1828
- ):
1829
- W_f_C = js.contact.collidable_point_forces(
1830
- model=model,
1831
- data=data,
1832
- link_forces=input_references.link_forces(),
1833
- joint_force_references=input_references.joint_force_references(),
1834
- )
1835
-
1836
- # Construct the vector defining the parent link index of each collidable point.
1837
- # We use this vector to sum the 6D forces of all collidable points rigidly
1838
- # attached to the same link.
1839
- parent_link_index_of_collidable_points = jnp.array(
1840
- model.kin_dyn_parameters.contact_parameters.body, dtype=int
1841
- )
1842
-
1843
- # Create the mask that associate each collidable point to their parent link.
1844
- # We use this mask to sum the collidable points to the right link.
1845
- mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
1846
- model.number_of_links()
1847
- )
1848
-
1849
- # Sum the forces of all collidable points rigidly attached to a body.
1850
- # Since the contact forces W_f_C are expressed in the world frame,
1851
- # we don't need any coordinate transformation.
1852
- W_f_L = mask.T @ W_f_C
1853
-
1854
- # Create a references object to store the link forces.
1855
- references = js.references.JaxSimModelReferences.build(
1856
- model=model, link_forces=W_f_L, velocity_representation=VelRepr.Inertial
1826
+ # Compute the 6D forces applied to the links equivalent to the forces applied
1827
+ # to the frames associated to the collidable points.
1828
+ f_L, _ = model.contact_model.compute_link_contact_forces(
1829
+ model=model,
1830
+ data=data,
1831
+ link_forces=input_references.link_forces(model=model, data=data),
1832
+ joint_force_references=input_references.joint_force_references(),
1833
+ **kwargs,
1857
1834
  )
1858
1835
 
1859
- # Use the references object to convert the link forces to the velocity
1860
- # representation of data.
1861
- with references.switch_velocity_representation(data.velocity_representation):
1862
- f_L = references.link_forces(model=model, data=data)
1863
-
1864
1836
  return f_L
1865
1837
 
1866
1838
 
@@ -1967,6 +1939,11 @@ def step(
1967
1939
  Returns:
1968
1940
  A tuple containing the new data of the model
1969
1941
  and the new state of the integrator.
1942
+
1943
+ Note:
1944
+ In order to reduce the occurrences of frame conversions performed internally,
1945
+ it is recommended to use inertial-fixed velocity representation. This can be
1946
+ particularly useful for automatically differentiated logic.
1970
1947
  """
1971
1948
 
1972
1949
  # Extract the integrator kwargs.
@@ -1976,15 +1953,61 @@ def step(
1976
1953
  integrator_kwargs = kwargs.pop("integrator_kwargs", {})
1977
1954
  integrator_kwargs = kwargs | integrator_kwargs
1978
1955
 
1979
- integrator_state = integrator_state if integrator_state is not None else dict()
1956
+ # Initialize the integrator state.
1957
+ integrator_state_t0 = integrator_state if integrator_state is not None else dict()
1980
1958
 
1981
1959
  # Initialize the time-related variables.
1982
1960
  state_t0 = data.state
1983
1961
  t0 = jnp.array(t0, dtype=float)
1984
1962
  dt = jnp.array(dt if dt is not None else model.time_step).astype(float)
1985
1963
 
1986
- # Rename the integrator state.
1987
- integrator_state_t0 = integrator_state
1964
+ # The visco-elastic contacts operate at best with their own integrator.
1965
+ # They can be used with Euler-like integrators, paying the price of ignoring
1966
+ # some of the benefits of continuous-time integration on the system position.
1967
+ # Furthermore, the requirement to know the Δt used by the integrator is not
1968
+ # compatible with high-order integrators, that use advanced RK stages to evaluate
1969
+ # the dynamics at intermediate times.
1970
+ module = jaxsim.rbda.contacts.visco_elastic.step.__module__
1971
+ name = jaxsim.rbda.contacts.visco_elastic.step.__name__
1972
+ msg = "You need to use the custom '{}.{}' function with this contact model."
1973
+ jaxsim.exceptions.raise_runtime_error_if(
1974
+ condition=(
1975
+ isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts)
1976
+ & (
1977
+ ~jnp.allclose(dt, model.time_step)
1978
+ | ~isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler)
1979
+ )
1980
+ ),
1981
+ msg=msg.format(module, name),
1982
+ )
1983
+
1984
+ # =================
1985
+ # Phase 1: pre-step
1986
+ # =================
1987
+
1988
+ # TODO: some contact models here may want to perform a dynamic filtering of
1989
+ # the enabled collidable points.
1990
+
1991
+ # Build the references object.
1992
+ # We assume that the link forces are expressed in the frame corresponding to the
1993
+ # velocity representation of the data.
1994
+ references = js.references.JaxSimModelReferences.build(
1995
+ model=model,
1996
+ data=data,
1997
+ velocity_representation=data.velocity_representation,
1998
+ link_forces=link_forces,
1999
+ joint_force_references=joint_force_references,
2000
+ )
2001
+
2002
+ # =============
2003
+ # Phase 2: step
2004
+ # =============
2005
+
2006
+ # Prepare the references to pass.
2007
+ with references.switch_velocity_representation(data.velocity_representation):
2008
+
2009
+ f_L = references.link_forces(model=model, data=data)
2010
+ τ_references = references.joint_force_references(model=model)
1988
2011
 
1989
2012
  # Step the dynamics forward.
1990
2013
  state_tf, integrator_state_tf = integrator.step(
@@ -1994,7 +2017,7 @@ def step(
1994
2017
  params=integrator_state_t0,
1995
2018
  # Always inject the current (model, data) pair into the system dynamics
1996
2019
  # considered by the integrator, and include the input variables represented
1997
- # by the pair (joint_force_references, link_forces).
2020
+ # by the pair (f_L, τ_references).
1998
2021
  # Note that the wrapper of the system dynamics will override (state_x0, t0)
1999
2022
  # inside the passed data even if it is not strictly needed. This logic is
2000
2023
  # necessary to re-use the jit-compiled step function of compatible pytrees
@@ -2003,8 +2026,8 @@ def step(
2003
2026
  dict(
2004
2027
  model=model,
2005
2028
  data=data,
2006
- joint_force_references=joint_force_references,
2007
- link_forces=link_forces,
2029
+ link_forces=f_L,
2030
+ joint_force_references=τ_references,
2008
2031
  )
2009
2032
  | integrator_kwargs
2010
2033
  ),
@@ -2013,6 +2036,10 @@ def step(
2013
2036
  # Store the new state of the model.
2014
2037
  data_tf = data.replace(state=state_tf)
2015
2038
 
2039
+ # ==================
2040
+ # Phase 3: post-step
2041
+ # ==================
2042
+
2016
2043
  # Post process the simulation state, if needed.
2017
2044
  match model.contact_model:
2018
2045
 
@@ -2040,17 +2067,18 @@ def step(
2040
2067
  msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
2041
2068
  )
2042
2069
 
2070
+ W_p_C = js.contact.collidable_point_positions(model, data_tf)
2071
+
2072
+ # Compute the penetration depth of the collidable points.
2073
+ δ, *_ = jax.vmap(
2074
+ jaxsim.rbda.contacts.common.compute_penetration_data,
2075
+ in_axes=(0, 0, None),
2076
+ )(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
2077
+
2043
2078
  with data_tf.switch_velocity_representation(VelRepr.Mixed):
2044
2079
 
2045
2080
  J_WC = js.contact.jacobian(model, data_tf)
2046
2081
  M = js.model.free_floating_mass_matrix(model, data_tf)
2047
- W_p_C = js.contact.collidable_point_positions(model, data_tf)
2048
-
2049
- # Compute the penetration depth of the collidable points.
2050
- δ, *_ = jax.vmap(
2051
- jaxsim.rbda.contacts.common.compute_penetration_data,
2052
- in_axes=(0, 0, None),
2053
- )(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
2054
2082
 
2055
2083
  # Compute the impact velocity.
2056
2084
  # It may be discontinuous in case new contacts are made.
@@ -2063,13 +2091,13 @@ def step(
2063
2091
  )
2064
2092
  )
2065
2093
 
2066
- # Reset the generalized velocity.
2067
- data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
2068
- data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
2094
+ # Reset the generalized velocity.
2095
+ data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
2096
+ data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
2069
2097
 
2070
- # Restore the input velocity representation.
2071
- data_tf = data_tf.replace(
2072
- velocity_representation=data.velocity_representation, validate=False
2073
- )
2098
+ # Restore the input velocity representation.
2099
+ data_tf = data_tf.replace(
2100
+ velocity_representation=data.velocity_representation, validate=False
2101
+ )
2074
2102
 
2075
2103
  return data_tf, integrator_state_tf
@@ -131,7 +131,7 @@ def system_velocity_dynamics(
131
131
 
132
132
  # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
133
133
  # with the terrain.
134
- W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
134
+ W_f_L_terrain = jnp.zeros_like(O_f_L).astype(float)
135
135
 
136
136
  # Initialize a dictionary of auxiliary data.
137
137
  # This dictionary is used to store additional data computed by the contact model.
@@ -139,66 +139,59 @@ def system_velocity_dynamics(
139
139
 
140
140
  if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
141
141
 
142
- # Note: the following code should be kept in sync with the function
143
- # `jaxsim.api.model.link_contact_forces`. We cannot merge them since
144
- # here we need to get also aux_data.
145
-
146
- # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
147
- # along with contact-specific auxiliary states.
148
142
  with (
149
143
  data.switch_velocity_representation(VelRepr.Inertial),
150
144
  references.switch_velocity_representation(VelRepr.Inertial),
151
145
  ):
152
- W_f_Ci, aux_data = js.contact.collidable_point_dynamics(
146
+
147
+ # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
148
+ # along with contact-specific auxiliary states.
149
+ W_f_C, aux_data = js.contact.collidable_point_dynamics(
153
150
  model=model,
154
151
  data=data,
155
152
  link_forces=references.link_forces(model=model, data=data),
156
153
  joint_force_references=references.joint_force_references(model=model),
157
154
  )
158
155
 
159
- # Construct the vector defining the parent link index of each collidable point.
160
- # We use this vector to sum the 6D forces of all collidable points rigidly
161
- # attached to the same link.
162
- parent_link_index_of_collidable_points = jnp.array(
163
- model.kin_dyn_parameters.contact_parameters.body, dtype=int
164
- )
165
-
166
- # Sum the forces of all collidable points rigidly attached to a body.
167
- # Since the contact forces W_f_Ci are expressed in the world frame,
168
- # we don't need any coordinate transformation.
169
- mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
170
- model.number_of_links()
171
- )
172
-
173
- W_f_Li_terrain = mask.T @ W_f_Ci
156
+ # Compute the 6D forces applied to the links equivalent to the forces applied
157
+ # to the frames associated to the collidable points.
158
+ W_f_L_terrain = model.contact_model.link_forces_from_contact_forces(
159
+ model=model,
160
+ data=data,
161
+ contact_forces=W_f_C,
162
+ )
174
163
 
175
164
  # ===========================
176
165
  # Compute system acceleration
177
166
  # ===========================
178
167
 
179
- # Compute the total link forces
168
+ # Compute the total link forces.
180
169
  with (
181
170
  data.switch_velocity_representation(VelRepr.Inertial),
182
171
  references.switch_velocity_representation(VelRepr.Inertial),
183
172
  ):
173
+
174
+ # Sum the contact forces just computed with the link forces applied by the user.
184
175
  references = references.apply_link_forces(
185
176
  model=model,
186
177
  data=data,
187
- forces=W_f_Li_terrain,
178
+ forces=W_f_L_terrain,
188
179
  additive=True,
189
180
  )
190
181
 
191
- # Get the link forces in inertial representation
182
+ # Get the link forces in inertial-fixed representation.
192
183
  f_L_total = references.link_forces(model=model, data=data)
193
184
 
194
- v̇_WB, = system_acceleration(
185
+ # Compute the system acceleration in inertial-fixed representation.
186
+ # This representation is useful for integration purpose.
187
+ W_v̇_WB, s̈ = system_acceleration(
195
188
  model=model,
196
189
  data=data,
197
190
  joint_force_references=joint_force_references,
198
191
  link_forces=f_L_total,
199
192
  )
200
193
 
201
- return v̇_WB, s̈, aux_data
194
+ return W_v̇_WB, s̈, aux_data
202
195
 
203
196
 
204
197
  def system_acceleration(
@@ -390,17 +383,15 @@ def system_dynamics(
390
383
 
391
384
  case contacts.ViscoElasticContacts():
392
385
 
393
- extended_ode_state["contacts_state"] = {
394
- "tangential_deformation": jnp.zeros_like(
395
- data.state.extended["tangential_deformation"]
396
- )
397
- }
386
+ extended_ode_state["tangential_deformation"] = jnp.zeros_like(
387
+ data.state.extended["tangential_deformation"]
388
+ )
398
389
 
399
390
  case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
400
391
  pass
401
392
 
402
393
  case _:
403
- raise ValueError(f"Invalid contact model {model.contact_model}")
394
+ raise ValueError(f"Invalid contact model: {model.contact_model}")
404
395
 
405
396
  # Extract the velocities.
406
397
  W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(