jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
@@ -1,1066 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import dataclasses
4
- import functools
5
- from typing import Any
6
-
7
- import jax
8
- import jax.numpy as jnp
9
- import jax_dataclasses
10
-
11
- import jaxsim
12
- import jaxsim.api as js
13
- import jaxsim.exceptions
14
- import jaxsim.typing as jtp
15
- from jaxsim import logging
16
- from jaxsim.api.common import ModelDataWithVelocityRepresentation
17
- from jaxsim.math import StandardGravity
18
- from jaxsim.terrain import Terrain
19
-
20
- from . import common
21
- from .soft import SoftContacts, SoftContactsParams
22
-
23
- try:
24
- from typing import Self
25
- except ImportError:
26
- from typing_extensions import Self
27
-
28
-
29
- @jax_dataclasses.pytree_dataclass
30
- class ViscoElasticContactsParams(common.ContactsParams):
31
- """Parameters of the visco-elastic contacts model."""
32
-
33
- K: jtp.Float = dataclasses.field(
34
- default_factory=lambda: jnp.array(1e6, dtype=float)
35
- )
36
-
37
- D: jtp.Float = dataclasses.field(
38
- default_factory=lambda: jnp.array(2000, dtype=float)
39
- )
40
-
41
- static_friction: jtp.Float = dataclasses.field(
42
- default_factory=lambda: jnp.array(0.5, dtype=float)
43
- )
44
-
45
- p: jtp.Float = dataclasses.field(
46
- default_factory=lambda: jnp.array(0.5, dtype=float)
47
- )
48
-
49
- q: jtp.Float = dataclasses.field(
50
- default_factory=lambda: jnp.array(0.5, dtype=float)
51
- )
52
-
53
- @classmethod
54
- def build(
55
- cls: type[Self],
56
- K: jtp.FloatLike = 1e6,
57
- D: jtp.FloatLike = 2_000,
58
- static_friction: jtp.FloatLike = 0.5,
59
- p: jtp.FloatLike = 0.5,
60
- q: jtp.FloatLike = 0.5,
61
- ) -> Self:
62
- """
63
- Create a SoftContactsParams instance with specified parameters.
64
-
65
- Args:
66
- K: The stiffness parameter.
67
- D: The damping parameter of the soft contacts model.
68
- static_friction: The static friction coefficient.
69
- p:
70
- The exponent p corresponding to the damping-related non-linearity
71
- of the Hunt/Crossley model.
72
- q:
73
- The exponent q corresponding to the spring-related non-linearity
74
- of the Hunt/Crossley model.
75
-
76
- Returns:
77
- A ViscoElasticParams instance with the specified parameters.
78
- """
79
-
80
- return ViscoElasticContactsParams(
81
- K=jnp.array(K, dtype=float),
82
- D=jnp.array(D, dtype=float),
83
- static_friction=jnp.array(static_friction, dtype=float),
84
- p=jnp.array(p, dtype=float),
85
- q=jnp.array(q, dtype=float),
86
- )
87
-
88
- @classmethod
89
- def build_default_from_jaxsim_model(
90
- cls: type[Self],
91
- model: js.model.JaxSimModel,
92
- *,
93
- standard_gravity: jtp.FloatLike = StandardGravity,
94
- static_friction_coefficient: jtp.FloatLike = 0.5,
95
- max_penetration: jtp.FloatLike = 0.001,
96
- number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
97
- damping_ratio: jtp.FloatLike = 1.0,
98
- p: jtp.FloatLike = 0.5,
99
- q: jtp.FloatLike = 0.5,
100
- ) -> Self:
101
- """
102
- Create a ViscoElasticContactsParams instance with good default parameters.
103
-
104
- Args:
105
- model: The target model.
106
- standard_gravity: The standard gravity constant.
107
- static_friction_coefficient:
108
- The static friction coefficient between the model and the terrain.
109
- max_penetration: The maximum penetration depth.
110
- number_of_active_collidable_points_steady_state:
111
- The number of contacts supporting the weight of the model
112
- in steady state.
113
- damping_ratio: The ratio controlling the damping behavior.
114
- p:
115
- The exponent p corresponding to the damping-related non-linearity
116
- of the Hunt/Crossley model.
117
- q:
118
- The exponent q corresponding to the spring-related non-linearity
119
- of the Hunt/Crossley model.
120
-
121
- Returns:
122
- A `ViscoElasticContactsParams` instance with the specified parameters.
123
-
124
- Note:
125
- The `damping_ratio` parameter allows to operate on the following conditions:
126
- - ξ > 1.0: over-damped
127
- - ξ = 1.0: critically damped
128
- - ξ < 1.0: under-damped
129
- """
130
-
131
- # Call the SoftContact builder instead of duplicating the logic.
132
- soft_contacts_params = SoftContactsParams.build_default_from_jaxsim_model(
133
- model=model,
134
- standard_gravity=standard_gravity,
135
- static_friction_coefficient=static_friction_coefficient,
136
- max_penetration=max_penetration,
137
- number_of_active_collidable_points_steady_state=number_of_active_collidable_points_steady_state,
138
- damping_ratio=damping_ratio,
139
- )
140
-
141
- return ViscoElasticContactsParams.build(
142
- K=soft_contacts_params.K,
143
- D=soft_contacts_params.D,
144
- static_friction=soft_contacts_params.mu,
145
- p=p,
146
- q=q,
147
- )
148
-
149
- def valid(self) -> jtp.BoolLike:
150
- """
151
- Check if the parameters are valid.
152
-
153
- Returns:
154
- `True` if the parameters are valid, `False` otherwise.
155
- """
156
-
157
- return (
158
- jnp.all(self.K >= 0.0)
159
- and jnp.all(self.D >= 0.0)
160
- and jnp.all(self.static_friction >= 0.0)
161
- and jnp.all(self.p >= 0.0)
162
- and jnp.all(self.q >= 0.0)
163
- )
164
-
165
- def __hash__(self) -> int:
166
-
167
- from jaxsim.utils.wrappers import HashedNumpyArray
168
-
169
- return hash(
170
- (
171
- HashedNumpyArray.hash_of_array(self.K),
172
- HashedNumpyArray.hash_of_array(self.D),
173
- HashedNumpyArray.hash_of_array(self.static_friction),
174
- HashedNumpyArray.hash_of_array(self.p),
175
- HashedNumpyArray.hash_of_array(self.q),
176
- )
177
- )
178
-
179
- def __eq__(self, other: ViscoElasticContactsParams) -> bool:
180
-
181
- if not isinstance(other, ViscoElasticContactsParams):
182
- return False
183
-
184
- return hash(self) == hash(other)
185
-
186
-
187
- @jax_dataclasses.pytree_dataclass
188
- class ViscoElasticContacts(common.ContactModel):
189
- """Visco-elastic contacts model."""
190
-
191
- max_squarings: jax_dataclasses.Static[int] = dataclasses.field(default=25)
192
-
193
- @classmethod
194
- def build(
195
- cls: type[Self],
196
- model: js.model.JaxSimModel | None = None,
197
- max_squarings: jtp.IntLike | None = None,
198
- **kwargs,
199
- ) -> Self:
200
- """
201
- Create a `ViscoElasticContacts` instance with specified parameters.
202
-
203
- Args:
204
- model:
205
- The robot model considered by the contact model.
206
- If passed, it is used to estimate good default parameters.
207
- max_squarings:
208
- The maximum number of squarings performed in the matrix exponential.
209
- **kwargs: Extra arguments to ignore.
210
-
211
- Returns:
212
- The `ViscoElasticContacts` instance.
213
- """
214
-
215
- if len(kwargs) != 0:
216
- logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
217
-
218
- return cls(
219
- max_squarings=int(
220
- max_squarings
221
- if max_squarings is not None
222
- else cls.__dataclass_fields__["max_squarings"].default
223
- ),
224
- )
225
-
226
- @classmethod
227
- def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
228
- """
229
- Build zero state variables of the contact model.
230
- """
231
-
232
- # Initialize the material deformation to zero.
233
- tangential_deformation = jnp.zeros(
234
- shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3),
235
- dtype=float,
236
- )
237
-
238
- return {"tangential_deformation": tangential_deformation}
239
-
240
- @jax.jit
241
- def compute_contact_forces(
242
- self,
243
- model: js.model.JaxSimModel,
244
- data: js.data.JaxSimModelData,
245
- *,
246
- dt: jtp.FloatLike | None = None,
247
- link_forces: jtp.MatrixLike | None = None,
248
- joint_force_references: jtp.VectorLike | None = None,
249
- ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
250
- """
251
- Compute the contact forces.
252
-
253
- Args:
254
- model: The robot model considered by the contact model.
255
- data: The data of the considered model.
256
- dt: The time step to consider. If not specified, it is read from the model.
257
- link_forces:
258
- The 6D forces to apply to the links expressed in the frame corresponding
259
- to the velocity representation of `data`.
260
- joint_force_references: The joint force references to apply.
261
-
262
- Note:
263
- This contact model, contrarily to most other contact models, requires the
264
- knowledge of the integration step. It is not straightforward to assess how
265
- this contact model behaves when used with high-order Runge-Kutta schemes.
266
- For the time being, it is recommended to use a simple forward Euler scheme.
267
- The main benefit of this model is that the stiff contact dynamics is computed
268
- separately from the rest of the system dynamics, which allows to use simple
269
- integration schemes without altering significantly the simulation stability.
270
-
271
- Returns:
272
- A tuple containing as first element the computed 6D contact force applied to
273
- the contact point and expressed in the world frame, and as second element
274
- a dictionary of optional additional information.
275
- """
276
-
277
- # Extract the indices corresponding to the enabled collidable points.
278
- indices_of_enabled_collidable_points = (
279
- model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
280
- )
281
-
282
- # Initialize the time step.
283
- dt = dt if dt is not None else model.time_step
284
-
285
- # Compute the average contact linear forces in mixed representation by
286
- # integrating the contact dynamics in the continuous time domain.
287
- CW_f̅l, CW_fl̿, m_tf = (
288
- ViscoElasticContacts._compute_contact_forces_with_exponential_integration(
289
- model=model,
290
- data=data,
291
- dt=jnp.array(dt).astype(float),
292
- link_forces=link_forces,
293
- joint_force_references=joint_force_references,
294
- indices_of_enabled_collidable_points=indices_of_enabled_collidable_points,
295
- max_squarings=self.max_squarings,
296
- )
297
- )
298
-
299
- # ============================================
300
- # Compute the inertial-fixed 6D contact forces
301
- # ============================================
302
-
303
- # Compute the transforms of the mixed frames `C[W] = (W_p_C, [W])`
304
- # associated to each collidable point.
305
- W_H_C = js.contact.transforms(model=model, data=data)[
306
- indices_of_enabled_collidable_points, :, :
307
- ]
308
-
309
- # Vmapped transformation from mixed to inertial-fixed representation.
310
- compute_forces_inertial_fixed_vmap = jax.vmap(
311
- lambda CW_fl_C, W_H_C: (
312
- ModelDataWithVelocityRepresentation.other_representation_to_inertial(
313
- array=jnp.zeros(6).at[0:3].set(CW_fl_C),
314
- other_representation=jaxsim.VelRepr.Mixed,
315
- transform=W_H_C,
316
- is_force=True,
317
- )
318
- )
319
- )
320
-
321
- # Express the linear contact forces in the inertial-fixed frame.
322
- W_f̅_C, W_f̿_C = jax.vmap(
323
- lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C)
324
- )(jnp.stack([CW_f̅l, CW_fl̿]))
325
-
326
- return W_f̅_C, dict(W_f_avg2_C=W_f̿_C, m_tf=m_tf)
327
-
328
- @staticmethod
329
- @functools.partial(jax.jit, static_argnames=("max_squarings",))
330
- def _compute_contact_forces_with_exponential_integration(
331
- model: js.model.JaxSimModel,
332
- data: js.data.JaxSimModelData,
333
- *,
334
- dt: jtp.FloatLike,
335
- link_forces: jtp.MatrixLike | None = None,
336
- joint_force_references: jtp.VectorLike | None = None,
337
- indices_of_enabled_collidable_points: jtp.VectorLike | None = None,
338
- max_squarings: int = 25,
339
- ) -> tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]:
340
- """
341
- Compute the average contact forces by integrating the contact dynamics.
342
-
343
- Args:
344
- model: The robot model considered by the contact model.
345
- data: The data of the considered model.
346
- dt: The integration time step.
347
- link_forces: The 6D forces to apply to the links.
348
- joint_force_references: The joint force references to apply.
349
- indices_of_enabled_collidable_points:
350
- The indices of the enabled collidable points.
351
- max_squarings:
352
- The maximum number of squarings performed in the matrix exponential.
353
-
354
- Returns:
355
- A tuple containing:
356
- - The average contact forces.
357
- - The average of the average contact forces.
358
- - The tangential deformation at the final state.
359
- """
360
-
361
- # ==========================
362
- # Populate missing arguments
363
- # ==========================
364
-
365
- indices = (
366
- indices_of_enabled_collidable_points
367
- if indices_of_enabled_collidable_points is not None
368
- else jnp.arange(
369
- len(model.kin_dyn_parameters.contact_parameters.body)
370
- ).astype(int)
371
- )
372
-
373
- # ==================================
374
- # Compute the contact point dynamics
375
- # ==================================
376
-
377
- p_t0, v_t0 = js.contact.collidable_point_kinematics(model, data)
378
- m_t0 = data.state.extended["tangential_deformation"][indices, :]
379
-
380
- p_t0 = p_t0[indices, :]
381
- v_t0 = v_t0[indices, :]
382
-
383
- # Compute the linearized contact dynamics.
384
- # Note that it linearizes the (non-linear) contact model at (p, v, m)[t0].
385
- A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics(
386
- model=model,
387
- data=data,
388
- link_forces=link_forces,
389
- joint_force_references=joint_force_references,
390
- indices_of_enabled_collidable_points=indices,
391
- p_t0=p_t0,
392
- v_t0=v_t0,
393
- m_t0=m_t0,
394
- )
395
-
396
- # =============================================
397
- # Compute the integrals of the contact dynamics
398
- # =============================================
399
-
400
- # Pack the initial state of the contact points.
401
- x_t0 = jnp.hstack([p_t0.flatten(), v_t0.flatten(), m_t0.flatten()])
402
-
403
- # Pack the augmented matrix used to compute the single and double integral
404
- # of the exponential integration.
405
- A̅ = jnp.vstack(
406
- [
407
- jnp.hstack(
408
- [
409
- A,
410
- jnp.vstack(b),
411
- jnp.vstack(x_t0),
412
- jnp.vstack(jnp.zeros_like(x_t0)),
413
- ]
414
- ),
415
- jnp.hstack([jnp.zeros(A.shape[1]), 0, 1, 0]),
416
- jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 1]),
417
- jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 0]),
418
- ]
419
- )
420
-
421
- # Compute the matrix exponential.
422
- exp_tA = jax.scipy.linalg.expm(
423
- (dt * A̅).astype(float), max_squarings=max_squarings
424
- )
425
-
426
- # Integrate the contact dynamics in the continuous time domain.
427
- x_int, x_int2 = (
428
- jnp.hstack([jnp.eye(A.shape[0]), jnp.zeros(shape=(A.shape[0], 3))])
429
- @ exp_tA
430
- @ jnp.vstack([jnp.zeros(shape=(A.shape[0] + 1, 2)), jnp.eye(2)])
431
- ).T
432
-
433
- jaxsim.exceptions.raise_runtime_error_if(
434
- condition=jnp.isnan(x_int).any(),
435
- msg="NaN integration, try to increase `max_squarings` or decreasing `dt`",
436
- )
437
-
438
- # ==========================
439
- # Compute the contact forces
440
- # ==========================
441
-
442
- # Compute the average contact forces.
443
- CW_f̅, _ = jnp.split(
444
- (A_sc @ x_int / dt + b_sc).reshape(-1, 3),
445
- indices_or_sections=2,
446
- )
447
-
448
- # Compute the average of the average contact forces.
449
- CW_f̿, _ = jnp.split(
450
- (A_sc @ x_int2 * 2 / (dt**2) + b_sc).reshape(-1, 3),
451
- indices_or_sections=2,
452
- )
453
-
454
- # Extract the tangential deformation at the final state.
455
- x_tf = x_int / dt
456
- m_tf = jnp.split(x_tf, 3)[2].reshape(-1, 3)
457
-
458
- return CW_f̅, CW_f̿, m_tf
459
-
460
- @staticmethod
461
- @jax.jit
462
- def _contact_points_dynamics(
463
- model: js.model.JaxSimModel,
464
- data: js.data.JaxSimModelData,
465
- *,
466
- link_forces: jtp.MatrixLike | None = None,
467
- joint_force_references: jtp.VectorLike | None = None,
468
- indices_of_enabled_collidable_points: jtp.VectorLike | None = None,
469
- p_t0: jtp.MatrixLike | None = None,
470
- v_t0: jtp.MatrixLike | None = None,
471
- m_t0: jtp.MatrixLike | None = None,
472
- ) -> tuple[jtp.Matrix, jtp.Vector, jtp.Matrix, jtp.Vector]:
473
- """
474
- Compute the dynamics of the contact points.
475
-
476
- Note:
477
- This function projects the system dynamics to the contact space and
478
- returns the matrices of a linear system to simulate its evolution.
479
- Since the active contact model can be non-linear, this function also
480
- linearizes the contact model at the initial state.
481
-
482
- Args:
483
- model: The robot model considered by the contact model.
484
- data: The data of the considered model.
485
- link_forces: The 6D forces to apply to the links.
486
- joint_force_references: The joint force references to apply.
487
- indices_of_enabled_collidable_points:
488
- The indices of the enabled collidable points.
489
- p_t0: The initial position of the collidable points.
490
- v_t0: The initial velocity of the collidable points.
491
- m_t0: The initial tangential deformation of the collidable points.
492
-
493
- Returns:
494
- A tuple containing:
495
- - The `A` matrix of the linear system that models the contact dynamics.
496
- - The `b` vector of the linear system that models the contact dynamics.
497
- - The `A_sc` matrix of the linear system that approximates the contact model.
498
- - The `b_sc` vector of the linear system that approximates the contact model.
499
- """
500
-
501
- indices_of_enabled_collidable_points = (
502
- indices_of_enabled_collidable_points
503
- if indices_of_enabled_collidable_points is not None
504
- else jnp.arange(
505
- len(model.kin_dyn_parameters.contact_parameters.body)
506
- ).astype(int)
507
- )
508
-
509
- p_t0 = jnp.atleast_2d(
510
- p_t0
511
- if p_t0 is not None
512
- else js.contact.collidable_point_positions(model=model, data=data)[
513
- indices_of_enabled_collidable_points, :
514
- ]
515
- )
516
-
517
- v_t0 = jnp.atleast_2d(
518
- v_t0
519
- if v_t0 is not None
520
- else js.contact.collidable_point_velocities(model=model, data=data)[
521
- indices_of_enabled_collidable_points, :
522
- ]
523
- )
524
-
525
- m_t0 = jnp.atleast_2d(
526
- m_t0
527
- if m_t0 is not None
528
- else data.state.extended["tangential_deformation"][
529
- indices_of_enabled_collidable_points, :
530
- ]
531
- )
532
-
533
- # We expect that the 6D forces of the `link_forces` argument are expressed
534
- # in the frame corresponding to the velocity representation of `data`.
535
- references = js.references.JaxSimModelReferences.build(
536
- model=model,
537
- link_forces=link_forces,
538
- joint_force_references=joint_force_references,
539
- data=data,
540
- velocity_representation=data.velocity_representation,
541
- )
542
-
543
- # ===========================
544
- # Linearize the contact model
545
- # ===========================
546
-
547
- # Linearize the contact model at the initial state of all considered
548
- # contact points.
549
- A_sc_points, b_sc_points = jax.vmap(
550
- lambda p, v, m: ViscoElasticContacts._linearize_contact_model(
551
- position=p,
552
- velocity=v,
553
- tangential_deformation=m,
554
- parameters=data.contacts_params,
555
- terrain=model.terrain,
556
- )
557
- )(p_t0, v_t0, m_t0)
558
-
559
- # Since x = [p1, p2, ..., v1, v2, ..., m1, m2, ...], we need to split the A_sc of
560
- # individual points since otherwise we'd get x = [ p1, v1, m1, p2, v2, m2, ...].
561
- A_sc_p, A_sc_v, A_sc_m = jnp.split(A_sc_points, indices_or_sections=3, axis=-1)
562
-
563
- # We want to have in output first the forces and then the material deformation rates.
564
- # Therefore, we need to extract the components is A_sc_* separately.
565
- A_sc = jnp.vstack(
566
- [
567
- jnp.hstack(
568
- [
569
- jax.scipy.linalg.block_diag(*A_sc_p[:, 0:3, :]),
570
- jax.scipy.linalg.block_diag(*A_sc_v[:, 0:3, :]),
571
- jax.scipy.linalg.block_diag(*A_sc_m[:, 0:3, :]),
572
- ],
573
- ),
574
- jnp.hstack(
575
- [
576
- jax.scipy.linalg.block_diag(*A_sc_p[:, 3:6, :]),
577
- jax.scipy.linalg.block_diag(*A_sc_v[:, 3:6, :]),
578
- jax.scipy.linalg.block_diag(*A_sc_m[:, 3:6, :]),
579
- ]
580
- ),
581
- ]
582
- )
583
-
584
- # We need to do the same for the b_sc.
585
- b_sc = jnp.hstack(
586
- [b_sc_points[:, 0:3].flatten(), b_sc_points[:, 3:6].flatten()]
587
- )
588
-
589
- # ===========================================================
590
- # Compute the A and b matrices of the contact points dynamics
591
- # ===========================================================
592
-
593
- with data.switch_velocity_representation(jaxsim.VelRepr.Mixed):
594
-
595
- BW_ν = data.generalized_velocity()
596
-
597
- M = js.model.free_floating_mass_matrix(model=model, data=data)
598
-
599
- CW_Jl_WC = js.contact.jacobian(
600
- model=model,
601
- data=data,
602
- output_vel_repr=jaxsim.VelRepr.Mixed,
603
- )[indices_of_enabled_collidable_points, 0:3, :]
604
-
605
- CW_J̇l_WC = js.contact.jacobian_derivative(
606
- model=model, data=data, output_vel_repr=jaxsim.VelRepr.Mixed
607
- )[indices_of_enabled_collidable_points, 0:3, :]
608
-
609
- # Compute the Delassus matrix.
610
- ψ = jnp.vstack(CW_Jl_WC) @ jnp.linalg.lstsq(M, jnp.vstack(CW_Jl_WC).T)[0]
611
-
612
- I_nc = jnp.eye(v_t0.flatten().size)
613
- O_nc = jnp.zeros(shape=(p_t0.flatten().size, p_t0.flatten().size))
614
-
615
- # Pack the A matrix.
616
- A = jnp.vstack(
617
- [
618
- jnp.hstack([O_nc, I_nc, O_nc]),
619
- ψ @ jnp.split(A_sc, 2, axis=0)[0],
620
- jnp.split(A_sc, 2, axis=0)[1],
621
- ]
622
- )
623
-
624
- # Short names for few variables.
625
- ν = BW_ν
626
- J = jnp.vstack(CW_Jl_WC)
627
- J̇ = jnp.vstack(CW_J̇l_WC)
628
-
629
- # Compute the free system acceleration components.
630
- with (
631
- data.switch_velocity_representation(jaxsim.VelRepr.Mixed),
632
- references.switch_velocity_representation(jaxsim.VelRepr.Mixed),
633
- ):
634
-
635
- BW_v̇_free_WB, s̈_free = js.ode.system_acceleration(
636
- model=model,
637
- data=data,
638
- link_forces=references.link_forces(model=model, data=data),
639
- joint_force_references=references.joint_force_references(model=model),
640
- )
641
-
642
- # Pack the free system acceleration in mixed representation.
643
- ν̇_free = jnp.hstack([BW_v̇_free_WB, s̈_free])
644
-
645
- # Compute the acceleration of collidable points.
646
- # This is the true derivative of ṗ only in mixed representation.
647
- p̈ = J @ ν̇_free + J̇ @ ν
648
-
649
- # Pack the b array.
650
- b = jnp.hstack(
651
- [
652
- jnp.zeros_like(p_t0.flatten()),
653
- p̈ + ψ @ jnp.split(b_sc, indices_or_sections=2)[0],
654
- jnp.split(b_sc, indices_or_sections=2)[1],
655
- ]
656
- )
657
-
658
- return A, b, A_sc, b_sc
659
-
660
- @staticmethod
661
- @functools.partial(jax.jit, static_argnames=("terrain",))
662
- def _linearize_contact_model(
663
- position: jtp.VectorLike,
664
- velocity: jtp.VectorLike,
665
- tangential_deformation: jtp.VectorLike,
666
- parameters: ViscoElasticContactsParams,
667
- terrain: Terrain,
668
- ) -> tuple[jtp.Matrix, jtp.Vector]:
669
- """
670
- Linearize the Hunt/Crossley contact model at the initial state.
671
-
672
- Args:
673
- position: The position of the contact point.
674
- velocity: The velocity of the contact point.
675
- tangential_deformation: The tangential deformation of the contact point.
676
- parameters: The parameters of the contact model.
677
- terrain: The considered terrain.
678
-
679
- Returns:
680
- A tuple containing the `A` matrix and the `b` vector of the linear system
681
- corresponding to the contact dynamics linearized at the initial state.
682
- """
683
-
684
- # Initialize the state at which the model is linearized.
685
- p0 = jnp.array(position, dtype=float).squeeze()
686
- v0 = jnp.array(velocity, dtype=float).squeeze()
687
- m0 = jnp.array(tangential_deformation, dtype=float).squeeze()
688
-
689
- # ============
690
- # Compute A_sc
691
- # ============
692
-
693
- compute_contact_force_non_linear_model = functools.partial(
694
- ViscoElasticContacts._compute_contact_force_non_linear_model,
695
- parameters=parameters,
696
- terrain=terrain,
697
- )
698
-
699
- # Compute with AD the functions to get the Jacobians of CW_fl.
700
- df_dp_fun, df_dv_fun, df_dm_fun = (
701
- jax.jacrev(
702
- lambda p0, v0, m0: compute_contact_force_non_linear_model(
703
- position=p0, velocity=v0, tangential_deformation=m0
704
- )[0],
705
- argnums=num,
706
- )
707
- for num in (0, 1, 2)
708
- )
709
-
710
- # Compute with AD the functions to get the Jacobians of ṁ.
711
- dṁ_dp_fun, dṁ_dv_fun, dṁ_dm_fun = (
712
- jax.jacrev(
713
- lambda p0, v0, m0: compute_contact_force_non_linear_model(
714
- position=p0, velocity=v0, tangential_deformation=m0
715
- )[1],
716
- argnums=num,
717
- )
718
- for num in (0, 1, 2)
719
- )
720
-
721
- # Compute the Jacobians of the contact forces w.r.t. the state.
722
- df_dp = jnp.vstack(df_dp_fun(p0, v0, m0))
723
- df_dv = jnp.vstack(df_dv_fun(p0, v0, m0))
724
- df_dm = jnp.vstack(df_dm_fun(p0, v0, m0))
725
-
726
- # Compute the Jacobians of the material deformation rate w.r.t. the state.
727
- dṁ_dp = jnp.vstack(dṁ_dp_fun(p0, v0, m0))
728
- dṁ_dv = jnp.vstack(dṁ_dv_fun(p0, v0, m0))
729
- dṁ_dm = jnp.vstack(dṁ_dm_fun(p0, v0, m0))
730
-
731
- # Pack the A matrix.
732
- A_sc = jnp.vstack(
733
- [
734
- jnp.hstack([df_dp, df_dv, df_dm]),
735
- jnp.hstack([dṁ_dp, dṁ_dv, dṁ_dm]),
736
- ]
737
- )
738
-
739
- # ============
740
- # Compute b_sc
741
- # ============
742
-
743
- # Compute the output of the non-linear model at the initial state.
744
- x0 = jnp.hstack([p0, v0, m0])
745
- f0, ṁ0 = compute_contact_force_non_linear_model(
746
- position=p0, velocity=v0, tangential_deformation=m0
747
- )
748
-
749
- # Pack the b vector.
750
- b_sc = jnp.hstack([f0, ṁ0]) - A_sc @ x0
751
-
752
- return A_sc, b_sc
753
-
754
- @staticmethod
755
- @functools.partial(jax.jit, static_argnames=("terrain",))
756
- def _compute_contact_force_non_linear_model(
757
- position: jtp.VectorLike,
758
- velocity: jtp.VectorLike,
759
- tangential_deformation: jtp.VectorLike,
760
- parameters: ViscoElasticContactsParams,
761
- terrain: Terrain,
762
- ) -> tuple[jtp.Vector, jtp.Vector]:
763
- """
764
- Compute the contact forces using the non-linear Hunt/Crossley model.
765
-
766
- Args:
767
- position: The position of the contact point.
768
- velocity: The velocity of the contact point.
769
- tangential_deformation: The tangential deformation of the contact point.
770
- parameters: The parameters of the contact model.
771
- terrain: The considered terrain.
772
-
773
- Returns:
774
- A tuple containing:
775
- - The linear contact force in the mixed contact frame.
776
- - The rate of material deformation.
777
- """
778
-
779
- # Compute the linear contact force in mixed representation using
780
- # the non-linear Hunt/Crossley model.
781
- # The following function also returns the rate of material deformation.
782
- CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model(
783
- position=position,
784
- velocity=velocity,
785
- tangential_deformation=tangential_deformation,
786
- terrain=terrain,
787
- K=parameters.K,
788
- D=parameters.D,
789
- mu=parameters.static_friction,
790
- p=parameters.p,
791
- q=parameters.q,
792
- )
793
-
794
- return CW_fl, ṁ
795
-
796
- @staticmethod
797
- @jax.jit
798
- def integrate_data_with_average_contact_forces(
799
- model: js.model.JaxSimModel,
800
- data: js.data.JaxSimModelData,
801
- *,
802
- dt: jtp.FloatLike,
803
- link_forces: jtp.MatrixLike | None = None,
804
- joint_force_references: jtp.VectorLike | None = None,
805
- average_link_contact_forces_inertial: jtp.MatrixLike | None = None,
806
- average_of_average_link_contact_forces_mixed: jtp.MatrixLike | None = None,
807
- ) -> js.data.JaxSimModelData:
808
- """
809
- Advance the system state by integrating the dynamics.
810
-
811
- Args:
812
- model: The model to consider.
813
- data: The data of the considered model.
814
- dt: The integration time step.
815
- link_forces:
816
- The 6D forces to apply to the links expressed in the frame corresponding
817
- to the velocity representation of `data`.
818
- joint_force_references: The joint force references to apply.
819
- average_link_contact_forces_inertial:
820
- The average contact forces computed with the exponential integrator and
821
- expressed in the inertial-fixed frame.
822
- average_of_average_link_contact_forces_mixed:
823
- The average of the average contact forces computed with the exponential
824
- integrator and expressed in the mixed frame.
825
-
826
- Returns:
827
- The data object storing the system state at the final time.
828
- """
829
-
830
- s_t0 = data.joint_positions()
831
- W_p_B_t0 = data.base_position()
832
- W_Q_B_t0 = data.base_orientation(dcm=False)
833
-
834
- ṡ_t0 = data.joint_velocities()
835
- with data.switch_velocity_representation(jaxsim.VelRepr.Mixed):
836
- W_ṗ_B_t0 = data.base_velocity()[0:3]
837
- W_ω_WB_t0 = data.base_velocity()[3:6]
838
-
839
- with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
840
- W_ν_t0 = data.generalized_velocity()
841
-
842
- # We expect that the 6D forces of the `link_forces` argument are expressed
843
- # in the frame corresponding to the velocity representation of `data`.
844
- references = js.references.JaxSimModelReferences.build(
845
- model=model,
846
- link_forces=link_forces,
847
- joint_force_references=joint_force_references,
848
- data=data,
849
- velocity_representation=data.velocity_representation,
850
- )
851
-
852
- W_f̅_L = (
853
- jnp.array(average_link_contact_forces_inertial)
854
- if average_link_contact_forces_inertial is not None
855
- else jnp.zeros_like(references._link_forces)
856
- ).astype(float)
857
-
858
- LW_f̿_L = (
859
- jnp.array(average_of_average_link_contact_forces_mixed)
860
- if average_of_average_link_contact_forces_mixed is not None
861
- else W_f̅_L
862
- ).astype(float)
863
-
864
- # Compute the system inertial acceleration, used to integrate the system velocity.
865
- # It considers the average contact forces computed with the exponential integrator.
866
- with (
867
- data.switch_velocity_representation(jaxsim.VelRepr.Inertial),
868
- references.switch_velocity_representation(jaxsim.VelRepr.Inertial),
869
- ):
870
-
871
- W_ν̇_pr = jnp.hstack(
872
- js.ode.system_acceleration(
873
- model=model,
874
- data=data,
875
- joint_force_references=references.joint_force_references(
876
- model=model
877
- ),
878
- link_forces=W_f̅_L + references.link_forces(model=model, data=data),
879
- )
880
- )
881
-
882
- # Compute the system mixed acceleration, used to integrate the system position.
883
- # It considers the average of the average contact forces computed with the
884
- # exponential integrator.
885
- with (
886
- data.switch_velocity_representation(jaxsim.VelRepr.Mixed),
887
- references.switch_velocity_representation(jaxsim.VelRepr.Mixed),
888
- ):
889
-
890
- BW_ν̇_pr2 = jnp.hstack(
891
- js.ode.system_acceleration(
892
- model=model,
893
- data=data,
894
- joint_force_references=references.joint_force_references(
895
- model=model
896
- ),
897
- link_forces=LW_f̿_L + references.link_forces(model=model, data=data),
898
- )
899
- )
900
-
901
- # Integrate the system velocity using the inertial-fixed acceleration.
902
- W_ν_plus = W_ν_t0 + dt * W_ν̇_pr
903
-
904
- # Integrate the system position using the mixed velocity.
905
- q_plus = jnp.hstack(
906
- [
907
- # Note: here both ṗ and p̈ -> need mixed representation.
908
- W_p_B_t0 + dt * W_ṗ_B_t0 + 0.5 * dt**2 * BW_ν̇_pr2[0:3],
909
- jaxsim.math.Quaternion.integration(
910
- dt=dt,
911
- quaternion=W_Q_B_t0,
912
- omega=(W_ω_WB_t0 + 0.5 * dt * BW_ν̇_pr2[3:6]),
913
- omega_in_body_fixed=False,
914
- ).squeeze(),
915
- s_t0 + dt * ṡ_t0 + 0.5 * dt**2 * BW_ν̇_pr2[6:],
916
- ]
917
- )
918
-
919
- # Create the data at the final time.
920
- data_tf = data.copy()
921
- data_tf = data_tf.reset_joint_positions(q_plus[7:])
922
- data_tf = data_tf.reset_base_position(q_plus[0:3])
923
- data_tf = data_tf.reset_base_quaternion(q_plus[3:7])
924
- data_tf = data_tf.reset_joint_velocities(W_ν_plus[6:])
925
- data_tf = data_tf.reset_base_velocity(
926
- W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial
927
- )
928
-
929
- return data_tf.replace(
930
- velocity_representation=data.velocity_representation, validate=False
931
- )
932
-
933
-
934
- @jax.jit
935
- def step(
936
- model: js.model.JaxSimModel,
937
- data: js.data.JaxSimModelData,
938
- *,
939
- dt: jtp.FloatLike | None = None,
940
- link_forces: jtp.MatrixLike | None = None,
941
- joint_force_references: jtp.VectorLike | None = None,
942
- ) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
943
- """
944
- Step the system dynamics with the visco-elastic contact model.
945
-
946
- Args:
947
- model: The model to consider.
948
- data: The data of the considered model.
949
- dt: The time step to consider. If not specified, it is read from the model.
950
- link_forces:
951
- The 6D forces to apply to the links expressed in the frame corresponding to
952
- the velocity representation of `data`.
953
- joint_force_references: The joint force references to consider.
954
-
955
- Returns:
956
- A tuple containing the new data of the model
957
- and an empty dictionary of auxiliary data.
958
- """
959
-
960
- assert isinstance(model.contact_model, ViscoElasticContacts)
961
- assert isinstance(data.contacts_params, ViscoElasticContactsParams)
962
-
963
- # Compute the contact forces in inertial-fixed representation.
964
- # TODO: understand what's wrong in other representations.
965
- data_inertial_fixed = data.replace(
966
- velocity_representation=jaxsim.VelRepr.Inertial, validate=False
967
- )
968
-
969
- # Create the references object.
970
- references = js.references.JaxSimModelReferences.build(
971
- model=model,
972
- data=data,
973
- link_forces=link_forces,
974
- joint_force_references=joint_force_references,
975
- velocity_representation=data.velocity_representation,
976
- )
977
-
978
- # Initialize the time step.
979
- dt = dt if dt is not None else model.time_step
980
-
981
- # Compute the contact forces with the exponential integrator.
982
- W_f̅_C, aux_data = model.contact_model.compute_contact_forces(
983
- model=model,
984
- data=data_inertial_fixed,
985
- dt=jnp.array(dt).astype(float),
986
- link_forces=references.link_forces(model=model, data=data),
987
- joint_force_references=references.joint_force_references(model=model),
988
- )
989
-
990
- # Extract the final material deformation and the average of average forces
991
- # from the dictionary containing auxiliary data.
992
- m_tf = aux_data["m_tf"]
993
- W_f̿_C = aux_data["W_f_avg2_C"]
994
-
995
- # ===============================
996
- # Compute the link contact forces
997
- # ===============================
998
-
999
- # Get the link contact forces by summing the forces of contact points belonging
1000
- # to the same link.
1001
- W_f̅_L, W_f̿_L = jax.vmap(
1002
- lambda W_f_C: model.contact_model.link_forces_from_contact_forces(
1003
- model=model, data=data_inertial_fixed, contact_forces=W_f_C
1004
- )
1005
- )(jnp.stack([W_f̅_C, W_f̿_C]))
1006
-
1007
- # Compute the link transforms.
1008
- W_H_L = (
1009
- js.model.forward_kinematics(model=model, data=data)
1010
- if data.velocity_representation is not jaxsim.VelRepr.Inertial
1011
- else jnp.zeros(shape=(model.number_of_links(), 4, 4))
1012
- )
1013
-
1014
- # For integration purpose, we need the average of average forces expressed in
1015
- # mixed representation.
1016
- LW_f̿_L = jax.vmap(
1017
- lambda W_f_L, W_H_L: (
1018
- ModelDataWithVelocityRepresentation.inertial_to_other_representation(
1019
- array=W_f_L,
1020
- other_representation=jaxsim.VelRepr.Mixed,
1021
- transform=W_H_L,
1022
- is_force=True,
1023
- )
1024
- )
1025
- )(W_f̿_L, W_H_L)
1026
-
1027
- # ==========================
1028
- # Integrate the system state
1029
- # ==========================
1030
-
1031
- # Integrate the system dynamics using the average contact forces.
1032
- data_tf: js.data.JaxSimModelData = (
1033
- model.contact_model.integrate_data_with_average_contact_forces(
1034
- model=model,
1035
- data=data_inertial_fixed,
1036
- dt=dt,
1037
- link_forces=references.link_forces(model=model, data=data),
1038
- joint_force_references=references.joint_force_references(model=model),
1039
- average_link_contact_forces_inertial=W_f̅_L,
1040
- average_of_average_link_contact_forces_mixed=LW_f̿_L,
1041
- )
1042
- )
1043
-
1044
- # Store the tangential deformation at the final state.
1045
- # Note that this was integrated in the continuous time domain, therefore it should
1046
- # be much more accurate than the one computed with the discrete soft contacts.
1047
- with data_tf.mutable_context():
1048
-
1049
- # Extract the indices corresponding to the enabled collidable points.
1050
- # The visco-elastic contact model computed only their contact forces.
1051
- indices_of_enabled_collidable_points = (
1052
- model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
1053
- )
1054
-
1055
- data_tf.state.extended |= {
1056
- "tangential_deformation": data_tf.state.extended["tangential_deformation"]
1057
- .at[indices_of_enabled_collidable_points]
1058
- .set(m_tf)
1059
- }
1060
-
1061
- # Restore the original velocity representation.
1062
- data_tf = data_tf.replace(
1063
- velocity_representation=data.velocity_representation, validate=False
1064
- )
1065
-
1066
- return data_tf, {}