jaxsim 0.2.1.dev38__tar.gz → 0.2.1.dev47__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/PKG-INFO +1 -1
  2. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/examples/PD_controller.ipynb +0 -41
  3. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/examples/Parallel_computing.ipynb +4 -14
  4. jaxsim-0.2.1.dev47/examples/README.md +29 -0
  5. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/_version.py +2 -2
  6. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/joint.py +3 -12
  7. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/kin_dyn_parameters.py +9 -10
  8. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/math/joint_model.py +38 -44
  9. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/__init__.py +1 -1
  10. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/joint.py +10 -35
  11. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/kinematic_graph.py +6 -4
  12. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/rod/parser.py +1 -1
  13. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/rod/utils.py +4 -8
  14. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/PKG-INFO +1 -1
  15. jaxsim-0.2.1.dev38/examples/README.md +0 -37
  16. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.devcontainer/Dockerfile +0 -0
  17. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.devcontainer/devcontainer.json +0 -0
  18. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.gitattributes +0 -0
  19. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.github/CODEOWNERS +0 -0
  20. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.github/workflows/ci_cd.yml +0 -0
  21. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.github/workflows/read_the_docs.yml +0 -0
  22. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.github/workflows/style.yml +0 -0
  23. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.gitignore +0 -0
  24. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.pre-commit-config.yaml +0 -0
  25. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/.readthedocs.yaml +0 -0
  26. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/CONTRIBUTING.md +0 -0
  27. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/LICENSE +0 -0
  28. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/README.md +0 -0
  29. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/Makefile +0 -0
  30. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/conf.py +0 -0
  31. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/guide/install.rst +0 -0
  32. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/index.rst +0 -0
  33. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/make.bat +0 -0
  34. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/modules/api.rst +0 -0
  35. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/modules/index.rst +0 -0
  36. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/modules/integrators.rst +0 -0
  37. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/modules/math.rst +0 -0
  38. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/modules/mujoco.rst +0 -0
  39. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/modules/parsers.rst +0 -0
  40. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/modules/rbda.rst +0 -0
  41. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/modules/typing.rst +0 -0
  42. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/docs/modules/utils.rst +0 -0
  43. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/environment.yml +0 -0
  44. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/examples/.gitattributes +0 -0
  45. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/examples/.gitignore +0 -0
  46. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/examples/assets/cartpole.urdf +0 -0
  47. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/pixi.lock +0 -0
  48. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/pyproject.toml +0 -0
  49. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/setup.cfg +0 -0
  50. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/setup.py +0 -0
  51. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/__init__.py +0 -0
  52. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/__init__.py +0 -0
  53. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/com.py +0 -0
  54. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/common.py +0 -0
  55. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/contact.py +0 -0
  56. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/data.py +0 -0
  57. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/link.py +0 -0
  58. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/model.py +0 -0
  59. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/ode.py +0 -0
  60. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/ode_data.py +0 -0
  61. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/api/references.py +0 -0
  62. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/integrators/__init__.py +0 -0
  63. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/integrators/common.py +0 -0
  64. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/integrators/fixed_step.py +0 -0
  65. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/integrators/variable_step.py +0 -0
  66. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/logging.py +0 -0
  67. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/math/__init__.py +0 -0
  68. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/math/adjoint.py +0 -0
  69. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/math/cross.py +0 -0
  70. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/math/inertia.py +0 -0
  71. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/math/quaternion.py +0 -0
  72. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/math/rotation.py +0 -0
  73. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/math/skew.py +0 -0
  74. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/math/transform.py +0 -0
  75. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/__init__.py +0 -0
  76. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/__main__.py +0 -0
  77. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/loaders.py +0 -0
  78. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/model.py +0 -0
  79. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/visualizer.py +0 -0
  80. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/__init__.py +0 -0
  81. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  82. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/link.py +0 -0
  83. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/model.py +0 -0
  84. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/rod/__init__.py +0 -0
  85. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/__init__.py +0 -0
  86. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/aba.py +0 -0
  87. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/collidable_points.py +0 -0
  88. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/crba.py +0 -0
  89. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  90. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/jacobian.py +0 -0
  91. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/rnea.py +0 -0
  92. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/soft_contacts.py +0 -0
  93. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/utils.py +0 -0
  94. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/terrain/__init__.py +0 -0
  95. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/terrain/terrain.py +0 -0
  96. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/typing.py +0 -0
  97. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/utils/__init__.py +0 -0
  98. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/utils/hashless.py +0 -0
  99. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  100. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim/utils/tracing.py +0 -0
  101. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  102. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  103. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/not-zip-safe +0 -0
  104. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/requires.txt +0 -0
  105. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/top_level.txt +0 -0
  106. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/__init__.py +0 -0
  107. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/conftest.py +0 -0
  108. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/test_api_com.py +0 -0
  109. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/test_api_data.py +0 -0
  110. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/test_api_joint.py +0 -0
  111. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/test_api_link.py +0 -0
  112. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/test_api_model.py +0 -0
  113. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/test_automatic_differentiation.py +0 -0
  114. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/test_pytree.py +0 -0
  115. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/test_simulations.py +0 -0
  116. {jaxsim-0.2.1.dev38 → jaxsim-0.2.1.dev47}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.1.dev38
