jaxsim 0.4.3.dev186__py3-none-any.whl → 0.4.3.dev200__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +58 -19
- jaxsim/api/data.py +3 -2
- jaxsim/api/model.py +8 -12
- jaxsim/rbda/contacts/__init__.py +7 -0
- jaxsim/rbda/contacts/common.py +49 -3
- jaxsim/rbda/contacts/relaxed_rigid.py +133 -72
- jaxsim/rbda/contacts/rigid.py +94 -115
- jaxsim/rbda/contacts/soft.py +21 -37
- jaxsim/rbda/contacts/visco_elastic.py +3 -3
- {jaxsim-0.4.3.dev186.dist-info → jaxsim-0.4.3.dev200.dist-info}/METADATA +1 -1
- {jaxsim-0.4.3.dev186.dist-info → jaxsim-0.4.3.dev200.dist-info}/RECORD +15 -15
- {jaxsim-0.4.3.dev186.dist-info → jaxsim-0.4.3.dev200.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.3.dev186.dist-info → jaxsim-0.4.3.dev200.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev186.dist-info → jaxsim-0.4.3.dev200.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev200'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev200')
|
jaxsim/api/contact.py
CHANGED
@@ -36,11 +36,10 @@ def collidable_point_kinematics(
|
|
36
36
|
the linear component of the mixed 6D frame velocity.
|
37
37
|
"""
|
38
38
|
|
39
|
-
from jaxsim.rbda import collidable_points
|
40
|
-
|
41
39
|
# Switch to inertial-fixed since the RBDAs expect velocities in this representation.
|
42
40
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
43
|
-
|
41
|
+
|
42
|
+
W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
|
44
43
|
model=model,
|
45
44
|
base_position=data.base_position(),
|
46
45
|
base_quaternion=data.base_orientation(dcm=False),
|
@@ -304,6 +303,15 @@ def in_contact(
|
|
304
303
|
|
305
304
|
|
306
305
|
def estimate_good_soft_contacts_parameters(
|
306
|
+
*args, **kwargs
|
307
|
+
) -> jaxsim.rbda.contacts.ContactParamsTypes:
|
308
|
+
|
309
|
+
msg = "This method is deprecated, please use `{}`."
|
310
|
+
logging.warning(msg.format(estimate_good_contact_parameters.__name__))
|
311
|
+
return estimate_good_contact_parameters(*args, **kwargs)
|
312
|
+
|
313
|
+
|
314
|
+
def estimate_good_contact_parameters(
|
307
315
|
model: js.model.JaxSimModel,
|
308
316
|
*,
|
309
317
|
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
|
@@ -312,14 +320,9 @@ def estimate_good_soft_contacts_parameters(
|
|
312
320
|
damping_ratio: jtp.FloatLike = 1.0,
|
313
321
|
max_penetration: jtp.FloatLike | None = None,
|
314
322
|
**kwargs,
|
315
|
-
) ->
|
316
|
-
jaxsim.rbda.contacts.RelaxedRigidContactsParams
|
317
|
-
| jaxsim.rbda.contacts.RigidContactsParams
|
318
|
-
| jaxsim.rbda.contacts.SoftContactsParams
|
319
|
-
| jaxsim.rbda.contacts.ViscoElasticContactsParams
|
320
|
-
):
|
323
|
+
) -> jaxsim.rbda.contacts.ContactParamsTypes:
|
321
324
|
"""
|
322
|
-
Estimate good
|
325
|
+
Estimate good contact parameters.
|
323
326
|
|
324
327
|
Args:
|
325
328
|
model: The model to consider.
|
@@ -332,12 +335,19 @@ def estimate_good_soft_contacts_parameters(
|
|
332
335
|
max_penetration:
|
333
336
|
The maximum penetration allowed in steady state when the robot is
|
334
337
|
supported by the configured number of active collidable points.
|
338
|
+
kwargs:
|
339
|
+
Additional model-specific parameters passed to the builder method of
|
340
|
+
the parameters class.
|
335
341
|
|
336
342
|
Returns:
|
337
|
-
The estimated good
|
343
|
+
The estimated good contacts parameters.
|
344
|
+
|
345
|
+
Note:
|
346
|
+
This is primarily a convenience function for soft-like contact models.
|
347
|
+
However, it provides with some good default parameters also for the other ones.
|
338
348
|
|
339
349
|
Note:
|
340
|
-
This method provides a good
|
350
|
+
This method provides a good set of contacts parameters.
|
341
351
|
The user is encouraged to fine-tune the parameters based on the
|
342
352
|
specific application.
|
343
353
|
"""
|
@@ -364,6 +374,7 @@ def estimate_good_soft_contacts_parameters(
|
|
364
374
|
max_δ = (
|
365
375
|
max_penetration
|
366
376
|
if max_penetration is not None
|
377
|
+
# Consider as default a 0.5% of the model height.
|
367
378
|
else 0.005 * estimate_model_height(model=model)
|
368
379
|
)
|
369
380
|
|
@@ -381,8 +392,11 @@ def estimate_good_soft_contacts_parameters(
|
|
381
392
|
max_penetration=max_δ,
|
382
393
|
number_of_active_collidable_points_steady_state=nc,
|
383
394
|
damping_ratio=damping_ratio,
|
384
|
-
|
385
|
-
|
395
|
+
**dict(
|
396
|
+
p=model.contact_model.parameters.p,
|
397
|
+
q=model.contact_model.parameters.q,
|
398
|
+
)
|
399
|
+
| kwargs,
|
386
400
|
)
|
387
401
|
|
388
402
|
case contacts.ViscoElasticContacts():
|
@@ -396,15 +410,40 @@ def estimate_good_soft_contacts_parameters(
|
|
396
410
|
max_penetration=max_δ,
|
397
411
|
number_of_active_collidable_points_steady_state=nc,
|
398
412
|
damping_ratio=damping_ratio,
|
399
|
-
|
400
|
-
|
401
|
-
|
413
|
+
**dict(
|
414
|
+
p=model.contact_model.parameters.p,
|
415
|
+
q=model.contact_model.parameters.q,
|
416
|
+
)
|
417
|
+
| kwargs,
|
402
418
|
)
|
403
419
|
)
|
404
420
|
|
421
|
+
case contacts.RigidContacts():
|
422
|
+
assert isinstance(model.contact_model, contacts.RigidContacts)
|
423
|
+
|
424
|
+
# Disable Baumgarte stabilization by default since it does not play
|
425
|
+
# well with the forward Euler integrator.
|
426
|
+
K = kwargs.get("K", 0.0)
|
427
|
+
|
428
|
+
parameters = contacts.RigidContactsParams.build(
|
429
|
+
mu=static_friction_coefficient,
|
430
|
+
**dict(
|
431
|
+
K=K,
|
432
|
+
D=2 * jnp.sqrt(K),
|
433
|
+
)
|
434
|
+
| kwargs,
|
435
|
+
)
|
436
|
+
|
437
|
+
case contacts.RelaxedRigidContacts():
|
438
|
+
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
|
439
|
+
|
440
|
+
parameters = contacts.RelaxedRigidContactsParams.build(
|
441
|
+
mu=static_friction_coefficient,
|
442
|
+
**kwargs,
|
443
|
+
)
|
444
|
+
|
405
445
|
case _:
|
406
|
-
|
407
|
-
parameters = model.contact_model.parameters
|
446
|
+
raise ValueError(f"Invalid contact model: {model.contact_model}")
|
408
447
|
|
409
448
|
return parameters
|
410
449
|
|
jaxsim/api/data.py
CHANGED
@@ -34,7 +34,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
34
34
|
|
35
35
|
state: ODEState
|
36
36
|
|
37
|
-
gravity: jtp.
|
37
|
+
gravity: jtp.Vector
|
38
38
|
|
39
39
|
contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
|
40
40
|
|
@@ -224,7 +224,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
224
224
|
jaxsim.rbda.contacts.SoftContacts
|
225
225
|
| jaxsim.rbda.contacts.ViscoElasticContacts,
|
226
226
|
):
|
227
|
-
|
227
|
+
|
228
|
+
contacts_params = js.contact.estimate_good_contact_parameters(
|
228
229
|
model=model, standard_gravity=standard_gravity
|
229
230
|
)
|
230
231
|
|
jaxsim/api/model.py
CHANGED
@@ -40,6 +40,8 @@ class JaxSimModel(JaxsimDataclass):
|
|
40
40
|
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
|
41
41
|
)
|
42
42
|
|
43
|
+
# Note that this is the default contact model.
|
44
|
+
# Its parameters, if any, are then overridden from those stored in JaxSimModelData.
|
43
45
|
contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
|
44
46
|
default=None, repr=False
|
45
47
|
)
|
@@ -2044,24 +2046,18 @@ def step(
|
|
2044
2046
|
M = js.model.free_floating_mass_matrix(model, data_tf)
|
2045
2047
|
W_p_C = js.contact.collidable_point_positions(model, data_tf)
|
2046
2048
|
|
2047
|
-
# Compute the
|
2048
|
-
|
2049
|
-
|
2050
|
-
|
2051
|
-
|
2052
|
-
inactive_collidable_points, _ = (
|
2053
|
-
jaxsim.rbda.contacts.RigidContacts.detect_contacts(
|
2054
|
-
W_p_C=W_p_C,
|
2055
|
-
terrain_height=terrain_height,
|
2056
|
-
)
|
2057
|
-
)
|
2049
|
+
# Compute the penetration depth of the collidable points.
|
2050
|
+
δ, *_ = jax.vmap(
|
2051
|
+
jaxsim.rbda.contacts.common.compute_penetration_data,
|
2052
|
+
in_axes=(0, 0, None),
|
2053
|
+
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
|
2058
2054
|
|
2059
2055
|
# Compute the impact velocity.
|
2060
2056
|
# It may be discontinuous in case new contacts are made.
|
2061
2057
|
BW_nu_post_impact = (
|
2062
2058
|
jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
|
2063
2059
|
data=data_tf,
|
2064
|
-
inactive_collidable_points=
|
2060
|
+
inactive_collidable_points=(δ <= 0),
|
2065
2061
|
M=M,
|
2066
2062
|
J_WC=J_WC,
|
2067
2063
|
)
|
jaxsim/rbda/contacts/__init__.py
CHANGED
@@ -4,3 +4,10 @@ from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
|
|
4
4
|
from .rigid import RigidContacts, RigidContactsParams
|
5
5
|
from .soft import SoftContacts, SoftContactsParams
|
6
6
|
from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams
|
7
|
+
|
8
|
+
ContactParamsTypes = (
|
9
|
+
SoftContactsParams
|
10
|
+
| RigidContactsParams
|
11
|
+
| RelaxedRigidContactsParams
|
12
|
+
| ViscoElasticContactsParams
|
13
|
+
)
|
jaxsim/rbda/contacts/common.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import abc
|
4
|
+
import functools
|
4
5
|
from typing import Any
|
5
6
|
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
|
6
10
|
import jaxsim.api as js
|
7
11
|
import jaxsim.terrain
|
8
12
|
import jaxsim.typing as jtp
|
@@ -14,6 +18,47 @@ except ImportError:
|
|
14
18
|
from typing_extensions import Self
|
15
19
|
|
16
20
|
|
21
|
+
@functools.partial(jax.jit, static_argnames=("terrain",))
|
22
|
+
def compute_penetration_data(
|
23
|
+
p: jtp.VectorLike,
|
24
|
+
v: jtp.VectorLike,
|
25
|
+
terrain: jaxsim.terrain.Terrain,
|
26
|
+
) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
|
27
|
+
"""
|
28
|
+
Compute the penetration data (depth, rate, and terrain normal) of a collidable point.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
p: The position of the collidable point.
|
32
|
+
v:
|
33
|
+
The linear velocity of the point (linear component of the mixed 6D velocity
|
34
|
+
of the implicit frame `C = (W_p_C, [W])` associated to the point).
|
35
|
+
terrain: The considered terrain.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
A tuple containing the penetration depth, the penetration velocity,
|
39
|
+
and the considered terrain normal.
|
40
|
+
"""
|
41
|
+
|
42
|
+
# Pre-process the position and the linear velocity of the collidable point.
|
43
|
+
W_ṗ_C = jnp.array(v).squeeze()
|
44
|
+
px, py, pz = jnp.array(p).squeeze()
|
45
|
+
|
46
|
+
# Compute the terrain normal and the contact depth.
|
47
|
+
n̂ = terrain.normal(x=px, y=py).squeeze()
|
48
|
+
h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])
|
49
|
+
|
50
|
+
# Compute the penetration depth normal to the terrain.
|
51
|
+
δ = jnp.maximum(0.0, jnp.dot(h, n̂))
|
52
|
+
|
53
|
+
# Compute the penetration normal velocity.
|
54
|
+
δ_dot = -jnp.dot(W_ṗ_C, n̂)
|
55
|
+
|
56
|
+
# Enforce the penetration rate to be zero when the penetration depth is zero.
|
57
|
+
δ_dot = jnp.where(δ > 0, δ_dot, 0.0)
|
58
|
+
|
59
|
+
return δ, δ_dot, n̂
|
60
|
+
|
61
|
+
|
17
62
|
class ContactsParams(JaxsimDataclass):
|
18
63
|
"""
|
19
64
|
Abstract class representing the parameters of a contact model.
|
@@ -86,7 +131,7 @@ class ContactModel(JaxsimDataclass):
|
|
86
131
|
model: js.model.JaxSimModel,
|
87
132
|
data: js.data.JaxSimModelData,
|
88
133
|
**kwargs,
|
89
|
-
) -> tuple[jtp.
|
134
|
+
) -> tuple[jtp.Matrix, tuple[Any, ...]]:
|
90
135
|
"""
|
91
136
|
Compute the contact forces.
|
92
137
|
|
@@ -95,8 +140,9 @@ class ContactModel(JaxsimDataclass):
|
|
95
140
|
data: The data of the considered model.
|
96
141
|
|
97
142
|
Returns:
|
98
|
-
A tuple containing as first element the computed 6D contact force applied to
|
99
|
-
and
|
143
|
+
A tuple containing as first element the computed 6D contact force applied to
|
144
|
+
the contact points and expressed in the world frame, and as second element
|
145
|
+
a tuple of optional additional information.
|
100
146
|
"""
|
101
147
|
|
102
148
|
pass
|
@@ -12,11 +12,10 @@ import optax
|
|
12
12
|
import jaxsim.api as js
|
13
13
|
import jaxsim.typing as jtp
|
14
14
|
from jaxsim import logging
|
15
|
-
from jaxsim.api.common import VelRepr
|
16
|
-
from jaxsim.math import Adjoint
|
15
|
+
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
17
16
|
from jaxsim.terrain.terrain import FlatTerrain, Terrain
|
18
17
|
|
19
|
-
from .
|
18
|
+
from . import common
|
20
19
|
|
21
20
|
try:
|
22
21
|
from typing import Self
|
@@ -25,7 +24,7 @@ except ImportError:
|
|
25
24
|
|
26
25
|
|
27
26
|
@jax_dataclasses.pytree_dataclass
|
28
|
-
class RelaxedRigidContactsParams(ContactsParams):
|
27
|
+
class RelaxedRigidContactsParams(common.ContactsParams):
|
29
28
|
"""Parameters of the relaxed rigid contacts model."""
|
30
29
|
|
31
30
|
# Time constant
|
@@ -116,14 +115,24 @@ class RelaxedRigidContactsParams(ContactsParams):
|
|
116
115
|
) -> Self:
|
117
116
|
"""Create a `RelaxedRigidContactsParams` instance"""
|
118
117
|
|
118
|
+
def default(name: str):
|
119
|
+
return cls.__dataclass_fields__[name].default_factory()
|
120
|
+
|
119
121
|
return cls(
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
122
|
+
time_constant=jnp.array(
|
123
|
+
time_constant or default("time_constant"), dtype=float
|
124
|
+
),
|
125
|
+
damping_coefficient=jnp.array(
|
126
|
+
damping_coefficient or default("damping_coefficient"), dtype=float
|
127
|
+
),
|
128
|
+
d_min=jnp.array(d_min or default("d_min"), dtype=float),
|
129
|
+
d_max=jnp.array(d_max or default("d_max"), dtype=float),
|
130
|
+
width=jnp.array(width or default("width"), dtype=float),
|
131
|
+
midpoint=jnp.array(midpoint or default("midpoint"), dtype=float),
|
132
|
+
power=jnp.array(power or default("power"), dtype=float),
|
133
|
+
stiffness=jnp.array(stiffness or default("stiffness"), dtype=float),
|
134
|
+
damping=jnp.array(damping or default("damping"), dtype=float),
|
135
|
+
mu=jnp.array(mu or default("mu"), dtype=float),
|
127
136
|
)
|
128
137
|
|
129
138
|
def valid(self) -> jtp.BoolLike:
|
@@ -142,7 +151,7 @@ class RelaxedRigidContactsParams(ContactsParams):
|
|
142
151
|
|
143
152
|
|
144
153
|
@jax_dataclasses.pytree_dataclass
|
145
|
-
class RelaxedRigidContacts(ContactModel):
|
154
|
+
class RelaxedRigidContacts(common.ContactModel):
|
146
155
|
"""Relaxed rigid contacts model."""
|
147
156
|
|
148
157
|
parameters: RelaxedRigidContactsParams = dataclasses.field(
|
@@ -229,7 +238,7 @@ class RelaxedRigidContacts(ContactModel):
|
|
229
238
|
*,
|
230
239
|
link_forces: jtp.MatrixLike | None = None,
|
231
240
|
joint_force_references: jtp.VectorLike | None = None,
|
232
|
-
) -> tuple[jtp.
|
241
|
+
) -> tuple[jtp.Matrix, tuple]:
|
233
242
|
"""
|
234
243
|
Compute the contact forces.
|
235
244
|
|
@@ -243,22 +252,23 @@ class RelaxedRigidContacts(ContactModel):
|
|
243
252
|
Optional `(n_joints,)` vector of joint forces.
|
244
253
|
|
245
254
|
Returns:
|
246
|
-
A tuple containing the contact forces.
|
255
|
+
A tuple containing as first element the computed contact forces.
|
247
256
|
"""
|
248
257
|
|
249
258
|
# Initialize the model and data this contact model is operating on.
|
250
259
|
# This will raise an exception if either the contact model or the
|
251
260
|
# contact parameters are not compatible.
|
252
261
|
model, data = self.initialize_model_and_data(model=model, data=data)
|
262
|
+
assert isinstance(data.contacts_params, RelaxedRigidContactsParams)
|
253
263
|
|
254
|
-
link_forces = (
|
255
|
-
link_forces
|
264
|
+
link_forces = jnp.atleast_2d(
|
265
|
+
jnp.array(link_forces, dtype=float).squeeze()
|
256
266
|
if link_forces is not None
|
257
267
|
else jnp.zeros((model.number_of_links(), 6))
|
258
268
|
)
|
259
269
|
|
260
|
-
joint_force_references = (
|
261
|
-
joint_force_references
|
270
|
+
joint_force_references = jnp.atleast_1d(
|
271
|
+
jnp.array(joint_force_references, dtype=float).squeeze()
|
262
272
|
if joint_force_references is not None
|
263
273
|
else jnp.zeros(model.number_of_joints())
|
264
274
|
)
|
@@ -271,10 +281,10 @@ class RelaxedRigidContacts(ContactModel):
|
|
271
281
|
joint_force_references=joint_force_references,
|
272
282
|
)
|
273
283
|
|
274
|
-
def
|
284
|
+
def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
|
275
285
|
x, y, z = jax.tree.map(jnp.squeeze, (x, y, z))
|
276
286
|
|
277
|
-
n̂ =
|
287
|
+
n̂ = model.terrain.normal(x=x, y=y).squeeze()
|
278
288
|
h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
|
279
289
|
|
280
290
|
return jnp.dot(h, n̂)
|
@@ -286,19 +296,19 @@ class RelaxedRigidContacts(ContactModel):
|
|
286
296
|
)
|
287
297
|
|
288
298
|
# Compute the activation state of the collidable points
|
289
|
-
δ = jax.vmap(
|
299
|
+
δ = jax.vmap(detect_contact)(*position.T)
|
300
|
+
|
301
|
+
# Compute the transforms of the implicit frames corresponding to the
|
302
|
+
# collidable points.
|
303
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
290
304
|
|
291
305
|
with (
|
292
306
|
references.switch_velocity_representation(VelRepr.Mixed),
|
293
307
|
data.switch_velocity_representation(VelRepr.Mixed),
|
294
308
|
):
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
299
|
-
)
|
300
|
-
)
|
301
|
-
W_H_C = js.contact.transforms(model=model, data=data)
|
309
|
+
|
310
|
+
BW_ν = data.generalized_velocity()
|
311
|
+
|
302
312
|
BW_ν̇_free = jnp.hstack(
|
303
313
|
js.ode.system_acceleration(
|
304
314
|
model=model,
|
@@ -309,20 +319,31 @@ class RelaxedRigidContacts(ContactModel):
|
|
309
319
|
),
|
310
320
|
)
|
311
321
|
)
|
312
|
-
|
322
|
+
|
323
|
+
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
324
|
+
|
325
|
+
Jl_WC = jnp.vstack(
|
326
|
+
jax.vmap(lambda J, height: J * (height < 0))(
|
327
|
+
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
328
|
+
)
|
329
|
+
)
|
330
|
+
|
313
331
|
J̇_WC = jnp.vstack(
|
314
332
|
jax.vmap(lambda J̇, height: J̇ * (height < 0))(
|
315
333
|
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
|
316
334
|
),
|
317
335
|
)
|
318
336
|
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
337
|
+
# Compute the regularization terms.
|
338
|
+
a_ref, R, K, D = self._regularizers(
|
339
|
+
model=model,
|
340
|
+
penetration=δ,
|
341
|
+
velocity=velocity,
|
342
|
+
parameters=data.contacts_params,
|
343
|
+
)
|
325
344
|
|
345
|
+
# Compute the Delassus matrix and the free mixed linear acceleration of
|
346
|
+
# the collidable points.
|
326
347
|
G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
|
327
348
|
CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
|
328
349
|
|
@@ -330,26 +351,40 @@ class RelaxedRigidContacts(ContactModel):
|
|
330
351
|
A = G + R
|
331
352
|
b = CW_al_free_WC - a_ref
|
332
353
|
|
354
|
+
# Create the objective function to minimize as a lambda computing the cost
|
355
|
+
# from the optimized variables x.
|
333
356
|
objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))
|
334
357
|
|
358
|
+
# ========================================
|
359
|
+
# Helper function to run the L-BFGS solver
|
360
|
+
# ========================================
|
361
|
+
|
335
362
|
def run_optimization(
|
336
|
-
init_params: jtp.
|
363
|
+
init_params: jtp.Vector,
|
337
364
|
fun: Callable,
|
338
|
-
opt: optax.
|
339
|
-
maxiter:
|
340
|
-
tol:
|
341
|
-
|
342
|
-
|
365
|
+
opt: optax.GradientTransformationExtraArgs,
|
366
|
+
maxiter: int,
|
367
|
+
tol: float,
|
368
|
+
) -> tuple[jtp.Vector, optax.OptState]:
|
369
|
+
|
370
|
+
# Get the function to compute the loss and the gradient w.r.t. its inputs.
|
343
371
|
value_and_grad_fn = optax.value_and_grad_from_state(fun)
|
344
372
|
|
345
|
-
|
373
|
+
# Initialize the carry of the following loop.
|
374
|
+
OptimizationCarry = tuple[jtp.Vector, optax.OptState]
|
375
|
+
init_carry: OptimizationCarry = (init_params, opt.init(params=init_params))
|
376
|
+
|
377
|
+
def step(carry: OptimizationCarry) -> OptimizationCarry:
|
378
|
+
|
346
379
|
params, state = carry
|
380
|
+
|
347
381
|
value, grad = value_and_grad_fn(
|
348
382
|
params,
|
349
383
|
state=state,
|
350
384
|
A=A,
|
351
385
|
b=b,
|
352
386
|
)
|
387
|
+
|
353
388
|
updates, state = opt.update(
|
354
389
|
updates=grad,
|
355
390
|
state=state,
|
@@ -360,22 +395,32 @@ class RelaxedRigidContacts(ContactModel):
|
|
360
395
|
A=A,
|
361
396
|
b=b,
|
362
397
|
)
|
398
|
+
|
363
399
|
params = optax.apply_updates(params, updates)
|
400
|
+
|
364
401
|
return params, state
|
365
402
|
|
366
|
-
def continuing_criterion(carry):
|
403
|
+
def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
|
404
|
+
|
367
405
|
_, state = carry
|
406
|
+
|
368
407
|
iter_num = optax.tree_utils.tree_get(state, "count")
|
369
408
|
grad = optax.tree_utils.tree_get(state, "grad")
|
370
409
|
err = optax.tree_utils.tree_l2_norm(grad)
|
410
|
+
|
371
411
|
return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))
|
372
412
|
|
373
|
-
init_carry = (init_params, opt.init(init_params))
|
374
413
|
final_params, final_state = jax.lax.while_loop(
|
375
414
|
continuing_criterion, step, init_carry
|
376
415
|
)
|
416
|
+
|
377
417
|
return final_params, final_state
|
378
418
|
|
419
|
+
# ======================================
|
420
|
+
# Compute the contact forces with L-BFGS
|
421
|
+
# ======================================
|
422
|
+
|
423
|
+
# Initialize the optimized forces with a linear Hunt/Crossley model.
|
379
424
|
init_params = (
|
380
425
|
K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
|
381
426
|
+ D[:, jnp.newaxis] * velocity
|
@@ -390,28 +435,30 @@ class RelaxedRigidContacts(ContactModel):
|
|
390
435
|
maxiter = solver_options.pop("maxiter")
|
391
436
|
|
392
437
|
# Compute the 3D linear force in C[W] frame.
|
393
|
-
|
438
|
+
solution, _ = run_optimization(
|
394
439
|
init_params=init_params,
|
395
|
-
A=A,
|
396
|
-
b=b,
|
397
|
-
maxiter=maxiter,
|
398
|
-
opt=optax.lbfgs(**solver_options),
|
399
440
|
fun=objective,
|
441
|
+
opt=optax.lbfgs(**solver_options),
|
400
442
|
tol=tol,
|
443
|
+
maxiter=maxiter,
|
401
444
|
)
|
402
445
|
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
446
|
+
# Reshape the optimized solution to be a matrix of 3D contact forces.
|
447
|
+
CW_fl_C = solution.reshape(-1, 3)
|
448
|
+
|
449
|
+
# Convert the contact forces from mixed to inertial-fixed representation.
|
450
|
+
W_f_C = jax.vmap(
|
451
|
+
lambda CW_fl_C, W_H_C: (
|
452
|
+
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
453
|
+
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
|
454
|
+
transform=W_H_C,
|
455
|
+
other_representation=VelRepr.Mixed,
|
456
|
+
is_force=True,
|
457
|
+
)
|
458
|
+
),
|
459
|
+
)(CW_fl_C, W_H_C)
|
413
460
|
|
414
|
-
return W_f_C, (
|
461
|
+
return W_f_C, ()
|
415
462
|
|
416
463
|
@staticmethod
|
417
464
|
def _regularizers(
|
@@ -433,13 +480,28 @@ class RelaxedRigidContacts(ContactModel):
|
|
433
480
|
A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
|
434
481
|
"""
|
435
482
|
|
436
|
-
|
437
|
-
|
483
|
+
# Extract the parameters of the contact model.
|
484
|
+
Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = (
|
485
|
+
getattr(parameters, field)
|
486
|
+
for field in (
|
487
|
+
"time_constant",
|
488
|
+
"damping_coefficient",
|
489
|
+
"d_min",
|
490
|
+
"d_max",
|
491
|
+
"width",
|
492
|
+
"midpoint",
|
493
|
+
"power",
|
494
|
+
"stiffness",
|
495
|
+
"damping",
|
496
|
+
"mu",
|
497
|
+
)
|
438
498
|
)
|
439
499
|
|
440
|
-
|
441
|
-
|
442
|
-
|
500
|
+
# Compute the 6D inertia matrices of all links.
|
501
|
+
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
502
|
+
|
503
|
+
def imp_aref(
|
504
|
+
penetration: jtp.Array, velocity: jtp.Array
|
443
505
|
) -> tuple[jtp.Array, jtp.Array]:
|
444
506
|
"""
|
445
507
|
Calculates impedance and offset acceleration in constraint frame.
|
@@ -474,7 +536,7 @@ class RelaxedRigidContacts(ContactModel):
|
|
474
536
|
|
475
537
|
return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
|
476
538
|
|
477
|
-
def
|
539
|
+
def compute_row(
|
478
540
|
*,
|
479
541
|
link_idx: jtp.Float,
|
480
542
|
penetration: jtp.Array,
|
@@ -482,7 +544,7 @@ class RelaxedRigidContacts(ContactModel):
|
|
482
544
|
) -> tuple[jtp.Array, jtp.Array]:
|
483
545
|
|
484
546
|
# Compute the reference acceleration.
|
485
|
-
ξ, a_ref, K, D =
|
547
|
+
ξ, a_ref, K, D = imp_aref(
|
486
548
|
penetration=penetration,
|
487
549
|
velocity=velocity,
|
488
550
|
)
|
@@ -496,12 +558,10 @@ class RelaxedRigidContacts(ContactModel):
|
|
496
558
|
|
497
559
|
return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
|
498
560
|
|
499
|
-
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
500
|
-
|
501
561
|
a_ref, R, K, D = jax.tree.map(
|
502
|
-
jnp.concatenate,
|
503
|
-
(
|
504
|
-
*jax.vmap(
|
562
|
+
f=jnp.concatenate,
|
563
|
+
tree=(
|
564
|
+
*jax.vmap(compute_row)(
|
505
565
|
link_idx=jnp.array(
|
506
566
|
model.kin_dyn_parameters.contact_parameters.body
|
507
567
|
),
|
@@ -510,4 +570,5 @@ class RelaxedRigidContacts(ContactModel):
|
|
510
570
|
),
|
511
571
|
),
|
512
572
|
)
|
573
|
+
|
513
574
|
return a_ref, jnp.diag(R), K, D
|
jaxsim/rbda/contacts/rigid.py
CHANGED
@@ -13,6 +13,7 @@ from jaxsim import logging
|
|
13
13
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
14
14
|
from jaxsim.terrain import FlatTerrain, Terrain
|
15
15
|
|
16
|
+
from . import common
|
16
17
|
from .common import ContactModel, ContactsParams
|
17
18
|
|
18
19
|
try:
|
@@ -170,46 +171,6 @@ class RigidContacts(ContactModel):
|
|
170
171
|
_solver_options_values=tuple(solver_options.values()),
|
171
172
|
)
|
172
173
|
|
173
|
-
@staticmethod
|
174
|
-
def detect_contacts(
|
175
|
-
W_p_C: jtp.ArrayLike,
|
176
|
-
terrain_height: jtp.ArrayLike,
|
177
|
-
) -> tuple[jtp.Vector, jtp.Vector]:
|
178
|
-
"""
|
179
|
-
Detect contacts between the collidable points and the terrain.
|
180
|
-
|
181
|
-
Args:
|
182
|
-
W_p_C: The position of the collidable points.
|
183
|
-
terrain_height: The height of the terrain at the collidable point position.
|
184
|
-
|
185
|
-
Returns:
|
186
|
-
A tuple containing the activation state of the collidable points
|
187
|
-
and the contact penetration depth h.
|
188
|
-
"""
|
189
|
-
|
190
|
-
# TODO: reduce code duplication with js.contact.in_contact
|
191
|
-
def detect_contact(
|
192
|
-
W_p_C: jtp.ArrayLike,
|
193
|
-
terrain_height: jtp.FloatLike,
|
194
|
-
) -> tuple[jtp.Bool, jtp.Float]:
|
195
|
-
"""
|
196
|
-
Detect contacts between the collidable points and the terrain.
|
197
|
-
"""
|
198
|
-
|
199
|
-
# Unpack the position of the collidable point.
|
200
|
-
_, _, pz = W_p_C.squeeze()
|
201
|
-
|
202
|
-
inactive = pz > terrain_height
|
203
|
-
|
204
|
-
# Compute contact penetration depth
|
205
|
-
h = jnp.maximum(0.0, terrain_height - pz)
|
206
|
-
|
207
|
-
return inactive, h
|
208
|
-
|
209
|
-
inactive_collidable_points, h = jax.vmap(detect_contact)(W_p_C, terrain_height)
|
210
|
-
|
211
|
-
return inactive_collidable_points, h
|
212
|
-
|
213
174
|
@staticmethod
|
214
175
|
def compute_impact_velocity(
|
215
176
|
inactive_collidable_points: jtp.ArrayLike,
|
@@ -281,7 +242,7 @@ class RigidContacts(ContactModel):
|
|
281
242
|
*,
|
282
243
|
link_forces: jtp.MatrixLike | None = None,
|
283
244
|
joint_force_references: jtp.VectorLike | None = None,
|
284
|
-
) -> tuple[jtp.
|
245
|
+
) -> tuple[jtp.Matrix, tuple]:
|
285
246
|
"""
|
286
247
|
Compute the contact forces.
|
287
248
|
|
@@ -295,36 +256,41 @@ class RigidContacts(ContactModel):
|
|
295
256
|
Optional `(n_joints,)` vector of joint forces.
|
296
257
|
|
297
258
|
Returns:
|
298
|
-
A tuple containing the contact forces.
|
259
|
+
A tuple containing as first element the computed contact forces.
|
299
260
|
"""
|
300
261
|
|
301
262
|
# Initialize the model and data this contact model is operating on.
|
302
263
|
# This will raise an exception if either the contact model or the
|
303
264
|
# contact parameters are not compatible.
|
304
265
|
model, data = self.initialize_model_and_data(model=model, data=data)
|
266
|
+
assert isinstance(data.contacts_params, RigidContactsParams)
|
305
267
|
|
306
|
-
# Import qpax just in this method
|
268
|
+
# Import qpax privately just in this method.
|
307
269
|
import qpax
|
308
270
|
|
309
|
-
link_forces = (
|
310
|
-
link_forces
|
271
|
+
link_forces = jnp.atleast_2d(
|
272
|
+
jnp.array(link_forces, dtype=float).squeeze()
|
311
273
|
if link_forces is not None
|
312
274
|
else jnp.zeros((model.number_of_links(), 6))
|
313
275
|
)
|
314
276
|
|
315
|
-
joint_force_references = (
|
316
|
-
joint_force_references
|
277
|
+
joint_force_references = jnp.atleast_1d(
|
278
|
+
jnp.array(joint_force_references, dtype=float).squeeze()
|
317
279
|
if joint_force_references is not None
|
318
280
|
else jnp.zeros((model.number_of_joints(),))
|
319
281
|
)
|
320
282
|
|
321
|
-
# Compute kin-dyn quantities used in the contact model
|
283
|
+
# Compute kin-dyn quantities used in the contact model.
|
322
284
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
285
|
+
|
286
|
+
BW_ν = data.generalized_velocity()
|
287
|
+
|
323
288
|
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
289
|
+
|
324
290
|
J_WC = js.contact.jacobian(model=model, data=data)
|
291
|
+
J̇_WC = js.contact.jacobian_derivative(model=model, data=data)
|
292
|
+
|
325
293
|
W_H_C = js.contact.transforms(model=model, data=data)
|
326
|
-
J̇_WC_BW = js.contact.jacobian_derivative(model=model, data=data)
|
327
|
-
BW_ν = data.generalized_velocity()
|
328
294
|
|
329
295
|
# Compute the position and linear velocities (mixed representation) of
|
330
296
|
# all collidable points belonging to the robot.
|
@@ -332,23 +298,16 @@ class RigidContacts(ContactModel):
|
|
332
298
|
model=model, data=data
|
333
299
|
)
|
334
300
|
|
335
|
-
|
336
|
-
n_collidable_points = model.kin_dyn_parameters.contact_parameters.
|
301
|
+
# Get the number of collidable points.
|
302
|
+
n_collidable_points = len(model.kin_dyn_parameters.contact_parameters.body)
|
337
303
|
|
338
|
-
# Compute the
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
)
|
343
|
-
|
344
|
-
# Compute the Delassus matrix.
|
345
|
-
delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
|
346
|
-
|
347
|
-
# Add regularization for better numerical conditioning.
|
348
|
-
delassus_matrix = delassus_matrix + self.regularization_delassus * jnp.eye(
|
349
|
-
delassus_matrix.shape[0]
|
304
|
+
# Compute the penetration depth and velocity of the collidable points.
|
305
|
+
# Note that this function considers the penetration in the normal direction.
|
306
|
+
δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
|
307
|
+
position, velocity, model.terrain
|
350
308
|
)
|
351
309
|
|
310
|
+
# Build a references object to simplify converting link forces.
|
352
311
|
references = js.references.JaxSimModelReferences.build(
|
353
312
|
model=model,
|
354
313
|
data=data,
|
@@ -357,10 +316,12 @@ class RigidContacts(ContactModel):
|
|
357
316
|
joint_force_references=joint_force_references,
|
358
317
|
)
|
359
318
|
|
319
|
+
# Compute the generalized free acceleration.
|
360
320
|
with (
|
361
321
|
references.switch_velocity_representation(VelRepr.Mixed),
|
362
322
|
data.switch_velocity_representation(VelRepr.Mixed),
|
363
323
|
):
|
324
|
+
|
364
325
|
BW_ν̇_free = jnp.hstack(
|
365
326
|
js.ode.system_acceleration(
|
366
327
|
model=model,
|
@@ -372,64 +333,74 @@ class RigidContacts(ContactModel):
|
|
372
333
|
)
|
373
334
|
)
|
374
335
|
|
336
|
+
# Compute the free linear acceleration of the collidable points.
|
337
|
+
# Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.
|
375
338
|
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
|
376
339
|
BW_nu=BW_ν,
|
377
340
|
BW_nu_dot=BW_ν̇_free,
|
378
341
|
CW_J_WC_BW=J_WC,
|
379
|
-
CW_J_dot_WC_BW=J̇
|
342
|
+
CW_J_dot_WC_BW=J̇_WC,
|
380
343
|
).flatten()
|
381
344
|
|
382
|
-
# Compute stabilization term
|
383
|
-
ḣ = velocity[:, 2].squeeze()
|
345
|
+
# Compute stabilization term.
|
384
346
|
baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term(
|
385
|
-
inactive_collidable_points=
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
347
|
+
inactive_collidable_points=(δ <= 0),
|
348
|
+
δ=δ,
|
349
|
+
δ_dot=δ_dot,
|
350
|
+
n=n̂,
|
351
|
+
K=data.contacts_params.K,
|
352
|
+
D=data.contacts_params.D,
|
390
353
|
).flatten()
|
391
354
|
|
392
|
-
|
355
|
+
# Compute the Delassus matrix.
|
356
|
+
delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
|
357
|
+
|
358
|
+
# Initialize regularization term of the Delassus matrix for
|
359
|
+
# better numerical conditioning.
|
360
|
+
Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0])
|
361
|
+
|
362
|
+
# Construct the quadratic cost function.
|
363
|
+
Q = delassus_matrix + Iε
|
364
|
+
q = free_contact_acc - baumgarte_term
|
393
365
|
|
394
|
-
#
|
395
|
-
Q = delassus_matrix
|
396
|
-
q = free_contact_acc
|
366
|
+
# Construct the inequality constraints.
|
397
367
|
G = RigidContacts._compute_ineq_constraint_matrix(
|
398
|
-
inactive_collidable_points=
|
368
|
+
inactive_collidable_points=(δ <= 0), mu=data.contacts_params.mu
|
399
369
|
)
|
400
370
|
h_bounds = RigidContacts._compute_ineq_bounds(
|
401
371
|
n_collidable_points=n_collidable_points
|
402
372
|
)
|
373
|
+
|
374
|
+
# Construct the equality constraints.
|
403
375
|
A = jnp.zeros((0, 3 * n_collidable_points))
|
404
376
|
b = jnp.zeros((0,))
|
405
377
|
|
406
|
-
# Solve the optimization problem
|
407
|
-
|
378
|
+
# Solve the following optimization problem with qpax:
|
379
|
+
#
|
380
|
+
# min_{x} 0.5 x⊤ Q x + q⊤ x
|
381
|
+
#
|
382
|
+
# s.t. A x = b
|
383
|
+
# G x ≤ h
|
384
|
+
#
|
385
|
+
# TODO: add possibility to notify if the QP problem did not converge.
|
386
|
+
solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841
|
408
387
|
Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
|
409
388
|
)
|
410
389
|
|
411
|
-
|
412
|
-
|
413
|
-
# Transform linear contact forces to 6D
|
414
|
-
CW_f_C = jnp.hstack(
|
415
|
-
(
|
416
|
-
f_C_lin,
|
417
|
-
jnp.zeros((f_C_lin.shape[0], 3)),
|
418
|
-
)
|
419
|
-
)
|
390
|
+
# Reshape the optimized solution to be a matrix of 3D contact forces.
|
391
|
+
CW_fl_C = solution.reshape(-1, 3)
|
420
392
|
|
421
|
-
#
|
393
|
+
# Convert the contact forces from mixed to inertial-fixed representation.
|
422
394
|
W_f_C = jax.vmap(
|
423
|
-
lambda
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
395
|
+
lambda CW_fl_C, W_H_C: (
|
396
|
+
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
397
|
+
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
|
398
|
+
transform=W_H_C,
|
399
|
+
other_representation=VelRepr.Mixed,
|
400
|
+
is_force=True,
|
401
|
+
)
|
428
402
|
),
|
429
|
-
)(
|
430
|
-
CW_f_C,
|
431
|
-
W_H_C,
|
432
|
-
)
|
403
|
+
)(CW_fl_C, W_H_C)
|
433
404
|
|
434
405
|
return W_f_C, ()
|
435
406
|
|
@@ -438,6 +409,7 @@ class RigidContacts(ContactModel):
|
|
438
409
|
M: jtp.MatrixLike,
|
439
410
|
J_WC: jtp.MatrixLike,
|
440
411
|
) -> jtp.Matrix:
|
412
|
+
|
441
413
|
sl = jnp.s_[:, 0:3, :]
|
442
414
|
J_WC_lin = jnp.vstack(J_WC[sl])
|
443
415
|
|
@@ -448,6 +420,7 @@ class RigidContacts(ContactModel):
|
|
448
420
|
def _compute_ineq_constraint_matrix(
|
449
421
|
inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
|
450
422
|
) -> jtp.Matrix:
|
423
|
+
|
451
424
|
def compute_G_single_point(mu: float, c: float) -> jtp.Matrix:
|
452
425
|
"""
|
453
426
|
Compute the inequality constraint matrix for a single collidable point
|
@@ -475,6 +448,7 @@ class RigidContacts(ContactModel):
|
|
475
448
|
|
476
449
|
@staticmethod
|
477
450
|
def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector:
|
451
|
+
|
478
452
|
n_constraints = 6 * n_collidable_points
|
479
453
|
return jnp.zeros(shape=(n_constraints,))
|
480
454
|
|
@@ -485,45 +459,50 @@ class RigidContacts(ContactModel):
|
|
485
459
|
CW_J_WC_BW: jtp.MatrixLike,
|
486
460
|
CW_J_dot_WC_BW: jtp.MatrixLike,
|
487
461
|
) -> jtp.Matrix:
|
488
|
-
|
462
|
+
|
489
463
|
BW_ν = BW_nu
|
490
464
|
BW_ν̇ = BW_nu_dot
|
465
|
+
CW_J̇_WC_BW = CW_J_dot_WC_BW
|
491
466
|
|
467
|
+
# Compute the linear acceleration of the collidable points.
|
468
|
+
# Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
|
492
469
|
CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
|
493
|
-
CW_a_WC = CW_a_WC.reshape(-1, 6)
|
494
470
|
|
471
|
+
CW_a_WC = CW_a_WC.reshape(-1, 6)
|
495
472
|
return CW_a_WC[:, 0:3].squeeze()
|
496
473
|
|
497
474
|
@staticmethod
|
498
475
|
def _compute_baumgarte_stabilization_term(
|
499
476
|
inactive_collidable_points: jtp.ArrayLike,
|
500
|
-
|
501
|
-
|
477
|
+
δ: jtp.ArrayLike,
|
478
|
+
δ_dot: jtp.ArrayLike,
|
479
|
+
n: jtp.ArrayLike,
|
502
480
|
K: jtp.FloatLike,
|
503
481
|
D: jtp.FloatLike,
|
504
482
|
) -> jtp.Array:
|
505
|
-
|
483
|
+
|
484
|
+
def baumgarte_stabilization_of_single_point(
|
506
485
|
inactive: jtp.BoolLike,
|
507
|
-
|
508
|
-
|
486
|
+
δ: jtp.FloatLike,
|
487
|
+
δ_dot: jtp.FloatLike,
|
488
|
+
n: jtp.ArrayLike,
|
509
489
|
k_baumgarte: jtp.FloatLike,
|
510
490
|
d_baumgarte: jtp.FloatLike,
|
511
491
|
) -> jtp.Array:
|
492
|
+
|
512
493
|
baumgarte_term = jax.lax.cond(
|
513
494
|
inactive,
|
514
|
-
lambda
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
k_baumgarte,
|
520
|
-
d_baumgarte,
|
521
|
-
),
|
495
|
+
lambda δ, δ_dot, n, K, D: jnp.zeros(3),
|
496
|
+
# This is equivalent to: K*(pT - p)⋅n̂ + D*(0 - v)⋅n̂,
|
497
|
+
# where pT is the point on the terrain surface vertical to p.
|
498
|
+
lambda δ, δ_dot, n, K, D: (K * δ + D * δ_dot) * n,
|
499
|
+
*(δ, δ_dot, n, k_baumgarte, d_baumgarte),
|
522
500
|
)
|
501
|
+
|
523
502
|
return baumgarte_term
|
524
503
|
|
525
504
|
baumgarte_term = jax.vmap(
|
526
|
-
|
527
|
-
)(inactive_collidable_points,
|
505
|
+
baumgarte_stabilization_of_single_point, in_axes=(0, 0, 0, 0, None, None)
|
506
|
+
)(inactive_collidable_points, δ, δ_dot, n, K, D)
|
528
507
|
|
529
508
|
return baumgarte_term
|
jaxsim/rbda/contacts/soft.py
CHANGED
@@ -14,7 +14,7 @@ from jaxsim import logging
|
|
14
14
|
from jaxsim.math import StandardGravity
|
15
15
|
from jaxsim.terrain import FlatTerrain, Terrain
|
16
16
|
|
17
|
-
from .
|
17
|
+
from . import common
|
18
18
|
|
19
19
|
try:
|
20
20
|
from typing import Self
|
@@ -23,7 +23,7 @@ except ImportError:
|
|
23
23
|
|
24
24
|
|
25
25
|
@jax_dataclasses.pytree_dataclass
|
26
|
-
class SoftContactsParams(ContactsParams):
|
26
|
+
class SoftContactsParams(common.ContactsParams):
|
27
27
|
"""Parameters of the soft contacts model."""
|
28
28
|
|
29
29
|
K: jtp.Float = dataclasses.field(
|
@@ -161,7 +161,9 @@ class SoftContactsParams(ContactsParams):
|
|
161
161
|
f_average = m * g / number_of_active_collidable_points_steady_state
|
162
162
|
|
163
163
|
# Compute the stiffness to get the desired steady-state penetration.
|
164
|
-
|
164
|
+
# Note that this is dependent on the non-linear exponent used in
|
165
|
+
# the damping term of the Hunt/Crossley model.
|
166
|
+
K = f_average / jnp.power(δ_max, 1 + p)
|
165
167
|
|
166
168
|
# Compute the damping using the damping ratio.
|
167
169
|
critical_damping = 2 * jnp.sqrt(K * m)
|
@@ -189,7 +191,7 @@ class SoftContactsParams(ContactsParams):
|
|
189
191
|
|
190
192
|
|
191
193
|
@jax_dataclasses.pytree_dataclass
|
192
|
-
class SoftContacts(ContactModel):
|
194
|
+
class SoftContacts(common.ContactModel):
|
193
195
|
"""Soft contacts model."""
|
194
196
|
|
195
197
|
parameters: SoftContactsParams = dataclasses.field(
|
@@ -277,9 +279,7 @@ class SoftContacts(ContactModel):
|
|
277
279
|
μ = mu
|
278
280
|
|
279
281
|
# Compute the penetration depth, its rate, and the considered terrain normal.
|
280
|
-
δ, δ̇, n̂ =
|
281
|
-
p=W_p_C, v=W_ṗ_C, terrain=terrain
|
282
|
-
)
|
282
|
+
δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)
|
283
283
|
|
284
284
|
# There are few operations like computing the norm of a vector with zero length
|
285
285
|
# or computing the square root of zero that are problematic in an AD context.
|
@@ -423,7 +423,18 @@ class SoftContacts(ContactModel):
|
|
423
423
|
self,
|
424
424
|
model: js.model.JaxSimModel,
|
425
425
|
data: js.data.JaxSimModelData,
|
426
|
-
) -> tuple[jtp.
|
426
|
+
) -> tuple[jtp.Matrix, tuple[jtp.Matrix]]:
|
427
|
+
"""
|
428
|
+
Compute the contact forces.
|
429
|
+
|
430
|
+
Args:
|
431
|
+
model: The model to consider.
|
432
|
+
data: The data of the considered model.
|
433
|
+
|
434
|
+
Returns:
|
435
|
+
A tuple containing as first element the computed contact forces, and as
|
436
|
+
second element the derivative of the material deformation.
|
437
|
+
"""
|
427
438
|
|
428
439
|
# Initialize the model and data this contact model is operating on.
|
429
440
|
# This will raise an exception if either the contact model or the
|
@@ -444,36 +455,9 @@ class SoftContacts(ContactModel):
|
|
444
455
|
position=p,
|
445
456
|
velocity=v,
|
446
457
|
tangential_deformation=m,
|
447
|
-
parameters=
|
448
|
-
terrain=
|
458
|
+
parameters=data.contacts_params,
|
459
|
+
terrain=model.terrain,
|
449
460
|
)
|
450
461
|
)(W_p_C, W_ṗ_C, m)
|
451
462
|
|
452
463
|
return W_f, (ṁ,)
|
453
|
-
|
454
|
-
@staticmethod
|
455
|
-
@jax.jit
|
456
|
-
def compute_penetration_data(
|
457
|
-
p: jtp.VectorLike,
|
458
|
-
v: jtp.VectorLike,
|
459
|
-
terrain: jaxsim.terrain.Terrain,
|
460
|
-
) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
|
461
|
-
|
462
|
-
# Pre-process the position and the linear velocity of the collidable point.
|
463
|
-
W_ṗ_C = jnp.array(v).squeeze()
|
464
|
-
px, py, pz = jnp.array(p).squeeze()
|
465
|
-
|
466
|
-
# Compute the terrain normal and the contact depth.
|
467
|
-
n̂ = terrain.normal(x=px, y=py).squeeze()
|
468
|
-
h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])
|
469
|
-
|
470
|
-
# Compute the penetration depth normal to the terrain.
|
471
|
-
δ = jnp.maximum(0.0, jnp.dot(h, n̂))
|
472
|
-
|
473
|
-
# Compute the penetration normal velocity.
|
474
|
-
δ̇ = -jnp.dot(W_ṗ_C, n̂)
|
475
|
-
|
476
|
-
# Enforce the penetration rate to be zero when the penetration depth is zero.
|
477
|
-
δ̇ = jnp.where(δ > 0, δ̇, 0.0)
|
478
|
-
|
479
|
-
return δ, δ̇, n̂
|
@@ -195,7 +195,7 @@ class ViscoElasticContacts(common.ContactModel):
|
|
195
195
|
default_factory=FlatTerrain
|
196
196
|
)
|
197
197
|
|
198
|
-
max_squarings: jax_dataclasses.Static[int] = 25
|
198
|
+
max_squarings: jax_dataclasses.Static[int] = dataclasses.field(default=25)
|
199
199
|
|
200
200
|
@classmethod
|
201
201
|
def build(
|
@@ -239,7 +239,7 @@ class ViscoElasticContacts(common.ContactModel):
|
|
239
239
|
parameters=parameters,
|
240
240
|
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
|
241
241
|
max_squarings=int(
|
242
|
-
max_squarings or cls.__dataclass_fields__["max_squarings"].default
|
242
|
+
max_squarings or cls.__dataclass_fields__["max_squarings"].default
|
243
243
|
),
|
244
244
|
)
|
245
245
|
|
@@ -266,7 +266,7 @@ class ViscoElasticContacts(common.ContactModel):
|
|
266
266
|
dt: jtp.FloatLike | None = None,
|
267
267
|
link_forces: jtp.MatrixLike | None = None,
|
268
268
|
joint_force_references: jtp.VectorLike | None = None,
|
269
|
-
) -> tuple[jtp.
|
269
|
+
) -> tuple[jtp.Matrix, tuple[jtp.Matrix, jtp.Matrix]]:
|
270
270
|
"""
|
271
271
|
Compute the contact forces.
|
272
272
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev200
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -1,18 +1,18 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=WDziMJEeSmuE81cozOtxmazlb4qAX6VPTrKOR0f3akg,428
|
3
3
|
jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
|
6
6
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
7
7
|
jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
|
8
8
|
jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
|
9
|
-
jaxsim/api/contact.py,sha256=
|
10
|
-
jaxsim/api/data.py,sha256=
|
9
|
+
jaxsim/api/contact.py,sha256=Egc62310ljn5goXlswwJYSB-LyW6M5gmPoT_a3mkd7U,25812
|
10
|
+
jaxsim/api/data.py,sha256=gQX6hfEaw0ooJYvpr5f8UvEJwqhtflEK_NHWn9XgTZY,28935
|
11
11
|
jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
|
12
12
|
jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
|
13
13
|
jaxsim/api/kin_dyn_parameters.py,sha256=thJbz9XhpXgom23S6MXX2ugxGoAD-k947ZMAHDisy2w,29620
|
14
14
|
jaxsim/api/link.py,sha256=LAA6ZMQXkWomXeptURBtc7z3_xDZ2BBnBMhVrohh0bE,18621
|
15
|
-
jaxsim/api/model.py,sha256
|
15
|
+
jaxsim/api/model.py,sha256=s2i4obxMjZ_XntJgT0dEV57LCo0GIC7VppUnxsqC1fc,69704
|
16
16
|
jaxsim/api/ode.py,sha256=J_WuaoPl3ZY-yvTrCQun-rQoIAv_duynSXAGxqx93sg,14211
|
17
17
|
jaxsim/api/ode_data.py,sha256=1SD-x-lYk_YSEnVpxTLd69uOKC0mFUj44ZqpSmEDOxw,20190
|
18
18
|
jaxsim/api/references.py,sha256=fW77LitZ8DYgT6ZmUInJfm5luBV1mTcqcNRiC_i79og,20862
|
@@ -53,20 +53,20 @@ jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdul
|
|
53
53
|
jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
|
54
54
|
jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
|
55
55
|
jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
|
56
|
-
jaxsim/rbda/contacts/__init__.py,sha256=
|
57
|
-
jaxsim/rbda/contacts/common.py,sha256=
|
58
|
-
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=
|
59
|
-
jaxsim/rbda/contacts/rigid.py,sha256=
|
60
|
-
jaxsim/rbda/contacts/soft.py,sha256=
|
61
|
-
jaxsim/rbda/contacts/visco_elastic.py,sha256=
|
56
|
+
jaxsim/rbda/contacts/__init__.py,sha256=L5MM-2pv76YPGzxExdz2EErgGBATuAjYnNHlq5QOySs,503
|
57
|
+
jaxsim/rbda/contacts/common.py,sha256=iywCQtesrnrwywRQv8cjyot2bG11dT_iONyF8OJztIA,5798
|
58
|
+
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=TR81tJ4ipcpvPnwlfkpyNDhvWizpEG542SFVu_CwHRU,19614
|
59
|
+
jaxsim/rbda/contacts/rigid.py,sha256=3aDPFrIm2_QpKKRpTqJJk8qBK-W63gq7Arc8WDVAcHc,17382
|
60
|
+
jaxsim/rbda/contacts/soft.py,sha256=6eFgV2hJK793RZfoY8oSqw-zC1UqFldaE0hfGHELnmU,16325
|
61
|
+
jaxsim/rbda/contacts/visco_elastic.py,sha256=wATvBhLrV-7IyVLJhW7OaMg_HDAmczl_8MnYm3wuqSc,39819
|
62
62
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
63
63
|
jaxsim/terrain/terrain.py,sha256=K91HEzPqTSyNrc_j1KfAAEF_5oDeuk_-jnnZGrcMEcY,5015
|
64
64
|
jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
65
65
|
jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
|
66
66
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
67
67
|
jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
|
68
|
-
jaxsim-0.4.3.
|
69
|
-
jaxsim-0.4.3.
|
70
|
-
jaxsim-0.4.3.
|
71
|
-
jaxsim-0.4.3.
|
72
|
-
jaxsim-0.4.3.
|
68
|
+
jaxsim-0.4.3.dev200.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
69
|
+
jaxsim-0.4.3.dev200.dist-info/METADATA,sha256=NUJ6GXIFFK-y9-p-M2OTTdI3g7utYHa3Lsg2VXSXtoI,17276
|
70
|
+
jaxsim-0.4.3.dev200.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
71
|
+
jaxsim-0.4.3.dev200.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
72
|
+
jaxsim-0.4.3.dev200.dist-info/RECORD,,
|
File without changes
|
File without changes
|