jaxsim 0.4.3.dev245__tar.gz → 0.4.3.dev271__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/PKG-INFO +1 -1
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/jaxsim_as_physics_engine.ipynb +27 -35
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/jaxsim_for_robot_controllers.ipynb +7 -27
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/_version.py +2 -2
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/model.py +92 -23
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/ode.py +26 -22
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/integrators/common.py +27 -76
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/integrators/variable_step.py +96 -61
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/PKG-INFO +1 -1
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_automatic_differentiation.py +0 -20
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_simulations.py +7 -57
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.devcontainer/Dockerfile +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.devcontainer/devcontainer.json +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.gitattributes +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/CODEOWNERS +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/dependabot.yml +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/workflows/ci_cd.yml +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/workflows/read_the_docs.yml +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.github/workflows/update_pixi_lockfile.yml +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.gitignore +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.pre-commit-config.yaml +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/.readthedocs.yaml +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/CONTRIBUTING.md +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/LICENSE +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/README.md +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/Makefile +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/conf.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/examples.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/guide/install.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/index.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/make.bat +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/api.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/integrators.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/math.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/mujoco.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/parsers.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/rbda.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/typing.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/docs/modules/utils.rst +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/environment.yml +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/.gitattributes +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/.gitignore +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/README.md +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/assets/build_cartpole_urdf.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/assets/cartpole.urdf +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/pixi.lock +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/pyproject.toml +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/setup.cfg +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/setup.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/com.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/common.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/contact.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/data.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/frame.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/joint.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/link.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/ode_data.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/api/references.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/exceptions.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/integrators/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/integrators/fixed_step.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/logging.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/adjoint.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/cross.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/inertia.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/joint_model.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/quaternion.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/rotation.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/skew.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/math/transform.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/__main__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/loaders.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/model.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/utils.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/mujoco/visualizer.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/collision.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/joint.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/link.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/descriptions/model.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/kinematic_graph.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/rod/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/rod/parser.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/parsers/rod/utils.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/aba.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/collidable_points.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/common.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/rigid.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/soft.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/contacts/visco_elastic.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/crba.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/forward_kinematics.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/jacobian.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/rnea.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/rbda/utils.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/terrain/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/terrain/terrain.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/typing.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/utils/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/utils/tracing.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim/utils/wrappers.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/SOURCES.txt +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/dependency_links.txt +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/requires.txt +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/src/jaxsim.egg-info/top_level.txt +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/__init__.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/conftest.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_com.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_contact.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_data.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_frame.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_joint.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_link.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_api_model.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_contact.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_exceptions.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/test_pytree.py +0 -0
- {jaxsim-0.4.3.dev245 → jaxsim-0.4.3.dev271}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev271
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -63,8 +63,9 @@
|
|
63
63
|
"import jax\n",
|
64
64
|
"import jax.numpy as jnp\n",
|
65
65
|
"import jaxsim.api as js\n",
|
66
|
+
"import jaxsim\n",
|
66
67
|
"import rod\n",
|
67
|
-
"from jaxsim import
|
68
|
+
"from jaxsim import logging\n",
|
68
69
|
"from rod.builder.primitives import SphereBuilder\n",
|
69
70
|
"\n",
|
70
71
|
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
|
@@ -142,8 +143,11 @@
|
|
142
143
|
"\n",
|
143
144
|
"- `model`: an object that defines the dynamics of the system.\n",
|
144
145
|
"- `data`: an object that contains the state of the system.\n",
|
145
|
-
"- `integrator
|
146
|
-
"- `integrator_state
|
146
|
+
"- `integrator` *(Optional)*: an object that defines the integration method.\n",
|
147
|
+
"- `integrator_state` *(Optional)*: an object that contains the state of the integrator.\n",
|
148
|
+
"\n",
|
149
|
+
"The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
|
150
|
+
"In this example, we will explicity pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model."
|
147
151
|
]
|
148
152
|
},
|
149
153
|
{
|
@@ -157,7 +161,9 @@
|
|
157
161
|
"# Create the JaxSim model.\n",
|
158
162
|
"# This is shared among all the parallel instances.\n",
|
159
163
|
"model = js.model.JaxSimModel.build_from_model_description(\n",
|
160
|
-
" model_description=model_sdf_string
|
164
|
+
" model_description=model_sdf_string,\n",
|
165
|
+
" time_step=0.001,\n",
|
166
|
+
" integrator=jaxsim.integrators.fixed_step.Heun2,\n",
|
161
167
|
")\n",
|
162
168
|
"\n",
|
163
169
|
"# Create the data of a single model.\n",
|
@@ -240,21 +246,7 @@
|
|
240
246
|
},
|
241
247
|
"outputs": [],
|
242
248
|
"source": [
|
243
|
-
"
|
244
|
-
"integrator = integrators.fixed_step.Heun2SO3.build(\n",
|
245
|
-
" dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
|
246
|
-
" model=model,\n",
|
247
|
-
" data=data_single,\n",
|
248
|
-
" system_dynamics=js.ode.system_dynamics,\n",
|
249
|
-
" ),\n",
|
250
|
-
")\n",
|
251
|
-
"\n",
|
252
|
-
"# Initialize the integrator.\n",
|
253
|
-
"integrator_state = integrator.init(\n",
|
254
|
-
" x0=data_single.state,\n",
|
255
|
-
" t0=0.0,\n",
|
256
|
-
" dt=model.time_step,\n",
|
257
|
-
")\n",
|
249
|
+
"print(f\"Using integrator: {model.integrator}\")\n",
|
258
250
|
"\n",
|
259
251
|
"# Initialize the simulated time.\n",
|
260
252
|
"T = jnp.arange(start=0, stop=1.0, step=model.time_step)"
|
@@ -324,46 +316,42 @@
|
|
324
316
|
"def step_single(\n",
|
325
317
|
" model: js.model.JaxSimModel,\n",
|
326
318
|
" data: js.data.JaxSimModelData,\n",
|
327
|
-
" integrator_state: dict[str, Any],\n",
|
328
319
|
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
|
329
320
|
"\n",
|
330
321
|
" # Close step over static arguments.\n",
|
331
322
|
" return js.model.step(\n",
|
332
323
|
" model=model,\n",
|
333
324
|
" data=data,\n",
|
334
|
-
" integrator=integrator,\n",
|
335
|
-
" integrator_state=integrator_state,\n",
|
336
325
|
" link_forces=None,\n",
|
337
326
|
" joint_force_references=None,\n",
|
338
327
|
" )\n",
|
339
328
|
"\n",
|
340
329
|
"\n",
|
341
330
|
"@jax.jit\n",
|
342
|
-
"@functools.partial(jax.vmap, in_axes=(None, 0
|
331
|
+
"@functools.partial(jax.vmap, in_axes=(None, 0))\n",
|
343
332
|
"def step_parallel(\n",
|
344
333
|
" model: js.model.JaxSimModel,\n",
|
345
334
|
" data: js.data.JaxSimModelData,\n",
|
346
|
-
" integrator_state: dict[str, Any],\n",
|
347
335
|
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
|
348
336
|
"\n",
|
349
337
|
" return step_single(\n",
|
350
|
-
" model=model, data=data
|
338
|
+
" model=model, data=data\n",
|
351
339
|
" )\n",
|
352
340
|
"\n",
|
353
341
|
"\n",
|
354
342
|
"# The first run will be slow since JAX needs to JIT-compile the functions.\n",
|
355
|
-
"_ = step_single(model, data_single
|
356
|
-
"_ = step_parallel(model, data_batch_t0
|
343
|
+
"_ = step_single(model, data_single)\n",
|
344
|
+
"_ = step_parallel(model, data_batch_t0)\n",
|
357
345
|
"\n",
|
358
346
|
"# Benchmark the execution of a single step.\n",
|
359
347
|
"print(\"\\nSingle simulation step:\")\n",
|
360
|
-
"%timeit step_single(model, data_single
|
348
|
+
"%timeit step_single(model, data_single)\n",
|
361
349
|
"\n",
|
362
350
|
"# On hardware accelerators, there's a range of batch_size values where\n",
|
363
351
|
"# increasing the number of parallel instances doesn't affect computation time.\n",
|
364
352
|
"# This range depends on the GPU/TPU specifications.\n",
|
365
353
|
"print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n",
|
366
|
-
"%timeit step_parallel(model, data_batch_t0
|
354
|
+
"%timeit step_parallel(model, data_batch_t0)"
|
367
355
|
]
|
368
356
|
},
|
369
357
|
{
|
@@ -381,7 +369,7 @@
|
|
381
369
|
"\n",
|
382
370
|
"for _ in T:\n",
|
383
371
|
"\n",
|
384
|
-
" data,
|
372
|
+
" data, _ = step_parallel(model, data)\n",
|
385
373
|
" data_trajectory_list.append(data)"
|
386
374
|
]
|
387
375
|
},
|
@@ -404,9 +392,7 @@
|
|
404
392
|
"source": [
|
405
393
|
"# Convert a list of PyTrees to a batched PyTree.\n",
|
406
394
|
"# This operation is called 'tree transpose' in JAX.\n",
|
407
|
-
"data_trajectory = jax.tree.map(\n",
|
408
|
-
" lambda *leafs: jnp.stack(leafs), *data_trajectory_list\n",
|
409
|
-
")\n",
|
395
|
+
"data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n",
|
410
396
|
"\n",
|
411
397
|
"print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")"
|
412
398
|
]
|
@@ -448,6 +434,11 @@
|
|
448
434
|
"\n",
|
449
435
|
"Have fun!"
|
450
436
|
]
|
437
|
+
},
|
438
|
+
{
|
439
|
+
"cell_type": "markdown",
|
440
|
+
"metadata": {},
|
441
|
+
"source": []
|
451
442
|
}
|
452
443
|
],
|
453
444
|
"metadata": {
|
@@ -459,7 +450,8 @@
|
|
459
450
|
"toc_visible": true
|
460
451
|
},
|
461
452
|
"kernelspec": {
|
462
|
-
"display_name": "
|
453
|
+
"display_name": "jaxsim",
|
454
|
+
"language": "python",
|
463
455
|
"name": "python3"
|
464
456
|
},
|
465
457
|
"language_info": {
|
@@ -472,7 +464,7 @@
|
|
472
464
|
"name": "python",
|
473
465
|
"nbconvert_exporter": "python",
|
474
466
|
"pygments_lexer": "ipython3",
|
475
|
-
"version": "3.
|
467
|
+
"version": "3.12.7"
|
476
468
|
}
|
477
469
|
},
|
478
470
|
"nbformat": 4,
|
@@ -123,7 +123,6 @@
|
|
123
123
|
"# @title Create the model and its data\n",
|
124
124
|
"\n",
|
125
125
|
"import jaxsim.api as js\n",
|
126
|
-
"from jaxsim import integrators\n",
|
127
126
|
"\n",
|
128
127
|
"# Create the model from the model description.\n",
|
129
128
|
"model = js.model.JaxSimModel.build_from_model_description(\n",
|
@@ -143,23 +142,7 @@
|
|
143
142
|
},
|
144
143
|
"outputs": [],
|
145
144
|
"source": [
|
146
|
-
"# @title
|
147
|
-
"\n",
|
148
|
-
"# Create the integrator.\n",
|
149
|
-
"integrator = integrators.fixed_step.RungeKutta4.build(\n",
|
150
|
-
" dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
|
151
|
-
" model=model,\n",
|
152
|
-
" data=data_zero,\n",
|
153
|
-
" system_dynamics=js.ode.system_dynamics,\n",
|
154
|
-
" ),\n",
|
155
|
-
")\n",
|
156
|
-
"\n",
|
157
|
-
"# Initialize the integrator.\n",
|
158
|
-
"integrator_state = integrator.init(\n",
|
159
|
-
" x0=data_zero.state,\n",
|
160
|
-
" t0=0.0,\n",
|
161
|
-
" dt=model.time_step,\n",
|
162
|
-
")\n",
|
145
|
+
"# @title Define simulation parameters\n",
|
163
146
|
"\n",
|
164
147
|
"# Initialize the simulated time.\n",
|
165
148
|
"T = jnp.arange(start=0, stop=5.0, step=model.time_step)"
|
@@ -255,11 +238,9 @@
|
|
255
238
|
"for _ in T:\n",
|
256
239
|
"\n",
|
257
240
|
" # Step the JaxSim simulation.\n",
|
258
|
-
" data,
|
241
|
+
" data, _ = js.model.step(\n",
|
259
242
|
" model=model,\n",
|
260
243
|
" data=data,\n",
|
261
|
-
" integrator=integrator,\n",
|
262
|
-
" integrator_state=integrator_state,\n",
|
263
244
|
" joint_force_references=None,\n",
|
264
245
|
" link_forces=None,\n",
|
265
246
|
" )\n",
|
@@ -359,7 +340,7 @@
|
|
359
340
|
" ṡ = data.joint_velocities()\n",
|
360
341
|
"\n",
|
361
342
|
" # Compute the actuated joint torques.\n",
|
362
|
-
" s_star = -
|
343
|
+
" s_star = -kp * (s - s_des) - kd * (ṡ - s_dot_des)\n",
|
363
344
|
" τ = Mss @ s_star + hs\n",
|
364
345
|
"\n",
|
365
346
|
" return τ"
|
@@ -407,11 +388,9 @@
|
|
407
388
|
" )\n",
|
408
389
|
"\n",
|
409
390
|
" # Step the JaxSim simulation.\n",
|
410
|
-
" data,
|
391
|
+
" data, _ = js.model.step(\n",
|
411
392
|
" model=model,\n",
|
412
393
|
" data=data,\n",
|
413
|
-
" integrator=integrator,\n",
|
414
|
-
" integrator_state=integrator_state,\n",
|
415
394
|
" joint_force_references=τ,\n",
|
416
395
|
" )\n",
|
417
396
|
"\n",
|
@@ -461,7 +440,8 @@
|
|
461
440
|
"toc_visible": true
|
462
441
|
},
|
463
442
|
"kernelspec": {
|
464
|
-
"display_name": "
|
443
|
+
"display_name": "jaxsim",
|
444
|
+
"language": "python",
|
465
445
|
"name": "python3"
|
466
446
|
},
|
467
447
|
"language_info": {
|
@@ -474,7 +454,7 @@
|
|
474
454
|
"name": "python",
|
475
455
|
"nbconvert_exporter": "python",
|
476
456
|
"pygments_lexer": "ipython3",
|
477
|
-
"version": "3.
|
457
|
+
"version": "3.12.7"
|
478
458
|
}
|
479
459
|
},
|
480
460
|
"nbformat": 4,
|
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev271'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev271')
|
@@ -54,6 +54,10 @@ class JaxSimModel(JaxsimDataclass):
|
|
54
54
|
default=None, repr=False
|
55
55
|
)
|
56
56
|
|
57
|
+
integrator: Static[jaxsim.integrators.Integrator | None] = dataclasses.field(
|
58
|
+
default=None, repr=False
|
59
|
+
)
|
60
|
+
|
57
61
|
_description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
|
58
62
|
dataclasses.field(default=None, repr=False)
|
59
63
|
)
|
@@ -93,12 +97,16 @@ class JaxSimModel(JaxsimDataclass):
|
|
93
97
|
# Initialization and state
|
94
98
|
# ========================
|
95
99
|
|
96
|
-
@
|
100
|
+
@classmethod
|
97
101
|
def build_from_model_description(
|
102
|
+
cls,
|
98
103
|
model_description: str | pathlib.Path | rod.Model,
|
99
|
-
model_name: str | None = None,
|
100
104
|
*,
|
105
|
+
model_name: str | None = None,
|
101
106
|
time_step: jtp.FloatLike | None = None,
|
107
|
+
integrator: (
|
108
|
+
jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
|
109
|
+
) = None,
|
102
110
|
terrain: jaxsim.terrain.Terrain | None = None,
|
103
111
|
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
|
104
112
|
is_urdf: bool | None = None,
|
@@ -120,6 +128,10 @@ class JaxSimModel(JaxsimDataclass):
|
|
120
128
|
contact_model:
|
121
129
|
The contact model to consider.
|
122
130
|
If not specified, a soft contacts model is used.
|
131
|
+
integrator:
|
132
|
+
The integrator to use. If not specified, a default one is used.
|
133
|
+
This argument can either be a pre-built integrator instance or one
|
134
|
+
of the integrator classes defined in JaxSim.
|
123
135
|
is_urdf:
|
124
136
|
The optional flag to force the model description to be parsed as a URDF.
|
125
137
|
This is usually automatically inferred.
|
@@ -146,10 +158,11 @@ class JaxSimModel(JaxsimDataclass):
|
|
146
158
|
)
|
147
159
|
|
148
160
|
# Build the model.
|
149
|
-
model =
|
161
|
+
model = cls.build(
|
150
162
|
model_description=intermediate_description,
|
151
163
|
model_name=model_name,
|
152
164
|
time_step=time_step,
|
165
|
+
integrator=integrator,
|
153
166
|
terrain=terrain,
|
154
167
|
contact_model=contact_model,
|
155
168
|
)
|
@@ -160,12 +173,16 @@ class JaxSimModel(JaxsimDataclass):
|
|
160
173
|
|
161
174
|
return model
|
162
175
|
|
163
|
-
@
|
176
|
+
@classmethod
|
164
177
|
def build(
|
178
|
+
cls,
|
165
179
|
model_description: ModelDescription,
|
166
|
-
model_name: str | None = None,
|
167
180
|
*,
|
181
|
+
model_name: str | None = None,
|
168
182
|
time_step: jtp.FloatLike | None = None,
|
183
|
+
integrator: (
|
184
|
+
jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
|
185
|
+
) = None,
|
169
186
|
terrain: jaxsim.terrain.Terrain | None = None,
|
170
187
|
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
|
171
188
|
) -> JaxSimModel:
|
@@ -182,6 +199,11 @@ class JaxSimModel(JaxsimDataclass):
|
|
182
199
|
The default time step to consider for the simulation. It can be
|
183
200
|
manually overridden in the function that steps the simulation.
|
184
201
|
terrain: The terrain to consider (the default is a flat infinite plane).
|
202
|
+
The optional name of the model overriding the physics model name.
|
203
|
+
integrator:
|
204
|
+
The integrator to use. If not specified, a default one is used.
|
205
|
+
This argument can either be a pre-built integrator instance or one
|
206
|
+
of the integrator classes defined in JaxSim.
|
185
207
|
contact_model:
|
186
208
|
The contact model to consider.
|
187
209
|
If not specified, a soft contacts model is used.
|
@@ -195,23 +217,62 @@ class JaxSimModel(JaxsimDataclass):
|
|
195
217
|
|
196
218
|
# Consider the default terrain (a flat infinite plane) if not specified.
|
197
219
|
terrain = (
|
198
|
-
terrain
|
220
|
+
terrain
|
221
|
+
if terrain is not None
|
222
|
+
else JaxSimModel.__dataclass_fields__["terrain"].default_factory()
|
199
223
|
)
|
200
224
|
|
201
225
|
# Consider the default time step if not specified.
|
202
226
|
time_step = (
|
203
|
-
time_step
|
227
|
+
time_step
|
228
|
+
if time_step is not None
|
229
|
+
else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
|
204
230
|
)
|
205
231
|
|
206
232
|
# Create the default contact model.
|
207
233
|
# It will be populated with an initial estimation of good parameters.
|
208
234
|
# While these might not be the best, they are a good starting point.
|
209
|
-
contact_model =
|
210
|
-
|
235
|
+
contact_model = (
|
236
|
+
contact_model
|
237
|
+
if contact_model is not None
|
238
|
+
else jaxsim.rbda.contacts.SoftContacts.build(
|
239
|
+
terrain=terrain, parameters=None
|
240
|
+
)
|
211
241
|
)
|
212
242
|
|
243
|
+
# Build the integrator if not provided.
|
244
|
+
match integrator:
|
245
|
+
|
246
|
+
# If None, build a default integrator.
|
247
|
+
case None:
|
248
|
+
|
249
|
+
integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
|
250
|
+
dynamics=js.ode.wrap_system_dynamics_for_integration(
|
251
|
+
system_dynamics=js.ode.system_dynamics
|
252
|
+
)
|
253
|
+
)
|
254
|
+
|
255
|
+
# If it's a pre-built integrator (also a custom one from the user)
|
256
|
+
# just use it as is.
|
257
|
+
case _ if isinstance(integrator, jaxsim.integrators.Integrator):
|
258
|
+
pass
|
259
|
+
|
260
|
+
# If an integrator class is passed, assume that it is a JaxSim integrator
|
261
|
+
# and build it with the default system dynamics.
|
262
|
+
case _ if issubclass(integrator, jaxsim.integrators.Integrator):
|
263
|
+
|
264
|
+
integrator_cls = integrator
|
265
|
+
integrator = integrator_cls.build(
|
266
|
+
dynamics=js.ode.wrap_system_dynamics_for_integration(
|
267
|
+
system_dynamics=js.ode.system_dynamics
|
268
|
+
)
|
269
|
+
)
|
270
|
+
|
271
|
+
case _:
|
272
|
+
raise ValueError(f"Invalid integrator: {integrator}")
|
273
|
+
|
213
274
|
# Build the model.
|
214
|
-
model =
|
275
|
+
model = cls(
|
215
276
|
model_name=model_name,
|
216
277
|
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
|
217
278
|
model_description=model_description
|
@@ -219,6 +280,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
219
280
|
time_step=time_step,
|
220
281
|
terrain=terrain,
|
221
282
|
contact_model=contact_model,
|
283
|
+
integrator=integrator,
|
222
284
|
# The following is wrapped as hashless since it's a static argument, and we
|
223
285
|
# don't want to trigger recompilation if it changes. All relevant parameters
|
224
286
|
# needed to compute kinematics and dynamics quantities are stored in the
|
@@ -404,6 +466,7 @@ def reduce(
|
|
404
466
|
reduced_model = JaxSimModel.build(
|
405
467
|
model_description=reduced_intermediate_description,
|
406
468
|
model_name=model.name(),
|
469
|
+
time_step=model.time_step,
|
407
470
|
terrain=model.terrain,
|
408
471
|
contact_model=model.contact_model,
|
409
472
|
)
|
@@ -1912,10 +1975,10 @@ def step(
|
|
1912
1975
|
model: JaxSimModel,
|
1913
1976
|
data: js.data.JaxSimModelData,
|
1914
1977
|
*,
|
1915
|
-
integrator: jaxsim.integrators.Integrator,
|
1916
1978
|
t0: jtp.FloatLike = 0.0,
|
1917
1979
|
dt: jtp.FloatLike | None = None,
|
1918
|
-
|
1980
|
+
integrator: jaxsim.integrators.Integrator | None = None,
|
1981
|
+
integrator_metadata: dict[str, Any] | None = None,
|
1919
1982
|
link_forces: jtp.MatrixLike | None = None,
|
1920
1983
|
joint_force_references: jtp.VectorLike | None = None,
|
1921
1984
|
**kwargs,
|
@@ -1927,7 +1990,7 @@ def step(
|
|
1927
1990
|
model: The model to consider.
|
1928
1991
|
data: The data of the considered model.
|
1929
1992
|
integrator: The integrator to use.
|
1930
|
-
|
1993
|
+
integrator_metadata: The metadata of the integrator, if needed.
|
1931
1994
|
t0: The initial time to consider. Only relevant for time-dependent dynamics.
|
1932
1995
|
dt: The time step to consider. If not specified, it is read from the model.
|
1933
1996
|
link_forces:
|
@@ -1937,8 +2000,9 @@ def step(
|
|
1937
2000
|
kwargs: Additional kwargs to pass to the integrator.
|
1938
2001
|
|
1939
2002
|
Returns:
|
1940
|
-
A tuple containing the new data of the model
|
1941
|
-
|
2003
|
+
A tuple containing the new data of the model and a dictionary of auxiliary
|
2004
|
+
data computed during the step. If the integrator has metadata, the dictionary
|
2005
|
+
will contain the new metadata stored in the `integrator_metadata` key.
|
1942
2006
|
|
1943
2007
|
Note:
|
1944
2008
|
In order to reduce the occurrences of frame conversions performed internally,
|
@@ -1953,8 +2017,9 @@ def step(
|
|
1953
2017
|
integrator_kwargs = kwargs.pop("integrator_kwargs", {})
|
1954
2018
|
integrator_kwargs = kwargs | integrator_kwargs
|
1955
2019
|
|
1956
|
-
#
|
1957
|
-
|
2020
|
+
# Extract the integrator and the optional metadata.
|
2021
|
+
integrator_metadata_t0 = integrator_metadata
|
2022
|
+
integrator = integrator if integrator is not None else model.integrator
|
1958
2023
|
|
1959
2024
|
# Initialize the time-related variables.
|
1960
2025
|
state_t0 = data.state
|
@@ -2010,11 +2075,11 @@ def step(
|
|
2010
2075
|
τ_references = references.joint_force_references(model=model)
|
2011
2076
|
|
2012
2077
|
# Step the dynamics forward.
|
2013
|
-
state_tf,
|
2078
|
+
state_tf, integrator_metadata_tf = integrator.step(
|
2014
2079
|
x0=state_t0,
|
2015
2080
|
t0=t0,
|
2016
2081
|
dt=dt,
|
2017
|
-
|
2082
|
+
metadata=integrator_metadata_t0,
|
2018
2083
|
# Always inject the current (model, data) pair into the system dynamics
|
2019
2084
|
# considered by the integrator, and include the input variables represented
|
2020
2085
|
# by the pair (f_L, τ_references).
|
@@ -2091,13 +2156,17 @@ def step(
|
|
2091
2156
|
)
|
2092
2157
|
)
|
2093
2158
|
|
2094
|
-
|
2095
|
-
|
2096
|
-
|
2159
|
+
# Reset the generalized velocity.
|
2160
|
+
data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
|
2161
|
+
data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
|
2097
2162
|
|
2098
2163
|
# Restore the input velocity representation.
|
2099
2164
|
data_tf = data_tf.replace(
|
2100
2165
|
velocity_representation=data.velocity_representation, validate=False
|
2101
2166
|
)
|
2102
2167
|
|
2103
|
-
return data_tf,
|
2168
|
+
return data_tf, {} | (
|
2169
|
+
dict(integrator_metadata=integrator_metadata_tf)
|
2170
|
+
if integrator_metadata is not None
|
2171
|
+
else {}
|
2172
|
+
)
|
@@ -24,41 +24,45 @@ class SystemDynamicsFromModelAndData(Protocol):
|
|
24
24
|
|
25
25
|
|
26
26
|
def wrap_system_dynamics_for_integration(
|
27
|
-
model: js.model.JaxSimModel,
|
28
|
-
data: js.data.JaxSimModelData,
|
29
27
|
*,
|
30
28
|
system_dynamics: SystemDynamicsFromModelAndData,
|
31
|
-
**kwargs,
|
29
|
+
**kwargs: dict[str, Any],
|
32
30
|
) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
|
33
31
|
"""
|
34
|
-
Wrap
|
35
|
-
|
32
|
+
Wrap the system dynamics considered by JaxSim integrators in a generic
|
33
|
+
`f(x, t, **u, **parameters)` function.
|
36
34
|
|
37
35
|
Args:
|
38
|
-
model: The model to consider.
|
39
|
-
data: The data of the considered model.
|
40
36
|
system_dynamics: The system dynamics to wrap.
|
41
37
|
**kwargs: Additional kwargs to close over the system dynamics.
|
42
38
|
|
43
39
|
Returns:
|
44
|
-
The system dynamics closed over the
|
40
|
+
The system dynamics closed over the additional kwargs to be used by
|
41
|
+
JaxSim integrators.
|
45
42
|
"""
|
46
43
|
|
47
|
-
#
|
48
|
-
|
49
|
-
|
50
|
-
#
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
44
|
+
# Close `system_dynamics` over additional kwargs.
|
45
|
+
# Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
|
46
|
+
# following logic to allow the caller to close over arguments having the same name
|
47
|
+
# of the ones used in the `wrap_system_dynamics_for_integration` function.
|
48
|
+
kwargs = kwargs.copy() if kwargs is not None else {}
|
49
|
+
colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
|
50
|
+
system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
|
51
|
+
|
52
|
+
# Remove `model` and `data` for backward compatibility.
|
53
|
+
# It's no longer necessary to close over them at this stage, as this is always
|
54
|
+
# done in `jaxsim.api.model.step`.
|
55
|
+
# We can remove the following lines in a few releases.
|
56
|
+
_ = system_dynamics_kwargs.pop("data", None)
|
57
|
+
_ = system_dynamics_kwargs.pop("model", None)
|
58
|
+
|
59
|
+
# Create the function with the signature expected by our generic integrators.
|
60
|
+
# Note that our system dynamics is time independent.
|
57
61
|
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
58
62
|
|
59
|
-
#
|
60
|
-
data_f = kwargs_f.pop("data"
|
61
|
-
model_f = kwargs_f.pop("model"
|
63
|
+
# Get the data and model objects from the kwargs.
|
64
|
+
data_f = kwargs_f.pop("data")
|
65
|
+
model_f = kwargs_f.pop("model")
|
62
66
|
|
63
67
|
# Update the state and time stored inside data.
|
64
68
|
with data_f.editable(validate=True) as data_rw:
|
@@ -69,7 +73,7 @@ def wrap_system_dynamics_for_integration(
|
|
69
73
|
return system_dynamics(
|
70
74
|
model=model_f,
|
71
75
|
data=data_rw,
|
72
|
-
**(
|
76
|
+
**(system_dynamics_kwargs | kwargs_f),
|
73
77
|
)
|
74
78
|
|
75
79
|
f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
|