jaxsim 0.4.3.dev143__tar.gz → 0.4.3.dev155__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/PKG-INFO +1 -1
  2. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/_version.py +2 -2
  3. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/contact.py +3 -12
  4. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/data.py +62 -44
  5. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/model.py +28 -17
  6. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/ode.py +9 -7
  7. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/ode_data.py +42 -57
  8. jaxsim-0.4.3.dev155/src/jaxsim/rbda/contacts/__init__.py +5 -0
  9. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/contacts/common.py +42 -35
  10. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/contacts/relaxed_rigid.py +32 -26
  11. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/contacts/rigid.py +31 -25
  12. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/contacts/soft.py +59 -133
  13. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/terrain/terrain.py +1 -1
  14. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim.egg-info/PKG-INFO +1 -1
  15. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_automatic_differentiation.py +4 -7
  16. jaxsim-0.4.3.dev143/src/jaxsim/rbda/contacts/__init__.py +0 -9
  17. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.devcontainer/Dockerfile +0 -0
  18. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.devcontainer/devcontainer.json +0 -0
  19. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.gitattributes +0 -0
  20. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.github/CODEOWNERS +0 -0
  21. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.github/workflows/ci_cd.yml +0 -0
  22. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.github/workflows/read_the_docs.yml +0 -0
  23. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.github/workflows/update_pixi_lockfile.yml +0 -0
  24. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.gitignore +0 -0
  25. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.pre-commit-config.yaml +0 -0
  26. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/.readthedocs.yaml +0 -0
  27. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/CONTRIBUTING.md +0 -0
  28. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/LICENSE +0 -0
  29. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/README.md +0 -0
  30. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/Makefile +0 -0
  31. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/conf.py +0 -0
  32. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/examples.rst +0 -0
  33. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/guide/install.rst +0 -0
  34. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/index.rst +0 -0
  35. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/make.bat +0 -0
  36. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/modules/api.rst +0 -0
  37. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/modules/integrators.rst +0 -0
  38. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/modules/math.rst +0 -0
  39. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/modules/mujoco.rst +0 -0
  40. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/modules/parsers.rst +0 -0
  41. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/modules/rbda.rst +0 -0
  42. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/modules/typing.rst +0 -0
  43. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/docs/modules/utils.rst +0 -0
  44. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/environment.yml +0 -0
  45. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/examples/.gitattributes +0 -0
  46. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/examples/.gitignore +0 -0
  47. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/examples/PD_controller.ipynb +0 -0
  48. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/examples/Parallel_computing.ipynb +0 -0
  49. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/examples/README.md +0 -0
  50. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/examples/assets/cartpole.urdf +0 -0
  51. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/pixi.lock +0 -0
  52. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/pyproject.toml +0 -0
  53. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/setup.cfg +0 -0
  54. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/setup.py +0 -0
  55. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/__init__.py +0 -0
  56. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/__init__.py +0 -0
  57. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/com.py +0 -0
  58. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/common.py +0 -0
  59. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/frame.py +0 -0
  60. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/joint.py +0 -0
  61. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  62. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/link.py +0 -0
  63. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/api/references.py +0 -0
  64. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/exceptions.py +0 -0
  65. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/integrators/__init__.py +0 -0
  66. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/integrators/common.py +0 -0
  67. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/integrators/fixed_step.py +0 -0
  68. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/integrators/variable_step.py +0 -0
  69. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/logging.py +0 -0
  70. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/math/__init__.py +0 -0
  71. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/math/adjoint.py +0 -0
  72. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/math/cross.py +0 -0
  73. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/math/inertia.py +0 -0
  74. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/math/joint_model.py +0 -0
  75. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/math/quaternion.py +0 -0
  76. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/math/rotation.py +0 -0
  77. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/math/skew.py +0 -0
  78. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/math/transform.py +0 -0
  79. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/mujoco/__init__.py +0 -0
  80. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/mujoco/__main__.py +0 -0
  81. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/mujoco/loaders.py +0 -0
  82. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/mujoco/model.py +0 -0
  83. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/mujoco/visualizer.py +0 -0
  84. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/__init__.py +0 -0
  85. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  86. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  87. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  88. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/descriptions/link.py +0 -0
  89. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/descriptions/model.py +0 -0
  90. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  91. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/rod/__init__.py +0 -0
  92. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/rod/parser.py +0 -0
  93. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/parsers/rod/utils.py +0 -0
  94. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/__init__.py +0 -0
  95. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/aba.py +0 -0
  96. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/collidable_points.py +0 -0
  97. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/crba.py +0 -0
  98. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  99. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/jacobian.py +0 -0
  100. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/rnea.py +0 -0
  101. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/rbda/utils.py +0 -0
  102. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/terrain/__init__.py +0 -0
  103. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/typing.py +0 -0
  104. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/utils/__init__.py +0 -0
  105. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  106. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/utils/tracing.py +0 -0
  107. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim/utils/wrappers.py +0 -0
  108. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  109. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  110. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim.egg-info/requires.txt +0 -0
  111. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/src/jaxsim.egg-info/top_level.txt +0 -0
  112. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/__init__.py +0 -0
  113. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/conftest.py +0 -0
  114. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_api_com.py +0 -0
  115. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_api_contact.py +0 -0
  116. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_api_data.py +0 -0
  117. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_api_frame.py +0 -0
  118. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_api_joint.py +0 -0
  119. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_api_link.py +0 -0
  120. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_api_model.py +0 -0
  121. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_contact.py +0 -0
  122. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_exceptions.py +0 -0
  123. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_pytree.py +0 -0
  124. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/test_simulations.py +0 -0
  125. {jaxsim-0.4.3.dev143 → jaxsim-0.4.3.dev155}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev143
