jaxsim 0.2.dev428__tar.gz → 0.2.dev435__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/PKG-INFO +1 -1
  2. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/examples/PD_controller.ipynb +92 -112
  3. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/examples/Parallel_computing.ipynb +73 -60
  4. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/mujoco/loaders.py +45 -0
  6. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim.egg-info/PKG-INFO +1 -1
  7. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/.devcontainer/Dockerfile +0 -0
  8. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/.devcontainer/devcontainer.json +0 -0
  9. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/.github/CODEOWNERS +0 -0
  10. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/.github/workflows/ci_cd.yml +0 -0
  11. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/.github/workflows/read_the_docs.yml +0 -0
  12. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/.github/workflows/style.yml +0 -0
  13. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/.gitignore +0 -0
  14. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/.pre-commit-config.yaml +0 -0
  15. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/.readthedocs.yaml +0 -0
  16. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/CONTRIBUTING.md +0 -0
  17. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/LICENSE +0 -0
  18. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/README.md +0 -0
  19. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/Makefile +0 -0
  20. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/conf.py +0 -0
  21. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/guide/install.rst +0 -0
  22. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/index.rst +0 -0
  23. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/make.bat +0 -0
  24. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/modules/api.rst +0 -0
  25. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/modules/index.rst +0 -0
  26. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/modules/integrators.rst +0 -0
  27. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/modules/math.rst +0 -0
  28. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/modules/mujoco.rst +0 -0
  29. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/modules/parsers.rst +0 -0
  30. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/modules/rbda.rst +0 -0
  31. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/modules/typing.rst +0 -0
  32. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/docs/modules/utils.rst +0 -0
  33. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/environment.yml +0 -0
  34. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/examples/.gitattributes +0 -0
  35. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/examples/.gitignore +0 -0
  36. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/examples/README.md +0 -0
  37. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/examples/assets/cartpole.urdf +0 -0
  38. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/examples/pixi.lock +0 -0
  39. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/examples/pixi.toml +0 -0
  40. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/pyproject.toml +0 -0
  41. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/setup.cfg +0 -0
  42. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/setup.py +0 -0
  43. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/__init__.py +0 -0
  44. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/__init__.py +0 -0
  45. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/com.py +0 -0
  46. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/common.py +0 -0
  47. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/contact.py +0 -0
  48. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/data.py +0 -0
  49. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/joint.py +0 -0
  50. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  51. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/link.py +0 -0
  52. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/model.py +0 -0
  53. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/ode.py +0 -0
  54. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/ode_data.py +0 -0
  55. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/api/references.py +0 -0
  56. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/integrators/__init__.py +0 -0
  57. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/integrators/common.py +0 -0
  58. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/integrators/fixed_step.py +0 -0
  59. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/integrators/variable_step.py +0 -0
  60. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/logging.py +0 -0
  61. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/math/__init__.py +0 -0
  62. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/math/adjoint.py +0 -0
  63. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/math/cross.py +0 -0
  64. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/math/inertia.py +0 -0
  65. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/math/joint_model.py +0 -0
  66. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/math/quaternion.py +0 -0
  67. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/math/rotation.py +0 -0
  68. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/math/skew.py +0 -0
  69. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/math/transform.py +0 -0
  70. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/mujoco/__init__.py +0 -0
  71. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/mujoco/__main__.py +0 -0
  72. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/mujoco/model.py +0 -0
  73. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/mujoco/visualizer.py +0 -0
  74. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/__init__.py +0 -0
  75. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  76. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  77. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  78. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/descriptions/link.py +0 -0
  79. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/descriptions/model.py +0 -0
  80. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  81. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/rod/__init__.py +0 -0
  82. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/rod/parser.py +0 -0
  83. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/parsers/rod/utils.py +0 -0
  84. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/rbda/__init__.py +0 -0
  85. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/rbda/aba.py +0 -0
  86. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/rbda/collidable_points.py +0 -0
  87. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/rbda/crba.py +0 -0
  88. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  89. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/rbda/jacobian.py +0 -0
  90. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/rbda/rnea.py +0 -0
  91. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/rbda/soft_contacts.py +0 -0
  92. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/rbda/utils.py +0 -0
  93. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/terrain/__init__.py +0 -0
  94. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/terrain/terrain.py +0 -0
  95. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/typing.py +0 -0
  96. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/utils/__init__.py +0 -0
  97. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/utils/hashless.py +0 -0
  98. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  99. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim/utils/tracing.py +0 -0
  100. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  101. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  102. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim.egg-info/not-zip-safe +0 -0
  103. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim.egg-info/requires.txt +0 -0
  104. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/src/jaxsim.egg-info/top_level.txt +0 -0
  105. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/__init__.py +0 -0
  106. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/conftest.py +0 -0
  107. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/test_api_com.py +0 -0
  108. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/test_api_data.py +0 -0
  109. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/test_api_joint.py +0 -0
  110. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/test_api_link.py +0 -0
  111. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/test_api_model.py +0 -0
  112. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/test_automatic_differentiation.py +0 -0
  113. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/test_pytree.py +0 -0
  114. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/test_simulations.py +0 -0
  115. {jaxsim-0.2.dev428 → jaxsim-0.2.dev435}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.dev428