3
+ Version: 0.2.1.dev47
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -6,10 +6,6 @@
6
6
  "source": [
7
7
  "# `JAXsim` Showcase: PD Controller\n",
8
8
  "\n",
9
- "<a target=\"_blank\" href=\"https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/PD_controller.ipynb\">\n",
10
- " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
11
- "</a>\n",
12
- "\n",
13
9
  "First, we install the necessary packages and import them."
14
10
  ]
15
11
  },
@@ -23,14 +19,6 @@
23
19
  "from IPython.display import clear_output, HTML, display\n",
24
20
  "import sys\n",
25
21
  "\n",
26
- "IS_COLAB = \"google.colab\" in sys.modules\n",
27
- "\n",
28
- "# Install JAX and Gazebo\n",
29
- "if IS_COLAB:\n",
30
- " !{sys.executable} -m pip install -U -q jaxsim\n",
31
- " !apt -qq update && apt install -qq --no-install-recommends gazebo\n",
32
- " clear_output()\n",
33
- "\n",
34
22
  "import jax\n",
35
23
  "import jax.numpy as jnp\n",
36
24
  "from jaxsim import logging\n",
@@ -141,35 +129,6 @@
141
129
  "from jaxsim.mujoco.loaders import UrdfToMjcf\n",
142
130
  "\n",
143
131
  "\n",
144
- "if IS_COLAB:\n",
145
- " if subprocess.run(\"ffmpeg -version\", shell=True).returncode:\n",
146
- " !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
147
- " clear_output()\n",
148
- "\n",
149
- " if subprocess.run(\"nvidia-smi\").returncode:\n",
150
- " raise RuntimeError(\n",
151
- " \"Cannot communicate with GPU. \"\n",
152
- " \"Make sure you are using a GPU Colab runtime. \"\n",
153
- " \"Go to the Runtime menu and select Choose runtime type.\"\n",
154
- " )\n",
155
- "\n",
156
- " # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
157
- " # This is usually installed as part of an Nvidia driver package, but the Colab\n",
158
- " # kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
159
- " # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
160
- " NVIDIA_ICD_CONFIG_PATH = \"/usr/share/glvnd/egl_vendor.d/10_nvidia.json\"\n",
161
- " if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
162
- " with open(NVIDIA_ICD_CONFIG_PATH, \"w\") as f:\n",
163
- " f.write(\n",
164
- " \"\"\"{\n",
165
- " \"file_format_version\" : \"1.0.0\",\n",
166
- " \"ICD\" : {\n",
167
- " \"library_path\" : \"libEGL_nvidia.so.0\"\n",
168
- " }\n",
169
- " }\n",
170
- " \"\"\"\n",
171
- " )\n",
172
- "\n",
173
132
  "%env MUJOCO_GL=egl\n",
174
133
  "\n",
175
134
  "try:\n",
@@ -4,11 +4,7 @@
4
4
  "cell_type": "markdown",
5
5
  "metadata": {},
6
6
  "source": [
7
- "# `JAXsim` Showcase: Parallel Simulation of a free-falling body\n",
8
- "\n",
9
- "<a target=\"_blank\" href=\"https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/Parallel_computing.ipynb\">\n",
10
- " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
11
- "</a>"
7
+ "# `JAXsim` Showcase: Parallel Simulation of a free-falling body"
12
8
  ]
13
9
  },
14
10
  {
@@ -29,16 +25,10 @@
29
25
  "\n",
30
26
  "from IPython.display import HTML, clear_output, display\n",
31
27
  "\n",
32
- "IS_COLAB = \"google.colab\" in sys.modules\n",
33
- "\n",
34
28
  "# Install JAX and Gazebo\n",
35
- "if IS_COLAB:\n",
36
- " !{sys.executable} -m pip install -U -q jaxsim\n",
37
- " !apt -qq update && apt install -qq --no-install-recommends gazebo\n",
38
- " clear_output()\n",
39
- "else:\n",
40
- " # Set environment variable to avoid GPU out of memory errors\n",
41
- " %env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
29
+ "\n",
30
+ "# Set environment variable to avoid GPU out of memory errors\n",
31
+ "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
42
32
  "\n",
43
33
  "import time\n",
44
34
  "from typing import Dict, Tuple\n",
@@ -0,0 +1,29 @@
1
+ # JAXsim Notebook Examples
2
+
3
+ This folder includes a Jupyter Notebook demonstrating the practical usage of JAXsim for system simulations.
4
+
5
+ ### Examples
6
+
7
+ - [PD_controller](./PD_controller.ipynb) - A simple example demonstrating the use of JAXsim to simulate a PD controller with gravity compensation for a 2-DOF cartpole.
8
+ - [Parallel_computing](./Parallel_computing.ipynb) - An example demonstrating how to simulate vectorized models in parallel using JAXsim.
9
+
10
+ > [!TIP]
11
+ > Stay tuned for more examples!
12
+
13
+ ## Running the Examples
14
+
15
+ To execute these examples utilizing JAXsim with hardware acceleration, you can use [pixi](https://pixi.sh) to run the examples in a local environment:
16
+
17
+ 1. **Install `pixi`:**
18
+
19
+ As per the [official documentation](https://pixi.sh/#installation):
20
+
21
+ ```bash
22
+ curl -fsSL https://pixi.sh/install.sh | bash
23
+ ```
24
+
25
+ 2. **Run the Example Notebook:**
26
+
27
+ Use `pixi run examples` from the project source directory to execute the example notebook locally.
28
+
29
+ This command will automatically handle the installation of necessary dependencies and execute the examples within a self-contained environment
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1.dev38'
16
- __version_tuple__ = version_tuple = (0, 2, 1, 'dev38')
15
+ __version__ = version = '0.2.1.dev47'
16
+ __version_tuple__ = version_tuple = (0, 2, 1, 'dev47')
@@ -3,7 +3,6 @@ from typing import Sequence
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
- import numpy as np
7
6
 
8
7
  import jaxsim.api as js
9
8
  import jaxsim.typing as jtp
@@ -30,17 +29,9 @@ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
30
29
  # Note: the index of the joint for RBDAs starts from 1, but
31
30
  # the index for accessing the right element starts from 0.
32
31
  # Therefore, there is a -1.
33
- return (
34
- jnp.array(
35
- np.argwhere(
36
- np.array(model.kin_dyn_parameters.joint_model.joint_names)
37
- == joint_name
38
- )
39
- - 1
40
- )
41
- .squeeze()
42
- .astype(int)
43
- )
32
+ return jnp.array(
33
+ model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
34
+ ).squeeze()
44
35
  return jnp.array(-1).astype(int)
45
36
 
46
37
 
@@ -382,21 +382,20 @@ class KynDynParameters(JaxsimDataclass):
382
382
  )
383
383
 
384
384
  # Compute the transforms and motion subspaces of the joints.
385
- # TODO: understand how to use joint_indices to access joint_types, right now
386
- # it fails when used within a JIT context.
387
- pre_H_suc_and_S = [
388
- supported_joint_motion(
389
- joint_type=self.joint_model.joint_types[i + 1],
390
- joint_position=jnp.array(s),
385
+ if self.number_of_joints() == 0:
386
+ pre_H_suc_J, S_J = jnp.empty((0, 4, 4)), jnp.empty((0, 6, 1))
387
+ else:
388
+ pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
389
+ jnp.array(self.joint_model.joint_types[1:]).astype(int),
390
+ jnp.array(joint_positions),
391
+ jnp.array(self.joint_model.joint_axis),
391
392
  )
392
- for i, s in enumerate(jnp.array(joint_positions).astype(float))
393
- ]
394
393
 
395
394
  # Extract the transforms and motion subspaces of the joints.
396
395
  # We stack the base transform W_H_B at index 0, and a dummy motion subspace
397
396
  # for either the fixed or free-floating joint connecting the world to the base.
398
- pre_H_suc = jnp.stack([W_H_B] + [H for H, _ in pre_H_suc_and_S])
399
- S = jnp.stack([jnp.vstack(jnp.zeros(6))] + [S for _, S in pre_H_suc_and_S])
397
+ pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
398
+ S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
400
399
 
401
400
  # Extract the successor-to-child fixed transforms.
402
401
  # Note that here we include also the index 0 since suc_H_child[0] stores the
@@ -1,7 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import functools
4
-
5
3
  import jax
6
4
  import jax.numpy as jnp
7
5
  import jax_dataclasses
@@ -9,12 +7,7 @@ import jaxlie
9
7
  from jax_dataclasses import Static
10
8
 
11
9
  import jaxsim.typing as jtp
12
- from jaxsim.parsers.descriptions import (
13
- JointDescriptor,
14
- JointGenericAxis,
15
- JointType,
16
- ModelDescription,
17
- )
10
+ from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
18
11
  from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
19
12
 
20
13
  from .rotation import Rotation
@@ -46,7 +39,8 @@ class JointModel:
46
39
 
47
40
  joint_dofs: Static[tuple[int, ...]]
48
41
  joint_names: Static[tuple[str, ...]]
49
- joint_types: Static[tuple[JointType | JointDescriptor, ...]]
42
+ joint_types: Static[tuple[JointType, ...]]
43
+ joint_axis: Static[tuple[JointGenericAxis, ...]]
50
44
 
51
45
  @staticmethod
52
46
  def build(description: ModelDescription) -> JointModel:
@@ -114,7 +108,8 @@ class JointModel:
114
108
  # Static attributes
115
109
  joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
116
110
  joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
117
- joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]),
111
+ joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
112
+ joint_axis=tuple([j.axis for j in ordered_joints]),
118
113
  )
119
114
 
120
115
  def parent_H_child(
@@ -204,8 +199,9 @@ class JointModel:
204
199
  """
205
200
 
206
201
  pre_H_suc, S = supported_joint_motion(
207
- joint_type=self.joint_types[joint_index],
208
- joint_position=joint_position,
202
+ self.joint_types[joint_index],
203
+ joint_position,
204
+ self.joint_axis[joint_index],
209
205
  )
210
206
 
211
207
  return pre_H_suc, S
@@ -226,59 +222,57 @@ class JointModel:
226
222
  return self.suc_H_i[joint_index]
227
223
 
228
224
 
229
- @functools.partial(jax.jit, static_argnames=["joint_type"])
225
+ @jax.jit
230
226
  def supported_joint_motion(
231
- joint_type: JointType | JointDescriptor, joint_position: jtp.VectorLike
227
+ joint_type: JointType,
228
+ joint_position: jtp.VectorLike,
229
+ joint_axis: JointGenericAxis,
230
+ /,
232
231
  ) -> tuple[jtp.Matrix, jtp.Array]:
233
232
  """
234
233
  Compute the homogeneous transformation and motion subspace of a joint.
235
234
 
236
235
  Args:
237
236
  joint_type: The type of the joint.
237
+ joint_axis: The axis of rotation or translation of the joint.
238
238
  joint_position: The position of the joint.
239
239
 
240
240
  Returns:
241
241
  A tuple containing the homogeneous transformation and the motion subspace.
242
242
  """
243
243
 
244
- if isinstance(joint_type, JointType):
245
- type_enum = joint_type
246
- elif isinstance(joint_type, JointDescriptor):
247
- type_enum = joint_type.joint_type
248
- else:
249
- raise ValueError(joint_type)
250
-
251
244
  # Prepare the joint position
252
245
  s = jnp.array(joint_position).astype(float)
253
246
 
254
- match type_enum:
255
-
256
- case JointType.R:
257
- joint_type: JointGenericAxis
247
+ def compute_F():
248
+ return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))
258
249
 
259
- pre_H_suc = jaxlie.SE3.from_rotation(
260
- rotation=jaxlie.SO3.from_matrix(
261
- Rotation.from_axis_angle(vector=s * joint_type.axis)
262
- )
250
+ def compute_R():
251
+ pre_H_suc = jaxlie.SE3.from_rotation(
252
+ rotation=jaxlie.SO3.from_matrix(
253
+ Rotation.from_axis_angle(vector=s * joint_axis)
263
254
  )
255
+ )
264
256
 
265
- S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_type.axis.squeeze()]))
266
-
267
- case JointType.P:
268
- joint_type: JointGenericAxis
269
-
270
- pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
271
- rotation=jaxlie.SO3.identity(),
272
- translation=jnp.array(s * joint_type.axis),
273
- )
257
+ S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()]))
258
+ return pre_H_suc, S
274
259
 
275
- S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))
260
+ def compute_P():
261
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
262
+ rotation=jaxlie.SO3.identity(),
263
+ translation=jnp.array(s * joint_axis),
264
+ )
276
265
 
277
- case JointType.F:
278
- pre_H_suc = jaxlie.SE3.identity()
279
- S = jnp.zeros(shape=(6, 1))
266
+ S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)]))
267
+ return pre_H_suc, S
280
268
 
281
- case _:
282
- raise ValueError(joint_type)
269
+ pre_H_suc, S = jax.lax.switch(
270
+ index=joint_type,
271
+ branches=(
272
+ compute_F, # JointType.Fixed
273
+ compute_R, # JointType.Revolute
274
+ compute_P, # JointType.Prismatic
275
+ ),
276
+ )
283
277
 
284
278
  return pre_H_suc.as_matrix(), S
@@ -1,4 +1,4 @@
1
1
  from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision
2
- from .joint import JointDescription, JointDescriptor, JointGenericAxis, JointType
2
+ from .joint import JointDescription, JointGenericAxis, JointType
3
3
  from .link import LinkDescription
4
4
  from .model import ModelDescription
@@ -1,8 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- import enum
5
- from typing import Tuple, Union
4
+ from typing import ClassVar, Tuple, Union
6
5
 
7
6
  import jax_dataclasses
8
7
  import numpy as np
@@ -14,39 +13,15 @@ from jaxsim.utils import JaxsimDataclass, Mutability
14
13
  from .link import LinkDescription
15
14
 
16
15
 
17
- @enum.unique
18
- class JointType(enum.IntEnum):
19
- """
20
- Type of supported joints.
21
- """
22
-
23
- @staticmethod
24
- def _generate_next_value_(name, start, count, last_values):
25
- # Start auto Enum value from 0 instead of 1
26
- return count
27
-
28
- #: Fixed joint.
29
- F = enum.auto()
30
-
31
- #: Revolute joint (1 DoF around axis).
32
- R = enum.auto()
33
-
34
- #: Prismatic joint (1 DoF along axis).
35
- P = enum.auto()
36
-
37
-
38
- @jax_dataclasses.pytree_dataclass
39
- class JointDescriptor:
40
- """
41
- Base class for joint types requiring to store additional metadata.
42
- """
43
-
44
- #: The joint type.
45
- joint_type: JointType
16
+ @dataclasses.dataclass(frozen=True)
17
+ class JointType:
18
+ Fixed: ClassVar[int] = 0
19
+ Revolute: ClassVar[int] = 1
20
+ Prismatic: ClassVar[int] = 2
46
21
 
47
22
 
48
23
  @jax_dataclasses.pytree_dataclass
49
- class JointGenericAxis(JointDescriptor):
24
+ class JointGenericAxis:
50
25
  """
51
26
  A joint requiring the specification of a 3D axis.
52
27
  """
@@ -55,7 +30,7 @@ class JointGenericAxis(JointDescriptor):
55
30
  axis: jtp.Vector
56
31
 
57
32
  def __hash__(self) -> int:
58
- return hash((self.joint_type, tuple(np.array(self.axis).tolist())))
33
+ return hash((tuple(np.array(self.axis).tolist())))
59
34
 
60
35
  def __eq__(self, other: JointGenericAxis) -> bool:
61
36
  if not isinstance(other, JointGenericAxis):
@@ -73,7 +48,7 @@ class JointDescription(JaxsimDataclass):
73
48
  name (str): The name of the joint.
74
49
  axis (npt.NDArray): The axis of rotation or translation for the joint.
75
50
  pose (npt.NDArray): The pose transformation matrix of the joint.
76
- jtype (Union[JointType, JointDescriptor]): The type of the joint.
51
+ jtype (JointType): The type of the joint.
77
52
  child (LinkDescription): The child link attached to the joint.
78
53
  parent (LinkDescription): The parent link attached to the joint.
79
54
  index (Optional[int]): An optional index for the joint.
@@ -89,7 +64,7 @@ class JointDescription(JaxsimDataclass):
89
64
  name: jax_dataclasses.Static[str]
90
65
  axis: npt.NDArray
91
66
  pose: npt.NDArray
92
- jtype: jax_dataclasses.Static[Union[JointType, JointDescriptor]]
67
+ jtype: jax_dataclasses.Static[JointType]
93
68
  child: LinkDescription = dataclasses.dataclass(repr=False)
94
69
  parent: LinkDescription = dataclasses.dataclass(repr=False)
95
70
 
@@ -689,6 +689,7 @@ class KinematicGraphTransforms:
689
689
  # Compute the joint transform from the predecessor to the successor frame.
690
690
  pre_H_J = self.pre_H_suc(
691
691
  joint_type=joint.jtype,
692
+ joint_axis=joint.axis,
692
693
  joint_position=self._initial_joint_positions[joint.name],
693
694
  )
694
695
 
@@ -762,14 +763,15 @@ class KinematicGraphTransforms:
762
763
 
763
764
  @staticmethod
764
765
  def pre_H_suc(
765
- joint_type: descriptions.JointType | descriptions.JointDescriptor,
766
+ joint_type: descriptions.JointType,
767
+ joint_axis: descriptions.JointGenericAxis,
766
768
  joint_position: float | None = None,
767
769
  ) -> npt.NDArray:
768
770
 
769
771
  import jaxsim.math
770
772
 
771
773
  return np.array(
772
- jaxsim.math.supported_joint_motion(
773
- joint_type=joint_type, joint_position=joint_position
774
- )[0]
774
+ jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)[
775
+ 0
776
+ ]
775
777
  )
@@ -352,7 +352,7 @@ def build_model_description(
352
352
  considered_joints=[
353
353
  j.name
354
354
  for j in sdf_data.joint_descriptions
355
- if j.jtype is not descriptions.JointType.F
355
+ if j.jtype is not descriptions.JointType.Fixed
356
356
  ],
357
357
  )
358
358
 
@@ -61,7 +61,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
61
61
 
62
62
  def joint_to_joint_type(
63
63
  joint: rod.Joint,
64
- ) -> descriptions.JointType | descriptions.JointDescriptor:
64
+ ) -> descriptions.JointType:
65
65
  """
66
66
  Extract the joint type from an SDF joint.
67
67
 
@@ -76,7 +76,7 @@ def joint_to_joint_type(
76
76
  joint_type = joint.type
77
77
 
78
78
  if joint_type == "fixed":
79
- return descriptions.JointType.F
79
+ return descriptions.JointType.Fixed
80
80
 
81
81
  if not (axis.xyz is not None and axis.xyz.xyz is not None):
82
82
  raise ValueError("Failed to read axis xyz data")
@@ -86,14 +86,10 @@ def joint_to_joint_type(
86
86
  axis_xyz = axis_xyz / np.linalg.norm(axis_xyz)
87
87
 
88
88
  if joint_type in {"revolute", "continuous"}:
89
- return descriptions.JointGenericAxis(
90
- joint_type=descriptions.JointType.R, axis=axis_xyz
91
- )
89
+ return descriptions.JointType.Revolute
92
90
 
93
91
  if joint_type == "prismatic":
94
- return descriptions.JointGenericAxis(
95
- joint_type=descriptions.JointType.P, axis=axis_xyz
96
- )
92
+ return descriptions.JointType.Prismatic
97
93
 
98
94
  raise ValueError("Joint not supported", axis_xyz, joint_type)
99
95
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.1.dev38
3
+ Version: 0.2.1.dev47
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -1,37 +0,0 @@
1
- # JAXsim Notebook Examples
2
-
3
- This folder includes a Jupyter Notebook demonstrating the practical usage of JAXsim for system simulations.
4
-
5
- ### Examples
6
-
7
- - [PD_controller](https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/PD_controller.ipynb) - A simple example demonstrating the use of JAXsim to simulate a PD controller with gravity compensation for a 2-DOF cartpole.
8
- - [Parallel_computing](https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/Parallel_computing.ipynb) - An example demonstrating how to simulate vectorized models in parallel using JAXsim.
9
-
10
- > [!TIP]
11
- > Stay tuned for more examples!
12
-
13
- ## Running the Examples
14
-
15
- To execute these examples utilizing JAXsim with hardware acceleration, there are a couple of options available:
16
-
17
- ### Option 1: Google Colab (Recommended)
18
-
19
- The simplest way to run the examples is by accessing the provided Google Colab notebook link mentioned above. This will enable you to execute the examples in a hosted environment.
20
-
21
- ### Option 2: Local Execution with `pixi`
22
-
23
- For local execution, follow these steps:
24
-
25
- 1. **Install `pixi`:**
26
-
27
- As per the [official documentation](https://pixi.sh/#installation):
28
-
29
- ```bash
30
- curl -fsSL https://pixi.sh/install.sh | bash
31
- ```
32
-
33
- 2. **Run the Example Notebook:**
34
-
35
- Use `pixi run examples` from the project source directory to execute the example notebook locally.
36
-
37
- This command will automatically handle the installation of necessary dependencies and execute the examples within a self-contained environment
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes