jaxsim 0.5.1.dev133__tar.gz → 0.5.1.dev139__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/PKG-INFO +1 -1
  2. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/__init__.py +0 -7
  3. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/_version.py +2 -2
  4. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/integrators/common.py +12 -9
  5. jaxsim-0.5.1.dev139/src/jaxsim/integrators/fixed_step.py +153 -0
  6. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/integrators/variable_step.py +73 -46
  7. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/adjoint.py +17 -11
  8. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/transform.py +9 -4
  9. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/PKG-INFO +1 -1
  10. jaxsim-0.5.1.dev133/src/jaxsim/integrators/fixed_step.py +0 -123
  11. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.devcontainer/Dockerfile +0 -0
  12. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.devcontainer/devcontainer.json +0 -0
  13. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.gitattributes +0 -0
  14. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.github/CODEOWNERS +0 -0
  15. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.github/dependabot.yml +0 -0
  16. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.github/workflows/ci_cd.yml +0 -0
  17. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.github/workflows/pixi.yml +0 -0
  18. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.github/workflows/read_the_docs.yml +0 -0
  19. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.gitignore +0 -0
  20. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.pre-commit-config.yaml +0 -0
  21. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/.readthedocs.yaml +0 -0
  22. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/CONTRIBUTING.md +0 -0
  23. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/LICENSE +0 -0
  24. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/README.md +0 -0
  25. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/Makefile +0 -0
  26. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/conf.py +0 -0
  27. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/examples.rst +0 -0
  28. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/guide/configuration.rst +0 -0
  29. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/guide/install.rst +0 -0
  30. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/index.rst +0 -0
  31. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/make.bat +0 -0
  32. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/modules/api.rst +0 -0
  33. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/modules/integrators.rst +0 -0
  34. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/modules/math.rst +0 -0
  35. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/modules/mujoco.rst +0 -0
  36. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/modules/parsers.rst +0 -0
  37. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/modules/rbda.rst +0 -0
  38. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/modules/typing.rst +0 -0
  39. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/docs/modules/utils.rst +0 -0
  40. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/environment.yml +0 -0
  41. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/examples/.gitattributes +0 -0
  42. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/examples/.gitignore +0 -0
  43. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/examples/README.md +0 -0
  44. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/examples/assets/build_cartpole_urdf.py +0 -0
  45. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/examples/assets/cartpole.urdf +0 -0
  46. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  47. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  48. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/examples/jaxsim_as_physics_engine_advanced.ipynb +0 -0
  49. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  50. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/pixi.lock +0 -0
  51. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/pyproject.toml +0 -0
  52. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/setup.cfg +0 -0
  53. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/setup.py +0 -0
  54. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/__init__.py +0 -0
  55. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/com.py +0 -0
  56. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/common.py +0 -0
  57. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/contact.py +0 -0
  58. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/data.py +0 -0
  59. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/frame.py +0 -0
  60. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/joint.py +0 -0
  61. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  62. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/link.py +0 -0
  63. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/model.py +0 -0
  64. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/ode.py +0 -0
  65. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/ode_data.py +0 -0
  66. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/api/references.py +0 -0
  67. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/exceptions.py +0 -0
  68. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/integrators/__init__.py +0 -0
  69. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/logging.py +0 -0
  70. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/__init__.py +0 -0
  71. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/cross.py +0 -0
  72. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/inertia.py +0 -0
  73. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/joint_model.py +0 -0
  74. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/quaternion.py +0 -0
  75. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/rotation.py +0 -0
  76. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/skew.py +0 -0
  77. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/math/utils.py +0 -0
  78. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/__init__.py +0 -0
  79. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/__main__.py +0 -0
  80. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/loaders.py +0 -0
  81. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/model.py +0 -0
  82. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/utils.py +0 -0
  83. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/visualizer.py +0 -0
  84. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/__init__.py +0 -0
  85. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  86. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  87. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  88. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/link.py +0 -0
  89. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/model.py +0 -0
  90. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  91. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/rod/__init__.py +0 -0
  92. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/rod/meshes.py +0 -0
  93. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/rod/parser.py +0 -0
  94. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/rod/utils.py +0 -0
  95. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/__init__.py +0 -0
  96. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/aba.py +0 -0
  97. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/collidable_points.py +0 -0
  98. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  99. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/common.py +0 -0
  100. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
  101. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  102. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/soft.py +0 -0
  103. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/visco_elastic.py +0 -0
  104. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/crba.py +0 -0
  105. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  106. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/jacobian.py +0 -0
  107. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/rnea.py +0 -0
  108. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/utils.py +0 -0
  109. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/terrain/__init__.py +0 -0
  110. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/terrain/terrain.py +0 -0
  111. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/typing.py +0 -0
  112. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/utils/__init__.py +0 -0
  113. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  114. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/utils/tracing.py +0 -0
  115. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim/utils/wrappers.py +0 -0
  116. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  117. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  118. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/requires.txt +0 -0
  119. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/top_level.txt +0 -0
  120. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/__init__.py +0 -0
  121. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/conftest.py +0 -0
  122. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_api_com.py +0 -0
  123. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_api_contact.py +0 -0
  124. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_api_data.py +0 -0
  125. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_api_frame.py +0 -0
  126. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_api_joint.py +0 -0
  127. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_api_link.py +0 -0
  128. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_api_model.py +0 -0
  129. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_automatic_differentiation.py +0 -0
  130. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_benchmark.py +0 -0
  131. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_contact.py +0 -0
  132. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_exceptions.py +0 -0
  133. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_meshes.py +0 -0
  134. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_pytree.py +0 -0
  135. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/test_simulations.py +0 -0
  136. {jaxsim-0.5.1.dev133 → jaxsim-0.5.1.dev139}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev133