3
+ Version: 0.2.dev435
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Home-page: https://github.com/ami-iit/jaxsim
6
6
  Author: Diego Ferigo
@@ -68,7 +68,12 @@
68
68
  "cell_type": "markdown",
69
69
  "metadata": {},
70
70
  "source": [
71
- "JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. "
71
+ "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n",
72
+ "\n",
73
+ "- `model`: an object that defines the dynamics of the system.\n",
74
+ "- `data`: an object that contains the state of the system.\n",
75
+ "- `integrator`: an object that defines the integration method.\n",
76
+ "- `integrator_state`: an object that contains the state of the integrator."
72
77
  ]
73
78
  },
74
79
  {
@@ -77,11 +82,23 @@
77
82
  "metadata": {},
78
83
  "outputs": [],
79
84
  "source": [
80
- "from jaxsim.high_level.model import Model\n",
85
+ "import jaxsim.api as js\n",
86
+ "from jaxsim import integrators\n",
87
+ "\n",
88
+ "dt = 0.01\n",
81
89
  "\n",
82
- "model = Model.build_from_model_description(\n",
90
+ "model = js.model.JaxSimModel.build_from_model_description(\n",
83
91
  " model_description=model_urdf_string, is_urdf=True\n",
84
- ")"
92
+ ")\n",
93
+ "data = js.data.JaxSimModelData.build(model=model)\n",
94
+ "integrator = integrators.fixed_step.RungeKutta4SO3.build(\n",
95
+ " dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
96
+ " model=model,\n",
97
+ " data=data,\n",
98
+ " system_dynamics=js.ode.system_dynamics,\n",
99
+ " ),\n",
100
+ ")\n",
101
+ "integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)"
85
102
  ]
86
103
  },
87
104
  {
@@ -101,7 +118,7 @@
101
118
  " minval=-1.0, maxval=1.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)\n",
102
119
  ")\n",
103
120
  "\n",
104
- "model.reset_joint_positions(positions=random_positions)"
121
+ "data = data.reset_joint_positions(positions=random_positions)"
105
122
  ]
106
123
  },
107
124
  {
@@ -118,17 +135,11 @@
118
135
  "outputs": [],
119
136
  "source": [
120
137
  "# @title Set up MuJoCo renderer\n",
121
- "!{sys.executable} -m pip install -U -q mujoco\n",
122
- "!{sys.executable} -m pip install -q mediapy\n",
123
138
  "\n",
124
- "import mediapy as media\n",
125
- "import tempfile\n",
126
- "import xml.etree.ElementTree as ET\n",
127
- "import numpy as np\n",
139
+ "from jaxsim.mujoco.visualizer import MujocoVisualizer\n",
140
+ "from jaxsim.mujoco import RodModelToMjcf, MujocoModelHelper, MujocoVideoRecorder\n",
141
+ "from jaxsim.mujoco.loaders import UrdfToMjcf\n",
128
142
  "\n",
129
- "import distutils.util\n",
130
- "import os\n",
131
- "import subprocess\n",
132
143
  "\n",
133
144
  "if IS_COLAB:\n",
134
145
  " if subprocess.run(\"ffmpeg -version\", shell=True).returncode:\n",
@@ -171,66 +182,28 @@
171
182
  " 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n",
172
183
  " )\n",
173
184
  "\n",
185
+ "camera = {\n",
186
+ " \"name\":\"cartpole_camera\",\n",
187
+ " \"mode\":\"fixed\",\n",
188
+ " \"pos\":\"3.954 3.533 2.343\",\n",
189
+ " \"xyaxes\":\"-0.594 0.804 -0.000 -0.163 -0.120 0.979\",\n",
190
+ " \"fovy\":\"60\",\n",
191
+ "}\n",
174
192
  "\n",
175
- "def load_mujoco_model_with_camera(xml_string, camera_pos, camera_xyaxes):\n",
176
- " def to_mjcf_string(list_to_str):\n",
177
- " return \" \".join(map(str, list_to_str))\n",
178
- "\n",
179
- " mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)\n",
180
- " path_temp_xml = tempfile.NamedTemporaryFile(mode=\"w+\")\n",
181
- " mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)\n",
182
- " # Add camera in mujoco model\n",
183
- " tree = ET.parse(path_temp_xml)\n",
184
- " for elem in tree.getroot().iter(\"worldbody\"):\n",
185
- " worldbody_elem = elem\n",
186
- " camera_elem = ET.Element(\"camera\")\n",
187
- " # Set attributes\n",
188
- " camera_elem.set(\"name\", \"side\")\n",
189
- " camera_elem.set(\"pos\", to_mjcf_string(camera_pos))\n",
190
- " camera_elem.set(\"xyaxes\", to_mjcf_string(camera_xyaxes))\n",
191
- " camera_elem.set(\"mode\", \"fixed\")\n",
192
- " worldbody_elem.append(camera_elem)\n",
193
- "\n",
194
- " # Save new model\n",
195
- " mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding=\"unicode\")\n",
196
- " mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)\n",
197
- " return mj_model\n",
198
- "\n",
199
- "\n",
200
- "def from_jaxsim_to_mujoco_pos(jaxsim_jointpos, mjmodel, jaxsimmodel):\n",
201
- " mujocoqposaddr2jaxindex = {}\n",
202
- " for jaxjnt in jaxsimmodel.joints():\n",
203
- " jntname = jaxjnt.name()\n",
204
- " mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index() - 1\n",
205
- "\n",
206
- " mujoco_jointpos = jaxsim_jointpos\n",
207
- " for i in range(0, len(mujoco_jointpos)):\n",
208
- " mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]\n",
209
- "\n",
210
- " return mujoco_jointpos\n",
211
- "\n",
212
- "\n",
213
- "# To get a good camera location, you can use \"Copy camera\" functionality in MuJoCo GUI\n",
214
- "mj_model = load_mujoco_model_with_camera(\n",
215
- " model_urdf_string,\n",
216
- " [3.954, 3.533, 2.343],\n",
217
- " [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979],\n",
218
- ")\n",
219
- "renderer = mujoco.Renderer(mj_model, height=480, width=640)\n",
193
+ "mjcf_string, assets = UrdfToMjcf.convert(urdf=model.built_from, cameras=camera)\n",
220
194
  "\n",
195
+ "mj_model_helper = MujocoModelHelper.build_from_xml(\n",
196
+ " mjcf_description=mjcf_string, assets=assets\n",
197
+ ")\n",
221
198
  "\n",
222
- "def get_image(camera, mujocojointpos) -> np.ndarray:\n",
223
- " \"\"\"Renders the environment state.\"\"\"\n",
224
- " # Copy joint data in mjdata state\n",
225
- " d = mujoco.MjData(mj_model)\n",
226
- " d.qpos = mujocojointpos\n",
227
- "\n",
228
- " # Forward kinematics\n",
229
- " mujoco.mj_forward(mj_model, d)\n",
230
- "\n",
231
- " # use the mjData object to update the renderer\n",
232
- " renderer.update_scene(d, camera=camera)\n",
233
- " return renderer.render()"
199
+ "# Create the video recorder.\n",
200
+ "recorder = MujocoVideoRecorder(\n",
201
+ " model=mj_model_helper.model,\n",
202
+ " data=mj_model_helper.data,\n",
203
+ " fps=int(1 / 0.010),\n",
204
+ " width=320 * 4,\n",
205
+ " height=240 * 4,\n",
206
+ ")"
234
207
  ]
235
208
  },
236
209
  {
@@ -246,24 +219,27 @@
246
219
  "metadata": {},
247
220
  "outputs": [],
248
221
  "source": [
249
- "from jaxsim.simulation.ode_integration import IntegratorType\n",
250
- "\n",
251
- "sim_images = []\n",
252
- "timestep = 0.01\n",
253
- "for _ in range(300):\n",
254
- " sim_images.append(\n",
255
- " get_image(\n",
256
- " \"side\",\n",
257
- " from_jaxsim_to_mujoco_pos(\n",
258
- " np.array(model.joint_positions()), mj_model, model\n",
259
- " ),\n",
260
- " )\n",
222
+ "import mediapy as media\n",
223
+ "\n",
224
+ "for _ in range(500):\n",
225
+ " data, integrator_state = js.model.step(\n",
226
+ " dt=dt,\n",
227
+ " model=model,\n",
228
+ " data=data,\n",
229
+ " integrator=integrator,\n",
230
+ " integrator_state=integrator_state,\n",
231
+ " joint_forces=None,\n",
232
+ " link_forces=None,\n",
261
233
  " )\n",
262
- " model.integrate(\n",
263
- " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n",
234
+ "\n",
235
+ " mj_model_helper.set_joint_positions(\n",
236
+ " positions=data.joint_positions(), joint_names=model.joint_names()\n",
264
237
  " )\n",
265
238
  "\n",
266
- "media.show_video(sim_images, fps=1 / timestep)"
239
+ " recorder.record_frame(camera_name=\"cartpole_camera\")\n",
240
+ "\n",
241
+ "media.show_video(recorder.frames, fps=1 / dt)\n",
242
+ "recorder.frames = []"
267
243
  ]
268
244
  },
269
245
  {
@@ -290,13 +266,17 @@
290
266
  "KP = 10.0\n",
291
267
  "KD = 6.0\n",
292
268
  "\n",
293
- "# Compute the gravity compensation term\n",
294
- "H = model.free_floating_bias_forces()[6:]\n",
295
- "\n",
296
269
  "\n",
297
270
  "def pd_controller(\n",
298
- " q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array\n",
271
+ " data: js.data.JaxSimModelData, q_d: jax.Array, q_dot_d: jax.Array\n",
299
272
  ") -> jax.Array:\n",
273
+ "\n",
274
+ " # Compute the gravity compensation term\n",
275
+ " H = js.model.free_floating_bias_forces(model=model, data=data)[6:]\n",
276
+ "\n",
277
+ " q = data.joint_positions()\n",
278
+ " q_dot = data.joint_velocities()\n",
279
+ "\n",
300
280
  " return H + KP * (q_d - q) + KD * (q_dot_d - q_dot)"
301
281
  ]
302
282
  },
@@ -313,31 +293,31 @@
313
293
  "metadata": {},
314
294
  "outputs": [],
315
295
  "source": [
316
- "sim_images = []\n",
317
- "timestep = 0.01\n",
318
- "\n",
319
- "for _ in range(300):\n",
320
- " sim_images.append(\n",
321
- " get_image(\n",
322
- " \"side\",\n",
323
- " from_jaxsim_to_mujoco_pos(\n",
324
- " np.array(model.joint_positions()), mj_model, model\n",
325
- " ),\n",
326
- " )\n",
296
+ "for _ in range(500):\n",
297
+ " control_torques = pd_controller(\n",
298
+ " data=data,\n",
299
+ " q_d=jnp.array([0.0, 0.0]),\n",
300
+ " q_dot_d=jnp.array([0.0, 0.0]),\n",
327
301
  " )\n",
328
- " model.set_joint_generalized_force_targets(\n",
329
- " forces=pd_controller(\n",
330
- " q=model.joint_positions(),\n",
331
- " q_d=jnp.array([0.0, 0.0]),\n",
332
- " q_dot=model.joint_velocities(),\n",
333
- " q_dot_d=jnp.array([0.0, 0.0]),\n",
334
- " )\n",
302
+ "\n",
303
+ " data, integrator_state = js.model.step(\n",
304
+ " dt=dt,\n",
305
+ " model=model,\n",
306
+ " data=data,\n",
307
+ " integrator=integrator,\n",
308
+ " integrator_state=integrator_state,\n",
309
+ " joint_forces=control_torques,\n",
310
+ " link_forces=None,\n",
335
311
  " )\n",
336
- " model.integrate(\n",
337
- " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n",
312
+ "\n",
313
+ " mj_model_helper.set_joint_positions(\n",
314
+ " positions=data.joint_positions(), joint_names=model.joint_names()\n",
338
315
  " )\n",
339
316
  "\n",
340
- "media.show_video(sim_images, fps=1 / timestep)"
317
+ " recorder.record_frame(camera_name=\"cartpole_camera\")\n",
318
+ "\n",
319
+ "media.show_video(recorder.frames, fps=1 / dt)\n",
320
+ "recorder.frames = []"
341
321
  ]
342
322
  }
343
323
  ],
@@ -370,7 +350,7 @@
370
350
  "name": "python",
371
351
  "nbconvert_exporter": "python",
372
352
  "pygments_lexer": "ipython3",
373
- "version": "3.11.6"
353
+ "version": "3.11.8"
374
354
  }
375
355
  },
376
356
  "nbformat": 4,
@@ -88,7 +88,12 @@
88
88
  "cell_type": "markdown",
89
89
  "metadata": {},
90
90
  "source": [
91
- "Now, we can create a simulator instance and load the model into it."
91
+ "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n",
92
+ "\n",
93
+ "- `model`: an object that defines the dynamics of the system.\n",
94
+ "- `data`: an object that contains the state of the system.\n",
95
+ "- `integrator`: an object that defines the integration method.\n",
96
+ "- `integrator_state`: an object that contains the state of the integrator."
92
97
  ]
93
98
  },
94
99
  {
@@ -97,29 +102,48 @@
97
102
  "metadata": {},
98
103
  "outputs": [],
99
104
  "source": [
100
- "from jaxsim.high_level.model import VelRepr\n",
101
- "from jaxsim.physics.algos.soft_contacts import SoftContactsParams\n",
102
- "from jaxsim.simulation.ode_integration import IntegratorType\n",
103
- "from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData\n",
105
+ "import jaxsim.api as js\n",
106
+ "from jaxsim import integrators\n",
104
107
  "\n",
105
- "# Simulation Step Parameters\n",
106
- "integration_time = 3.0 # seconds\n",
107
- "step_size = 0.001\n",
108
- "steps_per_run = 1\n",
108
+ "dt = 0.001\n",
109
+ "integration_time = 1500\n",
109
110
  "\n",
110
- "simulator = JaxSim.build(\n",
111
- " step_size=step_size,\n",
112
- " steps_per_run=steps_per_run,\n",
113
- " velocity_representation=VelRepr.Body,\n",
114
- " integrator_type=IntegratorType.EulerSemiImplicit,\n",
115
- " simulator_data=SimulatorData(\n",
116
- " contact_parameters=SoftContactsParams(K=1e6, D=2e3, mu=0.5),\n",
111
+ "model = js.model.JaxSimModel.build_from_model_description(\n",
112
+ " model_description=model_sdf_string\n",
113
+ ")\n",
114
+ "data = js.data.JaxSimModelData.build(model=model)\n",
115
+ "integrator = integrators.fixed_step.RungeKutta4SO3.build(\n",
116
+ " dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
117
+ " model=model,\n",
118
+ " data=data,\n",
119
+ " system_dynamics=js.ode.system_dynamics,\n",
117
120
  " ),\n",
118
121
  ")\n",
122
+ "integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {},
128
+ "source": [
129
+ "It is possible to automatically choose a good set of parameters for the terrain. \n",
119
130
  "\n",
131
+ "By default, in JaxSim a sphere primitive has 250 collision points. This can be modified by setting the `JAXSIM_COLLISION_SPHERE_POINTS` environment variable.\n",
120
132
  "\n",
121
- "# Add model to simulator\n",
122
- "model = simulator.insert_model_from_description(model_description=model_sdf_string)"
133
+ "Given that at its steady-state the sphere will act on two or three points, we can estimate the ground parameters by explicitly setting the number of active points to these values."
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "data = data.replace(\n",
143
+ " soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n",
144
+ " model, number_of_active_collidable_points_steady_state=3\n",
145
+ " )\n",
146
+ ")"
123
147
  ]
124
148
  },
125
149
  {
@@ -136,8 +160,9 @@
136
160
  "outputs": [],
137
161
  "source": [
138
162
  "# Primary Calculations\n",
163
+ "envs_per_row = 4 # @slider(2, 10, 1)\n",
164
+ "\n",
139
165
  "env_spacing = 0.5\n",
140
- "envs_per_row = 3\n",
141
166
  "edge_len = env_spacing * (2 * envs_per_row - 1)\n",
142
167
  "\n",
143
168
  "\n",
@@ -155,6 +180,7 @@
155
180
  " return jnp.array(poses)\n",
156
181
  "\n",
157
182
  "\n",
183
+ "logging.info(f\"Simulating {envs_per_row**2} environments\")\n",
158
184
  "poses = grid(edge_len, envs_per_row)"
159
185
  ]
160
186
  },
@@ -162,9 +188,7 @@
162
188
  "cell_type": "markdown",
163
189
  "metadata": {},
164
190
  "source": [
165
- "In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch.\n",
166
- "\n",
167
- "**Note:** [`step_over_horizon`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L432C1-L529C10) is useful only in open-loop simulations and where the horizon is known in advance. Please checkout [`step`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L384C10-L425) for closed-loop simulations."
191
+ "In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch."
168
192
  ]
169
193
  },
170
194
  {
@@ -173,35 +197,27 @@
173
197
  "metadata": {},
174
198
  "outputs": [],
175
199
  "source": [
176
- "from jaxsim.simulation import simulator_callbacks\n",
177
- "\n",
178
- "\n",
179
- "# Create a logger to store simulation data\n",
180
- "@jax_dataclasses.pytree_dataclass\n",
181
- "class SimulatorLogger(simulator_callbacks.PostStepCallback):\n",
182
- " def post_step(\n",
183
- " self, sim: JaxSim, step_data: Dict[str, StepData]\n",
184
- " ) -> Tuple[JaxSim, jtp.PyTree]:\n",
185
- " \"\"\"Return the StepData object of each simulated model\"\"\"\n",
186
- " return sim, step_data\n",
187
- "\n",
188
- "\n",
189
200
  "# Define a function to simulate a single model instance\n",
190
- "def simulate(sim: JaxSim, pose) -> JaxSim:\n",
191
- " model.zero()\n",
192
- " model.reset_base_position(position=jnp.array(pose))\n",
193
- "\n",
194
- " with sim.editable(validate=True) as sim:\n",
195
- " m = sim.get_model(model.name())\n",
196
- " m.data = model.data\n",
197
- "\n",
198
- " sim, (cb, (_, step_data)) = simulator.step_over_horizon(\n",
199
- " horizon_steps=integration_time // step_size,\n",
200
- " callback_handler=SimulatorLogger(),\n",
201
- " clear_inputs=True,\n",
202
- " )\n",
203
- "\n",
204
- " return step_data"
201
+ "def simulate(\n",
202
+ " data: js.data.JaxSimModelData, integrator_state: dict, pose: jnp.array\n",
203
+ ") -> tuple:\n",
204
+ "\n",
205
+ " data = data.reset_base_position(base_position=pose)\n",
206
+ " x_t_i = []\n",
207
+ "\n",
208
+ " for _ in range(integration_time):\n",
209
+ " data, integrator_state = js.model.step(\n",
210
+ " dt=dt,\n",
211
+ " model=model,\n",
212
+ " data=data,\n",
213
+ " integrator=integrator,\n",
214
+ " integrator_state=integrator_state,\n",
215
+ " joint_forces=None,\n",
216
+ " link_forces=None,\n",
217
+ " )\n",
218
+ " x_t_i.append(data.base_position())\n",
219
+ "\n",
220
+ " return x_t_i"
205
221
  ]
206
222
  },
207
223
  {
@@ -213,7 +229,7 @@
213
229
  "\n",
214
230
  "Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:\n",
215
231
  "\n",
216
- "`in_axes=(None, 0)` means that the first argument of `simulate` is not vectorized, while the second argument is vectorized over the zero-th dimension."
232
+ "`in_axes=(None, None, 0)` means that the first two arguments of `simulate` are not vectorized, while the third argument is vectorized over the zero-th dimension."
217
233
  ]
218
234
  },
219
235
  {
@@ -223,12 +239,12 @@
223
239
  "outputs": [],
224
240
  "source": [
225
241
  "# Define a function to simulate multiple model instances\n",
226
- "simulate_vectorized = jax.vmap(simulate, in_axes=(None, 0))\n",
242
+ "simulate_vectorized = jax.vmap(simulate, in_axes=(None, None, 0))\n",
227
243
  "\n",
228
244
  "# Run and time the simulation\n",
229
245
  "now = time.perf_counter()\n",
230
246
  "\n",
231
- "time_history = simulate_vectorized(simulator, poses[:, 0])\n",
247
+ "x_t = simulate_vectorized(data, integrator_state, poses[:, 0])\n",
232
248
  "\n",
233
249
  "comp_time = time.perf_counter() - now\n",
234
250
  "\n",
@@ -236,7 +252,7 @@
236
252
  " f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n",
237
253
  ")\n",
238
254
  "logging.info(\n",
239
- " f\"This corresponds to an RTF (Real Time Factor) of {envs_per_row**2 *integration_time/comp_time}\"\n",
255
+ " f\"This corresponds to an RTF (Real Time Factor) of {(envs_per_row**2 *integration_time/comp_time):.2f}\"\n",
240
256
  ")"
241
257
  ]
242
258
  },
@@ -253,13 +269,10 @@
253
269
  "metadata": {},
254
270
  "outputs": [],
255
271
  "source": [
256
- "time_history: Dict[str, StepData]\n",
257
- "x_t = time_history[model.name()].tf_model_state\n",
258
- "\n",
259
- "\n",
260
272
  "import matplotlib.pyplot as plt\n",
273
+ "import numpy as np\n",
261
274
  "\n",
262
- "plt.plot(time_history[model.name()].tf[0], x_t.base_position[:, :, 2].T)\n",
275
+ "plt.plot(np.arange(len(x_t)) * dt, np.array(x_t)[:, :, 2])\n",
263
276
  "plt.grid(True)\n",
264
277
  "plt.xlabel(\"Time [s]\")\n",
265
278
  "plt.ylabel(\"Height [m]\")\n",
@@ -297,7 +310,7 @@
297
310
  "name": "python",
298
311
  "nbconvert_exporter": "python",
299
312
  "pygments_lexer": "ipython3",
300
- "version": "3.12.1"
313
+ "version": "3.11.8"
301
314
  }
302
315
  },
303
316
  "nbformat": 4,
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.dev428'
16
- __version_tuple__ = version_tuple = (0, 2, 'dev428')
15
+ __version__ = version = '0.2.dev435'
16
+ __version_tuple__ = version_tuple = (0, 2, 'dev435')
@@ -1,3 +1,4 @@
1
+ import dataclasses
1
2
  import pathlib
2
3
  import tempfile
3
4
  import warnings
@@ -159,6 +160,7 @@ class RodModelToMjcf:
159
160
  considered_joints: list[str] | None = None,
160
161
  plane_normal: tuple[float, float, float] = (0, 0, 1),
161
162
  heightmap: bool | None = None,
163
+ cameras: list[dict[str, str]] | dict[str, str] = None,
162
164
  ) -> tuple[str, dict[str, Any]]:
163
165
  """
164
166
  Converts a ROD model to a Mujoco MJCF string.
@@ -166,6 +168,9 @@ class RodModelToMjcf:
166
168
  Args:
167
169
  rod_model: The ROD model to convert.
168
170
  considered_joints: The list of joint names to consider in the conversion.
171
+ plane_normal: The normal vector of the plane.
172
+ heightmap: Whether to generate a heightmap.
173
+ cameras: The list of cameras to add to the scene.
169
174
 
170
175
  Returns:
171
176
  tuple: A tuple containing the MJCF string and the assets dictionary.
@@ -470,6 +475,14 @@ class RodModelToMjcf:
470
475
  fovy="60",
471
476
  )
472
477
 
478
+ # Add user-defined camera
479
+ cameras = cameras if cameras is not None else {}
480
+ for camera in cameras if isinstance(cameras, list) else [cameras]:
481
+ mj_camera = MujocoCamera.build(**camera)
482
+ _ = ET.SubElement(
483
+ worldbody_element, "camera", dataclasses.asdict(mj_camera)
484
+ )
485
+
473
486
  # ------------------------------------------------
474
487
  # Add a light following the CoM of the first link
475
488
  # ------------------------------------------------
@@ -504,6 +517,7 @@ class UrdfToMjcf:
504
517
  model_name: str | None = None,
505
518
  plane_normal: tuple[float, float, float] = (0, 0, 1),
506
519
  heightmap: bool | None = None,
520
+ cameras: list[dict[str, str]] | dict[str, str] = None,
507
521
  ) -> tuple[str, dict[str, Any]]:
508
522
  """
509
523
  Converts a URDF file to a Mujoco MJCF string.
@@ -512,6 +526,9 @@ class UrdfToMjcf:
512
526
  urdf: The URDF file to convert.
513
527
  considered_joints: The list of joint names to consider in the conversion.
514
528
  model_name: The name of the model to convert.
529
+ plane_normal: The normal vector of the plane.
530
+ heightmap: Whether to generate a heightmap.
531
+ cameras: The list of cameras to add to the scene.
515
532
 
516
533
  Returns:
517
534
  tuple: A tuple containing the MJCF string and the assets dictionary.
@@ -530,6 +547,7 @@ class UrdfToMjcf:
530
547
  considered_joints=considered_joints,
531
548
  plane_normal=plane_normal,
532
549
  heightmap=heightmap,
550
+ cameras=cameras,
533
551
  )
534
552
 
535
553
 
@@ -541,6 +559,7 @@ class SdfToMjcf:
541
559
  model_name: str | None = None,
542
560
  plane_normal: tuple[float, float, float] = (0, 0, 1),
543
561
  heightmap: bool | None = None,
562
+ cameras: list[dict[str, str]] | dict[str, str] = None,
544
563
  ) -> tuple[str, dict[str, Any]]:
545
564
  """
546
565
  Converts a SDF file to a Mujoco MJCF string.
@@ -549,6 +568,9 @@ class SdfToMjcf:
549
568
  sdf: The SDF file to convert.
550
569
  considered_joints: The list of joint names to consider in the conversion.
551
570
  model_name: The name of the model to convert.
571
+ plane_normal: The normal vector of the plane.
572
+ heightmap: Whether to generate a heightmap.
573
+ cameras: The list of cameras to add to the scene.
552
574
 
553
575
  Returns:
554
576
  tuple: A tuple containing the MJCF string and the assets dictionary.
@@ -567,4 +589,27 @@ class SdfToMjcf:
567
589
  considered_joints=considered_joints,
568
590
  plane_normal=plane_normal,
569
591
  heightmap=heightmap,
592
+ cameras=cameras,
570
593
  )
594
+
595
+
596
+ @dataclasses.dataclass
597
+ class MujocoCamera:
598
+ name: str
599
+ mode: str
600
+ pos: str
601
+ xyaxes: str
602
+ fovy: str
603
+
604
+ @classmethod
605
+ def build(cls, **kwargs):
606
+ if not all(isinstance(value, str) for value in kwargs.values()):
607
+ raise ValueError("Values must be strings")
608
+
609
+ if len(kwargs["pos"].split()) != 3:
610
+ raise ValueError("pos must have three values separated by space")
611
+
612
+ if len(kwargs["xyaxes"].split()) != 6:
613
+ raise ValueError("xyaxes must have six values separated by space")
614
+
615
+ return cls(**kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.dev428
3
+ Version: 0.2.dev435
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Home-page: https://github.com/ami-iit/jaxsim
6
6
  Author: Diego Ferigo
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes