jaxsim 0.4.3.dev282__tar.gz → 0.4.3.dev295__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/PKG-INFO +1 -1
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/_version.py +2 -2
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/contact.py +8 -4
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/data.py +26 -20
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/joint.py +15 -5
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/kin_dyn_parameters.py +1 -3
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/link.py +3 -150
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/model.py +167 -12
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/references.py +29 -23
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/integrators/common.py +2 -3
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/math/adjoint.py +6 -5
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/rod/parser.py +1 -1
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/contacts/rigid.py +23 -47
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/contacts/visco_elastic.py +4 -2
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim.egg-info/PKG-INFO +1 -1
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_api_frame.py +1 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_api_joint.py +1 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_api_link.py +1 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.devcontainer/Dockerfile +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.devcontainer/devcontainer.json +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.gitattributes +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.github/CODEOWNERS +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.github/dependabot.yml +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.github/workflows/ci_cd.yml +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.github/workflows/read_the_docs.yml +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.github/workflows/update_pixi_lockfile.yml +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.gitignore +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.pre-commit-config.yaml +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/.readthedocs.yaml +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/CONTRIBUTING.md +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/LICENSE +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/README.md +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/Makefile +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/conf.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/examples.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/guide/install.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/index.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/make.bat +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/modules/api.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/modules/integrators.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/modules/math.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/modules/mujoco.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/modules/parsers.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/modules/rbda.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/modules/typing.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/docs/modules/utils.rst +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/environment.yml +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/examples/.gitattributes +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/examples/.gitignore +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/examples/README.md +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/examples/assets/build_cartpole_urdf.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/examples/assets/cartpole.urdf +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/examples/jaxsim_as_physics_engine.ipynb +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/pixi.lock +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/pyproject.toml +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/setup.cfg +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/setup.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/com.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/common.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/frame.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/ode.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/api/ode_data.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/exceptions.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/integrators/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/integrators/fixed_step.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/integrators/variable_step.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/logging.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/math/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/math/cross.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/math/inertia.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/math/joint_model.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/math/quaternion.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/math/rotation.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/math/skew.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/math/transform.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/mujoco/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/mujoco/__main__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/mujoco/loaders.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/mujoco/model.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/mujoco/utils.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/mujoco/visualizer.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/descriptions/collision.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/descriptions/joint.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/descriptions/link.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/descriptions/model.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/kinematic_graph.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/rod/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/parsers/rod/utils.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/aba.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/collidable_points.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/contacts/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/contacts/common.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/contacts/soft.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/crba.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/forward_kinematics.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/jacobian.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/rnea.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/rbda/utils.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/terrain/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/terrain/terrain.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/typing.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/utils/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/utils/tracing.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim/utils/wrappers.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim.egg-info/SOURCES.txt +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim.egg-info/dependency_links.txt +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim.egg-info/requires.txt +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/src/jaxsim.egg-info/top_level.txt +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/__init__.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/conftest.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_api_com.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_api_contact.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_api_data.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_api_model.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_automatic_differentiation.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_contact.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_exceptions.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_pytree.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/test_simulations.py +0 -0
- {jaxsim-0.4.3.dev282 → jaxsim-0.4.3.dev295}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev295
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev295'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev295')
|
@@ -243,9 +243,7 @@ def in_contact(
|
|
243
243
|
A boolean vector indicating whether the links are in contact with the terrain.
|
244
244
|
"""
|
245
245
|
|
246
|
-
|
247
|
-
|
248
|
-
if set(link_names).difference(model.link_names()):
|
246
|
+
if link_names is not None and set(link_names).difference(model.link_names()):
|
249
247
|
raise ValueError("One or more link names are not part of the model")
|
250
248
|
|
251
249
|
W_p_Ci = collidable_point_positions(model=model, data=data)
|
@@ -256,13 +254,19 @@ def in_contact(
|
|
256
254
|
|
257
255
|
below_terrain = W_p_Ci[:, 2] <= terrain_height
|
258
256
|
|
257
|
+
link_idxs = (
|
258
|
+
js.link.names_to_idxs(link_names=link_names, model=model)
|
259
|
+
if link_names is not None
|
260
|
+
else jnp.arange(model.number_of_links())
|
261
|
+
)
|
262
|
+
|
259
263
|
links_in_contact = jax.vmap(
|
260
264
|
lambda link_index: jnp.where(
|
261
265
|
jnp.array(model.kin_dyn_parameters.contact_parameters.body) == link_index,
|
262
266
|
below_terrain,
|
263
267
|
jnp.zeros_like(below_terrain, dtype=bool),
|
264
268
|
).any()
|
265
|
-
)(
|
269
|
+
)(link_idxs)
|
266
270
|
|
267
271
|
return links_in_contact
|
268
272
|
|
@@ -292,11 +292,13 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
292
292
|
msg = "The data object is not compatible with the provided model"
|
293
293
|
raise ValueError(msg)
|
294
294
|
|
295
|
-
|
296
|
-
|
297
|
-
return self.state.physics_model.joint_positions[
|
295
|
+
joint_idxs = (
|
298
296
|
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
299
|
-
|
297
|
+
if joint_names is not None
|
298
|
+
else jnp.arange(model.number_of_joints())
|
299
|
+
)
|
300
|
+
|
301
|
+
return self.state.physics_model.joint_positions[joint_idxs]
|
300
302
|
|
301
303
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
302
304
|
def joint_velocities(
|
@@ -337,11 +339,13 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
337
339
|
msg = "The data object is not compatible with the provided model"
|
338
340
|
raise ValueError(msg)
|
339
341
|
|
340
|
-
|
341
|
-
|
342
|
-
return self.state.physics_model.joint_velocities[
|
342
|
+
joint_idxs = (
|
343
343
|
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
344
|
-
|
344
|
+
if joint_names is not None
|
345
|
+
else jnp.arange(model.number_of_joints())
|
346
|
+
)
|
347
|
+
|
348
|
+
return self.state.physics_model.joint_velocities[joint_idxs]
|
345
349
|
|
346
350
|
@jax.jit
|
347
351
|
def base_position(self) -> jtp.Vector:
|
@@ -374,10 +378,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
374
378
|
# we introduce a Baumgarte stabilization to let the quaternion converge to
|
375
379
|
# a unit quaternion. In this case, it is not guaranteed that the quaternion
|
376
380
|
# stored in the state is a unit quaternion.
|
377
|
-
W_Q_B =
|
378
|
-
|
379
|
-
on_true=W_Q_B,
|
380
|
-
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
|
381
|
+
W_Q_B = jnp.where(
|
382
|
+
jnp.allclose(W_Q_B.dot(W_Q_B), 1.0), W_Q_B, W_Q_B / jnp.linalg.norm(W_Q_B)
|
381
383
|
)
|
382
384
|
|
383
385
|
return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
|
@@ -502,12 +504,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
502
504
|
msg = "The data object is not compatible with the provided model"
|
503
505
|
raise ValueError(msg)
|
504
506
|
|
505
|
-
|
507
|
+
joint_idxs = (
|
508
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
509
|
+
if joint_names is not None
|
510
|
+
else jnp.arange(model.number_of_joints())
|
511
|
+
)
|
506
512
|
|
507
513
|
return replace(
|
508
|
-
s=self.state.physics_model.joint_positions.at[
|
509
|
-
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
510
|
-
].set(positions)
|
514
|
+
s=self.state.physics_model.joint_positions.at[joint_idxs].set(positions)
|
511
515
|
)
|
512
516
|
|
513
517
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
@@ -548,12 +552,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
548
552
|
msg = "The data object is not compatible with the provided model"
|
549
553
|
raise ValueError(msg)
|
550
554
|
|
551
|
-
|
555
|
+
joint_idxs = (
|
556
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
557
|
+
if joint_names is not None
|
558
|
+
else jnp.arange(model.number_of_joints())
|
559
|
+
)
|
552
560
|
|
553
561
|
return replace(
|
554
|
-
ṡ=self.state.physics_model.joint_velocities.at[
|
555
|
-
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
556
|
-
].set(velocities)
|
562
|
+
ṡ=self.state.physics_model.joint_velocities.at[joint_idxs].set(velocities)
|
557
563
|
)
|
558
564
|
|
559
565
|
@jax.jit
|
@@ -157,13 +157,19 @@ def position_limits(
|
|
157
157
|
The position limits of the joints.
|
158
158
|
"""
|
159
159
|
|
160
|
-
|
160
|
+
joint_idxs = (
|
161
|
+
names_to_idxs(joint_names=joint_names, model=model)
|
162
|
+
if joint_names is not None
|
163
|
+
else jnp.arange(model.number_of_joints())
|
164
|
+
)
|
161
165
|
|
162
|
-
if len(
|
166
|
+
if len(joint_idxs) == 0:
|
163
167
|
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
|
164
168
|
|
165
|
-
|
166
|
-
|
169
|
+
s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_idxs]
|
170
|
+
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_idxs]
|
171
|
+
|
172
|
+
return s_min.astype(float), s_max.astype(float)
|
167
173
|
|
168
174
|
|
169
175
|
# ======================
|
@@ -203,7 +209,11 @@ def random_joint_positions(
|
|
203
209
|
# Get the joint indices.
|
204
210
|
# Note that it will trigger an exception if the given `joint_names` are not valid.
|
205
211
|
joint_names = joint_names if joint_names is not None else model.joint_names()
|
206
|
-
joint_indices =
|
212
|
+
joint_indices = (
|
213
|
+
names_to_idxs(model=model, joint_names=joint_names)
|
214
|
+
if joint_names is not None
|
215
|
+
else jnp.arange(model.number_of_joints())
|
216
|
+
)
|
207
217
|
|
208
218
|
from jaxsim.parsers.descriptions.joint import JointType
|
209
219
|
|
@@ -398,9 +398,7 @@ class KynDynParameters(JaxsimDataclass):
|
|
398
398
|
λ_H_pre = jnp.vstack(
|
399
399
|
[
|
400
400
|
jnp.eye(4)[jnp.newaxis],
|
401
|
-
|
402
|
-
lambda i: self.joint_model.parent_H_predecessor(joint_index=i)
|
403
|
-
)(jnp.arange(1, 1 + self.number_of_joints())),
|
401
|
+
self.joint_model.λ_H_pre[1 : 1 + self.number_of_joints()],
|
404
402
|
]
|
405
403
|
)
|
406
404
|
|
@@ -423,156 +423,9 @@ def jacobian_derivative(
|
|
423
423
|
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
424
424
|
)
|
425
425
|
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
joint_positions=data.joint_positions(),
|
430
|
-
joint_velocities=data.joint_velocities(),
|
431
|
-
)
|
432
|
-
|
433
|
-
# Compute the actual doubly-left free-floating jacobian derivative of the link
|
434
|
-
# by zeroing the columns not in the path π_B(L) using the boolean κ(i).
|
435
|
-
κb = model.kin_dyn_parameters.support_body_array_bool[link_index]
|
436
|
-
B_J̇_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J̇_full_WX_B
|
437
|
-
|
438
|
-
# =====================================================
|
439
|
-
# Compute quantities to adjust the input representation
|
440
|
-
# =====================================================
|
441
|
-
|
442
|
-
In = jnp.eye(model.dofs())
|
443
|
-
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
|
444
|
-
|
445
|
-
match data.velocity_representation:
|
446
|
-
|
447
|
-
case VelRepr.Inertial:
|
448
|
-
|
449
|
-
W_H_B = data.base_transform()
|
450
|
-
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
|
451
|
-
|
452
|
-
with data.switch_velocity_representation(VelRepr.Inertial):
|
453
|
-
W_v_WB = data.base_velocity()
|
454
|
-
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
|
455
|
-
|
456
|
-
# Compute the operator to change the representation of ν, and its
|
457
|
-
# time derivative.
|
458
|
-
T = jax.scipy.linalg.block_diag(B_X_W, In)
|
459
|
-
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On)
|
460
|
-
|
461
|
-
case VelRepr.Body:
|
462
|
-
|
463
|
-
B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
|
464
|
-
translation=jnp.zeros(3), rotation=jnp.eye(3)
|
465
|
-
)
|
466
|
-
|
467
|
-
B_Ẋ_B = jnp.zeros(shape=(6, 6))
|
468
|
-
|
469
|
-
# Compute the operator to change the representation of ν, and its
|
470
|
-
# time derivative.
|
471
|
-
T = jax.scipy.linalg.block_diag(B_X_B, In)
|
472
|
-
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On)
|
473
|
-
|
474
|
-
case VelRepr.Mixed:
|
475
|
-
|
476
|
-
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
477
|
-
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
478
|
-
|
479
|
-
with data.switch_velocity_representation(VelRepr.Mixed):
|
480
|
-
BW_v_WB = data.base_velocity()
|
481
|
-
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
482
|
-
|
483
|
-
BW_v_BW_B = BW_v_WB - BW_v_W_BW
|
484
|
-
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
|
485
|
-
|
486
|
-
# Compute the operator to change the representation of ν, and its
|
487
|
-
# time derivative.
|
488
|
-
T = jax.scipy.linalg.block_diag(B_X_BW, In)
|
489
|
-
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On)
|
490
|
-
|
491
|
-
case _:
|
492
|
-
raise ValueError(data.velocity_representation)
|
493
|
-
|
494
|
-
# ======================================================
|
495
|
-
# Compute quantities to adjust the output representation
|
496
|
-
# ======================================================
|
497
|
-
|
498
|
-
match output_vel_repr:
|
499
|
-
|
500
|
-
case VelRepr.Inertial:
|
501
|
-
|
502
|
-
W_H_B = data.base_transform()
|
503
|
-
O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)
|
504
|
-
|
505
|
-
with data.switch_velocity_representation(VelRepr.Body):
|
506
|
-
B_v_WB = data.base_velocity()
|
507
|
-
|
508
|
-
O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
|
509
|
-
|
510
|
-
case VelRepr.Body:
|
511
|
-
|
512
|
-
O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
|
513
|
-
transform=B_H_L[link_index, :, :], inverse=True
|
514
|
-
)
|
515
|
-
|
516
|
-
B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)
|
517
|
-
|
518
|
-
with data.switch_velocity_representation(VelRepr.Body):
|
519
|
-
B_v_WB = data.base_velocity()
|
520
|
-
L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index)
|
521
|
-
|
522
|
-
O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
|
523
|
-
B_X_L @ L_v_WL - B_v_WB
|
524
|
-
)
|
525
|
-
|
526
|
-
case VelRepr.Mixed:
|
527
|
-
|
528
|
-
W_H_B = data.base_transform()
|
529
|
-
W_H_L = W_H_B @ B_H_L[link_index, :, :]
|
530
|
-
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
|
531
|
-
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L[link_index, :, :])
|
532
|
-
|
533
|
-
O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B)
|
534
|
-
|
535
|
-
B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)
|
536
|
-
|
537
|
-
with data.switch_velocity_representation(VelRepr.Body):
|
538
|
-
B_v_WB = data.base_velocity()
|
539
|
-
|
540
|
-
with data.switch_velocity_representation(VelRepr.Mixed):
|
541
|
-
LW_v_WL = js.link.velocity(
|
542
|
-
model=model, data=data, link_index=link_index
|
543
|
-
)
|
544
|
-
LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3))
|
545
|
-
|
546
|
-
LW_v_LW_L = LW_v_WL - LW_v_W_LW
|
547
|
-
LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L
|
548
|
-
|
549
|
-
O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
|
550
|
-
B_X_LW @ LW_v_B_LW
|
551
|
-
)
|
552
|
-
case _:
|
553
|
-
raise ValueError(output_vel_repr)
|
554
|
-
|
555
|
-
# =============================================================
|
556
|
-
# Express the Jacobian derivative in the target representations
|
557
|
-
# =============================================================
|
558
|
-
|
559
|
-
# The derivative of the equation to change the input and output representations
|
560
|
-
# of the Jacobian derivative needs the computation of the plain link Jacobian.
|
561
|
-
# Compute here the full Jacobian of the model...
|
562
|
-
B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
|
563
|
-
model=model,
|
564
|
-
joint_positions=data.joint_positions(),
|
565
|
-
)
|
566
|
-
|
567
|
-
# ... and extract the link Jacobian using the boolean support body array.
|
568
|
-
B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WL_B
|
569
|
-
|
570
|
-
# Sum all the components that form the Jacobian derivative in the target
|
571
|
-
# input/output velocity representations.
|
572
|
-
O_J̇_WL_I = jnp.zeros(shape=(6, 6 + model.dofs()))
|
573
|
-
O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T
|
574
|
-
O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T
|
575
|
-
O_J̇_WL_I += O_X_B @ B_J_WL_B @ Ṫ
|
426
|
+
O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative(
|
427
|
+
model=model, data=data, output_vel_repr=output_vel_repr
|
428
|
+
)[link_index]
|
576
429
|
|
577
430
|
return O_J̇_WL_I
|
578
431
|
|
@@ -700,14 +700,173 @@ def generalized_free_floating_jacobian_derivative(
|
|
700
700
|
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
701
701
|
)
|
702
702
|
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
),
|
707
|
-
|
708
|
-
)
|
703
|
+
# Compute the derivative of the doubly-left free-floating full jacobian.
|
704
|
+
B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left(
|
705
|
+
model=model,
|
706
|
+
joint_positions=data.joint_positions(),
|
707
|
+
joint_velocities=data.joint_velocities(),
|
708
|
+
)
|
709
|
+
|
710
|
+
# The derivative of the equation to change the input and output representations
|
711
|
+
# of the Jacobian derivative needs the computation of the plain link Jacobian.
|
712
|
+
B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
|
713
|
+
model=model,
|
714
|
+
joint_positions=data.joint_positions(),
|
715
|
+
)
|
716
|
+
|
717
|
+
# Compute the actual doubly-left free-floating jacobian derivative of the link
|
718
|
+
# by zeroing the columns not in the path π_B(L) using the boolean κ(i).
|
719
|
+
κb = model.kin_dyn_parameters.support_body_array_bool
|
720
|
+
|
721
|
+
# Compute the base transform.
|
722
|
+
W_H_B = data.base_transform()
|
723
|
+
|
724
|
+
@functools.partial(jax.vmap, in_axes=(0, None, None, 0))
|
725
|
+
def _compute_row(
|
726
|
+
B_H_L: jtp.Matrix,
|
727
|
+
B_J_full_WL_B: jtp.Matrix,
|
728
|
+
W_H_B: jtp.Matrix,
|
729
|
+
κb: jtp.Matrix,
|
730
|
+
) -> jtp.Matrix:
|
731
|
+
|
732
|
+
# =====================================================
|
733
|
+
# Compute quantities to adjust the input representation
|
734
|
+
# =====================================================
|
735
|
+
|
736
|
+
In = jnp.eye(model.dofs())
|
737
|
+
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
|
738
|
+
|
739
|
+
# Extract the link quantities using the boolean support body array.
|
740
|
+
B_J̇_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J̇_full_WX_B
|
741
|
+
B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WL_B
|
742
|
+
|
743
|
+
match data.velocity_representation:
|
709
744
|
|
710
|
-
|
745
|
+
case VelRepr.Inertial:
|
746
|
+
|
747
|
+
B_X_W = jaxsim.math.Adjoint.from_transform(
|
748
|
+
transform=W_H_B, inverse=True
|
749
|
+
)
|
750
|
+
|
751
|
+
W_v_WB = data.base_velocity()
|
752
|
+
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
|
753
|
+
|
754
|
+
# Compute the operator to change the representation of ν, and its
|
755
|
+
# time derivative.
|
756
|
+
T = jax.scipy.linalg.block_diag(B_X_W, In)
|
757
|
+
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On)
|
758
|
+
|
759
|
+
case VelRepr.Body:
|
760
|
+
|
761
|
+
B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
|
762
|
+
translation=jnp.zeros(3), rotation=jnp.eye(3)
|
763
|
+
)
|
764
|
+
|
765
|
+
B_Ẋ_B = jnp.zeros(shape=(6, 6))
|
766
|
+
|
767
|
+
# Compute the operator to change the representation of ν, and its
|
768
|
+
# time derivative.
|
769
|
+
T = jax.scipy.linalg.block_diag(B_X_B, In)
|
770
|
+
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On)
|
771
|
+
|
772
|
+
case VelRepr.Mixed:
|
773
|
+
|
774
|
+
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
|
775
|
+
B_X_BW = jaxsim.math.Adjoint.from_transform(
|
776
|
+
transform=BW_H_B, inverse=True
|
777
|
+
)
|
778
|
+
|
779
|
+
BW_v_WB = data.base_velocity()
|
780
|
+
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
781
|
+
|
782
|
+
BW_v_BW_B = BW_v_WB - BW_v_W_BW
|
783
|
+
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
|
784
|
+
|
785
|
+
# Compute the operator to change the representation of ν, and its
|
786
|
+
# time derivative.
|
787
|
+
T = jax.scipy.linalg.block_diag(B_X_BW, In)
|
788
|
+
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On)
|
789
|
+
|
790
|
+
case _:
|
791
|
+
raise ValueError(data.velocity_representation)
|
792
|
+
|
793
|
+
# ======================================================
|
794
|
+
# Compute quantities to adjust the output representation
|
795
|
+
# ======================================================
|
796
|
+
|
797
|
+
match output_vel_repr:
|
798
|
+
|
799
|
+
case VelRepr.Inertial:
|
800
|
+
|
801
|
+
O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)
|
802
|
+
|
803
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
804
|
+
B_v_WB = data.base_velocity()
|
805
|
+
|
806
|
+
O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
|
807
|
+
|
808
|
+
case VelRepr.Body:
|
809
|
+
|
810
|
+
O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
|
811
|
+
transform=B_H_L, inverse=True
|
812
|
+
)
|
813
|
+
|
814
|
+
B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)
|
815
|
+
|
816
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
817
|
+
B_v_WB = data.base_velocity()
|
818
|
+
L_v_WL = L_X_B @ B_J_WL_B @ data.generalized_velocity()
|
819
|
+
|
820
|
+
O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
|
821
|
+
B_X_L @ L_v_WL - B_v_WB
|
822
|
+
)
|
823
|
+
|
824
|
+
case VelRepr.Mixed:
|
825
|
+
|
826
|
+
W_H_L = W_H_B @ B_H_L
|
827
|
+
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
|
828
|
+
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
|
829
|
+
|
830
|
+
O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B)
|
831
|
+
|
832
|
+
B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)
|
833
|
+
|
834
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
835
|
+
B_v_WB = data.base_velocity()
|
836
|
+
|
837
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
838
|
+
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
|
839
|
+
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
840
|
+
LW_v_WL = LW_X_B @ (
|
841
|
+
B_J_WL_B
|
842
|
+
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
843
|
+
@ data.generalized_velocity()
|
844
|
+
)
|
845
|
+
LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3))
|
846
|
+
|
847
|
+
LW_v_LW_L = LW_v_WL - LW_v_W_LW
|
848
|
+
LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L
|
849
|
+
|
850
|
+
O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
|
851
|
+
B_X_LW @ LW_v_B_LW
|
852
|
+
)
|
853
|
+
case _:
|
854
|
+
raise ValueError(output_vel_repr)
|
855
|
+
|
856
|
+
# =============================================================
|
857
|
+
# Express the Jacobian derivative in the target representations
|
858
|
+
# =============================================================
|
859
|
+
|
860
|
+
# Sum all the components that form the Jacobian derivative in the target
|
861
|
+
# input/output velocity representations.
|
862
|
+
O_J̇_WL_I = jnp.zeros(shape=(6, 6 + model.dofs()))
|
863
|
+
O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T
|
864
|
+
O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T
|
865
|
+
O_J̇_WL_I += O_X_B @ B_J_WL_B @ Ṫ
|
866
|
+
|
867
|
+
return O_J̇_WL_I
|
868
|
+
|
869
|
+
return _compute_row(B_H_L, B_J_full_WL_B, W_H_B, κb)
|
711
870
|
|
712
871
|
|
713
872
|
@functools.partial(jax.jit, static_argnames=["prefer_aba"])
|
@@ -1064,11 +1223,7 @@ def free_floating_coriolis_matrix(
|
|
1064
1223
|
L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
|
1065
1224
|
|
1066
1225
|
# Doubly-left free-floating Jacobian derivative.
|
1067
|
-
L_J̇_WL_B =
|
1068
|
-
lambda link_index: js.link.jacobian_derivative(
|
1069
|
-
model=model, data=data, link_index=link_index
|
1070
|
-
)
|
1071
|
-
)(jnp.arange(model.number_of_links()))
|
1226
|
+
L_J̇_WL_B = generalized_free_floating_jacobian_derivative(model=model, data=data)
|
1072
1227
|
|
1073
1228
|
L_M_L = link_spatial_inertia_matrices(model=model)
|
1074
1229
|
|
@@ -193,8 +193,11 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
193
193
|
return self.input.physics_model.f_ext
|
194
194
|
|
195
195
|
# If we have the model, we can extract the link names, if not provided.
|
196
|
-
|
197
|
-
|
196
|
+
link_idxs = (
|
197
|
+
js.link.names_to_idxs(link_names=link_names, model=model)
|
198
|
+
if link_names is not None
|
199
|
+
else jnp.arange(model.number_of_links())
|
200
|
+
)
|
198
201
|
|
199
202
|
# In inertial-fixed representation, we already have the link forces.
|
200
203
|
if self.velocity_representation is VelRepr.Inertial:
|
@@ -267,8 +270,11 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
267
270
|
msg = "The actuation object is not compatible with the provided model"
|
268
271
|
raise ValueError(msg)
|
269
272
|
|
270
|
-
|
271
|
-
|
273
|
+
joint_idxs = (
|
274
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
275
|
+
if joint_names is not None
|
276
|
+
else jnp.arange(model.number_of_joints())
|
277
|
+
)
|
272
278
|
|
273
279
|
return jnp.atleast_1d(
|
274
280
|
self.input.physics_model.tau[joint_idxs].squeeze()
|
@@ -318,8 +324,11 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
318
324
|
msg = "The references object is not compatible with the provided model"
|
319
325
|
raise ValueError(msg)
|
320
326
|
|
321
|
-
|
322
|
-
|
327
|
+
joint_idxs = (
|
328
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
329
|
+
if joint_names is not None
|
330
|
+
else jnp.arange(model.number_of_joints())
|
331
|
+
)
|
323
332
|
|
324
333
|
return replace(forces=self.input.physics_model.tau.at[joint_idxs].set(forces))
|
325
334
|
|
@@ -388,18 +397,16 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
388
397
|
|
389
398
|
return replace(forces=W_f0_L + W_f_L)
|
390
399
|
|
391
|
-
|
392
|
-
link_names = link_names if link_names is not None else model.link_names()
|
393
|
-
|
394
|
-
# Make sure that the link names are a tuple if they are provided by the user.
|
395
|
-
link_names = (link_names,) if isinstance(link_names, str) else link_names
|
396
|
-
|
397
|
-
if len(link_names) != f_L.shape[0]:
|
400
|
+
if link_names is not None and len(link_names) != f_L.shape[0]:
|
398
401
|
msg = "The number of link names ({}) must match the number of forces ({})"
|
399
402
|
raise ValueError(msg.format(len(link_names), f_L.shape[0]))
|
400
403
|
|
401
404
|
# Extract the link indices.
|
402
|
-
link_idxs =
|
405
|
+
link_idxs = (
|
406
|
+
js.link.names_to_idxs(link_names=link_names, model=model)
|
407
|
+
if link_names is not None
|
408
|
+
else jnp.arange(model.number_of_links())
|
409
|
+
)
|
403
410
|
|
404
411
|
# Compute the bias depending on whether we either set or add the link forces.
|
405
412
|
W_f0_L = (
|
@@ -480,22 +487,21 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
480
487
|
|
481
488
|
f_F = jnp.atleast_2d(forces).astype(float)
|
482
489
|
|
483
|
-
# If we have the model, we can extract the frame names if not provided.
|
484
|
-
frame_names = frame_names if frame_names is not None else model.frame_names()
|
485
|
-
|
486
|
-
# Make sure that the frame names are a tuple if they are provided by the user.
|
487
|
-
frame_names = (frame_names,) if isinstance(frame_names, str) else frame_names
|
488
|
-
|
489
490
|
if len(frame_names) != f_F.shape[0]:
|
490
491
|
msg = "The number of frame names ({}) must match the number of forces ({})"
|
491
492
|
raise ValueError(msg.format(len(frame_names), f_F.shape[0]))
|
492
493
|
|
493
494
|
# Extract the frame indices.
|
494
|
-
frame_idxs =
|
495
|
-
|
496
|
-
|
495
|
+
frame_idxs = (
|
496
|
+
js.frame.names_to_idxs(frame_names=frame_names, model=model)
|
497
|
+
if frame_names is not None
|
498
|
+
else jnp.arange(len(model.frame_names()))
|
497
499
|
)
|
498
500
|
|
501
|
+
parent_link_idxs = jnp.array(model.kin_dyn_parameters.frame_parameters.body)[
|
502
|
+
frame_idxs - model.number_of_links()
|
503
|
+
]
|
504
|
+
|
499
505
|
exceptions.raise_value_error_if(
|
500
506
|
condition=jnp.logical_not(data.valid(model=model)),
|
501
507
|
msg="The provided data is not valid for the model",
|
@@ -319,9 +319,8 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
319
319
|
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
|
320
320
|
|
321
321
|
# Initialize the carry of the for loop with the stacked kᵢ vectors.
|
322
|
-
carry0 = jax.
|
323
|
-
lambda l: jnp.
|
324
|
-
x0,
|
322
|
+
carry0 = jax.tree_map(
|
323
|
+
lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
|
325
324
|
)
|
326
325
|
|
327
326
|
# Closure on metadata to either evaluate the dynamics at the initial state
|
@@ -137,11 +137,12 @@ class Adjoint:
|
|
137
137
|
jtp.Matrix: The inverse adjoint matrix.
|
138
138
|
"""
|
139
139
|
A_X_B = adjoint
|
140
|
-
A_H_B = Adjoint.to_transform(adjoint=A_X_B)
|
141
140
|
|
142
|
-
A_R_B =
|
143
|
-
A_o_B = A_H_B[0:3, 3]
|
141
|
+
A_R_B = A_X_B[0:3, 0:3]
|
144
142
|
|
145
|
-
return
|
146
|
-
|
143
|
+
return jnp.vstack(
|
144
|
+
[
|
145
|
+
jnp.block([A_R_B.T, -A_R_B.T @ A_X_B[0:3, 3:6] @ A_R_B.T]),
|
146
|
+
jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]),
|
147
|
+
]
|
147
148
|
)
|