3
+ Version: 0.5.1.dev139
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -34,13 +34,6 @@ def _jnp_options() -> None:
34
34
  logging.info("Enabling JAX to use 64-bit precision")
35
35
  jax.config.update("jax_enable_x64", True)
36
36
 
37
- import jax.numpy as jnp
38
- import numpy as np
39
-
40
- # Verify that 64-bit precision is correctly set.
41
- if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
42
- logging.warning("Failed to enable 64-bit precision in JAX")
43
-
44
37
  # Warn about experimental usage of 32-bit precision.
45
38
  else:
46
39
  logging.warning(
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.5.1.dev133'
16
- __version_tuple__ = version_tuple = (0, 5, 1, 'dev133')
15
+ __version__ = version = '0.5.1.dev139'
16
+ __version_tuple__ = version_tuple = (0, 5, 1, 'dev139')
@@ -170,14 +170,14 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
170
170
  """
171
171
 
172
172
  # The Runge-Kutta matrix.
173
- A: ClassVar[jtp.Matrix]
173
+ A: jtp.Matrix
174
174
 
175
175
  # The weights coefficients.
176
176
  # Note that in practice we typically use its transpose `b.transpose()`.
177
- b: ClassVar[jtp.Matrix]
177
+ b: jtp.Matrix
178
178
 
179
179
  # The nodes coefficients.
180
- c: ClassVar[jtp.Vector]
180
+ c: jtp.Vector
181
181
 
182
182
  # Define the order of the solution.
183
183
  # It should have as many elements as the number of rows of `b.transpose()`.
@@ -226,28 +226,31 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
226
226
  Returns:
227
227
  The integrator object.
228
228
  """
229
+ A = cls.__dataclass_fields__["A"].default_factory()
230
+ b = cls.__dataclass_fields__["b"].default_factory()
231
+ c = cls.__dataclass_fields__["c"].default_factory()
229
232
 
230
233
  # Check validity of the Butcher tableau.
231
- if not ExplicitRungeKutta.butcher_tableau_is_valid(A=cls.A, b=cls.b, c=cls.c):
234
+ if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
232
235
  raise ValueError("The Butcher tableau of this class is not valid.")
233
236
 
234
237
  # Check that b.T has enough rows based on the configured index of the solution.
235
- if cls.row_index_of_solution >= cls.b.T.shape[0]:
238
+ if cls.row_index_of_solution >= b.T.shape[0]:
236
239
  msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
237
- raise ValueError(msg.format(cls.row_index_of_solution, cls.b.T.shape[0]))
240
+ raise ValueError(msg.format(cls.row_index_of_solution, b.T.shape[0]))
238
241
 
239
242
  # Check that the tuple containing the order of the b.T rows matches the number
240
243
  # of the b.T rows.
241
- if len(cls.order_of_bT_rows) != cls.b.T.shape[0]:
244
+ if len(cls.order_of_bT_rows) != b.T.shape[0]:
242
245
  msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}."
243
- raise ValueError(msg.format(len(cls.order_of_bT_rows), cls.b.T.shape[0]))
246
+ raise ValueError(msg.format(len(cls.order_of_bT_rows), b.T.shape[0]))
244
247
 
245
248
  # Check if the Butcher tableau supports FSAL (first-same-as-last).
246
249
  # If it does, store the index of the intermediate derivative to be used as the
247
250
  # first derivative of the next iteration.
248
251
  has_fsal, index_of_fsal = ( # noqa: F841
249
252
  ExplicitRungeKutta.butcher_tableau_supports_fsal(
250
- A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
253
+ A=A, b=b, c=c, index_of_solution=cls.row_index_of_solution
251
254
  )
252
255
  )
253
256
 
@@ -0,0 +1,153 @@
1
+ import dataclasses
2
+ from typing import ClassVar, Generic
3
+
4
+ import jax.numpy as jnp
5
+ import jax_dataclasses
6
+
7
+ import jaxsim.api as js
8
+ import jaxsim.typing as jtp
9
+
10
+ from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
11
+
12
+ ODEStateDerivative = js.ode_data.ODEState
13
+
14
+ # =====================================================
15
+ # Explicit Runge-Kutta integrators operating on PyTrees
16
+ # =====================================================
17
+
18
+
19
+ @jax_dataclasses.pytree_dataclass
20
+ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
21
+ """
22
+ Forward Euler integrator.
23
+ """
24
+
25
+ A: jtp.Matrix = dataclasses.field(
26
+ default_factory=lambda: jnp.atleast_2d(0).astype(float), compare=False
27
+ )
28
+ b: jtp.Matrix = dataclasses.field(
29
+ default_factory=lambda: jnp.atleast_2d(1).astype(float), compare=False
30
+ )
31
+
32
+ c: jtp.Vector = dataclasses.field(
33
+ default_factory=lambda: jnp.atleast_1d(0).astype(float), compare=False
34
+ )
35
+
36
+ row_index_of_solution: int = 0
37
+ order_of_bT_rows: tuple[int, ...] = (1,)
38
+ index_of_fsal: jtp.IntLike | None = None
39
+ fsal_enabled_if_supported: bool = False
40
+
41
+
42
+ @jax_dataclasses.pytree_dataclass
43
+ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
44
+ """
45
+ Heun's second-order integrator.
46
+ """
47
+
48
+ A: jtp.Matrix = dataclasses.field(
49
+ default_factory=lambda: jnp.array(
50
+ [
51
+ [0, 0],
52
+ [1, 0],
53
+ ]
54
+ ).astype(float),
55
+ compare=False,
56
+ )
57
+
58
+ b: jtp.Matrix = dataclasses.field(
59
+ default_factory=lambda: (
60
+ jnp.atleast_2d(
61
+ jnp.array([1 / 2, 1 / 2]),
62
+ )
63
+ .astype(float)
64
+ .transpose()
65
+ ),
66
+ compare=False,
67
+ )
68
+
69
+ c: jtp.Vector = dataclasses.field(
70
+ default_factory=lambda: jnp.array(
71
+ [0, 1],
72
+ ).astype(float),
73
+ compare=False,
74
+ )
75
+
76
+ row_index_of_solution: ClassVar[int] = 0
77
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (2,)
78
+ index_of_fsal: jtp.IntLike | None = None
79
+ fsal_enabled_if_supported: bool = False
80
+
81
+
82
+ @jax_dataclasses.pytree_dataclass
83
+ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
84
+ """
85
+ Fourth-order Runge-Kutta integrator.
86
+ """
87
+
88
+ A: jtp.Matrix = dataclasses.field(
89
+ default_factory=lambda: jnp.array(
90
+ [
91
+ [0, 0, 0, 0],
92
+ [1 / 2, 0, 0, 0],
93
+ [0, 1 / 2, 0, 0],
94
+ [0, 0, 1, 0],
95
+ ]
96
+ ).astype(float),
97
+ compare=False,
98
+ )
99
+
100
+ b: jtp.Matrix = dataclasses.field(
101
+ default_factory=lambda: (
102
+ jnp.atleast_2d(
103
+ jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
104
+ )
105
+ .astype(float)
106
+ .transpose()
107
+ ),
108
+ compare=False,
109
+ )
110
+
111
+ c: jtp.Vector = dataclasses.field(
112
+ default_factory=lambda: jnp.array(
113
+ [0, 1 / 2, 1 / 2, 1],
114
+ ).astype(float),
115
+ compare=False,
116
+ )
117
+
118
+ row_index_of_solution: ClassVar[int] = 0
119
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (4,)
120
+ index_of_fsal: jtp.IntLike | None = None
121
+ fsal_enabled_if_supported: bool = False
122
+
123
+
124
+ # ===============================================================================
125
+ # Explicit Runge-Kutta integrators operating on ODEState and integrating on SO(3)
126
+ # ===============================================================================
127
+
128
+
129
+ @jax_dataclasses.pytree_dataclass
130
+ class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
131
+ """
132
+ Forward Euler integrator for SO(3) states.
133
+ """
134
+
135
+ pass
136
+
137
+
138
+ @jax_dataclasses.pytree_dataclass
139
+ class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
140
+ """
141
+ Heun's second-order integrator for SO(3) states.
142
+ """
143
+
144
+ pass
145
+
146
+
147
+ @jax_dataclasses.pytree_dataclass
148
+ class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
149
+ """
150
+ Fourth-order Runge-Kutta integrator for SO(3) states.
151
+ """
152
+
153
+ pass
@@ -1,3 +1,4 @@
1
+ import dataclasses
1
2
  import functools
