jaxsim 0.6.2.dev233__tar.gz → 0.6.2.dev238__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/PKG-INFO +1 -1
  2. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/_version.py +2 -2
  3. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/data.py +35 -37
  4. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/contacts/relaxed_rigid.py +10 -2
  5. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim.egg-info/PKG-INFO +1 -1
  6. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_automatic_differentiation.py +1 -10
  7. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.devcontainer/Dockerfile +0 -0
  8. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.devcontainer/devcontainer.json +0 -0
  9. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.gitattributes +0 -0
  10. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.github/CODEOWNERS +0 -0
  11. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.github/dependabot.yml +0 -0
  12. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.github/workflows/ci_cd.yml +0 -0
  13. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.github/workflows/gpu_benchmark.yml +0 -0
  14. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.github/workflows/pixi.yml +0 -0
  15. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.github/workflows/read_the_docs.yml +0 -0
  16. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.gitignore +0 -0
  17. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.pre-commit-config.yaml +0 -0
  18. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/.readthedocs.yaml +0 -0
  19. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/CONTRIBUTING.md +0 -0
  20. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/LICENSE +0 -0
  21. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/README.md +0 -0
  22. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/Makefile +0 -0
  23. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/conf.py +0 -0
  24. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/examples.rst +0 -0
  25. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/guide/configuration.rst +0 -0
  26. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/guide/install.rst +0 -0
  27. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/index.rst +0 -0
  28. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/make.bat +0 -0
  29. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/modules/api.rst +0 -0
  30. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/modules/math.rst +0 -0
  31. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/modules/mujoco.rst +0 -0
  32. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/modules/parsers.rst +0 -0
  33. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/modules/rbda.rst +0 -0
  34. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/modules/typing.rst +0 -0
  35. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/docs/modules/utils.rst +0 -0
  36. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/environment.yml +0 -0
  37. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/examples/.gitattributes +0 -0
  38. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/examples/.gitignore +0 -0
  39. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/examples/README.md +0 -0
  40. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/examples/assets/build_cartpole_urdf.py +0 -0
  41. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/examples/assets/cartpole.urdf +0 -0
  42. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  43. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  44. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/examples/jaxsim_as_physics_engine_advanced.ipynb +0 -0
  45. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  46. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/pixi.lock +0 -0
  47. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/pyproject.toml +0 -0
  48. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/setup.cfg +0 -0
  49. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/setup.py +0 -0
  50. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/__init__.py +0 -0
  51. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/__init__.py +0 -0
  52. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/actuation_model.py +0 -0
  53. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/com.py +0 -0
  54. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/common.py +0 -0
  55. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/contact.py +0 -0
  56. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/frame.py +0 -0
  57. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/integrators.py +0 -0
  58. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/joint.py +0 -0
  59. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  60. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/link.py +0 -0
  61. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/model.py +0 -0
  62. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/ode.py +0 -0
  63. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/api/references.py +0 -0
  64. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/exceptions.py +0 -0
  65. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/logging.py +0 -0
  66. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/__init__.py +0 -0
  67. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/adjoint.py +0 -0
  68. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/cross.py +0 -0
  69. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/inertia.py +0 -0
  70. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/joint_model.py +0 -0
  71. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/quaternion.py +0 -0
  72. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/rotation.py +0 -0
  73. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/skew.py +0 -0
  74. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/transform.py +0 -0
  75. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/math/utils.py +0 -0
  76. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/mujoco/__init__.py +0 -0
  77. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/mujoco/__main__.py +0 -0
  78. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/mujoco/loaders.py +0 -0
  79. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/mujoco/model.py +0 -0
  80. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/mujoco/utils.py +0 -0
  81. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/mujoco/visualizer.py +0 -0
  82. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/__init__.py +0 -0
  83. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  84. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  85. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  86. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/descriptions/link.py +0 -0
  87. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/descriptions/model.py +0 -0
  88. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  89. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/rod/__init__.py +0 -0
  90. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/rod/meshes.py +0 -0
  91. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/rod/parser.py +0 -0
  92. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/parsers/rod/utils.py +0 -0
  93. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/__init__.py +0 -0
  94. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/aba.py +0 -0
  95. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/collidable_points.py +0 -0
  96. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  97. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/contacts/common.py +0 -0
  98. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  99. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/contacts/soft.py +0 -0
  100. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/crba.py +0 -0
  101. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  102. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/jacobian.py +0 -0
  103. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/rnea.py +0 -0
  104. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/rbda/utils.py +0 -0
  105. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/terrain/__init__.py +0 -0
  106. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/terrain/terrain.py +0 -0
  107. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/typing.py +0 -0
  108. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/utils/__init__.py +0 -0
  109. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  110. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/utils/tracing.py +0 -0
  111. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim/utils/wrappers.py +0 -0
  112. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  113. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  114. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim.egg-info/requires.txt +0 -0
  115. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/src/jaxsim.egg-info/top_level.txt +0 -0
  116. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/__init__.py +0 -0
  117. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/conftest.py +0 -0
  118. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_api_com.py +0 -0
  119. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_api_contact.py +0 -0
  120. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_api_data.py +0 -0
  121. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_api_frame.py +0 -0
  122. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_api_joint.py +0 -0
  123. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_api_link.py +0 -0
  124. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_api_model.py +0 -0
  125. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_benchmark.py +0 -0
  126. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_exceptions.py +0 -0
  127. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_meshes.py +0 -0
  128. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_pytree.py +0 -0
  129. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_simulations.py +0 -0
  130. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/test_visualizer.py +0 -0
  131. {jaxsim-0.6.2.dev233 → jaxsim-0.6.2.dev238}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.6.2.dev233
