jaxsim 0.6.2.dev281__tar.gz → 0.6.2.dev294__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/PKG-INFO +1 -1
  2. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/examples/jaxsim_as_physics_engine.ipynb +15 -5
  3. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/_version.py +2 -2
  4. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/kin_dyn_parameters.py +286 -6
  5. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/model.py +384 -29
  6. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/joint_model.py +2 -1
  7. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/mujoco/utils.py +1 -5
  8. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim.egg-info/PKG-INFO +1 -1
  9. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim.egg-info/SOURCES.txt +1 -0
  10. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/conftest.py +235 -0
  11. jaxsim-0.6.2.dev294/tests/test_api_model_hw_parametrization.py +380 -0
  12. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_automatic_differentiation.py +55 -0
  13. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.devcontainer/Dockerfile +0 -0
  14. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.devcontainer/devcontainer.json +0 -0
  15. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.gitattributes +0 -0
  16. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.github/CODEOWNERS +0 -0
  17. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.github/dependabot.yml +0 -0
  18. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.github/release.yml +0 -0
  19. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.github/workflows/ci_cd.yml +0 -0
  20. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.github/workflows/gpu_benchmark.yml +0 -0
  21. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.github/workflows/pixi.yml +0 -0
  22. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.github/workflows/read_the_docs.yml +0 -0
  23. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.gitignore +0 -0
  24. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.pre-commit-config.yaml +0 -0
  25. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/.readthedocs.yaml +0 -0
  26. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/CONTRIBUTING.md +0 -0
  27. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/LICENSE +0 -0
  28. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/README.md +0 -0
  29. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/Makefile +0 -0
  30. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/conf.py +0 -0
  31. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/examples.rst +0 -0
  32. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/guide/configuration.rst +0 -0
  33. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/guide/install.rst +0 -0
  34. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/index.rst +0 -0
  35. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/make.bat +0 -0
  36. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/modules/api.rst +0 -0
  37. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/modules/math.rst +0 -0
  38. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/modules/mujoco.rst +0 -0
  39. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/modules/parsers.rst +0 -0
  40. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/modules/rbda.rst +0 -0
  41. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/modules/typing.rst +0 -0
  42. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/docs/modules/utils.rst +0 -0
  43. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/environment.yml +0 -0
  44. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/examples/.gitattributes +0 -0
  45. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/examples/.gitignore +0 -0
  46. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/examples/README.md +0 -0
  47. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/examples/assets/build_cartpole_urdf.py +0 -0
  48. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/examples/assets/cartpole.urdf +0 -0
  49. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  50. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/examples/jaxsim_as_physics_engine_advanced.ipynb +0 -0
  51. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  52. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/pixi.lock +0 -0
  53. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/pyproject.toml +0 -0
  54. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/setup.cfg +0 -0
  55. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/setup.py +0 -0
  56. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/__init__.py +0 -0
  57. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/__init__.py +0 -0
  58. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/actuation_model.py +0 -0
  59. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/com.py +0 -0
  60. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/common.py +0 -0
  61. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/contact.py +0 -0
  62. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/data.py +0 -0
  63. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/frame.py +0 -0
  64. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/integrators.py +0 -0
  65. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/joint.py +0 -0
  66. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/link.py +0 -0
  67. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/ode.py +0 -0
  68. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/api/references.py +0 -0
  69. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/exceptions.py +0 -0
  70. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/logging.py +0 -0
  71. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/__init__.py +0 -0
  72. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/adjoint.py +0 -0
  73. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/cross.py +0 -0
  74. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/inertia.py +0 -0
  75. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/quaternion.py +0 -0
  76. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/rotation.py +0 -0
  77. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/skew.py +0 -0
  78. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/transform.py +0 -0
  79. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/math/utils.py +0 -0
  80. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/mujoco/__init__.py +0 -0
  81. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/mujoco/__main__.py +0 -0
  82. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/mujoco/loaders.py +0 -0
  83. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/mujoco/model.py +0 -0
  84. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/mujoco/visualizer.py +0 -0
  85. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/__init__.py +0 -0
  86. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  87. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  88. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  89. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/descriptions/link.py +0 -0
  90. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/descriptions/model.py +0 -0
  91. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  92. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/rod/__init__.py +0 -0
  93. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/rod/meshes.py +0 -0
  94. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/rod/parser.py +0 -0
  95. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/parsers/rod/utils.py +0 -0
  96. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/__init__.py +0 -0
  97. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/aba.py +0 -0
  98. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/actuation/__init__.py +0 -0
  99. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/actuation/common.py +0 -0
  100. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/collidable_points.py +0 -0
  101. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  102. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/contacts/common.py +0 -0
  103. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
  104. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  105. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/contacts/soft.py +0 -0
  106. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/crba.py +0 -0
  107. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  108. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/jacobian.py +0 -0
  109. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/rnea.py +0 -0
  110. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/rbda/utils.py +0 -0
  111. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/terrain/__init__.py +0 -0
  112. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/terrain/terrain.py +0 -0
  113. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/typing.py +0 -0
  114. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/utils/__init__.py +0 -0
  115. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  116. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/utils/tracing.py +0 -0
  117. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim/utils/wrappers.py +0 -0
  118. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  119. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim.egg-info/requires.txt +0 -0
  120. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/src/jaxsim.egg-info/top_level.txt +0 -0
  121. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/__init__.py +0 -0
  122. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_actuation.py +0 -0
  123. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_api_com.py +0 -0
  124. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_api_contact.py +0 -0
  125. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_api_data.py +0 -0
  126. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_api_frame.py +0 -0
  127. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_api_joint.py +0 -0
  128. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_api_link.py +0 -0
  129. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_api_model.py +0 -0
  130. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_benchmark.py +0 -0
  131. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_exceptions.py +0 -0
  132. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_meshes.py +0 -0
  133. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_pytree.py +0 -0
  134. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_simulations.py +0 -0
  135. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/test_visualizer.py +0 -0
  136. {jaxsim-0.6.2.dev281 → jaxsim-0.6.2.dev294}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.6.2.dev281
3
+ Version: 0.6.2.dev294
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -33,6 +33,7 @@
33
33
  "outputs": [],
34
34
  "source": [
35
35
  "# @title Imports and setup\n",
36
+ "import os\n",
36
37
  "import sys\n",
37
38
  "from IPython.display import clear_output\n",
38
39
  "\n",
@@ -62,7 +63,6 @@
62
63
  "import jaxsim.api as js\n",
63
64
  "from jaxsim import logging\n",
64
65
  "import pathlib\n",
65
- "import urllib.request\n",
66
66
  "\n",
67
67
  "logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
68
68
  "print(f\"Running on {jax.devices()}\")"
@@ -110,12 +110,22 @@
110
110
  "outputs": [],
111
111
  "source": [
112
112
  "# Create the JaxSim model.\n",
113
- "url = \"https://raw.githubusercontent.com/icub-tech-iit/ergocub-software/refs/heads/master/urdf/ergoCub/robots/ergoCubSN001/model.urdf\"\n",
113
+ "try:\n",
114
+ " os.environ[\"ROBOT_DESCRIPTION_COMMIT\"] = \"v0.7.7\"\n",
114
115
  "\n",
115
- "# Retrieve the file\n",
116
- "model_path, _ = urllib.request.urlretrieve(url)\n",
116
+ " import robot_descriptions.ergocub_description\n",
117
+ "\n",
118
+ "finally:\n",
119
+ " _ = os.environ.pop(\"ROBOT_DESCRIPTION_COMMIT\", None)\n",
120
+ "\n",
121
+ "model_description_path = pathlib.Path(\n",
122
+ " robot_descriptions.ergocub_description.URDF_PATH.replace(\n",
123
+ " \"ergoCubSN002\", \"ergoCubSN001\"\n",
124
+ " )\n",
125
+ ")\n",
126
+ "\n",
127
+ "clear_output()\n",
117
128
  "\n",
118
- "model_description_path = pathlib.Path(model_path)\n",
119
129
  "full_model = js.model.JaxSimModel.build_from_model_description(\n",
120
130
  " model_description=model_description_path,\n",
121
131
  " time_step=0.0001,\n",
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.6.2.dev281'
21
- __version_tuple__ = version_tuple = (0, 6, 2, 'dev281')
20
+ __version__ = version = '0.6.2.dev294'
21
+ __version_tuple__ = version_tuple = (0, 6, 2, 'dev294')
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
+ from typing import ClassVar
4
5
 
5
6
  import jax.lax
6
7
  import jax.numpy as jnp
@@ -9,8 +10,10 @@ import numpy as np
9
10
  import numpy.typing as npt
10
11
  from jax_dataclasses import Static
11
12
 
13
+ import jaxsim
12
14
  import jaxsim.typing as jtp
13
- from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
15
+ from jaxsim.math import Inertia, JointModel, supported_joint_motion
16
+ from jaxsim.math.adjoint import Adjoint
14
17
  from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription
15
18
  from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
16
19
 
@@ -30,6 +33,7 @@ class KinDynParameters(JaxsimDataclass):
30
33
  contact_parameters: The parameters of the collidable points.
31
34
  joint_model: The joint model of the model.
32
35
  joint_parameters: The parameters of the joints.
36
+ hw_link_metadata: The hardware parameters of the model links.
33
37
  """
34
38
 
35
39
  # Static
@@ -51,6 +55,9 @@ class KinDynParameters(JaxsimDataclass):
51
55
  joint_model: JointModel
52
56
  joint_parameters: JointParameters | None
53
57
 
58
+ # Model hardware parameters
59
+ hw_link_metadata: HwLinkMetadata | None = dataclasses.field(default=None)
60
+
54
61
  @property
55
62
  def motion_subspaces(self) -> jtp.Matrix:
56
63
  r"""
@@ -197,7 +204,6 @@ class KinDynParameters(JaxsimDataclass):
197
204
  carry0 = κb, link_index
198
205
 
199
206
  def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
200
-
201
207
  κb, active_link_index = carry
202
208
 
203
209
  κb, active_link_index = jax.lax.cond(
@@ -224,7 +230,6 @@ class KinDynParameters(JaxsimDataclass):
224
230
  )
225
231
 
226
232
  def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:
227
-
228
233
  S = {
229
234
  JointType.Fixed: np.zeros(shape=(6, 1)),
230
235
  JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])),
@@ -265,14 +270,12 @@ class KinDynParameters(JaxsimDataclass):
265
270
  )
266
271
 
267
272
  def __eq__(self, other: KinDynParameters) -> bool:
268
-
269
273
  if not isinstance(other, KinDynParameters):
270
274
  return False
271
275
 
272
276
  return hash(self) == hash(other)
273
277
 
274
278
  def __hash__(self) -> int:
275
-
276
279
  return hash(
277
280
  (
278
281
  hash(self.number_of_links()),
@@ -671,7 +674,11 @@ class LinkParameters(JaxsimDataclass):
671
674
 
672
675
  return (
673
676
  jnp.hstack(
674
- [params.mass, params.center_of_mass.squeeze(), params.inertia_elements]
677
+ [
678
+ params.mass,
679
+ params.center_of_mass.squeeze(),
680
+ params.inertia_elements,
681
+ ]
675
682
  )
676
683
  .squeeze()
677
684
  .astype(float)
@@ -882,3 +889,276 @@ class FrameParameters(JaxsimDataclass):
882
889
  assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]
883
890
 
884
891
  return fp
892
+
893
+
894
+ @dataclasses.dataclass(frozen=True)
895
+ class LinkParametrizableShape:
896
+ """
897
+ Enum-like class listing the supported shapes for HW parametrization.
898
+ """
899
+
900
+ Unsupported: ClassVar[int] = -1
901
+ Box: ClassVar[int] = 0
902
+ Cylinder: ClassVar[int] = 1
903
+ Sphere: ClassVar[int] = 2
904
+
905
+
906
+ @jax_dataclasses.pytree_dataclass
907
+ class HwLinkMetadata(JaxsimDataclass):
908
+ """
909
+ Class storing the hardware parameters of a link.
910
+
911
+ Attributes:
912
+ shape: The shape of the link.
913
+ 0 = box, 1 = sphere, 2 = cylinder, -1 = unsupported.
914
+ dims: The dimensions of the link.
915
+ box: [lx,ly,lz], sphere: [r,0,0], cylinder: [r,l,0]
916
+ density: The density of the link.
917
+ L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G.
918
+ L_H_vis: The homogeneous transformation matrix from the link frame to the visual frame.
919
+ L_H_pre_mask: The mask indicating the link's child joint indices.
920
+ L_H_pre: The homogeneous transforms for child joints.
921
+ """
922
+
923
+ shape: jtp.Vector
924
+ dims: jtp.Vector
925
+ density: jtp.Float
926
+ L_H_G: jtp.Matrix
927
+ L_H_vis: jtp.Matrix
928
+ L_H_pre_mask: jtp.Vector
929
+ L_H_pre: jtp.Matrix
930
+
931
+ @staticmethod
932
+ def compute_mass_and_inertia(
933
+ hw_link_metadata: HwLinkMetadata,
934
+ ) -> tuple[jtp.Float, jtp.Matrix]:
935
+ """
936
+ Compute the mass and inertia of a hardware link based on its metadata.
937
+
938
+ This function calculates the mass and inertia tensor of a hardware link
939
+ using its shape, dimensions, and density. The computation is performed
940
+ by using shape-specific methods.
941
+
942
+ Args:
943
+ hw_link_metadata: Metadata describing the hardware link,
944
+ including its shape, dimensions, and density.
945
+
946
+ Returns:
947
+ tuple: A tuple containing:
948
+ - mass: The computed mass of the hardware link.
949
+ - inertia: The computed inertia tensor of the hardware link.
950
+ """
951
+
952
+ mass, inertia = jax.lax.switch(
953
+ hw_link_metadata.shape,
954
+ [
955
+ HwLinkMetadata._box,
956
+ HwLinkMetadata._cylinder,
957
+ HwLinkMetadata._sphere,
958
+ ],
959
+ hw_link_metadata.dims,
960
+ hw_link_metadata.density,
961
+ )
962
+ return mass, inertia
963
+
964
+ @staticmethod
965
+ def _box(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
966
+ lx, ly, lz = dims
967
+
968
+ mass = density * lx * ly * lz
969
+
970
+ inertia = jnp.array(
971
+ [
972
+ [mass * (ly**2 + lz**2) / 12, 0, 0],
973
+ [0, mass * (lx**2 + lz**2) / 12, 0],
974
+ [0, 0, mass * (lx**2 + ly**2) / 12],
975
+ ]
976
+ )
977
+ return mass, inertia
978
+
979
+ @staticmethod
980
+ def _cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
981
+ r, l, _ = dims
982
+
983
+ mass = density * (jnp.pi * r**2 * l)
984
+
985
+ inertia = jnp.array(
986
+ [
987
+ [mass * (3 * r**2 + l**2) / 12, 0, 0],
988
+ [0, mass * (3 * r**2 + l**2) / 12, 0],
989
+ [0, 0, mass * (r**2) / 2],
990
+ ]
991
+ )
992
+
993
+ return mass, inertia
994
+
995
+ @staticmethod
996
+ def _sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
997
+ r = dims[0]
998
+
999
+ mass = density * (4 / 3 * jnp.pi * r**3)
1000
+
1001
+ inertia = jnp.eye(3) * (2 / 5 * mass * r**2)
1002
+
1003
+ return mass, inertia
1004
+
1005
+ @staticmethod
1006
+ def _convert_scaling_to_3d_vector(
1007
+ shape: jtp.Int, scaling_factors: jtp.Vector
1008
+ ) -> jtp.Vector:
1009
+ """
1010
+ Convert scaling factors for specific shape dimensions into a 3D scaling vector.
1011
+
1012
+ Args:
1013
+ shape: The shape of the link (e.g., box, sphere, cylinder).
1014
+ scaling_factors: The scaling factors for the shape dimensions.
1015
+
1016
+ Returns:
1017
+ A 3D scaling vector to apply to position vectors.
1018
+
1019
+ Note:
1020
+ The scaling factors are applied as follows to generate the 3D scale vector:
1021
+ - Box: [lx, ly, lz]
1022
+ - Cylinder: [r, r, l]
1023
+ - Sphere: [r, r, r]
1024
+ """
1025
+ return jax.lax.switch(
1026
+ shape,
1027
+ branches=[
1028
+ # Box
1029
+ lambda: scaling_factors,
1030
+ # Cylinder
1031
+ lambda: jnp.array(
1032
+ [
1033
+ scaling_factors[0],
1034
+ scaling_factors[0],
1035
+ scaling_factors[1],
1036
+ ]
1037
+ ),
1038
+ # Sphere
1039
+ lambda: jnp.array(
1040
+ [
1041
+ scaling_factors[0],
1042
+ scaling_factors[0],
1043
+ scaling_factors[0],
1044
+ ]
1045
+ ),
1046
+ ],
1047
+ )
1048
+
1049
+ @staticmethod
1050
+ def compute_inertia_link(I_com, mass, L_H_G) -> jtp.Matrix:
1051
+ """
1052
+ Compute the inertia tensor of the link based on its shape and mass.
1053
+ """
1054
+
1055
+ L_R_G = L_H_G[:3, :3]
1056
+ return L_R_G @ I_com @ L_R_G.T
1057
+
1058
+ @staticmethod
1059
+ def apply_scaling(
1060
+ hw_metadata: HwLinkMetadata, scaling_factors: ScalingFactors
1061
+ ) -> HwLinkMetadata:
1062
+ """
1063
+ Apply scaling to the hardware parameters and return a new HwLinkMetadata object.
1064
+
1065
+ Args:
1066
+ hw_metadata: the original HwLinkMetadata object.
1067
+ scaling_factors: the scaling factors to apply.
1068
+
1069
+ Returns:
1070
+ A new HwLinkMetadata object with updated parameters.
1071
+ """
1072
+
1073
+ # ==================================
1074
+ # Handle unsupported links
1075
+ # ==================================
1076
+ def unsupported_case(hw_metadata, scaling_factors):
1077
+ # Return the metadata unchanged for unsupported links
1078
+ return hw_metadata
1079
+
1080
+ def supported_case(hw_metadata, scaling_factors):
1081
+ # ==================================
1082
+ # Update the kinematics of the link
1083
+ # ==================================
1084
+
1085
+ # Get the nominal transforms
1086
+ L_H_G = hw_metadata.L_H_G
1087
+ L_H_vis = hw_metadata.L_H_vis
1088
+ L_H_pre_array = hw_metadata.L_H_pre
1089
+ L_H_pre_mask = hw_metadata.L_H_pre_mask
1090
+
1091
+ # Compute the 3D scaling vector
1092
+ scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector(
1093
+ hw_metadata.shape, scaling_factors.dims
1094
+ )
1095
+
1096
+ # Express the transforms in the G frame
1097
+ G_H_L = jaxsim.math.Transform.inverse(L_H_G)
1098
+ G_H_vis = G_H_L @ L_H_vis
1099
+ G_H_pre_array = jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array)
1100
+
1101
+ # Apply the scaling to the position vectors
1102
+ G_H̅_L = G_H_L.at[:3, 3].set(scale_vector * G_H_L[:3, 3])
1103
+ G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3])
1104
+ # Apply scaling to the position vectors in G_H_pre_array based on the mask
1105
+ G_H̅_pre_array = jax.vmap(
1106
+ lambda G_H_pre, mask: jnp.where(
1107
+ # Expand mask for broadcasting
1108
+ mask[..., None, None],
1109
+ # Apply scaling
1110
+ G_H_pre.at[:3, 3].set(scale_vector * G_H_pre[:3, 3]),
1111
+ # Keep unchanged if mask is False
1112
+ G_H_pre,
1113
+ )
1114
+ )(G_H_pre_array, L_H_pre_mask)
1115
+
1116
+ # Get back to the link frame
1117
+ L_H̅_G = jaxsim.math.Transform.inverse(G_H̅_L)
1118
+ L_H̅_vis = L_H̅_G @ G_H̅_vis
1119
+ L_H̅_pre_array = jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array)
1120
+
1121
+ # ============================
1122
+ # Update the shape parameters
1123
+ # ============================
1124
+
1125
+ updated_dims = hw_metadata.dims * scaling_factors.dims
1126
+
1127
+ # ==============================
1128
+ # Scale the density of the link
1129
+ # ==============================
1130
+
1131
+ updated_density = hw_metadata.density * scaling_factors.density
1132
+
1133
+ # ============================
1134
+ # Return updated HwLinkMetadata
1135
+ # ============================
1136
+
1137
+ return hw_metadata.replace(
1138
+ dims=updated_dims,
1139
+ density=updated_density,
1140
+ L_H_G=L_H̅_G,
1141
+ L_H_vis=L_H̅_vis,
1142
+ L_H_pre=L_H̅_pre_array,
1143
+ )
1144
+
1145
+ # Use jax.lax.cond to handle unsupported links
1146
+ return jax.lax.cond(
1147
+ hw_metadata.shape == LinkParametrizableShape.Unsupported,
1148
+ lambda: unsupported_case(hw_metadata, scaling_factors),
1149
+ lambda: supported_case(hw_metadata, scaling_factors),
1150
+ )
1151
+
1152
+
1153
+ @jax_dataclasses.pytree_dataclass
1154
+ class ScalingFactors(JaxsimDataclass):
1155
+ """
1156
+ Class storing scaling factors for hardware parameters.
1157
+
1158
+ Attributes:
1159
+ dims: Scaling factors for shape dimensions.
1160
+ density: Scaling factor for density.
1161
+ """
1162
+
1163
+ dims: jtp.Vector
1164
+ density: jtp.Float