2
3
  from typing import Any, ClassVar, Generic
3
4
 
@@ -254,6 +255,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
254
255
  # Maximum number of rejected steps when the Δt needs to be reduced.
255
256
  max_step_rejections: Static[jtp.IntLike] = MAX_STEP_REJECTIONS_DEFAULT
256
257
 
258
+ index_of_fsal: jtp.IntLike | None = None
259
+ fsal_enabled_if_supported: bool = False
260
+
257
261
  def init(
258
262
  self,
259
263
  x0: State,
@@ -573,16 +577,18 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
573
577
  **kwargs: Additional parameters.
574
578
  """
575
579
 
580
+ b = cls.__dataclass_fields__["b"].default_factory()
581
+
576
582
  # Check that b.T has enough rows based on the configured index of the
577
583
  # solution estimate. This is necessary for embedded methods.
578
584
  if (
579
585
  cls.row_index_of_solution_estimate is not None
580
- and cls.row_index_of_solution_estimate >= cls.b.T.shape[0]
586
+ and cls.row_index_of_solution_estimate >= b.T.shape[0]
581
587
  ):
582
588
  msg = "The index of the solution estimate ({}-th row of `b.T`) "
583
589
  msg += "is out of range ({})."
584
590
  raise ValueError(
585
- msg.format(cls.row_index_of_solution_estimate, cls.b.T.shape[0])
591
+ msg.format(cls.row_index_of_solution_estimate, b.T.shape[0])
586
592
  )
587
593
 
588
594
  integrator = super().build(
@@ -611,35 +617,47 @@ class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
611
617
  The Heun-Euler integrator for SO(3) dynamics.
612
618
  """
613
619
 
614
- A: ClassVar[jtp.Matrix] = jnp.array(
615
- [
616
- [0, 0],
617
- [1, 0],
618
- ]
619
- ).astype(float)
620
-
621
- b: ClassVar[jtp.Matrix] = (
622
- jnp.atleast_2d(
623
- jnp.array(
624
- [
625
- [1 / 2, 1 / 2],
626
- [1, 0],
627
- ]
628
- ),
629
- )
630
- .astype(float)
631
- .transpose()
620
+ A: jtp.Matrix = dataclasses.field(
621
+ default_factory=lambda: jnp.array(
622
+ [
623
+ [0, 0],
624
+ [1, 0],
625
+ ]
626
+ ).astype(float),
627
+ compare=False,
632
628
  )
633
629
 
634
- c: ClassVar[jtp.Vector] = jnp.array(
635
- [0, 1],
636
- ).astype(float)
630
+ b: jtp.Matrix = dataclasses.field(
631
+ default_factory=lambda: (
632
+ jnp.atleast_2d(
633
+ jnp.array(
634
+ [
635
+ [1 / 2, 1 / 2],
636
+ [1, 0],
637
+ ]
638
+ ),
639
+ )
640
+ .astype(float)
641
+ .transpose()
642
+ ),
643
+ compare=False,
644
+ )
645
+
646
+ c: jtp.Vector = dataclasses.field(
647
+ default_factory=lambda: jnp.array(
648
+ [0, 1],
649
+ ).astype(float),
650
+ compare=False,
651
+ )
637
652
 
638
653
  row_index_of_solution: ClassVar[int] = 0
639
654
  row_index_of_solution_estimate: ClassVar[int | None] = 1
640
655
 
641
656
  order_of_bT_rows: ClassVar[tuple[int, ...]] = (2, 1)
642
657
 
658
+ index_of_fsal: jtp.IntLike | None = None
659
+ fsal_enabled_if_supported: bool = False
660
+
643
661
 
644
662
  @jax_dataclasses.pytree_dataclass
645
663
  class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
@@ -647,31 +665,40 @@ class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mi
647
665
  The Bogacki-Shampine integrator for SO(3) dynamics.
648
666
  """
649
667
 
650
- A: ClassVar[jtp.Matrix] = jnp.array(
651
- [
652
- [0, 0, 0, 0],
653
- [1 / 2, 0, 0, 0],
654
- [0, 3 / 4, 0, 0],
655
- [2 / 9, 1 / 3, 4 / 9, 0],
656
- ]
657
- ).astype(float)
658
-
659
- b: ClassVar[jtp.Matrix] = (
660
- jnp.atleast_2d(
661
- jnp.array(
662
- [
663
- [2 / 9, 1 / 3, 4 / 9, 0],
664
- [7 / 24, 1 / 4, 1 / 3, 1 / 8],
665
- ]
666
- ),
667
- )
668
- .astype(float)
669
- .transpose()
668
+ A: jtp.Matrix = dataclasses.field(
669
+ default_factory=lambda: jnp.array(
670
+ [
671
+ [0, 0, 0, 0],
672
+ [1 / 2, 0, 0, 0],
673
+ [0, 3 / 4, 0, 0],
674
+ [2 / 9, 1 / 3, 4 / 9, 0],
675
+ ]
676
+ ).astype(float),
677
+ compare=False,
670
678
  )
671
679
 
672
- c: ClassVar[jtp.Vector] = jnp.array(
673
- [0, 1 / 2, 3 / 4, 1],
674
- ).astype(float)
680
+ b: jtp.Matrix = dataclasses.field(
681
+ default_factory=lambda: (
682
+ jnp.atleast_2d(
683
+ jnp.array(
684
+ [
685
+ [2 / 9, 1 / 3, 4 / 9, 0],
686
+ [7 / 24, 1 / 4, 1 / 3, 1 / 8],
687
+ ]
688
+ ),
689
+ )
690
+ .astype(float)
691
+ .transpose()
692
+ ),
693
+ compare=False,
694
+ )
695
+
696
+ c: jtp.Vector = dataclasses.field(
697
+ default_factory=lambda: jnp.array(
698
+ [0, 1 / 2, 3 / 4, 1],
699
+ ).astype(float),
700
+ compare=False,
701
+ )
675
702
 
676
703
  row_index_of_solution: ClassVar[int] = 0
677
704
  row_index_of_solution_estimate: ClassVar[int | None] = 1
@@ -13,8 +13,8 @@ class Adjoint:
13
13
 
14
14
  @staticmethod
15
15
  def from_quaternion_and_translation(
16
- quaternion: jtp.Vector = jnp.array([1.0, 0, 0, 0]),
17
- translation: jtp.Vector = jnp.zeros(3),
16
+ quaternion: jtp.Vector | None = None,
17
+ translation: jtp.Vector | None = None,
18
18
  inverse: bool = False,
19
19
  normalize_quaternion: bool = False,
20
20
  ) -> jtp.Matrix:
@@ -22,14 +22,17 @@ class Adjoint:
22
22
  Create an adjoint matrix from a quaternion and a translation.
23
23
 
24
24
  Args:
25
- quaternion: A quaternion vector (4D) representing orientation.
26
- translation: A translation vector (3D).
27
- inverse: Whether to compute the inverse adjoint.
28
- normalize_quaternion: Whether to normalize the quaternion before creating the adjoint.
25
+ quaternion (jtp.Vector): A quaternion vector (4D) representing orientation. Default is [1, 0, 0, 0].
26
+ translation (jtp.Vector): A translation vector (3D). Default is [0, 0, 0].
27
+ inverse (bool): Whether to compute the inverse adjoint. Default is False.
28
+ normalize_quaternion (bool): Whether to normalize the quaternion before creating the adjoint.
29
+ Default is False.
29
30
 
30
31
  Returns:
31
32
  jtp.Matrix: The adjoint matrix.
32
33
  """
34
+ quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
35
+ translation = translation if translation is not None else jnp.zeros(3)
33
36
  assert quaternion.size == 4
34
37
  assert translation.size == 3
35
38
 
@@ -64,21 +67,24 @@ class Adjoint:
64
67
 
65
68
  @staticmethod
66
69
  def from_rotation_and_translation(
67
- rotation: jtp.Matrix = jnp.eye(3),
68
- translation: jtp.Vector = jnp.zeros(3),
70
+ rotation: jtp.Matrix | None = None,
71
+ translation: jtp.Vector | None = None,
69
72
  inverse: bool = False,
70
73
  ) -> jtp.Matrix:
71
74
  """
72
75
  Create an adjoint matrix from a rotation matrix and a translation vector.
73
76
 
74
77
  Args:
75
- rotation: A 3x3 rotation matrix.
76
- translation: A translation vector (3D).
77
- inverse: Whether to compute the inverse adjoint. Default is False.
78
+ rotation (jtp.Matrix): A 3x3 rotation matrix. Default is identity.
79
+ translation (jtp.Vector): A translation vector (3D). Default is [0, 0, 0].
80
+ inverse (bool): Whether to compute the inverse adjoint. Default is False.
78
81
 
79
82
  Returns:
80
83
  jtp.Matrix: The adjoint matrix.
81
84
  """
85
+ rotation = rotation if rotation is not None else jnp.eye(3)
86
+ translation = translation if translation is not None else jnp.zeros(3)
87
+
82
88
  assert rotation.shape == (3, 3)
83
89
  assert translation.size == 3
84
90
 
@@ -11,8 +11,8 @@ class Transform:
11
11
 
12
12
  @staticmethod
13
13
  def from_quaternion_and_translation(
14
- quaternion: jtp.VectorLike = jnp.array([1.0, 0, 0, 0]),
15
- translation: jtp.VectorLike = jnp.zeros(3),
14
+ quaternion: jtp.VectorLike | None = None,
15
+ translation: jtp.VectorLike | None = None,
16
16
  inverse: jtp.BoolLike = False,
17
17
  normalize_quaternion: jtp.BoolLike = False,
18
18
  ) -> jtp.Matrix:
@@ -30,6 +30,9 @@ class Transform:
30
30
  The 4x4 transformation matrix representing the SE(3) transformation.
31
31
  """
32
32
 
33
+ quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
34
+ translation = translation if translation is not None else jnp.zeros(3)
35
+
33
36
  W_Q_B = jnp.array(quaternion).astype(float)
34
37
  W_p_B = jnp.array(translation).astype(float)
35
38
 
@@ -47,8 +50,8 @@ class Transform:
47
50
 
48
51
  @staticmethod
49
52
  def from_rotation_and_translation(
50
- rotation: jtp.MatrixLike = jnp.eye(3),
51
- translation: jtp.VectorLike = jnp.zeros(3),
53
+ rotation: jtp.MatrixLike | None = None,
54
+ translation: jtp.VectorLike | None = None,
52
55
  inverse: jtp.BoolLike = False,
53
56
  ) -> jtp.Matrix:
54
57
  """
@@ -62,6 +65,8 @@ class Transform:
62
65
  Returns:
63
66
  The 4x4 transformation matrix representing the SE(3) transformation.
64
67
  """
68
+ rotation = rotation if rotation is not None else jnp.eye(3)
69
+ translation = translation if translation is not None else jnp.zeros(3)
65
70
 
66
71
  A_R_B = jnp.array(rotation).astype(float)
67
72
  W_p_B = jnp.array(translation).astype(float)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev133
3
+ Version: 0.5.1.dev139
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -1,123 +0,0 @@
1
- from typing import ClassVar, Generic
2
-
3
- import jax.numpy as jnp
4
- import jax_dataclasses
5
-
6
- import jaxsim.api as js
7
- import jaxsim.typing as jtp
8
-
9
- from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
10
-
11
- ODEStateDerivative = js.ode_data.ODEState
12
-
13
- # =====================================================
14
- # Explicit Runge-Kutta integrators operating on PyTrees
15
- # =====================================================
16
-
17
-
18
- @jax_dataclasses.pytree_dataclass
19
- class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
20
- """
21
- Forward Euler integrator.
22
- """
23
-
24
- A: ClassVar[jtp.Matrix] = jnp.atleast_2d(0).astype(float)
25
-
26
- b: ClassVar[jtp.Matrix] = jnp.atleast_2d(1).astype(float).transpose()
27
-
28
- c: ClassVar[jtp.Vector] = jnp.atleast_1d(0).astype(float)
29
-
30
- row_index_of_solution: ClassVar[int] = 0
31
- order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
32
-
33
-
34
- @jax_dataclasses.pytree_dataclass
35
- class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
36
- """
37
- Heun's second-order integrator.
38
- """
39
-
40
- A: ClassVar[jtp.Matrix] = jnp.array(
41
- [
42
- [0, 0],
43
- [1, 0],
44
- ]
45
- ).astype(float)
46
-
47
- b: ClassVar[jtp.Matrix] = (
48
- jnp.atleast_2d(
49
- jnp.array([1 / 2, 1 / 2]),
50
- )
51
- .astype(float)
52
- .transpose()
53
- )
54
-
55
- c: ClassVar[jtp.Vector] = jnp.array(
56
- [0, 1],
57
- ).astype(float)
58
-
59
- row_index_of_solution: ClassVar[int] = 0
60
- order_of_bT_rows: ClassVar[tuple[int, ...]] = (2,)
61
-
62
-
63
- @jax_dataclasses.pytree_dataclass
64
- class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
65
- """
66
- Fourth-order Runge-Kutta integrator.
67
- """
68
-
69
- A: ClassVar[jtp.Matrix] = jnp.array(
70
- [
71
- [0, 0, 0, 0],
72
- [1 / 2, 0, 0, 0],
73
- [0, 1 / 2, 0, 0],
74
- [0, 0, 1, 0],
75
- ]
76
- ).astype(float)
77
-
78
- b: ClassVar[jtp.Matrix] = (
79
- jnp.atleast_2d(
80
- jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
81
- )
82
- .astype(float)
83
- .transpose()
84
- )
85
-
86
- c: ClassVar[jtp.Vector] = jnp.array(
87
- [0, 1 / 2, 1 / 2, 1],
88
- ).astype(float)
89
-
90
- row_index_of_solution: ClassVar[int] = 0
91
- order_of_bT_rows: ClassVar[tuple[int, ...]] = (4,)
92
-
93
-
94
- # ===============================================================================
95
- # Explicit Runge-Kutta integrators operating on ODEState and integrating on SO(3)
96
- # ===============================================================================
97
-
98
-
99
- @jax_dataclasses.pytree_dataclass
100
- class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
101
- """
102
- Forward Euler integrator for SO(3) states.
103
- """
104
-
105
- pass
106
-
107
-
108
- @jax_dataclasses.pytree_dataclass
109
- class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
110
- """
111
- Heun's second-order integrator for SO(3) states.
112
- """
113
-
114
- pass
115
-
116
-
117
- @jax_dataclasses.pytree_dataclass
118
- class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
119
- """
120
- Fourth-order Runge-Kutta integrator for SO(3) states.
121
- """
122
-
123
- pass
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes