jaxsim 0.2.dev2__tar.gz → 0.2.dev8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/PKG-INFO +3 -3
  2. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/examples/PD_controller.ipynb +105 -76
  3. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/examples/Parallel_computing.ipynb +3 -5
  4. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/setup.cfg +1 -1
  5. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/_version.py +2 -2
  6. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/high_level/model.py +2 -2
  7. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/rod/parser.py +52 -36
  8. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/model/physics_model.py +6 -6
  9. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/simulation/simulator.py +5 -6
  10. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/utils/oop.py +14 -10
  11. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim.egg-info/PKG-INFO +3 -3
  12. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim.egg-info/requires.txt +2 -2
  13. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/.github/workflows/ci_cd.yml +0 -0
  14. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/.github/workflows/read_the_docs.yml +0 -0
  15. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/.github/workflows/style.yml +0 -0
  16. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/.gitignore +0 -0
  17. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/.readthedocs.yaml +0 -0
  18. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/LICENSE +0 -0
  19. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/README.md +0 -0
  20. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/Makefile +0 -0
  21. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/conf.py +0 -0
  22. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/guide/install.rst +0 -0
  23. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/index.rst +0 -0
  24. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/jaxsim_conda_env.yml +0 -0
  25. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/make.bat +0 -0
  26. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/modules/high_level.rst +0 -0
  27. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/modules/math.rst +0 -0
  28. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/modules/parsers.rst +0 -0
  29. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/modules/physics.rst +0 -0
  30. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/modules/simulation.rst +0 -0
  31. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/modules/typing.rst +0 -0
  32. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/docs/modules/utils.rst +0 -0
  33. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/examples/.gitattributes +0 -0
  34. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/examples/.gitignore +0 -0
  35. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/examples/README.md +0 -0
  36. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/examples/assets/cartpole.urdf +0 -0
  37. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/examples/pixi.lock +0 -0
  38. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/examples/pixi.toml +0 -0
  39. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/pyproject.toml +0 -0
  40. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/setup.py +0 -0
  41. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/__init__.py +0 -0
  42. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/high_level/__init__.py +0 -0
  43. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/high_level/common.py +0 -0
  44. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/high_level/joint.py +0 -0
  45. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/high_level/link.py +0 -0
  46. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/logging.py +0 -0
  47. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/__init__.py +0 -0
  48. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/adjoint.py +0 -0
  49. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/conv.py +0 -0
  50. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/cross.py +0 -0
  51. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/inertia.py +0 -0
  52. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/joint.py +0 -0
  53. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/plucker.py +0 -0
  54. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/quaternion.py +0 -0
  55. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/rotation.py +0 -0
  56. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/math/skew.py +0 -0
  57. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/__init__.py +0 -0
  58. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  59. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  60. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  61. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/descriptions/link.py +0 -0
  62. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/descriptions/model.py +0 -0
  63. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  64. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/rod/__init__.py +0 -0
  65. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/parsers/rod/utils.py +0 -0
  66. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/__init__.py +0 -0
  67. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/__init__.py +0 -0
  68. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/aba.py +0 -0
  69. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/aba_motors.py +0 -0
  70. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/crba.py +0 -0
  71. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/forward_kinematics.py +0 -0
  72. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/jacobian.py +0 -0
  73. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/rnea.py +0 -0
  74. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/rnea_motors.py +0 -0
  75. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/soft_contacts.py +0 -0
  76. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/terrain.py +0 -0
  77. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/algos/utils.py +0 -0
  78. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/model/__init__.py +0 -0
  79. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/model/ground_contact.py +0 -0
  80. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/physics/model/physics_model_state.py +0 -0
  81. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/simulation/__init__.py +0 -0
  82. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/simulation/integrators.py +0 -0
  83. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/simulation/ode.py +0 -0
  84. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/simulation/ode_data.py +0 -0
  85. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/simulation/ode_integration.py +0 -0
  86. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/simulation/simulator_callbacks.py +0 -0
  87. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/simulation/utils.py +0 -0
  88. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/sixd/__init__.py +0 -0
  89. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/typing.py +0 -0
  90. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/utils/__init__.py +0 -0
  91. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  92. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/utils/tracing.py +0 -0
  93. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim/utils/vmappable.py +0 -0
  94. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  95. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  96. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim.egg-info/not-zip-safe +0 -0
  97. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/src/jaxsim.egg-info/top_level.txt +0 -0
  98. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/tests/__init__.py +0 -0
  99. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/tests/test_ad_physics.py +0 -0
  100. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/tests/test_eom.py +0 -0
  101. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/tests/test_forward_dynamics.py +0 -0
  102. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/tests/test_jax_oop.py +0 -0
  103. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/tests/utils_idyntree.py +0 -0
  104. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/tests/utils_models.py +0 -0
  105. {jaxsim-0.2.dev2 → jaxsim-0.2.dev8}/tests/utils_rng.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.dev2
3
+ Version: 0.2.dev8
4
4
  Summary: A physics engine in reduced coordinates implemented with JAX.
5
5
  Home-page: https://github.com/ami-iit/jaxsim
6
6
  Author: Diego Ferigo
@@ -38,7 +38,7 @@ Requires-Dist: jax_dataclasses>=1.4.0
38
38
  Requires-Dist: pptree
39
39
  Requires-Dist: rod
40
40
  Provides-Extra: style
41
- Requires-Dist: black; extra == "style"
41
+ Requires-Dist: black[jupyter]; extra == "style"
42
42
  Requires-Dist: isort; extra == "style"
43
43
  Provides-Extra: testing
44
44
  Requires-Dist: idyntree; extra == "testing"
@@ -47,7 +47,7 @@ Requires-Dist: pytest-forked; extra == "testing"
47
47
  Requires-Dist: pytest-icdiff; extra == "testing"
48
48
  Requires-Dist: robot-descriptions; extra == "testing"
49
49
  Provides-Extra: all
50
- Requires-Dist: black; extra == "all"
50
+ Requires-Dist: black[jupyter]; extra == "all"
51
51
  Requires-Dist: isort; extra == "all"
52
52
  Requires-Dist: idyntree; extra == "all"
53
53
  Requires-Dist: pytest>=6.0; extra == "all"
@@ -19,7 +19,7 @@
19
19
  "metadata": {},
20
20
  "outputs": [],
21
21
  "source": [
22
- "#@title Imports and setup\n",
22
+ "# @title Imports and setup\n",
23
23
  "from IPython.display import clear_output, HTML, display\n",
24
24
  "import sys\n",
25
25
  "\n",
@@ -52,7 +52,7 @@
52
52
  "metadata": {},
53
53
  "outputs": [],
54
54
  "source": [
55
- "#@title Fetch the URDF file\n",
55
+ "# @title Fetch the URDF file\n",
56
56
  "import requests\n",
57
57
  "\n",
58
58
  "url = \"https://raw.githubusercontent.com/ami-iit/jaxsim/main/examples/assets/cartpole.urdf\"\n",
@@ -117,7 +117,7 @@
117
117
  "metadata": {},
118
118
  "outputs": [],
119
119
  "source": [
120
- "#@title Set up MuJoCo renderer\n",
120
+ "# @title Set up MuJoCo renderer\n",
121
121
  "!{sys.executable} -m pip install -U -q mujoco\n",
122
122
  "!{sys.executable} -m pip install -q mediapy\n",
123
123
  "\n",
@@ -131,96 +131,106 @@
131
131
  "import subprocess\n",
132
132
  "\n",
133
133
  "if IS_COLAB:\n",
134
- " if subprocess.run('ffmpeg -version', shell=True).returncode:\n",
135
- " !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
136
- " clear_output()\n",
134
+ " if subprocess.run(\"ffmpeg -version\", shell=True).returncode:\n",
135
+ " !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
136
+ " clear_output()\n",
137
+ "\n",
138
+ " if subprocess.run(\"nvidia-smi\").returncode:\n",
139
+ " raise RuntimeError(\n",
140
+ " \"Cannot communicate with GPU. \"\n",
141
+ " \"Make sure you are using a GPU Colab runtime. \"\n",
142
+ " \"Go to the Runtime menu and select Choose runtime type.\"\n",
143
+ " )\n",
137
144
  "\n",
138
- " if subprocess.run('nvidia-smi').returncode:\n",
139
- " raise RuntimeError(\n",
140
- " 'Cannot communicate with GPU. '\n",
141
- " 'Make sure you are using a GPU Colab runtime. '\n",
142
- " 'Go to the Runtime menu and select Choose runtime type.')\n",
143
- "\n",
144
- "# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
145
- "# This is usually installed as part of an Nvidia driver package, but the Colab\n",
146
- "# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
147
- "# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
148
- " NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
149
- " if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
150
- " with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
151
- " f.write(\"\"\"{\n",
145
+ " # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
146
+ " # This is usually installed as part of an Nvidia driver package, but the Colab\n",
147
+ " # kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
148
+ " # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
149
+ " NVIDIA_ICD_CONFIG_PATH = \"/usr/share/glvnd/egl_vendor.d/10_nvidia.json\"\n",
150
+ " if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
151
+ " with open(NVIDIA_ICD_CONFIG_PATH, \"w\") as f:\n",
152
+ " f.write(\n",
153
+ " \"\"\"{\n",
152
154
  " \"file_format_version\" : \"1.0.0\",\n",
153
155
  " \"ICD\" : {\n",
154
156
  " \"library_path\" : \"libEGL_nvidia.so.0\"\n",
155
157
  " }\n",
156
158
  " }\n",
157
- " \"\"\")\n",
159
+ " \"\"\"\n",
160
+ " )\n",
158
161
  "\n",
159
162
  "%env MUJOCO_GL=egl\n",
160
163
  "\n",
161
164
  "try:\n",
162
- " import mujoco\n",
165
+ " import mujoco\n",
163
166
  "except Exception as e:\n",
164
- " raise e from RuntimeError(\n",
165
- " 'Something went wrong during installation. Check the shell output above '\n",
166
- " 'for more information.\\n'\n",
167
- " 'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
168
- " 'by going to the Runtime menu and selecting \"Choose runtime type\".')\n",
167
+ " raise e from RuntimeError(\n",
168
+ " \"Something went wrong during installation. Check the shell output above \"\n",
169
+ " \"for more information.\\n\"\n",
170
+ " \"If using a hosted Colab runtime, make sure you enable GPU acceleration \"\n",
171
+ " 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n",
172
+ " )\n",
173
+ "\n",
169
174
  "\n",
170
175
  "def load_mujoco_model_with_camera(xml_string, camera_pos, camera_xyaxes):\n",
171
- " def to_mjcf_string(list_to_str):\n",
172
- " return ' '.join(map(str, list_to_str))\n",
173
- "\n",
174
- " mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)\n",
175
- " path_temp_xml = tempfile.NamedTemporaryFile(mode=\"w+\")\n",
176
- " mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)\n",
177
- " # Add camera in mujoco model\n",
178
- " tree = ET.parse(path_temp_xml)\n",
179
- " for elem in tree.getroot().iter(\"worldbody\"):\n",
180
- " worldbody_elem = elem\n",
181
- " camera_elem = ET.Element(\"camera\")\n",
182
- " # Set attributes \n",
183
- " camera_elem.set(\"name\", \"side\")\n",
184
- " camera_elem.set(\"pos\", to_mjcf_string(camera_pos))\n",
185
- " camera_elem.set(\"xyaxes\", to_mjcf_string(camera_xyaxes))\n",
186
- " camera_elem.set(\"mode\", \"fixed\")\n",
187
- " worldbody_elem.append(camera_elem)\n",
188
- "\n",
189
- " # Save new model\n",
190
- " mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding=\"unicode\")\n",
191
- " mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)\n",
192
- " return mj_model\n",
176
+ " def to_mjcf_string(list_to_str):\n",
177
+ " return \" \".join(map(str, list_to_str))\n",
178
+ "\n",
179
+ " mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)\n",
180
+ " path_temp_xml = tempfile.NamedTemporaryFile(mode=\"w+\")\n",
181
+ " mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)\n",
182
+ " # Add camera in mujoco model\n",
183
+ " tree = ET.parse(path_temp_xml)\n",
184
+ " for elem in tree.getroot().iter(\"worldbody\"):\n",
185
+ " worldbody_elem = elem\n",
186
+ " camera_elem = ET.Element(\"camera\")\n",
187
+ " # Set attributes\n",
188
+ " camera_elem.set(\"name\", \"side\")\n",
189
+ " camera_elem.set(\"pos\", to_mjcf_string(camera_pos))\n",
190
+ " camera_elem.set(\"xyaxes\", to_mjcf_string(camera_xyaxes))\n",
191
+ " camera_elem.set(\"mode\", \"fixed\")\n",
192
+ " worldbody_elem.append(camera_elem)\n",
193
+ "\n",
194
+ " # Save new model\n",
195
+ " mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding=\"unicode\")\n",
196
+ " mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)\n",
197
+ " return mj_model\n",
193
198
  "\n",
194
199
  "\n",
195
200
  "def from_jaxsim_to_mujoco_pos(jaxsim_jointpos, mjmodel, jaxsimmodel):\n",
196
- " mujocoqposaddr2jaxindex = {}\n",
197
- " for jaxjnt in jaxsimmodel.joints():\n",
198
- " jntname = jaxjnt.name()\n",
199
- " mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index()-1\n",
201
+ " mujocoqposaddr2jaxindex = {}\n",
202
+ " for jaxjnt in jaxsimmodel.joints():\n",
203
+ " jntname = jaxjnt.name()\n",
204
+ " mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index() - 1\n",
200
205
  "\n",
201
- " mujoco_jointpos = jaxsim_jointpos\n",
202
- " for i in range(0, len(mujoco_jointpos)):\n",
203
- " mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]\n",
206
+ " mujoco_jointpos = jaxsim_jointpos\n",
207
+ " for i in range(0, len(mujoco_jointpos)):\n",
208
+ " mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]\n",
209
+ "\n",
210
+ " return mujoco_jointpos\n",
204
211
  "\n",
205
- " return mujoco_jointpos\n",
206
- " \n",
207
212
  "\n",
208
213
  "# To get a good camera location, you can use \"Copy camera\" functionality in MuJoCo GUI\n",
209
- "mj_model = load_mujoco_model_with_camera(model_urdf_string, [3.954, 3.533, 2.343], [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979])\n",
214
+ "mj_model = load_mujoco_model_with_camera(\n",
215
+ " model_urdf_string,\n",
216
+ " [3.954, 3.533, 2.343],\n",
217
+ " [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979],\n",
218
+ ")\n",
210
219
  "renderer = mujoco.Renderer(mj_model, height=480, width=640)\n",
211
220
  "\n",
221
+ "\n",
212
222
  "def get_image(camera, mujocojointpos) -> np.ndarray:\n",
213
- " \"\"\"Renders the environment state.\"\"\"\n",
214
- " # Copy joint data in mjdata state\n",
215
- " d = mujoco.MjData(mj_model)\n",
216
- " d.qpos = mujocojointpos\n",
217
- " \n",
218
- " # Forward kinematics\n",
219
- " mujoco.mj_forward(mj_model, d)\n",
220
- " \n",
221
- " # use the mjData object to update the renderer\n",
222
- " renderer.update_scene(d, camera=camera)\n",
223
- " return renderer.render()"
223
+ " \"\"\"Renders the environment state.\"\"\"\n",
224
+ " # Copy joint data in mjdata state\n",
225
+ " d = mujoco.MjData(mj_model)\n",
226
+ " d.qpos = mujocojointpos\n",
227
+ "\n",
228
+ " # Forward kinematics\n",
229
+ " mujoco.mj_forward(mj_model, d)\n",
230
+ "\n",
231
+ " # use the mjData object to update the renderer\n",
232
+ " renderer.update_scene(d, camera=camera)\n",
233
+ " return renderer.render()"
224
234
  ]
225
235
  },
226
236
  {
@@ -241,10 +251,19 @@
241
251
  "sim_images = []\n",
242
252
  "timestep = 0.01\n",
243
253
  "for _ in range(300):\n",
244
- " sim_images.append(get_image(\"side\", from_jaxsim_to_mujoco_pos(np.array(model.joint_positions()), mj_model, model)))\n",
245
- " model.integrate(t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit)\n",
254
+ " sim_images.append(\n",
255
+ " get_image(\n",
256
+ " \"side\",\n",
257
+ " from_jaxsim_to_mujoco_pos(\n",
258
+ " np.array(model.joint_positions()), mj_model, model\n",
259
+ " ),\n",
260
+ " )\n",
261
+ " )\n",
262
+ " model.integrate(\n",
263
+ " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n",
264
+ " )\n",
246
265
  "\n",
247
- "media.show_video(sim_images, fps=1/timestep)"
266
+ "media.show_video(sim_images, fps=1 / timestep)"
248
267
  ]
249
268
  },
250
269
  {
@@ -274,6 +293,7 @@
274
293
  "# Compute the gravity compensation term\n",
275
294
  "H = model.free_floating_bias_forces()[6:]\n",
276
295
  "\n",
296
+ "\n",
277
297
  "def pd_controller(\n",
278
298
  " q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array\n",
279
299
  ") -> jax.Array:\n",
@@ -297,7 +317,14 @@
297
317
  "timestep = 0.01\n",
298
318
  "\n",
299
319
  "for _ in range(300):\n",
300
- " sim_images.append(get_image(\"side\", from_jaxsim_to_mujoco_pos(np.array(model.joint_positions()), mj_model, model)))\n",
320
+ " sim_images.append(\n",
321
+ " get_image(\n",
322
+ " \"side\",\n",
323
+ " from_jaxsim_to_mujoco_pos(\n",
324
+ " np.array(model.joint_positions()), mj_model, model\n",
325
+ " ),\n",
326
+ " )\n",
327
+ " )\n",
301
328
  " model.set_joint_generalized_force_targets(\n",
302
329
  " forces=pd_controller(\n",
303
330
  " q=model.joint_positions(),\n",
@@ -306,9 +333,11 @@
306
333
  " q_dot_d=jnp.array([0.0, 0.0]),\n",
307
334
  " )\n",
308
335
  " )\n",
309
- " model.integrate(t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit)\n",
336
+ " model.integrate(\n",
337
+ " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n",
338
+ " )\n",
310
339
  "\n",
311
- "media.show_video(sim_images, fps=1/timestep)"
340
+ "media.show_video(sim_images, fps=1 / timestep)"
312
341
  ]
313
342
  }
314
343
  ],
@@ -24,7 +24,7 @@
24
24
  "metadata": {},
25
25
  "outputs": [],
26
26
  "source": [
27
- "#@title Imports and setup\n",
27
+ "# @title Imports and setup\n",
28
28
  "import sys\n",
29
29
  "\n",
30
30
  "from IPython.display import HTML, clear_output, display\n",
@@ -71,7 +71,7 @@
71
71
  "metadata": {},
72
72
  "outputs": [],
73
73
  "source": [
74
- "#@title Create a sphere model\n",
74
+ "# @title Create a sphere model\n",
75
75
  "model_sdf_string = rod.Sdf(\n",
76
76
  " version=\"1.7\",\n",
77
77
  " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n",
@@ -119,9 +119,7 @@
119
119
  "\n",
120
120
  "\n",
121
121
  "# Add model to simulator\n",
122
- "model = simulator.insert_model_from_description(\n",
123
- " model_description=model_sdf_string\n",
124
- ")"
122
+ "model = simulator.insert_model_from_description(model_description=model_sdf_string)"
125
123
  ]
126
124
  },
127
125
  {
@@ -62,7 +62,7 @@ where = src
62
62
 
63
63
  [options.extras_require]
64
64
  style =
65
- black
65
+ black[jupyter]
66
66
  isort
67
67
  testing =
68
68
  idyntree
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.dev2'
16
- __version_tuple__ = version_tuple = (0, 2, 'dev2')
15
+ __version__ = version = '0.2.dev8'
16
+ __version_tuple__ = version_tuple = (0, 2, 'dev8')
@@ -385,13 +385,13 @@ class Model(Vmappable):
385
385
  def link_names(self) -> tuple[str, ...]:
386
386
  """"""
387
387
 
388
- return tuple(l.name() for l in self.links())
388
+ return tuple(self.physics_model.description.links_dict.keys())
389
389
 
390
390
  @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
391
391
  def joint_names(self) -> tuple[str, ...]:
392
392
  """"""
393
393
 
394
- return tuple(j.name() for j in self.joints())
394
+ return tuple(self.physics_model.description.joints_dict.keys())
395
395
 
396
396
  @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
397
397
  def links(
@@ -135,11 +135,13 @@ def extract_model_data(
135
135
  parent=world_link,
136
136
  child=links_dict[j.child],
137
137
  jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
138
- axis=np.array(j.axis.xyz.xyz)
139
- if j.axis is not None
140
- and j.axis.xyz is not None
141
- and j.axis.xyz.xyz is not None
142
- else None,
138
+ axis=(
139
+ np.array(j.axis.xyz.xyz)
140
+ if j.axis is not None
141
+ and j.axis.xyz is not None
142
+ and j.axis.xyz.xyz is not None
143
+ else None
144
+ ),
143
145
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
144
146
  )
145
147
  for j in sdf_model.joints()
@@ -200,41 +202,55 @@ def extract_model_data(
200
202
  parent=links_dict[j.parent],
201
203
  child=links_dict[j.child],
202
204
  jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
203
- axis=np.array(j.axis.xyz.xyz)
204
- if j.axis is not None
205
- and j.axis.xyz is not None
206
- and j.axis.xyz.xyz is not None
207
- else None,
205
+ axis=(
206
+ np.array(j.axis.xyz.xyz)
207
+ if j.axis is not None
208
+ and j.axis.xyz is not None
209
+ and j.axis.xyz.xyz is not None
210
+ else None
211
+ ),
208
212
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
209
213
  initial_position=0.0,
210
214
  position_limit=(
211
- float(j.axis.limit.lower)
212
- if j.axis is not None and j.axis.limit is not None
213
- else np.finfo(float).min,
214
- float(j.axis.limit.upper)
215
- if j.axis is not None and j.axis.limit is not None
216
- else np.finfo(float).max,
215
+ (
216
+ float(j.axis.limit.lower)
217
+ if j.axis is not None and j.axis.limit is not None
218
+ else np.finfo(float).min
219
+ ),
220
+ (
221
+ float(j.axis.limit.upper)
222
+ if j.axis is not None and j.axis.limit is not None
223
+ else np.finfo(float).max
224
+ ),
225
+ ),
226
+ friction_static=(
227
+ j.axis.dynamics.friction
228
+ if j.axis is not None
229
+ and j.axis.dynamics is not None
230
+ and j.axis.dynamics.friction is not None
231
+ else 0.0
232
+ ),
233
+ friction_viscous=(
234
+ j.axis.dynamics.damping
235
+ if j.axis is not None
236
+ and j.axis.dynamics is not None
237
+ and j.axis.dynamics.damping is not None
238
+ else 0.0
239
+ ),
240
+ position_limit_damper=(
241
+ j.axis.limit.dissipation
242
+ if j.axis is not None
243
+ and j.axis.limit is not None
244
+ and j.axis.limit.dissipation is not None
245
+ else 0.0
246
+ ),
247
+ position_limit_spring=(
248
+ j.axis.limit.stiffness
249
+ if j.axis is not None
250
+ and j.axis.limit is not None
251
+ and j.axis.limit.stiffness is not None
252
+ else 0.0
217
253
  ),
218
- friction_static=j.axis.dynamics.friction
219
- if j.axis is not None
220
- and j.axis.dynamics is not None
221
- and j.axis.dynamics.friction is not None
222
- else 0.0,
223
- friction_viscous=j.axis.dynamics.damping
224
- if j.axis is not None
225
- and j.axis.dynamics is not None
226
- and j.axis.dynamics.damping is not None
227
- else 0.0,
228
- position_limit_damper=j.axis.limit.dissipation
229
- if j.axis is not None
230
- and j.axis.limit is not None
231
- and j.axis.limit.dissipation is not None
232
- else 0.0,
233
- position_limit_spring=j.axis.limit.stiffness
234
- if j.axis is not None
235
- and j.axis.limit is not None
236
- and j.axis.limit.stiffness is not None
237
- else 0.0,
238
254
  )
239
255
  for j in sdf_model.joints()
240
256
  if j.type in {"revolute", "prismatic", "fixed"}
@@ -45,14 +45,14 @@ class PhysicsModel(JaxsimDataclass):
45
45
  )
46
46
  is_floating_base: Static[bool] = dataclasses.field(default=False)
47
47
  gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact())
48
- description: Static[
49
- jaxsim.parsers.descriptions.model.ModelDescription
50
- ] = dataclasses.field(default=None)
48
+ description: Static[jaxsim.parsers.descriptions.model.ModelDescription] = (
49
+ dataclasses.field(default=None)
50
+ )
51
51
 
52
52
  _parent_array_dict: Static[Dict[int, int]] = dataclasses.field(default_factory=dict)
53
- _jtype_dict: Static[
54
- Dict[int, Union[JointType, JointDescriptor]]
55
- ] = dataclasses.field(default_factory=dict)
53
+ _jtype_dict: Static[Dict[int, Union[JointType, JointDescriptor]]] = (
54
+ dataclasses.field(default_factory=dict)
55
+ )
56
56
  _tree_transforms_dict: Dict[int, jtp.Matrix] = dataclasses.field(
57
57
  default_factory=dict
58
58
  )
@@ -432,8 +432,9 @@ class JaxSim(Vmappable):
432
432
  def step_over_horizon(
433
433
  self,
434
434
  horizon_steps: jtp.Int,
435
- callback_handler: Union["scb.SimulatorCallback", "scb.CallbackHandler"]
436
- | None = None,
435
+ callback_handler: (
436
+ Union["scb.SimulatorCallback", "scb.CallbackHandler"] | None
437
+ ) = None,
437
438
  clear_inputs: jtp.Bool = False,
438
439
  ) -> Union[
439
440
  "JaxSim",
@@ -459,10 +460,8 @@ class JaxSim(Vmappable):
459
460
  sim = self.copy().mutable(validate=True)
460
461
 
461
462
  # Helper to get callbacks from the handler
462
- get_cb = (
463
- lambda h, cb_name: getattr(h, cb_name)
464
- if h is not None and hasattr(h, cb_name)
465
- else None
463
+ get_cb = lambda h, cb_name: (
464
+ getattr(h, cb_name) if h is not None and hasattr(h, cb_name) else None
466
465
  )
467
466
 
468
467
  # Get the callbacks
@@ -3,16 +3,20 @@ import dataclasses
3
3
  import functools
4
4
  import inspect
5
5
  import os
6
- from typing import Any, Callable, Generator
6
+ from typing import Any, Callable, Generator, TypeVar
7
7
 
8
8
  import jax
9
9
  import jax.flatten_util
10
+ from typing_extensions import ParamSpec
10
11
 
11
12
  from jaxsim import logging
12
13
  from jaxsim.utils import tracing
13
14
 
14
15
  from . import Mutability, Vmappable
15
16
 
17
+ _P = ParamSpec("_P")
18
+ _R = TypeVar("_R")
19
+
16
20
 
17
21
  class jax_tf:
18
22
  """
@@ -27,13 +31,13 @@ class jax_tf:
27
31
 
28
32
  @staticmethod
29
33
  def method_ro(
30
- fn: Callable,
34
+ fn: Callable[_P, _R],
31
35
  jit: bool = True,
32
36
  static_argnames: tuple[str, ...] | list[str] = (),
33
37
  vmap: bool | None = None,
34
38
  vmap_in_axes: tuple[int, ...] | int | None = None,
35
39
  vmap_out_axes: tuple[int, ...] | int | None = None,
36
- ):
40
+ ) -> Callable[_P, _R]:
37
41
  """
38
42
  Decorator for r/o methods of classes inheriting from Vmappable.
39
43
  """
@@ -51,14 +55,14 @@ class jax_tf:
51
55
 
52
56
  @staticmethod
53
57
  def method_rw(
54
- fn: Callable,
58
+ fn: Callable[_P, _R],
55
59
  validate: bool = True,
56
60
  jit: bool = True,
57
61
  static_argnames: tuple[str, ...] | list[str] = (),
58
62
  vmap: bool | None = None,
59
63
  vmap_in_axes: tuple[int, ...] | int | None = None,
60
64
  vmap_out_axes: tuple[int, ...] | int | None = None,
61
- ):
65
+ ) -> Callable[_P, _R]:
62
66
  """
63
67
  Decorator for r/w methods of classes inheriting from Vmappable.
64
68
  """
@@ -76,7 +80,7 @@ class jax_tf:
76
80
 
77
81
  @staticmethod
78
82
  def method(
79
- fn: Callable,
83
+ fn: Callable[_P, _R],
80
84
  read_only: bool = True,
81
85
  validate: bool = True,
82
86
  jit_enabled: bool = True,
@@ -109,7 +113,7 @@ class jax_tf:
109
113
  """
110
114
 
111
115
  @functools.wraps(fn)
112
- def wrapper(*args, **kwargs):
116
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs):
113
117
  """The wrapper function that is returned by this decorator."""
114
118
 
115
119
  # Methods of classes inheriting from Vmappable decorated by this wrapper
@@ -202,9 +206,9 @@ class jax_tf:
202
206
  mutability_dict = {
203
207
  Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION,
204
208
  Mutability.MUTABLE: Mutability.MUTABLE,
205
- Mutability.FROZEN: Mutability.MUTABLE
206
- if validate
207
- else Mutability.MUTABLE_NO_VALIDATION,
209
+ Mutability.FROZEN: (
210
+ Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
211
+ ),
208
212
  }
209
213
 
210
214
  # We need to replace all the dynamic leafs of the original instance with those
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.dev2
3
+ Version: 0.2.dev8
4
4
  Summary: A physics engine in reduced coordinates implemented with JAX.
5
5
  Home-page: https://github.com/ami-iit/jaxsim
6
6
  Author: Diego Ferigo
@@ -38,7 +38,7 @@ Requires-Dist: jax_dataclasses>=1.4.0
38
38
  Requires-Dist: pptree
39
39
  Requires-Dist: rod
40
40
  Provides-Extra: style
41
- Requires-Dist: black; extra == "style"
41
+ Requires-Dist: black[jupyter]; extra == "style"
42
42
  Requires-Dist: isort; extra == "style"
43
43
  Provides-Extra: testing
44
44
  Requires-Dist: idyntree; extra == "testing"
@@ -47,7 +47,7 @@ Requires-Dist: pytest-forked; extra == "testing"
47
47
  Requires-Dist: pytest-icdiff; extra == "testing"
48
48
  Requires-Dist: robot-descriptions; extra == "testing"
49
49
  Provides-Extra: all
50
- Requires-Dist: black; extra == "all"
50
+ Requires-Dist: black[jupyter]; extra == "all"
51
51
  Requires-Dist: isort; extra == "all"
52
52
  Requires-Dist: idyntree; extra == "all"
53
53
  Requires-Dist: pytest>=6.0; extra == "all"
@@ -7,7 +7,7 @@ pptree
7
7
  rod
8
8
 
9
9
  [all]
10
- black
10
+ black[jupyter]
11
11
  isort
12
12
  idyntree
13
13
  pytest>=6.0
@@ -16,7 +16,7 @@ pytest-icdiff
16
16
  robot-descriptions
17
17
 
18
18
  [style]
19
- black
19
+ black[jupyter]
20
20
  isort
21
21
 
22
22
  [testing]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes