jaxsim 0.4.3.dev159__py3-none-any.whl → 0.4.3.dev169__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1053 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import functools
5
+ from typing import Any
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jax_dataclasses
10
+
11
+ import jaxsim
12
+ import jaxsim.api as js
13
+ import jaxsim.exceptions
14
+ import jaxsim.typing as jtp
15
+ from jaxsim import logging
16
+ from jaxsim.math import StandardGravity
17
+ from jaxsim.terrain import FlatTerrain, Terrain
18
+
19
+ from . import common
20
+ from .soft import SoftContacts, SoftContactsParams
21
+
22
+ try:
23
+ from typing import Self
24
+ except ImportError:
25
+ from typing_extensions import Self
26
+
27
+
28
+ @jax_dataclasses.pytree_dataclass
29
+ class ViscoElasticContactsParams(common.ContactsParams):
30
+ """Parameters of the visco-elastic contacts model."""
31
+
32
+ K: jtp.Float = dataclasses.field(
33
+ default_factory=lambda: jnp.array(1e6, dtype=float)
34
+ )
35
+
36
+ D: jtp.Float = dataclasses.field(
37
+ default_factory=lambda: jnp.array(2000, dtype=float)
38
+ )
39
+
40
+ static_friction: jtp.Float = dataclasses.field(
41
+ default_factory=lambda: jnp.array(0.5, dtype=float)
42
+ )
43
+
44
+ p: jtp.Float = dataclasses.field(
45
+ default_factory=lambda: jnp.array(0.5, dtype=float)
46
+ )
47
+
48
+ q: jtp.Float = dataclasses.field(
49
+ default_factory=lambda: jnp.array(0.5, dtype=float)
50
+ )
51
+
52
+ @classmethod
53
+ def build(
54
+ cls: type[Self],
55
+ K: jtp.FloatLike = 1e6,
56
+ D: jtp.FloatLike = 2_000,
57
+ static_friction: jtp.FloatLike = 0.5,
58
+ p: jtp.FloatLike = 0.5,
59
+ q: jtp.FloatLike = 0.5,
60
+ ) -> Self:
61
+ """
62
+ Create a SoftContactsParams instance with specified parameters.
63
+
64
+ Args:
65
+ K: The stiffness parameter.
66
+ D: The damping parameter of the soft contacts model.
67
+ static_friction: The static friction coefficient.
68
+ p:
69
+ The exponent p corresponding to the damping-related non-linearity
70
+ of the Hunt/Crossley model.
71
+ q:
72
+ The exponent q corresponding to the spring-related non-linearity
73
+ of the Hunt/Crossley model.
74
+
75
+ Returns:
76
+ A ViscoElasticParams instance with the specified parameters.
77
+ """
78
+
79
+ return ViscoElasticContactsParams(
80
+ K=jnp.array(K, dtype=float),
81
+ D=jnp.array(D, dtype=float),
82
+ static_friction=jnp.array(static_friction, dtype=float),
83
+ p=jnp.array(p, dtype=float),
84
+ q=jnp.array(q, dtype=float),
85
+ )
86
+
87
+ @classmethod
88
+ def build_default_from_jaxsim_model(
89
+ cls: type[Self],
90
+ model: js.model.JaxSimModel,
91
+ *,
92
+ standard_gravity: jtp.FloatLike = StandardGravity,
93
+ static_friction_coefficient: jtp.FloatLike = 0.5,
94
+ max_penetration: jtp.FloatLike = 0.001,
95
+ number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
96
+ damping_ratio: jtp.FloatLike = 1.0,
97
+ p: jtp.FloatLike = 0.5,
98
+ q: jtp.FloatLike = 0.5,
99
+ ) -> Self:
100
+ """
101
+ Create a ViscoElasticContactsParams instance with good default parameters.
102
+
103
+ Args:
104
+ model: The target model.
105
+ standard_gravity: The standard gravity constant.
106
+ static_friction_coefficient:
107
+ The static friction coefficient between the model and the terrain.
108
+ max_penetration: The maximum penetration depth.
109
+ number_of_active_collidable_points_steady_state:
110
+ The number of contacts supporting the weight of the model
111
+ in steady state.
112
+ damping_ratio: The ratio controlling the damping behavior.
113
+ p:
114
+ The exponent p corresponding to the damping-related non-linearity
115
+ of the Hunt/Crossley model.
116
+ q:
117
+ The exponent q corresponding to the spring-related non-linearity
118
+ of the Hunt/Crossley model.
119
+
120
+ Returns:
121
+ A `ViscoElasticContactsParams` instance with the specified parameters.
122
+
123
+ Note:
124
+ The `damping_ratio` parameter allows to operate on the following conditions:
125
+ - ξ > 1.0: over-damped
126
+ - ξ = 1.0: critically damped
127
+ - ξ < 1.0: under-damped
128
+ """
129
+
130
+ # Call the SoftContact builder instead of duplicating the logic.
131
+ soft_contacts_params = SoftContactsParams.build_default_from_jaxsim_model(
132
+ model=model,
133
+ standard_gravity=standard_gravity,
134
+ static_friction_coefficient=static_friction_coefficient,
135
+ max_penetration=max_penetration,
136
+ number_of_active_collidable_points_steady_state=number_of_active_collidable_points_steady_state,
137
+ damping_ratio=damping_ratio,
138
+ )
139
+
140
+ return ViscoElasticContactsParams.build(
141
+ K=soft_contacts_params.K,
142
+ D=soft_contacts_params.D,
143
+ static_friction=soft_contacts_params.mu,
144
+ p=p,
145
+ q=q,
146
+ )
147
+
148
+ def valid(self) -> jtp.BoolLike:
149
+ """
150
+ Check if the parameters are valid.
151
+
152
+ Returns:
153
+ `True` if the parameters are valid, `False` otherwise.
154
+ """
155
+
156
+ return (
157
+ jnp.all(self.K >= 0.0)
158
+ and jnp.all(self.D >= 0.0)
159
+ and jnp.all(self.static_friction >= 0.0)
160
+ and jnp.all(self.p >= 0.0)
161
+ and jnp.all(self.q >= 0.0)
162
+ )
163
+
164
+ def __hash__(self) -> int:
165
+
166
+ from jaxsim.utils.wrappers import HashedNumpyArray
167
+
168
+ return hash(
169
+ (
170
+ HashedNumpyArray.hash_of_array(self.K),
171
+ HashedNumpyArray.hash_of_array(self.D),
172
+ HashedNumpyArray.hash_of_array(self.static_friction),
173
+ HashedNumpyArray.hash_of_array(self.p),
174
+ HashedNumpyArray.hash_of_array(self.q),
175
+ )
176
+ )
177
+
178
+ def __eq__(self, other: ViscoElasticContactsParams) -> bool:
179
+
180
+ if not isinstance(other, ViscoElasticContactsParams):
181
+ return False
182
+
183
+ return hash(self) == hash(other)
184
+
185
+
186
+ @jax_dataclasses.pytree_dataclass
187
+ class ViscoElasticContacts(common.ContactModel):
188
+ """Visco-elastic contacts model."""
189
+
190
+ parameters: ViscoElasticContactsParams = dataclasses.field(
191
+ default_factory=ViscoElasticContactsParams
192
+ )
193
+
194
+ terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
195
+ default_factory=FlatTerrain
196
+ )
197
+
198
+ max_squarings: jax_dataclasses.Static[int] = 25
199
+
200
+ @classmethod
201
+ def build(
202
+ cls: type[Self],
203
+ parameters: SoftContactsParams | None = None,
204
+ terrain: Terrain | None = None,
205
+ model: js.model.JaxSimModel | None = None,
206
+ max_squarings: jtp.IntLike | None = None,
207
+ **kwargs,
208
+ ) -> Self:
209
+ """
210
+ Create a `ViscoElasticContacts` instance with specified parameters.
211
+
212
+ Args:
213
+ parameters: The parameters of the soft contacts model.
214
+ terrain: The considered terrain.
215
+ model:
216
+ The robot model considered by the contact model.
217
+ If passed, it is used to estimate good default parameters.
218
+ max_squarings:
219
+ The maximum number of squarings performed in the matrix exponential.
220
+
221
+ Returns:
222
+ The `ViscoElasticContacts` instance.
223
+ """
224
+
225
+ if len(kwargs) != 0:
226
+ logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
227
+
228
+ # Build the contact parameters if not provided. Use the model to estimate
229
+ # good default parameters, if passed. Users can later override these default
230
+ # parameters with their own values -- possibly tuned better.
231
+ if parameters is None:
232
+ parameters = (
233
+ ViscoElasticContactsParams.build_default_from_jaxsim_model(model=model)
234
+ if model is not None
235
+ else cls.__dataclass_fields__["parameters"].default_factory()
236
+ )
237
+
238
+ return ViscoElasticContacts(
239
+ parameters=parameters,
240
+ terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
241
+ max_squarings=int(
242
+ max_squarings or cls.__dataclass_fields__["max_squarings"].default()
243
+ ),
244
+ )
245
+
246
+ @classmethod
247
+ def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
248
+ """
249
+ Build zero state variables of the contact model.
250
+ """
251
+
252
+ # Initialize the material deformation to zero.
253
+ tangential_deformation = jnp.zeros(
254
+ shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3),
255
+ dtype=float,
256
+ )
257
+
258
+ return {"tangential_deformation": tangential_deformation}
259
+
260
+ @jax.jit
261
+ def compute_contact_forces(
262
+ self,
263
+ model: js.model.JaxSimModel,
264
+ data: js.data.JaxSimModelData,
265
+ *,
266
+ dt: jtp.FloatLike,
267
+ link_forces: jtp.MatrixLike | None = None,
268
+ joint_force_references: jtp.VectorLike | None = None,
269
+ ) -> tuple[jtp.Vector, tuple[Any, ...]]:
270
+ """
271
+ Compute the contact forces.
272
+
273
+ Args:
274
+ model: The robot model considered by the contact model.
275
+ data: The data of the considered model.
276
+ dt: The integration time step.
277
+ link_forces:
278
+ The 6D forces to apply to the links expressed in the frame corresponding
279
+ to the velocity representation of `data`.
280
+ joint_force_references: The joint force references to apply.
281
+
282
+ Note:
283
+ This contact model, contrarily to most other contact models, requires the
284
+ knowledge of the integration step. It is not straightforward to assess how
285
+ this contact model behaves when used with high-order Runge-Kutta schemes.
286
+ For the time being, it is recommended to use a simple forward Euler scheme.
287
+ The main benefit of this model is that the stiff contact dynamics is computed
288
+ separately from the rest of the system dynamics, which allows to use simple
289
+ integration schemes without altering significantly the simulation stability.
290
+
291
+ Returns:
292
+ A tuple containing as first element the computed 6D contact force applied to
293
+ the contact point and expressed in the world frame, and as second element
294
+ a tuple of optional additional information.
295
+ """
296
+
297
+ # Initialize the model and data this contact model is operating on.
298
+ # This will raise an exception if either the contact model or the
299
+ # contact parameters are not compatible.
300
+ model, data = self.initialize_model_and_data(model=model, data=data)
301
+ assert isinstance(data.contacts_params, ViscoElasticContactsParams)
302
+
303
+ # Extract the indices corresponding to the enabled collidable points.
304
+ indices_of_enabled_collidable_points = (
305
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
306
+ )
307
+
308
+ # Compute the average contact linear forces in mixed representation by
309
+ # integrating the contact dynamics in the continuous time domain.
310
+ CW_f̅l, CW_fl̿, m_tf = (
311
+ ViscoElasticContacts._compute_contact_forces_with_exponential_integration(
312
+ model=model,
313
+ data=data,
314
+ dt=dt,
315
+ joint_force_references=joint_force_references,
316
+ link_forces=link_forces,
317
+ indices_of_enabled_collidable_points=indices_of_enabled_collidable_points,
318
+ max_squarings=self.max_squarings,
319
+ )
320
+ )
321
+
322
+ # ============================================
323
+ # Compute the inertial-fixed 6D contact forces
324
+ # ============================================
325
+
326
+ # Compute the transforms of the mixed frames `C[W] = (W_p_C, [W])`
327
+ # associated to each collidable point.
328
+ W_H_C = js.contact.transforms(model=model, data=data)[
329
+ indices_of_enabled_collidable_points, :, :
330
+ ]
331
+
332
+ # Vmapped transformation from mixed to inertial-fixed representation.
333
+ compute_forces_inertial_fixed_vmap = jax.vmap(
334
+ lambda CW_fl_C, W_H_C: data.other_representation_to_inertial(
335
+ array=jnp.zeros(6).at[0:3].set(CW_fl_C),
336
+ other_representation=jaxsim.VelRepr.Mixed,
337
+ transform=W_H_C,
338
+ is_force=True,
339
+ )
340
+ )
341
+
342
+ # Express the linear contact forces in the inertial-fixed frame.
343
+ W_f̅_C, W_f̿_C = jax.vmap(
344
+ lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C)
345
+ )(jnp.stack([CW_f̅l, CW_fl̿]))
346
+
347
+ return W_f̅_C, (W_f̿_C, m_tf)
348
+
349
+ @staticmethod
350
+ @functools.partial(jax.jit, static_argnames=("max_squarings",))
351
+ def _compute_contact_forces_with_exponential_integration(
352
+ model: js.model.JaxSimModel,
353
+ data: js.data.JaxSimModelData,
354
+ *,
355
+ dt: jtp.FloatLike,
356
+ link_forces: jtp.MatrixLike | None = None,
357
+ joint_force_references: jtp.VectorLike | None = None,
358
+ indices_of_enabled_collidable_points: jtp.VectorLike | None = None,
359
+ max_squarings: int = 25,
360
+ ) -> tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]:
361
+ """
362
+ Compute the average contact forces by integrating the contact dynamics.
363
+
364
+ Args:
365
+ model: The robot model considered by the contact model.
366
+ data: The data of the considered model.
367
+ dt: The integration time step.
368
+ link_forces: The 6D forces to apply to the links.
369
+ joint_force_references: The joint force references to apply.
370
+ indices_of_enabled_collidable_points:
371
+ The indices of the enabled collidable points.
372
+ max_squarings:
373
+ The maximum number of squarings performed in the matrix exponential.
374
+
375
+ Returns:
376
+ A tuple containing:
377
+ - The average contact forces.
378
+ - The average of the average contact forces.
379
+ - The tangential deformation at the final state.
380
+ """
381
+
382
+ # ==========================
383
+ # Populate missing arguments
384
+ # ==========================
385
+
386
+ indices = (
387
+ indices_of_enabled_collidable_points
388
+ if indices_of_enabled_collidable_points is not None
389
+ else jnp.arange(
390
+ len(model.kin_dyn_parameters.contact_parameters.body)
391
+ ).astype(int)
392
+ )
393
+
394
+ # ==================================
395
+ # Compute the contact point dynamics
396
+ # ==================================
397
+
398
+ p_t0 = js.contact.collidable_point_positions(model, data)[indices, :]
399
+ v_t0 = js.contact.collidable_point_velocities(model, data)[indices, :]
400
+ m_t0 = data.state.extended["tangential_deformation"][indices, :]
401
+
402
+ # Compute the linearized contact dynamics.
403
+ # Note that it linearizes the (non-linear) contact model at (p, v, m)[t0].
404
+ A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics(
405
+ model=model,
406
+ data=data,
407
+ joint_force_references=joint_force_references,
408
+ link_forces=link_forces,
409
+ indices_of_enabled_collidable_points=indices,
410
+ p_t0=p_t0,
411
+ v_t0=v_t0,
412
+ m_t0=m_t0,
413
+ )
414
+
415
+ # =============================================
416
+ # Compute the integrals of the contact dynamics
417
+ # =============================================
418
+
419
+ # Pack the initial state of the contact points.
420
+ x_t0 = jnp.hstack([p_t0.flatten(), v_t0.flatten(), m_t0.flatten()])
421
+
422
+ # Pack the augmented matrix used to compute the single and double integral
423
+ # of the exponential integration.
424
+ A̅ = jnp.vstack(
425
+ [
426
+ jnp.hstack(
427
+ [
428
+ A,
429
+ jnp.vstack(b),
430
+ jnp.vstack(x_t0),
431
+ jnp.vstack(jnp.zeros_like(x_t0)),
432
+ ]
433
+ ),
434
+ jnp.hstack([jnp.zeros(A.shape[1]), 0, 1, 0]),
435
+ jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 1]),
436
+ jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 0]),
437
+ ]
438
+ )
439
+
440
+ # Compute the matrix exponential.
441
+ exp_tA = jax.scipy.linalg.expm(
442
+ (dt * A̅).astype(float), max_squarings=max_squarings
443
+ )
444
+
445
+ # Integrate the contact dynamics in the continuous time domain.
446
+ x_int, x_int2 = (
447
+ jnp.hstack([jnp.eye(A.shape[0]), jnp.zeros(shape=(A.shape[0], 3))])
448
+ @ exp_tA
449
+ @ jnp.vstack([jnp.zeros(shape=(A.shape[0] + 1, 2)), jnp.eye(2)])
450
+ ).T
451
+
452
+ jaxsim.exceptions.raise_runtime_error_if(
453
+ condition=jnp.isnan(x_int).any(),
454
+ msg="NaN integration, try to increase `max_squarings` or decreasing `dt`",
455
+ )
456
+
457
+ # ==========================
458
+ # Compute the contact forces
459
+ # ==========================
460
+
461
+ # Compute the average contact forces.
462
+ CW_f̅, _ = jnp.split(
463
+ (A_sc @ x_int / dt + b_sc).reshape(-1, 3),
464
+ indices_or_sections=2,
465
+ )
466
+
467
+ # Compute the average of the average contact forces.
468
+ CW_f̿, _ = jnp.split(
469
+ (A_sc @ x_int2 * 2 / (dt**2) + b_sc).reshape(-1, 3),
470
+ indices_or_sections=2,
471
+ )
472
+
473
+ # Extract the tangential deformation at the final state.
474
+ x_tf = x_int / dt
475
+ m_tf = jnp.split(x_tf, 3)[2].reshape(-1, 3)
476
+
477
+ return CW_f̅, CW_f̿, m_tf
478
+
479
+ @staticmethod
480
+ @jax.jit
481
+ def _contact_points_dynamics(
482
+ model: js.model.JaxSimModel,
483
+ data: js.data.JaxSimModelData,
484
+ *,
485
+ link_forces: jtp.MatrixLike | None = None,
486
+ joint_force_references: jtp.VectorLike | None = None,
487
+ indices_of_enabled_collidable_points: jtp.VectorLike | None = None,
488
+ p_t0: jtp.MatrixLike | None = None,
489
+ v_t0: jtp.MatrixLike | None = None,
490
+ m_t0: jtp.MatrixLike | None = None,
491
+ ) -> tuple[jtp.Matrix, jtp.Vector, jtp.Matrix, jtp.Vector]:
492
+ """
493
+ Compute the dynamics of the contact points.
494
+
495
+ Note:
496
+ This function projects the system dynamics to the contact space and
497
+ returns the matrices of a linear system to simulate its evolution.
498
+ Since the active contact model can be non-linear, this function also
499
+ linearizes the contact model at the initial state.
500
+
501
+ Args:
502
+ model: The robot model considered by the contact model.
503
+ data: The data of the considered model.
504
+ link_forces: The 6D forces to apply to the links.
505
+ joint_force_references: The joint force references to apply.
506
+ indices_of_enabled_collidable_points:
507
+ The indices of the enabled collidable points.
508
+ p_t0: The initial position of the collidable points.
509
+ v_t0: The initial velocity of the collidable points.
510
+ m_t0: The initial tangential deformation of the collidable points.
511
+
512
+ Returns:
513
+ A tuple containing:
514
+ - The `A` matrix of the linear system that models the contact dynamics.
515
+ - The `b` vector of the linear system that models the contact dynamics.
516
+ - The `A_sc` matrix of the linear system that approximates the contact model.
517
+ - The `b_sc` vector of the linear system that approximates the contact model.
518
+ """
519
+
520
+ indices_of_enabled_collidable_points = (
521
+ indices_of_enabled_collidable_points
522
+ if indices_of_enabled_collidable_points is not None
523
+ else jnp.arange(
524
+ len(model.kin_dyn_parameters.contact_parameters.body)
525
+ ).astype(int)
526
+ )
527
+
528
+ p_t0 = jnp.atleast_2d(
529
+ p_t0
530
+ if p_t0 is not None
531
+ else js.contact.collidable_point_positions(model=model, data=data)[
532
+ indices_of_enabled_collidable_points, :
533
+ ]
534
+ )
535
+
536
+ v_t0 = jnp.atleast_2d(
537
+ v_t0
538
+ if v_t0 is not None
539
+ else js.contact.collidable_point_velocities(model=model, data=data)[
540
+ indices_of_enabled_collidable_points, :
541
+ ]
542
+ )
543
+
544
+ m_t0 = jnp.atleast_2d(
545
+ m_t0
546
+ if m_t0 is not None
547
+ else data.state.extended["tangential_deformation"][
548
+ indices_of_enabled_collidable_points, :
549
+ ]
550
+ )
551
+
552
+ # We expect that the 6D forces of the `link_forces` argument are expressed
553
+ # in the frame corresponding to the velocity representation of `data`.
554
+ references = js.references.JaxSimModelReferences.build(
555
+ model=model,
556
+ link_forces=link_forces,
557
+ joint_force_references=joint_force_references,
558
+ data=data,
559
+ velocity_representation=data.velocity_representation,
560
+ )
561
+
562
+ # ===========================
563
+ # Linearize the contact model
564
+ # ===========================
565
+
566
+ # Linearize the contact model at the initial state of all considered
567
+ # contact points.
568
+ A_sc_points, b_sc_points = jax.vmap(
569
+ lambda p, v, m: ViscoElasticContacts._linearize_contact_model(
570
+ position=p,
571
+ velocity=v,
572
+ tangential_deformation=m,
573
+ parameters=data.contacts_params,
574
+ terrain=model.terrain,
575
+ )
576
+ )(p_t0, v_t0, m_t0)
577
+
578
+ # Since x = [p1, p2, ..., v1, v2, ..., m1, m2, ...], we need to split the A_sc of
579
+ # individual points since otherwise we'd get x = [ p1, v1, m1, p2, v2, m2, ...].
580
+ A_sc_p, A_sc_v, A_sc_m = jnp.split(A_sc_points, indices_or_sections=3, axis=-1)
581
+
582
+ # We want to have in output first the forces and then the material deformation rates.
583
+ # Therefore, we need to extract the components is A_sc_* separately.
584
+ A_sc = jnp.vstack(
585
+ [
586
+ jnp.hstack(
587
+ [
588
+ jax.scipy.linalg.block_diag(*A_sc_p[:, 0:3, :]),
589
+ jax.scipy.linalg.block_diag(*A_sc_v[:, 0:3, :]),
590
+ jax.scipy.linalg.block_diag(*A_sc_m[:, 0:3, :]),
591
+ ],
592
+ ),
593
+ jnp.hstack(
594
+ [
595
+ jax.scipy.linalg.block_diag(*A_sc_p[:, 3:6, :]),
596
+ jax.scipy.linalg.block_diag(*A_sc_v[:, 3:6, :]),
597
+ jax.scipy.linalg.block_diag(*A_sc_m[:, 3:6, :]),
598
+ ]
599
+ ),
600
+ ]
601
+ )
602
+
603
+ # We need to do the same for the b_sc.
604
+ b_sc = jnp.hstack(
605
+ [b_sc_points[:, 0:3].flatten(), b_sc_points[:, 3:6].flatten()]
606
+ )
607
+
608
+ # ===========================================================
609
+ # Compute the A and b matrices of the contact points dynamics
610
+ # ===========================================================
611
+
612
+ with data.switch_velocity_representation(jaxsim.VelRepr.Mixed):
613
+
614
+ BW_ν = data.generalized_velocity()
615
+
616
+ M = js.model.free_floating_mass_matrix(model=model, data=data)
617
+
618
+ CW_Jl_WC = js.contact.jacobian(
619
+ model=model,
620
+ data=data,
621
+ output_vel_repr=jaxsim.VelRepr.Mixed,
622
+ )[indices_of_enabled_collidable_points, 0:3, :]
623
+
624
+ CW_J̇l_WC = js.contact.jacobian_derivative(
625
+ model=model, data=data, output_vel_repr=jaxsim.VelRepr.Mixed
626
+ )[indices_of_enabled_collidable_points, 0:3, :]
627
+
628
+ # Compute the Delassus matrix.
629
+ ψ = jnp.vstack(CW_Jl_WC) @ jnp.linalg.lstsq(M, jnp.vstack(CW_Jl_WC).T)[0]
630
+
631
+ I_nc = jnp.eye(v_t0.flatten().size)
632
+ O_nc = jnp.zeros(shape=(p_t0.flatten().size, p_t0.flatten().size))
633
+
634
+ # Pack the A matrix.
635
+ A = jnp.vstack(
636
+ [
637
+ jnp.hstack([O_nc, I_nc, O_nc]),
638
+ ψ @ jnp.split(A_sc, 2, axis=0)[0],
639
+ jnp.split(A_sc, 2, axis=0)[1],
640
+ ]
641
+ )
642
+
643
+ # Short names for few variables.
644
+ ν = BW_ν
645
+ J = jnp.vstack(CW_Jl_WC)
646
+ J̇ = jnp.vstack(CW_J̇l_WC)
647
+
648
+ # Compute the free system acceleration components.
649
+ with (
650
+ data.switch_velocity_representation(jaxsim.VelRepr.Mixed),
651
+ references.switch_velocity_representation(jaxsim.VelRepr.Mixed),
652
+ ):
653
+
654
+ BW_v̇_free_WB, s̈_free = js.ode.system_acceleration(
655
+ model=model,
656
+ data=data,
657
+ joint_force_references=references.joint_force_references(model=model),
658
+ link_forces=references.link_forces(model=model, data=data),
659
+ )
660
+
661
+ # Pack the free system acceleration in mixed representation.
662
+ ν̇_free = jnp.hstack([BW_v̇_free_WB, s̈_free])
663
+
664
+ # Compute the acceleration of collidable points.
665
+ # This is the true derivative of ṗ only in mixed representation.
666
+ p̈ = J @ ν̇_free + J̇ @ ν
667
+
668
+ # Pack the b array.
669
+ b = jnp.hstack(
670
+ [
671
+ jnp.zeros_like(p_t0.flatten()),
672
+ p̈ + ψ @ jnp.split(b_sc, indices_or_sections=2)[0],
673
+ jnp.split(b_sc, indices_or_sections=2)[1],
674
+ ]
675
+ )
676
+
677
+ return A, b, A_sc, b_sc
678
+
679
+ @staticmethod
680
+ @functools.partial(jax.jit, static_argnames=("terrain",))
681
+ def _linearize_contact_model(
682
+ position: jtp.VectorLike,
683
+ velocity: jtp.VectorLike,
684
+ tangential_deformation: jtp.VectorLike,
685
+ parameters: ViscoElasticContactsParams,
686
+ terrain: Terrain,
687
+ ) -> tuple[jtp.Matrix, jtp.Vector]:
688
+ """"""
689
+
690
+ # Initialize the state at which the model is linearized.
691
+ p0 = jnp.array(position, dtype=float).squeeze()
692
+ v0 = jnp.array(velocity, dtype=float).squeeze()
693
+ m0 = jnp.array(tangential_deformation, dtype=float).squeeze()
694
+
695
+ # ============
696
+ # Compute A_sc
697
+ # ============
698
+
699
+ compute_contact_force_non_linear_model = functools.partial(
700
+ ViscoElasticContacts._compute_contact_force_non_linear_model,
701
+ parameters=parameters,
702
+ terrain=terrain,
703
+ )
704
+
705
+ # Compute with AD the functions to get the Jacobians of CW_fl.
706
+ df_dp_fun, df_dv_fun, df_dm_fun = (
707
+ jax.jacrev(
708
+ lambda p0, v0, m0: compute_contact_force_non_linear_model(
709
+ position=p0, velocity=v0, tangential_deformation=m0
710
+ )[0],
711
+ argnums=num,
712
+ )
713
+ for num in (0, 1, 2)
714
+ )
715
+
716
+ # Compute with AD the functions to get the Jacobians of ṁ.
717
+ dṁ_dp_fun, dṁ_dv_fun, dṁ_dm_fun = (
718
+ jax.jacrev(
719
+ lambda p0, v0, m0: compute_contact_force_non_linear_model(
720
+ position=p0, velocity=v0, tangential_deformation=m0
721
+ )[1],
722
+ argnums=num,
723
+ )
724
+ for num in (0, 1, 2)
725
+ )
726
+
727
+ # Compute the Jacobians of the contact forces w.r.t. the state.
728
+ df_dp = jnp.vstack(df_dp_fun(p0, v0, m0))
729
+ df_dv = jnp.vstack(df_dv_fun(p0, v0, m0))
730
+ df_dm = jnp.vstack(df_dm_fun(p0, v0, m0))
731
+
732
+ # Compute the Jacobians of the material deformation rate w.r.t. the state.
733
+ dṁ_dp = jnp.vstack(dṁ_dp_fun(p0, v0, m0))
734
+ dṁ_dv = jnp.vstack(dṁ_dv_fun(p0, v0, m0))
735
+ dṁ_dm = jnp.vstack(dṁ_dm_fun(p0, v0, m0))
736
+
737
+ # Pack the A matrix.
738
+ A_sc = jnp.vstack(
739
+ [
740
+ jnp.hstack([df_dp, df_dv, df_dm]),
741
+ jnp.hstack([dṁ_dp, dṁ_dv, dṁ_dm]),
742
+ ]
743
+ )
744
+
745
+ # ============
746
+ # Compute b_sc
747
+ # ============
748
+
749
+ # Compute the output of the non-linear model at the initial state.
750
+ x0 = jnp.hstack([p0, v0, m0])
751
+ f0, ṁ0 = compute_contact_force_non_linear_model(
752
+ position=p0, velocity=v0, tangential_deformation=m0
753
+ )
754
+
755
+ # Pack the b vector.
756
+ b_sc = jnp.hstack([f0, ṁ0]) - A_sc @ x0
757
+
758
+ return A_sc, b_sc
759
+
760
+ @staticmethod
761
+ @functools.partial(jax.jit, static_argnames=("terrain",))
762
+ def _compute_contact_force_non_linear_model(
763
+ position: jtp.VectorLike,
764
+ velocity: jtp.VectorLike,
765
+ tangential_deformation: jtp.VectorLike,
766
+ parameters: ViscoElasticContactsParams,
767
+ terrain: Terrain,
768
+ ) -> tuple[jtp.Vector, jtp.Vector]:
769
+ """
770
+ Compute the contact forces using the non-linear Hunt/Crossley model.
771
+
772
+ Args:
773
+ position: The position of the contact point.
774
+ velocity: The velocity of the contact point.
775
+ tangential_deformation: The tangential deformation of the contact point.
776
+ parameters: The parameters of the contact model.
777
+ terrain: The considered terrain.
778
+
779
+ Returns:
780
+ A tuple containing:
781
+ - The linear contact force in the mixed contact frame.
782
+ - The rate of material deformation.
783
+ """
784
+
785
+ # Compute the linear contact force in mixed representation using
786
+ # the non-linear Hunt/Crossley model.
787
+ # The following function also returns the rate of material deformation.
788
+ CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model(
789
+ position=position,
790
+ velocity=velocity,
791
+ tangential_deformation=tangential_deformation,
792
+ terrain=terrain,
793
+ K=parameters.K,
794
+ D=parameters.D,
795
+ mu=parameters.static_friction,
796
+ p=parameters.p,
797
+ q=parameters.q,
798
+ )
799
+
800
+ return CW_fl, ṁ
801
+
802
+ @staticmethod
803
+ @jax.jit
804
+ def integrate_data_with_average_contact_forces(
805
+ model: js.model.JaxSimModel,
806
+ data: js.data.JaxSimModelData,
807
+ *,
808
+ dt: jtp.FloatLike,
809
+ link_forces: jtp.MatrixLike | None = None,
810
+ joint_force_references: jtp.VectorLike | None = None,
811
+ average_link_contact_forces_inertial: jtp.MatrixLike | None = None,
812
+ average_of_average_link_contact_forces_mixed: jtp.MatrixLike | None = None,
813
+ ) -> js.data.JaxSimModelData:
814
+ """
815
+ Advance the system state by integrating the dynamics.
816
+
817
+ Args:
818
+ model: The model to consider.
819
+ data: The data of the considered model.
820
+ dt: The integration time step.
821
+ link_forces:
822
+ The 6D forces to apply to the links expressed in the frame corresponding
823
+ to the velocity representation of `data`.
824
+ joint_force_references: The joint force references to apply.
825
+ average_link_contact_forces_inertial:
826
+ The average contact forces computed with the exponential integrator and
827
+ expressed in the inertial-fixed frame.
828
+ average_of_average_link_contact_forces_mixed:
829
+ The average of the average contact forces computed with the exponential
830
+ integrator and expressed in the mixed frame.
831
+
832
+ Returns:
833
+ The data object storing the system state at the final time.
834
+ """
835
+
836
+ s_t0 = data.joint_positions()
837
+ W_p_B_t0 = data.base_position()
838
+ W_Q_B_t0 = data.base_orientation(dcm=False)
839
+
840
+ ṡ_t0 = data.joint_velocities()
841
+ with data.switch_velocity_representation(jaxsim.VelRepr.Mixed):
842
+ W_ṗ_B_t0 = data.base_velocity()[0:3]
843
+ W_ω_WB_t0 = data.base_velocity()[3:6]
844
+
845
+ with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
846
+ W_ν_t0 = data.generalized_velocity()
847
+
848
+ # We expect that the 6D forces of the `link_forces` argument are expressed
849
+ # in the frame corresponding to the velocity representation of `data`.
850
+ references = js.references.JaxSimModelReferences.build(
851
+ model=model,
852
+ link_forces=link_forces,
853
+ joint_force_references=joint_force_references,
854
+ data=data,
855
+ velocity_representation=data.velocity_representation,
856
+ )
857
+
858
+ W_f̅_L = (
859
+ jnp.array(average_link_contact_forces_inertial)
860
+ if average_link_contact_forces_inertial is not None
861
+ else jnp.zeros_like(references.input.physics_model.f_ext)
862
+ ).astype(float)
863
+
864
+ LW_f̿_L = (
865
+ jnp.array(average_of_average_link_contact_forces_mixed)
866
+ if average_of_average_link_contact_forces_mixed is not None
867
+ else W_f̅_L
868
+ ).astype(float)
869
+
870
+ # Compute the system inertial acceleration, used to integrate the system velocity.
871
+ # It considers the average contact forces computed with the exponential integrator.
872
+ with (
873
+ data.switch_velocity_representation(jaxsim.VelRepr.Inertial),
874
+ references.switch_velocity_representation(jaxsim.VelRepr.Inertial),
875
+ ):
876
+
877
+ W_ν̇_pr = jnp.hstack(
878
+ js.ode.system_acceleration(
879
+ model=model,
880
+ data=data,
881
+ joint_force_references=references.joint_force_references(
882
+ model=model
883
+ ),
884
+ link_forces=W_f̅_L + references.link_forces(model=model, data=data),
885
+ )
886
+ )
887
+
888
+ # Compute the system mixed acceleration, used to integrate the system position.
889
+ # It considers the average of the average contact forces computed with the
890
+ # exponential integrator.
891
+ with (
892
+ data.switch_velocity_representation(jaxsim.VelRepr.Mixed),
893
+ references.switch_velocity_representation(jaxsim.VelRepr.Mixed),
894
+ ):
895
+
896
+ BW_ν̇_pr2 = jnp.hstack(
897
+ js.ode.system_acceleration(
898
+ model=model,
899
+ data=data,
900
+ joint_force_references=references.joint_force_references(
901
+ model=model
902
+ ),
903
+ link_forces=LW_f̿_L + references.link_forces(model=model, data=data),
904
+ )
905
+ )
906
+
907
+ # Integrate the system velocity using the inertial-fixed acceleration.
908
+ W_ν_plus = W_ν_t0 + dt * W_ν̇_pr
909
+
910
+ # Integrate the system position using the mixed velocity.
911
+ q_plus = jnp.hstack(
912
+ [
913
+ # Note: here both ṗ and p̈ -> need mixed representation.
914
+ W_p_B_t0 + dt * W_ṗ_B_t0 + 0.5 * dt**2 * BW_ν̇_pr2[0:3],
915
+ jaxsim.math.Quaternion.integration(
916
+ dt=dt,
917
+ quaternion=W_Q_B_t0,
918
+ omega=(W_ω_WB_t0 + 0.5 * dt * BW_ν̇_pr2[3:6]),
919
+ omega_in_body_fixed=False,
920
+ ).squeeze(),
921
+ s_t0 + dt * ṡ_t0 + 0.5 * dt**2 * BW_ν̇_pr2[6:],
922
+ ]
923
+ )
924
+
925
+ # Create the data at the final time.
926
+ with data.editable(validate=True) as data_tf:
927
+ data_tf: js.data.JaxSimModelData
928
+ data_tf.time_ns = data.time_ns + (dt * 1e9).astype(data.time_ns.dtype)
929
+
930
+ data_tf = data_tf.reset_joint_positions(q_plus[7:])
931
+ data_tf = data_tf.reset_base_position(q_plus[0:3])
932
+ data_tf = data_tf.reset_base_quaternion(q_plus[3:7])
933
+
934
+ data_tf = data_tf.reset_joint_velocities(W_ν_plus[6:])
935
+ data_tf = data_tf.reset_base_velocity(
936
+ W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial
937
+ )
938
+
939
+ return data_tf.replace(
940
+ velocity_representation=data.velocity_representation, validate=False
941
+ )
942
+
943
+
944
+ @jax.jit
945
+ def step(
946
+ model: js.model.JaxSimModel,
947
+ data: js.data.JaxSimModelData,
948
+ *,
949
+ dt: jtp.FloatLike,
950
+ link_forces: jtp.MatrixLike | None = None,
951
+ joint_force_references: jtp.VectorLike | None = None,
952
+ ) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
953
+ """
954
+ Step the system dynamics with the visco-elastic contact model.
955
+
956
+ Args:
957
+ model: The model to consider.
958
+ data: The data of the considered model.
959
+ dt: The time step to consider.
960
+ link_forces:
961
+ The 6D forces to apply to the links expressed in the frame corresponding to
962
+ the velocity representation of `data`.
963
+ joint_force_references: The joint force references to consider.
964
+
965
+ Returns:
966
+ A tuple containing the new data of the model
967
+ and an empty dictionary of auxiliary data.
968
+ """
969
+
970
+ assert isinstance(model.contact_model, ViscoElasticContacts)
971
+ assert isinstance(data.contacts_params, ViscoElasticContactsParams)
972
+
973
+ # Compute the contact forces with the exponential integrator.
974
+ W_f̅_C, (W_f̿_C, m_tf) = model.contact_model.compute_contact_forces(
975
+ model=model,
976
+ data=data,
977
+ dt=dt,
978
+ link_forces=link_forces,
979
+ joint_force_references=joint_force_references,
980
+ )
981
+
982
+ # ===============================
983
+ # Compute the link contact forces
984
+ # ===============================
985
+
986
+ # Extract the indices corresponding to the enabled collidable points.
987
+ # The visco-elastic contact model computed only their contact forces.
988
+ indices_of_enabled_collidable_points = (
989
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
990
+ )
991
+
992
+ # Compute the link transforms.
993
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
994
+
995
+ # Construct the vector defining the parent link index of each collidable point.
996
+ # We use this vector to sum the 6D forces of all collidable points rigidly
997
+ # attached to the same link.
998
+ parent_link_index_of_collidable_points = jnp.array(
999
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
1000
+ )[indices_of_enabled_collidable_points]
1001
+
1002
+ # Create the mask that associate each collidable point to their parent link.
1003
+ # We use this mask to sum the collidable points to the right link.
1004
+ mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
1005
+ model.number_of_links()
1006
+ )
1007
+
1008
+ # Sum the forces of all collidable points rigidly attached to a body.
1009
+ # Since the contact forces W_f_C are expressed in the world frame,
1010
+ # we don't need any coordinate transformation.
1011
+ W_f̅_L = mask.T @ W_f̅_C
1012
+ W_f̿_L = mask.T @ W_f̿_C
1013
+
1014
+ # For integration purpose, we need these average of averages expressed in
1015
+ # mixed representation.
1016
+ LW_f̿_L = jax.vmap(
1017
+ lambda W_f_L, W_H_L: data.inertial_to_other_representation(
1018
+ array=W_f_L,
1019
+ other_representation=jaxsim.VelRepr.Mixed,
1020
+ transform=W_H_L,
1021
+ is_force=True,
1022
+ )
1023
+ )(W_f̿_L, W_H_L)
1024
+
1025
+ # ==========================
1026
+ # Integrate the system state
1027
+ # ==========================
1028
+
1029
+ # Integrate the system dynamics using the average contact forces.
1030
+ data_tf: js.data.JaxSimModelData = (
1031
+ model.contact_model.integrate_data_with_average_contact_forces(
1032
+ model=model,
1033
+ data=data,
1034
+ dt=dt,
1035
+ link_forces=link_forces,
1036
+ joint_force_references=joint_force_references,
1037
+ average_link_contact_forces_inertial=W_f̅_L,
1038
+ average_of_average_link_contact_forces_mixed=LW_f̿_L,
1039
+ )
1040
+ )
1041
+
1042
+ # Store the tangential deformation at the final state.
1043
+ # Note that this was integrated in the continuous time domain, therefore it should
1044
+ # be much more accurate than the one computed with the discrete soft contacts.
1045
+ with data_tf.mutable_context():
1046
+
1047
+ data_tf.state.extended |= {
1048
+ "tangential_deformation": data_tf.state.extended["tangential_deformation"]
1049
+ .at[indices_of_enabled_collidable_points]
1050
+ .set(m_tf)
1051
+ }
1052
+
1053
+ return data_tf, {}