3
+ Version: 0.4.3.dev155
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev143'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev143')
15
+ __version__ = version = '0.4.3.dev155'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev155')
@@ -144,7 +144,8 @@ def collidable_point_dynamics(
144
144
  The joint force references to apply to the joints.
145
145
 
146
146
  Returns:
147
- The 6D force applied to each collidable point and additional data based on the contact model configured:
147
+ The 6D force applied to each collidable point and additional data based
148
+ on the contact model configured:
148
149
  - Soft: the material deformation rate.
149
150
  - Rigid: no additional data.
150
151
  - QuasiRigid: no additional data.
@@ -156,21 +157,13 @@ def collidable_point_dynamics(
156
157
  """
157
158
 
158
159
  # Import privately the contacts classes.
159
- from jaxsim.rbda.contacts import (
160
- RelaxedRigidContacts,
161
- RelaxedRigidContactsState,
162
- RigidContacts,
163
- RigidContactsState,
164
- SoftContacts,
165
- SoftContactsState,
166
- )
160
+ from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts
167
161
 
168
162
  # Build the soft contact model.
169
163
  match model.contact_model:
170
164
 
171
165
  case SoftContacts():
172
166
  assert isinstance(model.contact_model, SoftContacts)
173
- assert isinstance(data.state.contact, SoftContactsState)
174
167
 
175
168
  # Compute the 6D force expressed in the inertial frame and applied to each
176
169
  # collidable point, and the corresponding material deformation rate.
@@ -187,7 +180,6 @@ def collidable_point_dynamics(
187
180
 
188
181
  case RigidContacts():
189
182
  assert isinstance(model.contact_model, RigidContacts)
190
- assert isinstance(data.state.contact, RigidContactsState)
191
183
 
192
184
  # Compute the 6D force expressed in the inertial frame and applied to each
193
185
  # collidable point.
@@ -203,7 +195,6 @@ def collidable_point_dynamics(
203
195
 
204
196
  case RelaxedRigidContacts():
205
197
  assert isinstance(model.contact_model, RelaxedRigidContacts)
206
- assert isinstance(data.state.contact, RelaxedRigidContactsState)
207
198
 
208
199
  # Compute the 6D force expressed in the inertial frame and applied to each
209
200
  # collidable point.
@@ -13,7 +13,6 @@ import jaxsim.api as js
13
13
  import jaxsim.math
14
14
  import jaxsim.rbda
15
15
  import jaxsim.typing as jtp
16
- from jaxsim.rbda.contacts import SoftContacts
17
16
  from jaxsim.utils import Mutability
18
17
  from jaxsim.utils.tracing import not_tracing
19
18
 
@@ -107,17 +106,17 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
107
106
  @staticmethod
108
107
  def build(
109
108
  model: js.model.JaxSimModel,
110
- base_position: jtp.Vector | None = None,
111
- base_quaternion: jtp.Vector | None = None,
112
- joint_positions: jtp.Vector | None = None,
113
- base_linear_velocity: jtp.Vector | None = None,
114
- base_angular_velocity: jtp.Vector | None = None,
115
- joint_velocities: jtp.Vector | None = None,
109
+ base_position: jtp.VectorLike | None = None,
110
+ base_quaternion: jtp.VectorLike | None = None,
111
+ joint_positions: jtp.VectorLike | None = None,
112
+ base_linear_velocity: jtp.VectorLike | None = None,
113
+ base_angular_velocity: jtp.VectorLike | None = None,
114
+ joint_velocities: jtp.VectorLike | None = None,
116
115
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
117
- contact: jaxsim.rbda.contacts.ContactsState | None = None,
118
116
  contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
119
117
  velocity_representation: VelRepr = VelRepr.Inertial,
120
118
  time: jtp.FloatLike | None = None,
119
+ extended_ode_state: dict[str, jtp.PyTree] | None = None,
121
120
  ) -> JaxSimModelData:
122
121
  """
123
122
  Create a `JaxSimModelData` object with the given state.
@@ -133,56 +132,73 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
133
132
  The base angular velocity in the selected representation.
134
133
  joint_velocities: The joint velocities.
135
134
  standard_gravity: The standard gravity constant.
136
- contact: The state of the soft contacts.
137
135
  contacts_params: The parameters of the soft contacts.
138
136
  velocity_representation: The velocity representation to use.
139
137
  time: The time at which the state is created.
138
+ extended_ode_state:
139
+ Additional user-defined state variables that are not part of the
140
+ standard `ODEState` object. Useful to extend the system dynamics
141
+ considered by default in JaxSim.
140
142
 
141
143
  Returns:
142
- A `JaxSimModelData` object with the given state.
144
+ A `JaxSimModelData` initialized with the given state.
143
145
  """
144
146
 
145
147
  base_position = jnp.array(
146
- base_position if base_position is not None else jnp.zeros(3)
148
+ base_position if base_position is not None else jnp.zeros(3),
149
+ dtype=float,
147
150
  ).squeeze()
148
151
 
149
152
  base_quaternion = jnp.array(
150
- base_quaternion
151
- if base_quaternion is not None
152
- else jnp.array([1.0, 0, 0, 0])
153
+ (
154
+ base_quaternion
155
+ if base_quaternion is not None
156
+ else jnp.array([1.0, 0, 0, 0])
157
+ ),
158
+ dtype=float,
153
159
  ).squeeze()
154
160
 
155
161
  base_linear_velocity = jnp.array(
156
- base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
162
+ base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3),
163
+ dtype=float,
157
164
  ).squeeze()
158
165
 
159
166
  base_angular_velocity = jnp.array(
160
- base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
167
+ (
168
+ base_angular_velocity
169
+ if base_angular_velocity is not None
170
+ else jnp.zeros(3)
171
+ ),
172
+ dtype=float,
161
173
  ).squeeze()
162
174
 
163
175
  gravity = jnp.zeros(3).at[2].set(-standard_gravity)
164
176
 
165
177
  joint_positions = jnp.atleast_1d(
166
- joint_positions.squeeze()
167
- if joint_positions is not None
168
- else jnp.zeros(model.dofs())
178
+ jnp.array(
179
+ (
180
+ joint_positions
181
+ if joint_positions is not None
182
+ else jnp.zeros(model.dofs())
183
+ ),
184
+ dtype=float,
185
+ ).squeeze()
169
186
  )
170
187
 
171
188
  joint_velocities = jnp.atleast_1d(
172
- joint_velocities.squeeze()
173
- if joint_velocities is not None
174
- else jnp.zeros(model.dofs())
189
+ jnp.array(
190
+ (
191
+ joint_velocities
192
+ if joint_velocities is not None
193
+ else jnp.zeros(model.dofs())
194
+ ),
195
+ dtype=float,
196
+ ).squeeze()
175
197
  )
176
198
 
177
- time_ns = (
178
- jnp.array(
179
- time * 1e9,
180
- dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
181
- )
182
- if time is not None
183
- else jnp.array(
184
- 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
185
- )
199
+ time_ns = jnp.array(
200
+ time * 1e9 if time is not None else 0.0,
201
+ dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
186
202
  )
187
203
 
188
204
  W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
@@ -194,21 +210,22 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
194
210
  other_representation=velocity_representation,
195
211
  transform=W_H_B,
196
212
  is_force=False,
197
- )
213
+ ).astype(float)
198
214
 
199
215
  ode_state = ODEState.build_from_jaxsim_model(
200
216
  model=model,
201
- base_position=base_position.astype(float),
202
- base_quaternion=base_quaternion.astype(float),
203
- joint_positions=joint_positions.astype(float),
204
- base_linear_velocity=v_WB[0:3].astype(float),
205
- base_angular_velocity=v_WB[3:6].astype(float),
206
- joint_velocities=joint_velocities.astype(float),
207
- tangential_deformation=(
208
- contact.tangential_deformation
209
- if contact is not None and isinstance(model.contact_model, SoftContacts)
210
- else None
211
- ),
217
+ base_position=base_position,
218
+ base_quaternion=base_quaternion,
219
+ joint_positions=joint_positions,
220
+ base_linear_velocity=v_WB[0:3],
221
+ base_angular_velocity=v_WB[3:6],
222
+ joint_velocities=joint_velocities,
223
+ # Unpack all the additional ODE states. If the contact model requires an
224
+ # additional state that is not explicitly passed to this builder, ODEState
225
+ # automatically populates that state with zeroed variables.
226
+ # This is not true for any other custom state that the user might want to
227
+ # pass to the integrator.
228
+ **(extended_ode_state if extended_ode_state else {}),
212
229
  )
213
230
 
214
231
  if not ode_state.valid(model=model):
@@ -220,13 +237,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
220
237
  contacts_params = js.contact.estimate_good_soft_contacts_parameters(
221
238
  model=model, standard_gravity=standard_gravity
222
239
  )
240
+
223
241
  else:
224
242
  contacts_params = model.contact_model.parameters
225
243
 
226
244
  return JaxSimModelData(
227
245
  time_ns=time_ns,
228
246
  state=ode_state,
229
- gravity=gravity.astype(float),
247
+ gravity=gravity,
230
248
  contacts_params=contacts_params,
231
249
  velocity_representation=velocity_representation,
232
250
  )
@@ -33,7 +33,7 @@ class JaxSimModel(JaxsimDataclass):
33
33
  model_name: Static[str]
34
34
 
35
35
  terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
36
- default=jaxsim.terrain.FlatTerrain(), repr=False
36
+ default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
37
37
  )
38
38
 
39
39
  contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
@@ -101,13 +101,14 @@ class JaxSimModel(JaxsimDataclass):
101
101
  A path to an SDF/URDF file, a string containing
102
102
  its content, or a pre-parsed/pre-built rod model.
103
103
  model_name:
104
- The optional name of the model that overrides the one in
105
- the description.
106
- terrain:
107
- The optional terrain to consider.
104
+ The name of the model. If not specified, it is read from the description.
105
+ terrain: The terrain to consider (the default is a flat infinite plane).
106
+ contact_model:
107
+ The contact model to consider.
108
+ If not specified, a soft contacts model is used.
108
109
  is_urdf:
109
- The optional flag to force the model description to be parsed as a
110
- URDF or a SDF. This is otherwise automatically inferred.
110
+ The optional flag to force the model description to be parsed as a URDF.
111
+ This is usually automatically inferred.
111
112
  considered_joints:
112
113
  The list of joints to consider. If None, all joints are considered.
113
114
 
@@ -120,7 +121,7 @@ class JaxSimModel(JaxsimDataclass):
120
121
  # Parse the input resource (either a path to file or a string with the URDF/SDF)
121
122
  # and build the -intermediate- model description.
122
123
  intermediate_description = jaxsim.parsers.rod.build_model_description(
123
- model_description=model_description
124
+ model_description=model_description, is_urdf=is_urdf
124
125
  )
125
126
 
126
127
  # Lump links together if not all joints are considered.
@@ -160,11 +161,11 @@ class JaxSimModel(JaxsimDataclass):
160
161
  The intermediate model description defining the kinematics and dynamics
161
162
  of the model.
162
163
  model_name:
163
- The optional name of the model overriding the physics model name.
164
- terrain:
165
- The optional terrain to consider.
164
+ The name of the model. If not specified, it is read from the description.
165
+ terrain: The terrain to consider (the default is a flat infinite plane).
166
166
  contact_model:
167
- The optional contact model to consider. If None, the soft contact model is used.
167
+ The contact model to consider.
168
+ If not specified, a soft contacts model is used.
168
169
 
169
170
  Returns:
170
171
  The built Model object.
@@ -173,21 +174,31 @@ class JaxSimModel(JaxsimDataclass):
173
174
  # Set the model name (if not provided, use the one from the model description).
174
175
  model_name = model_name if model_name is not None else model_description.name
175
176
 
176
- # Set the terrain (if not provided, use the default flat terrain).
177
- terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
178
- contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts(
179
- terrain=terrain
177
+ # Consider the default terrain (a flat infinite plane) if not specified.
178
+ terrain = (
179
+ terrain or JaxSimModel.__dataclass_fields__["terrain"].default_factory()
180
+ )
181
+
182
+ # Create the default contact model.
183
+ # It will be populated with an initial estimation of good parameters.
184
+ # While these might not be the best, they are a good starting point.
185
+ contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts.build(
186
+ terrain=terrain, parameters=None
180
187
  )
181
188
 
182
189
  # Build the model.
183
190
  model = JaxSimModel(
184
191
  model_name=model_name,
185
- _description=wrappers.HashlessObject(obj=model_description),
186
192
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
187
193
  model_description=model_description
188
194
  ),
189
195
  terrain=terrain,
190
196
  contact_model=contact_model,
197
+ # The following is wrapped as hashless since it's a static argument, and we
198
+ # don't want to trigger recompilation if it changes. All relevant parameters
199
+ # needed to compute kinematics and dynamics quantities are stored in the
200
+ # kin_dyn_parameters attribute.
201
+ _description=wrappers.HashlessObject(obj=model_description),
191
202
  )
192
203
 
193
204
  return model
@@ -370,9 +370,8 @@ def system_dynamics(
370
370
  corresponding derivative, and the dictionary of auxiliary data returned
371
371
  by the system dynamics evaluation.
372
372
  """
373
- from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
374
- from jaxsim.rbda.contacts.rigid import RigidContacts
375
- from jaxsim.rbda.contacts.soft import SoftContacts
373
+
374
+ from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts
376
375
 
377
376
  # Compute the accelerations and the material deformation rate.
378
377
  W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
@@ -382,17 +381,20 @@ def system_dynamics(
382
381
  link_forces=link_forces,
383
382
  )
384
383
 
385
- ode_state_kwargs = {}
384
+ # Initialize the dictionary storing the derivative of the additional state variables
385
+ # that extend the state vector of the integrated ODE system.
386
+ extended_ode_state = {}
386
387
 
387
388
  match model.contact_model:
389
+
388
390
  case SoftContacts():
389
- ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
391
+ extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]
390
392
 
391
393
  case RigidContacts() | RelaxedRigidContacts():
392
394
  pass
393
395
 
394
396
  case _:
395
- raise ValueError("Unable to determine contact state class prefix.")
397
+ raise ValueError(f"Invalid contact model {model.contact_model}")
396
398
 
397
399
  # Extract the velocities.
398
400
  W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
@@ -412,7 +414,7 @@ def system_dynamics(
412
414
  base_linear_velocity=W_v̇_WB[0:3],
413
415
  base_angular_velocity=W_v̇_WB[3:6],
414
416
  joint_velocities=s̈,
415
- **ode_state_kwargs,
417
+ **extended_ode_state,
416
418
  )
417
419
 
418
420
  return ode_state_derivative, aux_dict
@@ -1,19 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
4
+
5
+ import jax
3
6
  import jax.numpy as jnp
4
7
  import jax_dataclasses
5
8
 
6
9
  import jaxsim.api as js
7
10
  import jaxsim.typing as jtp
8
- from jaxsim.rbda.contacts import (
9
- ContactsState,
10
- RelaxedRigidContacts,
11
- RelaxedRigidContactsState,
12
- RigidContacts,
13
- RigidContactsState,
14
- SoftContacts,
15
- SoftContactsState,
16
- )
17
11
  from jaxsim.utils import JaxsimDataclass
18
12
 
19
13
  # =============================================================================
@@ -125,15 +119,18 @@ class ODEState(JaxsimDataclass):
125
119
 
126
120
  Attributes:
127
121
  physics_model: The state of the physics model.
128
- contact: The state of the contacts model.
122
+ extended:
123
+ Additional state variables extending the state vector corresponding to
124
+ equations of motion. These extended variables are passed to the integrator.
129
125
  """
130
126
 
131
127
  physics_model: PhysicsModelState
132
- contact: ContactsState
128
+
129
+ extended: dict[str, jtp.PyTree] = dataclasses.field(default_factory=dict)
133
130
 
134
131
  @staticmethod
135
132
  def build_from_jaxsim_model(
136
- model: js.model.JaxSimModel | None = None,
133
+ model: js.model.JaxSimModel,
137
134
  joint_positions: jtp.Vector | None = None,
138
135
  joint_velocities: jtp.Vector | None = None,
139
136
  base_position: jtp.Vector | None = None,
@@ -155,7 +152,15 @@ class ODEState(JaxsimDataclass):
155
152
  The linear velocity of the base link in inertial-fixed representation.
156
153
  base_angular_velocity:
157
154
  The angular velocity of the base link in inertial-fixed representation.
158
- kwargs: Additional arguments needed to build the contact state.
155
+ kwargs:
156
+ Additional arguments corresponding variables extending the default
157
+ state vector of the physics model.
158
+
159
+ Note:
160
+ Kwargs can be used to supply any additional state variables that are passed
161
+ to the integrator. This is useful to extend the default system dynamics,
162
+ for example if the contact model requires additional state variables or to
163
+ simulate additional dynamics like actuators or muscoloskeletal models.
159
164
 
160
165
  Returns:
161
166
  The `ODEState` built from the `JaxSimModel`.
@@ -165,29 +170,11 @@ class ODEState(JaxsimDataclass):
165
170
  `JaxSimModel` and initialized to zero.
166
171
  """
167
172
 
168
- # Get the contact model from the `JaxSimModel`.
169
- match model.contact_model:
170
-
171
- case SoftContacts():
172
-
173
- tangential_deformation = kwargs.get("tangential_deformation", None)
174
-
175
- contact = SoftContactsState.build_from_jaxsim_model(
176
- model=model,
177
- **(
178
- dict(tangential_deformation=tangential_deformation)
179
- if tangential_deformation is not None
180
- else dict()
181
- ),
182
- )
183
- case RigidContacts():
184
- contact = RigidContactsState.build()
173
+ # Initialize the extended state with the optional contact state.
174
+ extended_state = model.contact_model.zero_state_variables(model=model)
185
175
 
186
- case RelaxedRigidContacts():
187
- contact = RelaxedRigidContactsState.build()
188
-
189
- case _:
190
- raise ValueError("Unsupported contact model.")
176
+ # Override the default extended state with optional kwargs.
177
+ extended_state |= kwargs
191
178
 
192
179
  return ODEState.build(
193
180
  model=model,
@@ -200,13 +187,13 @@ class ODEState(JaxsimDataclass):
200
187
  base_linear_velocity=base_linear_velocity,
201
188
  base_angular_velocity=base_angular_velocity,
202
189
  ),
203
- contact=contact,
190
+ extended_state=extended_state,
204
191
  )
