jaxsim 0.6.2.dev182__py3-none-any.whl → 0.6.2.dev225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +0 -1
- jaxsim/api/com.py +1 -3
- jaxsim/api/common.py +26 -38
- jaxsim/api/contact.py +140 -24
- jaxsim/api/data.py +96 -33
- jaxsim/api/integrators.py +18 -11
- jaxsim/api/model.py +25 -43
- jaxsim/api/ode.py +28 -6
- jaxsim/api/references.py +9 -16
- jaxsim/math/__init__.py +1 -1
- jaxsim/math/adjoint.py +2 -2
- jaxsim/math/transform.py +2 -2
- jaxsim/math/utils.py +3 -2
- jaxsim/mujoco/visualizer.py +1 -1
- jaxsim/parsers/kinematic_graph.py +1 -1
- jaxsim/rbda/__init__.py +1 -1
- jaxsim/rbda/contacts/__init__.py +6 -2
- jaxsim/rbda/contacts/common.py +114 -4
- jaxsim/rbda/contacts/relaxed_rigid.py +57 -177
- jaxsim/rbda/contacts/rigid.py +538 -0
- jaxsim/rbda/contacts/soft.py +448 -0
- jaxsim/rbda/forward_kinematics.py +0 -29
- jaxsim/rbda/utils.py +2 -2
- jaxsim/terrain/terrain.py +1 -1
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/METADATA +3 -2
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/RECORD +30 -29
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/WHEEL +1 -1
- jaxsim/api/contact_model.py +0 -101
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info/licenses}/LICENSE +0 -0
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,538 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import jax_dataclasses
|
9
|
+
import qpax
|
10
|
+
|
11
|
+
import jaxsim.api as js
|
12
|
+
import jaxsim.typing as jtp
|
13
|
+
from jaxsim import logging
|
14
|
+
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
15
|
+
|
16
|
+
from . import common
|
17
|
+
from .common import ContactModel, ContactsParams
|
18
|
+
|
19
|
+
try:
|
20
|
+
from typing import Self
|
21
|
+
except ImportError:
|
22
|
+
from typing_extensions import Self
|
23
|
+
|
24
|
+
|
25
|
+
@jax_dataclasses.pytree_dataclass
|
26
|
+
class RigidContactsParams(ContactsParams):
|
27
|
+
"""Parameters of the rigid contacts model."""
|
28
|
+
|
29
|
+
# Static friction coefficient
|
30
|
+
mu: jtp.Float = dataclasses.field(
|
31
|
+
default_factory=lambda: jnp.array(0.5, dtype=float)
|
32
|
+
)
|
33
|
+
|
34
|
+
# Baumgarte proportional term
|
35
|
+
K: jtp.Float = dataclasses.field(
|
36
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
37
|
+
)
|
38
|
+
|
39
|
+
# Baumgarte derivative term
|
40
|
+
D: jtp.Float = dataclasses.field(
|
41
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
42
|
+
)
|
43
|
+
|
44
|
+
def __hash__(self) -> int:
|
45
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
46
|
+
|
47
|
+
return hash(
|
48
|
+
(
|
49
|
+
HashedNumpyArray.hash_of_array(self.mu),
|
50
|
+
HashedNumpyArray.hash_of_array(self.K),
|
51
|
+
HashedNumpyArray.hash_of_array(self.D),
|
52
|
+
)
|
53
|
+
)
|
54
|
+
|
55
|
+
def __eq__(self, other: RigidContactsParams) -> bool:
|
56
|
+
if not isinstance(other, RigidContactsParams):
|
57
|
+
return False
|
58
|
+
|
59
|
+
return hash(self) == hash(other)
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def build(
|
63
|
+
cls: type[Self],
|
64
|
+
*,
|
65
|
+
mu: jtp.FloatLike | None = None,
|
66
|
+
K: jtp.FloatLike | None = None,
|
67
|
+
D: jtp.FloatLike | None = None,
|
68
|
+
) -> Self:
|
69
|
+
"""Create a `RigidContactParams` instance."""
|
70
|
+
|
71
|
+
return cls(
|
72
|
+
mu=jnp.array(
|
73
|
+
mu
|
74
|
+
if mu is not None
|
75
|
+
else cls.__dataclass_fields__["mu"].default_factory()
|
76
|
+
).astype(float),
|
77
|
+
K=jnp.array(
|
78
|
+
K if K is not None else cls.__dataclass_fields__["K"].default_factory()
|
79
|
+
).astype(float),
|
80
|
+
D=jnp.array(
|
81
|
+
D if D is not None else cls.__dataclass_fields__["D"].default_factory()
|
82
|
+
).astype(float),
|
83
|
+
)
|
84
|
+
|
85
|
+
def valid(self) -> jtp.BoolLike:
|
86
|
+
"""Check if the parameters are valid."""
|
87
|
+
return bool(
|
88
|
+
jnp.all(self.mu >= 0.0)
|
89
|
+
and jnp.all(self.K >= 0.0)
|
90
|
+
and jnp.all(self.D >= 0.0)
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
@jax_dataclasses.pytree_dataclass
|
95
|
+
class RigidContacts(ContactModel):
|
96
|
+
"""Rigid contacts model."""
|
97
|
+
|
98
|
+
regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field(
|
99
|
+
default=1e-6, kw_only=True
|
100
|
+
)
|
101
|
+
|
102
|
+
_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
|
103
|
+
default=("solver_tol",), kw_only=True
|
104
|
+
)
|
105
|
+
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
|
106
|
+
default=(1e-3,), kw_only=True
|
107
|
+
)
|
108
|
+
|
109
|
+
@property
|
110
|
+
def solver_options(self) -> dict[str, Any]:
|
111
|
+
"""Get the solver options as a dictionary."""
|
112
|
+
|
113
|
+
return dict(
|
114
|
+
zip(
|
115
|
+
self._solver_options_keys,
|
116
|
+
self._solver_options_values,
|
117
|
+
strict=True,
|
118
|
+
)
|
119
|
+
)
|
120
|
+
|
121
|
+
@classmethod
|
122
|
+
def build(
|
123
|
+
cls: type[Self],
|
124
|
+
regularization_delassus: jtp.FloatLike | None = None,
|
125
|
+
solver_options: dict[str, Any] | None = None,
|
126
|
+
**kwargs,
|
127
|
+
) -> Self:
|
128
|
+
"""
|
129
|
+
Create a `RigidContacts` instance with specified parameters.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
regularization_delassus:
|
133
|
+
The regularization term to add to the diagonal of the Delassus matrix.
|
134
|
+
solver_options: The options to pass to the QP solver.
|
135
|
+
**kwargs: Extra arguments which are ignored.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
The `RigidContacts` instance.
|
139
|
+
"""
|
140
|
+
|
141
|
+
if len(kwargs) != 0:
|
142
|
+
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
|
143
|
+
|
144
|
+
# Get the default solver options.
|
145
|
+
default_solver_options = dict(
|
146
|
+
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
|
147
|
+
)
|
148
|
+
|
149
|
+
# Create the solver options to set by combining the default solver options
|
150
|
+
# with the user-provided solver options.
|
151
|
+
solver_options = default_solver_options | (
|
152
|
+
solver_options if solver_options is not None else {}
|
153
|
+
)
|
154
|
+
|
155
|
+
# Make sure that the solver options are hashable.
|
156
|
+
# We need to check this because the solver options are static.
|
157
|
+
try:
|
158
|
+
hash(tuple(solver_options.values()))
|
159
|
+
except TypeError as exc:
|
160
|
+
raise ValueError(
|
161
|
+
"The values of the solver options must be hashable."
|
162
|
+
) from exc
|
163
|
+
|
164
|
+
return cls(
|
165
|
+
regularization_delassus=float(
|
166
|
+
regularization_delassus
|
167
|
+
if regularization_delassus is not None
|
168
|
+
else cls.__dataclass_fields__["regularization_delassus"].default
|
169
|
+
),
|
170
|
+
_solver_options_keys=tuple(solver_options.keys()),
|
171
|
+
_solver_options_values=tuple(solver_options.values()),
|
172
|
+
**kwargs,
|
173
|
+
)
|
174
|
+
|
175
|
+
@staticmethod
|
176
|
+
def compute_impact_velocity(
|
177
|
+
inactive_collidable_points: jtp.ArrayLike,
|
178
|
+
M: jtp.MatrixLike,
|
179
|
+
J_WC: jtp.MatrixLike,
|
180
|
+
generalized_velocity: jtp.VectorLike,
|
181
|
+
) -> jtp.Vector:
|
182
|
+
"""
|
183
|
+
Return the new velocity of the system after a potential impact.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
inactive_collidable_points: The activation state of the collidable points.
|
187
|
+
M: The mass matrix of the system (in mixed representation).
|
188
|
+
J_WC: The Jacobian matrix of the collidable points (in mixed representation).
|
189
|
+
generalized_velocity: The generalized velocity of the system.
|
190
|
+
|
191
|
+
Note:
|
192
|
+
The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity`
|
193
|
+
must be expressed in the same velocity representation.
|
194
|
+
"""
|
195
|
+
|
196
|
+
# Compute system velocity after impact maintaining zero linear velocity of active points.
|
197
|
+
sl = jnp.s_[:, 0:3, :]
|
198
|
+
Jl_WC = J_WC[sl]
|
199
|
+
|
200
|
+
# Zero out the jacobian rows of inactive points.
|
201
|
+
Jl_WC = jnp.vstack(
|
202
|
+
jnp.where(
|
203
|
+
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
|
204
|
+
jnp.zeros_like(Jl_WC),
|
205
|
+
Jl_WC,
|
206
|
+
)
|
207
|
+
)
|
208
|
+
|
209
|
+
A = jnp.vstack(
|
210
|
+
[
|
211
|
+
jnp.hstack([M, -Jl_WC.T]),
|
212
|
+
jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]),
|
213
|
+
]
|
214
|
+
)
|
215
|
+
b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])])
|
216
|
+
|
217
|
+
BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0]
|
218
|
+
|
219
|
+
return BW_ν_post_impact[0 : M.shape[0]]
|
220
|
+
|
221
|
+
@jax.jit
|
222
|
+
@js.common.named_scope
|
223
|
+
def compute_contact_forces(
|
224
|
+
self,
|
225
|
+
model: js.model.JaxSimModel,
|
226
|
+
data: js.data.JaxSimModelData,
|
227
|
+
*,
|
228
|
+
link_forces: jtp.MatrixLike | None = None,
|
229
|
+
joint_force_references: jtp.VectorLike | None = None,
|
230
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
231
|
+
"""
|
232
|
+
Compute the contact forces.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
model: The model to consider.
|
236
|
+
data: The data of the considered model.
|
237
|
+
link_forces:
|
238
|
+
Optional `(n_links, 6)` matrix of external forces acting on the links,
|
239
|
+
expressed in the same representation of data.
|
240
|
+
joint_force_references:
|
241
|
+
Optional `(n_joints,)` vector of joint forces.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
A tuple containing as first element the computed contact forces.
|
245
|
+
"""
|
246
|
+
|
247
|
+
# Get the indices of the enabled collidable points.
|
248
|
+
indices_of_enabled_collidable_points = (
|
249
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
250
|
+
)
|
251
|
+
|
252
|
+
n_collidable_points = len(indices_of_enabled_collidable_points)
|
253
|
+
|
254
|
+
link_forces = jnp.atleast_2d(
|
255
|
+
jnp.array(link_forces, dtype=float).squeeze()
|
256
|
+
if link_forces is not None
|
257
|
+
else jnp.zeros((model.number_of_links(), 6))
|
258
|
+
)
|
259
|
+
|
260
|
+
joint_force_references = jnp.atleast_1d(
|
261
|
+
jnp.array(joint_force_references, dtype=float).squeeze()
|
262
|
+
if joint_force_references is not None
|
263
|
+
else jnp.zeros((model.number_of_joints(),))
|
264
|
+
)
|
265
|
+
|
266
|
+
# Build a references object to simplify converting link forces.
|
267
|
+
references = js.references.JaxSimModelReferences.build(
|
268
|
+
model=model,
|
269
|
+
data=data,
|
270
|
+
velocity_representation=data.velocity_representation,
|
271
|
+
link_forces=link_forces,
|
272
|
+
joint_force_references=joint_force_references,
|
273
|
+
)
|
274
|
+
|
275
|
+
# Compute the position and linear velocities (mixed representation) of
|
276
|
+
# all enabled collidable points belonging to the robot.
|
277
|
+
position, velocity = js.contact.collidable_point_kinematics(
|
278
|
+
model=model, data=data
|
279
|
+
)
|
280
|
+
|
281
|
+
# Compute the penetration depth and velocity of the collidable points.
|
282
|
+
# Note that this function considers the penetration in the normal direction.
|
283
|
+
δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
|
284
|
+
position, velocity, model.terrain
|
285
|
+
)
|
286
|
+
|
287
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
288
|
+
|
289
|
+
with (
|
290
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
291
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
292
|
+
):
|
293
|
+
# Compute kin-dyn quantities used in the contact model.
|
294
|
+
BW_ν = data.generalized_velocity
|
295
|
+
|
296
|
+
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
297
|
+
|
298
|
+
J_WC = js.contact.jacobian(model=model, data=data)
|
299
|
+
J̇_WC = js.contact.jacobian_derivative(model=model, data=data)
|
300
|
+
|
301
|
+
# Compute the generalized free acceleration.
|
302
|
+
BW_ν̇_free = jnp.hstack(
|
303
|
+
js.model.forward_dynamics_aba(
|
304
|
+
model=model,
|
305
|
+
data=data,
|
306
|
+
link_forces=references.link_forces(model=model, data=data),
|
307
|
+
joint_forces=references.joint_force_references(model=model),
|
308
|
+
)
|
309
|
+
)
|
310
|
+
|
311
|
+
# Compute the free linear acceleration of the collidable points.
|
312
|
+
# Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.
|
313
|
+
free_contact_acc = _linear_acceleration_of_collidable_points(
|
314
|
+
BW_nu=BW_ν,
|
315
|
+
BW_nu_dot=BW_ν̇_free,
|
316
|
+
CW_J_WC_BW=J_WC,
|
317
|
+
CW_J_dot_WC_BW=J̇_WC,
|
318
|
+
).flatten()
|
319
|
+
|
320
|
+
# Compute stabilization term.
|
321
|
+
baumgarte_term = _compute_baumgarte_stabilization_term(
|
322
|
+
inactive_collidable_points=(δ <= 0),
|
323
|
+
δ=δ,
|
324
|
+
δ_dot=δ_dot,
|
325
|
+
n=n̂,
|
326
|
+
K=model.contact_params.K,
|
327
|
+
D=model.contact_params.D,
|
328
|
+
).flatten()
|
329
|
+
|
330
|
+
# Compute the Delassus matrix.
|
331
|
+
delassus_matrix = _delassus_matrix(M=M, J_WC=J_WC)
|
332
|
+
|
333
|
+
# Initialize regularization term of the Delassus matrix for
|
334
|
+
# better numerical conditioning.
|
335
|
+
Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0])
|
336
|
+
|
337
|
+
# Construct the quadratic cost function.
|
338
|
+
Q = delassus_matrix + Iε
|
339
|
+
q = free_contact_acc - baumgarte_term
|
340
|
+
|
341
|
+
# Construct the inequality constraints.
|
342
|
+
G = _compute_ineq_constraint_matrix(
|
343
|
+
inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu
|
344
|
+
)
|
345
|
+
h_bounds = jnp.zeros(shape=(n_collidable_points * 6,))
|
346
|
+
|
347
|
+
# Construct the equality constraints.
|
348
|
+
A = jnp.zeros((0, 3 * n_collidable_points))
|
349
|
+
b = jnp.zeros((0,))
|
350
|
+
|
351
|
+
# Solve the following optimization problem with qpax:
|
352
|
+
#
|
353
|
+
# min_{x} 0.5 x⊤ Q x + q⊤ x
|
354
|
+
#
|
355
|
+
# s.t. A x = b
|
356
|
+
# G x ≤ h
|
357
|
+
#
|
358
|
+
# TODO: add possibility to notify if the QP problem did not converge.
|
359
|
+
solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841
|
360
|
+
Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
|
361
|
+
)
|
362
|
+
|
363
|
+
# Reshape the optimized solution to be a matrix of 3D contact forces.
|
364
|
+
CW_fl_C = solution.reshape(-1, 3)
|
365
|
+
|
366
|
+
# Convert the contact forces from mixed to inertial-fixed representation.
|
367
|
+
W_f_C = jax.vmap(
|
368
|
+
lambda CW_fl_C, W_H_C: (
|
369
|
+
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
370
|
+
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
|
371
|
+
transform=W_H_C,
|
372
|
+
other_representation=VelRepr.Mixed,
|
373
|
+
is_force=True,
|
374
|
+
)
|
375
|
+
),
|
376
|
+
)(CW_fl_C, W_H_C)
|
377
|
+
|
378
|
+
return W_f_C, {}
|
379
|
+
|
380
|
+
@jax.jit
|
381
|
+
@js.common.named_scope
|
382
|
+
def update_velocity_after_impact(
|
383
|
+
self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
384
|
+
) -> js.data.JaxSimModelData:
|
385
|
+
"""
|
386
|
+
Update the velocity after an impact.
|
387
|
+
|
388
|
+
Args:
|
389
|
+
model: The robot model considered by the contact model.
|
390
|
+
data: The data of the considered model.
|
391
|
+
|
392
|
+
Returns:
|
393
|
+
The updated data of the considered model.
|
394
|
+
"""
|
395
|
+
|
396
|
+
# Extract the indices corresponding to the enabled collidable points.
|
397
|
+
indices_of_enabled_collidable_points = (
|
398
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
399
|
+
)
|
400
|
+
|
401
|
+
W_p_C = js.contact.collidable_point_positions(model, data)[
|
402
|
+
indices_of_enabled_collidable_points
|
403
|
+
]
|
404
|
+
|
405
|
+
# Compute the penetration depth of the collidable points.
|
406
|
+
δ, *_ = jax.vmap(
|
407
|
+
common.compute_penetration_data,
|
408
|
+
in_axes=(0, 0, None),
|
409
|
+
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
|
410
|
+
|
411
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
412
|
+
J_WC = js.contact.jacobian(model, data)[
|
413
|
+
indices_of_enabled_collidable_points
|
414
|
+
]
|
415
|
+
M = js.model.free_floating_mass_matrix(model, data)
|
416
|
+
BW_ν_pre_impact = data.generalized_velocity
|
417
|
+
|
418
|
+
# Compute the impact velocity.
|
419
|
+
# It may be discontinuous in case new contacts are made.
|
420
|
+
BW_ν_post_impact = RigidContacts.compute_impact_velocity(
|
421
|
+
generalized_velocity=BW_ν_pre_impact,
|
422
|
+
inactive_collidable_points=(δ <= 0),
|
423
|
+
M=M,
|
424
|
+
J_WC=J_WC,
|
425
|
+
)
|
426
|
+
|
427
|
+
BW_ν_post_impact_inertial = data.other_representation_to_inertial(
|
428
|
+
array=BW_ν_post_impact[0:6],
|
429
|
+
other_representation=VelRepr.Mixed,
|
430
|
+
transform=data._base_transform.at[0:3, 0:3].set(jnp.eye(3)),
|
431
|
+
is_force=False,
|
432
|
+
)
|
433
|
+
|
434
|
+
# Reset the generalized velocity.
|
435
|
+
data = dataclasses.replace(
|
436
|
+
data,
|
437
|
+
_base_linear_velocity=BW_ν_post_impact_inertial[0:3],
|
438
|
+
_base_angular_velocity=BW_ν_post_impact_inertial[3:6],
|
439
|
+
_joint_velocities=BW_ν_post_impact[6:],
|
440
|
+
)
|
441
|
+
|
442
|
+
return data
|
443
|
+
|
444
|
+
def update_contact_state(
|
445
|
+
self: type[Self], old_contact_state: dict[str, jtp.Array]
|
446
|
+
) -> dict[str, jtp.Array]:
|
447
|
+
"""
|
448
|
+
Update the contact state.
|
449
|
+
|
450
|
+
Args:
|
451
|
+
old_contact_state: The old contact state.
|
452
|
+
|
453
|
+
Returns:
|
454
|
+
The updated contact state.
|
455
|
+
"""
|
456
|
+
|
457
|
+
return {}
|
458
|
+
|
459
|
+
|
460
|
+
@staticmethod
|
461
|
+
def _delassus_matrix(
|
462
|
+
M: jtp.MatrixLike,
|
463
|
+
J_WC: jtp.MatrixLike,
|
464
|
+
) -> jtp.Matrix:
|
465
|
+
|
466
|
+
sl = jnp.s_[:, 0:3, :]
|
467
|
+
J_WC_lin = jnp.vstack(J_WC[sl])
|
468
|
+
|
469
|
+
delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
|
470
|
+
return delassus_matrix
|
471
|
+
|
472
|
+
|
473
|
+
@jax.jit
|
474
|
+
@js.common.named_scope
|
475
|
+
def _compute_ineq_constraint_matrix(
|
476
|
+
inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
|
477
|
+
) -> jtp.Matrix:
|
478
|
+
"""
|
479
|
+
Compute the inequality constraint matrix for a single collidable point.
|
480
|
+
|
481
|
+
Rows 0-3: enforce the friction pyramid constraint,
|
482
|
+
Row 4: last one is for the non negativity of the vertical force
|
483
|
+
Row 5: contact complementarity condition
|
484
|
+
"""
|
485
|
+
G_single_point = jnp.array(
|
486
|
+
[
|
487
|
+
[1, 0, -mu],
|
488
|
+
[0, 1, -mu],
|
489
|
+
[-1, 0, -mu],
|
490
|
+
[0, -1, -mu],
|
491
|
+
[0, 0, -1],
|
492
|
+
[0, 0, 0],
|
493
|
+
]
|
494
|
+
)
|
495
|
+
G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1))
|
496
|
+
G = G.at[:, 5, 2].set(inactive_collidable_points)
|
497
|
+
|
498
|
+
G = jax.scipy.linalg.block_diag(*G)
|
499
|
+
return G
|
500
|
+
|
501
|
+
|
502
|
+
@jax.jit
|
503
|
+
@js.common.named_scope
|
504
|
+
def _linear_acceleration_of_collidable_points(
|
505
|
+
BW_nu: jtp.ArrayLike,
|
506
|
+
BW_nu_dot: jtp.ArrayLike,
|
507
|
+
CW_J_WC_BW: jtp.MatrixLike,
|
508
|
+
CW_J_dot_WC_BW: jtp.MatrixLike,
|
509
|
+
) -> jtp.Matrix:
|
510
|
+
|
511
|
+
BW_ν = BW_nu
|
512
|
+
BW_ν̇ = BW_nu_dot
|
513
|
+
CW_J̇_WC_BW = CW_J_dot_WC_BW
|
514
|
+
|
515
|
+
# Compute the linear acceleration of the collidable points.
|
516
|
+
# Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
|
517
|
+
CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
|
518
|
+
|
519
|
+
CW_a_WC = CW_a_WC.reshape(-1, 6)
|
520
|
+
return CW_a_WC[:, 0:3].squeeze()
|
521
|
+
|
522
|
+
|
523
|
+
@jax.jit
|
524
|
+
@js.common.named_scope
|
525
|
+
def _compute_baumgarte_stabilization_term(
|
526
|
+
inactive_collidable_points: jtp.ArrayLike,
|
527
|
+
δ: jtp.ArrayLike,
|
528
|
+
δ_dot: jtp.ArrayLike,
|
529
|
+
n: jtp.ArrayLike,
|
530
|
+
K: jtp.FloatLike,
|
531
|
+
D: jtp.FloatLike,
|
532
|
+
) -> jtp.Array:
|
533
|
+
|
534
|
+
return jnp.where(
|
535
|
+
inactive_collidable_points[:, jnp.newaxis],
|
536
|
+
jnp.zeros_like(n),
|
537
|
+
(K * δ + D * δ_dot)[:, jnp.newaxis] * n,
|
538
|
+
)
|