jaxsim 0.4.3.dev64__tar.gz → 0.4.3.dev68__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/PKG-INFO +1 -2
  2. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/environment.yml +0 -1
  3. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/pyproject.toml +0 -2
  4. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/__init__.py +0 -5
  5. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/_version.py +2 -2
  6. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/contact.py +1 -27
  7. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/data.py +11 -40
  8. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/joint.py +2 -62
  9. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/model.py +1 -12
  10. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/ode.py +24 -19
  11. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/ode_data.py +1 -11
  12. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/integrators/common.py +1 -1
  13. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/inertia.py +1 -1
  14. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/loaders.py +3 -3
  15. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/kinematic_graph.py +3 -3
  16. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/rod/parser.py +14 -18
  17. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/contacts/rigid.py +41 -11
  18. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/terrain/terrain.py +25 -41
  19. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/typing.py +1 -1
  20. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/utils/jaxsim_dataclass.py +9 -12
  21. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/utils/wrappers.py +1 -1
  22. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/PKG-INFO +1 -2
  23. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/SOURCES.txt +0 -1
  24. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/requires.txt +0 -1
  25. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/conftest.py +0 -25
  26. jaxsim-0.4.3.dev64/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -384
  27. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.devcontainer/Dockerfile +0 -0
  28. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.devcontainer/devcontainer.json +0 -0
  29. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.gitattributes +0 -0
  30. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.github/CODEOWNERS +0 -0
  31. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.github/workflows/ci_cd.yml +0 -0
  32. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.github/workflows/read_the_docs.yml +0 -0
  33. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.github/workflows/update_pixi_lockfile.yml +0 -0
  34. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.gitignore +0 -0
  35. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.pre-commit-config.yaml +0 -0
  36. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.readthedocs.yaml +0 -0
  37. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/CONTRIBUTING.md +0 -0
  38. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/LICENSE +0 -0
  39. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/README.md +0 -0
  40. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/Makefile +0 -0
  41. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/conf.py +0 -0
  42. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/examples.rst +0 -0
  43. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/guide/install.rst +0 -0
  44. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/index.rst +0 -0
  45. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/make.bat +0 -0
  46. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/api.rst +0 -0
  47. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/integrators.rst +0 -0
  48. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/math.rst +0 -0
  49. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/mujoco.rst +0 -0
  50. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/parsers.rst +0 -0
  51. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/rbda.rst +0 -0
  52. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/typing.rst +0 -0
  53. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/utils.rst +0 -0
  54. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/.gitattributes +0 -0
  55. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/.gitignore +0 -0
  56. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/PD_controller.ipynb +0 -0
  57. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/Parallel_computing.ipynb +0 -0
  58. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/README.md +0 -0
  59. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/assets/cartpole.urdf +0 -0
  60. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/pixi.lock +0 -0
  61. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/setup.cfg +0 -0
  62. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/setup.py +0 -0
  63. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/__init__.py +0 -0
  64. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/com.py +0 -0
  65. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/common.py +0 -0
  66. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/frame.py +0 -0
  67. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  68. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/link.py +0 -0
  69. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/references.py +0 -0
  70. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/exceptions.py +0 -0
  71. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/integrators/__init__.py +0 -0
  72. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/integrators/fixed_step.py +0 -0
  73. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/integrators/variable_step.py +0 -0
  74. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/logging.py +0 -0
  75. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/__init__.py +0 -0
  76. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/adjoint.py +0 -0
  77. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/cross.py +0 -0
  78. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/joint_model.py +0 -0
  79. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/quaternion.py +0 -0
  80. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/rotation.py +0 -0
  81. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/skew.py +0 -0
  82. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/transform.py +0 -0
  83. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/__init__.py +0 -0
  84. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/__main__.py +0 -0
  85. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/model.py +0 -0
  86. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/visualizer.py +0 -0
  87. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/__init__.py +0 -0
  88. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  89. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  90. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  91. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/link.py +0 -0
  92. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/model.py +0 -0
  93. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/rod/__init__.py +0 -0
  94. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/rod/utils.py +0 -0
  95. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/__init__.py +0 -0
  96. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/aba.py +0 -0
  97. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/collidable_points.py +0 -0
  98. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  99. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/contacts/common.py +0 -0
  100. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/contacts/soft.py +0 -0
  101. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/crba.py +0 -0
  102. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  103. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/jacobian.py +0 -0
  104. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/rnea.py +0 -0
  105. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/utils.py +0 -0
  106. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/terrain/__init__.py +0 -0
  107. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/utils/__init__.py +0 -0
  108. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/utils/tracing.py +0 -0
  109. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  110. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/top_level.txt +0 -0
  111. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/__init__.py +0 -0
  112. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_com.py +0 -0
  113. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_contact.py +0 -0
  114. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_data.py +0 -0
  115. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_frame.py +0 -0
  116. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_joint.py +0 -0
  117. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_link.py +0 -0
  118. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_model.py +0 -0
  119. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_automatic_differentiation.py +0 -0
  120. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_contact.py +0 -0
  121. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_exceptions.py +0 -0
  122. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_pytree.py +0 -0
  123. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_simulations.py +0 -0
  124. {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev64
3
+ Version: 0.4.3.dev68
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -61,7 +61,6 @@ Description-Content-Type: text/markdown
61
61
  License-File: LICENSE
62
62
  Requires-Dist: coloredlogs
63
63
  Requires-Dist: jax>=0.4.13
64
- Requires-Dist: jaxopt>=0.8.0
65
64
  Requires-Dist: jaxlib>=0.4.13
66
65
  Requires-Dist: jaxlie>=1.3.0
67
66
  Requires-Dist: jax_dataclasses>=1.4.0
@@ -8,7 +8,6 @@ dependencies:
8
8
  - python >= 3.12.0
9
9
  - coloredlogs
10
10
  - jax >= 0.4.13
11
- - jaxopt >= 0.8.0
12
11
  - jaxlib >= 0.4.13
13
12
  - jaxlie >= 1.3.0
14
13
  - jax-dataclasses >= 1.4.0
@@ -45,7 +45,6 @@ classifiers = [
45
45
  dependencies = [
46
46
  "coloredlogs",
47
47
  "jax >= 0.4.13",
48
- "jaxopt >= 0.8.0",
49
48
  "jaxlib >= 0.4.13",
50
49
  "jaxlie >= 1.3.0",
51
50
  "jax_dataclasses >= 1.4.0",
@@ -182,7 +181,6 @@ platforms = ["linux-64", "osx-arm64", "osx-64"]
182
181
  coloredlogs = "*"
183
182
  jax = "*"
184
183
  jax-dataclasses = "*"
185
- jaxopt = "*"
186
184
  jaxlib = "*"
187
185
  jaxlie = "*"
188
186
  lxml = "*"
@@ -20,11 +20,6 @@ def _jnp_options() -> None:
20
20
  if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
21
21
  logging.warning("Failed to enable 64bit precision in JAX")
22
22
 
23
- else:
24
- logging.warning(
25
- "Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
26
- )
27
-
28
23
 
29
24
  def _np_options() -> None:
30
25
  import numpy as np
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev64'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev64')
15
+ __version__ = version = '0.4.3.dev68'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev68')
@@ -131,8 +131,7 @@ def collidable_point_dynamics(
131
131
  Returns:
132
132
  The 6D force applied to each collidable point and additional data based on the contact model configured:
133
133
  - Soft: the material deformation rate.
134
- - Rigid: no additional data.
135
- - QuasiRigid: no additional data.
134
+ - Rigid: nothing.
136
135
 
137
136
  Note:
138
137
  The material deformation rate is always returned in the mixed frame
@@ -145,10 +144,6 @@ def collidable_point_dynamics(
145
144
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
146
145
 
147
146
  # Import privately the contacts classes.
148
- from jaxsim.rbda.contacts.relaxed_rigid import (
149
- RelaxedRigidContacts,
150
- RelaxedRigidContactsState,
151
- )
152
147
  from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
153
148
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
154
149
 
@@ -195,27 +190,6 @@ def collidable_point_dynamics(
195
190
 
196
191
  aux_data = dict()
197
192
 
198
- case RelaxedRigidContacts():
199
- assert isinstance(model.contact_model, RelaxedRigidContacts)
200
- assert isinstance(data.state.contact, RelaxedRigidContactsState)
201
-
202
- # Build the contact model.
203
- relaxed_rigid_contacts = RelaxedRigidContacts(
204
- parameters=data.contacts_params, terrain=model.terrain
205
- )
206
-
207
- # Compute the 6D force expressed in the inertial frame and applied to each
208
- # collidable point.
209
- W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
210
- position=W_p_Ci,
211
- velocity=W_ṗ_Ci,
212
- model=model,
213
- data=data,
214
- link_forces=link_forces,
215
- )
216
-
217
- aux_data = dict()
218
-
219
193
  case _:
220
194
  raise ValueError(f"Invalid contact model {model.contact_model}")
221
195
 
@@ -39,9 +39,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
39
39
  contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
40
40
 
41
41
  time_ns: jtp.Int = dataclasses.field(
42
- default_factory=lambda: jnp.array(
43
- 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
44
- ),
42
+ default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
45
43
  )
46
44
 
47
45
  def __hash__(self) -> int:
@@ -174,14 +172,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
174
172
  )
175
173
 
176
174
  time_ns = (
177
- jnp.array(
178
- time * 1e9,
179
- dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
180
- )
175
+ jnp.array(time * 1e9, dtype=jnp.uint64)
181
176
  if time is not None
182
- else jnp.array(
183
- 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
184
- )
177
+ else jnp.array(0, dtype=jnp.uint64)
185
178
  )
186
179
 
187
180
  if isinstance(model.contact_model, SoftContacts):
@@ -593,18 +586,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
593
586
  The updated `JaxSimModelData` object.
594
587
  """
595
588
 
596
- W_Q_B = jnp.array(base_quaternion, dtype=float)
597
-
598
- W_Q_B = jax.lax.select(
599
- pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
600
- on_true=W_Q_B,
601
- on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
602
- )
589
+ base_quaternion = jnp.array(base_quaternion)
603
590
 
604
591
  return self.replace(
605
592
  validate=True,
606
593
  state=self.state.replace(
607
- physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
594
+ physics_model=self.state.physics_model.replace(
595
+ base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
596
+ float
597
+ )
598
+ )
608
599
  ),
609
600
  )
610
601
 
@@ -746,13 +737,6 @@ def random_model_data(
746
737
  jtp.FloatLike | Sequence[jtp.FloatLike],
747
738
  jtp.FloatLike | Sequence[jtp.FloatLike],
748
739
  ] = ((-1, -1, 0.5), 1.0),
749
- joint_pos_bounds: (
750
- tuple[
751
- jtp.FloatLike | Sequence[jtp.FloatLike],
752
- jtp.FloatLike | Sequence[jtp.FloatLike],
753
- ]
754
- | None
755
- ) = None,
756
740
  base_vel_lin_bounds: tuple[
757
741
  jtp.FloatLike | Sequence[jtp.FloatLike],
758
742
  jtp.FloatLike | Sequence[jtp.FloatLike],
@@ -778,8 +762,6 @@ def random_model_data(
778
762
  key: The random key.
779
763
  velocity_representation: The velocity representation to use.
780
764
  base_pos_bounds: The bounds for the base position.
781
- joint_pos_bounds:
782
- The bounds for the joint positions (reading the joint limits if None).
783
765
  base_vel_lin_bounds: The bounds for the base linear velocity.
784
766
  base_vel_ang_bounds: The bounds for the base angular velocity.
785
767
  joint_vel_bounds: The bounds for the joint velocities.
@@ -824,19 +806,8 @@ def random_model_data(
824
806
  ).wxyz
825
807
 
826
808
  if model.number_of_joints() > 0:
827
-
828
- s_min, s_max = (
829
- jnp.array(joint_pos_bounds, dtype=float)
830
- if joint_pos_bounds is not None
831
- else (None, None)
832
- )
833
-
834
- physics_model_state.joint_positions = (
835
- js.joint.random_joint_positions(model=model, key=k3)
836
- if (s_min is None or s_max is None)
837
- else jax.random.uniform(
838
- key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
839
- )
809
+ physics_model_state.joint_positions = js.joint.random_joint_positions(
810
+ model=model, key=k3
840
811
  )
841
812
 
842
813
  physics_model_state.joint_velocities = jax.random.uniform(
@@ -180,77 +180,17 @@ def random_joint_positions(
180
180
 
181
181
  Args:
182
182
  model: The model to consider.
183
- joint_names: The names of the considered joints (all if None).
184
- key: The random key (initialized from seed 0 if None).
185
-
186
- Note:
187
- If the joint range or revolute joints is larger than 2π, their joint positions
188
- will be sampled from an interval of size 2π.
183
+ joint_names: The names of the joints.
184
+ key: The random key.
189
185
 
190
186
  Returns:
191
187
  The random joint positions.
192
188
  """
193
189
 
194
- # Consider the key corresponding to a zero seed if it was not passed.
195
190
  key = key if key is not None else jax.random.PRNGKey(seed=0)
196
191
 
197
- # Get the joint limits parsed from the model description.
198
192
  s_min, s_max = position_limits(model=model, joint_names=joint_names)
199
193
 
200
- # Get the joint indices.
201
- # Note that it will trigger an exception if the given `joint_names` are not valid.
202
- joint_names = joint_names if joint_names is not None else model.joint_names()
203
- joint_indices = names_to_idxs(model=model, joint_names=joint_names)
204
-
205
- from jaxsim.parsers.descriptions.joint import JointType
206
-
207
- # Filter for revolute joints.
208
- is_revolute = jnp.where(
209
- jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
210
- == JointType.Revolute,
211
- True,
212
- False,
213
- )
214
-
215
- # Shorthand for π.
216
- π = jnp.pi
217
-
218
- # Filter for revolute with full range (or continuous).
219
- is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
220
-
221
- # Clip the lower limit to -π if the joint range is larger than [-π, π].
222
- s_min = jnp.where(
223
- jnp.logical_and(
224
- is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
225
- ),
226
- -π,
227
- s_min,
228
- )
229
-
230
- # Clip the upper limit to +π if the joint range is larger than [-π, π].
231
- s_max = jnp.where(
232
- jnp.logical_and(
233
- is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
234
- ),
235
- π,
236
- s_max,
237
- )
238
-
239
- # Shift the lower limit if the upper limit is smaller than +π.
240
- s_min = jnp.where(
241
- jnp.logical_and(is_revolute_full_range, s_max < π),
242
- s_max - 2 * π,
243
- s_min,
244
- )
245
-
246
- # Shift the upper limit if the lower limit is larger than -π.
247
- s_max = jnp.where(
248
- jnp.logical_and(is_revolute_full_range, s_min > -π),
249
- s_min + 2 * π,
250
- s_max,
251
- )
252
-
253
- # Sample the joint positions.
254
194
  s_random = jax.random.uniform(
255
195
  minval=s_min,
256
196
  maxval=s_max,
@@ -1931,22 +1931,11 @@ def step(
1931
1931
  ),
1932
1932
  )
1933
1933
 
1934
- tf_ns = t0_ns + jnp.array(dt * 1e9, dtype=t0_ns.dtype)
1935
- tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))
1936
-
1937
- jax.lax.cond(
1938
- pred=tf_ns < t0_ns,
1939
- true_fun=lambda: jax.debug.print(
1940
- "The simulation time overflowed, resetting simulation time to 0."
1941
- ),
1942
- false_fun=lambda: None,
1943
- )
1944
-
1945
1934
  data_tf = (
1946
1935
  # Store the new state of the model and the new time.
1947
1936
  data.replace(
1948
1937
  state=state_tf,
1949
- time_ns=tf_ns,
1938
+ time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
1950
1939
  )
1951
1940
  )
1952
1941
 
@@ -175,15 +175,17 @@ def system_velocity_dynamics(
175
175
  forces=W_f_Li_terrain,
176
176
  additive=True,
177
177
  )
178
-
179
- # Get the link forces in inertial representation
178
+ # Get the link forces in the data representation
179
+ with references.switch_velocity_representation(data.velocity_representation):
180
180
  f_L_total = references.link_forces(model=model, data=data)
181
181
 
182
- v̇_WB, = system_acceleration(
183
- model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
184
- )
182
+ # The following method always returns the inertial-fixed acceleration, and expects
183
+ # the link_forces expressed in the inertial frame.
184
+ W_v̇_WB, s̈ = system_acceleration(
185
+ model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
186
+ )
185
187
 
186
- return v̇_WB, s̈, aux_data
188
+ return W_v̇_WB, s̈, aux_data
187
189
 
188
190
 
189
191
  def system_acceleration(
@@ -194,7 +196,7 @@ def system_acceleration(
194
196
  link_forces: jtp.MatrixLike | None = None,
195
197
  ) -> tuple[jtp.Vector, jtp.Vector]:
196
198
  """
197
- Compute the system acceleration in the active representation.
199
+ Compute the system acceleration in inertial-fixed representation.
198
200
 
199
201
  Args:
200
202
  model: The model to consider.
@@ -204,7 +206,7 @@ def system_acceleration(
204
206
  The 6D forces to apply to the links expressed in the same representation of data.
205
207
 
206
208
  Returns:
207
- A tuple containing the base 6D acceleration in in the active representation
209
+ A tuple containing the base 6D acceleration in inertial-fixed representation
208
210
  and the joint accelerations.
209
211
  """
210
212
 
@@ -270,15 +272,18 @@ def system_acceleration(
270
272
  )
271
273
 
272
274
  # - Joint accelerations: s̈ ∈ ℝⁿ
273
- # - Base acceleration: v̇_WB ∈ ℝ⁶
274
- v̇_WB, s̈ = js.model.forward_dynamics_aba(
275
- model=model,
276
- data=data,
277
- joint_forces=references.joint_force_references(model=model),
278
- link_forces=references.link_forces(model=model, data=data),
279
- )
280
-
281
- return v̇_WB,
275
+ # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
276
+ with (
277
+ data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
278
+ references.switch_velocity_representation(VelRepr.Inertial),
279
+ ):
280
+ W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
281
+ model=model,
282
+ data=data,
283
+ joint_forces=references.joint_force_references(),
284
+ link_forces=references.link_forces(),
285
+ )
286
+ return W_v̇_WB, s̈
282
287
 
283
288
 
284
289
  @jax.jit
@@ -348,7 +353,7 @@ def system_dynamics(
348
353
  corresponding derivative, and the dictionary of auxiliary data returned
349
354
  by the system dynamics evaluation.
350
355
  """
351
- from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
356
+
352
357
  from jaxsim.rbda.contacts.rigid import RigidContacts
353
358
  from jaxsim.rbda.contacts.soft import SoftContacts
354
359
 
@@ -366,7 +371,7 @@ def system_dynamics(
366
371
  case SoftContacts():
367
372
  ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
368
373
 
369
- case RigidContacts() | RelaxedRigidContacts():
374
+ case RigidContacts():
370
375
  pass
371
376
 
372
377
  case _:
@@ -6,10 +6,6 @@ import jax_dataclasses
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
8
  from jaxsim.rbda import ContactsState
9
- from jaxsim.rbda.contacts.relaxed_rigid import (
10
- RelaxedRigidContacts,
11
- RelaxedRigidContactsState,
12
- )
13
9
  from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
14
10
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
15
11
  from jaxsim.utils import JaxsimDataclass
@@ -177,10 +173,6 @@ class ODEState(JaxsimDataclass):
177
173
  )
178
174
  case RigidContacts():
179
175
  contact = RigidContactsState.build()
180
-
181
- case RelaxedRigidContacts():
182
- contact = RelaxedRigidContactsState.build()
183
-
184
176
  case _:
185
177
  raise ValueError("Unable to determine contact state class prefix.")
186
178
 
@@ -224,9 +216,7 @@ class ODEState(JaxsimDataclass):
224
216
 
225
217
  # Get the contact model from the `JaxSimModel`.
226
218
  match contact:
227
- case (
228
- SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
229
- ):
219
+ case SoftContactsState() | RigidContactsState():
230
220
  pass
231
221
  case None:
232
222
  contact = SoftContactsState.zero(model=model)
@@ -497,7 +497,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
497
497
  b: jtp.Matrix,
498
498
  c: jtp.Vector,
499
499
  index_of_solution: jtp.IntLike = 0,
500
- ) -> tuple[bool, int | None]:
500
+ ) -> [bool, int | None]:
501
501
  """
502
502
  Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
503
503
 
@@ -45,7 +45,7 @@ class Inertia:
45
45
  M (jtp.Matrix): The 6x6 inertia matrix.
46
46
 
47
47
  Returns:
48
- tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
48
+ Tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
49
49
 
50
50
  Raises:
51
51
  ValueError: If the input matrix M has an unexpected shape.
@@ -211,7 +211,7 @@ class RodModelToMjcf:
211
211
  joints_dict = {j.name: j for j in rod_model.joints()}
212
212
 
213
213
  # Convert all the joints not considered to fixed joints.
214
- for joint_name in {j.name for j in rod_model.joints()} - considered_joints:
214
+ for joint_name in set(j.name for j in rod_model.joints()) - considered_joints:
215
215
  joints_dict[joint_name].type = "fixed"
216
216
 
217
217
  # Convert the ROD model to URDF.
@@ -289,10 +289,10 @@ class RodModelToMjcf:
289
289
  mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
290
290
 
291
291
  # Get the joint names.
292
- mj_joint_names = {
292
+ mj_joint_names = set(
293
293
  mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
294
294
  for idx in range(mj_model.njnt)
295
- }
295
+ )
296
296
 
297
297
  # Check that the Mujoco model only has the considered joints.
298
298
  if mj_joint_names != considered_joints:
@@ -394,7 +394,7 @@ class KinematicGraph(Sequence[LinkDescription]):
394
394
  return copy.deepcopy(self)
395
395
 
396
396
  # Check if all considered joints are part of the full kinematic graph
397
- if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
397
+ if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
398
398
  extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
399
399
  msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
400
400
  raise ValueError(msg)
@@ -536,8 +536,8 @@ class KinematicGraph(Sequence[LinkDescription]):
536
536
  root_link_name=full_graph.root.name,
537
537
  )
538
538
 
539
- assert {f.name for f in self.frames}.isdisjoint(
540
- {f.name for f in unconnected_frames + reduced_frames}
539
+ assert set(f.name for f in self.frames).isdisjoint(
540
+ set(f.name for f in unconnected_frames + reduced_frames)
541
541
  )
542
542
 
543
543
  for link in unconnected_links:
@@ -223,7 +223,7 @@ def extract_model_data(
223
223
  child=links_dict[j.child],
224
224
  jtype=utils.joint_to_joint_type(joint=j),
225
225
  axis=(
226
- np.array(j.axis.xyz.xyz, dtype=float)
226
+ np.array(j.axis.xyz.xyz)
227
227
  if j.axis is not None
228
228
  and j.axis.xyz is not None
229
229
  and j.axis.xyz.xyz is not None
@@ -232,43 +232,39 @@ def extract_model_data(
232
232
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
233
233
  initial_position=0.0,
234
234
  position_limit=(
235
- float(
236
- j.axis.limit.lower
237
- if j.axis is not None
238
- and j.axis.limit is not None
239
- and j.axis.limit.lower is not None
240
- else jnp.finfo(float).min
235
+ (
236
+ float(j.axis.limit.lower)
237
+ if j.axis is not None and j.axis.limit is not None
238
+ else np.finfo(float).min
241
239
  ),
242
- float(
243
- j.axis.limit.upper
244
- if j.axis is not None
245
- and j.axis.limit is not None
246
- and j.axis.limit.upper is not None
247
- else jnp.finfo(float).max
240
+ (
241
+ float(j.axis.limit.upper)
242
+ if j.axis is not None and j.axis.limit is not None
243
+ else np.finfo(float).max
248
244
  ),
249
245
  ),
250
- friction_static=float(
246
+ friction_static=(
251
247
  j.axis.dynamics.friction
252
248
  if j.axis is not None
253
249
  and j.axis.dynamics is not None
254
250
  and j.axis.dynamics.friction is not None
255
251
  else 0.0
256
252
  ),
257
- friction_viscous=float(
253
+ friction_viscous=(
258
254
  j.axis.dynamics.damping
259
255
  if j.axis is not None
260
256
  and j.axis.dynamics is not None
261
257
  and j.axis.dynamics.damping is not None
262
258
  else 0.0
263
259
  ),
264
- position_limit_damper=float(
260
+ position_limit_damper=(
265
261
  j.axis.limit.dissipation
266
262
  if j.axis is not None
267
263
  and j.axis.limit is not None
268
264
  and j.axis.limit.dissipation is not None
269
265
  else 0.0
270
266
  ),
271
- position_limit_spring=float(
267
+ position_limit_spring=(
272
268
  j.axis.limit.stiffness
273
269
  if j.axis is not None
274
270
  and j.axis.limit is not None
@@ -277,7 +273,7 @@ def extract_model_data(
277
273
  ),
278
274
  )
279
275
  for j in sdf_model.joints()
280
- if j.type in {"revolute", "continuous", "prismatic", "fixed"}
276
+ if j.type in {"revolute", "prismatic", "fixed"}
281
277
  and j.parent != "world"
282
278
  and j.child in links_dict.keys()
283
279
  ]
@@ -9,6 +9,7 @@ import jax_dataclasses
9
9
 
10
10
  import jaxsim.api as js
11
11
  import jaxsim.typing as jtp
12
+ from jaxsim import math
12
13
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
13
14
  from jaxsim.terrain import FlatTerrain, Terrain
14
15
 
@@ -271,17 +272,9 @@ class RigidContacts(ContactModel):
271
272
  link_forces=link_forces,
272
273
  )
273
274
 
274
- with (
275
- references.switch_velocity_representation(VelRepr.Mixed),
276
- data.switch_velocity_representation(VelRepr.Mixed),
277
- ):
278
- BW_ν̇_free = jnp.hstack(
279
- js.ode.system_acceleration(
280
- model=model,
281
- data=data,
282
- joint_forces=references.joint_force_references(model=model),
283
- link_forces=references.link_forces(model=model, data=data),
284
- )
275
+ with references.switch_velocity_representation(VelRepr.Mixed):
276
+ BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
277
+ model, data, references=references
285
278
  )
286
279
 
287
280
  free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
@@ -387,6 +380,43 @@ class RigidContacts(ContactModel):
387
380
  n_constraints = 6 * n_collidable_points
388
381
  return jnp.zeros(shape=(n_constraints,))
389
382
 
383
+ @staticmethod
384
+ def _compute_mixed_nu_dot_free(
385
+ model: js.model.JaxSimModel,
386
+ data: js.data.JaxSimModelData,
387
+ references: js.references.JaxSimModelReferences | None = None,
388
+ ) -> jtp.Array:
389
+ references = (
390
+ references
391
+ if references is not None
392
+ else js.references.JaxSimModelReferences.zero(model=model, data=data)
393
+ )
394
+
395
+ with (
396
+ data.switch_velocity_representation(VelRepr.Mixed),
397
+ references.switch_velocity_representation(VelRepr.Mixed),
398
+ ):
399
+ BW_v_WB = data.base_velocity()
400
+ W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
401
+ W_v̇_WB, s̈ = js.ode.system_acceleration(
402
+ model=model,
403
+ data=data,
404
+ joint_forces=references.joint_force_references(model=model),
405
+ link_forces=references.link_forces(model=model, data=data),
406
+ )
407
+
408
+ # Convert the inertial-fixed base acceleration to a mixed base acceleration.
409
+ W_H_B = data.base_transform()
410
+ W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
411
+ BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
412
+ term1 = BW_X_W @ W_v̇_WB
413
+ term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
414
+ BW_v̇_WB = term1 - term2
415
+
416
+ BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
417
+
418
+ return BW_ν̇
419
+
390
420
  @staticmethod
391
421
  def _linear_acceleration_of_collidable_points(
392
422
  model: js.model.JaxSimModel,