jaxsim 0.6.1.dev13__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
- jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.1.dev13.dist-info/RECORD +0 -74
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
@@ -1,1066 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import dataclasses
|
4
|
-
import functools
|
5
|
-
from typing import Any
|
6
|
-
|
7
|
-
import jax
|
8
|
-
import jax.numpy as jnp
|
9
|
-
import jax_dataclasses
|
10
|
-
|
11
|
-
import jaxsim
|
12
|
-
import jaxsim.api as js
|
13
|
-
import jaxsim.exceptions
|
14
|
-
import jaxsim.typing as jtp
|
15
|
-
from jaxsim import logging
|
16
|
-
from jaxsim.api.common import ModelDataWithVelocityRepresentation
|
17
|
-
from jaxsim.math import StandardGravity
|
18
|
-
from jaxsim.terrain import Terrain
|
19
|
-
|
20
|
-
from . import common
|
21
|
-
from .soft import SoftContacts, SoftContactsParams
|
22
|
-
|
23
|
-
try:
|
24
|
-
from typing import Self
|
25
|
-
except ImportError:
|
26
|
-
from typing_extensions import Self
|
27
|
-
|
28
|
-
|
29
|
-
@jax_dataclasses.pytree_dataclass
|
30
|
-
class ViscoElasticContactsParams(common.ContactsParams):
|
31
|
-
"""Parameters of the visco-elastic contacts model."""
|
32
|
-
|
33
|
-
K: jtp.Float = dataclasses.field(
|
34
|
-
default_factory=lambda: jnp.array(1e6, dtype=float)
|
35
|
-
)
|
36
|
-
|
37
|
-
D: jtp.Float = dataclasses.field(
|
38
|
-
default_factory=lambda: jnp.array(2000, dtype=float)
|
39
|
-
)
|
40
|
-
|
41
|
-
static_friction: jtp.Float = dataclasses.field(
|
42
|
-
default_factory=lambda: jnp.array(0.5, dtype=float)
|
43
|
-
)
|
44
|
-
|
45
|
-
p: jtp.Float = dataclasses.field(
|
46
|
-
default_factory=lambda: jnp.array(0.5, dtype=float)
|
47
|
-
)
|
48
|
-
|
49
|
-
q: jtp.Float = dataclasses.field(
|
50
|
-
default_factory=lambda: jnp.array(0.5, dtype=float)
|
51
|
-
)
|
52
|
-
|
53
|
-
@classmethod
|
54
|
-
def build(
|
55
|
-
cls: type[Self],
|
56
|
-
K: jtp.FloatLike = 1e6,
|
57
|
-
D: jtp.FloatLike = 2_000,
|
58
|
-
static_friction: jtp.FloatLike = 0.5,
|
59
|
-
p: jtp.FloatLike = 0.5,
|
60
|
-
q: jtp.FloatLike = 0.5,
|
61
|
-
) -> Self:
|
62
|
-
"""
|
63
|
-
Create a SoftContactsParams instance with specified parameters.
|
64
|
-
|
65
|
-
Args:
|
66
|
-
K: The stiffness parameter.
|
67
|
-
D: The damping parameter of the soft contacts model.
|
68
|
-
static_friction: The static friction coefficient.
|
69
|
-
p:
|
70
|
-
The exponent p corresponding to the damping-related non-linearity
|
71
|
-
of the Hunt/Crossley model.
|
72
|
-
q:
|
73
|
-
The exponent q corresponding to the spring-related non-linearity
|
74
|
-
of the Hunt/Crossley model.
|
75
|
-
|
76
|
-
Returns:
|
77
|
-
A ViscoElasticParams instance with the specified parameters.
|
78
|
-
"""
|
79
|
-
|
80
|
-
return ViscoElasticContactsParams(
|
81
|
-
K=jnp.array(K, dtype=float),
|
82
|
-
D=jnp.array(D, dtype=float),
|
83
|
-
static_friction=jnp.array(static_friction, dtype=float),
|
84
|
-
p=jnp.array(p, dtype=float),
|
85
|
-
q=jnp.array(q, dtype=float),
|
86
|
-
)
|
87
|
-
|
88
|
-
@classmethod
|
89
|
-
def build_default_from_jaxsim_model(
|
90
|
-
cls: type[Self],
|
91
|
-
model: js.model.JaxSimModel,
|
92
|
-
*,
|
93
|
-
standard_gravity: jtp.FloatLike = StandardGravity,
|
94
|
-
static_friction_coefficient: jtp.FloatLike = 0.5,
|
95
|
-
max_penetration: jtp.FloatLike = 0.001,
|
96
|
-
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
97
|
-
damping_ratio: jtp.FloatLike = 1.0,
|
98
|
-
p: jtp.FloatLike = 0.5,
|
99
|
-
q: jtp.FloatLike = 0.5,
|
100
|
-
) -> Self:
|
101
|
-
"""
|
102
|
-
Create a ViscoElasticContactsParams instance with good default parameters.
|
103
|
-
|
104
|
-
Args:
|
105
|
-
model: The target model.
|
106
|
-
standard_gravity: The standard gravity constant.
|
107
|
-
static_friction_coefficient:
|
108
|
-
The static friction coefficient between the model and the terrain.
|
109
|
-
max_penetration: The maximum penetration depth.
|
110
|
-
number_of_active_collidable_points_steady_state:
|
111
|
-
The number of contacts supporting the weight of the model
|
112
|
-
in steady state.
|
113
|
-
damping_ratio: The ratio controlling the damping behavior.
|
114
|
-
p:
|
115
|
-
The exponent p corresponding to the damping-related non-linearity
|
116
|
-
of the Hunt/Crossley model.
|
117
|
-
q:
|
118
|
-
The exponent q corresponding to the spring-related non-linearity
|
119
|
-
of the Hunt/Crossley model.
|
120
|
-
|
121
|
-
Returns:
|
122
|
-
A `ViscoElasticContactsParams` instance with the specified parameters.
|
123
|
-
|
124
|
-
Note:
|
125
|
-
The `damping_ratio` parameter allows to operate on the following conditions:
|
126
|
-
- ξ > 1.0: over-damped
|
127
|
-
- ξ = 1.0: critically damped
|
128
|
-
- ξ < 1.0: under-damped
|
129
|
-
"""
|
130
|
-
|
131
|
-
# Call the SoftContact builder instead of duplicating the logic.
|
132
|
-
soft_contacts_params = SoftContactsParams.build_default_from_jaxsim_model(
|
133
|
-
model=model,
|
134
|
-
standard_gravity=standard_gravity,
|
135
|
-
static_friction_coefficient=static_friction_coefficient,
|
136
|
-
max_penetration=max_penetration,
|
137
|
-
number_of_active_collidable_points_steady_state=number_of_active_collidable_points_steady_state,
|
138
|
-
damping_ratio=damping_ratio,
|
139
|
-
)
|
140
|
-
|
141
|
-
return ViscoElasticContactsParams.build(
|
142
|
-
K=soft_contacts_params.K,
|
143
|
-
D=soft_contacts_params.D,
|
144
|
-
static_friction=soft_contacts_params.mu,
|
145
|
-
p=p,
|
146
|
-
q=q,
|
147
|
-
)
|
148
|
-
|
149
|
-
def valid(self) -> jtp.BoolLike:
|
150
|
-
"""
|
151
|
-
Check if the parameters are valid.
|
152
|
-
|
153
|
-
Returns:
|
154
|
-
`True` if the parameters are valid, `False` otherwise.
|
155
|
-
"""
|
156
|
-
|
157
|
-
return (
|
158
|
-
jnp.all(self.K >= 0.0)
|
159
|
-
and jnp.all(self.D >= 0.0)
|
160
|
-
and jnp.all(self.static_friction >= 0.0)
|
161
|
-
and jnp.all(self.p >= 0.0)
|
162
|
-
and jnp.all(self.q >= 0.0)
|
163
|
-
)
|
164
|
-
|
165
|
-
def __hash__(self) -> int:
|
166
|
-
|
167
|
-
from jaxsim.utils.wrappers import HashedNumpyArray
|
168
|
-
|
169
|
-
return hash(
|
170
|
-
(
|
171
|
-
HashedNumpyArray.hash_of_array(self.K),
|
172
|
-
HashedNumpyArray.hash_of_array(self.D),
|
173
|
-
HashedNumpyArray.hash_of_array(self.static_friction),
|
174
|
-
HashedNumpyArray.hash_of_array(self.p),
|
175
|
-
HashedNumpyArray.hash_of_array(self.q),
|
176
|
-
)
|
177
|
-
)
|
178
|
-
|
179
|
-
def __eq__(self, other: ViscoElasticContactsParams) -> bool:
|
180
|
-
|
181
|
-
if not isinstance(other, ViscoElasticContactsParams):
|
182
|
-
return False
|
183
|
-
|
184
|
-
return hash(self) == hash(other)
|
185
|
-
|
186
|
-
|
187
|
-
@jax_dataclasses.pytree_dataclass
|
188
|
-
class ViscoElasticContacts(common.ContactModel):
|
189
|
-
"""Visco-elastic contacts model."""
|
190
|
-
|
191
|
-
max_squarings: jax_dataclasses.Static[int] = dataclasses.field(default=25)
|
192
|
-
|
193
|
-
@classmethod
|
194
|
-
def build(
|
195
|
-
cls: type[Self],
|
196
|
-
model: js.model.JaxSimModel | None = None,
|
197
|
-
max_squarings: jtp.IntLike | None = None,
|
198
|
-
**kwargs,
|
199
|
-
) -> Self:
|
200
|
-
"""
|
201
|
-
Create a `ViscoElasticContacts` instance with specified parameters.
|
202
|
-
|
203
|
-
Args:
|
204
|
-
model:
|
205
|
-
The robot model considered by the contact model.
|
206
|
-
If passed, it is used to estimate good default parameters.
|
207
|
-
max_squarings:
|
208
|
-
The maximum number of squarings performed in the matrix exponential.
|
209
|
-
**kwargs: Extra arguments to ignore.
|
210
|
-
|
211
|
-
Returns:
|
212
|
-
The `ViscoElasticContacts` instance.
|
213
|
-
"""
|
214
|
-
|
215
|
-
if len(kwargs) != 0:
|
216
|
-
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
|
217
|
-
|
218
|
-
return cls(
|
219
|
-
max_squarings=int(
|
220
|
-
max_squarings
|
221
|
-
if max_squarings is not None
|
222
|
-
else cls.__dataclass_fields__["max_squarings"].default
|
223
|
-
),
|
224
|
-
)
|
225
|
-
|
226
|
-
@classmethod
|
227
|
-
def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
|
228
|
-
"""
|
229
|
-
Build zero state variables of the contact model.
|
230
|
-
"""
|
231
|
-
|
232
|
-
# Initialize the material deformation to zero.
|
233
|
-
tangential_deformation = jnp.zeros(
|
234
|
-
shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3),
|
235
|
-
dtype=float,
|
236
|
-
)
|
237
|
-
|
238
|
-
return {"tangential_deformation": tangential_deformation}
|
239
|
-
|
240
|
-
@jax.jit
|
241
|
-
def compute_contact_forces(
|
242
|
-
self,
|
243
|
-
model: js.model.JaxSimModel,
|
244
|
-
data: js.data.JaxSimModelData,
|
245
|
-
*,
|
246
|
-
dt: jtp.FloatLike | None = None,
|
247
|
-
link_forces: jtp.MatrixLike | None = None,
|
248
|
-
joint_force_references: jtp.VectorLike | None = None,
|
249
|
-
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
250
|
-
"""
|
251
|
-
Compute the contact forces.
|
252
|
-
|
253
|
-
Args:
|
254
|
-
model: The robot model considered by the contact model.
|
255
|
-
data: The data of the considered model.
|
256
|
-
dt: The time step to consider. If not specified, it is read from the model.
|
257
|
-
link_forces:
|
258
|
-
The 6D forces to apply to the links expressed in the frame corresponding
|
259
|
-
to the velocity representation of `data`.
|
260
|
-
joint_force_references: The joint force references to apply.
|
261
|
-
|
262
|
-
Note:
|
263
|
-
This contact model, contrarily to most other contact models, requires the
|
264
|
-
knowledge of the integration step. It is not straightforward to assess how
|
265
|
-
this contact model behaves when used with high-order Runge-Kutta schemes.
|
266
|
-
For the time being, it is recommended to use a simple forward Euler scheme.
|
267
|
-
The main benefit of this model is that the stiff contact dynamics is computed
|
268
|
-
separately from the rest of the system dynamics, which allows to use simple
|
269
|
-
integration schemes without altering significantly the simulation stability.
|
270
|
-
|
271
|
-
Returns:
|
272
|
-
A tuple containing as first element the computed 6D contact force applied to
|
273
|
-
the contact point and expressed in the world frame, and as second element
|
274
|
-
a dictionary of optional additional information.
|
275
|
-
"""
|
276
|
-
|
277
|
-
# Extract the indices corresponding to the enabled collidable points.
|
278
|
-
indices_of_enabled_collidable_points = (
|
279
|
-
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
280
|
-
)
|
281
|
-
|
282
|
-
# Initialize the time step.
|
283
|
-
dt = dt if dt is not None else model.time_step
|
284
|
-
|
285
|
-
# Compute the average contact linear forces in mixed representation by
|
286
|
-
# integrating the contact dynamics in the continuous time domain.
|
287
|
-
CW_f̅l, CW_fl̿, m_tf = (
|
288
|
-
ViscoElasticContacts._compute_contact_forces_with_exponential_integration(
|
289
|
-
model=model,
|
290
|
-
data=data,
|
291
|
-
dt=jnp.array(dt).astype(float),
|
292
|
-
link_forces=link_forces,
|
293
|
-
joint_force_references=joint_force_references,
|
294
|
-
indices_of_enabled_collidable_points=indices_of_enabled_collidable_points,
|
295
|
-
max_squarings=self.max_squarings,
|
296
|
-
)
|
297
|
-
)
|
298
|
-
|
299
|
-
# ============================================
|
300
|
-
# Compute the inertial-fixed 6D contact forces
|
301
|
-
# ============================================
|
302
|
-
|
303
|
-
# Compute the transforms of the mixed frames `C[W] = (W_p_C, [W])`
|
304
|
-
# associated to each collidable point.
|
305
|
-
W_H_C = js.contact.transforms(model=model, data=data)[
|
306
|
-
indices_of_enabled_collidable_points, :, :
|
307
|
-
]
|
308
|
-
|
309
|
-
# Vmapped transformation from mixed to inertial-fixed representation.
|
310
|
-
compute_forces_inertial_fixed_vmap = jax.vmap(
|
311
|
-
lambda CW_fl_C, W_H_C: (
|
312
|
-
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
313
|
-
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
|
314
|
-
other_representation=jaxsim.VelRepr.Mixed,
|
315
|
-
transform=W_H_C,
|
316
|
-
is_force=True,
|
317
|
-
)
|
318
|
-
)
|
319
|
-
)
|
320
|
-
|
321
|
-
# Express the linear contact forces in the inertial-fixed frame.
|
322
|
-
W_f̅_C, W_f̿_C = jax.vmap(
|
323
|
-
lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C)
|
324
|
-
)(jnp.stack([CW_f̅l, CW_fl̿]))
|
325
|
-
|
326
|
-
return W_f̅_C, dict(W_f_avg2_C=W_f̿_C, m_tf=m_tf)
|
327
|
-
|
328
|
-
@staticmethod
|
329
|
-
@functools.partial(jax.jit, static_argnames=("max_squarings",))
|
330
|
-
def _compute_contact_forces_with_exponential_integration(
|
331
|
-
model: js.model.JaxSimModel,
|
332
|
-
data: js.data.JaxSimModelData,
|
333
|
-
*,
|
334
|
-
dt: jtp.FloatLike,
|
335
|
-
link_forces: jtp.MatrixLike | None = None,
|
336
|
-
joint_force_references: jtp.VectorLike | None = None,
|
337
|
-
indices_of_enabled_collidable_points: jtp.VectorLike | None = None,
|
338
|
-
max_squarings: int = 25,
|
339
|
-
) -> tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]:
|
340
|
-
"""
|
341
|
-
Compute the average contact forces by integrating the contact dynamics.
|
342
|
-
|
343
|
-
Args:
|
344
|
-
model: The robot model considered by the contact model.
|
345
|
-
data: The data of the considered model.
|
346
|
-
dt: The integration time step.
|
347
|
-
link_forces: The 6D forces to apply to the links.
|
348
|
-
joint_force_references: The joint force references to apply.
|
349
|
-
indices_of_enabled_collidable_points:
|
350
|
-
The indices of the enabled collidable points.
|
351
|
-
max_squarings:
|
352
|
-
The maximum number of squarings performed in the matrix exponential.
|
353
|
-
|
354
|
-
Returns:
|
355
|
-
A tuple containing:
|
356
|
-
- The average contact forces.
|
357
|
-
- The average of the average contact forces.
|
358
|
-
- The tangential deformation at the final state.
|
359
|
-
"""
|
360
|
-
|
361
|
-
# ==========================
|
362
|
-
# Populate missing arguments
|
363
|
-
# ==========================
|
364
|
-
|
365
|
-
indices = (
|
366
|
-
indices_of_enabled_collidable_points
|
367
|
-
if indices_of_enabled_collidable_points is not None
|
368
|
-
else jnp.arange(
|
369
|
-
len(model.kin_dyn_parameters.contact_parameters.body)
|
370
|
-
).astype(int)
|
371
|
-
)
|
372
|
-
|
373
|
-
# ==================================
|
374
|
-
# Compute the contact point dynamics
|
375
|
-
# ==================================
|
376
|
-
|
377
|
-
p_t0, v_t0 = js.contact.collidable_point_kinematics(model, data)
|
378
|
-
m_t0 = data.state.extended["tangential_deformation"][indices, :]
|
379
|
-
|
380
|
-
p_t0 = p_t0[indices, :]
|
381
|
-
v_t0 = v_t0[indices, :]
|
382
|
-
|
383
|
-
# Compute the linearized contact dynamics.
|
384
|
-
# Note that it linearizes the (non-linear) contact model at (p, v, m)[t0].
|
385
|
-
A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics(
|
386
|
-
model=model,
|
387
|
-
data=data,
|
388
|
-
link_forces=link_forces,
|
389
|
-
joint_force_references=joint_force_references,
|
390
|
-
indices_of_enabled_collidable_points=indices,
|
391
|
-
p_t0=p_t0,
|
392
|
-
v_t0=v_t0,
|
393
|
-
m_t0=m_t0,
|
394
|
-
)
|
395
|
-
|
396
|
-
# =============================================
|
397
|
-
# Compute the integrals of the contact dynamics
|
398
|
-
# =============================================
|
399
|
-
|
400
|
-
# Pack the initial state of the contact points.
|
401
|
-
x_t0 = jnp.hstack([p_t0.flatten(), v_t0.flatten(), m_t0.flatten()])
|
402
|
-
|
403
|
-
# Pack the augmented matrix used to compute the single and double integral
|
404
|
-
# of the exponential integration.
|
405
|
-
A̅ = jnp.vstack(
|
406
|
-
[
|
407
|
-
jnp.hstack(
|
408
|
-
[
|
409
|
-
A,
|
410
|
-
jnp.vstack(b),
|
411
|
-
jnp.vstack(x_t0),
|
412
|
-
jnp.vstack(jnp.zeros_like(x_t0)),
|
413
|
-
]
|
414
|
-
),
|
415
|
-
jnp.hstack([jnp.zeros(A.shape[1]), 0, 1, 0]),
|
416
|
-
jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 1]),
|
417
|
-
jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 0]),
|
418
|
-
]
|
419
|
-
)
|
420
|
-
|
421
|
-
# Compute the matrix exponential.
|
422
|
-
exp_tA = jax.scipy.linalg.expm(
|
423
|
-
(dt * A̅).astype(float), max_squarings=max_squarings
|
424
|
-
)
|
425
|
-
|
426
|
-
# Integrate the contact dynamics in the continuous time domain.
|
427
|
-
x_int, x_int2 = (
|
428
|
-
jnp.hstack([jnp.eye(A.shape[0]), jnp.zeros(shape=(A.shape[0], 3))])
|
429
|
-
@ exp_tA
|
430
|
-
@ jnp.vstack([jnp.zeros(shape=(A.shape[0] + 1, 2)), jnp.eye(2)])
|
431
|
-
).T
|
432
|
-
|
433
|
-
jaxsim.exceptions.raise_runtime_error_if(
|
434
|
-
condition=jnp.isnan(x_int).any(),
|
435
|
-
msg="NaN integration, try to increase `max_squarings` or decreasing `dt`",
|
436
|
-
)
|
437
|
-
|
438
|
-
# ==========================
|
439
|
-
# Compute the contact forces
|
440
|
-
# ==========================
|
441
|
-
|
442
|
-
# Compute the average contact forces.
|
443
|
-
CW_f̅, _ = jnp.split(
|
444
|
-
(A_sc @ x_int / dt + b_sc).reshape(-1, 3),
|
445
|
-
indices_or_sections=2,
|
446
|
-
)
|
447
|
-
|
448
|
-
# Compute the average of the average contact forces.
|
449
|
-
CW_f̿, _ = jnp.split(
|
450
|
-
(A_sc @ x_int2 * 2 / (dt**2) + b_sc).reshape(-1, 3),
|
451
|
-
indices_or_sections=2,
|
452
|
-
)
|
453
|
-
|
454
|
-
# Extract the tangential deformation at the final state.
|
455
|
-
x_tf = x_int / dt
|
456
|
-
m_tf = jnp.split(x_tf, 3)[2].reshape(-1, 3)
|
457
|
-
|
458
|
-
return CW_f̅, CW_f̿, m_tf
|
459
|
-
|
460
|
-
@staticmethod
|
461
|
-
@jax.jit
|
462
|
-
def _contact_points_dynamics(
|
463
|
-
model: js.model.JaxSimModel,
|
464
|
-
data: js.data.JaxSimModelData,
|
465
|
-
*,
|
466
|
-
link_forces: jtp.MatrixLike | None = None,
|
467
|
-
joint_force_references: jtp.VectorLike | None = None,
|
468
|
-
indices_of_enabled_collidable_points: jtp.VectorLike | None = None,
|
469
|
-
p_t0: jtp.MatrixLike | None = None,
|
470
|
-
v_t0: jtp.MatrixLike | None = None,
|
471
|
-
m_t0: jtp.MatrixLike | None = None,
|
472
|
-
) -> tuple[jtp.Matrix, jtp.Vector, jtp.Matrix, jtp.Vector]:
|
473
|
-
"""
|
474
|
-
Compute the dynamics of the contact points.
|
475
|
-
|
476
|
-
Note:
|
477
|
-
This function projects the system dynamics to the contact space and
|
478
|
-
returns the matrices of a linear system to simulate its evolution.
|
479
|
-
Since the active contact model can be non-linear, this function also
|
480
|
-
linearizes the contact model at the initial state.
|
481
|
-
|
482
|
-
Args:
|
483
|
-
model: The robot model considered by the contact model.
|
484
|
-
data: The data of the considered model.
|
485
|
-
link_forces: The 6D forces to apply to the links.
|
486
|
-
joint_force_references: The joint force references to apply.
|
487
|
-
indices_of_enabled_collidable_points:
|
488
|
-
The indices of the enabled collidable points.
|
489
|
-
p_t0: The initial position of the collidable points.
|
490
|
-
v_t0: The initial velocity of the collidable points.
|
491
|
-
m_t0: The initial tangential deformation of the collidable points.
|
492
|
-
|
493
|
-
Returns:
|
494
|
-
A tuple containing:
|
495
|
-
- The `A` matrix of the linear system that models the contact dynamics.
|
496
|
-
- The `b` vector of the linear system that models the contact dynamics.
|
497
|
-
- The `A_sc` matrix of the linear system that approximates the contact model.
|
498
|
-
- The `b_sc` vector of the linear system that approximates the contact model.
|
499
|
-
"""
|
500
|
-
|
501
|
-
indices_of_enabled_collidable_points = (
|
502
|
-
indices_of_enabled_collidable_points
|
503
|
-
if indices_of_enabled_collidable_points is not None
|
504
|
-
else jnp.arange(
|
505
|
-
len(model.kin_dyn_parameters.contact_parameters.body)
|
506
|
-
).astype(int)
|
507
|
-
)
|
508
|
-
|
509
|
-
p_t0 = jnp.atleast_2d(
|
510
|
-
p_t0
|
511
|
-
if p_t0 is not None
|
512
|
-
else js.contact.collidable_point_positions(model=model, data=data)[
|
513
|
-
indices_of_enabled_collidable_points, :
|
514
|
-
]
|
515
|
-
)
|
516
|
-
|
517
|
-
v_t0 = jnp.atleast_2d(
|
518
|
-
v_t0
|
519
|
-
if v_t0 is not None
|
520
|
-
else js.contact.collidable_point_velocities(model=model, data=data)[
|
521
|
-
indices_of_enabled_collidable_points, :
|
522
|
-
]
|
523
|
-
)
|
524
|
-
|
525
|
-
m_t0 = jnp.atleast_2d(
|
526
|
-
m_t0
|
527
|
-
if m_t0 is not None
|
528
|
-
else data.state.extended["tangential_deformation"][
|
529
|
-
indices_of_enabled_collidable_points, :
|
530
|
-
]
|
531
|
-
)
|
532
|
-
|
533
|
-
# We expect that the 6D forces of the `link_forces` argument are expressed
|
534
|
-
# in the frame corresponding to the velocity representation of `data`.
|
535
|
-
references = js.references.JaxSimModelReferences.build(
|
536
|
-
model=model,
|
537
|
-
link_forces=link_forces,
|
538
|
-
joint_force_references=joint_force_references,
|
539
|
-
data=data,
|
540
|
-
velocity_representation=data.velocity_representation,
|
541
|
-
)
|
542
|
-
|
543
|
-
# ===========================
|
544
|
-
# Linearize the contact model
|
545
|
-
# ===========================
|
546
|
-
|
547
|
-
# Linearize the contact model at the initial state of all considered
|
548
|
-
# contact points.
|
549
|
-
A_sc_points, b_sc_points = jax.vmap(
|
550
|
-
lambda p, v, m: ViscoElasticContacts._linearize_contact_model(
|
551
|
-
position=p,
|
552
|
-
velocity=v,
|
553
|
-
tangential_deformation=m,
|
554
|
-
parameters=data.contacts_params,
|
555
|
-
terrain=model.terrain,
|
556
|
-
)
|
557
|
-
)(p_t0, v_t0, m_t0)
|
558
|
-
|
559
|
-
# Since x = [p1, p2, ..., v1, v2, ..., m1, m2, ...], we need to split the A_sc of
|
560
|
-
# individual points since otherwise we'd get x = [ p1, v1, m1, p2, v2, m2, ...].
|
561
|
-
A_sc_p, A_sc_v, A_sc_m = jnp.split(A_sc_points, indices_or_sections=3, axis=-1)
|
562
|
-
|
563
|
-
# We want to have in output first the forces and then the material deformation rates.
|
564
|
-
# Therefore, we need to extract the components is A_sc_* separately.
|
565
|
-
A_sc = jnp.vstack(
|
566
|
-
[
|
567
|
-
jnp.hstack(
|
568
|
-
[
|
569
|
-
jax.scipy.linalg.block_diag(*A_sc_p[:, 0:3, :]),
|
570
|
-
jax.scipy.linalg.block_diag(*A_sc_v[:, 0:3, :]),
|
571
|
-
jax.scipy.linalg.block_diag(*A_sc_m[:, 0:3, :]),
|
572
|
-
],
|
573
|
-
),
|
574
|
-
jnp.hstack(
|
575
|
-
[
|
576
|
-
jax.scipy.linalg.block_diag(*A_sc_p[:, 3:6, :]),
|
577
|
-
jax.scipy.linalg.block_diag(*A_sc_v[:, 3:6, :]),
|
578
|
-
jax.scipy.linalg.block_diag(*A_sc_m[:, 3:6, :]),
|
579
|
-
]
|
580
|
-
),
|
581
|
-
]
|
582
|
-
)
|
583
|
-
|
584
|
-
# We need to do the same for the b_sc.
|
585
|
-
b_sc = jnp.hstack(
|
586
|
-
[b_sc_points[:, 0:3].flatten(), b_sc_points[:, 3:6].flatten()]
|
587
|
-
)
|
588
|
-
|
589
|
-
# ===========================================================
|
590
|
-
# Compute the A and b matrices of the contact points dynamics
|
591
|
-
# ===========================================================
|
592
|
-
|
593
|
-
with data.switch_velocity_representation(jaxsim.VelRepr.Mixed):
|
594
|
-
|
595
|
-
BW_ν = data.generalized_velocity()
|
596
|
-
|
597
|
-
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
598
|
-
|
599
|
-
CW_Jl_WC = js.contact.jacobian(
|
600
|
-
model=model,
|
601
|
-
data=data,
|
602
|
-
output_vel_repr=jaxsim.VelRepr.Mixed,
|
603
|
-
)[indices_of_enabled_collidable_points, 0:3, :]
|
604
|
-
|
605
|
-
CW_J̇l_WC = js.contact.jacobian_derivative(
|
606
|
-
model=model, data=data, output_vel_repr=jaxsim.VelRepr.Mixed
|
607
|
-
)[indices_of_enabled_collidable_points, 0:3, :]
|
608
|
-
|
609
|
-
# Compute the Delassus matrix.
|
610
|
-
ψ = jnp.vstack(CW_Jl_WC) @ jnp.linalg.lstsq(M, jnp.vstack(CW_Jl_WC).T)[0]
|
611
|
-
|
612
|
-
I_nc = jnp.eye(v_t0.flatten().size)
|
613
|
-
O_nc = jnp.zeros(shape=(p_t0.flatten().size, p_t0.flatten().size))
|
614
|
-
|
615
|
-
# Pack the A matrix.
|
616
|
-
A = jnp.vstack(
|
617
|
-
[
|
618
|
-
jnp.hstack([O_nc, I_nc, O_nc]),
|
619
|
-
ψ @ jnp.split(A_sc, 2, axis=0)[0],
|
620
|
-
jnp.split(A_sc, 2, axis=0)[1],
|
621
|
-
]
|
622
|
-
)
|
623
|
-
|
624
|
-
# Short names for few variables.
|
625
|
-
ν = BW_ν
|
626
|
-
J = jnp.vstack(CW_Jl_WC)
|
627
|
-
J̇ = jnp.vstack(CW_J̇l_WC)
|
628
|
-
|
629
|
-
# Compute the free system acceleration components.
|
630
|
-
with (
|
631
|
-
data.switch_velocity_representation(jaxsim.VelRepr.Mixed),
|
632
|
-
references.switch_velocity_representation(jaxsim.VelRepr.Mixed),
|
633
|
-
):
|
634
|
-
|
635
|
-
BW_v̇_free_WB, s̈_free = js.ode.system_acceleration(
|
636
|
-
model=model,
|
637
|
-
data=data,
|
638
|
-
link_forces=references.link_forces(model=model, data=data),
|
639
|
-
joint_force_references=references.joint_force_references(model=model),
|
640
|
-
)
|
641
|
-
|
642
|
-
# Pack the free system acceleration in mixed representation.
|
643
|
-
ν̇_free = jnp.hstack([BW_v̇_free_WB, s̈_free])
|
644
|
-
|
645
|
-
# Compute the acceleration of collidable points.
|
646
|
-
# This is the true derivative of ṗ only in mixed representation.
|
647
|
-
p̈ = J @ ν̇_free + J̇ @ ν
|
648
|
-
|
649
|
-
# Pack the b array.
|
650
|
-
b = jnp.hstack(
|
651
|
-
[
|
652
|
-
jnp.zeros_like(p_t0.flatten()),
|
653
|
-
p̈ + ψ @ jnp.split(b_sc, indices_or_sections=2)[0],
|
654
|
-
jnp.split(b_sc, indices_or_sections=2)[1],
|
655
|
-
]
|
656
|
-
)
|
657
|
-
|
658
|
-
return A, b, A_sc, b_sc
|
659
|
-
|
660
|
-
@staticmethod
|
661
|
-
@functools.partial(jax.jit, static_argnames=("terrain",))
|
662
|
-
def _linearize_contact_model(
|
663
|
-
position: jtp.VectorLike,
|
664
|
-
velocity: jtp.VectorLike,
|
665
|
-
tangential_deformation: jtp.VectorLike,
|
666
|
-
parameters: ViscoElasticContactsParams,
|
667
|
-
terrain: Terrain,
|
668
|
-
) -> tuple[jtp.Matrix, jtp.Vector]:
|
669
|
-
"""
|
670
|
-
Linearize the Hunt/Crossley contact model at the initial state.
|
671
|
-
|
672
|
-
Args:
|
673
|
-
position: The position of the contact point.
|
674
|
-
velocity: The velocity of the contact point.
|
675
|
-
tangential_deformation: The tangential deformation of the contact point.
|
676
|
-
parameters: The parameters of the contact model.
|
677
|
-
terrain: The considered terrain.
|
678
|
-
|
679
|
-
Returns:
|
680
|
-
A tuple containing the `A` matrix and the `b` vector of the linear system
|
681
|
-
corresponding to the contact dynamics linearized at the initial state.
|
682
|
-
"""
|
683
|
-
|
684
|
-
# Initialize the state at which the model is linearized.
|
685
|
-
p0 = jnp.array(position, dtype=float).squeeze()
|
686
|
-
v0 = jnp.array(velocity, dtype=float).squeeze()
|
687
|
-
m0 = jnp.array(tangential_deformation, dtype=float).squeeze()
|
688
|
-
|
689
|
-
# ============
|
690
|
-
# Compute A_sc
|
691
|
-
# ============
|
692
|
-
|
693
|
-
compute_contact_force_non_linear_model = functools.partial(
|
694
|
-
ViscoElasticContacts._compute_contact_force_non_linear_model,
|
695
|
-
parameters=parameters,
|
696
|
-
terrain=terrain,
|
697
|
-
)
|
698
|
-
|
699
|
-
# Compute with AD the functions to get the Jacobians of CW_fl.
|
700
|
-
df_dp_fun, df_dv_fun, df_dm_fun = (
|
701
|
-
jax.jacrev(
|
702
|
-
lambda p0, v0, m0: compute_contact_force_non_linear_model(
|
703
|
-
position=p0, velocity=v0, tangential_deformation=m0
|
704
|
-
)[0],
|
705
|
-
argnums=num,
|
706
|
-
)
|
707
|
-
for num in (0, 1, 2)
|
708
|
-
)
|
709
|
-
|
710
|
-
# Compute with AD the functions to get the Jacobians of ṁ.
|
711
|
-
dṁ_dp_fun, dṁ_dv_fun, dṁ_dm_fun = (
|
712
|
-
jax.jacrev(
|
713
|
-
lambda p0, v0, m0: compute_contact_force_non_linear_model(
|
714
|
-
position=p0, velocity=v0, tangential_deformation=m0
|
715
|
-
)[1],
|
716
|
-
argnums=num,
|
717
|
-
)
|
718
|
-
for num in (0, 1, 2)
|
719
|
-
)
|
720
|
-
|
721
|
-
# Compute the Jacobians of the contact forces w.r.t. the state.
|
722
|
-
df_dp = jnp.vstack(df_dp_fun(p0, v0, m0))
|
723
|
-
df_dv = jnp.vstack(df_dv_fun(p0, v0, m0))
|
724
|
-
df_dm = jnp.vstack(df_dm_fun(p0, v0, m0))
|
725
|
-
|
726
|
-
# Compute the Jacobians of the material deformation rate w.r.t. the state.
|
727
|
-
dṁ_dp = jnp.vstack(dṁ_dp_fun(p0, v0, m0))
|
728
|
-
dṁ_dv = jnp.vstack(dṁ_dv_fun(p0, v0, m0))
|
729
|
-
dṁ_dm = jnp.vstack(dṁ_dm_fun(p0, v0, m0))
|
730
|
-
|
731
|
-
# Pack the A matrix.
|
732
|
-
A_sc = jnp.vstack(
|
733
|
-
[
|
734
|
-
jnp.hstack([df_dp, df_dv, df_dm]),
|
735
|
-
jnp.hstack([dṁ_dp, dṁ_dv, dṁ_dm]),
|
736
|
-
]
|
737
|
-
)
|
738
|
-
|
739
|
-
# ============
|
740
|
-
# Compute b_sc
|
741
|
-
# ============
|
742
|
-
|
743
|
-
# Compute the output of the non-linear model at the initial state.
|
744
|
-
x0 = jnp.hstack([p0, v0, m0])
|
745
|
-
f0, ṁ0 = compute_contact_force_non_linear_model(
|
746
|
-
position=p0, velocity=v0, tangential_deformation=m0
|
747
|
-
)
|
748
|
-
|
749
|
-
# Pack the b vector.
|
750
|
-
b_sc = jnp.hstack([f0, ṁ0]) - A_sc @ x0
|
751
|
-
|
752
|
-
return A_sc, b_sc
|
753
|
-
|
754
|
-
@staticmethod
|
755
|
-
@functools.partial(jax.jit, static_argnames=("terrain",))
|
756
|
-
def _compute_contact_force_non_linear_model(
|
757
|
-
position: jtp.VectorLike,
|
758
|
-
velocity: jtp.VectorLike,
|
759
|
-
tangential_deformation: jtp.VectorLike,
|
760
|
-
parameters: ViscoElasticContactsParams,
|
761
|
-
terrain: Terrain,
|
762
|
-
) -> tuple[jtp.Vector, jtp.Vector]:
|
763
|
-
"""
|
764
|
-
Compute the contact forces using the non-linear Hunt/Crossley model.
|
765
|
-
|
766
|
-
Args:
|
767
|
-
position: The position of the contact point.
|
768
|
-
velocity: The velocity of the contact point.
|
769
|
-
tangential_deformation: The tangential deformation of the contact point.
|
770
|
-
parameters: The parameters of the contact model.
|
771
|
-
terrain: The considered terrain.
|
772
|
-
|
773
|
-
Returns:
|
774
|
-
A tuple containing:
|
775
|
-
- The linear contact force in the mixed contact frame.
|
776
|
-
- The rate of material deformation.
|
777
|
-
"""
|
778
|
-
|
779
|
-
# Compute the linear contact force in mixed representation using
|
780
|
-
# the non-linear Hunt/Crossley model.
|
781
|
-
# The following function also returns the rate of material deformation.
|
782
|
-
CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model(
|
783
|
-
position=position,
|
784
|
-
velocity=velocity,
|
785
|
-
tangential_deformation=tangential_deformation,
|
786
|
-
terrain=terrain,
|
787
|
-
K=parameters.K,
|
788
|
-
D=parameters.D,
|
789
|
-
mu=parameters.static_friction,
|
790
|
-
p=parameters.p,
|
791
|
-
q=parameters.q,
|
792
|
-
)
|
793
|
-
|
794
|
-
return CW_fl, ṁ
|
795
|
-
|
796
|
-
@staticmethod
|
797
|
-
@jax.jit
|
798
|
-
def integrate_data_with_average_contact_forces(
|
799
|
-
model: js.model.JaxSimModel,
|
800
|
-
data: js.data.JaxSimModelData,
|
801
|
-
*,
|
802
|
-
dt: jtp.FloatLike,
|
803
|
-
link_forces: jtp.MatrixLike | None = None,
|
804
|
-
joint_force_references: jtp.VectorLike | None = None,
|
805
|
-
average_link_contact_forces_inertial: jtp.MatrixLike | None = None,
|
806
|
-
average_of_average_link_contact_forces_mixed: jtp.MatrixLike | None = None,
|
807
|
-
) -> js.data.JaxSimModelData:
|
808
|
-
"""
|
809
|
-
Advance the system state by integrating the dynamics.
|
810
|
-
|
811
|
-
Args:
|
812
|
-
model: The model to consider.
|
813
|
-
data: The data of the considered model.
|
814
|
-
dt: The integration time step.
|
815
|
-
link_forces:
|
816
|
-
The 6D forces to apply to the links expressed in the frame corresponding
|
817
|
-
to the velocity representation of `data`.
|
818
|
-
joint_force_references: The joint force references to apply.
|
819
|
-
average_link_contact_forces_inertial:
|
820
|
-
The average contact forces computed with the exponential integrator and
|
821
|
-
expressed in the inertial-fixed frame.
|
822
|
-
average_of_average_link_contact_forces_mixed:
|
823
|
-
The average of the average contact forces computed with the exponential
|
824
|
-
integrator and expressed in the mixed frame.
|
825
|
-
|
826
|
-
Returns:
|
827
|
-
The data object storing the system state at the final time.
|
828
|
-
"""
|
829
|
-
|
830
|
-
s_t0 = data.joint_positions()
|
831
|
-
W_p_B_t0 = data.base_position()
|
832
|
-
W_Q_B_t0 = data.base_orientation(dcm=False)
|
833
|
-
|
834
|
-
ṡ_t0 = data.joint_velocities()
|
835
|
-
with data.switch_velocity_representation(jaxsim.VelRepr.Mixed):
|
836
|
-
W_ṗ_B_t0 = data.base_velocity()[0:3]
|
837
|
-
W_ω_WB_t0 = data.base_velocity()[3:6]
|
838
|
-
|
839
|
-
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
|
840
|
-
W_ν_t0 = data.generalized_velocity()
|
841
|
-
|
842
|
-
# We expect that the 6D forces of the `link_forces` argument are expressed
|
843
|
-
# in the frame corresponding to the velocity representation of `data`.
|
844
|
-
references = js.references.JaxSimModelReferences.build(
|
845
|
-
model=model,
|
846
|
-
link_forces=link_forces,
|
847
|
-
joint_force_references=joint_force_references,
|
848
|
-
data=data,
|
849
|
-
velocity_representation=data.velocity_representation,
|
850
|
-
)
|
851
|
-
|
852
|
-
W_f̅_L = (
|
853
|
-
jnp.array(average_link_contact_forces_inertial)
|
854
|
-
if average_link_contact_forces_inertial is not None
|
855
|
-
else jnp.zeros_like(references._link_forces)
|
856
|
-
).astype(float)
|
857
|
-
|
858
|
-
LW_f̿_L = (
|
859
|
-
jnp.array(average_of_average_link_contact_forces_mixed)
|
860
|
-
if average_of_average_link_contact_forces_mixed is not None
|
861
|
-
else W_f̅_L
|
862
|
-
).astype(float)
|
863
|
-
|
864
|
-
# Compute the system inertial acceleration, used to integrate the system velocity.
|
865
|
-
# It considers the average contact forces computed with the exponential integrator.
|
866
|
-
with (
|
867
|
-
data.switch_velocity_representation(jaxsim.VelRepr.Inertial),
|
868
|
-
references.switch_velocity_representation(jaxsim.VelRepr.Inertial),
|
869
|
-
):
|
870
|
-
|
871
|
-
W_ν̇_pr = jnp.hstack(
|
872
|
-
js.ode.system_acceleration(
|
873
|
-
model=model,
|
874
|
-
data=data,
|
875
|
-
joint_force_references=references.joint_force_references(
|
876
|
-
model=model
|
877
|
-
),
|
878
|
-
link_forces=W_f̅_L + references.link_forces(model=model, data=data),
|
879
|
-
)
|
880
|
-
)
|
881
|
-
|
882
|
-
# Compute the system mixed acceleration, used to integrate the system position.
|
883
|
-
# It considers the average of the average contact forces computed with the
|
884
|
-
# exponential integrator.
|
885
|
-
with (
|
886
|
-
data.switch_velocity_representation(jaxsim.VelRepr.Mixed),
|
887
|
-
references.switch_velocity_representation(jaxsim.VelRepr.Mixed),
|
888
|
-
):
|
889
|
-
|
890
|
-
BW_ν̇_pr2 = jnp.hstack(
|
891
|
-
js.ode.system_acceleration(
|
892
|
-
model=model,
|
893
|
-
data=data,
|
894
|
-
joint_force_references=references.joint_force_references(
|
895
|
-
model=model
|
896
|
-
),
|
897
|
-
link_forces=LW_f̿_L + references.link_forces(model=model, data=data),
|
898
|
-
)
|
899
|
-
)
|
900
|
-
|
901
|
-
# Integrate the system velocity using the inertial-fixed acceleration.
|
902
|
-
W_ν_plus = W_ν_t0 + dt * W_ν̇_pr
|
903
|
-
|
904
|
-
# Integrate the system position using the mixed velocity.
|
905
|
-
q_plus = jnp.hstack(
|
906
|
-
[
|
907
|
-
# Note: here both ṗ and p̈ -> need mixed representation.
|
908
|
-
W_p_B_t0 + dt * W_ṗ_B_t0 + 0.5 * dt**2 * BW_ν̇_pr2[0:3],
|
909
|
-
jaxsim.math.Quaternion.integration(
|
910
|
-
dt=dt,
|
911
|
-
quaternion=W_Q_B_t0,
|
912
|
-
omega=(W_ω_WB_t0 + 0.5 * dt * BW_ν̇_pr2[3:6]),
|
913
|
-
omega_in_body_fixed=False,
|
914
|
-
).squeeze(),
|
915
|
-
s_t0 + dt * ṡ_t0 + 0.5 * dt**2 * BW_ν̇_pr2[6:],
|
916
|
-
]
|
917
|
-
)
|
918
|
-
|
919
|
-
# Create the data at the final time.
|
920
|
-
data_tf = data.copy()
|
921
|
-
data_tf = data_tf.reset_joint_positions(q_plus[7:])
|
922
|
-
data_tf = data_tf.reset_base_position(q_plus[0:3])
|
923
|
-
data_tf = data_tf.reset_base_quaternion(q_plus[3:7])
|
924
|
-
data_tf = data_tf.reset_joint_velocities(W_ν_plus[6:])
|
925
|
-
data_tf = data_tf.reset_base_velocity(
|
926
|
-
W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial
|
927
|
-
)
|
928
|
-
|
929
|
-
return data_tf.replace(
|
930
|
-
velocity_representation=data.velocity_representation, validate=False
|
931
|
-
)
|
932
|
-
|
933
|
-
|
934
|
-
@jax.jit
|
935
|
-
def step(
|
936
|
-
model: js.model.JaxSimModel,
|
937
|
-
data: js.data.JaxSimModelData,
|
938
|
-
*,
|
939
|
-
dt: jtp.FloatLike | None = None,
|
940
|
-
link_forces: jtp.MatrixLike | None = None,
|
941
|
-
joint_force_references: jtp.VectorLike | None = None,
|
942
|
-
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
|
943
|
-
"""
|
944
|
-
Step the system dynamics with the visco-elastic contact model.
|
945
|
-
|
946
|
-
Args:
|
947
|
-
model: The model to consider.
|
948
|
-
data: The data of the considered model.
|
949
|
-
dt: The time step to consider. If not specified, it is read from the model.
|
950
|
-
link_forces:
|
951
|
-
The 6D forces to apply to the links expressed in the frame corresponding to
|
952
|
-
the velocity representation of `data`.
|
953
|
-
joint_force_references: The joint force references to consider.
|
954
|
-
|
955
|
-
Returns:
|
956
|
-
A tuple containing the new data of the model
|
957
|
-
and an empty dictionary of auxiliary data.
|
958
|
-
"""
|
959
|
-
|
960
|
-
assert isinstance(model.contact_model, ViscoElasticContacts)
|
961
|
-
assert isinstance(data.contacts_params, ViscoElasticContactsParams)
|
962
|
-
|
963
|
-
# Compute the contact forces in inertial-fixed representation.
|
964
|
-
# TODO: understand what's wrong in other representations.
|
965
|
-
data_inertial_fixed = data.replace(
|
966
|
-
velocity_representation=jaxsim.VelRepr.Inertial, validate=False
|
967
|
-
)
|
968
|
-
|
969
|
-
# Create the references object.
|
970
|
-
references = js.references.JaxSimModelReferences.build(
|
971
|
-
model=model,
|
972
|
-
data=data,
|
973
|
-
link_forces=link_forces,
|
974
|
-
joint_force_references=joint_force_references,
|
975
|
-
velocity_representation=data.velocity_representation,
|
976
|
-
)
|
977
|
-
|
978
|
-
# Initialize the time step.
|
979
|
-
dt = dt if dt is not None else model.time_step
|
980
|
-
|
981
|
-
# Compute the contact forces with the exponential integrator.
|
982
|
-
W_f̅_C, aux_data = model.contact_model.compute_contact_forces(
|
983
|
-
model=model,
|
984
|
-
data=data_inertial_fixed,
|
985
|
-
dt=jnp.array(dt).astype(float),
|
986
|
-
link_forces=references.link_forces(model=model, data=data),
|
987
|
-
joint_force_references=references.joint_force_references(model=model),
|
988
|
-
)
|
989
|
-
|
990
|
-
# Extract the final material deformation and the average of average forces
|
991
|
-
# from the dictionary containing auxiliary data.
|
992
|
-
m_tf = aux_data["m_tf"]
|
993
|
-
W_f̿_C = aux_data["W_f_avg2_C"]
|
994
|
-
|
995
|
-
# ===============================
|
996
|
-
# Compute the link contact forces
|
997
|
-
# ===============================
|
998
|
-
|
999
|
-
# Get the link contact forces by summing the forces of contact points belonging
|
1000
|
-
# to the same link.
|
1001
|
-
W_f̅_L, W_f̿_L = jax.vmap(
|
1002
|
-
lambda W_f_C: model.contact_model.link_forces_from_contact_forces(
|
1003
|
-
model=model, data=data_inertial_fixed, contact_forces=W_f_C
|
1004
|
-
)
|
1005
|
-
)(jnp.stack([W_f̅_C, W_f̿_C]))
|
1006
|
-
|
1007
|
-
# Compute the link transforms.
|
1008
|
-
W_H_L = (
|
1009
|
-
js.model.forward_kinematics(model=model, data=data)
|
1010
|
-
if data.velocity_representation is not jaxsim.VelRepr.Inertial
|
1011
|
-
else jnp.zeros(shape=(model.number_of_links(), 4, 4))
|
1012
|
-
)
|
1013
|
-
|
1014
|
-
# For integration purpose, we need the average of average forces expressed in
|
1015
|
-
# mixed representation.
|
1016
|
-
LW_f̿_L = jax.vmap(
|
1017
|
-
lambda W_f_L, W_H_L: (
|
1018
|
-
ModelDataWithVelocityRepresentation.inertial_to_other_representation(
|
1019
|
-
array=W_f_L,
|
1020
|
-
other_representation=jaxsim.VelRepr.Mixed,
|
1021
|
-
transform=W_H_L,
|
1022
|
-
is_force=True,
|
1023
|
-
)
|
1024
|
-
)
|
1025
|
-
)(W_f̿_L, W_H_L)
|
1026
|
-
|
1027
|
-
# ==========================
|
1028
|
-
# Integrate the system state
|
1029
|
-
# ==========================
|
1030
|
-
|
1031
|
-
# Integrate the system dynamics using the average contact forces.
|
1032
|
-
data_tf: js.data.JaxSimModelData = (
|
1033
|
-
model.contact_model.integrate_data_with_average_contact_forces(
|
1034
|
-
model=model,
|
1035
|
-
data=data_inertial_fixed,
|
1036
|
-
dt=dt,
|
1037
|
-
link_forces=references.link_forces(model=model, data=data),
|
1038
|
-
joint_force_references=references.joint_force_references(model=model),
|
1039
|
-
average_link_contact_forces_inertial=W_f̅_L,
|
1040
|
-
average_of_average_link_contact_forces_mixed=LW_f̿_L,
|
1041
|
-
)
|
1042
|
-
)
|
1043
|
-
|
1044
|
-
# Store the tangential deformation at the final state.
|
1045
|
-
# Note that this was integrated in the continuous time domain, therefore it should
|
1046
|
-
# be much more accurate than the one computed with the discrete soft contacts.
|
1047
|
-
with data_tf.mutable_context():
|
1048
|
-
|
1049
|
-
# Extract the indices corresponding to the enabled collidable points.
|
1050
|
-
# The visco-elastic contact model computed only their contact forces.
|
1051
|
-
indices_of_enabled_collidable_points = (
|
1052
|
-
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
1053
|
-
)
|
1054
|
-
|
1055
|
-
data_tf.state.extended |= {
|
1056
|
-
"tangential_deformation": data_tf.state.extended["tangential_deformation"]
|
1057
|
-
.at[indices_of_enabled_collidable_points]
|
1058
|
-
.set(m_tf)
|
1059
|
-
}
|
1060
|
-
|
1061
|
-
# Restore the original velocity representation.
|
1062
|
-
data_tf = data_tf.replace(
|
1063
|
-
velocity_representation=data.velocity_representation, validate=False
|
1064
|
-
)
|
1065
|
-
|
1066
|
-
return data_tf, {}
|