jaxsim 0.5.1.dev126__tar.gz → 0.5.1.dev139__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/PKG-INFO +1 -1
  2. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/pyproject.toml +13 -1
  3. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/__init__.py +0 -7
  4. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/com.py +1 -1
  6. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/common.py +1 -1
  7. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/contact.py +3 -0
  8. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/data.py +2 -1
  9. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/kin_dyn_parameters.py +18 -1
  10. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/model.py +7 -4
  11. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/ode.py +21 -1
  12. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/exceptions.py +8 -0
  13. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/integrators/common.py +72 -11
  14. jaxsim-0.5.1.dev139/src/jaxsim/integrators/fixed_step.py +153 -0
  15. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/integrators/variable_step.py +117 -46
  16. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/adjoint.py +19 -10
  17. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/cross.py +6 -2
  18. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/inertia.py +8 -4
  19. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/quaternion.py +10 -6
  20. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/rotation.py +6 -3
  21. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/skew.py +2 -2
  22. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/transform.py +12 -4
  23. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/utils.py +2 -2
  24. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/loaders.py +17 -7
  25. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/model.py +15 -15
  26. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/utils.py +6 -1
  27. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/visualizer.py +11 -7
  28. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/collision.py +7 -4
  29. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/joint.py +16 -14
  30. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/model.py +1 -1
  31. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/kinematic_graph.py +38 -0
  32. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/rod/meshes.py +5 -5
  33. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/rod/parser.py +1 -1
  34. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/rod/utils.py +11 -0
  35. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/common.py +2 -0
  36. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/relaxed_rigid.py +7 -4
  37. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/rigid.py +8 -4
  38. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/soft.py +37 -0
  39. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/visco_elastic.py +1 -0
  40. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/terrain/terrain.py +52 -0
  41. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/utils/jaxsim_dataclass.py +3 -3
  42. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/utils/tracing.py +2 -2
  43. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/utils/wrappers.py +9 -0
  44. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/PKG-INFO +1 -1
  45. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/conftest.py +10 -2
  46. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_meshes.py +5 -2
  47. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_simulations.py +2 -1
  48. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/utils_idyntree.py +2 -0
  49. jaxsim-0.5.1.dev126/src/jaxsim/integrators/fixed_step.py +0 -102
  50. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.devcontainer/Dockerfile +0 -0
  51. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.devcontainer/devcontainer.json +0 -0
  52. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.gitattributes +0 -0
  53. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.github/CODEOWNERS +0 -0
  54. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.github/dependabot.yml +0 -0
  55. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.github/workflows/ci_cd.yml +0 -0
  56. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.github/workflows/pixi.yml +0 -0
  57. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.github/workflows/read_the_docs.yml +0 -0
  58. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.gitignore +0 -0
  59. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.pre-commit-config.yaml +0 -0
  60. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/.readthedocs.yaml +0 -0
  61. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/CONTRIBUTING.md +0 -0
  62. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/LICENSE +0 -0
  63. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/README.md +0 -0
  64. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/Makefile +0 -0
  65. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/conf.py +0 -0
  66. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/examples.rst +0 -0
  67. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/guide/configuration.rst +0 -0
  68. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/guide/install.rst +0 -0
  69. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/index.rst +0 -0
  70. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/make.bat +0 -0
  71. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/modules/api.rst +0 -0
  72. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/modules/integrators.rst +0 -0
  73. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/modules/math.rst +0 -0
  74. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/modules/mujoco.rst +0 -0
  75. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/modules/parsers.rst +0 -0
  76. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/modules/rbda.rst +0 -0
  77. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/modules/typing.rst +0 -0
  78. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/docs/modules/utils.rst +0 -0
  79. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/environment.yml +0 -0
  80. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/examples/.gitattributes +0 -0
  81. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/examples/.gitignore +0 -0
  82. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/examples/README.md +0 -0
  83. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/examples/assets/build_cartpole_urdf.py +0 -0
  84. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/examples/assets/cartpole.urdf +0 -0
  85. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  86. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  87. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/examples/jaxsim_as_physics_engine_advanced.ipynb +0 -0
  88. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  89. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/pixi.lock +0 -0
  90. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/setup.cfg +0 -0
  91. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/setup.py +0 -0
  92. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/__init__.py +0 -0
  93. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/frame.py +0 -0
  94. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/joint.py +0 -0
  95. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/link.py +0 -0
  96. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/ode_data.py +0 -0
  97. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/api/references.py +0 -0
  98. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/integrators/__init__.py +0 -0
  99. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/logging.py +0 -0
  100. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/__init__.py +0 -0
  101. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/math/joint_model.py +0 -0
  102. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/__init__.py +0 -0
  103. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/mujoco/__main__.py +0 -0
  104. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/__init__.py +0 -0
  105. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  106. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/descriptions/link.py +0 -0
  107. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/parsers/rod/__init__.py +0 -0
  108. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/__init__.py +0 -0
  109. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/aba.py +0 -0
  110. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/collidable_points.py +0 -0
  111. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  112. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/crba.py +0 -0
  113. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  114. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/jacobian.py +0 -0
  115. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/rnea.py +0 -0
  116. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/rbda/utils.py +0 -0
  117. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/terrain/__init__.py +0 -0
  118. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/typing.py +0 -0
  119. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim/utils/__init__.py +0 -0
  120. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  121. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  122. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/requires.txt +0 -0
  123. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/src/jaxsim.egg-info/top_level.txt +0 -0
  124. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/__init__.py +0 -0
  125. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_api_com.py +0 -0
  126. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_api_contact.py +0 -0
  127. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_api_data.py +0 -0
  128. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_api_frame.py +0 -0
  129. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_api_joint.py +0 -0
  130. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_api_link.py +0 -0
  131. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_api_model.py +0 -0
  132. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_automatic_differentiation.py +0 -0
  133. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_benchmark.py +0 -0
  134. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_contact.py +0 -0
  135. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_exceptions.py +0 -0
  136. {jaxsim-0.5.1.dev126 → jaxsim-0.5.1.dev139}/tests/test_pytree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev126
3
+ Version: 0.5.1.dev139
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -150,6 +150,7 @@ preview = true
150
150
  # https://docs.astral.sh/ruff/rules/
151
151
  select = [
152
152
  "B",
153
+ "D",
153
154
  "E",
154
155
  "F",
155
156
  "I",
@@ -162,6 +163,15 @@ select = [
162
163
  ignore = [
163
164
  "B008", # Function call in default argument
164
165
  "B024", # Abstract base class without abstract methods
166
+ "D100", # Missing docstring in public module
167
+ "D104", # Missing docstring in public package
168
+ "D105", # Missing docstring in magic method
169
+ "D200", # One-line docstring should fit on one line with quotes
170
+ "D202", # No blank lines allowed after function docstring
171
+ "D205", # 1 blank line required between summary line and description
172
+ "D212", # Multi-line docstring summary should start at the first line
173
+ "D411", # Missing blank line before section
174
+ "D413", # Missing blank line after last section
165
175
  "E402", # Module level import not at top of file
166
176
  "E501", # Line too long
167
177
  "E731", # Do not assign a `lambda` expression, use a `def`
@@ -173,9 +183,11 @@ ignore = [
173
183
  [tool.ruff.lint.per-file-ignores]
174
184
  # Ignore `E402` (import violations) in all `__init__.py` files
175
185
  "**/{tests,docs,tools}/*" = ["E402"]
176
- "**/{tests}/*" = ["B007"]
186
+ "**/{tests,examples}/*" = ["B007", "D100", "D102", "D103"]
177
187
  "__init__.py" = ["F401"]
178
188
  "docs/conf.py" = ["F401"]
189
+ "src/jaxsim/exceptions.py" = ["D401"]
190
+ "src/jaxsim/logging.py" = ["D101", "D103"]
179
191
 
180
192
  # ==================
181
193
  # Pixi configuration
@@ -34,13 +34,6 @@ def _jnp_options() -> None:
34
34
  logging.info("Enabling JAX to use 64-bit precision")
35
35
  jax.config.update("jax_enable_x64", True)
36
36
 
37
- import jax.numpy as jnp
38
- import numpy as np
39
-
40
- # Verify that 64-bit precision is correctly set.
41
- if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
42
- logging.warning("Failed to enable 64-bit precision in JAX")
43
-
44
37
  # Warn about experimental usage of 32-bit precision.
45
38
  else:
46
39
  logging.warning(
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.5.1.dev126'
16
- __version_tuple__ = version_tuple = (0, 5, 1, 'dev126')
15
+ __version__ = version = '0.5.1.dev139'
16
+ __version_tuple__ = version_tuple = (0, 5, 1, 'dev139')
@@ -279,7 +279,7 @@ def bias_acceleration(
279
279
  C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
280
280
  ) -> jtp.Vector:
281
281
  """
282
- Helper to convert the body-fixed representation of the link bias acceleration
282
+ Convert the body-fixed representation of the link bias acceleration
283
283
  C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
284
284
  """
285
285
 
@@ -26,7 +26,7 @@ _R = TypeVar("_R")
26
26
 
27
27
 
28
28
  def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
29
- """Applies a JAX named scope to a function for improved profiling and clarity."""
29
+ """Apply a JAX named scope to a function for improved profiling and clarity."""
30
30
 
31
31
  @functools.wraps(fn)
32
32
  def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
@@ -293,6 +293,9 @@ def in_contact(
293
293
  def estimate_good_soft_contacts_parameters(
294
294
  *args, **kwargs
295
295
  ) -> jaxsim.rbda.contacts.ContactParamsTypes:
296
+ """
297
+ Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.
298
+ """
296
299
 
297
300
  msg = "This method is deprecated, please use `{}`."
298
301
  logging.warning(msg.format(estimate_good_contact_parameters.__name__))
@@ -456,7 +456,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
456
456
  @jax.jit
457
457
  def generalized_velocity(self) -> jtp.Vector:
458
458
  r"""
459
- Get the generalized velocity
459
+ Get the generalized velocity.
460
+
460
461
  :math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`
461
462
 
462
463
  Returns:
@@ -52,10 +52,16 @@ class KinDynParameters(JaxsimDataclass):
52
52
 
53
53
  @property
54
54
  def parent_array(self) -> jtp.Vector:
55
+ r"""
56
+ Return the parent array :math:`\lambda(i)` of the model.
57
+ """
55
58
  return self._parent_array.get()
56
59
 
57
60
  @property
58
61
  def support_body_array_bool(self) -> jtp.Matrix:
62
+ r"""
63
+ Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
64
+ """
59
65
  return self._support_body_array_bool.get()
60
66
 
61
67
  @staticmethod
@@ -648,7 +654,16 @@ class LinkParameters(JaxsimDataclass):
648
654
  def build_from_flat_parameters(
649
655
  index: jtp.IntLike, parameters: jtp.VectorLike
650
656
  ) -> LinkParameters:
657
+ """
658
+ Build a LinkParameters object from a flat vector of parameters.
659
+
660
+ Args:
661
+ index: The index of the link.
662
+ parameters: The flat vector of parameters.
651
663
 
664
+ Returns:
665
+ The LinkParameters object.
666
+ """
652
667
  index = jnp.array(index).squeeze().astype(int)
653
668
 
654
669
  m = jnp.array(parameters[0]).squeeze().astype(float)
@@ -772,7 +787,9 @@ class ContactParameters(JaxsimDataclass):
772
787
 
773
788
  @property
774
789
  def indices_of_enabled_collidable_points(self) -> npt.NDArray:
775
-
790
+ """
791
+ Return the indices of the enabled collidable points.
792
+ """
776
793
  return np.where(np.array(self.enabled))[0]
777
794
 
778
795
  @staticmethod
@@ -63,6 +63,9 @@ class JaxSimModel(JaxsimDataclass):
63
63
 
64
64
  @property
65
65
  def description(self) -> ModelDescription:
66
+ """
67
+ Return the model description.
68
+ """
66
69
  return self._description.get()
67
70
 
68
71
  def __eq__(self, other: JaxSimModel) -> bool:
@@ -1015,7 +1018,7 @@ def forward_dynamics_aba(
1015
1018
  W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
1016
1019
  ) -> jtp.Vector:
1017
1020
  """
1018
- Helper to convert the inertial-fixed apparent base acceleration W_v̇_WB to
1021
+ Convert the inertial-fixed apparent base acceleration W_v̇_WB to
1019
1022
  another representation C_v̇_WB expressed in a generic frame C.
1020
1023
  """
1021
1024
 
@@ -1376,7 +1379,7 @@ def inverse_dynamics(
1376
1379
 
1377
1380
  def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
1378
1381
  """
1379
- Helper to convert the active representation of the base acceleration C_v̇_WB
1382
+ Convert the active representation of the base acceleration C_v̇_WB
1380
1383
  expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
1381
1384
  """
1382
1385
 
@@ -1825,7 +1828,7 @@ def link_bias_accelerations(
1825
1828
  C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
1826
1829
  ) -> jtp.Vector:
1827
1830
  """
1828
- Helper to convert the active representation of the base acceleration C_v̇_WB
1831
+ Convert the active representation of the base acceleration C_v̇_WB
1829
1832
  expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
1830
1833
  """
1831
1834
 
@@ -1961,7 +1964,7 @@ def link_bias_accelerations(
1961
1964
  L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
1962
1965
  ) -> jtp.Vector:
1963
1966
  """
1964
- Helper to convert the body-fixed apparent acceleration L_v̇_WL to
1967
+ Convert the body-fixed apparent acceleration L_v̇_WL to
1965
1968
  another representation C_v̇_WL expressed in a generic frame C.
1966
1969
  """
1967
1970
 
@@ -15,12 +15,32 @@ from .ode_data import ODEState
15
15
 
16
16
 
17
17
  class SystemDynamicsFromModelAndData(Protocol):
18
+ """
19
+ Protocol defining the signature of a function computing the system dynamics
20
+ given a model and data object.
21
+ """
22
+
18
23
  def __call__(
19
24
  self,
20
25
  model: js.model.JaxSimModel,
21
26
  data: js.data.JaxSimModelData,
22
27
  **kwargs: dict[str, Any],
23
- ) -> tuple[ODEState, dict[str, Any]]: ...
28
+ ) -> tuple[ODEState, dict[str, Any]]:
29
+ """
30
+ Compute the system dynamics given a model and data object.
31
+
32
+ Args:
33
+ model: The model to consider.
34
+ data: The data of the considered model.
35
+ **kwargs: Additional keyword arguments.
36
+
37
+ Returns:
38
+ A tuple with an `ODEState` object storing in each of its attributes the
39
+ corresponding derivative, and the dictionary of auxiliary data returned
40
+ by the system dynamics evaluation.
41
+ """
42
+
43
+ pass
24
44
 
25
45
 
26
46
  def wrap_system_dynamics_for_integration(
@@ -17,6 +17,8 @@ def raise_if(
17
17
  msg:
18
18
  The message to display when the exception is raised. The message can be a
19
19
  format string (fmt), whose fields are filled with the args and kwargs.
20
+ *args: The arguments to fill the format string.
21
+ **kwargs: The keyword arguments to fill the format string
20
22
  """
21
23
 
22
24
  # Disable host callback if running on unsupported hardware or if the user
@@ -61,6 +63,9 @@ def raise_if(
61
63
  def raise_runtime_error_if(
62
64
  condition: bool | jax.Array, msg: str, *args, **kwargs
63
65
  ) -> None:
66
+ """
67
+ Raise a RuntimeError if a condition is met. Useful in jit-compiled functions.
68
+ """
64
69
 
65
70
  return raise_if(condition, RuntimeError, msg, *args, **kwargs)
66
71
 
@@ -68,5 +73,8 @@ def raise_runtime_error_if(
68
73
  def raise_value_error_if(
69
74
  condition: bool | jax.Array, msg: str, *args, **kwargs
70
75
  ) -> None:
76
+ """
77
+ Raise a ValueError if a condition is met. Useful in jit-compiled functions.
78
+ """
71
79
 
72
80
  return raise_if(condition, ValueError, msg, *args, **kwargs)
@@ -36,9 +36,25 @@ PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
36
36
 
37
37
 
38
38
  class SystemDynamics(Protocol[State, StateDerivative]):
39
+ """
40
+ Protocol defining the system dynamics.
41
+ """
42
+
39
43
  def __call__(
40
44
  self, x: State, t: Time, **kwargs
41
- ) -> tuple[StateDerivative, dict[str, Any]]: ...
45
+ ) -> tuple[StateDerivative, dict[str, Any]]:
46
+ """
47
+ Compute the state derivative of the system.
48
+
49
+ Args:
50
+ x: The state of the system.
51
+ t: The time of the system.
52
+ **kwargs: Additional keyword arguments.
53
+
54
+ Returns:
55
+ The state derivative of the system and the auxiliary dictionary.
56
+ """
57
+ pass
42
58
 
43
59
 
44
60
  # =======================
@@ -48,6 +64,9 @@ class SystemDynamics(Protocol[State, StateDerivative]):
48
64
 
49
65
  @jax_dataclasses.pytree_dataclass
50
66
  class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
67
+ """
68
+ Factory class for integrators.
69
+ """
51
70
 
52
71
  dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
53
72
  repr=False, hash=False, compare=False, kw_only=True
@@ -110,6 +129,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
110
129
  def __call__(
111
130
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
112
131
  ) -> tuple[NextState, dict[str, Any]]:
132
+ """
133
+ Perform a single integration step.
134
+ """
113
135
  pass
114
136
 
115
137
  def init(
@@ -121,6 +143,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
121
143
  include_dynamics_aux_dict: bool = False,
122
144
  **kwargs,
123
145
  ) -> dict[str, Any]:
146
+ """
147
+ Initialize the integrator. This method is deprecated.
148
+ """
124
149
 
125
150
  logging.warning(
126
151
  "The 'init' method has been deprecated. There is no need to call it."
@@ -131,16 +156,28 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
131
156
 
132
157
  @jax_dataclasses.pytree_dataclass
133
158
  class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
159
+ """
160
+ Base class for explicit Runge-Kutta integrators.
161
+
162
+ Attributes:
163
+ A: The Runge-Kutta matrix.
164
+ b: The weights coefficients.
165
+ c: The nodes coefficients.
166
+ order_of_bT_rows: The order of the solution.
167
+ row_index_of_solution: The row of the integration output corresponding to the final solution.
168
+ fsal_enabled_if_supported: Whether to enable the FSAL property, if supported.
169
+ index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration.
170
+ """
134
171
 
135
172
  # The Runge-Kutta matrix.
136
- A: ClassVar[jtp.Matrix]
173
+ A: jtp.Matrix
137
174
 
138
175
  # The weights coefficients.
139
176
  # Note that in practice we typically use its transpose `b.transpose()`.
140
- b: ClassVar[jtp.Matrix]
177
+ b: jtp.Matrix
141
178
 
142
179
  # The nodes coefficients.
143
- c: ClassVar[jtp.Vector]
180
+ c: jtp.Vector
144
181
 
145
182
  # Define the order of the solution.
146
183
  # It should have as many elements as the number of rows of `b.transpose()`.
@@ -156,10 +193,16 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
156
193
 
157
194
  @property
158
195
  def has_fsal(self) -> bool:
196
+ """
197
+ Check if the integrator supports the FSAL property.
198
+ """
159
199
  return self.fsal_enabled_if_supported and self.index_of_fsal is not None
160
200
 
161
201
  @property
162
202
  def order(self) -> int:
203
+ """
204
+ Return the order of the integrator.
205
+ """
163
206
  return self.order_of_bT_rows[self.row_index_of_solution]
164
207
 
165
208
  @override
@@ -183,28 +226,31 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
183
226
  Returns:
184
227
  The integrator object.
185
228
  """
229
+ A = cls.__dataclass_fields__["A"].default_factory()
230
+ b = cls.__dataclass_fields__["b"].default_factory()
231
+ c = cls.__dataclass_fields__["c"].default_factory()
186
232
 
187
233
  # Check validity of the Butcher tableau.
188
- if not ExplicitRungeKutta.butcher_tableau_is_valid(A=cls.A, b=cls.b, c=cls.c):
234
+ if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
189
235
  raise ValueError("The Butcher tableau of this class is not valid.")
190
236
 
191
237
  # Check that b.T has enough rows based on the configured index of the solution.
192
- if cls.row_index_of_solution >= cls.b.T.shape[0]:
238
+ if cls.row_index_of_solution >= b.T.shape[0]:
193
239
  msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
194
- raise ValueError(msg.format(cls.row_index_of_solution, cls.b.T.shape[0]))
240
+ raise ValueError(msg.format(cls.row_index_of_solution, b.T.shape[0]))
195
241
 
196
242
  # Check that the tuple containing the order of the b.T rows matches the number
197
243
  # of the b.T rows.
198
- if len(cls.order_of_bT_rows) != cls.b.T.shape[0]:
244
+ if len(cls.order_of_bT_rows) != b.T.shape[0]:
199
245
  msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}."
200
- raise ValueError(msg.format(len(cls.order_of_bT_rows), cls.b.T.shape[0]))
246
+ raise ValueError(msg.format(len(cls.order_of_bT_rows), b.T.shape[0]))
201
247
 
202
248
  # Check if the Butcher tableau supports FSAL (first-same-as-last).
203
249
  # If it does, store the index of the intermediate derivative to be used as the
204
250
  # first derivative of the next iteration.
205
251
  has_fsal, index_of_fsal = ( # noqa: F841
206
252
  ExplicitRungeKutta.butcher_tableau_supports_fsal(
207
- A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution
253
+ A=A, b=b, c=c, index_of_solution=cls.row_index_of_solution
208
254
  )
209
255
  )
210
256
 
@@ -221,6 +267,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
221
267
  def __call__(
222
268
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
223
269
  ) -> tuple[NextState, dict[str, Any]]:
270
+ """
271
+ Perform a single integration step.
272
+ """
224
273
 
225
274
  # Here z is a batched state with as many batch elements as b.T rows.
226
275
  # Note that z has multiple batches only if b.T has more than one row,
@@ -331,7 +380,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
331
380
  def scan_body(
332
381
  carry: jax.Array, i: int | jax.Array
333
382
  ) -> tuple[jax.Array, dict[str, Any]]:
334
- """"""
383
+ """
384
+ Compute the kᵢ derivative of the Runge-Kutta stage.
385
+ """
335
386
 
336
387
  # Unpack the carry, i.e. the stacked kᵢ vectors.
337
388
  K = carry
@@ -498,6 +549,16 @@ class ExplicitRungeKuttaSO3Mixin:
498
549
  def post_process_state(
499
550
  cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
500
551
  ) -> js.ode_data.ODEState:
552
+ r"""
553
+ Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the
554
+ quaternion is normalized.
555
+
556
+ Args:
557
+ x0: The initial state of the system.
558
+ t0: The initial time of the system.
559
+ xf: The final state of the system obtain through the integration.
560
+ dt: The time step used for the integration.
561
+ """
501
562
 
502
563
  # Extract the initial base quaternion.
503
564
  W_Q_B_t0 = x0.physics_model.base_quaternion
@@ -0,0 +1,153 @@
1
+ import dataclasses
2
+ from typing import ClassVar, Generic
3
+
4
+ import jax.numpy as jnp
5
+ import jax_dataclasses
6
+
7
+ import jaxsim.api as js
8
+ import jaxsim.typing as jtp
9
+
10
+ from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
11
+
12
+ ODEStateDerivative = js.ode_data.ODEState
13
+
14
+ # =====================================================
15
+ # Explicit Runge-Kutta integrators operating on PyTrees
16
+ # =====================================================
17
+
18
+
19
+ @jax_dataclasses.pytree_dataclass
20
+ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
21
+ """
22
+ Forward Euler integrator.
23
+ """
24
+
25
+ A: jtp.Matrix = dataclasses.field(
26
+ default_factory=lambda: jnp.atleast_2d(0).astype(float), compare=False
27
+ )
28
+ b: jtp.Matrix = dataclasses.field(
29
+ default_factory=lambda: jnp.atleast_2d(1).astype(float), compare=False
30
+ )
31
+
32
+ c: jtp.Vector = dataclasses.field(
33
+ default_factory=lambda: jnp.atleast_1d(0).astype(float), compare=False
34
+ )
35
+
36
+ row_index_of_solution: int = 0
37
+ order_of_bT_rows: tuple[int, ...] = (1,)
38
+ index_of_fsal: jtp.IntLike | None = None
39
+ fsal_enabled_if_supported: bool = False
40
+
41
+
42
+ @jax_dataclasses.pytree_dataclass
43
+ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
44
+ """
45
+ Heun's second-order integrator.
46
+ """
47
+
48
+ A: jtp.Matrix = dataclasses.field(
49
+ default_factory=lambda: jnp.array(
50
+ [
51
+ [0, 0],
52
+ [1, 0],
53
+ ]
54
+ ).astype(float),
55
+ compare=False,
56
+ )
57
+
58
+ b: jtp.Matrix = dataclasses.field(
59
+ default_factory=lambda: (
60
+ jnp.atleast_2d(
61
+ jnp.array([1 / 2, 1 / 2]),
62
+ )
63
+ .astype(float)
64
+ .transpose()
65
+ ),
66
+ compare=False,
67
+ )
68
+
69
+ c: jtp.Vector = dataclasses.field(
70
+ default_factory=lambda: jnp.array(
71
+ [0, 1],
72
+ ).astype(float),
73
+ compare=False,
74
+ )
75
+
76
+ row_index_of_solution: ClassVar[int] = 0
77
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (2,)
78
+ index_of_fsal: jtp.IntLike | None = None
79
+ fsal_enabled_if_supported: bool = False
80
+
81
+
82
+ @jax_dataclasses.pytree_dataclass
83
+ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
84
+ """
85
+ Fourth-order Runge-Kutta integrator.
86
+ """
87
+
88
+ A: jtp.Matrix = dataclasses.field(
89
+ default_factory=lambda: jnp.array(
90
+ [
91
+ [0, 0, 0, 0],
92
+ [1 / 2, 0, 0, 0],
93
+ [0, 1 / 2, 0, 0],
94
+ [0, 0, 1, 0],
95
+ ]
96
+ ).astype(float),
97
+ compare=False,
98
+ )
99
+
100
+ b: jtp.Matrix = dataclasses.field(
101
+ default_factory=lambda: (
102
+ jnp.atleast_2d(
103
+ jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
104
+ )
105
+ .astype(float)
106
+ .transpose()
107
+ ),
108
+ compare=False,
109
+ )
110
+
111
+ c: jtp.Vector = dataclasses.field(
112
+ default_factory=lambda: jnp.array(
113
+ [0, 1 / 2, 1 / 2, 1],
114
+ ).astype(float),
115
+ compare=False,
116
+ )
117
+
118
+ row_index_of_solution: ClassVar[int] = 0
119
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (4,)
120
+ index_of_fsal: jtp.IntLike | None = None
121
+ fsal_enabled_if_supported: bool = False
122
+
123
+
124
+ # ===============================================================================
125
+ # Explicit Runge-Kutta integrators operating on ODEState and integrating on SO(3)
126
+ # ===============================================================================
127
+
128
+
129
+ @jax_dataclasses.pytree_dataclass
130
+ class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
131
+ """
132
+ Forward Euler integrator for SO(3) states.
133
+ """
134
+
135
+ pass
136
+
137
+
138
+ @jax_dataclasses.pytree_dataclass
139
+ class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
140
+ """
141
+ Heun's second-order integrator for SO(3) states.
142
+ """
143
+
144
+ pass
145
+
146
+
147
+ @jax_dataclasses.pytree_dataclass
148
+ class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
149
+ """
150
+ Fourth-order Runge-Kutta integrator for SO(3) states.
151
+ """
152
+
153
+ pass