jaxsim 0.4.3.dev359__py3-none-any.whl → 0.4.3.dev362__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/rbda/contacts/common.py +5 -5
- jaxsim/rbda/contacts/relaxed_rigid.py +74 -56
- {jaxsim-0.4.3.dev359.dist-info → jaxsim-0.4.3.dev362.dist-info}/METADATA +1 -1
- {jaxsim-0.4.3.dev359.dist-info → jaxsim-0.4.3.dev362.dist-info}/RECORD +8 -8
- {jaxsim-0.4.3.dev359.dist-info → jaxsim-0.4.3.dev362.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev359.dist-info → jaxsim-0.4.3.dev362.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev359.dist-info → jaxsim-0.4.3.dev362.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev362'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev362')
|
jaxsim/rbda/contacts/common.py
CHANGED
@@ -324,10 +324,10 @@ class ContactModel(JaxsimDataclass):
|
|
324
324
|
The initialized model and data objects.
|
325
325
|
"""
|
326
326
|
|
327
|
-
with
|
328
|
-
|
327
|
+
with self.editable(validate=validate) as contact_model:
|
328
|
+
contact_model.parameters = data.contacts_params
|
329
329
|
|
330
|
-
with
|
331
|
-
|
330
|
+
with model.editable(validate=validate) as model_out:
|
331
|
+
model_out.contact_model = contact_model
|
332
332
|
|
333
|
-
return model_out,
|
333
|
+
return model_out, data
|
@@ -10,6 +10,7 @@ import jax_dataclasses
|
|
10
10
|
import optax
|
11
11
|
|
12
12
|
import jaxsim.api as js
|
13
|
+
import jaxsim.rbda.contacts
|
13
14
|
import jaxsim.typing as jtp
|
14
15
|
from jaxsim import logging
|
15
16
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
@@ -314,22 +315,20 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
314
315
|
joint_force_references=joint_force_references,
|
315
316
|
)
|
316
317
|
|
317
|
-
def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
|
318
|
-
x, y, z = jax.tree.map(jnp.squeeze, (x, y, z))
|
319
|
-
|
320
|
-
n̂ = model.terrain.normal(x=x, y=y).squeeze()
|
321
|
-
h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
|
322
|
-
|
323
|
-
return jnp.dot(h, n̂)
|
324
|
-
|
325
318
|
# Compute the position and linear velocities (mixed representation) of
|
326
319
|
# all collidable points belonging to the robot.
|
327
320
|
position, velocity = js.contact.collidable_point_kinematics(
|
328
321
|
model=model, data=data
|
329
322
|
)
|
330
323
|
|
331
|
-
# Compute the
|
332
|
-
|
324
|
+
# Compute the penetration depth and velocity of the collidable points.
|
325
|
+
# Note that this function considers the penetration in the normal direction.
|
326
|
+
δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
|
327
|
+
position, velocity, model.terrain
|
328
|
+
)
|
329
|
+
|
330
|
+
# Compute the position in the constraint frame.
|
331
|
+
position_constraint = jax.vmap(lambda δ, n̂: -δ * n̂)(δ, n̂)
|
333
332
|
|
334
333
|
# Compute the transforms of the implicit frames corresponding to the
|
335
334
|
# collidable points.
|
@@ -356,24 +355,22 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
356
355
|
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
357
356
|
|
358
357
|
Jl_WC = jnp.vstack(
|
359
|
-
jax.vmap(lambda J,
|
360
|
-
js.contact.jacobian(model=model, data=data)[:, :3, :],
|
361
|
-
δ,
|
358
|
+
jax.vmap(lambda J, δ: J * (δ > 0))(
|
359
|
+
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
362
360
|
)
|
363
361
|
)
|
364
362
|
|
365
363
|
J̇_WC = jnp.vstack(
|
366
|
-
jax.vmap(lambda J̇,
|
367
|
-
js.contact.jacobian_derivative(model=model, data=data)[:, :3],
|
368
|
-
δ,
|
364
|
+
jax.vmap(lambda J̇, δ: J̇ * (δ > 0))(
|
365
|
+
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
|
369
366
|
),
|
370
367
|
)
|
371
368
|
|
372
369
|
# Compute the regularization terms.
|
373
|
-
a_ref, R,
|
370
|
+
a_ref, R, *_ = self._regularizers(
|
374
371
|
model=model,
|
375
|
-
|
376
|
-
|
372
|
+
position_constraint=position_constraint,
|
373
|
+
velocity_constraint=velocity,
|
377
374
|
parameters=data.contacts_params,
|
378
375
|
)
|
379
376
|
|
@@ -435,6 +432,7 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
435
432
|
|
436
433
|
return params, state
|
437
434
|
|
435
|
+
# TODO: maybe fix the number of iterations and switch to scan?
|
438
436
|
def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
|
439
437
|
|
440
438
|
_, state = carry
|
@@ -456,10 +454,20 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
456
454
|
# ======================================
|
457
455
|
|
458
456
|
# Initialize the optimized forces with a linear Hunt/Crossley model.
|
459
|
-
init_params = (
|
460
|
-
|
461
|
-
|
462
|
-
|
457
|
+
init_params = jax.vmap(
|
458
|
+
lambda p, v: jaxsim.rbda.contacts.SoftContacts.hunt_crossley_contact_model(
|
459
|
+
position=p,
|
460
|
+
velocity=v,
|
461
|
+
terrain=model.terrain,
|
462
|
+
K=1e6,
|
463
|
+
D=2e3,
|
464
|
+
p=0.0,
|
465
|
+
q=0.0,
|
466
|
+
# No tangential initial forces.
|
467
|
+
mu=0.0,
|
468
|
+
tangential_deformation=jnp.zeros(3),
|
469
|
+
)[0]
|
470
|
+
)(position, velocity).flatten()
|
463
471
|
|
464
472
|
# Get the solver options.
|
465
473
|
solver_options = self.solver_options
|
@@ -498,8 +506,8 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
498
506
|
@staticmethod
|
499
507
|
def _regularizers(
|
500
508
|
model: js.model.JaxSimModel,
|
501
|
-
|
502
|
-
|
509
|
+
position_constraint: jtp.Vector,
|
510
|
+
velocity_constraint: jtp.Vector,
|
503
511
|
parameters: RelaxedRigidContactsParams,
|
504
512
|
) -> tuple:
|
505
513
|
"""
|
@@ -507,12 +515,13 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
507
515
|
|
508
516
|
Args:
|
509
517
|
model: The jaxsim model.
|
510
|
-
penetration: The
|
511
|
-
velocity: The velocity
|
518
|
+
penetration: The point position in the constraint frame.
|
519
|
+
velocity: The point velocity in the constraint frame.
|
512
520
|
parameters: The parameters of the relaxed rigid contacts model.
|
513
521
|
|
514
522
|
Returns:
|
515
|
-
A tuple containing the reference acceleration, the regularization matrix,
|
523
|
+
A tuple containing the reference acceleration, the regularization matrix,
|
524
|
+
the stiffness, and the damping.
|
516
525
|
"""
|
517
526
|
|
518
527
|
# Extract the parameters of the contact model.
|
@@ -545,70 +554,79 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
545
554
|
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
546
555
|
|
547
556
|
def imp_aref(
|
548
|
-
|
549
|
-
|
557
|
+
pos: jtp.Vector,
|
558
|
+
vel: jtp.Vector,
|
559
|
+
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector]:
|
550
560
|
"""
|
551
561
|
Calculates impedance and offset acceleration in constraint frame.
|
552
562
|
|
553
563
|
Args:
|
554
|
-
|
555
|
-
|
564
|
+
pos: position in constraint frame.
|
565
|
+
vel: velocity in constraint frame.
|
556
566
|
|
557
567
|
Returns:
|
568
|
+
ξ: computed impedance
|
558
569
|
a_ref: offset acceleration in constraint frame
|
559
|
-
R: regularization matrix
|
560
570
|
K: computed stiffness
|
561
571
|
D: computed damping
|
562
572
|
"""
|
563
|
-
position = jnp.zeros(shape=(3,)).at[2].set(penetration)
|
564
573
|
|
565
|
-
imp_x = jnp.abs(
|
566
|
-
imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
|
574
|
+
imp_x = jnp.abs(pos) / width
|
567
575
|
|
576
|
+
imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
|
568
577
|
imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
|
569
|
-
|
570
578
|
imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
|
571
579
|
|
572
|
-
|
573
|
-
|
580
|
+
# Compute the impedance.
|
581
|
+
ξ = ξ_min + imp_y * (ξ_max - ξ_min)
|
582
|
+
ξ = jnp.clip(ξ, ξ_min, ξ_max)
|
583
|
+
ξ = jnp.where(imp_x > 1.0, ξ_max, ξ)
|
584
|
+
|
585
|
+
# Compute the spring and damper parameters during runtime from the
|
586
|
+
# impedance and other contact parameters.
|
587
|
+
K = 1 / (ξ_max * Ω * ζ) ** 2
|
588
|
+
D = 2 / (ξ_max * Ω)
|
574
589
|
|
575
|
-
#
|
576
|
-
|
577
|
-
|
590
|
+
# If the user specifies K and D and they are negative, the computed `a_ref`
|
591
|
+
# becomes something more similar to a classic Baumgarte regularization.
|
592
|
+
K = jnp.where(K < 0, -K / ξ_max**2, K)
|
593
|
+
D = jnp.where(D < 0, -D / ξ_max, D)
|
578
594
|
|
579
|
-
|
595
|
+
# Compute the reference acceleration.
|
596
|
+
a_ref = -(D * vel + K * ξ * pos)
|
580
597
|
|
581
|
-
return
|
598
|
+
return ξ, a_ref, K, D
|
582
599
|
|
583
600
|
def compute_row(
|
584
601
|
*,
|
585
|
-
link_idx: jtp.
|
586
|
-
|
587
|
-
|
588
|
-
) -> tuple[jtp.
|
602
|
+
link_idx: jtp.Int,
|
603
|
+
pos: jtp.Vector,
|
604
|
+
vel: jtp.Vector,
|
605
|
+
) -> tuple[jtp.Vector, jtp.Matrix, jtp.Vector, jtp.Vector]:
|
589
606
|
|
590
607
|
# Compute the reference acceleration.
|
591
|
-
ξ, a_ref, K, D = imp_aref(
|
592
|
-
penetration=penetration,
|
593
|
-
velocity=velocity,
|
594
|
-
)
|
608
|
+
ξ, a_ref, K, D = imp_aref(pos=pos, vel=vel)
|
595
609
|
|
596
|
-
# Compute the regularization
|
610
|
+
# Compute the regularization term.
|
597
611
|
R = (
|
598
612
|
(2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
|
599
613
|
* (1 + μ**2)
|
600
614
|
@ jnp.linalg.inv(M_L[link_idx, :3, :3])
|
601
615
|
)
|
602
616
|
|
603
|
-
|
617
|
+
# Return the computed values, setting them to zero in case of no contact.
|
618
|
+
is_active = (pos.dot(pos) > 0).astype(float)
|
619
|
+
return jax.tree.map(
|
620
|
+
lambda x: jnp.atleast_1d(x) * is_active, (a_ref, R, K, D)
|
621
|
+
)
|
604
622
|
|
605
623
|
a_ref, R, K, D = jax.tree.map(
|
606
624
|
f=jnp.concatenate,
|
607
625
|
tree=(
|
608
626
|
*jax.vmap(compute_row)(
|
609
627
|
link_idx=parent_link_idx_of_enabled_collidable_points,
|
610
|
-
|
611
|
-
|
628
|
+
pos=position_constraint,
|
629
|
+
vel=velocity_constraint,
|
612
630
|
),
|
613
631
|
),
|
614
632
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev362
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
|
6
6
|
Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
|
@@ -1,5 +1,5 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=yPZj4x5C1RSfHbZrDhkcc8pxpqIJyV7TN-gB5XXkOBo,428
|
3
3
|
jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
|
@@ -54,8 +54,8 @@ jaxsim/rbda/jacobian.py,sha256=L6Vn4Kf9I6wj-MYcFY6o67mgIfLFaaW4i2wNQJ2PDL0,10981
|
|
54
54
|
jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
|
55
55
|
jaxsim/rbda/utils.py,sha256=GLt7XIl1ROkx0_fnBCKUHYdB9_IBF3Yi4OnkHSX3gxA,5365
|
56
56
|
jaxsim/rbda/contacts/__init__.py,sha256=L5MM-2pv76YPGzxExdz2EErgGBATuAjYnNHlq5QOySs,503
|
57
|
-
jaxsim/rbda/contacts/common.py,sha256=
|
58
|
-
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=
|
57
|
+
jaxsim/rbda/contacts/common.py,sha256=mjOS1MJkf9zRfcAKBwmYtO5Vrcrv-kLFM57FfFB8LgM,11244
|
58
|
+
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=u7WliwuKff2RjS85eIEtJXbDLilZMqlz-j46-Pv7QAw,21681
|
59
59
|
jaxsim/rbda/contacts/rigid.py,sha256=MSzkU6SFbW6CryNlyyxQ7K0-U-8k6VROGKv_DQrwqiw,17156
|
60
60
|
jaxsim/rbda/contacts/soft.py,sha256=t6bqBfGAtV1AWoevY82LAcXy2XW8w_uu7bNywcyxF0s,17001
|
61
61
|
jaxsim/rbda/contacts/visco_elastic.py,sha256=vQkfMuqQ3Qu8nbDTPY4jWBZjV3U7qtoRK1Aya3O3oFA,41424
|
@@ -65,8 +65,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
|
65
65
|
jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
|
66
66
|
jaxsim/utils/tracing.py,sha256=eEY28MZW0Lm_jJNt1NkFqZz0ek01tvhR46OXZYCo7tc,532
|
67
67
|
jaxsim/utils/wrappers.py,sha256=ZY7olSORzZRvSzkdeNLj8yjwUIAt9L0Douwl7wItjpk,4008
|
68
|
-
jaxsim-0.4.3.
|
69
|
-
jaxsim-0.4.3.
|
70
|
-
jaxsim-0.4.3.
|
71
|
-
jaxsim-0.4.3.
|
72
|
-
jaxsim-0.4.3.
|
68
|
+
jaxsim-0.4.3.dev362.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
69
|
+
jaxsim-0.4.3.dev362.dist-info/METADATA,sha256=YNsB6rTQXbGbbhmBKHhGQQ54P8nOASpjyEfGsXt5l-A,17513
|
70
|
+
jaxsim-0.4.3.dev362.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
|
71
|
+
jaxsim-0.4.3.dev362.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
72
|
+
jaxsim-0.4.3.dev362.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|