jaxsim 0.4.3.dev181__py3-none-any.whl → 0.4.3.dev200__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,11 +12,10 @@ import optax
12
12
  import jaxsim.api as js
13
13
  import jaxsim.typing as jtp
14
14
  from jaxsim import logging
15
- from jaxsim.api.common import VelRepr
16
- from jaxsim.math import Adjoint
15
+ from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
17
16
  from jaxsim.terrain.terrain import FlatTerrain, Terrain
18
17
 
19
- from .common import ContactModel, ContactsParams
18
+ from . import common
20
19
 
21
20
  try:
22
21
  from typing import Self
@@ -25,7 +24,7 @@ except ImportError:
25
24
 
26
25
 
27
26
  @jax_dataclasses.pytree_dataclass
28
- class RelaxedRigidContactsParams(ContactsParams):
27
+ class RelaxedRigidContactsParams(common.ContactsParams):
29
28
  """Parameters of the relaxed rigid contacts model."""
30
29
 
31
30
  # Time constant
@@ -116,14 +115,24 @@ class RelaxedRigidContactsParams(ContactsParams):
116
115
  ) -> Self:
117
116
  """Create a `RelaxedRigidContactsParams` instance"""
118
117
 
118
+ def default(name: str):
119
+ return cls.__dataclass_fields__[name].default_factory()
120
+
119
121
  return cls(
120
- **{
121
- field: jnp.array(locals().get(field, default), dtype=default.dtype)
122
- for field, default in map(
123
- lambda f: (f, cls.__dataclass_fields__[f].default),
124
- filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
125
- )
126
- }
122
+ time_constant=jnp.array(
123
+ time_constant or default("time_constant"), dtype=float
124
+ ),
125
+ damping_coefficient=jnp.array(
126
+ damping_coefficient or default("damping_coefficient"), dtype=float
127
+ ),
128
+ d_min=jnp.array(d_min or default("d_min"), dtype=float),
129
+ d_max=jnp.array(d_max or default("d_max"), dtype=float),
130
+ width=jnp.array(width or default("width"), dtype=float),
131
+ midpoint=jnp.array(midpoint or default("midpoint"), dtype=float),
132
+ power=jnp.array(power or default("power"), dtype=float),
133
+ stiffness=jnp.array(stiffness or default("stiffness"), dtype=float),
134
+ damping=jnp.array(damping or default("damping"), dtype=float),
135
+ mu=jnp.array(mu or default("mu"), dtype=float),
127
136
  )
128
137
 
129
138
  def valid(self) -> jtp.BoolLike:
@@ -142,7 +151,7 @@ class RelaxedRigidContactsParams(ContactsParams):
142
151
 
143
152
 
144
153
  @jax_dataclasses.pytree_dataclass
145
- class RelaxedRigidContacts(ContactModel):
154
+ class RelaxedRigidContacts(common.ContactModel):
146
155
  """Relaxed rigid contacts model."""
147
156
 
148
157
  parameters: RelaxedRigidContactsParams = dataclasses.field(
@@ -229,7 +238,7 @@ class RelaxedRigidContacts(ContactModel):
229
238
  *,
230
239
  link_forces: jtp.MatrixLike | None = None,
231
240
  joint_force_references: jtp.VectorLike | None = None,
232
- ) -> tuple[jtp.Vector, tuple[Any, ...]]:
241
+ ) -> tuple[jtp.Matrix, tuple]:
233
242
  """
234
243
  Compute the contact forces.
235
244
 
@@ -243,22 +252,23 @@ class RelaxedRigidContacts(ContactModel):
243
252
  Optional `(n_joints,)` vector of joint forces.
244
253
 
245
254
  Returns:
246
- A tuple containing the contact forces.
255
+ A tuple containing as first element the computed contact forces.
247
256
  """
248
257
 
249
258
  # Initialize the model and data this contact model is operating on.
250
259
  # This will raise an exception if either the contact model or the
251
260
  # contact parameters are not compatible.
252
261
  model, data = self.initialize_model_and_data(model=model, data=data)
262
+ assert isinstance(data.contacts_params, RelaxedRigidContactsParams)
253
263
 
254
- link_forces = (
255
- link_forces
264
+ link_forces = jnp.atleast_2d(
265
+ jnp.array(link_forces, dtype=float).squeeze()
256
266
  if link_forces is not None
257
267
  else jnp.zeros((model.number_of_links(), 6))
258
268
  )
259
269
 
260
- joint_force_references = (
261
- joint_force_references
270
+ joint_force_references = jnp.atleast_1d(
271
+ jnp.array(joint_force_references, dtype=float).squeeze()
262
272
  if joint_force_references is not None
263
273
  else jnp.zeros(model.number_of_joints())
264
274
  )
@@ -271,10 +281,10 @@ class RelaxedRigidContacts(ContactModel):
271
281
  joint_force_references=joint_force_references,
272
282
  )
273
283
 
274
- def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
284
+ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
275
285
  x, y, z = jax.tree.map(jnp.squeeze, (x, y, z))
276
286
 
277
- n̂ = self.terrain.normal(x=x, y=y).squeeze()
287
+ n̂ = model.terrain.normal(x=x, y=y).squeeze()
278
288
  h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
279
289
 
280
290
  return jnp.dot(h, n̂)
@@ -286,19 +296,19 @@ class RelaxedRigidContacts(ContactModel):
286
296
  )