205
192
 
206
193
  @staticmethod
207
194
  def build(
208
195
  physics_model_state: PhysicsModelState | None = None,
209
- contact: ContactsState | None = None,
196
+ extended_state: dict[str, jtp.PyTree] | None = None,
210
197
  model: js.model.JaxSimModel | None = None,
211
198
  ) -> ODEState:
212
199
  """
@@ -214,62 +201,60 @@ class ODEState(JaxsimDataclass):
214
201
 
215
202
  Args:
216
203
  physics_model_state: The state of the physics model.
217
- contact: The state of the contacts model.
204
+ extended_state: Additional state variables extending the state vector.
218
205
  model: The `JaxSimModel` associated with the ODE state.
219
206
 
220
207
  Returns:
221
208
  A `ODEState` instance.
222
209
  """
223
210
 
211
+ # Build a zero state for the physics model if not provided.
224
212
  physics_model_state = (
225
213
  physics_model_state
226
214
  if physics_model_state is not None
227
215
  else PhysicsModelState.zero(model=model)
228
216
  )
229
217
 
230
- # Get the contact model from the `JaxSimModel`.
231
- match contact:
232
- case (
233
- SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
234
- ):
235
- pass
236
- case None:
237
- contact = SoftContactsState.zero(model=model)
238
- case _:
239
- raise ValueError("Unable to determine contact state class prefix.")
240
-
241
- return ODEState(physics_model=physics_model_state, contact=contact)
218
+ return ODEState(
219
+ physics_model=physics_model_state,
220
+ extended=extended_state,
221
+ )
242
222
 
243
223
  @staticmethod
244
224
  def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState:
245
225
  """
246
- Build a zero `ODEState` from a `JaxSimModel`.
226
+ Build a zero `ODEState` corresponding to a `JaxSimModel`.
247
227
 
248
228
  Args:
249
- model: The `JaxSimModel` associated with the ODE state.
229
+ model: The model to consider.
230
+ data: The data of the considered model.
250
231
 
251
232
  Returns:
252
233
  A zero `ODEState` instance.
253
234
  """
254
235
 
255
- model_state = ODEState.build(
256
- model=model, contact=data.state.contact.zero(model=model)
236
+ ode_state = ODEState.build(
237
+ model=model,
238
+ extended_state=jax.tree.map(
239
+ lambda x: jnp.zeros_like(x), data.state.extended
240
+ ),
257
241
  )
258
242
 
259
- return model_state
243
+ return ode_state
260
244
 
261
245
  def valid(self, model: js.model.JaxSimModel) -> bool:
262
246
  """
263
247
  Check if the `ODEState` is valid for a given `JaxSimModel`.
264
248
 
265
249
  Args:
266
- model: The `JaxSimModel` to validate the `ODEState` against.
250
+ model: The model to validate this `ODEState` against.
267
251
 
268
252
  Returns:
269
253
  `True` if the ODE state is valid for the given model, `False` otherwise.
270
254
  """
271
255
 
272
- return self.physics_model.valid(model=model) and self.contact.valid(model=model)
256
+ # TODO: should we validate the extended state?
257
+ return self.physics_model.valid(model=model)
273
258
 
274
259
 
275
260
  # ==================================================
@@ -0,0 +1,5 @@
1
+ from . import relaxed_rigid, rigid, soft
2
+ from .common import ContactModel, ContactsParams
3
+ from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
4
+ from .rigid import RigidContacts, RigidContactsParams
5
+ from .soft import SoftContacts, SoftContactsParams