jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev366__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +86 -74
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/link.py +2 -2
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev366.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py
CHANGED
@@ -61,7 +61,6 @@ del _jnp_options
|
|
61
61
|
del _np_options
|
62
62
|
del _is_editable
|
63
63
|
|
64
|
-
from . import
|
65
|
-
from .
|
66
|
-
from .
|
67
|
-
from .simulation.simulator import JaxSim
|
64
|
+
from . import terrain # isort:skip
|
65
|
+
from . import api, integrators, logging, math, rbda
|
66
|
+
from .api.common import VelRepr
|
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.2.
|
16
|
-
__version_tuple__ = version_tuple = (0, 2, '
|
15
|
+
__version__ = version = '0.2.dev366'
|
16
|
+
__version_tuple__ = version_tuple = (0, 2, 'dev366')
|
jaxsim/api/__init__.py
CHANGED
jaxsim/api/com.py
ADDED
@@ -0,0 +1,240 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
4
|
+
|
5
|
+
import jaxsim.api as js
|
6
|
+
import jaxsim.math
|
7
|
+
import jaxsim.typing as jtp
|
8
|
+
|
9
|
+
from .common import VelRepr
|
10
|
+
|
11
|
+
|
12
|
+
@jax.jit
|
13
|
+
def com_position(
|
14
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
15
|
+
) -> jtp.Vector:
|
16
|
+
"""
|
17
|
+
Compute the position of the center of mass of the model.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
model: The model to consider.
|
21
|
+
data: The data of the considered model.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
The position of the center of mass of the model w.r.t. the world frame.
|
25
|
+
"""
|
26
|
+
|
27
|
+
m = js.model.total_mass(model=model)
|
28
|
+
|
29
|
+
W_H_L = js.model.forward_kinematics(model=model, data=data)
|
30
|
+
W_H_B = data.base_transform()
|
31
|
+
B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
|
32
|
+
|
33
|
+
def B_p̃_LCoM(i) -> jtp.Vector:
|
34
|
+
m = js.link.mass(model=model, link_index=i)
|
35
|
+
L_p_LCoM = js.link.com_position(
|
36
|
+
model=model, data=data, link_index=i, in_link_frame=True
|
37
|
+
)
|
38
|
+
return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
|
39
|
+
|
40
|
+
com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
|
41
|
+
|
42
|
+
B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
|
43
|
+
B_p̃_CoM = B_p̃_CoM.at[3].set(1)
|
44
|
+
|
45
|
+
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
|
46
|
+
|
47
|
+
|
48
|
+
@jax.jit
|
49
|
+
def com_linear_velocity(
|
50
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
51
|
+
) -> jtp.Vector:
|
52
|
+
r"""
|
53
|
+
Compute the linear velocity of the center of mass of the model.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
model: The model to consider.
|
57
|
+
data: The data of the considered model.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
The linear velocity of the center of mass of the model in the
|
61
|
+
active representation.
|
62
|
+
|
63
|
+
Note:
|
64
|
+
The linear velocity of the center of mass is expressed in the mixed frame
|
65
|
+
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
|
66
|
+
active velocity representation is either inertial-fixed or mixed,
|
67
|
+
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
|
68
|
+
"""
|
69
|
+
|
70
|
+
# Extract the linear component of the 6D average centroidal velocity.
|
71
|
+
# This is expressed in G[B] in body-fixed representation, and in G[W] in
|
72
|
+
# inertial-fixed or mixed representation.
|
73
|
+
G_vl_WG = average_centroidal_velocity(model=model, data=data)[0:3]
|
74
|
+
|
75
|
+
return G_vl_WG
|
76
|
+
|
77
|
+
|
78
|
+
@jax.jit
|
79
|
+
def centroidal_momentum(
|
80
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
81
|
+
) -> jtp.Vector:
|
82
|
+
r"""
|
83
|
+
Compute the centroidal momentum of the model.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
model: The model to consider.
|
87
|
+
data: The data of the considered model.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
The centroidal momentum of the model.
|
91
|
+
|
92
|
+
Note:
|
93
|
+
The centroidal momentum is expressed in the mixed frame
|
94
|
+
:math:`({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`C = W` if the
|
95
|
+
active velocity representation is either inertial-fixed or mixed,
|
96
|
+
and :math:`C = B` if the active velocity representation is body-fixed.
|
97
|
+
"""
|
98
|
+
|
99
|
+
ν = data.generalized_velocity()
|
100
|
+
G_J = centroidal_momentum_jacobian(model=model, data=data)
|
101
|
+
|
102
|
+
return G_J @ ν
|
103
|
+
|
104
|
+
|
105
|
+
@jax.jit
|
106
|
+
def centroidal_momentum_jacobian(
|
107
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
108
|
+
) -> jtp.Matrix:
|
109
|
+
r"""
|
110
|
+
Compute the Jacobian of the centroidal momentum of the model.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
model: The model to consider.
|
114
|
+
data: The data of the considered model.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
The Jacobian of the centroidal momentum of the model.
|
118
|
+
|
119
|
+
Note:
|
120
|
+
The frame corresponding to the output representation of this Jacobian is either
|
121
|
+
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
|
122
|
+
or :math:`G[B]`, if the active velocity representation is body-fixed.
|
123
|
+
|
124
|
+
Note:
|
125
|
+
This Jacobian is also known in the literature as Centroidal Momentum Matrix.
|
126
|
+
"""
|
127
|
+
|
128
|
+
# Compute the Jacobian of the total momentum with body-fixed output representation.
|
129
|
+
# We convert the output representation either to G[W] or G[B] below.
|
130
|
+
B_Jh = js.model.total_momentum_jacobian(
|
131
|
+
model=model, data=data, output_vel_repr=VelRepr.Body
|
132
|
+
)
|
133
|
+
|
134
|
+
W_H_B = data.base_transform()
|
135
|
+
B_H_W = jaxsim.math.Transform.inverse(W_H_B)
|
136
|
+
|
137
|
+
W_p_CoM = com_position(model=model, data=data)
|
138
|
+
|
139
|
+
match data.velocity_representation:
|
140
|
+
case VelRepr.Inertial | VelRepr.Mixed:
|
141
|
+
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
|
142
|
+
case VelRepr.Body:
|
143
|
+
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
|
144
|
+
case _:
|
145
|
+
raise ValueError(data.velocity_representation)
|
146
|
+
|
147
|
+
# Compute the transform for 6D forces.
|
148
|
+
G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T
|
149
|
+
|
150
|
+
return G_Xf_B @ B_Jh
|
151
|
+
|
152
|
+
|
153
|
+
@jax.jit
|
154
|
+
def locked_centroidal_spatial_inertia(
|
155
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
156
|
+
):
|
157
|
+
"""
|
158
|
+
Compute the locked centroidal spatial inertia of the model.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
model: The model to consider.
|
162
|
+
data: The data of the considered model.
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
The locked centroidal spatial inertia of the model.
|
166
|
+
"""
|
167
|
+
|
168
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
169
|
+
B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)
|
170
|
+
|
171
|
+
W_H_B = data.base_transform()
|
172
|
+
W_p_CoM = com_position(model=model, data=data)
|
173
|
+
|
174
|
+
match data.velocity_representation:
|
175
|
+
case VelRepr.Inertial | VelRepr.Mixed:
|
176
|
+
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
|
177
|
+
case VelRepr.Body:
|
178
|
+
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
|
179
|
+
case _:
|
180
|
+
raise ValueError(data.velocity_representation)
|
181
|
+
|
182
|
+
B_H_G = jaxlie.SE3.from_matrix(jaxsim.math.Transform.inverse(W_H_B) @ W_H_G)
|
183
|
+
|
184
|
+
B_Xv_G = B_H_G.adjoint()
|
185
|
+
G_Xf_B = B_Xv_G.transpose()
|
186
|
+
|
187
|
+
return G_Xf_B @ B_Mbb_B @ B_Xv_G
|
188
|
+
|
189
|
+
|
190
|
+
@jax.jit
|
191
|
+
def average_centroidal_velocity(
|
192
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
193
|
+
) -> jtp.Vector:
|
194
|
+
r"""
|
195
|
+
Compute the average centroidal velocity of the model.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
model: The model to consider.
|
199
|
+
data: The data of the considered model.
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
The average centroidal velocity of the model.
|
203
|
+
|
204
|
+
Note:
|
205
|
+
The average velocity is expressed in the mixed frame
|
206
|
+
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
|
207
|
+
active velocity representation is either inertial-fixed or mixed,
|
208
|
+
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
|
209
|
+
"""
|
210
|
+
|
211
|
+
ν = data.generalized_velocity()
|
212
|
+
G_J = average_centroidal_velocity_jacobian(model=model, data=data)
|
213
|
+
|
214
|
+
return G_J @ ν
|
215
|
+
|
216
|
+
|
217
|
+
@jax.jit
|
218
|
+
def average_centroidal_velocity_jacobian(
|
219
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
220
|
+
) -> jtp.Matrix:
|
221
|
+
r"""
|
222
|
+
Compute the Jacobian of the average centroidal velocity of the model.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
model: The model to consider.
|
226
|
+
data: The data of the considered model.
|
227
|
+
|
228
|
+
Returns:
|
229
|
+
The Jacobian of the average centroidal velocity of the model.
|
230
|
+
|
231
|
+
Note:
|
232
|
+
The frame corresponding to the output representation of this Jacobian is either
|
233
|
+
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
|
234
|
+
or :math:`G[B]`, if the active velocity representation is body-fixed.
|
235
|
+
"""
|
236
|
+
|
237
|
+
G_J = centroidal_momentum_jacobian(model=model, data=data)
|
238
|
+
G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data)
|
239
|
+
|
240
|
+
return jnp.linalg.inv(G_Mbb) @ G_J
|
jaxsim/api/common.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import abc
|
2
2
|
import contextlib
|
3
3
|
import dataclasses
|
4
|
+
import enum
|
4
5
|
import functools
|
5
6
|
from typing import ContextManager
|
6
7
|
|
@@ -11,7 +12,6 @@ import jaxlie
|
|
11
12
|
from jax_dataclasses import Static
|
12
13
|
|
13
14
|
import jaxsim.typing as jtp
|
14
|
-
from jaxsim.high_level.common import VelRepr
|
15
15
|
from jaxsim.utils import JaxsimDataclass, Mutability
|
16
16
|
|
17
17
|
try:
|
@@ -20,6 +20,17 @@ except ImportError:
|
|
20
20
|
from typing_extensions import Self
|
21
21
|
|
22
22
|
|
23
|
+
@enum.unique
|
24
|
+
class VelRepr(enum.IntEnum):
|
25
|
+
"""
|
26
|
+
Enumeration of all supported 6D velocity representations.
|
27
|
+
"""
|
28
|
+
|
29
|
+
Body = enum.auto()
|
30
|
+
Mixed = enum.auto()
|
31
|
+
Inertial = enum.auto()
|
32
|
+
|
33
|
+
|
23
34
|
@jax_dataclasses.pytree_dataclass
|
24
35
|
class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
25
36
|
"""
|
@@ -59,7 +70,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
59
70
|
# We run this in a mutable context with restoration so that any exception
|
60
71
|
# occurring, we restore the original object in case it was modified.
|
61
72
|
with self.mutable_context(
|
62
|
-
mutability=self.
|
73
|
+
mutability=self.mutability(), restore_after_exception=True
|
63
74
|
):
|
64
75
|
yield self
|
65
76
|
|
jaxsim/api/contact.py
CHANGED
@@ -3,16 +3,16 @@ import functools
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
5
|
|
6
|
+
import jaxsim.api as js
|
7
|
+
import jaxsim.rbda
|
6
8
|
import jaxsim.typing as jtp
|
7
|
-
from jaxsim.physics.algos import soft_contacts
|
8
9
|
|
9
|
-
from . import
|
10
|
-
from . import model as Model
|
10
|
+
from .common import VelRepr
|
11
11
|
|
12
12
|
|
13
13
|
@jax.jit
|
14
14
|
def collidable_point_kinematics(
|
15
|
-
model:
|
15
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
16
16
|
) -> tuple[jtp.Matrix, jtp.Matrix]:
|
17
17
|
"""
|
18
18
|
Compute the position and 3D velocity of the collidable points in the world frame.
|
@@ -30,21 +30,25 @@ def collidable_point_kinematics(
|
|
30
30
|
the linear component of the mixed 6D frame velocity.
|
31
31
|
"""
|
32
32
|
|
33
|
-
from jaxsim.
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
33
|
+
from jaxsim.rbda import collidable_points
|
34
|
+
|
35
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
36
|
+
W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel(
|
37
|
+
model=model,
|
38
|
+
base_position=data.base_position(),
|
39
|
+
base_quaternion=data.base_orientation(dcm=False),
|
40
|
+
joint_positions=data.joint_positions(model=model),
|
41
|
+
base_linear_velocity=data.base_velocity()[0:3],
|
42
|
+
base_angular_velocity=data.base_velocity()[3:6],
|
43
|
+
joint_velocities=data.joint_velocities(model=model),
|
44
|
+
)
|
41
45
|
|
42
|
-
return W_p_Ci
|
46
|
+
return W_p_Ci, W_ṗ_Ci
|
43
47
|
|
44
48
|
|
45
49
|
@jax.jit
|
46
50
|
def collidable_point_positions(
|
47
|
-
model:
|
51
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
48
52
|
) -> jtp.Matrix:
|
49
53
|
"""
|
50
54
|
Compute the position of the collidable points in the world frame.
|
@@ -62,7 +66,7 @@ def collidable_point_positions(
|
|
62
66
|
|
63
67
|
@jax.jit
|
64
68
|
def collidable_point_velocities(
|
65
|
-
model:
|
69
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
66
70
|
) -> jtp.Matrix:
|
67
71
|
"""
|
68
72
|
Compute the 3D velocity of the collidable points in the world frame.
|
@@ -78,10 +82,83 @@ def collidable_point_velocities(
|
|
78
82
|
return collidable_point_kinematics(model=model, data=data)[1]
|
79
83
|
|
80
84
|
|
85
|
+
@jax.jit
|
86
|
+
def collidable_point_forces(
|
87
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
88
|
+
) -> jtp.Matrix:
|
89
|
+
"""
|
90
|
+
Compute the 6D forces applied to each collidable point.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
model: The model to consider.
|
94
|
+
data: The data of the considered model.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
The 6D forces applied to each collidable point expressed in the frame
|
98
|
+
corresponding to the active representation.
|
99
|
+
"""
|
100
|
+
|
101
|
+
f_Ci, _ = collidable_point_dynamics(model=model, data=data)
|
102
|
+
|
103
|
+
return f_Ci
|
104
|
+
|
105
|
+
|
106
|
+
@jax.jit
|
107
|
+
def collidable_point_dynamics(
|
108
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
109
|
+
) -> tuple[jtp.Matrix, jtp.Matrix]:
|
110
|
+
r"""
|
111
|
+
Compute the 6D force applied to each collidable point and the corresponding
|
112
|
+
material deformation rate.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
model: The model to consider.
|
116
|
+
data: The data of the considered model.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
The 6D force applied to each collidable point and the corresponding
|
120
|
+
material deformation rate.
|
121
|
+
|
122
|
+
Note:
|
123
|
+
The material deformation rate is always returned in the mixed frame
|
124
|
+
`C[W] = ({}^W \mathbf{p}_C, [W])`. This is convenient for integration purpose.
|
125
|
+
Instead, the 6D forces are returned in the active representation.
|
126
|
+
"""
|
127
|
+
|
128
|
+
# Compute the position and linear velocities (mixed representation) of
|
129
|
+
# all collidable points belonging to the robot.
|
130
|
+
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
|
131
|
+
|
132
|
+
# Build the soft contact model.
|
133
|
+
soft_contacts = jaxsim.rbda.SoftContacts(
|
134
|
+
parameters=data.soft_contacts_params, terrain=model.terrain
|
135
|
+
)
|
136
|
+
|
137
|
+
# Compute the 6D force expressed in the inertial frame and applied to each
|
138
|
+
# collidable point, and the corresponding material deformation rate.
|
139
|
+
# Note that the material deformation rate is always returned in the mixed frame
|
140
|
+
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
|
141
|
+
W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)(
|
142
|
+
W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation
|
143
|
+
)
|
144
|
+
|
145
|
+
# Convert the 6D forces to the active representation.
|
146
|
+
f_Ci = jax.vmap(
|
147
|
+
lambda W_f_C: data.inertial_to_other_representation(
|
148
|
+
array=W_f_C,
|
149
|
+
other_representation=data.velocity_representation,
|
150
|
+
transform=data.base_transform(),
|
151
|
+
is_force=True,
|
152
|
+
)
|
153
|
+
)(W_f_Ci)
|
154
|
+
|
155
|
+
return f_Ci, CW_ṁ
|
156
|
+
|
157
|
+
|
81
158
|
@functools.partial(jax.jit, static_argnames=["link_names"])
|
82
159
|
def in_contact(
|
83
|
-
model:
|
84
|
-
data:
|
160
|
+
model: js.model.JaxSimModel,
|
161
|
+
data: js.data.JaxSimModelData,
|
85
162
|
*,
|
86
163
|
link_names: tuple[str, ...] | None = None,
|
87
164
|
) -> jtp.Vector:
|
@@ -100,48 +177,44 @@ def in_contact(
|
|
100
177
|
|
101
178
|
link_names = link_names if link_names is not None else model.link_names()
|
102
179
|
|
103
|
-
if set(link_names)
|
180
|
+
if set(link_names).difference(model.link_names()):
|
104
181
|
raise ValueError("One or more link names are not part of the model")
|
105
182
|
|
106
|
-
|
107
|
-
|
108
|
-
W_p_Ci, _ = collidable_points_pos_vel(
|
109
|
-
model=model.physics_model,
|
110
|
-
q=data.state.physics_model.joint_positions,
|
111
|
-
qd=data.state.physics_model.joint_velocities,
|
112
|
-
xfb=data.state.physics_model.xfb(),
|
113
|
-
)
|
183
|
+
W_p_Ci = collidable_point_positions(model=model, data=data)
|
114
184
|
|
115
185
|
terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
|
116
|
-
W_p_Ci[0
|
186
|
+
W_p_Ci[:, 0], W_p_Ci[:, 1]
|
117
187
|
)
|
118
188
|
|
119
|
-
below_terrain = W_p_Ci[2
|
189
|
+
below_terrain = W_p_Ci[:, 2] <= terrain_height
|
120
190
|
|
121
191
|
links_in_contact = jax.vmap(
|
122
192
|
lambda link_index: jnp.where(
|
123
|
-
model.
|
193
|
+
jnp.array(model.kin_dyn_parameters.contact_parameters.body) == link_index,
|
124
194
|
below_terrain,
|
125
195
|
jnp.zeros_like(below_terrain, dtype=bool),
|
126
196
|
).any()
|
127
|
-
)(
|
197
|
+
)(js.link.names_to_idxs(link_names=link_names, model=model))
|
128
198
|
|
129
199
|
return links_in_contact
|
130
200
|
|
131
201
|
|
132
202
|
@jax.jit
|
133
203
|
def estimate_good_soft_contacts_parameters(
|
134
|
-
model:
|
204
|
+
model: js.model.JaxSimModel,
|
205
|
+
*,
|
206
|
+
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
|
135
207
|
static_friction_coefficient: jtp.FloatLike = 0.5,
|
136
208
|
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
137
209
|
damping_ratio: jtp.FloatLike = 1.0,
|
138
210
|
max_penetration: jtp.FloatLike | None = None,
|
139
|
-
) -> soft_contacts.SoftContactsParams:
|
211
|
+
) -> jaxsim.rbda.soft_contacts.SoftContactsParams:
|
140
212
|
"""
|
141
213
|
Estimate good soft contacts parameters for the given model.
|
142
214
|
|
143
215
|
Args:
|
144
216
|
model: The model to consider.
|
217
|
+
standard_gravity: The standard gravity constant.
|
145
218
|
static_friction_coefficient: The static friction coefficient.
|
146
219
|
number_of_active_collidable_points_steady_state:
|
147
220
|
The number of active collidable points in steady state supporting
|
@@ -160,16 +233,17 @@ def estimate_good_soft_contacts_parameters(
|
|
160
233
|
specific application.
|
161
234
|
"""
|
162
235
|
|
163
|
-
def estimate_model_height(model:
|
236
|
+
def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
|
164
237
|
""""""
|
165
238
|
|
166
|
-
zero_data =
|
167
|
-
model=model,
|
239
|
+
zero_data = js.data.JaxSimModelData.build(
|
240
|
+
model=model,
|
241
|
+
soft_contacts_params=jaxsim.rbda.soft_contacts.SoftContactsParams(),
|
168
242
|
)
|
169
243
|
|
170
|
-
W_pz_CoM =
|
244
|
+
W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
|
171
245
|
|
172
|
-
if model.
|
246
|
+
if model.floating_base():
|
173
247
|
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
|
174
248
|
return 2 * (W_pz_CoM - W_pz_C.min())
|
175
249
|
|
@@ -183,12 +257,15 @@ def estimate_good_soft_contacts_parameters(
|
|
183
257
|
|
184
258
|
nc = number_of_active_collidable_points_steady_state
|
185
259
|
|
186
|
-
sc_parameters =
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
260
|
+
sc_parameters = (
|
261
|
+
jaxsim.rbda.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model(
|
262
|
+
model=model,
|
263
|
+
standard_gravity=standard_gravity,
|
264
|
+
static_friction_coefficient=static_friction_coefficient,
|
265
|
+
max_penetration=max_δ,
|
266
|
+
number_of_active_collidable_points_steady_state=nc,
|
267
|
+
damping_ratio=damping_ratio,
|
268
|
+
)
|
192
269
|
)
|
193
270
|
|
194
271
|
return sc_parameters
|