287
297
 
288
298
  # Compute the activation state of the collidable points
289
- δ = jax.vmap(_detect_contact)(*position.T)
299
+ δ = jax.vmap(detect_contact)(*position.T)
300
+
301
+ # Compute the transforms of the implicit frames corresponding to the
302
+ # collidable points.
303
+ W_H_C = js.contact.transforms(model=model, data=data)
290
304
 
291
305
  with (
292
306
  references.switch_velocity_representation(VelRepr.Mixed),
293
307
  data.switch_velocity_representation(VelRepr.Mixed),
294
308
  ):
295
- M = js.model.free_floating_mass_matrix(model=model, data=data)
296
- Jl_WC = jnp.vstack(
297
- jax.vmap(lambda J, height: J * (height < 0))(
298
- js.contact.jacobian(model=model, data=data)[:, :3, :], δ
299
- )
300
- )
301
- W_H_C = js.contact.transforms(model=model, data=data)
309
+
310
+ BW_ν = data.generalized_velocity()
311
+
302
312
  BW_ν̇_free = jnp.hstack(
303
313
  js.ode.system_acceleration(
304
314
  model=model,
@@ -309,20 +319,31 @@ class RelaxedRigidContacts(ContactModel):
309
319
  ),
310
320
  )
311
321
  )
312
- BW_ν = data.generalized_velocity()
322
+
323
+ M = js.model.free_floating_mass_matrix(model=model, data=data)
324
+
325
+ Jl_WC = jnp.vstack(
326
+ jax.vmap(lambda J, height: J * (height < 0))(
327
+ js.contact.jacobian(model=model, data=data)[:, :3, :], δ
328
+ )
329
+ )
330
+
313
331
  J̇_WC = jnp.vstack(
314
332
  jax.vmap(lambda J̇, height: J̇ * (height < 0))(
315
333
  js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
316
334
  ),
317
335
  )
318
336
 
319
- a_ref, R, K, D = self._regularizers(
320
- model=model,
321
- penetration=δ,
322
- velocity=velocity,
323
- parameters=self.parameters,
324
- )
337
+ # Compute the regularization terms.
338
+ a_ref, R, K, D = self._regularizers(
339
+ model=model,
340
+ penetration=δ,
341
+ velocity=velocity,
342
+ parameters=data.contacts_params,
343
+ )
325
344
 
345
+ # Compute the Delassus matrix and the free mixed linear acceleration of
346
+ # the collidable points.
326
347
  G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
327
348
  CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
328
349
 
@@ -330,26 +351,40 @@ class RelaxedRigidContacts(ContactModel):
330
351
  A = G + R
331
352
  b = CW_al_free_WC - a_ref
332
353
 
354
+ # Create the objective function to minimize as a lambda computing the cost
355
+ # from the optimized variables x.
333
356
  objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))
334
357
 
358
+ # ========================================
359
+ # Helper function to run the L-BFGS solver
360
+ # ========================================
361
+
335
362
  def run_optimization(
336
- init_params: jtp.Array,
363
+ init_params: jtp.Vector,
337
364
  fun: Callable,
338
- opt: optax.GradientTransformation,
339
- maxiter: jtp.Int,
340
- tol: jtp.Float,
341
- **kwargs,
342
- ):
365
+ opt: optax.GradientTransformationExtraArgs,
366
+ maxiter: int,
367
+ tol: float,
368
+ ) -> tuple[jtp.Vector, optax.OptState]:
369
+
370
+ # Get the function to compute the loss and the gradient w.r.t. its inputs.
343
371
  value_and_grad_fn = optax.value_and_grad_from_state(fun)
344
372
 
345
- def step(carry):
373
+ # Initialize the carry of the following loop.
374
+ OptimizationCarry = tuple[jtp.Vector, optax.OptState]
375
+ init_carry: OptimizationCarry = (init_params, opt.init(params=init_params))
376
+
377
+ def step(carry: OptimizationCarry) -> OptimizationCarry:
378
+
346
379
  params, state = carry
380
+
347
381
  value, grad = value_and_grad_fn(
348
382
  params,
349
383
  state=state,
350
384
  A=A,
351
385
  b=b,
352
386
  )
387
+
353
388
  updates, state = opt.update(
354
389
  updates=grad,
355
390
  state=state,
@@ -360,22 +395,32 @@ class RelaxedRigidContacts(ContactModel):
360
395
  A=A,
361
396
  b=b,
362
397
  )
398
+
363
399
  params = optax.apply_updates(params, updates)
400
+
364
401
  return params, state
365
402
 
366
- def continuing_criterion(carry):
403
+ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
404
+
367
405
  _, state = carry
406
+
368
407
  iter_num = optax.tree_utils.tree_get(state, "count")
369
408
  grad = optax.tree_utils.tree_get(state, "grad")
370
409
  err = optax.tree_utils.tree_l2_norm(grad)
410
+
371
411
  return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))
372
412
 
373
- init_carry = (init_params, opt.init(init_params))
374
413
  final_params, final_state = jax.lax.while_loop(
375
414
  continuing_criterion, step, init_carry
376
415
  )
416
+
377
417
  return final_params, final_state
378
418
 
419
+ # ======================================
420
+ # Compute the contact forces with L-BFGS
421
+ # ======================================
422
+
423
+ # Initialize the optimized forces with a linear Hunt/Crossley model.
379
424
  init_params = (
380
425
  K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
381
426
  + D[:, jnp.newaxis] * velocity
@@ -390,28 +435,30 @@ class RelaxedRigidContacts(ContactModel):
390
435
  maxiter = solver_options.pop("maxiter")
391
436
 
392
437
  # Compute the 3D linear force in C[W] frame.
393
- CW_f_Ci, _ = run_optimization(
438
+ solution, _ = run_optimization(
394
439
  init_params=init_params,
395
- A=A,
396
- b=b,
397
- maxiter=maxiter,
398
- opt=optax.lbfgs(**solver_options),
399
440
  fun=objective,
441
+ opt=optax.lbfgs(**solver_options),
400
442
  tol=tol,
443
+ maxiter=maxiter,
401
444
  )
402
445
 
403
- CW_f_Ci = CW_f_Ci.reshape((-1, 3))
404
-
405
- def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
406
- W_Xf_CW = Adjoint.from_transform(
407
- W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
408
- inverse=True,
409
- ).T
410
- return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
411
-
412
- W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
446
+ # Reshape the optimized solution to be a matrix of 3D contact forces.
447
+ CW_fl_C = solution.reshape(-1, 3)
448
+
449
+ # Convert the contact forces from mixed to inertial-fixed representation.
450
+ W_f_C = jax.vmap(
451
+ lambda CW_fl_C, W_H_C: (
452
+ ModelDataWithVelocityRepresentation.other_representation_to_inertial(
453
+ array=jnp.zeros(6).at[0:3].set(CW_fl_C),
454
+ transform=W_H_C,
455
+ other_representation=VelRepr.Mixed,
456
+ is_force=True,
457
+ )
458
+ ),
459
+ )(CW_fl_C, W_H_C)
413
460
 
414
- return W_f_C, (None,)
461
+ return W_f_C, ()
415
462
 
416
463
  @staticmethod
417
464
  def _regularizers(
@@ -433,13 +480,28 @@ class RelaxedRigidContacts(ContactModel):
433
480
  A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
434
481
  """
435
482
 
436
- Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
437
- parameters
483
+ # Extract the parameters of the contact model.
484
+ Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = (
485
+ getattr(parameters, field)
486
+ for field in (
487
+ "time_constant",
488
+ "damping_coefficient",
489
+ "d_min",
490
+ "d_max",
491
+ "width",
492
+ "midpoint",
493
+ "power",
494
+ "stiffness",
495
+ "damping",
496
+ "mu",
497
+ )
438
498
  )
439
499
 
440
- def _imp_aref(
441
- penetration: jtp.Array,
442
- velocity: jtp.Array,
500
+ # Compute the 6D inertia matrices of all links.
501
+ M_L = js.model.link_spatial_inertia_matrices(model=model)
502
+
503
+ def imp_aref(
504
+ penetration: jtp.Array, velocity: jtp.Array
443
505
  ) -> tuple[jtp.Array, jtp.Array]:
444
506
  """
445
507
  Calculates impedance and offset acceleration in constraint frame.
@@ -474,7 +536,7 @@ class RelaxedRigidContacts(ContactModel):
474
536
 
475
537
  return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
476
538
 
477
- def _compute_row(
539
+ def compute_row(
478
540
  *,
479
541
  link_idx: jtp.Float,
480
542
  penetration: jtp.Array,
@@ -482,7 +544,7 @@ class RelaxedRigidContacts(ContactModel):
482
544
  ) -> tuple[jtp.Array, jtp.Array]:
483
545
 
484
546
  # Compute the reference acceleration.
485
- ξ, a_ref, K, D = _imp_aref(
547
+ ξ, a_ref, K, D = imp_aref(
486
548
  penetration=penetration,
487
549
  velocity=velocity,
488
550
  )
@@ -496,12 +558,10 @@ class RelaxedRigidContacts(ContactModel):
496
558
 
497
559
  return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
498
560
 
499
- M_L = js.model.link_spatial_inertia_matrices(model=model)
500
-
501
561
  a_ref, R, K, D = jax.tree.map(
502
- jnp.concatenate,
503
- (
504
- *jax.vmap(_compute_row)(
562
+ f=jnp.concatenate,
563
+ tree=(
564
+ *jax.vmap(compute_row)(
505
565
  link_idx=jnp.array(
506
566
  model.kin_dyn_parameters.contact_parameters.body
507
567
  ),
@@ -510,4 +570,5 @@ class RelaxedRigidContacts(ContactModel):
510
570
  ),
511
571
  ),
512
572
  )
573
+
513
574
  return a_ref, jnp.diag(R), K, D