jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py
CHANGED
@@ -8,16 +8,37 @@ def _jnp_options() -> None:
|
|
8
8
|
|
9
9
|
import jax
|
10
10
|
|
11
|
-
#
|
12
|
-
|
13
|
-
|
11
|
+
# Check if running on TPU.
|
12
|
+
is_tpu = jax.devices()[0].platform == "tpu"
|
13
|
+
|
14
|
+
# Check if running on Metal.
|
15
|
+
is_metal = jax.devices()[0].platform == "METAL"
|
16
|
+
|
17
|
+
# Enable by default 64-bit precision to get accurate physics.
|
18
|
+
# Users can enforce 32-bit precision by setting the following variable to 0.
|
19
|
+
use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"
|
20
|
+
|
21
|
+
# Notify the user if unsupported 64-bit precision was enforced on TPU.
|
22
|
+
if (is_tpu or is_metal) and use_x64:
|
23
|
+
msg = f"64-bit precision is not allowed on {jax.devices()[0].platform.upper}. Enforcing 32bit precision."
|
24
|
+
logging.warning(msg)
|
25
|
+
use_x64 = False
|
26
|
+
|
27
|
+
if is_metal:
|
28
|
+
logging.warning(
|
29
|
+
"JAX Metal backend is experimental. Some functionalities may not be available."
|
30
|
+
)
|
31
|
+
|
32
|
+
# Enable 64-bit precision in JAX.
|
33
|
+
if use_x64:
|
34
|
+
logging.info("Enabling JAX to use 64-bit precision")
|
14
35
|
jax.config.update("jax_enable_x64", True)
|
15
36
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
37
|
+
# Warn about experimental usage of 32-bit precision.
|
38
|
+
else:
|
39
|
+
logging.warning(
|
40
|
+
"Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
|
41
|
+
)
|
21
42
|
|
22
43
|
|
23
44
|
def _np_options() -> None:
|
@@ -27,41 +48,71 @@ def _np_options() -> None:
|
|
27
48
|
|
28
49
|
|
29
50
|
def _is_editable() -> bool:
|
51
|
+
|
30
52
|
import importlib.util
|
31
53
|
import pathlib
|
32
54
|
import site
|
33
55
|
|
34
|
-
# Get the ModuleSpec of jaxsim
|
56
|
+
# Get the ModuleSpec of jaxsim.
|
35
57
|
jaxsim_spec = importlib.util.find_spec(name="jaxsim")
|
36
58
|
|
37
59
|
# This can be None. If it's None, assume non-editable installation.
|
38
60
|
if jaxsim_spec.origin is None:
|
39
61
|
return False
|
40
62
|
|
41
|
-
# Get the folder containing the jaxsim package
|
63
|
+
# Get the folder containing the jaxsim package.
|
42
64
|
jaxsim_package_dir = str(pathlib.Path(jaxsim_spec.origin).parent.parent)
|
43
65
|
|
44
|
-
# The installation is editable if the package dir is not in any {site|dist}-packages
|
66
|
+
# The installation is editable if the package dir is not in any {site|dist}-packages.
|
45
67
|
return jaxsim_package_dir not in site.getsitepackages()
|
46
68
|
|
47
69
|
|
48
|
-
|
49
|
-
|
50
|
-
logging
|
51
|
-
|
52
|
-
|
70
|
+
def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
|
71
|
+
"""
|
72
|
+
Get the default logging level.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
env_var: The environment variable to check.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
The logging level to set.
|
79
|
+
"""
|
80
|
+
|
81
|
+
import os
|
82
|
+
|
83
|
+
# Define the default logging level depending on the installation mode.
|
84
|
+
default_logging_level = (
|
85
|
+
logging.LoggingLevel.DEBUG
|
86
|
+
if _is_editable() # noqa: F821
|
87
|
+
else logging.LoggingLevel.WARNING
|
88
|
+
)
|
89
|
+
|
90
|
+
# Allow to override the default logging level with an environment variable.
|
91
|
+
try:
|
92
|
+
return logging.LoggingLevel[
|
93
|
+
os.environ.get(env_var, default_logging_level.name).upper()
|
94
|
+
]
|
95
|
+
|
96
|
+
except KeyError as exc:
|
97
|
+
msg = f"Invalid logging level defined in {env_var}='{os.environ[env_var]}'"
|
98
|
+
raise RuntimeError(msg) from exc
|
99
|
+
|
100
|
+
|
101
|
+
# Configure the logger with the default logging level.
|
102
|
+
logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL"))
|
103
|
+
|
53
104
|
|
54
|
-
# Configure JAX
|
105
|
+
# Configure JAX.
|
55
106
|
_jnp_options()
|
56
107
|
|
57
|
-
# Initialize the numpy print options
|
108
|
+
# Initialize the numpy print options.
|
58
109
|
_np_options()
|
59
110
|
|
60
111
|
del _jnp_options
|
61
112
|
del _np_options
|
113
|
+
del _get_default_logging_level
|
62
114
|
del _is_editable
|
63
115
|
|
64
|
-
from . import
|
65
|
-
from .
|
66
|
-
from .
|
67
|
-
from .simulation.simulator import JaxSim
|
116
|
+
from . import terrain # isort:skip
|
117
|
+
from . import api, integrators, logging, math, rbda
|
118
|
+
from .api.common import VelRepr
|
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.
|
16
|
-
__version_tuple__ = version_tuple = (0,
|
15
|
+
__version__ = version = '0.6.1.dev2'
|
16
|
+
__version_tuple__ = version_tuple = (0, 6, 1, 'dev2')
|
jaxsim/api/__init__.py
CHANGED
jaxsim/api/com.py
ADDED
@@ -0,0 +1,423 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
|
4
|
+
import jaxsim.api as js
|
5
|
+
import jaxsim.math
|
6
|
+
import jaxsim.typing as jtp
|
7
|
+
|
8
|
+
from .common import VelRepr
|
9
|
+
|
10
|
+
|
11
|
+
@jax.jit
|
12
|
+
@js.common.named_scope
|
13
|
+
def com_position(
|
14
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
15
|
+
) -> jtp.Vector:
|
16
|
+
"""
|
17
|
+
Compute the position of the center of mass of the model.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
model: The model to consider.
|
21
|
+
data: The data of the considered model.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
The position of the center of mass of the model w.r.t. the world frame.
|
25
|
+
"""
|
26
|
+
|
27
|
+
m = js.model.total_mass(model=model)
|
28
|
+
|
29
|
+
W_H_L = js.model.forward_kinematics(model=model, data=data)
|
30
|
+
W_H_B = data.base_transform()
|
31
|
+
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)
|
32
|
+
|
33
|
+
def B_p̃_LCoM(i) -> jtp.Vector:
|
34
|
+
m = js.link.mass(model=model, link_index=i)
|
35
|
+
L_p_LCoM = js.link.com_position(
|
36
|
+
model=model, data=data, link_index=i, in_link_frame=True
|
37
|
+
)
|
38
|
+
return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
|
39
|
+
|
40
|
+
com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
|
41
|
+
|
42
|
+
B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
|
43
|
+
B_p̃_CoM = B_p̃_CoM.at[3].set(1)
|
44
|
+
|
45
|
+
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
|
46
|
+
|
47
|
+
|
48
|
+
@jax.jit
|
49
|
+
@js.common.named_scope
|
50
|
+
def com_linear_velocity(
|
51
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
52
|
+
) -> jtp.Vector:
|
53
|
+
r"""
|
54
|
+
Compute the linear velocity of the center of mass of the model.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
model: The model to consider.
|
58
|
+
data: The data of the considered model.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
The linear velocity of the center of mass of the model in the
|
62
|
+
active representation.
|
63
|
+
|
64
|
+
Note:
|
65
|
+
The linear velocity of the center of mass is expressed in the mixed frame
|
66
|
+
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
|
67
|
+
active velocity representation is either inertial-fixed or mixed,
|
68
|
+
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
|
69
|
+
"""
|
70
|
+
|
71
|
+
# Extract the linear component of the 6D average centroidal velocity.
|
72
|
+
# This is expressed in G[B] in body-fixed representation, and in G[W] in
|
73
|
+
# inertial-fixed or mixed representation.
|
74
|
+
G_vl_WG = average_centroidal_velocity(model=model, data=data)[0:3]
|
75
|
+
|
76
|
+
return G_vl_WG
|
77
|
+
|
78
|
+
|
79
|
+
@jax.jit
|
80
|
+
@js.common.named_scope
|
81
|
+
def centroidal_momentum(
|
82
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
83
|
+
) -> jtp.Vector:
|
84
|
+
r"""
|
85
|
+
Compute the centroidal momentum of the model.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
model: The model to consider.
|
89
|
+
data: The data of the considered model.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
The centroidal momentum of the model.
|
93
|
+
|
94
|
+
Note:
|
95
|
+
The centroidal momentum is expressed in the mixed frame
|
96
|
+
:math:`({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`C = W` if the
|
97
|
+
active velocity representation is either inertial-fixed or mixed,
|
98
|
+
and :math:`C = B` if the active velocity representation is body-fixed.
|
99
|
+
"""
|
100
|
+
|
101
|
+
ν = data.generalized_velocity()
|
102
|
+
G_J = centroidal_momentum_jacobian(model=model, data=data)
|
103
|
+
|
104
|
+
return G_J @ ν
|
105
|
+
|
106
|
+
|
107
|
+
@jax.jit
|
108
|
+
@js.common.named_scope
|
109
|
+
def centroidal_momentum_jacobian(
|
110
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
111
|
+
) -> jtp.Matrix:
|
112
|
+
r"""
|
113
|
+
Compute the Jacobian of the centroidal momentum of the model.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
model: The model to consider.
|
117
|
+
data: The data of the considered model.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
The Jacobian of the centroidal momentum of the model.
|
121
|
+
|
122
|
+
Note:
|
123
|
+
The frame corresponding to the output representation of this Jacobian is either
|
124
|
+
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
|
125
|
+
or :math:`G[B]`, if the active velocity representation is body-fixed.
|
126
|
+
|
127
|
+
Note:
|
128
|
+
This Jacobian is also known in the literature as Centroidal Momentum Matrix.
|
129
|
+
"""
|
130
|
+
|
131
|
+
# Compute the Jacobian of the total momentum with body-fixed output representation.
|
132
|
+
# We convert the output representation either to G[W] or G[B] below.
|
133
|
+
B_Jh = js.model.total_momentum_jacobian(
|
134
|
+
model=model, data=data, output_vel_repr=VelRepr.Body
|
135
|
+
)
|
136
|
+
|
137
|
+
W_H_B = data.base_transform()
|
138
|
+
B_H_W = jaxsim.math.Transform.inverse(W_H_B)
|
139
|
+
|
140
|
+
W_p_CoM = com_position(model=model, data=data)
|
141
|
+
|
142
|
+
match data.velocity_representation:
|
143
|
+
case VelRepr.Inertial | VelRepr.Mixed:
|
144
|
+
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
|
145
|
+
case VelRepr.Body:
|
146
|
+
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
|
147
|
+
case _:
|
148
|
+
raise ValueError(data.velocity_representation)
|
149
|
+
|
150
|
+
# Compute the transform for 6D forces.
|
151
|
+
G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T
|
152
|
+
|
153
|
+
return G_Xf_B @ B_Jh
|
154
|
+
|
155
|
+
|
156
|
+
@jax.jit
|
157
|
+
@js.common.named_scope
|
158
|
+
def locked_centroidal_spatial_inertia(
|
159
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
160
|
+
):
|
161
|
+
"""
|
162
|
+
Compute the locked centroidal spatial inertia of the model.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
model: The model to consider.
|
166
|
+
data: The data of the considered model.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
The locked centroidal spatial inertia of the model.
|
170
|
+
"""
|
171
|
+
|
172
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
173
|
+
B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)
|
174
|
+
|
175
|
+
W_H_B = data.base_transform()
|
176
|
+
W_p_CoM = com_position(model=model, data=data)
|
177
|
+
|
178
|
+
match data.velocity_representation:
|
179
|
+
case VelRepr.Inertial | VelRepr.Mixed:
|
180
|
+
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
|
181
|
+
case VelRepr.Body:
|
182
|
+
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
|
183
|
+
case _:
|
184
|
+
raise ValueError(data.velocity_representation)
|
185
|
+
|
186
|
+
B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G
|
187
|
+
|
188
|
+
B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G)
|
189
|
+
G_Xf_B = B_Xv_G.transpose()
|
190
|
+
|
191
|
+
return G_Xf_B @ B_Mbb_B @ B_Xv_G
|
192
|
+
|
193
|
+
|
194
|
+
@jax.jit
|
195
|
+
@js.common.named_scope
|
196
|
+
def average_centroidal_velocity(
|
197
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
198
|
+
) -> jtp.Vector:
|
199
|
+
r"""
|
200
|
+
Compute the average centroidal velocity of the model.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
model: The model to consider.
|
204
|
+
data: The data of the considered model.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
The average centroidal velocity of the model.
|
208
|
+
|
209
|
+
Note:
|
210
|
+
The average velocity is expressed in the mixed frame
|
211
|
+
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
|
212
|
+
active velocity representation is either inertial-fixed or mixed,
|
213
|
+
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
|
214
|
+
"""
|
215
|
+
|
216
|
+
ν = data.generalized_velocity()
|
217
|
+
G_J = average_centroidal_velocity_jacobian(model=model, data=data)
|
218
|
+
|
219
|
+
return G_J @ ν
|
220
|
+
|
221
|
+
|
222
|
+
@jax.jit
|
223
|
+
@js.common.named_scope
|
224
|
+
def average_centroidal_velocity_jacobian(
|
225
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
226
|
+
) -> jtp.Matrix:
|
227
|
+
r"""
|
228
|
+
Compute the Jacobian of the average centroidal velocity of the model.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
model: The model to consider.
|
232
|
+
data: The data of the considered model.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
The Jacobian of the average centroidal velocity of the model.
|
236
|
+
|
237
|
+
Note:
|
238
|
+
The frame corresponding to the output representation of this Jacobian is either
|
239
|
+
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
|
240
|
+
or :math:`G[B]`, if the active velocity representation is body-fixed.
|
241
|
+
"""
|
242
|
+
|
243
|
+
G_J = centroidal_momentum_jacobian(model=model, data=data)
|
244
|
+
G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data)
|
245
|
+
|
246
|
+
return jnp.linalg.inv(G_Mbb) @ G_J
|
247
|
+
|
248
|
+
|
249
|
+
@jax.jit
|
250
|
+
@js.common.named_scope
|
251
|
+
def bias_acceleration(
|
252
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
253
|
+
) -> jtp.Vector:
|
254
|
+
r"""
|
255
|
+
Compute the bias linear acceleration of the center of mass.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
model: The model to consider.
|
259
|
+
data: The data of the considered model.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
The bias linear acceleration of the center of mass in the active representation.
|
263
|
+
|
264
|
+
Note:
|
265
|
+
The bias acceleration is expressed in the mixed frame
|
266
|
+
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
|
267
|
+
active velocity representation is either inertial-fixed or mixed,
|
268
|
+
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
|
269
|
+
"""
|
270
|
+
|
271
|
+
# Compute the pose of all links with forward kinematics.
|
272
|
+
W_H_L = js.model.forward_kinematics(model=model, data=data)
|
273
|
+
|
274
|
+
# Compute the bias acceleration of all links by zeroing the generalized velocity
|
275
|
+
# in the active representation.
|
276
|
+
v̇_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
|
277
|
+
|
278
|
+
def other_representation_to_body(
|
279
|
+
C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
|
280
|
+
) -> jtp.Vector:
|
281
|
+
"""
|
282
|
+
Convert the body-fixed representation of the link bias acceleration
|
283
|
+
C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
|
284
|
+
"""
|
285
|
+
|
286
|
+
L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C)
|
287
|
+
C_X_L = jaxsim.math.Adjoint.inverse(L_X_C)
|
288
|
+
|
289
|
+
L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC)
|
290
|
+
return L_v̇_WL
|
291
|
+
|
292
|
+
# We need here to get the body-fixed bias acceleration of the links.
|
293
|
+
# Since it's computed in the active representation, we need to convert it to body.
|
294
|
+
match data.velocity_representation:
|
295
|
+
|
296
|
+
case VelRepr.Body:
|
297
|
+
L_a_bias_WL = v̇_bias_WL
|
298
|
+
|
299
|
+
case VelRepr.Inertial:
|
300
|
+
|
301
|
+
C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
|
302
|
+
C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
303
|
+
|
304
|
+
L_H_C = L_H_W = jax.vmap( # noqa: F841
|
305
|
+
lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
|
306
|
+
)(W_H_L)
|
307
|
+
|
308
|
+
L_v_LC = L_v_LW = jax.vmap( # noqa: F841
|
309
|
+
lambda i: -js.link.velocity(
|
310
|
+
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
|
311
|
+
)
|
312
|
+
)(jnp.arange(model.number_of_links()))
|
313
|
+
|
314
|
+
L_a_bias_WL = jax.vmap(
|
315
|
+
lambda i: other_representation_to_body(
|
316
|
+
C_v̇_WL=C_v̇_WL[i],
|
317
|
+
C_v_WC=C_v_WC,
|
318
|
+
L_H_C=L_H_C[i],
|
319
|
+
L_v_LC=L_v_LC[i],
|
320
|
+
)
|
321
|
+
)(jnp.arange(model.number_of_links()))
|
322
|
+
|
323
|
+
case VelRepr.Mixed:
|
324
|
+
|
325
|
+
C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841
|
326
|
+
|
327
|
+
C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841
|
328
|
+
lambda i: js.link.velocity(
|
329
|
+
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
|
330
|
+
)
|
331
|
+
.at[3:6]
|
332
|
+
.set(jnp.zeros(3))
|
333
|
+
)(jnp.arange(model.number_of_links()))
|
334
|
+
|
335
|
+
L_H_C = L_H_LW = jax.vmap( # noqa: F841
|
336
|
+
lambda W_H_L: jaxsim.math.Transform.inverse(
|
337
|
+
W_H_L.at[0:3, 3].set(jnp.zeros(3))
|
338
|
+
)
|
339
|
+
)(W_H_L)
|
340
|
+
|
341
|
+
L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841
|
342
|
+
lambda i: -js.link.velocity(
|
343
|
+
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
|
344
|
+
)
|
345
|
+
.at[0:3]
|
346
|
+
.set(jnp.zeros(3))
|
347
|
+
)(jnp.arange(model.number_of_links()))
|
348
|
+
|
349
|
+
L_a_bias_WL = jax.vmap(
|
350
|
+
lambda i: other_representation_to_body(
|
351
|
+
C_v̇_WL=C_v̇_WL[i],
|
352
|
+
C_v_WC=C_v_WC[i],
|
353
|
+
L_H_C=L_H_C[i],
|
354
|
+
L_v_LC=L_v_LC[i],
|
355
|
+
)
|
356
|
+
)(jnp.arange(model.number_of_links()))
|
357
|
+
|
358
|
+
case _:
|
359
|
+
raise ValueError(data.velocity_representation)
|
360
|
+
|
361
|
+
# Compute the bias of the 6D momentum derivative.
|
362
|
+
def bias_momentum_derivative_term(
|
363
|
+
link_index: jtp.Int, L_a_bias_WL: jtp.Vector
|
364
|
+
) -> jtp.Vector:
|
365
|
+
|
366
|
+
# Get the body-fixed 6D inertia matrix.
|
367
|
+
L_M_L = js.link.spatial_inertia(model=model, link_index=link_index)
|
368
|
+
|
369
|
+
# Compute the body-fixed 6D velocity.
|
370
|
+
L_v_WL = js.link.velocity(
|
371
|
+
model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body
|
372
|
+
)
|
373
|
+
|
374
|
+
# Compute the world-to-link transformations for 6D forces.
|
375
|
+
W_Xf_L = jaxsim.math.Adjoint.from_transform(
|
376
|
+
transform=W_H_L[link_index], inverse=True
|
377
|
+
).T
|
378
|
+
|
379
|
+
# Compute the contribution of the link to the bias acceleration of the CoM.
|
380
|
+
W_ḣ_bias_link_contribution = W_Xf_L @ (
|
381
|
+
L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL
|
382
|
+
)
|
383
|
+
|
384
|
+
return W_ḣ_bias_link_contribution
|
385
|
+
|
386
|
+
# Sum the contributions of all links to the bias acceleration of the CoM.
|
387
|
+
W_ḣ_bias = jax.vmap(bias_momentum_derivative_term)(
|
388
|
+
jnp.arange(model.number_of_links()), L_a_bias_WL
|
389
|
+
).sum(axis=0)
|
390
|
+
|
391
|
+
# Compute the total mass of the model.
|
392
|
+
m = js.model.total_mass(model=model)
|
393
|
+
|
394
|
+
# Compute the position of the CoM.
|
395
|
+
W_p_CoM = com_position(model=model, data=data)
|
396
|
+
|
397
|
+
match data.velocity_representation:
|
398
|
+
|
399
|
+
# G := G[W] = (W_p_CoM, [W])
|
400
|
+
case VelRepr.Inertial | VelRepr.Mixed:
|
401
|
+
|
402
|
+
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
|
403
|
+
GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T
|
404
|
+
|
405
|
+
GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias
|
406
|
+
GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m
|
407
|
+
|
408
|
+
return GW_v̇l_com_bias
|
409
|
+
|
410
|
+
# G := G[B] = (W_p_CoM, [B])
|
411
|
+
case VelRepr.Body:
|
412
|
+
|
413
|
+
GB_Xf_W = jaxsim.math.Adjoint.from_transform(
|
414
|
+
transform=data.base_transform().at[0:3].set(W_p_CoM)
|
415
|
+
).T
|
416
|
+
|
417
|
+
GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
|
418
|
+
GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m
|
419
|
+
|
420
|
+
return GB_v̇l_com_bias
|
421
|
+
|
422
|
+
case _:
|
423
|
+
raise ValueError(data.velocity_representation)
|