jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,605 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from collections.abc import Callable
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import jax_dataclasses
|
10
|
+
import optax
|
11
|
+
|
12
|
+
import jaxsim.api as js
|
13
|
+
import jaxsim.rbda.contacts
|
14
|
+
import jaxsim.typing as jtp
|
15
|
+
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
16
|
+
|
17
|
+
from . import common
|
18
|
+
|
19
|
+
try:
|
20
|
+
from typing import Self
|
21
|
+
except ImportError:
|
22
|
+
from typing_extensions import Self
|
23
|
+
|
24
|
+
|
25
|
+
@jax_dataclasses.pytree_dataclass
|
26
|
+
class RelaxedRigidContactsParams(common.ContactsParams):
|
27
|
+
"""Parameters of the relaxed rigid contacts model."""
|
28
|
+
|
29
|
+
# Time constant
|
30
|
+
time_constant: jtp.Float = dataclasses.field(
|
31
|
+
default_factory=lambda: jnp.array(0.01, dtype=float)
|
32
|
+
)
|
33
|
+
|
34
|
+
# Adimensional damping coefficient
|
35
|
+
damping_coefficient: jtp.Float = dataclasses.field(
|
36
|
+
default_factory=lambda: jnp.array(1.0, dtype=float)
|
37
|
+
)
|
38
|
+
|
39
|
+
# Minimum impedance
|
40
|
+
d_min: jtp.Float = dataclasses.field(
|
41
|
+
default_factory=lambda: jnp.array(0.9, dtype=float)
|
42
|
+
)
|
43
|
+
|
44
|
+
# Maximum impedance
|
45
|
+
d_max: jtp.Float = dataclasses.field(
|
46
|
+
default_factory=lambda: jnp.array(0.95, dtype=float)
|
47
|
+
)
|
48
|
+
|
49
|
+
# Width
|
50
|
+
width: jtp.Float = dataclasses.field(
|
51
|
+
default_factory=lambda: jnp.array(0.0001, dtype=float)
|
52
|
+
)
|
53
|
+
|
54
|
+
# Midpoint
|
55
|
+
midpoint: jtp.Float = dataclasses.field(
|
56
|
+
default_factory=lambda: jnp.array(0.1, dtype=float)
|
57
|
+
)
|
58
|
+
|
59
|
+
# Power exponent
|
60
|
+
power: jtp.Float = dataclasses.field(
|
61
|
+
default_factory=lambda: jnp.array(1.0, dtype=float)
|
62
|
+
)
|
63
|
+
|
64
|
+
# Stiffness
|
65
|
+
stiffness: jtp.Float = dataclasses.field(
|
66
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
67
|
+
)
|
68
|
+
|
69
|
+
# Damping
|
70
|
+
damping: jtp.Float = dataclasses.field(
|
71
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
72
|
+
)
|
73
|
+
|
74
|
+
# Friction coefficient
|
75
|
+
mu: jtp.Float = dataclasses.field(
|
76
|
+
default_factory=lambda: jnp.array(0.5, dtype=float)
|
77
|
+
)
|
78
|
+
|
79
|
+
def __hash__(self) -> int:
|
80
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
81
|
+
|
82
|
+
return hash(
|
83
|
+
(
|
84
|
+
HashedNumpyArray(self.time_constant),
|
85
|
+
HashedNumpyArray(self.damping_coefficient),
|
86
|
+
HashedNumpyArray(self.d_min),
|
87
|
+
HashedNumpyArray(self.d_max),
|
88
|
+
HashedNumpyArray(self.width),
|
89
|
+
HashedNumpyArray(self.midpoint),
|
90
|
+
HashedNumpyArray(self.power),
|
91
|
+
HashedNumpyArray(self.stiffness),
|
92
|
+
HashedNumpyArray(self.damping),
|
93
|
+
HashedNumpyArray(self.mu),
|
94
|
+
)
|
95
|
+
)
|
96
|
+
|
97
|
+
def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
|
98
|
+
return hash(self) == hash(other)
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
def build(
|
102
|
+
cls: type[Self],
|
103
|
+
*,
|
104
|
+
time_constant: jtp.FloatLike | None = None,
|
105
|
+
damping_coefficient: jtp.FloatLike | None = None,
|
106
|
+
d_min: jtp.FloatLike | None = None,
|
107
|
+
d_max: jtp.FloatLike | None = None,
|
108
|
+
width: jtp.FloatLike | None = None,
|
109
|
+
midpoint: jtp.FloatLike | None = None,
|
110
|
+
power: jtp.FloatLike | None = None,
|
111
|
+
stiffness: jtp.FloatLike | None = None,
|
112
|
+
damping: jtp.FloatLike | None = None,
|
113
|
+
mu: jtp.FloatLike | None = None,
|
114
|
+
) -> Self:
|
115
|
+
"""Create a `RelaxedRigidContactsParams` instance."""
|
116
|
+
|
117
|
+
def default(name: str):
|
118
|
+
return cls.__dataclass_fields__[name].default_factory()
|
119
|
+
|
120
|
+
return cls(
|
121
|
+
time_constant=jnp.array(
|
122
|
+
(
|
123
|
+
time_constant
|
124
|
+
if time_constant is not None
|
125
|
+
else default("time_constant")
|
126
|
+
),
|
127
|
+
dtype=float,
|
128
|
+
),
|
129
|
+
damping_coefficient=jnp.array(
|
130
|
+
(
|
131
|
+
damping_coefficient
|
132
|
+
if damping_coefficient is not None
|
133
|
+
else default("damping_coefficient")
|
134
|
+
),
|
135
|
+
dtype=float,
|
136
|
+
),
|
137
|
+
d_min=jnp.array(
|
138
|
+
d_min if d_min is not None else default("d_min"), dtype=float
|
139
|
+
),
|
140
|
+
d_max=jnp.array(
|
141
|
+
d_max if d_max is not None else default("d_max"), dtype=float
|
142
|
+
),
|
143
|
+
width=jnp.array(
|
144
|
+
width if width is not None else default("width"), dtype=float
|
145
|
+
),
|
146
|
+
midpoint=jnp.array(
|
147
|
+
midpoint if midpoint is not None else default("midpoint"), dtype=float
|
148
|
+
),
|
149
|
+
power=jnp.array(
|
150
|
+
power if power is not None else default("power"), dtype=float
|
151
|
+
),
|
152
|
+
stiffness=jnp.array(
|
153
|
+
stiffness if stiffness is not None else default("stiffness"),
|
154
|
+
dtype=float,
|
155
|
+
),
|
156
|
+
damping=jnp.array(
|
157
|
+
damping if damping is not None else default("damping"), dtype=float
|
158
|
+
),
|
159
|
+
mu=jnp.array(mu if mu is not None else default("mu"), dtype=float),
|
160
|
+
)
|
161
|
+
|
162
|
+
def valid(self) -> jtp.BoolLike:
|
163
|
+
"""Check if the parameters are valid."""
|
164
|
+
|
165
|
+
return bool(
|
166
|
+
jnp.all(self.time_constant >= 0.0)
|
167
|
+
and jnp.all(self.damping_coefficient > 0.0)
|
168
|
+
and jnp.all(self.d_min >= 0.0)
|
169
|
+
and jnp.all(self.d_max <= 1.0)
|
170
|
+
and jnp.all(self.d_min <= self.d_max)
|
171
|
+
and jnp.all(self.width >= 0.0)
|
172
|
+
and jnp.all(self.midpoint >= 0.0)
|
173
|
+
and jnp.all(self.power >= 0.0)
|
174
|
+
and jnp.all(self.mu >= 0.0)
|
175
|
+
)
|
176
|
+
|
177
|
+
|
178
|
+
@jax_dataclasses.pytree_dataclass
|
179
|
+
class RelaxedRigidContacts(common.ContactModel):
|
180
|
+
"""Relaxed rigid contacts model."""
|
181
|
+
|
182
|
+
_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
|
183
|
+
default=("tol", "maxiter", "memory_size"), kw_only=True
|
184
|
+
)
|
185
|
+
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
|
186
|
+
default=(1e-6, 50, 10), kw_only=True
|
187
|
+
)
|
188
|
+
|
189
|
+
@property
|
190
|
+
def solver_options(self) -> dict[str, Any]:
|
191
|
+
"""Get the solver options."""
|
192
|
+
|
193
|
+
return dict(
|
194
|
+
zip(
|
195
|
+
self._solver_options_keys,
|
196
|
+
self._solver_options_values,
|
197
|
+
strict=True,
|
198
|
+
)
|
199
|
+
)
|
200
|
+
|
201
|
+
@classmethod
|
202
|
+
def build(
|
203
|
+
cls: type[Self],
|
204
|
+
solver_options: dict[str, Any] | None = None,
|
205
|
+
**kwargs,
|
206
|
+
) -> Self:
|
207
|
+
"""
|
208
|
+
Create a `RelaxedRigidContacts` instance with specified parameters.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
solver_options: The options to pass to the L-BFGS solver.
|
212
|
+
**kwargs: The parameters of the relaxed rigid contacts model.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
The `RelaxedRigidContacts` instance.
|
216
|
+
"""
|
217
|
+
|
218
|
+
# Get the default solver options.
|
219
|
+
default_solver_options = dict(
|
220
|
+
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
|
221
|
+
)
|
222
|
+
|
223
|
+
# Create the solver options to set by combining the default solver options
|
224
|
+
# with the user-provided solver options.
|
225
|
+
solver_options = default_solver_options | (
|
226
|
+
solver_options if solver_options is not None else {}
|
227
|
+
)
|
228
|
+
|
229
|
+
# Make sure that the solver options are hashable.
|
230
|
+
# We need to check this because the solver options are static.
|
231
|
+
try:
|
232
|
+
hash(tuple(solver_options.values()))
|
233
|
+
except TypeError as exc:
|
234
|
+
raise ValueError(
|
235
|
+
"The values of the solver options must be hashable."
|
236
|
+
) from exc
|
237
|
+
|
238
|
+
return cls(
|
239
|
+
_solver_options_keys=tuple(solver_options.keys()),
|
240
|
+
_solver_options_values=tuple(solver_options.values()),
|
241
|
+
**kwargs,
|
242
|
+
)
|
243
|
+
|
244
|
+
@jax.jit
|
245
|
+
def compute_contact_forces(
|
246
|
+
self,
|
247
|
+
model: js.model.JaxSimModel,
|
248
|
+
data: js.data.JaxSimModelData,
|
249
|
+
*,
|
250
|
+
link_forces: jtp.MatrixLike | None = None,
|
251
|
+
joint_force_references: jtp.VectorLike | None = None,
|
252
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
253
|
+
"""
|
254
|
+
Compute the contact forces.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
model: The model to consider.
|
258
|
+
data: The data of the considered model.
|
259
|
+
link_forces:
|
260
|
+
Optional `(n_links, 6)` matrix of external forces acting on the links,
|
261
|
+
expressed in the same representation of data.
|
262
|
+
joint_force_references:
|
263
|
+
Optional `(n_joints,)` vector of joint forces.
|
264
|
+
|
265
|
+
Returns:
|
266
|
+
A tuple containing as first element the computed contact forces.
|
267
|
+
"""
|
268
|
+
|
269
|
+
link_forces = jnp.atleast_2d(
|
270
|
+
jnp.array(link_forces, dtype=float).squeeze()
|
271
|
+
if link_forces is not None
|
272
|
+
else jnp.zeros((model.number_of_links(), 6))
|
273
|
+
)
|
274
|
+
|
275
|
+
joint_force_references = jnp.atleast_1d(
|
276
|
+
jnp.array(joint_force_references, dtype=float).squeeze()
|
277
|
+
if joint_force_references is not None
|
278
|
+
else jnp.zeros(model.number_of_joints())
|
279
|
+
)
|
280
|
+
|
281
|
+
references = js.references.JaxSimModelReferences.build(
|
282
|
+
model=model,
|
283
|
+
data=data,
|
284
|
+
velocity_representation=data.velocity_representation,
|
285
|
+
link_forces=link_forces,
|
286
|
+
joint_force_references=joint_force_references,
|
287
|
+
)
|
288
|
+
|
289
|
+
# Compute the position and linear velocities (mixed representation) of
|
290
|
+
# all collidable points belonging to the robot.
|
291
|
+
position, velocity = js.contact.collidable_point_kinematics(
|
292
|
+
model=model, data=data
|
293
|
+
)
|
294
|
+
|
295
|
+
# Compute the penetration depth and velocity of the collidable points.
|
296
|
+
# Note that this function considers the penetration in the normal direction.
|
297
|
+
δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
|
298
|
+
position, velocity, model.terrain
|
299
|
+
)
|
300
|
+
|
301
|
+
# Compute the position in the constraint frame.
|
302
|
+
position_constraint = jax.vmap(lambda δ, n̂: -δ * n̂)(δ, n̂)
|
303
|
+
|
304
|
+
# Compute the transforms of the implicit frames corresponding to the
|
305
|
+
# collidable points.
|
306
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
307
|
+
|
308
|
+
with (
|
309
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
310
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
311
|
+
):
|
312
|
+
|
313
|
+
BW_ν = data.generalized_velocity()
|
314
|
+
|
315
|
+
BW_ν̇_free = jnp.hstack(
|
316
|
+
js.ode.system_acceleration(
|
317
|
+
model=model,
|
318
|
+
data=data,
|
319
|
+
link_forces=references.link_forces(model=model, data=data),
|
320
|
+
joint_force_references=references.joint_force_references(
|
321
|
+
model=model
|
322
|
+
),
|
323
|
+
)
|
324
|
+
)
|
325
|
+
|
326
|
+
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
327
|
+
|
328
|
+
Jl_WC = jnp.vstack(
|
329
|
+
jax.vmap(lambda J, δ: J * (δ > 0))(
|
330
|
+
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
331
|
+
)
|
332
|
+
)
|
333
|
+
|
334
|
+
J̇_WC = jnp.vstack(
|
335
|
+
jax.vmap(lambda J̇, δ: J̇ * (δ > 0))(
|
336
|
+
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
|
337
|
+
),
|
338
|
+
)
|
339
|
+
|
340
|
+
# Compute the regularization terms.
|
341
|
+
a_ref, R, *_ = self._regularizers(
|
342
|
+
model=model,
|
343
|
+
position_constraint=position_constraint,
|
344
|
+
velocity_constraint=velocity,
|
345
|
+
parameters=data.contacts_params,
|
346
|
+
)
|
347
|
+
|
348
|
+
# Compute the Delassus matrix and the free mixed linear acceleration of
|
349
|
+
# the collidable points.
|
350
|
+
G = Jl_WC @ jnp.linalg.pinv(M) @ Jl_WC.T
|
351
|
+
CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
|
352
|
+
|
353
|
+
# Calculate quantities for the linear optimization problem.
|
354
|
+
A = G + R
|
355
|
+
b = CW_al_free_WC - a_ref
|
356
|
+
|
357
|
+
# Create the objective function to minimize as a lambda computing the cost
|
358
|
+
# from the optimized variables x.
|
359
|
+
objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))
|
360
|
+
|
361
|
+
# ========================================
|
362
|
+
# Helper function to run the L-BFGS solver
|
363
|
+
# ========================================
|
364
|
+
|
365
|
+
def run_optimization(
|
366
|
+
init_params: jtp.Vector,
|
367
|
+
fun: Callable,
|
368
|
+
opt: optax.GradientTransformationExtraArgs,
|
369
|
+
maxiter: int,
|
370
|
+
tol: float,
|
371
|
+
) -> tuple[jtp.Vector, optax.OptState]:
|
372
|
+
|
373
|
+
# Get the function to compute the loss and the gradient w.r.t. its inputs.
|
374
|
+
value_and_grad_fn = optax.value_and_grad_from_state(fun)
|
375
|
+
|
376
|
+
# Initialize the carry of the following loop.
|
377
|
+
OptimizationCarry = tuple[jtp.Vector, optax.OptState]
|
378
|
+
init_carry: OptimizationCarry = (init_params, opt.init(params=init_params))
|
379
|
+
|
380
|
+
def step(carry: OptimizationCarry) -> OptimizationCarry:
|
381
|
+
|
382
|
+
params, state = carry
|
383
|
+
|
384
|
+
value, grad = value_and_grad_fn(
|
385
|
+
params,
|
386
|
+
state=state,
|
387
|
+
A=A,
|
388
|
+
b=b,
|
389
|
+
)
|
390
|
+
|
391
|
+
updates, state = opt.update(
|
392
|
+
updates=grad,
|
393
|
+
state=state,
|
394
|
+
params=params,
|
395
|
+
value=value,
|
396
|
+
grad=grad,
|
397
|
+
value_fn=fun,
|
398
|
+
A=A,
|
399
|
+
b=b,
|
400
|
+
)
|
401
|
+
|
402
|
+
params = optax.apply_updates(params, updates)
|
403
|
+
|
404
|
+
return params, state
|
405
|
+
|
406
|
+
# TODO: maybe fix the number of iterations and switch to scan?
|
407
|
+
def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
|
408
|
+
|
409
|
+
_, state = carry
|
410
|
+
|
411
|
+
iter_num = optax.tree_utils.tree_get(state, "count")
|
412
|
+
grad = optax.tree_utils.tree_get(state, "grad")
|
413
|
+
err = optax.tree_utils.tree_l2_norm(grad)
|
414
|
+
|
415
|
+
return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))
|
416
|
+
|
417
|
+
final_params, final_state = jax.lax.while_loop(
|
418
|
+
continuing_criterion, step, init_carry
|
419
|
+
)
|
420
|
+
|
421
|
+
return final_params, final_state
|
422
|
+
|
423
|
+
# ======================================
|
424
|
+
# Compute the contact forces with L-BFGS
|
425
|
+
# ======================================
|
426
|
+
|
427
|
+
# Initialize the optimized forces with a linear Hunt/Crossley model.
|
428
|
+
init_params = jax.vmap(
|
429
|
+
lambda p, v: jaxsim.rbda.contacts.SoftContacts.hunt_crossley_contact_model(
|
430
|
+
position=p,
|
431
|
+
velocity=v,
|
432
|
+
terrain=model.terrain,
|
433
|
+
K=1e6,
|
434
|
+
D=2e3,
|
435
|
+
p=0.5,
|
436
|
+
q=0.5,
|
437
|
+
# No tangential initial forces.
|
438
|
+
mu=0.0,
|
439
|
+
tangential_deformation=jnp.zeros(3),
|
440
|
+
)[0]
|
441
|
+
)(position, velocity).flatten()
|
442
|
+
|
443
|
+
# Get the solver options.
|
444
|
+
solver_options = self.solver_options
|
445
|
+
|
446
|
+
# Extract the options corresponding to the convergence criteria.
|
447
|
+
# All the remaining options are passed to the solver.
|
448
|
+
tol = solver_options.pop("tol")
|
449
|
+
maxiter = solver_options.pop("maxiter")
|
450
|
+
|
451
|
+
# Compute the 3D linear force in C[W] frame.
|
452
|
+
solution, _ = run_optimization(
|
453
|
+
init_params=init_params,
|
454
|
+
fun=objective,
|
455
|
+
opt=optax.lbfgs(**solver_options),
|
456
|
+
tol=tol,
|
457
|
+
maxiter=maxiter,
|
458
|
+
)
|
459
|
+
|
460
|
+
# Reshape the optimized solution to be a matrix of 3D contact forces.
|
461
|
+
CW_fl_C = solution.reshape(-1, 3)
|
462
|
+
|
463
|
+
# Convert the contact forces from mixed to inertial-fixed representation.
|
464
|
+
W_f_C = jax.vmap(
|
465
|
+
lambda CW_fl_C, W_H_C: (
|
466
|
+
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
467
|
+
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
|
468
|
+
transform=W_H_C,
|
469
|
+
other_representation=VelRepr.Mixed,
|
470
|
+
is_force=True,
|
471
|
+
)
|
472
|
+
),
|
473
|
+
)(CW_fl_C, W_H_C)
|
474
|
+
|
475
|
+
return W_f_C, {}
|
476
|
+
|
477
|
+
@staticmethod
|
478
|
+
def _regularizers(
|
479
|
+
model: js.model.JaxSimModel,
|
480
|
+
position_constraint: jtp.Vector,
|
481
|
+
velocity_constraint: jtp.Vector,
|
482
|
+
parameters: RelaxedRigidContactsParams,
|
483
|
+
) -> tuple:
|
484
|
+
"""
|
485
|
+
Compute the contact jacobian and the reference acceleration.
|
486
|
+
|
487
|
+
Args:
|
488
|
+
model: The jaxsim model.
|
489
|
+
position_constraint: The position of the collidable points in the constraint frame.
|
490
|
+
velocity_constraint: The velocity of the collidable points in the constraint frame.
|
491
|
+
parameters: The parameters of the relaxed rigid contacts model.
|
492
|
+
|
493
|
+
Returns:
|
494
|
+
A tuple containing the reference acceleration, the regularization matrix,
|
495
|
+
the stiffness, and the damping.
|
496
|
+
"""
|
497
|
+
|
498
|
+
# Extract the parameters of the contact model.
|
499
|
+
Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = (
|
500
|
+
getattr(parameters, field)
|
501
|
+
for field in (
|
502
|
+
"time_constant",
|
503
|
+
"damping_coefficient",
|
504
|
+
"d_min",
|
505
|
+
"d_max",
|
506
|
+
"width",
|
507
|
+
"midpoint",
|
508
|
+
"power",
|
509
|
+
"stiffness",
|
510
|
+
"damping",
|
511
|
+
"mu",
|
512
|
+
)
|
513
|
+
)
|
514
|
+
|
515
|
+
# Get the indices of the enabled collidable points.
|
516
|
+
indices_of_enabled_collidable_points = (
|
517
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
518
|
+
)
|
519
|
+
|
520
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
521
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
522
|
+
)[indices_of_enabled_collidable_points]
|
523
|
+
|
524
|
+
# Compute the 6D inertia matrices of all links.
|
525
|
+
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
526
|
+
|
527
|
+
def imp_aref(
|
528
|
+
pos: jtp.Vector,
|
529
|
+
vel: jtp.Vector,
|
530
|
+
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector]:
|
531
|
+
"""
|
532
|
+
Calculate impedance and offset acceleration in constraint frame.
|
533
|
+
|
534
|
+
Args:
|
535
|
+
pos: position in constraint frame.
|
536
|
+
vel: velocity in constraint frame.
|
537
|
+
|
538
|
+
Returns:
|
539
|
+
ξ: computed impedance
|
540
|
+
a_ref: offset acceleration in constraint frame
|
541
|
+
K: computed stiffness
|
542
|
+
D: computed damping
|
543
|
+
"""
|
544
|
+
|
545
|
+
imp_x = jnp.abs(pos) / width
|
546
|
+
|
547
|
+
imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
|
548
|
+
imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
|
549
|
+
imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
|
550
|
+
|
551
|
+
# Compute the impedance.
|
552
|
+
ξ = ξ_min + imp_y * (ξ_max - ξ_min)
|
553
|
+
ξ = jnp.clip(ξ, ξ_min, ξ_max)
|
554
|
+
ξ = jnp.where(imp_x > 1.0, ξ_max, ξ)
|
555
|
+
|
556
|
+
# Compute the spring and damper parameters during runtime from the
|
557
|
+
# impedance and other contact parameters.
|
558
|
+
K = 1 / (ξ_max * Ω * ζ) ** 2
|
559
|
+
D = 2 / (ξ_max * Ω)
|
560
|
+
|
561
|
+
# If the user specifies K and D and they are negative, the computed `a_ref`
|
562
|
+
# becomes something more similar to a classic Baumgarte regularization.
|
563
|
+
K = jnp.where(K < 0, -K / ξ_max**2, K)
|
564
|
+
D = jnp.where(D < 0, -D / ξ_max, D)
|
565
|
+
|
566
|
+
# Compute the reference acceleration.
|
567
|
+
a_ref = -(D * vel + K * ξ * pos)
|
568
|
+
|
569
|
+
return ξ, a_ref, K, D
|
570
|
+
|
571
|
+
def compute_row(
|
572
|
+
*,
|
573
|
+
link_idx: jtp.Int,
|
574
|
+
pos: jtp.Vector,
|
575
|
+
vel: jtp.Vector,
|
576
|
+
) -> tuple[jtp.Vector, jtp.Matrix, jtp.Vector, jtp.Vector]:
|
577
|
+
|
578
|
+
# Compute the reference acceleration.
|
579
|
+
ξ, a_ref, K, D = imp_aref(pos=pos, vel=vel)
|
580
|
+
|
581
|
+
# Compute the regularization term.
|
582
|
+
R = (
|
583
|
+
(2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
|
584
|
+
* (1 + μ**2)
|
585
|
+
@ jnp.linalg.inv(M_L[link_idx, :3, :3])
|
586
|
+
)
|
587
|
+
|
588
|
+
# Return the computed values, setting them to zero in case of no contact.
|
589
|
+
is_active = (pos.dot(pos) > 0).astype(float)
|
590
|
+
return jax.tree.map(
|
591
|
+
lambda x: jnp.atleast_1d(x) * is_active, (a_ref, R, K, D)
|
592
|
+
)
|
593
|
+
|
594
|
+
a_ref, R, K, D = jax.tree.map(
|
595
|
+
f=jnp.concatenate,
|
596
|
+
tree=(
|
597
|
+
*jax.vmap(compute_row)(
|
598
|
+
link_idx=parent_link_idx_of_enabled_collidable_points,
|
599
|
+
pos=position_constraint,
|
600
|
+
vel=velocity_constraint,
|
601
|
+
),
|
602
|
+
),
|
603
|
+
)
|
604
|
+
|
605
|
+
return a_ref, jnp.diag(R), K, D
|