jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,605 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from collections.abc import Callable
5
+ from typing import Any
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jax_dataclasses
10
+ import optax
11
+
12
+ import jaxsim.api as js
13
+ import jaxsim.rbda.contacts
14
+ import jaxsim.typing as jtp
15
+ from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
16
+
17
+ from . import common
18
+
19
+ try:
20
+ from typing import Self
21
+ except ImportError:
22
+ from typing_extensions import Self
23
+
24
+
25
+ @jax_dataclasses.pytree_dataclass
26
+ class RelaxedRigidContactsParams(common.ContactsParams):
27
+ """Parameters of the relaxed rigid contacts model."""
28
+
29
+ # Time constant
30
+ time_constant: jtp.Float = dataclasses.field(
31
+ default_factory=lambda: jnp.array(0.01, dtype=float)
32
+ )
33
+
34
+ # Adimensional damping coefficient
35
+ damping_coefficient: jtp.Float = dataclasses.field(
36
+ default_factory=lambda: jnp.array(1.0, dtype=float)
37
+ )
38
+
39
+ # Minimum impedance
40
+ d_min: jtp.Float = dataclasses.field(
41
+ default_factory=lambda: jnp.array(0.9, dtype=float)
42
+ )
43
+
44
+ # Maximum impedance
45
+ d_max: jtp.Float = dataclasses.field(
46
+ default_factory=lambda: jnp.array(0.95, dtype=float)
47
+ )
48
+
49
+ # Width
50
+ width: jtp.Float = dataclasses.field(
51
+ default_factory=lambda: jnp.array(0.0001, dtype=float)
52
+ )
53
+
54
+ # Midpoint
55
+ midpoint: jtp.Float = dataclasses.field(
56
+ default_factory=lambda: jnp.array(0.1, dtype=float)
57
+ )
58
+
59
+ # Power exponent
60
+ power: jtp.Float = dataclasses.field(
61
+ default_factory=lambda: jnp.array(1.0, dtype=float)
62
+ )
63
+
64
+ # Stiffness
65
+ stiffness: jtp.Float = dataclasses.field(
66
+ default_factory=lambda: jnp.array(0.0, dtype=float)
67
+ )
68
+
69
+ # Damping
70
+ damping: jtp.Float = dataclasses.field(
71
+ default_factory=lambda: jnp.array(0.0, dtype=float)
72
+ )
73
+
74
+ # Friction coefficient
75
+ mu: jtp.Float = dataclasses.field(
76
+ default_factory=lambda: jnp.array(0.5, dtype=float)
77
+ )
78
+
79
+ def __hash__(self) -> int:
80
+ from jaxsim.utils.wrappers import HashedNumpyArray
81
+
82
+ return hash(
83
+ (
84
+ HashedNumpyArray(self.time_constant),
85
+ HashedNumpyArray(self.damping_coefficient),
86
+ HashedNumpyArray(self.d_min),
87
+ HashedNumpyArray(self.d_max),
88
+ HashedNumpyArray(self.width),
89
+ HashedNumpyArray(self.midpoint),
90
+ HashedNumpyArray(self.power),
91
+ HashedNumpyArray(self.stiffness),
92
+ HashedNumpyArray(self.damping),
93
+ HashedNumpyArray(self.mu),
94
+ )
95
+ )
96
+
97
+ def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
98
+ return hash(self) == hash(other)
99
+
100
+ @classmethod
101
+ def build(
102
+ cls: type[Self],
103
+ *,
104
+ time_constant: jtp.FloatLike | None = None,
105
+ damping_coefficient: jtp.FloatLike | None = None,
106
+ d_min: jtp.FloatLike | None = None,
107
+ d_max: jtp.FloatLike | None = None,
108
+ width: jtp.FloatLike | None = None,
109
+ midpoint: jtp.FloatLike | None = None,
110
+ power: jtp.FloatLike | None = None,
111
+ stiffness: jtp.FloatLike | None = None,
112
+ damping: jtp.FloatLike | None = None,
113
+ mu: jtp.FloatLike | None = None,
114
+ ) -> Self:
115
+ """Create a `RelaxedRigidContactsParams` instance."""
116
+
117
+ def default(name: str):
118
+ return cls.__dataclass_fields__[name].default_factory()
119
+
120
+ return cls(
121
+ time_constant=jnp.array(
122
+ (
123
+ time_constant
124
+ if time_constant is not None
125
+ else default("time_constant")
126
+ ),
127
+ dtype=float,
128
+ ),
129
+ damping_coefficient=jnp.array(
130
+ (
131
+ damping_coefficient
132
+ if damping_coefficient is not None
133
+ else default("damping_coefficient")
134
+ ),
135
+ dtype=float,
136
+ ),
137
+ d_min=jnp.array(
138
+ d_min if d_min is not None else default("d_min"), dtype=float
139
+ ),
140
+ d_max=jnp.array(
141
+ d_max if d_max is not None else default("d_max"), dtype=float
142
+ ),
143
+ width=jnp.array(
144
+ width if width is not None else default("width"), dtype=float
145
+ ),
146
+ midpoint=jnp.array(
147
+ midpoint if midpoint is not None else default("midpoint"), dtype=float
148
+ ),
149
+ power=jnp.array(
150
+ power if power is not None else default("power"), dtype=float
151
+ ),
152
+ stiffness=jnp.array(
153
+ stiffness if stiffness is not None else default("stiffness"),
154
+ dtype=float,
155
+ ),
156
+ damping=jnp.array(
157
+ damping if damping is not None else default("damping"), dtype=float
158
+ ),
159
+ mu=jnp.array(mu if mu is not None else default("mu"), dtype=float),
160
+ )
161
+
162
+ def valid(self) -> jtp.BoolLike:
163
+ """Check if the parameters are valid."""
164
+
165
+ return bool(
166
+ jnp.all(self.time_constant >= 0.0)
167
+ and jnp.all(self.damping_coefficient > 0.0)
168
+ and jnp.all(self.d_min >= 0.0)
169
+ and jnp.all(self.d_max <= 1.0)
170
+ and jnp.all(self.d_min <= self.d_max)
171
+ and jnp.all(self.width >= 0.0)
172
+ and jnp.all(self.midpoint >= 0.0)
173
+ and jnp.all(self.power >= 0.0)
174
+ and jnp.all(self.mu >= 0.0)
175
+ )
176
+
177
+
178
+ @jax_dataclasses.pytree_dataclass
179
+ class RelaxedRigidContacts(common.ContactModel):
180
+ """Relaxed rigid contacts model."""
181
+
182
+ _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
183
+ default=("tol", "maxiter", "memory_size"), kw_only=True
184
+ )
185
+ _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
186
+ default=(1e-6, 50, 10), kw_only=True
187
+ )
188
+
189
+ @property
190
+ def solver_options(self) -> dict[str, Any]:
191
+ """Get the solver options."""
192
+
193
+ return dict(
194
+ zip(
195
+ self._solver_options_keys,
196
+ self._solver_options_values,
197
+ strict=True,
198
+ )
199
+ )
200
+
201
+ @classmethod
202
+ def build(
203
+ cls: type[Self],
204
+ solver_options: dict[str, Any] | None = None,
205
+ **kwargs,
206
+ ) -> Self:
207
+ """
208
+ Create a `RelaxedRigidContacts` instance with specified parameters.
209
+
210
+ Args:
211
+ solver_options: The options to pass to the L-BFGS solver.
212
+ **kwargs: The parameters of the relaxed rigid contacts model.
213
+
214
+ Returns:
215
+ The `RelaxedRigidContacts` instance.
216
+ """
217
+
218
+ # Get the default solver options.
219
+ default_solver_options = dict(
220
+ zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
221
+ )
222
+
223
+ # Create the solver options to set by combining the default solver options
224
+ # with the user-provided solver options.
225
+ solver_options = default_solver_options | (
226
+ solver_options if solver_options is not None else {}
227
+ )
228
+
229
+ # Make sure that the solver options are hashable.
230
+ # We need to check this because the solver options are static.
231
+ try:
232
+ hash(tuple(solver_options.values()))
233
+ except TypeError as exc:
234
+ raise ValueError(
235
+ "The values of the solver options must be hashable."
236
+ ) from exc
237
+
238
+ return cls(
239
+ _solver_options_keys=tuple(solver_options.keys()),
240
+ _solver_options_values=tuple(solver_options.values()),
241
+ **kwargs,
242
+ )
243
+
244
+ @jax.jit
245
+ def compute_contact_forces(
246
+ self,
247
+ model: js.model.JaxSimModel,
248
+ data: js.data.JaxSimModelData,
249
+ *,
250
+ link_forces: jtp.MatrixLike | None = None,
251
+ joint_force_references: jtp.VectorLike | None = None,
252
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
253
+ """
254
+ Compute the contact forces.
255
+
256
+ Args:
257
+ model: The model to consider.
258
+ data: The data of the considered model.
259
+ link_forces:
260
+ Optional `(n_links, 6)` matrix of external forces acting on the links,
261
+ expressed in the same representation of data.
262
+ joint_force_references:
263
+ Optional `(n_joints,)` vector of joint forces.
264
+
265
+ Returns:
266
+ A tuple containing as first element the computed contact forces.
267
+ """
268
+
269
+ link_forces = jnp.atleast_2d(
270
+ jnp.array(link_forces, dtype=float).squeeze()
271
+ if link_forces is not None
272
+ else jnp.zeros((model.number_of_links(), 6))
273
+ )
274
+
275
+ joint_force_references = jnp.atleast_1d(
276
+ jnp.array(joint_force_references, dtype=float).squeeze()
277
+ if joint_force_references is not None
278
+ else jnp.zeros(model.number_of_joints())
279
+ )
280
+
281
+ references = js.references.JaxSimModelReferences.build(
282
+ model=model,
283
+ data=data,
284
+ velocity_representation=data.velocity_representation,
285
+ link_forces=link_forces,
286
+ joint_force_references=joint_force_references,
287
+ )
288
+
289
+ # Compute the position and linear velocities (mixed representation) of
290
+ # all collidable points belonging to the robot.
291
+ position, velocity = js.contact.collidable_point_kinematics(
292
+ model=model, data=data
293
+ )
294
+
295
+ # Compute the penetration depth and velocity of the collidable points.
296
+ # Note that this function considers the penetration in the normal direction.
297
+ δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
298
+ position, velocity, model.terrain
299
+ )
300
+
301
+ # Compute the position in the constraint frame.
302
+ position_constraint = jax.vmap(lambda δ, n̂: -δ * n̂)(δ, n̂)
303
+
304
+ # Compute the transforms of the implicit frames corresponding to the
305
+ # collidable points.
306
+ W_H_C = js.contact.transforms(model=model, data=data)
307
+
308
+ with (
309
+ references.switch_velocity_representation(VelRepr.Mixed),
310
+ data.switch_velocity_representation(VelRepr.Mixed),
311
+ ):
312
+
313
+ BW_ν = data.generalized_velocity()
314
+
315
+ BW_ν̇_free = jnp.hstack(
316
+ js.ode.system_acceleration(
317
+ model=model,
318
+ data=data,
319
+ link_forces=references.link_forces(model=model, data=data),
320
+ joint_force_references=references.joint_force_references(
321
+ model=model
322
+ ),
323
+ )
324
+ )
325
+
326
+ M = js.model.free_floating_mass_matrix(model=model, data=data)
327
+
328
+ Jl_WC = jnp.vstack(
329
+ jax.vmap(lambda J, δ: J * (δ > 0))(
330
+ js.contact.jacobian(model=model, data=data)[:, :3, :], δ
331
+ )
332
+ )
333
+
334
+ J̇_WC = jnp.vstack(
335
+ jax.vmap(lambda J̇, δ: J̇ * (δ > 0))(
336
+ js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
337
+ ),
338
+ )
339
+
340
+ # Compute the regularization terms.
341
+ a_ref, R, *_ = self._regularizers(
342
+ model=model,
343
+ position_constraint=position_constraint,
344
+ velocity_constraint=velocity,
345
+ parameters=data.contacts_params,
346
+ )
347
+
348
+ # Compute the Delassus matrix and the free mixed linear acceleration of
349
+ # the collidable points.
350
+ G = Jl_WC @ jnp.linalg.pinv(M) @ Jl_WC.T
351
+ CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
352
+
353
+ # Calculate quantities for the linear optimization problem.
354
+ A = G + R
355
+ b = CW_al_free_WC - a_ref
356
+
357
+ # Create the objective function to minimize as a lambda computing the cost
358
+ # from the optimized variables x.
359
+ objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))
360
+
361
+ # ========================================
362
+ # Helper function to run the L-BFGS solver
363
+ # ========================================
364
+
365
+ def run_optimization(
366
+ init_params: jtp.Vector,
367
+ fun: Callable,
368
+ opt: optax.GradientTransformationExtraArgs,
369
+ maxiter: int,
370
+ tol: float,
371
+ ) -> tuple[jtp.Vector, optax.OptState]:
372
+
373
+ # Get the function to compute the loss and the gradient w.r.t. its inputs.
374
+ value_and_grad_fn = optax.value_and_grad_from_state(fun)
375
+
376
+ # Initialize the carry of the following loop.
377
+ OptimizationCarry = tuple[jtp.Vector, optax.OptState]
378
+ init_carry: OptimizationCarry = (init_params, opt.init(params=init_params))
379
+
380
+ def step(carry: OptimizationCarry) -> OptimizationCarry:
381
+
382
+ params, state = carry
383
+
384
+ value, grad = value_and_grad_fn(
385
+ params,
386
+ state=state,
387
+ A=A,
388
+ b=b,
389
+ )
390
+
391
+ updates, state = opt.update(
392
+ updates=grad,
393
+ state=state,
394
+ params=params,
395
+ value=value,
396
+ grad=grad,
397
+ value_fn=fun,
398
+ A=A,
399
+ b=b,
400
+ )
401
+
402
+ params = optax.apply_updates(params, updates)
403
+
404
+ return params, state
405
+
406
+ # TODO: maybe fix the number of iterations and switch to scan?
407
+ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
408
+
409
+ _, state = carry
410
+
411
+ iter_num = optax.tree_utils.tree_get(state, "count")
412
+ grad = optax.tree_utils.tree_get(state, "grad")
413
+ err = optax.tree_utils.tree_l2_norm(grad)
414
+
415
+ return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))
416
+
417
+ final_params, final_state = jax.lax.while_loop(
418
+ continuing_criterion, step, init_carry
419
+ )
420
+
421
+ return final_params, final_state
422
+
423
+ # ======================================
424
+ # Compute the contact forces with L-BFGS
425
+ # ======================================
426
+
427
+ # Initialize the optimized forces with a linear Hunt/Crossley model.
428
+ init_params = jax.vmap(
429
+ lambda p, v: jaxsim.rbda.contacts.SoftContacts.hunt_crossley_contact_model(
430
+ position=p,
431
+ velocity=v,
432
+ terrain=model.terrain,
433
+ K=1e6,
434
+ D=2e3,
435
+ p=0.5,
436
+ q=0.5,
437
+ # No tangential initial forces.
438
+ mu=0.0,
439
+ tangential_deformation=jnp.zeros(3),
440
+ )[0]
441
+ )(position, velocity).flatten()
442
+
443
+ # Get the solver options.
444
+ solver_options = self.solver_options
445
+
446
+ # Extract the options corresponding to the convergence criteria.
447
+ # All the remaining options are passed to the solver.
448
+ tol = solver_options.pop("tol")
449
+ maxiter = solver_options.pop("maxiter")
450
+
451
+ # Compute the 3D linear force in C[W] frame.
452
+ solution, _ = run_optimization(
453
+ init_params=init_params,
454
+ fun=objective,
455
+ opt=optax.lbfgs(**solver_options),
456
+ tol=tol,
457
+ maxiter=maxiter,
458
+ )
459
+
460
+ # Reshape the optimized solution to be a matrix of 3D contact forces.
461
+ CW_fl_C = solution.reshape(-1, 3)
462
+
463
+ # Convert the contact forces from mixed to inertial-fixed representation.
464
+ W_f_C = jax.vmap(
465
+ lambda CW_fl_C, W_H_C: (
466
+ ModelDataWithVelocityRepresentation.other_representation_to_inertial(
467
+ array=jnp.zeros(6).at[0:3].set(CW_fl_C),
468
+ transform=W_H_C,
469
+ other_representation=VelRepr.Mixed,
470
+ is_force=True,
471
+ )
472
+ ),
473
+ )(CW_fl_C, W_H_C)
474
+
475
+ return W_f_C, {}
476
+
477
+ @staticmethod
478
+ def _regularizers(
479
+ model: js.model.JaxSimModel,
480
+ position_constraint: jtp.Vector,
481
+ velocity_constraint: jtp.Vector,
482
+ parameters: RelaxedRigidContactsParams,
483
+ ) -> tuple:
484
+ """
485
+ Compute the contact jacobian and the reference acceleration.
486
+
487
+ Args:
488
+ model: The jaxsim model.
489
+ position_constraint: The position of the collidable points in the constraint frame.
490
+ velocity_constraint: The velocity of the collidable points in the constraint frame.
491
+ parameters: The parameters of the relaxed rigid contacts model.
492
+
493
+ Returns:
494
+ A tuple containing the reference acceleration, the regularization matrix,
495
+ the stiffness, and the damping.
496
+ """
497
+
498
+ # Extract the parameters of the contact model.
499
+ Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = (
500
+ getattr(parameters, field)
501
+ for field in (
502
+ "time_constant",
503
+ "damping_coefficient",
504
+ "d_min",
505
+ "d_max",
506
+ "width",
507
+ "midpoint",
508
+ "power",
509
+ "stiffness",
510
+ "damping",
511
+ "mu",
512
+ )
513
+ )
514
+
515
+ # Get the indices of the enabled collidable points.
516
+ indices_of_enabled_collidable_points = (
517
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
518
+ )
519
+
520
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
521
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
522
+ )[indices_of_enabled_collidable_points]
523
+
524
+ # Compute the 6D inertia matrices of all links.
525
+ M_L = js.model.link_spatial_inertia_matrices(model=model)
526
+
527
+ def imp_aref(
528
+ pos: jtp.Vector,
529
+ vel: jtp.Vector,
530
+ ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector]:
531
+ """
532
+ Calculate impedance and offset acceleration in constraint frame.
533
+
534
+ Args:
535
+ pos: position in constraint frame.
536
+ vel: velocity in constraint frame.
537
+
538
+ Returns:
539
+ ξ: computed impedance
540
+ a_ref: offset acceleration in constraint frame
541
+ K: computed stiffness
542
+ D: computed damping
543
+ """
544
+
545
+ imp_x = jnp.abs(pos) / width
546
+
547
+ imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
548
+ imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
549
+ imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
550
+
551
+ # Compute the impedance.
552
+ ξ = ξ_min + imp_y * (ξ_max - ξ_min)
553
+ ξ = jnp.clip(ξ, ξ_min, ξ_max)
554
+ ξ = jnp.where(imp_x > 1.0, ξ_max, ξ)
555
+
556
+ # Compute the spring and damper parameters during runtime from the
557
+ # impedance and other contact parameters.
558
+ K = 1 / (ξ_max * Ω * ζ) ** 2
559
+ D = 2 / (ξ_max * Ω)
560
+
561
+ # If the user specifies K and D and they are negative, the computed `a_ref`
562
+ # becomes something more similar to a classic Baumgarte regularization.
563
+ K = jnp.where(K < 0, -K / ξ_max**2, K)
564
+ D = jnp.where(D < 0, -D / ξ_max, D)
565
+
566
+ # Compute the reference acceleration.
567
+ a_ref = -(D * vel + K * ξ * pos)
568
+
569
+ return ξ, a_ref, K, D
570
+
571
+ def compute_row(
572
+ *,
573
+ link_idx: jtp.Int,
574
+ pos: jtp.Vector,
575
+ vel: jtp.Vector,
576
+ ) -> tuple[jtp.Vector, jtp.Matrix, jtp.Vector, jtp.Vector]:
577
+
578
+ # Compute the reference acceleration.
579
+ ξ, a_ref, K, D = imp_aref(pos=pos, vel=vel)
580
+
581
+ # Compute the regularization term.
582
+ R = (
583
+ (2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
584
+ * (1 + μ**2)
585
+ @ jnp.linalg.inv(M_L[link_idx, :3, :3])
586
+ )
587
+
588
+ # Return the computed values, setting them to zero in case of no contact.
589
+ is_active = (pos.dot(pos) > 0).astype(float)
590
+ return jax.tree.map(
591
+ lambda x: jnp.atleast_1d(x) * is_active, (a_ref, R, K, D)
592
+ )
593
+
594
+ a_ref, R, K, D = jax.tree.map(
595
+ f=jnp.concatenate,
596
+ tree=(
597
+ *jax.vmap(compute_row)(
598
+ link_idx=parent_link_idx_of_enabled_collidable_points,
599
+ pos=position_constraint,
600
+ vel=velocity_constraint,
601
+ ),
602
+ ),
603
+ )
604
+
605
+ return a_ref, jnp.diag(R), K, D