jaxsim 0.4.3.dev64__tar.gz → 0.4.3.dev68__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/PKG-INFO +1 -2
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/environment.yml +0 -1
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/pyproject.toml +0 -2
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/__init__.py +0 -5
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/_version.py +2 -2
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/contact.py +1 -27
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/data.py +11 -40
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/joint.py +2 -62
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/model.py +1 -12
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/ode.py +24 -19
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/ode_data.py +1 -11
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/integrators/common.py +1 -1
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/inertia.py +1 -1
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/loaders.py +3 -3
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/kinematic_graph.py +3 -3
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/rod/parser.py +14 -18
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/contacts/rigid.py +41 -11
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/terrain/terrain.py +25 -41
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/typing.py +1 -1
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/utils/jaxsim_dataclass.py +9 -12
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/utils/wrappers.py +1 -1
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/PKG-INFO +1 -2
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/SOURCES.txt +0 -1
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/requires.txt +0 -1
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/conftest.py +0 -25
- jaxsim-0.4.3.dev64/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -384
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.devcontainer/Dockerfile +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.devcontainer/devcontainer.json +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.gitattributes +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.github/CODEOWNERS +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.github/workflows/ci_cd.yml +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.github/workflows/read_the_docs.yml +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.github/workflows/update_pixi_lockfile.yml +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.gitignore +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.pre-commit-config.yaml +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/.readthedocs.yaml +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/CONTRIBUTING.md +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/LICENSE +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/README.md +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/Makefile +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/conf.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/examples.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/guide/install.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/index.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/make.bat +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/api.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/integrators.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/math.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/mujoco.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/parsers.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/rbda.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/typing.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/docs/modules/utils.rst +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/.gitattributes +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/.gitignore +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/PD_controller.ipynb +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/Parallel_computing.ipynb +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/README.md +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/examples/assets/cartpole.urdf +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/pixi.lock +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/setup.cfg +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/setup.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/com.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/common.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/frame.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/link.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/api/references.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/exceptions.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/integrators/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/integrators/fixed_step.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/integrators/variable_step.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/logging.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/adjoint.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/cross.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/joint_model.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/quaternion.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/rotation.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/skew.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/math/transform.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/__main__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/model.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/mujoco/visualizer.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/collision.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/joint.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/link.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/descriptions/model.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/rod/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/parsers/rod/utils.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/aba.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/collidable_points.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/contacts/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/contacts/common.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/contacts/soft.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/crba.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/forward_kinematics.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/jacobian.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/rnea.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/rbda/utils.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/terrain/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/utils/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim/utils/tracing.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/dependency_links.txt +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/src/jaxsim.egg-info/top_level.txt +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/__init__.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_com.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_contact.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_data.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_frame.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_joint.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_link.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_api_model.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_automatic_differentiation.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_contact.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_exceptions.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_pytree.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/test_simulations.py +0 -0
- {jaxsim-0.4.3.dev64 → jaxsim-0.4.3.dev68}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev68
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -61,7 +61,6 @@ Description-Content-Type: text/markdown
|
|
61
61
|
License-File: LICENSE
|
62
62
|
Requires-Dist: coloredlogs
|
63
63
|
Requires-Dist: jax>=0.4.13
|
64
|
-
Requires-Dist: jaxopt>=0.8.0
|
65
64
|
Requires-Dist: jaxlib>=0.4.13
|
66
65
|
Requires-Dist: jaxlie>=1.3.0
|
67
66
|
Requires-Dist: jax_dataclasses>=1.4.0
|
@@ -45,7 +45,6 @@ classifiers = [
|
|
45
45
|
dependencies = [
|
46
46
|
"coloredlogs",
|
47
47
|
"jax >= 0.4.13",
|
48
|
-
"jaxopt >= 0.8.0",
|
49
48
|
"jaxlib >= 0.4.13",
|
50
49
|
"jaxlie >= 1.3.0",
|
51
50
|
"jax_dataclasses >= 1.4.0",
|
@@ -182,7 +181,6 @@ platforms = ["linux-64", "osx-arm64", "osx-64"]
|
|
182
181
|
coloredlogs = "*"
|
183
182
|
jax = "*"
|
184
183
|
jax-dataclasses = "*"
|
185
|
-
jaxopt = "*"
|
186
184
|
jaxlib = "*"
|
187
185
|
jaxlie = "*"
|
188
186
|
lxml = "*"
|
@@ -20,11 +20,6 @@ def _jnp_options() -> None:
|
|
20
20
|
if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
|
21
21
|
logging.warning("Failed to enable 64bit precision in JAX")
|
22
22
|
|
23
|
-
else:
|
24
|
-
logging.warning(
|
25
|
-
"Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
|
26
|
-
)
|
27
|
-
|
28
23
|
|
29
24
|
def _np_options() -> None:
|
30
25
|
import numpy as np
|
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev68'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev68')
|
@@ -131,8 +131,7 @@ def collidable_point_dynamics(
|
|
131
131
|
Returns:
|
132
132
|
The 6D force applied to each collidable point and additional data based on the contact model configured:
|
133
133
|
- Soft: the material deformation rate.
|
134
|
-
- Rigid:
|
135
|
-
- QuasiRigid: no additional data.
|
134
|
+
- Rigid: nothing.
|
136
135
|
|
137
136
|
Note:
|
138
137
|
The material deformation rate is always returned in the mixed frame
|
@@ -145,10 +144,6 @@ def collidable_point_dynamics(
|
|
145
144
|
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
|
146
145
|
|
147
146
|
# Import privately the contacts classes.
|
148
|
-
from jaxsim.rbda.contacts.relaxed_rigid import (
|
149
|
-
RelaxedRigidContacts,
|
150
|
-
RelaxedRigidContactsState,
|
151
|
-
)
|
152
147
|
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
153
148
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
154
149
|
|
@@ -195,27 +190,6 @@ def collidable_point_dynamics(
|
|
195
190
|
|
196
191
|
aux_data = dict()
|
197
192
|
|
198
|
-
case RelaxedRigidContacts():
|
199
|
-
assert isinstance(model.contact_model, RelaxedRigidContacts)
|
200
|
-
assert isinstance(data.state.contact, RelaxedRigidContactsState)
|
201
|
-
|
202
|
-
# Build the contact model.
|
203
|
-
relaxed_rigid_contacts = RelaxedRigidContacts(
|
204
|
-
parameters=data.contacts_params, terrain=model.terrain
|
205
|
-
)
|
206
|
-
|
207
|
-
# Compute the 6D force expressed in the inertial frame and applied to each
|
208
|
-
# collidable point.
|
209
|
-
W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
|
210
|
-
position=W_p_Ci,
|
211
|
-
velocity=W_ṗ_Ci,
|
212
|
-
model=model,
|
213
|
-
data=data,
|
214
|
-
link_forces=link_forces,
|
215
|
-
)
|
216
|
-
|
217
|
-
aux_data = dict()
|
218
|
-
|
219
193
|
case _:
|
220
194
|
raise ValueError(f"Invalid contact model {model.contact_model}")
|
221
195
|
|
@@ -39,9 +39,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
39
39
|
contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
|
40
40
|
|
41
41
|
time_ns: jtp.Int = dataclasses.field(
|
42
|
-
default_factory=lambda: jnp.array(
|
43
|
-
0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
|
44
|
-
),
|
42
|
+
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
|
45
43
|
)
|
46
44
|
|
47
45
|
def __hash__(self) -> int:
|
@@ -174,14 +172,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
174
172
|
)
|
175
173
|
|
176
174
|
time_ns = (
|
177
|
-
jnp.array(
|
178
|
-
time * 1e9,
|
179
|
-
dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
|
180
|
-
)
|
175
|
+
jnp.array(time * 1e9, dtype=jnp.uint64)
|
181
176
|
if time is not None
|
182
|
-
else jnp.array(
|
183
|
-
0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
|
184
|
-
)
|
177
|
+
else jnp.array(0, dtype=jnp.uint64)
|
185
178
|
)
|
186
179
|
|
187
180
|
if isinstance(model.contact_model, SoftContacts):
|
@@ -593,18 +586,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
593
586
|
The updated `JaxSimModelData` object.
|
594
587
|
"""
|
595
588
|
|
596
|
-
|
597
|
-
|
598
|
-
W_Q_B = jax.lax.select(
|
599
|
-
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
|
600
|
-
on_true=W_Q_B,
|
601
|
-
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
|
602
|
-
)
|
589
|
+
base_quaternion = jnp.array(base_quaternion)
|
603
590
|
|
604
591
|
return self.replace(
|
605
592
|
validate=True,
|
606
593
|
state=self.state.replace(
|
607
|
-
physics_model=self.state.physics_model.replace(
|
594
|
+
physics_model=self.state.physics_model.replace(
|
595
|
+
base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
|
596
|
+
float
|
597
|
+
)
|
598
|
+
)
|
608
599
|
),
|
609
600
|
)
|
610
601
|
|
@@ -746,13 +737,6 @@ def random_model_data(
|
|
746
737
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
747
738
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
748
739
|
] = ((-1, -1, 0.5), 1.0),
|
749
|
-
joint_pos_bounds: (
|
750
|
-
tuple[
|
751
|
-
jtp.FloatLike | Sequence[jtp.FloatLike],
|
752
|
-
jtp.FloatLike | Sequence[jtp.FloatLike],
|
753
|
-
]
|
754
|
-
| None
|
755
|
-
) = None,
|
756
740
|
base_vel_lin_bounds: tuple[
|
757
741
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
758
742
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
@@ -778,8 +762,6 @@ def random_model_data(
|
|
778
762
|
key: The random key.
|
779
763
|
velocity_representation: The velocity representation to use.
|
780
764
|
base_pos_bounds: The bounds for the base position.
|
781
|
-
joint_pos_bounds:
|
782
|
-
The bounds for the joint positions (reading the joint limits if None).
|
783
765
|
base_vel_lin_bounds: The bounds for the base linear velocity.
|
784
766
|
base_vel_ang_bounds: The bounds for the base angular velocity.
|
785
767
|
joint_vel_bounds: The bounds for the joint velocities.
|
@@ -824,19 +806,8 @@ def random_model_data(
|
|
824
806
|
).wxyz
|
825
807
|
|
826
808
|
if model.number_of_joints() > 0:
|
827
|
-
|
828
|
-
|
829
|
-
jnp.array(joint_pos_bounds, dtype=float)
|
830
|
-
if joint_pos_bounds is not None
|
831
|
-
else (None, None)
|
832
|
-
)
|
833
|
-
|
834
|
-
physics_model_state.joint_positions = (
|
835
|
-
js.joint.random_joint_positions(model=model, key=k3)
|
836
|
-
if (s_min is None or s_max is None)
|
837
|
-
else jax.random.uniform(
|
838
|
-
key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
|
839
|
-
)
|
809
|
+
physics_model_state.joint_positions = js.joint.random_joint_positions(
|
810
|
+
model=model, key=k3
|
840
811
|
)
|
841
812
|
|
842
813
|
physics_model_state.joint_velocities = jax.random.uniform(
|
@@ -180,77 +180,17 @@ def random_joint_positions(
|
|
180
180
|
|
181
181
|
Args:
|
182
182
|
model: The model to consider.
|
183
|
-
joint_names: The names of the
|
184
|
-
key: The random key
|
185
|
-
|
186
|
-
Note:
|
187
|
-
If the joint range or revolute joints is larger than 2π, their joint positions
|
188
|
-
will be sampled from an interval of size 2π.
|
183
|
+
joint_names: The names of the joints.
|
184
|
+
key: The random key.
|
189
185
|
|
190
186
|
Returns:
|
191
187
|
The random joint positions.
|
192
188
|
"""
|
193
189
|
|
194
|
-
# Consider the key corresponding to a zero seed if it was not passed.
|
195
190
|
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
196
191
|
|
197
|
-
# Get the joint limits parsed from the model description.
|
198
192
|
s_min, s_max = position_limits(model=model, joint_names=joint_names)
|
199
193
|
|
200
|
-
# Get the joint indices.
|
201
|
-
# Note that it will trigger an exception if the given `joint_names` are not valid.
|
202
|
-
joint_names = joint_names if joint_names is not None else model.joint_names()
|
203
|
-
joint_indices = names_to_idxs(model=model, joint_names=joint_names)
|
204
|
-
|
205
|
-
from jaxsim.parsers.descriptions.joint import JointType
|
206
|
-
|
207
|
-
# Filter for revolute joints.
|
208
|
-
is_revolute = jnp.where(
|
209
|
-
jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
|
210
|
-
== JointType.Revolute,
|
211
|
-
True,
|
212
|
-
False,
|
213
|
-
)
|
214
|
-
|
215
|
-
# Shorthand for π.
|
216
|
-
π = jnp.pi
|
217
|
-
|
218
|
-
# Filter for revolute with full range (or continuous).
|
219
|
-
is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
|
220
|
-
|
221
|
-
# Clip the lower limit to -π if the joint range is larger than [-π, π].
|
222
|
-
s_min = jnp.where(
|
223
|
-
jnp.logical_and(
|
224
|
-
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
225
|
-
),
|
226
|
-
-π,
|
227
|
-
s_min,
|
228
|
-
)
|
229
|
-
|
230
|
-
# Clip the upper limit to +π if the joint range is larger than [-π, π].
|
231
|
-
s_max = jnp.where(
|
232
|
-
jnp.logical_and(
|
233
|
-
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
234
|
-
),
|
235
|
-
π,
|
236
|
-
s_max,
|
237
|
-
)
|
238
|
-
|
239
|
-
# Shift the lower limit if the upper limit is smaller than +π.
|
240
|
-
s_min = jnp.where(
|
241
|
-
jnp.logical_and(is_revolute_full_range, s_max < π),
|
242
|
-
s_max - 2 * π,
|
243
|
-
s_min,
|
244
|
-
)
|
245
|
-
|
246
|
-
# Shift the upper limit if the lower limit is larger than -π.
|
247
|
-
s_max = jnp.where(
|
248
|
-
jnp.logical_and(is_revolute_full_range, s_min > -π),
|
249
|
-
s_min + 2 * π,
|
250
|
-
s_max,
|
251
|
-
)
|
252
|
-
|
253
|
-
# Sample the joint positions.
|
254
194
|
s_random = jax.random.uniform(
|
255
195
|
minval=s_min,
|
256
196
|
maxval=s_max,
|
@@ -1931,22 +1931,11 @@ def step(
|
|
1931
1931
|
),
|
1932
1932
|
)
|
1933
1933
|
|
1934
|
-
tf_ns = t0_ns + jnp.array(dt * 1e9, dtype=t0_ns.dtype)
|
1935
|
-
tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))
|
1936
|
-
|
1937
|
-
jax.lax.cond(
|
1938
|
-
pred=tf_ns < t0_ns,
|
1939
|
-
true_fun=lambda: jax.debug.print(
|
1940
|
-
"The simulation time overflowed, resetting simulation time to 0."
|
1941
|
-
),
|
1942
|
-
false_fun=lambda: None,
|
1943
|
-
)
|
1944
|
-
|
1945
1934
|
data_tf = (
|
1946
1935
|
# Store the new state of the model and the new time.
|
1947
1936
|
data.replace(
|
1948
1937
|
state=state_tf,
|
1949
|
-
time_ns=
|
1938
|
+
time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
|
1950
1939
|
)
|
1951
1940
|
)
|
1952
1941
|
|
@@ -175,15 +175,17 @@ def system_velocity_dynamics(
|
|
175
175
|
forces=W_f_Li_terrain,
|
176
176
|
additive=True,
|
177
177
|
)
|
178
|
-
|
179
|
-
|
178
|
+
# Get the link forces in the data representation
|
179
|
+
with references.switch_velocity_representation(data.velocity_representation):
|
180
180
|
f_L_total = references.link_forces(model=model, data=data)
|
181
181
|
|
182
|
-
|
183
|
-
|
184
|
-
|
182
|
+
# The following method always returns the inertial-fixed acceleration, and expects
|
183
|
+
# the link_forces expressed in the inertial frame.
|
184
|
+
W_v̇_WB, s̈ = system_acceleration(
|
185
|
+
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
|
186
|
+
)
|
185
187
|
|
186
|
-
return
|
188
|
+
return W_v̇_WB, s̈, aux_data
|
187
189
|
|
188
190
|
|
189
191
|
def system_acceleration(
|
@@ -194,7 +196,7 @@ def system_acceleration(
|
|
194
196
|
link_forces: jtp.MatrixLike | None = None,
|
195
197
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
196
198
|
"""
|
197
|
-
Compute the system acceleration in
|
199
|
+
Compute the system acceleration in inertial-fixed representation.
|
198
200
|
|
199
201
|
Args:
|
200
202
|
model: The model to consider.
|
@@ -204,7 +206,7 @@ def system_acceleration(
|
|
204
206
|
The 6D forces to apply to the links expressed in the same representation of data.
|
205
207
|
|
206
208
|
Returns:
|
207
|
-
A tuple containing the base 6D acceleration in
|
209
|
+
A tuple containing the base 6D acceleration in inertial-fixed representation
|
208
210
|
and the joint accelerations.
|
209
211
|
"""
|
210
212
|
|
@@ -270,15 +272,18 @@ def system_acceleration(
|
|
270
272
|
)
|
271
273
|
|
272
274
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
273
|
-
# - Base acceleration:
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
275
|
+
# - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
|
276
|
+
with (
|
277
|
+
data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
|
278
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
279
|
+
):
|
280
|
+
W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
281
|
+
model=model,
|
282
|
+
data=data,
|
283
|
+
joint_forces=references.joint_force_references(),
|
284
|
+
link_forces=references.link_forces(),
|
285
|
+
)
|
286
|
+
return W_v̇_WB, s̈
|
282
287
|
|
283
288
|
|
284
289
|
@jax.jit
|
@@ -348,7 +353,7 @@ def system_dynamics(
|
|
348
353
|
corresponding derivative, and the dictionary of auxiliary data returned
|
349
354
|
by the system dynamics evaluation.
|
350
355
|
"""
|
351
|
-
|
356
|
+
|
352
357
|
from jaxsim.rbda.contacts.rigid import RigidContacts
|
353
358
|
from jaxsim.rbda.contacts.soft import SoftContacts
|
354
359
|
|
@@ -366,7 +371,7 @@ def system_dynamics(
|
|
366
371
|
case SoftContacts():
|
367
372
|
ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
|
368
373
|
|
369
|
-
case RigidContacts()
|
374
|
+
case RigidContacts():
|
370
375
|
pass
|
371
376
|
|
372
377
|
case _:
|
@@ -6,10 +6,6 @@ import jax_dataclasses
|
|
6
6
|
import jaxsim.api as js
|
7
7
|
import jaxsim.typing as jtp
|
8
8
|
from jaxsim.rbda import ContactsState
|
9
|
-
from jaxsim.rbda.contacts.relaxed_rigid import (
|
10
|
-
RelaxedRigidContacts,
|
11
|
-
RelaxedRigidContactsState,
|
12
|
-
)
|
13
9
|
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
14
10
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
15
11
|
from jaxsim.utils import JaxsimDataclass
|
@@ -177,10 +173,6 @@ class ODEState(JaxsimDataclass):
|
|
177
173
|
)
|
178
174
|
case RigidContacts():
|
179
175
|
contact = RigidContactsState.build()
|
180
|
-
|
181
|
-
case RelaxedRigidContacts():
|
182
|
-
contact = RelaxedRigidContactsState.build()
|
183
|
-
|
184
176
|
case _:
|
185
177
|
raise ValueError("Unable to determine contact state class prefix.")
|
186
178
|
|
@@ -224,9 +216,7 @@ class ODEState(JaxsimDataclass):
|
|
224
216
|
|
225
217
|
# Get the contact model from the `JaxSimModel`.
|
226
218
|
match contact:
|
227
|
-
case (
|
228
|
-
SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
|
229
|
-
):
|
219
|
+
case SoftContactsState() | RigidContactsState():
|
230
220
|
pass
|
231
221
|
case None:
|
232
222
|
contact = SoftContactsState.zero(model=model)
|
@@ -497,7 +497,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
497
497
|
b: jtp.Matrix,
|
498
498
|
c: jtp.Vector,
|
499
499
|
index_of_solution: jtp.IntLike = 0,
|
500
|
-
) ->
|
500
|
+
) -> [bool, int | None]:
|
501
501
|
"""
|
502
502
|
Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
|
503
503
|
|
@@ -45,7 +45,7 @@ class Inertia:
|
|
45
45
|
M (jtp.Matrix): The 6x6 inertia matrix.
|
46
46
|
|
47
47
|
Returns:
|
48
|
-
|
48
|
+
Tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
|
49
49
|
|
50
50
|
Raises:
|
51
51
|
ValueError: If the input matrix M has an unexpected shape.
|
@@ -211,7 +211,7 @@ class RodModelToMjcf:
|
|
211
211
|
joints_dict = {j.name: j for j in rod_model.joints()}
|
212
212
|
|
213
213
|
# Convert all the joints not considered to fixed joints.
|
214
|
-
for joint_name in
|
214
|
+
for joint_name in set(j.name for j in rod_model.joints()) - considered_joints:
|
215
215
|
joints_dict[joint_name].type = "fixed"
|
216
216
|
|
217
217
|
# Convert the ROD model to URDF.
|
@@ -289,10 +289,10 @@ class RodModelToMjcf:
|
|
289
289
|
mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
|
290
290
|
|
291
291
|
# Get the joint names.
|
292
|
-
mj_joint_names =
|
292
|
+
mj_joint_names = set(
|
293
293
|
mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
|
294
294
|
for idx in range(mj_model.njnt)
|
295
|
-
|
295
|
+
)
|
296
296
|
|
297
297
|
# Check that the Mujoco model only has the considered joints.
|
298
298
|
if mj_joint_names != considered_joints:
|
@@ -394,7 +394,7 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
394
394
|
return copy.deepcopy(self)
|
395
395
|
|
396
396
|
# Check if all considered joints are part of the full kinematic graph
|
397
|
-
if len(set(considered_joints) -
|
397
|
+
if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
|
398
398
|
extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
|
399
399
|
msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
|
400
400
|
raise ValueError(msg)
|
@@ -536,8 +536,8 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
536
536
|
root_link_name=full_graph.root.name,
|
537
537
|
)
|
538
538
|
|
539
|
-
assert
|
540
|
-
|
539
|
+
assert set(f.name for f in self.frames).isdisjoint(
|
540
|
+
set(f.name for f in unconnected_frames + reduced_frames)
|
541
541
|
)
|
542
542
|
|
543
543
|
for link in unconnected_links:
|
@@ -223,7 +223,7 @@ def extract_model_data(
|
|
223
223
|
child=links_dict[j.child],
|
224
224
|
jtype=utils.joint_to_joint_type(joint=j),
|
225
225
|
axis=(
|
226
|
-
np.array(j.axis.xyz.xyz
|
226
|
+
np.array(j.axis.xyz.xyz)
|
227
227
|
if j.axis is not None
|
228
228
|
and j.axis.xyz is not None
|
229
229
|
and j.axis.xyz.xyz is not None
|
@@ -232,43 +232,39 @@ def extract_model_data(
|
|
232
232
|
pose=j.pose.transform() if j.pose is not None else np.eye(4),
|
233
233
|
initial_position=0.0,
|
234
234
|
position_limit=(
|
235
|
-
|
236
|
-
j.axis.limit.lower
|
237
|
-
if j.axis is not None
|
238
|
-
|
239
|
-
and j.axis.limit.lower is not None
|
240
|
-
else jnp.finfo(float).min
|
235
|
+
(
|
236
|
+
float(j.axis.limit.lower)
|
237
|
+
if j.axis is not None and j.axis.limit is not None
|
238
|
+
else np.finfo(float).min
|
241
239
|
),
|
242
|
-
|
243
|
-
j.axis.limit.upper
|
244
|
-
if j.axis is not None
|
245
|
-
|
246
|
-
and j.axis.limit.upper is not None
|
247
|
-
else jnp.finfo(float).max
|
240
|
+
(
|
241
|
+
float(j.axis.limit.upper)
|
242
|
+
if j.axis is not None and j.axis.limit is not None
|
243
|
+
else np.finfo(float).max
|
248
244
|
),
|
249
245
|
),
|
250
|
-
friction_static=
|
246
|
+
friction_static=(
|
251
247
|
j.axis.dynamics.friction
|
252
248
|
if j.axis is not None
|
253
249
|
and j.axis.dynamics is not None
|
254
250
|
and j.axis.dynamics.friction is not None
|
255
251
|
else 0.0
|
256
252
|
),
|
257
|
-
friction_viscous=
|
253
|
+
friction_viscous=(
|
258
254
|
j.axis.dynamics.damping
|
259
255
|
if j.axis is not None
|
260
256
|
and j.axis.dynamics is not None
|
261
257
|
and j.axis.dynamics.damping is not None
|
262
258
|
else 0.0
|
263
259
|
),
|
264
|
-
position_limit_damper=
|
260
|
+
position_limit_damper=(
|
265
261
|
j.axis.limit.dissipation
|
266
262
|
if j.axis is not None
|
267
263
|
and j.axis.limit is not None
|
268
264
|
and j.axis.limit.dissipation is not None
|
269
265
|
else 0.0
|
270
266
|
),
|
271
|
-
position_limit_spring=
|
267
|
+
position_limit_spring=(
|
272
268
|
j.axis.limit.stiffness
|
273
269
|
if j.axis is not None
|
274
270
|
and j.axis.limit is not None
|
@@ -277,7 +273,7 @@ def extract_model_data(
|
|
277
273
|
),
|
278
274
|
)
|
279
275
|
for j in sdf_model.joints()
|
280
|
-
if j.type in {"revolute", "
|
276
|
+
if j.type in {"revolute", "prismatic", "fixed"}
|
281
277
|
and j.parent != "world"
|
282
278
|
and j.child in links_dict.keys()
|
283
279
|
]
|
@@ -9,6 +9,7 @@ import jax_dataclasses
|
|
9
9
|
|
10
10
|
import jaxsim.api as js
|
11
11
|
import jaxsim.typing as jtp
|
12
|
+
from jaxsim import math
|
12
13
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
13
14
|
from jaxsim.terrain import FlatTerrain, Terrain
|
14
15
|
|
@@ -271,17 +272,9 @@ class RigidContacts(ContactModel):
|
|
271
272
|
link_forces=link_forces,
|
272
273
|
)
|
273
274
|
|
274
|
-
with (
|
275
|
-
|
276
|
-
|
277
|
-
):
|
278
|
-
BW_ν̇_free = jnp.hstack(
|
279
|
-
js.ode.system_acceleration(
|
280
|
-
model=model,
|
281
|
-
data=data,
|
282
|
-
joint_forces=references.joint_force_references(model=model),
|
283
|
-
link_forces=references.link_forces(model=model, data=data),
|
284
|
-
)
|
275
|
+
with references.switch_velocity_representation(VelRepr.Mixed):
|
276
|
+
BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
|
277
|
+
model, data, references=references
|
285
278
|
)
|
286
279
|
|
287
280
|
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
|
@@ -387,6 +380,43 @@ class RigidContacts(ContactModel):
|
|
387
380
|
n_constraints = 6 * n_collidable_points
|
388
381
|
return jnp.zeros(shape=(n_constraints,))
|
389
382
|
|
383
|
+
@staticmethod
|
384
|
+
def _compute_mixed_nu_dot_free(
|
385
|
+
model: js.model.JaxSimModel,
|
386
|
+
data: js.data.JaxSimModelData,
|
387
|
+
references: js.references.JaxSimModelReferences | None = None,
|
388
|
+
) -> jtp.Array:
|
389
|
+
references = (
|
390
|
+
references
|
391
|
+
if references is not None
|
392
|
+
else js.references.JaxSimModelReferences.zero(model=model, data=data)
|
393
|
+
)
|
394
|
+
|
395
|
+
with (
|
396
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
397
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
398
|
+
):
|
399
|
+
BW_v_WB = data.base_velocity()
|
400
|
+
W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
|
401
|
+
W_v̇_WB, s̈ = js.ode.system_acceleration(
|
402
|
+
model=model,
|
403
|
+
data=data,
|
404
|
+
joint_forces=references.joint_force_references(model=model),
|
405
|
+
link_forces=references.link_forces(model=model, data=data),
|
406
|
+
)
|
407
|
+
|
408
|
+
# Convert the inertial-fixed base acceleration to a mixed base acceleration.
|
409
|
+
W_H_B = data.base_transform()
|
410
|
+
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
411
|
+
BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
|
412
|
+
term1 = BW_X_W @ W_v̇_WB
|
413
|
+
term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
|
414
|
+
BW_v̇_WB = term1 - term2
|
415
|
+
|
416
|
+
BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
|
417
|
+
|
418
|
+
return BW_ν̇
|
419
|
+
|
390
420
|
@staticmethod
|
391
421
|
def _linear_acceleration_of_collidable_points(
|
392
422
|
model: js.model.JaxSimModel,
|