jaxsim 0.6.2.dev182__py3-none-any.whl → 0.6.2.dev225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,538 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import Any
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import jax_dataclasses
9
+ import qpax
10
+
11
+ import jaxsim.api as js
12
+ import jaxsim.typing as jtp
13
+ from jaxsim import logging
14
+ from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
15
+
16
+ from . import common
17
+ from .common import ContactModel, ContactsParams
18
+
19
+ try:
20
+ from typing import Self
21
+ except ImportError:
22
+ from typing_extensions import Self
23
+
24
+
25
+ @jax_dataclasses.pytree_dataclass
26
+ class RigidContactsParams(ContactsParams):
27
+ """Parameters of the rigid contacts model."""
28
+
29
+ # Static friction coefficient
30
+ mu: jtp.Float = dataclasses.field(
31
+ default_factory=lambda: jnp.array(0.5, dtype=float)
32
+ )
33
+
34
+ # Baumgarte proportional term
35
+ K: jtp.Float = dataclasses.field(
36
+ default_factory=lambda: jnp.array(0.0, dtype=float)
37
+ )
38
+
39
+ # Baumgarte derivative term
40
+ D: jtp.Float = dataclasses.field(
41
+ default_factory=lambda: jnp.array(0.0, dtype=float)
42
+ )
43
+
44
+ def __hash__(self) -> int:
45
+ from jaxsim.utils.wrappers import HashedNumpyArray
46
+
47
+ return hash(
48
+ (
49
+ HashedNumpyArray.hash_of_array(self.mu),
50
+ HashedNumpyArray.hash_of_array(self.K),
51
+ HashedNumpyArray.hash_of_array(self.D),
52
+ )
53
+ )
54
+
55
+ def __eq__(self, other: RigidContactsParams) -> bool:
56
+ if not isinstance(other, RigidContactsParams):
57
+ return False
58
+
59
+ return hash(self) == hash(other)
60
+
61
+ @classmethod
62
+ def build(
63
+ cls: type[Self],
64
+ *,
65
+ mu: jtp.FloatLike | None = None,
66
+ K: jtp.FloatLike | None = None,
67
+ D: jtp.FloatLike | None = None,
68
+ ) -> Self:
69
+ """Create a `RigidContactParams` instance."""
70
+
71
+ return cls(
72
+ mu=jnp.array(
73
+ mu
74
+ if mu is not None
75
+ else cls.__dataclass_fields__["mu"].default_factory()
76
+ ).astype(float),
77
+ K=jnp.array(
78
+ K if K is not None else cls.__dataclass_fields__["K"].default_factory()
79
+ ).astype(float),
80
+ D=jnp.array(
81
+ D if D is not None else cls.__dataclass_fields__["D"].default_factory()
82
+ ).astype(float),
83
+ )
84
+
85
+ def valid(self) -> jtp.BoolLike:
86
+ """Check if the parameters are valid."""
87
+ return bool(
88
+ jnp.all(self.mu >= 0.0)
89
+ and jnp.all(self.K >= 0.0)
90
+ and jnp.all(self.D >= 0.0)
91
+ )
92
+
93
+
94
+ @jax_dataclasses.pytree_dataclass
95
+ class RigidContacts(ContactModel):
96
+ """Rigid contacts model."""
97
+
98
+ regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field(
99
+ default=1e-6, kw_only=True
100
+ )
101
+
102
+ _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
103
+ default=("solver_tol",), kw_only=True
104
+ )
105
+ _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
106
+ default=(1e-3,), kw_only=True
107
+ )
108
+
109
+ @property
110
+ def solver_options(self) -> dict[str, Any]:
111
+ """Get the solver options as a dictionary."""
112
+
113
+ return dict(
114
+ zip(
115
+ self._solver_options_keys,
116
+ self._solver_options_values,
117
+ strict=True,
118
+ )
119
+ )
120
+
121
+ @classmethod
122
+ def build(
123
+ cls: type[Self],
124
+ regularization_delassus: jtp.FloatLike | None = None,
125
+ solver_options: dict[str, Any] | None = None,
126
+ **kwargs,
127
+ ) -> Self:
128
+ """
129
+ Create a `RigidContacts` instance with specified parameters.
130
+
131
+ Args:
132
+ regularization_delassus:
133
+ The regularization term to add to the diagonal of the Delassus matrix.
134
+ solver_options: The options to pass to the QP solver.
135
+ **kwargs: Extra arguments which are ignored.
136
+
137
+ Returns:
138
+ The `RigidContacts` instance.
139
+ """
140
+
141
+ if len(kwargs) != 0:
142
+ logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
143
+
144
+ # Get the default solver options.
145
+ default_solver_options = dict(
146
+ zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
147
+ )
148
+
149
+ # Create the solver options to set by combining the default solver options
150
+ # with the user-provided solver options.
151
+ solver_options = default_solver_options | (
152
+ solver_options if solver_options is not None else {}
153
+ )
154
+
155
+ # Make sure that the solver options are hashable.
156
+ # We need to check this because the solver options are static.
157
+ try:
158
+ hash(tuple(solver_options.values()))
159
+ except TypeError as exc:
160
+ raise ValueError(
161
+ "The values of the solver options must be hashable."
162
+ ) from exc
163
+
164
+ return cls(
165
+ regularization_delassus=float(
166
+ regularization_delassus
167
+ if regularization_delassus is not None
168
+ else cls.__dataclass_fields__["regularization_delassus"].default
169
+ ),
170
+ _solver_options_keys=tuple(solver_options.keys()),
171
+ _solver_options_values=tuple(solver_options.values()),
172
+ **kwargs,
173
+ )
174
+
175
+ @staticmethod
176
+ def compute_impact_velocity(
177
+ inactive_collidable_points: jtp.ArrayLike,
178
+ M: jtp.MatrixLike,
179
+ J_WC: jtp.MatrixLike,
180
+ generalized_velocity: jtp.VectorLike,
181
+ ) -> jtp.Vector:
182
+ """
183
+ Return the new velocity of the system after a potential impact.
184
+
185
+ Args:
186
+ inactive_collidable_points: The activation state of the collidable points.
187
+ M: The mass matrix of the system (in mixed representation).
188
+ J_WC: The Jacobian matrix of the collidable points (in mixed representation).
189
+ generalized_velocity: The generalized velocity of the system.
190
+
191
+ Note:
192
+ The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity`
193
+ must be expressed in the same velocity representation.
194
+ """
195
+
196
+ # Compute system velocity after impact maintaining zero linear velocity of active points.
197
+ sl = jnp.s_[:, 0:3, :]
198
+ Jl_WC = J_WC[sl]
199
+
200
+ # Zero out the jacobian rows of inactive points.
201
+ Jl_WC = jnp.vstack(
202
+ jnp.where(
203
+ inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
204
+ jnp.zeros_like(Jl_WC),
205
+ Jl_WC,
206
+ )
207
+ )
208
+
209
+ A = jnp.vstack(
210
+ [
211
+ jnp.hstack([M, -Jl_WC.T]),
212
+ jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]),
213
+ ]
214
+ )
215
+ b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])])
216
+
217
+ BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0]
218
+
219
+ return BW_ν_post_impact[0 : M.shape[0]]
220
+
221
+ @jax.jit
222
+ @js.common.named_scope
223
+ def compute_contact_forces(
224
+ self,
225
+ model: js.model.JaxSimModel,
226
+ data: js.data.JaxSimModelData,
227
+ *,
228
+ link_forces: jtp.MatrixLike | None = None,
229
+ joint_force_references: jtp.VectorLike | None = None,
230
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
231
+ """
232
+ Compute the contact forces.
233
+
234
+ Args:
235
+ model: The model to consider.
236
+ data: The data of the considered model.
237
+ link_forces:
238
+ Optional `(n_links, 6)` matrix of external forces acting on the links,
239
+ expressed in the same representation of data.
240
+ joint_force_references:
241
+ Optional `(n_joints,)` vector of joint forces.
242
+
243
+ Returns:
244
+ A tuple containing as first element the computed contact forces.
245
+ """
246
+
247
+ # Get the indices of the enabled collidable points.
248
+ indices_of_enabled_collidable_points = (
249
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
250
+ )
251
+
252
+ n_collidable_points = len(indices_of_enabled_collidable_points)
253
+
254
+ link_forces = jnp.atleast_2d(
255
+ jnp.array(link_forces, dtype=float).squeeze()
256
+ if link_forces is not None
257
+ else jnp.zeros((model.number_of_links(), 6))
258
+ )
259
+
260
+ joint_force_references = jnp.atleast_1d(
261
+ jnp.array(joint_force_references, dtype=float).squeeze()
262
+ if joint_force_references is not None
263
+ else jnp.zeros((model.number_of_joints(),))
264
+ )
265
+
266
+ # Build a references object to simplify converting link forces.
267
+ references = js.references.JaxSimModelReferences.build(
268
+ model=model,
269
+ data=data,
270
+ velocity_representation=data.velocity_representation,
271
+ link_forces=link_forces,
272
+ joint_force_references=joint_force_references,
273
+ )
274
+
275
+ # Compute the position and linear velocities (mixed representation) of
276
+ # all enabled collidable points belonging to the robot.
277
+ position, velocity = js.contact.collidable_point_kinematics(
278
+ model=model, data=data
279
+ )
280
+
281
+ # Compute the penetration depth and velocity of the collidable points.
282
+ # Note that this function considers the penetration in the normal direction.
283
+ δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
284
+ position, velocity, model.terrain
285
+ )
286
+
287
+ W_H_C = js.contact.transforms(model=model, data=data)
288
+
289
+ with (
290
+ references.switch_velocity_representation(VelRepr.Mixed),
291
+ data.switch_velocity_representation(VelRepr.Mixed),
292
+ ):
293
+ # Compute kin-dyn quantities used in the contact model.
294
+ BW_ν = data.generalized_velocity
295
+
296
+ M = js.model.free_floating_mass_matrix(model=model, data=data)
297
+
298
+ J_WC = js.contact.jacobian(model=model, data=data)
299
+ J̇_WC = js.contact.jacobian_derivative(model=model, data=data)
300
+
301
+ # Compute the generalized free acceleration.
302
+ BW_ν̇_free = jnp.hstack(
303
+ js.model.forward_dynamics_aba(
304
+ model=model,
305
+ data=data,
306
+ link_forces=references.link_forces(model=model, data=data),
307
+ joint_forces=references.joint_force_references(model=model),
308
+ )
309
+ )
310
+
311
+ # Compute the free linear acceleration of the collidable points.
312
+ # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.
313
+ free_contact_acc = _linear_acceleration_of_collidable_points(
314
+ BW_nu=BW_ν,
315
+ BW_nu_dot=BW_ν̇_free,
316
+ CW_J_WC_BW=J_WC,
317
+ CW_J_dot_WC_BW=J̇_WC,
318
+ ).flatten()
319
+
320
+ # Compute stabilization term.
321
+ baumgarte_term = _compute_baumgarte_stabilization_term(
322
+ inactive_collidable_points=(δ <= 0),
323
+ δ=δ,
324
+ δ_dot=δ_dot,
325
+ n=n̂,
326
+ K=model.contact_params.K,
327
+ D=model.contact_params.D,
328
+ ).flatten()
329
+
330
+ # Compute the Delassus matrix.
331
+ delassus_matrix = _delassus_matrix(M=M, J_WC=J_WC)
332
+
333
+ # Initialize regularization term of the Delassus matrix for
334
+ # better numerical conditioning.
335
+ Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0])
336
+
337
+ # Construct the quadratic cost function.
338
+ Q = delassus_matrix + Iε
339
+ q = free_contact_acc - baumgarte_term
340
+
341
+ # Construct the inequality constraints.
342
+ G = _compute_ineq_constraint_matrix(
343
+ inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu
344
+ )
345
+ h_bounds = jnp.zeros(shape=(n_collidable_points * 6,))
346
+
347
+ # Construct the equality constraints.
348
+ A = jnp.zeros((0, 3 * n_collidable_points))
349
+ b = jnp.zeros((0,))
350
+
351
+ # Solve the following optimization problem with qpax:
352
+ #
353
+ # min_{x} 0.5 x⊤ Q x + q⊤ x
354
+ #
355
+ # s.t. A x = b
356
+ # G x ≤ h
357
+ #
358
+ # TODO: add possibility to notify if the QP problem did not converge.
359
+ solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841
360
+ Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
361
+ )
362
+
363
+ # Reshape the optimized solution to be a matrix of 3D contact forces.
364
+ CW_fl_C = solution.reshape(-1, 3)
365
+
366
+ # Convert the contact forces from mixed to inertial-fixed representation.
367
+ W_f_C = jax.vmap(
368
+ lambda CW_fl_C, W_H_C: (
369
+ ModelDataWithVelocityRepresentation.other_representation_to_inertial(
370
+ array=jnp.zeros(6).at[0:3].set(CW_fl_C),
371
+ transform=W_H_C,
372
+ other_representation=VelRepr.Mixed,
373
+ is_force=True,
374
+ )
375
+ ),
376
+ )(CW_fl_C, W_H_C)
377
+
378
+ return W_f_C, {}
379
+
380
+ @jax.jit
381
+ @js.common.named_scope
382
+ def update_velocity_after_impact(
383
+ self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
384
+ ) -> js.data.JaxSimModelData:
385
+ """
386
+ Update the velocity after an impact.
387
+
388
+ Args:
389
+ model: The robot model considered by the contact model.
390
+ data: The data of the considered model.
391
+
392
+ Returns:
393
+ The updated data of the considered model.
394
+ """
395
+
396
+ # Extract the indices corresponding to the enabled collidable points.
397
+ indices_of_enabled_collidable_points = (
398
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
399
+ )
400
+
401
+ W_p_C = js.contact.collidable_point_positions(model, data)[
402
+ indices_of_enabled_collidable_points
403
+ ]
404
+
405
+ # Compute the penetration depth of the collidable points.
406
+ δ, *_ = jax.vmap(
407
+ common.compute_penetration_data,
408
+ in_axes=(0, 0, None),
409
+ )(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
410
+
411
+ with data.switch_velocity_representation(VelRepr.Mixed):
412
+ J_WC = js.contact.jacobian(model, data)[
413
+ indices_of_enabled_collidable_points
414
+ ]
415
+ M = js.model.free_floating_mass_matrix(model, data)
416
+ BW_ν_pre_impact = data.generalized_velocity
417
+
418
+ # Compute the impact velocity.
419
+ # It may be discontinuous in case new contacts are made.
420
+ BW_ν_post_impact = RigidContacts.compute_impact_velocity(
421
+ generalized_velocity=BW_ν_pre_impact,
422
+ inactive_collidable_points=(δ <= 0),
423
+ M=M,
424
+ J_WC=J_WC,
425
+ )
426
+
427
+ BW_ν_post_impact_inertial = data.other_representation_to_inertial(
428
+ array=BW_ν_post_impact[0:6],
429
+ other_representation=VelRepr.Mixed,
430
+ transform=data._base_transform.at[0:3, 0:3].set(jnp.eye(3)),
431
+ is_force=False,
432
+ )
433
+
434
+ # Reset the generalized velocity.
435
+ data = dataclasses.replace(
436
+ data,
437
+ _base_linear_velocity=BW_ν_post_impact_inertial[0:3],
438
+ _base_angular_velocity=BW_ν_post_impact_inertial[3:6],
439
+ _joint_velocities=BW_ν_post_impact[6:],
440
+ )
441
+
442
+ return data
443
+
444
+ def update_contact_state(
445
+ self: type[Self], old_contact_state: dict[str, jtp.Array]
446
+ ) -> dict[str, jtp.Array]:
447
+ """
448
+ Update the contact state.
449
+
450
+ Args:
451
+ old_contact_state: The old contact state.
452
+
453
+ Returns:
454
+ The updated contact state.
455
+ """
456
+
457
+ return {}
458
+
459
+
460
+ @staticmethod
461
+ def _delassus_matrix(
462
+ M: jtp.MatrixLike,
463
+ J_WC: jtp.MatrixLike,
464
+ ) -> jtp.Matrix:
465
+
466
+ sl = jnp.s_[:, 0:3, :]
467
+ J_WC_lin = jnp.vstack(J_WC[sl])
468
+
469
+ delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
470
+ return delassus_matrix
471
+
472
+
473
+ @jax.jit
474
+ @js.common.named_scope
475
+ def _compute_ineq_constraint_matrix(
476
+ inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
477
+ ) -> jtp.Matrix:
478
+ """
479
+ Compute the inequality constraint matrix for a single collidable point.
480
+
481
+ Rows 0-3: enforce the friction pyramid constraint,
482
+ Row 4: last one is for the non negativity of the vertical force
483
+ Row 5: contact complementarity condition
484
+ """
485
+ G_single_point = jnp.array(
486
+ [
487
+ [1, 0, -mu],
488
+ [0, 1, -mu],
489
+ [-1, 0, -mu],
490
+ [0, -1, -mu],
491
+ [0, 0, -1],
492
+ [0, 0, 0],
493
+ ]
494
+ )
495
+ G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1))
496
+ G = G.at[:, 5, 2].set(inactive_collidable_points)
497
+
498
+ G = jax.scipy.linalg.block_diag(*G)
499
+ return G
500
+
501
+
502
+ @jax.jit
503
+ @js.common.named_scope
504
+ def _linear_acceleration_of_collidable_points(
505
+ BW_nu: jtp.ArrayLike,
506
+ BW_nu_dot: jtp.ArrayLike,
507
+ CW_J_WC_BW: jtp.MatrixLike,
508
+ CW_J_dot_WC_BW: jtp.MatrixLike,
509
+ ) -> jtp.Matrix:
510
+
511
+ BW_ν = BW_nu
512
+ BW_ν̇ = BW_nu_dot
513
+ CW_J̇_WC_BW = CW_J_dot_WC_BW
514
+
515
+ # Compute the linear acceleration of the collidable points.
516
+ # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
517
+ CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
518
+
519
+ CW_a_WC = CW_a_WC.reshape(-1, 6)
520
+ return CW_a_WC[:, 0:3].squeeze()
521
+
522
+
523
+ @jax.jit
524
+ @js.common.named_scope
525
+ def _compute_baumgarte_stabilization_term(
526
+ inactive_collidable_points: jtp.ArrayLike,
527
+ δ: jtp.ArrayLike,
528
+ δ_dot: jtp.ArrayLike,
529
+ n: jtp.ArrayLike,
530
+ K: jtp.FloatLike,
531
+ D: jtp.FloatLike,
532
+ ) -> jtp.Array:
533
+
534
+ return jnp.where(
535
+ inactive_collidable_points[:, jnp.newaxis],
536
+ jnp.zeros_like(n),
537
+ (K * δ + D * δ_dot)[:, jnp.newaxis] * n,
538
+ )