3
+ Version: 0.6.2.dev238
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.6.2.dev233'
21
- __version_tuple__ = version_tuple = (0, 6, 2, 'dev233')
20
+ __version__ = version = '0.6.2.dev238'
21
+ __version_tuple__ = version_tuple = (0, 6, 2, 'dev238')
@@ -60,7 +60,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
60
60
  _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
61
61
 
62
62
  # Extended state for soft and rigid contact models.
63
- contact_state: dict[str, jtp.Array] = dataclasses.field(default=None)
63
+ contact_state: dict[str, jtp.Array] = dataclasses.field(default_factory=dict)
64
64
 
65
65
  @staticmethod
66
66
  def build(
@@ -174,7 +174,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
174
174
  contact_state = contact_state or {}
175
175
 
176
176
  if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
177
- contact_state.setdefault(
177
+ contact_state["tangential_deformation"] = contact_state.get(
178
178
  "tangential_deformation",
179
179
  jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
180
180
  )
@@ -420,11 +420,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
420
420
  Replace the attributes of the `JaxSimModelData` object.
421
421
  """
422
422
 
423
- # Extract the batch size.
424
- batch_size = (
425
- self._base_transform.shape[0] if self._base_transform.ndim > 2 else 1
426
- )
427
-
428
423
  if joint_positions is None:
429
424
  joint_positions = self.joint_positions
430
425
  if joint_velocities is None:
@@ -437,10 +432,11 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
437
432
  contact_state = self.contact_state
438
433
 
439
434
  if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
440
- contact_state.setdefault(
441
- "tangential_deformation",
442
- jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
443
- )
435
+ contact_state = {
436
+ "tangential_deformation": jnp.zeros_like(
437
+ contact_state["tangential_deformation"]
438
+ )
439
+ }
444
440
 
445
441
  # Normalize the quaternion to avoid numerical issues.
446
442
  base_quaternion_norm = jaxsim.math.safe_norm(
@@ -450,20 +446,18 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
450
446
  base_quaternion_norm == 0, 1.0, base_quaternion_norm
451
447
  )
452
448
 
453
- joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float)
454
- joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float)
455
- base_quaternion = jnp.atleast_1d(base_quaternion.squeeze()).astype(float)
456
- base_position = jnp.atleast_1d(base_position.squeeze()).astype(float)
449
+ joint_positions = jnp.atleast_2d(joint_positions.squeeze()).astype(float)
450
+ joint_velocities = jnp.atleast_2d(joint_velocities.squeeze()).astype(float)
451
+ base_quaternion = jnp.atleast_2d(base_quaternion.squeeze()).astype(float)
452
+ base_position = jnp.atleast_2d(base_position.squeeze()).astype(float)
457
453
 
458
454
  base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
459
455
  translation=base_position, quaternion=base_quaternion
460
- )
456
+ ).reshape((-1, 4, 4))
461
457
 
462
458
  joint_transforms = jax.vmap(model.kin_dyn_parameters.joint_transforms)(
463
- joint_positions=jnp.broadcast_to(
464
- joint_positions, (batch_size, model.dofs())
465
- ),
466
- base_transform=jnp.broadcast_to(base_transform, (batch_size, 4, 4)),
459
+ joint_positions=joint_positions,
460
+ base_transform=base_transform,
467
461
  )
468
462
 
469
463
  if base_linear_velocity is None and base_angular_velocity is None:
@@ -494,27 +488,31 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
494
488
  jaxsim.rbda.forward_kinematics_model, in_axes=(None,)
495
489
  )(
496
490
  model,
497
- base_position=jnp.broadcast_to(base_position, (batch_size, 3)),
498
- base_quaternion=jnp.broadcast_to(base_quaternion, (batch_size, 4)),
499
- joint_positions=jnp.broadcast_to(
500
- joint_positions, (batch_size, model.dofs())
501
- ),
502
- joint_velocities=jnp.broadcast_to(
503
- joint_velocities, (batch_size, model.dofs())
504
- ),
505
- base_linear_velocity_inertial=jnp.broadcast_to(
506
- base_linear_velocity_inertial, (batch_size, 3)
507
- ),
508
- base_angular_velocity_inertial=jnp.broadcast_to(
509
- base_angular_velocity_inertial, (batch_size, 3)
491
+ base_position=base_position,
492
+ base_quaternion=base_quaternion,
493
+ joint_positions=joint_positions,
494
+ joint_velocities=joint_velocities,
495
+ base_linear_velocity_inertial=jnp.atleast_2d(base_linear_velocity_inertial),
496
+ base_angular_velocity_inertial=jnp.atleast_2d(
497
+ base_angular_velocity_inertial
510
498
  ),
511
499
  )
512
500
 
513
501
  # Adjust the output shapes.
514
- if batch_size == 1:
515
- link_transforms = link_transforms.reshape(self._link_transforms.shape)
516
- link_velocities = link_velocities.reshape(self._link_velocities.shape)
517
- joint_transforms = joint_transforms.reshape(self._joint_transforms.shape)
502
+ joint_positions = joint_positions.reshape(self._joint_positions.shape)
503
+ joint_velocities = joint_velocities.reshape(self._joint_velocities.shape)
504
+ base_quaternion = base_quaternion.reshape(self._base_quaternion.shape)
505
+ base_linear_velocity_inertial = base_linear_velocity_inertial.reshape(
506
+ self._base_linear_velocity.shape
507
+ )
508
+ base_angular_velocity_inertial = base_angular_velocity_inertial.reshape(
509
+ self._base_angular_velocity.shape
510
+ )
511
+ base_position = base_position.reshape(self._base_position.shape)
512
+ base_transform = base_transform.reshape(self._base_transform.shape)
513
+ joint_transforms = joint_transforms.reshape(self._joint_transforms.shape)
514
+ link_transforms = link_transforms.reshape(self._link_transforms.shape)
515
+ link_velocities = link_velocities.reshape(self._link_velocities.shape)
518
516
 
519
517
  return super().replace(
520
518
  _joint_positions=joint_positions,
@@ -477,8 +477,7 @@ class RelaxedRigidContacts(common.ContactModel):
477
477
  tol = solver_options.pop("tol")
478
478
  maxiter = solver_options.pop("maxiter")
479
479
 
480
- # Compute the 3D linear force in C[W] frame.
481
- solution, _ = run_optimization(
480
+ solve_fn = lambda *_: run_optimization(
482
481
  init_params=init_params,
483
482
  fun=objective,
484
483
  opt=optax.lbfgs(**solver_options),
@@ -486,6 +485,15 @@ class RelaxedRigidContacts(common.ContactModel):
486
485
  maxiter=maxiter,
487
486
  )
488
487
 
488
+ # Compute the 3D linear force in C[W] frame.
489
+ solution, _ = jax.lax.custom_linear_solve(
490
+ lambda x: A @ x,
491
+ -b,
492
+ solve=solve_fn,
493
+ symmetric=True,
494
+ has_aux=True,
495
+ )
496
+
489
497
  # Reshape the optimized solution to be a matrix of 3D contact forces.
490
498
  CW_fl_C = solution.reshape(-1, 3)
491
499
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.6.2.dev233
3
+ Version: 0.6.2.dev238
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -3,8 +3,6 @@ import os
3
3
  import jax
4
4
  import jax.numpy as jnp
5
5
  import numpy as np
6
- import optax
7
- import pytest
8
6
  from jax.test_util import check_grads
9
7
 
10
8
  import jaxsim.api as js
@@ -350,11 +348,6 @@ def test_ad_integration(
350
348
 
351
349
  model = jaxsim_models_types
352
350
 
353
- # TODO: Remove when https://github.com/google-deepmind/optax/pull/1190 is included in a release.
354
- # Skip if `optax` version is less or equal to "0.2.4" and the model is ergoCub.
355
- if model.name() == "ergoCub" and optax.__version__ <= "0.2.4":
356
- pytest.skip("Skipping ergoCub model with optax version <= 0.2.4.")
357
-
358
351
  _, subkey = jax.random.split(prng_key, num=2)
359
352
  data, references = get_random_data_and_references(
360
353
  model=model, velocity_representation=VelRepr.Inertial, key=subkey
@@ -416,13 +409,11 @@ def test_ad_integration(
416
409
  return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ
417
410
 
418
411
  # Check derivatives against finite differences.
419
- # We set forward mode only because the backward mode is not supported by the
420
- # current implementation of `optax` optimizers in the relaxed rigid contact model.
421
412
  check_grads(
422
413
  f=step,
423
414
  args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L),
424
415
  order=AD_ORDER,
425
- modes=["fwd"],
416
+ modes=["fwd", "rev"],
426
417
  eps=ε,
427
418
  )
428
419
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes