jaxsim 0.6.1.dev13__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
- jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.1.dev13.dist-info/RECORD +0 -74
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/model.py
CHANGED
@@ -148,7 +148,7 @@ class MujocoModelHelper:
|
|
148
148
|
def gravity(self) -> npt.NDArray:
|
149
149
|
"""Return the 3D gravity vector."""
|
150
150
|
|
151
|
-
return self.model.
|
151
|
+
return np.array([0, 0, self.model.gravity])
|
152
152
|
|
153
153
|
# =========================
|
154
154
|
# Methods for the base link
|
jaxsim/mujoco/utils.py
CHANGED
@@ -59,11 +59,11 @@ def mujoco_data_from_jaxsim(
|
|
59
59
|
if jaxsim_model.floating_base():
|
60
60
|
|
61
61
|
# Set the model position.
|
62
|
-
model_helper.set_base_position(position=np.array(jaxsim_data.base_position
|
62
|
+
model_helper.set_base_position(position=np.array(jaxsim_data.base_position))
|
63
63
|
|
64
64
|
# Set the model orientation.
|
65
65
|
model_helper.set_base_orientation(
|
66
|
-
orientation=np.array(jaxsim_data.base_orientation
|
66
|
+
orientation=np.array(jaxsim_data.base_orientation)
|
67
67
|
)
|
68
68
|
|
69
69
|
# Set the joint positions.
|
@@ -952,9 +952,7 @@ class KinematicGraphTransforms:
|
|
952
952
|
import jaxsim.math
|
953
953
|
|
954
954
|
return np.array(
|
955
|
-
jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)
|
956
|
-
0
|
957
|
-
]
|
955
|
+
jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)
|
958
956
|
)
|
959
957
|
|
960
958
|
def find_parent_link_of_frame(self, name: str) -> str:
|
jaxsim/rbda/aba.py
CHANGED
@@ -4,7 +4,7 @@ import jaxlie
|
|
4
4
|
|
5
5
|
import jaxsim.api as js
|
6
6
|
import jaxsim.typing as jtp
|
7
|
-
from jaxsim.math import Adjoint, Cross
|
7
|
+
from jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross
|
8
8
|
|
9
9
|
from . import utils
|
10
10
|
|
@@ -20,7 +20,7 @@ def aba(
|
|
20
20
|
joint_velocities: jtp.VectorLike,
|
21
21
|
joint_forces: jtp.VectorLike | None = None,
|
22
22
|
link_forces: jtp.MatrixLike | None = None,
|
23
|
-
standard_gravity: jtp.FloatLike =
|
23
|
+
standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,
|
24
24
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
25
25
|
"""
|
26
26
|
Compute forward dynamics using the Articulated Body Algorithm (ABA).
|
@@ -85,13 +85,16 @@ def aba(
|
|
85
85
|
W_X_B = W_H_B.adjoint()
|
86
86
|
B_X_W = W_H_B.inverse().adjoint()
|
87
87
|
|
88
|
-
# Compute the parent-to-child adjoints
|
88
|
+
# Compute the parent-to-child adjoints of the joints.
|
89
89
|
# These transforms define the relative kinematics of the entire model, including
|
90
90
|
# the base transform for both floating-base and fixed-base models.
|
91
|
-
i_X_λi
|
91
|
+
i_X_λi = model.kin_dyn_parameters.joint_transforms(
|
92
92
|
joint_positions=s, base_transform=W_H_B.as_matrix()
|
93
93
|
)
|
94
94
|
|
95
|
+
# Extract the joint motion subspaces.
|
96
|
+
S = model.kin_dyn_parameters.motion_subspaces
|
97
|
+
|
95
98
|
# Allocate buffers.
|
96
99
|
v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
97
100
|
c = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
jaxsim/rbda/collidable_points.py
CHANGED
@@ -1,23 +1,16 @@
|
|
1
1
|
import jax
|
2
2
|
import jax.numpy as jnp
|
3
|
-
import jaxlie
|
4
3
|
|
5
4
|
import jaxsim.api as js
|
6
5
|
import jaxsim.typing as jtp
|
7
|
-
from jaxsim.math import
|
8
|
-
|
9
|
-
from . import utils
|
6
|
+
from jaxsim.math import Skew
|
10
7
|
|
11
8
|
|
12
9
|
def collidable_points_pos_vel(
|
13
10
|
model: js.model.JaxSimModel,
|
14
11
|
*,
|
15
|
-
|
16
|
-
|
17
|
-
joint_positions: jtp.Vector,
|
18
|
-
base_linear_velocity: jtp.Vector,
|
19
|
-
base_angular_velocity: jtp.Vector,
|
20
|
-
joint_velocities: jtp.Vector,
|
12
|
+
link_transforms: jtp.Matrix,
|
13
|
+
link_velocities: jtp.Matrix,
|
21
14
|
) -> tuple[jtp.Matrix, jtp.Matrix]:
|
22
15
|
"""
|
23
16
|
|
@@ -25,14 +18,8 @@ def collidable_points_pos_vel(
|
|
25
18
|
|
26
19
|
Args:
|
27
20
|
model: The model to consider.
|
28
|
-
|
29
|
-
|
30
|
-
joint_positions: The positions of the joints.
|
31
|
-
base_linear_velocity:
|
32
|
-
The linear velocity of the base link in inertial-fixed representation.
|
33
|
-
base_angular_velocity:
|
34
|
-
The angular velocity of the base link in inertial-fixed representation.
|
35
|
-
joint_velocities: The velocities of the joints.
|
21
|
+
link_transforms: The transforms from the world frame to each link.
|
22
|
+
link_velocities: The linear and angular velocities of each link.
|
36
23
|
|
37
24
|
Returns:
|
38
25
|
A tuple containing the position and linear velocity of the enabled collidable points.
|
@@ -54,95 +41,17 @@ def collidable_points_pos_vel(
|
|
54
41
|
if len(indices_of_enabled_collidable_points) == 0:
|
55
42
|
return jnp.array(0).astype(float), jnp.empty(0).astype(float)
|
56
43
|
|
57
|
-
W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
|
58
|
-
model=model,
|
59
|
-
base_position=base_position,
|
60
|
-
base_quaternion=base_quaternion,
|
61
|
-
joint_positions=joint_positions,
|
62
|
-
base_linear_velocity=base_linear_velocity,
|
63
|
-
base_angular_velocity=base_angular_velocity,
|
64
|
-
joint_velocities=joint_velocities,
|
65
|
-
)
|
66
|
-
|
67
|
-
# Get the parent array λ(i).
|
68
|
-
# Note: λ(0) must not be used, it's initialized to -1.
|
69
|
-
λ = model.kin_dyn_parameters.parent_array
|
70
|
-
|
71
|
-
# Compute the base transform.
|
72
|
-
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
73
|
-
rotation=jaxlie.SO3(wxyz=W_Q_B),
|
74
|
-
translation=W_p_B,
|
75
|
-
)
|
76
|
-
|
77
|
-
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
78
|
-
# These transforms define the relative kinematics of the entire model, including
|
79
|
-
# the base transform for both floating-base and fixed-base models.
|
80
|
-
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
81
|
-
joint_positions=s, base_transform=W_H_B.as_matrix()
|
82
|
-
)
|
83
|
-
|
84
|
-
# Allocate buffer of transforms world -> link and initialize the base pose.
|
85
|
-
W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
86
|
-
W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
|
87
|
-
|
88
|
-
# Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity.
|
89
|
-
W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6))
|
90
|
-
W_v_Wi = W_v_Wi.at[0].set(W_v_WB)
|
91
|
-
|
92
|
-
# ====================
|
93
|
-
# Propagate kinematics
|
94
|
-
# ====================
|
95
|
-
|
96
|
-
PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix]
|
97
|
-
propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)
|
98
|
-
|
99
|
-
def propagate_kinematics(
|
100
|
-
carry: PropagateTransformsCarry, i: jtp.Int
|
101
|
-
) -> tuple[PropagateTransformsCarry, None]:
|
102
|
-
|
103
|
-
ii = i - 1
|
104
|
-
W_X_i, W_v_Wi = carry
|
105
|
-
|
106
|
-
# Compute the parent to child 6D transform.
|
107
|
-
λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i])
|
108
|
-
|
109
|
-
# Compute the world to child 6D transform.
|
110
|
-
W_Xi_i = W_X_i[λ[i]] @ λi_X_i
|
111
|
-
W_X_i = W_X_i.at[i].set(W_Xi_i)
|
112
|
-
|
113
|
-
# Propagate the 6D velocity.
|
114
|
-
W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
|
115
|
-
W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
|
116
|
-
|
117
|
-
return (W_X_i, W_v_Wi), None
|
118
|
-
|
119
|
-
(W_X_i, W_v_Wi), _ = (
|
120
|
-
jax.lax.scan(
|
121
|
-
f=propagate_kinematics,
|
122
|
-
init=propagate_transforms_carry,
|
123
|
-
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
124
|
-
)
|
125
|
-
if model.number_of_links() > 1
|
126
|
-
else [(W_X_i, W_v_Wi), None]
|
127
|
-
)
|
128
|
-
|
129
|
-
# ==================================================
|
130
|
-
# Compute position and velocity of collidable points
|
131
|
-
# ==================================================
|
132
|
-
|
133
44
|
def process_point_kinematics(
|
134
45
|
Li_p_C: jtp.Vector, parent_body: jtp.Int
|
135
46
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
136
47
|
|
137
48
|
# Compute the position of the collidable point.
|
138
|
-
W_p_Ci = (
|
139
|
-
Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
|
140
|
-
)[0:3]
|
49
|
+
W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3]
|
141
50
|
|
142
51
|
# Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}.
|
143
52
|
CW_vl_WCi = (
|
144
53
|
jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
|
145
|
-
@
|
54
|
+
@ link_velocities[parent_body].squeeze()
|
146
55
|
)
|
147
56
|
|
148
57
|
return W_p_Ci, CW_vl_WCi
|
jaxsim/rbda/contacts/__init__.py
CHANGED
@@ -1,13 +1,5 @@
|
|
1
|
-
from . import relaxed_rigid
|
1
|
+
from . import relaxed_rigid
|
2
2
|
from .common import ContactModel, ContactsParams
|
3
3
|
from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
|
4
|
-
from .rigid import RigidContacts, RigidContactsParams
|
5
|
-
from .soft import SoftContacts, SoftContactsParams
|
6
|
-
from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams
|
7
4
|
|
8
|
-
ContactParamsTypes =
|
9
|
-
SoftContactsParams
|
10
|
-
| RigidContactsParams
|
11
|
-
| RelaxedRigidContactsParams
|
12
|
-
| ViscoElasticContactsParams
|
13
|
-
)
|
5
|
+
ContactParamsTypes = RelaxedRigidContactsParams
|
jaxsim/rbda/contacts/common.py
CHANGED
@@ -9,7 +9,6 @@ import jax.numpy as jnp
|
|
9
9
|
import jaxsim.api as js
|
10
10
|
import jaxsim.terrain
|
11
11
|
import jaxsim.typing as jtp
|
12
|
-
from jaxsim.api.common import ModelDataWithVelocityRepresentation
|
13
12
|
from jaxsim.utils import JaxsimDataclass
|
14
13
|
|
15
14
|
try:
|
@@ -135,143 +134,6 @@ class ContactModel(JaxsimDataclass):
|
|
135
134
|
|
136
135
|
pass
|
137
136
|
|
138
|
-
def compute_link_contact_forces(
|
139
|
-
self,
|
140
|
-
model: js.model.JaxSimModel,
|
141
|
-
data: js.data.JaxSimModelData,
|
142
|
-
**kwargs,
|
143
|
-
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
144
|
-
"""
|
145
|
-
Compute the link contact forces.
|
146
|
-
|
147
|
-
Args:
|
148
|
-
model: The robot model considered by the contact model.
|
149
|
-
data: The data of the considered model.
|
150
|
-
**kwargs: Optional additional arguments, specific to the contact model.
|
151
|
-
|
152
|
-
Returns:
|
153
|
-
A tuple containing as first element the 6D contact force applied to the
|
154
|
-
links and expressed in the frame of the velocity representation of data,
|
155
|
-
and as second element a dictionary of optional additional information.
|
156
|
-
"""
|
157
|
-
|
158
|
-
# Compute the contact forces expressed in the inertial frame.
|
159
|
-
# This function, contrarily to `compute_contact_forces`, already handles how
|
160
|
-
# the optional kwargs should be passed to the specific contact models.
|
161
|
-
W_f_C, aux_dict = js.contact.collidable_point_dynamics(
|
162
|
-
model=model, data=data, **kwargs
|
163
|
-
)
|
164
|
-
|
165
|
-
# Compute the 6D forces applied to the links equivalent to the forces applied
|
166
|
-
# to the frames associated to the collidable points.
|
167
|
-
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
|
168
|
-
|
169
|
-
W_f_L = self.link_forces_from_contact_forces(
|
170
|
-
model=model, data=data, contact_forces=W_f_C
|
171
|
-
)
|
172
|
-
|
173
|
-
# Store the link forces in the references object for easy conversion.
|
174
|
-
references = js.references.JaxSimModelReferences.build(
|
175
|
-
model=model,
|
176
|
-
data=data,
|
177
|
-
link_forces=W_f_L,
|
178
|
-
velocity_representation=jaxsim.VelRepr.Inertial,
|
179
|
-
)
|
180
|
-
|
181
|
-
# Convert the link forces to the frame corresponding to the velocity
|
182
|
-
# representation of data.
|
183
|
-
with references.switch_velocity_representation(data.velocity_representation):
|
184
|
-
f_L = references.link_forces(model=model, data=data)
|
185
|
-
|
186
|
-
return f_L, aux_dict
|
187
|
-
|
188
|
-
@staticmethod
|
189
|
-
def link_forces_from_contact_forces(
|
190
|
-
model: js.model.JaxSimModel,
|
191
|
-
data: js.data.JaxSimModelData,
|
192
|
-
*,
|
193
|
-
contact_forces: jtp.MatrixLike,
|
194
|
-
) -> jtp.Matrix:
|
195
|
-
"""
|
196
|
-
Compute the link forces from the contact forces.
|
197
|
-
|
198
|
-
Args:
|
199
|
-
model: The robot model considered by the contact model.
|
200
|
-
data: The data of the considered model.
|
201
|
-
contact_forces: The contact forces computed by the contact model.
|
202
|
-
|
203
|
-
Returns:
|
204
|
-
The 6D contact forces applied to the links and expressed in the frame of
|
205
|
-
the velocity representation of data.
|
206
|
-
"""
|
207
|
-
|
208
|
-
# Get the object storing the contact parameters of the model.
|
209
|
-
contact_parameters = model.kin_dyn_parameters.contact_parameters
|
210
|
-
|
211
|
-
# Extract the indices corresponding to the enabled collidable points.
|
212
|
-
indices_of_enabled_collidable_points = (
|
213
|
-
contact_parameters.indices_of_enabled_collidable_points
|
214
|
-
)
|
215
|
-
|
216
|
-
# Convert the contact forces to a JAX array.
|
217
|
-
f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
|
218
|
-
|
219
|
-
# Get the pose of the enabled collidable points.
|
220
|
-
W_H_C = js.contact.transforms(model=model, data=data)[
|
221
|
-
indices_of_enabled_collidable_points
|
222
|
-
]
|
223
|
-
|
224
|
-
# Convert the contact forces to inertial-fixed representation.
|
225
|
-
W_f_C = jax.vmap(
|
226
|
-
lambda f_C, W_H_C: (
|
227
|
-
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
228
|
-
array=f_C,
|
229
|
-
other_representation=data.velocity_representation,
|
230
|
-
transform=W_H_C,
|
231
|
-
is_force=True,
|
232
|
-
)
|
233
|
-
)
|
234
|
-
)(f_C, W_H_C)
|
235
|
-
|
236
|
-
# Construct the vector defining the parent link index of each collidable point.
|
237
|
-
# We use this vector to sum the 6D forces of all collidable points rigidly
|
238
|
-
# attached to the same link.
|
239
|
-
parent_link_index_of_collidable_points = jnp.array(
|
240
|
-
contact_parameters.body, dtype=int
|
241
|
-
)[indices_of_enabled_collidable_points]
|
242
|
-
|
243
|
-
# Create the mask that associate each collidable point to their parent link.
|
244
|
-
# We use this mask to sum the collidable points to the right link.
|
245
|
-
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
|
246
|
-
model.number_of_links()
|
247
|
-
)
|
248
|
-
|
249
|
-
# Sum the forces of all collidable points rigidly attached to a body.
|
250
|
-
# Since the contact forces W_f_C are expressed in the world frame,
|
251
|
-
# we don't need any coordinate transformation.
|
252
|
-
W_f_L = mask.T @ W_f_C
|
253
|
-
|
254
|
-
# Compute the link transforms.
|
255
|
-
W_H_L = (
|
256
|
-
js.model.forward_kinematics(model=model, data=data)
|
257
|
-
if data.velocity_representation is not jaxsim.VelRepr.Inertial
|
258
|
-
else jnp.zeros(shape=(model.number_of_links(), 4, 4))
|
259
|
-
)
|
260
|
-
|
261
|
-
# Convert the inertial-fixed link forces to the velocity representation of data.
|
262
|
-
f_L = jax.vmap(
|
263
|
-
lambda W_f_L, W_H_L: (
|
264
|
-
ModelDataWithVelocityRepresentation.inertial_to_other_representation(
|
265
|
-
array=W_f_L,
|
266
|
-
other_representation=data.velocity_representation,
|
267
|
-
transform=W_H_L,
|
268
|
-
is_force=True,
|
269
|
-
)
|
270
|
-
)
|
271
|
-
)(W_f_L, W_H_L)
|
272
|
-
|
273
|
-
return f_L
|
274
|
-
|
275
137
|
@classmethod
|
276
138
|
def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
|
277
139
|
"""
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
+
import functools
|
4
5
|
from collections.abc import Callable
|
5
6
|
from typing import Any
|
6
7
|
|
@@ -13,6 +14,7 @@ import jaxsim.api as js
|
|
13
14
|
import jaxsim.rbda.contacts
|
14
15
|
import jaxsim.typing as jtp
|
15
16
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
17
|
+
from jaxsim.terrain.terrain import Terrain
|
16
18
|
|
17
19
|
from . import common
|
18
20
|
|
@@ -263,7 +265,7 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
263
265
|
Optional `(n_joints,)` vector of joint forces.
|
264
266
|
|
265
267
|
Returns:
|
266
|
-
A tuple containing as first element the computed contact forces.
|
268
|
+
A tuple containing as first element the computed contact forces in inertial representation.
|
267
269
|
"""
|
268
270
|
|
269
271
|
link_forces = jnp.atleast_2d(
|
@@ -306,20 +308,17 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
306
308
|
W_H_C = js.contact.transforms(model=model, data=data)
|
307
309
|
|
308
310
|
with (
|
309
|
-
references.switch_velocity_representation(VelRepr.Mixed),
|
310
311
|
data.switch_velocity_representation(VelRepr.Mixed),
|
312
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
311
313
|
):
|
312
|
-
|
313
|
-
BW_ν = data.generalized_velocity()
|
314
|
+
BW_ν = data.generalized_velocity
|
314
315
|
|
315
316
|
BW_ν̇_free = jnp.hstack(
|
316
317
|
js.ode.system_acceleration(
|
317
318
|
model=model,
|
318
319
|
data=data,
|
319
320
|
link_forces=references.link_forces(model=model, data=data),
|
320
|
-
|
321
|
-
model=model
|
322
|
-
),
|
321
|
+
joint_torques=references.joint_force_references(model=model),
|
323
322
|
)
|
324
323
|
)
|
325
324
|
|
@@ -342,7 +341,7 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
342
341
|
model=model,
|
343
342
|
position_constraint=position_constraint,
|
344
343
|
velocity_constraint=velocity,
|
345
|
-
parameters=
|
344
|
+
parameters=model.contacts_params,
|
346
345
|
)
|
347
346
|
|
348
347
|
# Compute the Delassus matrix and the free mixed linear acceleration of
|
@@ -426,7 +425,7 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
426
425
|
|
427
426
|
# Initialize the optimized forces with a linear Hunt/Crossley model.
|
428
427
|
init_params = jax.vmap(
|
429
|
-
lambda p, v:
|
428
|
+
lambda p, v: self._hunt_crossley_contact_model(
|
430
429
|
position=p,
|
431
430
|
velocity=v,
|
432
431
|
terrain=model.terrain,
|
@@ -603,3 +602,149 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
603
602
|
)
|
604
603
|
|
605
604
|
return a_ref, jnp.diag(R), K, D
|
605
|
+
|
606
|
+
@staticmethod
|
607
|
+
@functools.partial(jax.jit, static_argnames=("terrain",))
|
608
|
+
def _hunt_crossley_contact_model(
|
609
|
+
position: jtp.VectorLike,
|
610
|
+
velocity: jtp.VectorLike,
|
611
|
+
tangential_deformation: jtp.VectorLike,
|
612
|
+
terrain: Terrain,
|
613
|
+
K: jtp.FloatLike,
|
614
|
+
D: jtp.FloatLike,
|
615
|
+
mu: jtp.FloatLike,
|
616
|
+
p: jtp.FloatLike = 0.5,
|
617
|
+
q: jtp.FloatLike = 0.5,
|
618
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
619
|
+
"""
|
620
|
+
Compute the contact force using the Hunt/Crossley model.
|
621
|
+
|
622
|
+
Args:
|
623
|
+
position: The position of the collidable point.
|
624
|
+
velocity: The velocity of the collidable point.
|
625
|
+
tangential_deformation: The material deformation of the collidable point.
|
626
|
+
terrain: The terrain model.
|
627
|
+
K: The stiffness parameter.
|
628
|
+
D: The damping parameter of the soft contacts model.
|
629
|
+
mu: The static friction coefficient.
|
630
|
+
p:
|
631
|
+
The exponent p corresponding to the damping-related non-linearity
|
632
|
+
of the Hunt/Crossley model.
|
633
|
+
q:
|
634
|
+
The exponent q corresponding to the spring-related non-linearity
|
635
|
+
of the Hunt/Crossley model
|
636
|
+
|
637
|
+
Returns:
|
638
|
+
A tuple containing the computed contact force and the derivative of the
|
639
|
+
material deformation.
|
640
|
+
"""
|
641
|
+
|
642
|
+
# Convert the input vectors to arrays.
|
643
|
+
W_p_C = jnp.array(position, dtype=float).squeeze()
|
644
|
+
W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()
|
645
|
+
m = jnp.array(tangential_deformation, dtype=float).squeeze()
|
646
|
+
|
647
|
+
# Use symbol for the static friction.
|
648
|
+
μ = mu
|
649
|
+
|
650
|
+
# Compute the penetration depth, its rate, and the considered terrain normal.
|
651
|
+
δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)
|
652
|
+
|
653
|
+
# There are few operations like computing the norm of a vector with zero length
|
654
|
+
# or computing the square root of zero that are problematic in an AD context.
|
655
|
+
# To avoid these issues, we introduce a small tolerance ε to their arguments
|
656
|
+
# and make sure that we do not check them against zero directly.
|
657
|
+
ε = jnp.finfo(float).eps
|
658
|
+
|
659
|
+
# Compute the powers of the penetration depth.
|
660
|
+
# Inject ε to address AD issues in differentiating the square root when
|
661
|
+
# p and q are fractional.
|
662
|
+
δp = jnp.power(δ + ε, p)
|
663
|
+
δq = jnp.power(δ + ε, q)
|
664
|
+
|
665
|
+
# ========================
|
666
|
+
# Compute the normal force
|
667
|
+
# ========================
|
668
|
+
|
669
|
+
# Non-linear spring-damper model (Hunt/Crossley model).
|
670
|
+
# This is the force magnitude along the direction normal to the terrain.
|
671
|
+
force_normal_mag = (K * δp) * δ + (D * δq) * δ̇
|
672
|
+
|
673
|
+
# Depending on the magnitude of δ̇, the normal force could be negative.
|
674
|
+
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
|
675
|
+
|
676
|
+
# Compute the 3D linear force in C[W] frame.
|
677
|
+
f_normal = force_normal_mag * n̂
|
678
|
+
|
679
|
+
# ============================
|
680
|
+
# Compute the tangential force
|
681
|
+
# ============================
|
682
|
+
|
683
|
+
# Extract the tangential component of the velocity.
|
684
|
+
v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂
|
685
|
+
|
686
|
+
# Extract the normal and tangential components of the material deformation.
|
687
|
+
m_normal = jnp.dot(m, n̂) * n̂
|
688
|
+
m_tangential = m - jnp.dot(m, n̂) * n̂
|
689
|
+
|
690
|
+
# Compute the tangential force in the sticking case.
|
691
|
+
# Using the tangential component of the material deformation should not be
|
692
|
+
# necessary if the sticking-slipping transition occurs in a terrain area
|
693
|
+
# with a locally constant normal. However, this assumption is not true in
|
694
|
+
# general, especially for highly uneven terrains.
|
695
|
+
f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)
|
696
|
+
|
697
|
+
# Detect the contact type (sticking or slipping).
|
698
|
+
# Note that if there is no contact, sticking is set to True, and this detail
|
699
|
+
# is exploited in the computation of the `contact_status` variable.
|
700
|
+
sticking = jnp.logical_or(
|
701
|
+
δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2
|
702
|
+
)
|
703
|
+
|
704
|
+
# Compute the direction of the tangential force.
|
705
|
+
# To prevent dividing by zero, we use a switch statement.
|
706
|
+
norm = jaxsim.math.safe_norm(f_tangential)
|
707
|
+
f_tangential_direction = f_tangential / (
|
708
|
+
norm + jnp.finfo(float).eps * (norm == 0)
|
709
|
+
)
|
710
|
+
|
711
|
+
# Project the tangential force to the friction cone if slipping.
|
712
|
+
f_tangential = jnp.where(
|
713
|
+
sticking,
|
714
|
+
f_tangential,
|
715
|
+
jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
|
716
|
+
)
|
717
|
+
|
718
|
+
# Set the tangential force to zero if there is no contact.
|
719
|
+
f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential)
|
720
|
+
|
721
|
+
# =====================================
|
722
|
+
# Compute the material deformation rate
|
723
|
+
# =====================================
|
724
|
+
|
725
|
+
# Compute the derivative of the material deformation.
|
726
|
+
# Note that we included an additional relaxation of `m_normal` in the
|
727
|
+
# sticking case, so that the normal deformation that could have accumulated
|
728
|
+
# from a previous slipping phase can relax to zero.
|
729
|
+
ṁ_no_contact = -(K / D) * m
|
730
|
+
ṁ_sticking = v_tangential - (K / D) * m_normal
|
731
|
+
ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)
|
732
|
+
|
733
|
+
# Compute the contact status:
|
734
|
+
# 0: slipping
|
735
|
+
# 1: sticking
|
736
|
+
# 2: no contact
|
737
|
+
contact_status = sticking.astype(int)
|
738
|
+
contact_status += (δ <= 0).astype(int)
|
739
|
+
|
740
|
+
# Select the right material deformation rate depending on the contact status.
|
741
|
+
ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact)
|
742
|
+
|
743
|
+
# ==========================================
|
744
|
+
# Compute and return the final contact force
|
745
|
+
# ==========================================
|
746
|
+
|
747
|
+
# Sum the normal and tangential forces.
|
748
|
+
CW_fl = f_normal + f_tangential
|
749
|
+
|
750
|
+
return CW_fl, ṁ
|
jaxsim/rbda/crba.py
CHANGED
@@ -30,13 +30,16 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
|
|
30
30
|
# Note: λ(0) must not be used, it's initialized to -1.
|
31
31
|
λ = model.kin_dyn_parameters.parent_array
|
32
32
|
|
33
|
-
# Compute the parent-to-child adjoints
|
33
|
+
# Compute the parent-to-child adjoints of the joints.
|
34
34
|
# These transforms define the relative kinematics of the entire model, including
|
35
35
|
# the base transform for both floating-base and fixed-base models.
|
36
|
-
i_X_λi
|
36
|
+
i_X_λi = model.kin_dyn_parameters.joint_transforms(
|
37
37
|
joint_positions=s, base_transform=jnp.eye(4)
|
38
38
|
)
|
39
39
|
|
40
|
+
# Extract the joint motion subspaces.
|
41
|
+
S = model.kin_dyn_parameters.motion_subspaces
|
42
|
+
|
40
43
|
# Allocate the buffer of transforms link -> base.
|
41
44
|
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
42
45
|
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|