jaxsim 0.4.3.dev352__py3-none-any.whl → 0.4.3.dev362__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev352'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev352')
15
+ __version__ = version = '0.4.3.dev362'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev362')
jaxsim/api/contact.py CHANGED
@@ -151,7 +151,7 @@ def collidable_point_dynamics(
151
151
  kwargs: Additional keyword arguments to pass to the active contact model.
152
152
 
153
153
  Returns:
154
- The 6D force applied to each eneabled collidable point and additional data based
154
+ The 6D force applied to each enabled collidable point and additional data based
155
155
  on the contact model configured:
156
156
  - Soft: the material deformation rate.
157
157
  - Rigid: no additional data.
jaxsim/api/model.py CHANGED
@@ -2235,7 +2235,7 @@ def step(
2235
2235
  # by the pair (f_L, τ_references).
2236
2236
  # Note that the wrapper of the system dynamics will override (state_x0, t0)
2237
2237
  # inside the passed data even if it is not strictly needed. This logic is
2238
- # necessary to re-use the jit-compiled step function of compatible pytrees
2238
+ # necessary to reuse the jit-compiled step function of compatible pytrees
2239
2239
  # of model and data produced e.g. by parameterized applications.
2240
2240
  **(
2241
2241
  dict(
@@ -324,10 +324,10 @@ class ContactModel(JaxsimDataclass):
324
324
  The initialized model and data objects.
325
325
  """
326
326
 
327
- with model.editable(validate=validate) as model_out:
328
- model_out.contact_model = self
327
+ with self.editable(validate=validate) as contact_model:
328
+ contact_model.parameters = data.contacts_params
329
329
 
330
- with data.editable(validate=validate) as data_out:
331
- data_out.contacts_params = data.contacts_params
330
+ with model.editable(validate=validate) as model_out:
331
+ model_out.contact_model = contact_model
332
332
 
333
- return model_out, data_out
333
+ return model_out, data
@@ -10,6 +10,7 @@ import jax_dataclasses
10
10
  import optax
11
11
 
12
12
  import jaxsim.api as js
13
+ import jaxsim.rbda.contacts
13
14
  import jaxsim.typing as jtp
14
15
  from jaxsim import logging
15
16
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
@@ -314,22 +315,20 @@ class RelaxedRigidContacts(common.ContactModel):
314
315
  joint_force_references=joint_force_references,
315
316
  )
316
317
 
317
- def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
318
- x, y, z = jax.tree.map(jnp.squeeze, (x, y, z))
319
-
320
- n̂ = model.terrain.normal(x=x, y=y).squeeze()
321
- h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
322
-
323
- return jnp.dot(h, n̂)
324
-
325
318
  # Compute the position and linear velocities (mixed representation) of
326
319
  # all collidable points belonging to the robot.
327
320
  position, velocity = js.contact.collidable_point_kinematics(
328
321
  model=model, data=data
329
322
  )
330
323
 
331
- # Compute the activation state of the collidable points
332
- δ = jax.vmap(detect_contact)(*position.T)
324
+ # Compute the penetration depth and velocity of the collidable points.
325
+ # Note that this function considers the penetration in the normal direction.
326
+ δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
327
+ position, velocity, model.terrain
328
+ )
329
+
330
+ # Compute the position in the constraint frame.
331
+ position_constraint = jax.vmap(lambda δ, n̂: -δ * n̂)(δ, n̂)
333
332
 
334
333
  # Compute the transforms of the implicit frames corresponding to the
335
334
  # collidable points.
@@ -356,24 +355,22 @@ class RelaxedRigidContacts(common.ContactModel):
356
355
  M = js.model.free_floating_mass_matrix(model=model, data=data)
357
356
 
358
357
  Jl_WC = jnp.vstack(
359
- jax.vmap(lambda J, height: J * (height < 0))(
360
- js.contact.jacobian(model=model, data=data)[:, :3, :],
361
- δ,
358
+ jax.vmap(lambda J, δ: J * (δ > 0))(
359
+ js.contact.jacobian(model=model, data=data)[:, :3, :], δ
362
360
  )
363
361
  )
364
362
 
365
363
  J̇_WC = jnp.vstack(
366
- jax.vmap(lambda J̇, height: J̇ * (height < 0))(
367
- js.contact.jacobian_derivative(model=model, data=data)[:, :3],
368
- δ,
364
+ jax.vmap(lambda J̇, δ: J̇ * (δ > 0))(
365
+ js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
369
366
  ),
370
367
  )
371
368
 
372
369
  # Compute the regularization terms.
373
- a_ref, R, K, D = self._regularizers(
370
+ a_ref, R, *_ = self._regularizers(
374
371
  model=model,
375
- penetration=δ,
376
- velocity=velocity,
372
+ position_constraint=position_constraint,
373
+ velocity_constraint=velocity,
377
374
  parameters=data.contacts_params,
378
375
  )
379
376
 
@@ -435,6 +432,7 @@ class RelaxedRigidContacts(common.ContactModel):
435
432
 
436
433
  return params, state
437
434
 
435
+ # TODO: maybe fix the number of iterations and switch to scan?
438
436
  def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
439
437
 
440
438
  _, state = carry
@@ -456,10 +454,20 @@ class RelaxedRigidContacts(common.ContactModel):
456
454
  # ======================================
457
455
 
458
456
  # Initialize the optimized forces with a linear Hunt/Crossley model.
459
- init_params = (
460
- K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
461
- + D[:, jnp.newaxis] * velocity
462
- ).flatten()
457
+ init_params = jax.vmap(
458
+ lambda p, v: jaxsim.rbda.contacts.SoftContacts.hunt_crossley_contact_model(
459
+ position=p,
460
+ velocity=v,
461
+ terrain=model.terrain,
462
+ K=1e6,
463
+ D=2e3,
464
+ p=0.0,
465
+ q=0.0,
466
+ # No tangential initial forces.
467
+ mu=0.0,
468
+ tangential_deformation=jnp.zeros(3),
469
+ )[0]
470
+ )(position, velocity).flatten()
463
471
 
464
472
  # Get the solver options.
465
473
  solver_options = self.solver_options
@@ -498,8 +506,8 @@ class RelaxedRigidContacts(common.ContactModel):
498
506
  @staticmethod
499
507
  def _regularizers(
500
508
  model: js.model.JaxSimModel,
501
- penetration: jtp.Array,
502
- velocity: jtp.Array,
509
+ position_constraint: jtp.Vector,
510
+ velocity_constraint: jtp.Vector,
503
511
  parameters: RelaxedRigidContactsParams,
504
512
  ) -> tuple:
505
513
  """
@@ -507,12 +515,13 @@ class RelaxedRigidContacts(common.ContactModel):
507
515
 
508
516
  Args:
509
517
  model: The jaxsim model.
510
- penetration: The penetration of the collidable points.
511
- velocity: The velocity of the collidable points.
518
+ penetration: The point position in the constraint frame.
519
+ velocity: The point velocity in the constraint frame.
512
520
  parameters: The parameters of the relaxed rigid contacts model.
513
521
 
514
522
  Returns:
515
- A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
523
+ A tuple containing the reference acceleration, the regularization matrix,
524
+ the stiffness, and the damping.
516
525
  """
517
526
 
518
527
  # Extract the parameters of the contact model.
@@ -545,70 +554,79 @@ class RelaxedRigidContacts(common.ContactModel):
545
554
  M_L = js.model.link_spatial_inertia_matrices(model=model)
546
555
 
547
556
  def imp_aref(
548
- penetration: jtp.Array, velocity: jtp.Array
549
- ) -> tuple[jtp.Array, jtp.Array]:
557
+ pos: jtp.Vector,
558
+ vel: jtp.Vector,
559
+ ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector]:
550
560
  """
551
561
  Calculates impedance and offset acceleration in constraint frame.
552
562
 
553
563
  Args:
554
- penetration: penetration in constraint frame
555
- velocity: velocity in constraint frame
564
+ pos: position in constraint frame.
565
+ vel: velocity in constraint frame.
556
566
 
557
567
  Returns:
568
+ ξ: computed impedance
558
569
  a_ref: offset acceleration in constraint frame
559
- R: regularization matrix
560
570
  K: computed stiffness
561
571
  D: computed damping
562
572
  """
563
- position = jnp.zeros(shape=(3,)).at[2].set(penetration)
564
573
 
565
- imp_x = jnp.abs(position) / width
566
- imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
574
+ imp_x = jnp.abs(pos) / width
567
575
 
576
+ imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
568
577
  imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
569
-
570
578
  imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
571
579
 
572
- imp = jnp.clip(ξ_min + imp_y * (ξ_max - ξ_min), ξ_min, ξ_max)
573
- imp = jnp.atleast_1d(jnp.where(imp_x > 1.0, ξ_max, imp))
580
+ # Compute the impedance.
581
+ ξ = ξ_min + imp_y * (ξ_max - ξ_min)
582
+ ξ = jnp.clip(ξ, ξ_min, ξ_max)
583
+ ξ = jnp.where(imp_x > 1.0, ξ_max, ξ)
584
+
585
+ # Compute the spring and damper parameters during runtime from the
586
+ # impedance and other contact parameters.
587
+ K = 1 / (ξ_max * Ω * ζ) ** 2
588
+ D = 2 / (ξ_max * Ω)
574
589
 
575
- # When passing negative values, K and D represent a spring and damper, respectively.
576
- K_f = jnp.where(K < 0, -K / ξ_max**2, 1 / (ξ_max * Ω * ζ) ** 2)
577
- D_f = jnp.where(D < 0, -D / ξ_max, 2 / (ξ_max * Ω))
590
+ # If the user specifies K and D and they are negative, the computed `a_ref`
591
+ # becomes something more similar to a classic Baumgarte regularization.
592
+ K = jnp.where(K < 0, -K / ξ_max**2, K)
593
+ D = jnp.where(D < 0, -D / ξ_max, D)
578
594
 
579
- a_ref = -jnp.atleast_1d(D_f * velocity + K_f * imp * position)
595
+ # Compute the reference acceleration.
596
+ a_ref = -(D * vel + K * ξ * pos)
580
597
 
581
- return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
598
+ return ξ, a_ref, K, D
582
599
 
583
600
  def compute_row(
584
601
  *,
585
- link_idx: jtp.Float,
586
- penetration: jtp.Array,
587
- velocity: jtp.Array,
588
- ) -> tuple[jtp.Array, jtp.Array]:
602
+ link_idx: jtp.Int,
603
+ pos: jtp.Vector,
604
+ vel: jtp.Vector,
605
+ ) -> tuple[jtp.Vector, jtp.Matrix, jtp.Vector, jtp.Vector]:
589
606
 
590
607
  # Compute the reference acceleration.
591
- ξ, a_ref, K, D = imp_aref(
592
- penetration=penetration,
593
- velocity=velocity,
594
- )
608
+ ξ, a_ref, K, D = imp_aref(pos=pos, vel=vel)
595
609
 
596
- # Compute the regularization terms.
610
+ # Compute the regularization term.
597
611
  R = (
598
612
  (2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
599
613
  * (1 + μ**2)
600
614
  @ jnp.linalg.inv(M_L[link_idx, :3, :3])
601
615
  )
602
616
 
603
- return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
617
+ # Return the computed values, setting them to zero in case of no contact.
618
+ is_active = (pos.dot(pos) > 0).astype(float)
619
+ return jax.tree.map(
620
+ lambda x: jnp.atleast_1d(x) * is_active, (a_ref, R, K, D)
621
+ )
604
622
 
605
623
  a_ref, R, K, D = jax.tree.map(
606
624
  f=jnp.concatenate,
607
625
  tree=(
608
626
  *jax.vmap(compute_row)(
609
627
  link_idx=parent_link_idx_of_enabled_collidable_points,
610
- penetration=penetration,
611
- velocity=velocity,
628
+ pos=position_constraint,
629
+ vel=velocity_constraint,
612
630
  ),
613
631
  ),
614
632
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev352
3
+ Version: 0.4.3.dev362
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -1,18 +1,18 @@
1
1
  jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
2
- jaxsim/_version.py,sha256=0_67mou8TrQcvuWEau9dM_rtURNW7KokTnrV6Fq6e_g,428
2
+ jaxsim/_version.py,sha256=yPZj4x5C1RSfHbZrDhkcc8pxpqIJyV7TN-gB5XXkOBo,428
3
3
  jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
8
8
  jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
9
- jaxsim/api/contact.py,sha256=D6RucrH9gnoUFLdmAEYwLGrimU0wLmuoDeOONu4ni74,25658
9
+ jaxsim/api/contact.py,sha256=cPE7FIUycGzmmNW4Zh9w7qsWd4cQYowo_Tg3mL2evL0,25657
10
10
  jaxsim/api/data.py,sha256=ThRpoBlbdwf1N3xs8SWrY5d8RbfdYRwFcmkdIPgtee4,29004
11
11
  jaxsim/api/frame.py,sha256=yPSgNygHkvWlln4wShNt7vZm_fFobVEm7phsklNNyH8,12922
12
12
  jaxsim/api/joint.py,sha256=8rCIxRMeAidsaBbw7kkGp6z3-UmBPtqmYmV_arHDQJ8,7365
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=Y9wnMshz83Zm4UEPOAOTINdtfkBZ86w853c8Yi2qaVs,29670
14
14
  jaxsim/api/link.py,sha256=nHjffhNdi_xGkteMsqdb_hC9mdV9rNw7k3pl89Uhw_8,12798
15
- jaxsim/api/model.py,sha256=A88AaBZpWvQ-L9blFyl1GHvTWI05rvVFKbSaHzD77_k,79563
15
+ jaxsim/api/model.py,sha256=HVoZ8AtFe5XOq8GlPbtpbMy-gzo4o1gM5gPchH0tHGw,79562
16
16
  jaxsim/api/ode.py,sha256=_t18avoCJngQk6eMFTGpaeahbpchQP20qJnUOCPkz8s,15360
17
17
  jaxsim/api/ode_data.py,sha256=1SD-x-lYk_YSEnVpxTLd69uOKC0mFUj44ZqpSmEDOxw,20190
18
18
  jaxsim/api/references.py,sha256=eIOk3MAOc9LJSKfI8M4WA8gGD-meo50vRfhXdea4sNI,20539
@@ -54,8 +54,8 @@ jaxsim/rbda/jacobian.py,sha256=L6Vn4Kf9I6wj-MYcFY6o67mgIfLFaaW4i2wNQJ2PDL0,10981
54
54
  jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
55
55
  jaxsim/rbda/utils.py,sha256=GLt7XIl1ROkx0_fnBCKUHYdB9_IBF3Yi4OnkHSX3gxA,5365
56
56
  jaxsim/rbda/contacts/__init__.py,sha256=L5MM-2pv76YPGzxExdz2EErgGBATuAjYnNHlq5QOySs,503
57
- jaxsim/rbda/contacts/common.py,sha256=ai49HeLQOsnckG0H2tUKW2KQ0Au_v9jRuNdnqie-YBk,11234
58
- jaxsim/rbda/contacts/relaxed_rigid.py,sha256=tbyskONuUhC6BZnZSpNUnlCjkI7LR6mCtmU_HimOAVE,20893
57
+ jaxsim/rbda/contacts/common.py,sha256=mjOS1MJkf9zRfcAKBwmYtO5Vrcrv-kLFM57FfFB8LgM,11244
58
+ jaxsim/rbda/contacts/relaxed_rigid.py,sha256=u7WliwuKff2RjS85eIEtJXbDLilZMqlz-j46-Pv7QAw,21681
59
59
  jaxsim/rbda/contacts/rigid.py,sha256=MSzkU6SFbW6CryNlyyxQ7K0-U-8k6VROGKv_DQrwqiw,17156
60
60
  jaxsim/rbda/contacts/soft.py,sha256=t6bqBfGAtV1AWoevY82LAcXy2XW8w_uu7bNywcyxF0s,17001
61
61
  jaxsim/rbda/contacts/visco_elastic.py,sha256=vQkfMuqQ3Qu8nbDTPY4jWBZjV3U7qtoRK1Aya3O3oFA,41424
@@ -65,8 +65,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
65
65
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
66
66
  jaxsim/utils/tracing.py,sha256=eEY28MZW0Lm_jJNt1NkFqZz0ek01tvhR46OXZYCo7tc,532
67
67
  jaxsim/utils/wrappers.py,sha256=ZY7olSORzZRvSzkdeNLj8yjwUIAt9L0Douwl7wItjpk,4008
68
- jaxsim-0.4.3.dev352.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
- jaxsim-0.4.3.dev352.dist-info/METADATA,sha256=ZNz9p67E-hIsIWmtnTKkQZakk-LI8Nt_mKX3ywW_gOU,17513
70
- jaxsim-0.4.3.dev352.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
71
- jaxsim-0.4.3.dev352.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
- jaxsim-0.4.3.dev352.dist-info/RECORD,,
68
+ jaxsim-0.4.3.dev362.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
+ jaxsim-0.4.3.dev362.dist-info/METADATA,sha256=YNsB6rTQXbGbbhmBKHhGQQ54P8nOASpjyEfGsXt5l-A,17513
70
+ jaxsim-0.4.3.dev362.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
71
+ jaxsim-0.4.3.dev362.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
+ jaxsim-0.4.3.dev362.dist-info/RECORD,,