jaxsim 0.4.3.dev181__py3-none-any.whl → 0.4.3.dev200__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +58 -19
- jaxsim/api/data.py +3 -2
- jaxsim/api/model.py +8 -12
- jaxsim/mujoco/__init__.py +1 -0
- jaxsim/mujoco/loaders.py +1 -1
- jaxsim/mujoco/model.py +4 -3
- jaxsim/mujoco/utils.py +101 -0
- jaxsim/mujoco/visualizer.py +1 -1
- jaxsim/rbda/contacts/__init__.py +7 -0
- jaxsim/rbda/contacts/common.py +49 -3
- jaxsim/rbda/contacts/relaxed_rigid.py +133 -72
- jaxsim/rbda/contacts/rigid.py +94 -115
- jaxsim/rbda/contacts/soft.py +21 -37
- jaxsim/rbda/contacts/visco_elastic.py +3 -3
- {jaxsim-0.4.3.dev181.dist-info → jaxsim-0.4.3.dev200.dist-info}/METADATA +1 -1
- {jaxsim-0.4.3.dev181.dist-info → jaxsim-0.4.3.dev200.dist-info}/RECORD +20 -19
- {jaxsim-0.4.3.dev181.dist-info → jaxsim-0.4.3.dev200.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.3.dev181.dist-info → jaxsim-0.4.3.dev200.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev181.dist-info → jaxsim-0.4.3.dev200.dist-info}/top_level.txt +0 -0
@@ -12,11 +12,10 @@ import optax
|
|
12
12
|
import jaxsim.api as js
|
13
13
|
import jaxsim.typing as jtp
|
14
14
|
from jaxsim import logging
|
15
|
-
from jaxsim.api.common import VelRepr
|
16
|
-
from jaxsim.math import Adjoint
|
15
|
+
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
17
16
|
from jaxsim.terrain.terrain import FlatTerrain, Terrain
|
18
17
|
|
19
|
-
from .
|
18
|
+
from . import common
|
20
19
|
|
21
20
|
try:
|
22
21
|
from typing import Self
|
@@ -25,7 +24,7 @@ except ImportError:
|
|
25
24
|
|
26
25
|
|
27
26
|
@jax_dataclasses.pytree_dataclass
|
28
|
-
class RelaxedRigidContactsParams(ContactsParams):
|
27
|
+
class RelaxedRigidContactsParams(common.ContactsParams):
|
29
28
|
"""Parameters of the relaxed rigid contacts model."""
|
30
29
|
|
31
30
|
# Time constant
|
@@ -116,14 +115,24 @@ class RelaxedRigidContactsParams(ContactsParams):
|
|
116
115
|
) -> Self:
|
117
116
|
"""Create a `RelaxedRigidContactsParams` instance"""
|
118
117
|
|
118
|
+
def default(name: str):
|
119
|
+
return cls.__dataclass_fields__[name].default_factory()
|
120
|
+
|
119
121
|
return cls(
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
122
|
+
time_constant=jnp.array(
|
123
|
+
time_constant or default("time_constant"), dtype=float
|
124
|
+
),
|
125
|
+
damping_coefficient=jnp.array(
|
126
|
+
damping_coefficient or default("damping_coefficient"), dtype=float
|
127
|
+
),
|
128
|
+
d_min=jnp.array(d_min or default("d_min"), dtype=float),
|
129
|
+
d_max=jnp.array(d_max or default("d_max"), dtype=float),
|
130
|
+
width=jnp.array(width or default("width"), dtype=float),
|
131
|
+
midpoint=jnp.array(midpoint or default("midpoint"), dtype=float),
|
132
|
+
power=jnp.array(power or default("power"), dtype=float),
|
133
|
+
stiffness=jnp.array(stiffness or default("stiffness"), dtype=float),
|
134
|
+
damping=jnp.array(damping or default("damping"), dtype=float),
|
135
|
+
mu=jnp.array(mu or default("mu"), dtype=float),
|
127
136
|
)
|
128
137
|
|
129
138
|
def valid(self) -> jtp.BoolLike:
|
@@ -142,7 +151,7 @@ class RelaxedRigidContactsParams(ContactsParams):
|
|
142
151
|
|
143
152
|
|
144
153
|
@jax_dataclasses.pytree_dataclass
|
145
|
-
class RelaxedRigidContacts(ContactModel):
|
154
|
+
class RelaxedRigidContacts(common.ContactModel):
|
146
155
|
"""Relaxed rigid contacts model."""
|
147
156
|
|
148
157
|
parameters: RelaxedRigidContactsParams = dataclasses.field(
|
@@ -229,7 +238,7 @@ class RelaxedRigidContacts(ContactModel):
|
|
229
238
|
*,
|
230
239
|
link_forces: jtp.MatrixLike | None = None,
|
231
240
|
joint_force_references: jtp.VectorLike | None = None,
|
232
|
-
) -> tuple[jtp.
|
241
|
+
) -> tuple[jtp.Matrix, tuple]:
|
233
242
|
"""
|
234
243
|
Compute the contact forces.
|
235
244
|
|
@@ -243,22 +252,23 @@ class RelaxedRigidContacts(ContactModel):
|
|
243
252
|
Optional `(n_joints,)` vector of joint forces.
|
244
253
|
|
245
254
|
Returns:
|
246
|
-
A tuple containing the contact forces.
|
255
|
+
A tuple containing as first element the computed contact forces.
|
247
256
|
"""
|
248
257
|
|
249
258
|
# Initialize the model and data this contact model is operating on.
|
250
259
|
# This will raise an exception if either the contact model or the
|
251
260
|
# contact parameters are not compatible.
|
252
261
|
model, data = self.initialize_model_and_data(model=model, data=data)
|
262
|
+
assert isinstance(data.contacts_params, RelaxedRigidContactsParams)
|
253
263
|
|
254
|
-
link_forces = (
|
255
|
-
link_forces
|
264
|
+
link_forces = jnp.atleast_2d(
|
265
|
+
jnp.array(link_forces, dtype=float).squeeze()
|
256
266
|
if link_forces is not None
|
257
267
|
else jnp.zeros((model.number_of_links(), 6))
|
258
268
|
)
|
259
269
|
|
260
|
-
joint_force_references = (
|
261
|
-
joint_force_references
|
270
|
+
joint_force_references = jnp.atleast_1d(
|
271
|
+
jnp.array(joint_force_references, dtype=float).squeeze()
|
262
272
|
if joint_force_references is not None
|
263
273
|
else jnp.zeros(model.number_of_joints())
|
264
274
|
)
|
@@ -271,10 +281,10 @@ class RelaxedRigidContacts(ContactModel):
|
|
271
281
|
joint_force_references=joint_force_references,
|
272
282
|
)
|
273
283
|
|
274
|
-
def
|
284
|
+
def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
|
275
285
|
x, y, z = jax.tree.map(jnp.squeeze, (x, y, z))
|
276
286
|
|
277
|
-
n̂ =
|
287
|
+
n̂ = model.terrain.normal(x=x, y=y).squeeze()
|
278
288
|
h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
|
279
289
|
|
280
290
|
return jnp.dot(h, n̂)
|
@@ -286,19 +296,19 @@ class RelaxedRigidContacts(ContactModel):
|
|
286
296
|
)
|
287
297
|
|
288
298
|
# Compute the activation state of the collidable points
|
289
|
-
δ = jax.vmap(
|
299
|
+
δ = jax.vmap(detect_contact)(*position.T)
|
300
|
+
|
301
|
+
# Compute the transforms of the implicit frames corresponding to the
|
302
|
+
# collidable points.
|
303
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
290
304
|
|
291
305
|
with (
|
292
306
|
references.switch_velocity_representation(VelRepr.Mixed),
|
293
307
|
data.switch_velocity_representation(VelRepr.Mixed),
|
294
308
|
):
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
299
|
-
)
|
300
|
-
)
|
301
|
-
W_H_C = js.contact.transforms(model=model, data=data)
|
309
|
+
|
310
|
+
BW_ν = data.generalized_velocity()
|
311
|
+
|
302
312
|
BW_ν̇_free = jnp.hstack(
|
303
313
|
js.ode.system_acceleration(
|
304
314
|
model=model,
|
@@ -309,20 +319,31 @@ class RelaxedRigidContacts(ContactModel):
|
|
309
319
|
),
|
310
320
|
)
|
311
321
|
)
|
312
|
-
|
322
|
+
|
323
|
+
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
324
|
+
|
325
|
+
Jl_WC = jnp.vstack(
|
326
|
+
jax.vmap(lambda J, height: J * (height < 0))(
|
327
|
+
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
328
|
+
)
|
329
|
+
)
|
330
|
+
|
313
331
|
J̇_WC = jnp.vstack(
|
314
332
|
jax.vmap(lambda J̇, height: J̇ * (height < 0))(
|
315
333
|
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
|
316
334
|
),
|
317
335
|
)
|
318
336
|
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
337
|
+
# Compute the regularization terms.
|
338
|
+
a_ref, R, K, D = self._regularizers(
|
339
|
+
model=model,
|
340
|
+
penetration=δ,
|
341
|
+
velocity=velocity,
|
342
|
+
parameters=data.contacts_params,
|
343
|
+
)
|
325
344
|
|
345
|
+
# Compute the Delassus matrix and the free mixed linear acceleration of
|
346
|
+
# the collidable points.
|
326
347
|
G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
|
327
348
|
CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
|
328
349
|
|
@@ -330,26 +351,40 @@ class RelaxedRigidContacts(ContactModel):
|
|
330
351
|
A = G + R
|
331
352
|
b = CW_al_free_WC - a_ref
|
332
353
|
|
354
|
+
# Create the objective function to minimize as a lambda computing the cost
|
355
|
+
# from the optimized variables x.
|
333
356
|
objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))
|
334
357
|
|
358
|
+
# ========================================
|
359
|
+
# Helper function to run the L-BFGS solver
|
360
|
+
# ========================================
|
361
|
+
|
335
362
|
def run_optimization(
|
336
|
-
init_params: jtp.
|
363
|
+
init_params: jtp.Vector,
|
337
364
|
fun: Callable,
|
338
|
-
opt: optax.
|
339
|
-
maxiter:
|
340
|
-
tol:
|
341
|
-
|
342
|
-
|
365
|
+
opt: optax.GradientTransformationExtraArgs,
|
366
|
+
maxiter: int,
|
367
|
+
tol: float,
|
368
|
+
) -> tuple[jtp.Vector, optax.OptState]:
|
369
|
+
|
370
|
+
# Get the function to compute the loss and the gradient w.r.t. its inputs.
|
343
371
|
value_and_grad_fn = optax.value_and_grad_from_state(fun)
|
344
372
|
|
345
|
-
|
373
|
+
# Initialize the carry of the following loop.
|
374
|
+
OptimizationCarry = tuple[jtp.Vector, optax.OptState]
|
375
|
+
init_carry: OptimizationCarry = (init_params, opt.init(params=init_params))
|
376
|
+
|
377
|
+
def step(carry: OptimizationCarry) -> OptimizationCarry:
|
378
|
+
|
346
379
|
params, state = carry
|
380
|
+
|
347
381
|
value, grad = value_and_grad_fn(
|
348
382
|
params,
|
349
383
|
state=state,
|
350
384
|
A=A,
|
351
385
|
b=b,
|
352
386
|
)
|
387
|
+
|
353
388
|
updates, state = opt.update(
|
354
389
|
updates=grad,
|
355
390
|
state=state,
|
@@ -360,22 +395,32 @@ class RelaxedRigidContacts(ContactModel):
|
|
360
395
|
A=A,
|
361
396
|
b=b,
|
362
397
|
)
|
398
|
+
|
363
399
|
params = optax.apply_updates(params, updates)
|
400
|
+
|
364
401
|
return params, state
|
365
402
|
|
366
|
-
def continuing_criterion(carry):
|
403
|
+
def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
|
404
|
+
|
367
405
|
_, state = carry
|
406
|
+
|
368
407
|
iter_num = optax.tree_utils.tree_get(state, "count")
|
369
408
|
grad = optax.tree_utils.tree_get(state, "grad")
|
370
409
|
err = optax.tree_utils.tree_l2_norm(grad)
|
410
|
+
|
371
411
|
return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))
|
372
412
|
|
373
|
-
init_carry = (init_params, opt.init(init_params))
|
374
413
|
final_params, final_state = jax.lax.while_loop(
|
375
414
|
continuing_criterion, step, init_carry
|
376
415
|
)
|
416
|
+
|
377
417
|
return final_params, final_state
|
378
418
|
|
419
|
+
# ======================================
|
420
|
+
# Compute the contact forces with L-BFGS
|
421
|
+
# ======================================
|
422
|
+
|
423
|
+
# Initialize the optimized forces with a linear Hunt/Crossley model.
|
379
424
|
init_params = (
|
380
425
|
K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
|
381
426
|
+ D[:, jnp.newaxis] * velocity
|
@@ -390,28 +435,30 @@ class RelaxedRigidContacts(ContactModel):
|
|
390
435
|
maxiter = solver_options.pop("maxiter")
|
391
436
|
|
392
437
|
# Compute the 3D linear force in C[W] frame.
|
393
|
-
|
438
|
+
solution, _ = run_optimization(
|
394
439
|
init_params=init_params,
|
395
|
-
A=A,
|
396
|
-
b=b,
|
397
|
-
maxiter=maxiter,
|
398
|
-
opt=optax.lbfgs(**solver_options),
|
399
440
|
fun=objective,
|
441
|
+
opt=optax.lbfgs(**solver_options),
|
400
442
|
tol=tol,
|
443
|
+
maxiter=maxiter,
|
401
444
|
)
|
402
445
|
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
446
|
+
# Reshape the optimized solution to be a matrix of 3D contact forces.
|
447
|
+
CW_fl_C = solution.reshape(-1, 3)
|
448
|
+
|
449
|
+
# Convert the contact forces from mixed to inertial-fixed representation.
|
450
|
+
W_f_C = jax.vmap(
|
451
|
+
lambda CW_fl_C, W_H_C: (
|
452
|
+
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
453
|
+
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
|
454
|
+
transform=W_H_C,
|
455
|
+
other_representation=VelRepr.Mixed,
|
456
|
+
is_force=True,
|
457
|
+
)
|
458
|
+
),
|
459
|
+
)(CW_fl_C, W_H_C)
|
413
460
|
|
414
|
-
return W_f_C, (
|
461
|
+
return W_f_C, ()
|
415
462
|
|
416
463
|
@staticmethod
|
417
464
|
def _regularizers(
|
@@ -433,13 +480,28 @@ class RelaxedRigidContacts(ContactModel):
|
|
433
480
|
A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
|
434
481
|
"""
|
435
482
|
|
436
|
-
|
437
|
-
|
483
|
+
# Extract the parameters of the contact model.
|
484
|
+
Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = (
|
485
|
+
getattr(parameters, field)
|
486
|
+
for field in (
|
487
|
+
"time_constant",
|
488
|
+
"damping_coefficient",
|
489
|
+
"d_min",
|
490
|
+
"d_max",
|
491
|
+
"width",
|
492
|
+
"midpoint",
|
493
|
+
"power",
|
494
|
+
"stiffness",
|
495
|
+
"damping",
|
496
|
+
"mu",
|
497
|
+
)
|
438
498
|
)
|
439
499
|
|
440
|
-
|
441
|
-
|
442
|
-
|
500
|
+
# Compute the 6D inertia matrices of all links.
|
501
|
+
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
502
|
+
|
503
|
+
def imp_aref(
|
504
|
+
penetration: jtp.Array, velocity: jtp.Array
|
443
505
|
) -> tuple[jtp.Array, jtp.Array]:
|
444
506
|
"""
|
445
507
|
Calculates impedance and offset acceleration in constraint frame.
|
@@ -474,7 +536,7 @@ class RelaxedRigidContacts(ContactModel):
|
|
474
536
|
|
475
537
|
return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
|
476
538
|
|
477
|
-
def
|
539
|
+
def compute_row(
|
478
540
|
*,
|
479
541
|
link_idx: jtp.Float,
|
480
542
|
penetration: jtp.Array,
|
@@ -482,7 +544,7 @@ class RelaxedRigidContacts(ContactModel):
|
|
482
544
|
) -> tuple[jtp.Array, jtp.Array]:
|
483
545
|
|
484
546
|
# Compute the reference acceleration.
|
485
|
-
ξ, a_ref, K, D =
|
547
|
+
ξ, a_ref, K, D = imp_aref(
|
486
548
|
penetration=penetration,
|
487
549
|
velocity=velocity,
|
488
550
|
)
|
@@ -496,12 +558,10 @@ class RelaxedRigidContacts(ContactModel):
|
|
496
558
|
|
497
559
|
return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
|
498
560
|
|
499
|
-
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
500
|
-
|
501
561
|
a_ref, R, K, D = jax.tree.map(
|
502
|
-
jnp.concatenate,
|
503
|
-
(
|
504
|
-
*jax.vmap(
|
562
|
+
f=jnp.concatenate,
|
563
|
+
tree=(
|
564
|
+
*jax.vmap(compute_row)(
|
505
565
|
link_idx=jnp.array(
|
506
566
|
model.kin_dyn_parameters.contact_parameters.body
|
507
567
|
),
|
@@ -510,4 +570,5 @@ class RelaxedRigidContacts(ContactModel):
|
|
510
570
|
),
|
511
571
|
),
|
512
572
|
)
|
573
|
+
|
513
574
|
return a_ref, jnp.diag(R), K, D
|