jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
- jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/rbda/contacts/soft.py
DELETED
@@ -1,480 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import dataclasses
|
4
|
-
import functools
|
5
|
-
|
6
|
-
import jax
|
7
|
-
import jax.numpy as jnp
|
8
|
-
import jax_dataclasses
|
9
|
-
|
10
|
-
import jaxsim.api as js
|
11
|
-
import jaxsim.math
|
12
|
-
import jaxsim.typing as jtp
|
13
|
-
from jaxsim import logging
|
14
|
-
from jaxsim.math import StandardGravity
|
15
|
-
from jaxsim.terrain import Terrain
|
16
|
-
|
17
|
-
from . import common
|
18
|
-
|
19
|
-
try:
|
20
|
-
from typing import Self
|
21
|
-
except ImportError:
|
22
|
-
from typing_extensions import Self
|
23
|
-
|
24
|
-
|
25
|
-
@jax_dataclasses.pytree_dataclass
|
26
|
-
class SoftContactsParams(common.ContactsParams):
|
27
|
-
"""Parameters of the soft contacts model."""
|
28
|
-
|
29
|
-
K: jtp.Float = dataclasses.field(
|
30
|
-
default_factory=lambda: jnp.array(1e6, dtype=float)
|
31
|
-
)
|
32
|
-
|
33
|
-
D: jtp.Float = dataclasses.field(
|
34
|
-
default_factory=lambda: jnp.array(2000, dtype=float)
|
35
|
-
)
|
36
|
-
|
37
|
-
mu: jtp.Float = dataclasses.field(
|
38
|
-
default_factory=lambda: jnp.array(0.5, dtype=float)
|
39
|
-
)
|
40
|
-
|
41
|
-
p: jtp.Float = dataclasses.field(
|
42
|
-
default_factory=lambda: jnp.array(0.5, dtype=float)
|
43
|
-
)
|
44
|
-
|
45
|
-
q: jtp.Float = dataclasses.field(
|
46
|
-
default_factory=lambda: jnp.array(0.5, dtype=float)
|
47
|
-
)
|
48
|
-
|
49
|
-
def __hash__(self) -> int:
|
50
|
-
|
51
|
-
from jaxsim.utils.wrappers import HashedNumpyArray
|
52
|
-
|
53
|
-
return hash(
|
54
|
-
(
|
55
|
-
HashedNumpyArray.hash_of_array(self.K),
|
56
|
-
HashedNumpyArray.hash_of_array(self.D),
|
57
|
-
HashedNumpyArray.hash_of_array(self.mu),
|
58
|
-
HashedNumpyArray.hash_of_array(self.p),
|
59
|
-
HashedNumpyArray.hash_of_array(self.q),
|
60
|
-
)
|
61
|
-
)
|
62
|
-
|
63
|
-
def __eq__(self, other: SoftContactsParams) -> bool:
|
64
|
-
|
65
|
-
if not isinstance(other, SoftContactsParams):
|
66
|
-
return NotImplemented
|
67
|
-
|
68
|
-
return hash(self) == hash(other)
|
69
|
-
|
70
|
-
@classmethod
|
71
|
-
def build(
|
72
|
-
cls: type[Self],
|
73
|
-
*,
|
74
|
-
K: jtp.FloatLike = 1e6,
|
75
|
-
D: jtp.FloatLike = 2_000,
|
76
|
-
mu: jtp.FloatLike = 0.5,
|
77
|
-
p: jtp.FloatLike = 0.5,
|
78
|
-
q: jtp.FloatLike = 0.5,
|
79
|
-
) -> Self:
|
80
|
-
"""
|
81
|
-
Create a SoftContactsParams instance with specified parameters.
|
82
|
-
|
83
|
-
Args:
|
84
|
-
K: The stiffness parameter.
|
85
|
-
D: The damping parameter of the soft contacts model.
|
86
|
-
mu: The static friction coefficient.
|
87
|
-
p:
|
88
|
-
The exponent p corresponding to the damping-related non-linearity
|
89
|
-
of the Hunt/Crossley model.
|
90
|
-
q:
|
91
|
-
The exponent q corresponding to the spring-related non-linearity
|
92
|
-
of the Hunt/Crossley model
|
93
|
-
|
94
|
-
Returns:
|
95
|
-
A SoftContactsParams instance with the specified parameters.
|
96
|
-
"""
|
97
|
-
|
98
|
-
return SoftContactsParams(
|
99
|
-
K=jnp.array(K, dtype=float),
|
100
|
-
D=jnp.array(D, dtype=float),
|
101
|
-
mu=jnp.array(mu, dtype=float),
|
102
|
-
p=jnp.array(p, dtype=float),
|
103
|
-
q=jnp.array(q, dtype=float),
|
104
|
-
)
|
105
|
-
|
106
|
-
@classmethod
|
107
|
-
def build_default_from_jaxsim_model(
|
108
|
-
cls: type[Self],
|
109
|
-
model: js.model.JaxSimModel,
|
110
|
-
*,
|
111
|
-
standard_gravity: jtp.FloatLike = StandardGravity,
|
112
|
-
static_friction_coefficient: jtp.FloatLike = 0.5,
|
113
|
-
max_penetration: jtp.FloatLike = 0.001,
|
114
|
-
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
115
|
-
damping_ratio: jtp.FloatLike = 1.0,
|
116
|
-
p: jtp.FloatLike = 0.5,
|
117
|
-
q: jtp.FloatLike = 0.5,
|
118
|
-
) -> SoftContactsParams:
|
119
|
-
"""
|
120
|
-
Create a SoftContactsParams instance with good default parameters.
|
121
|
-
|
122
|
-
Args:
|
123
|
-
model: The target model.
|
124
|
-
standard_gravity: The standard gravity constant.
|
125
|
-
static_friction_coefficient:
|
126
|
-
The static friction coefficient between the model and the terrain.
|
127
|
-
max_penetration: The maximum penetration depth.
|
128
|
-
number_of_active_collidable_points_steady_state:
|
129
|
-
The number of contacts supporting the weight of the model
|
130
|
-
in steady state.
|
131
|
-
damping_ratio: The ratio controlling the damping behavior.
|
132
|
-
p:
|
133
|
-
The exponent p corresponding to the damping-related non-linearity
|
134
|
-
of the Hunt/Crossley model.
|
135
|
-
q:
|
136
|
-
The exponent q corresponding to the spring-related non-linearity
|
137
|
-
of the Hunt/Crossley model
|
138
|
-
|
139
|
-
Returns:
|
140
|
-
A `SoftContactsParams` instance with the specified parameters.
|
141
|
-
|
142
|
-
Note:
|
143
|
-
The `damping_ratio` parameter allows to operate on the following conditions:
|
144
|
-
- ξ > 1.0: over-damped
|
145
|
-
- ξ = 1.0: critically damped
|
146
|
-
- ξ < 1.0: under-damped
|
147
|
-
"""
|
148
|
-
|
149
|
-
# Use symbols for input parameters.
|
150
|
-
ξ = damping_ratio
|
151
|
-
δ_max = max_penetration
|
152
|
-
μc = static_friction_coefficient
|
153
|
-
|
154
|
-
# Compute the total mass of the model.
|
155
|
-
m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
|
156
|
-
|
157
|
-
# Rename the standard gravity.
|
158
|
-
g = standard_gravity
|
159
|
-
|
160
|
-
# Compute the average support force on each collidable point.
|
161
|
-
f_average = m * g / number_of_active_collidable_points_steady_state
|
162
|
-
|
163
|
-
# Compute the stiffness to get the desired steady-state penetration.
|
164
|
-
# Note that this is dependent on the non-linear exponent used in
|
165
|
-
# the damping term of the Hunt/Crossley model.
|
166
|
-
K = f_average / jnp.power(δ_max, 1 + p)
|
167
|
-
|
168
|
-
# Compute the damping using the damping ratio.
|
169
|
-
critical_damping = 2 * jnp.sqrt(K * m)
|
170
|
-
D = ξ * critical_damping
|
171
|
-
|
172
|
-
return SoftContactsParams.build(K=K, D=D, mu=μc, p=p, q=q)
|
173
|
-
|
174
|
-
def valid(self) -> jtp.BoolLike:
|
175
|
-
"""
|
176
|
-
Check if the parameters are valid.
|
177
|
-
|
178
|
-
Returns:
|
179
|
-
`True` if the parameters are valid, `False` otherwise.
|
180
|
-
"""
|
181
|
-
|
182
|
-
return jnp.hstack(
|
183
|
-
[
|
184
|
-
self.K >= 0.0,
|
185
|
-
self.D >= 0.0,
|
186
|
-
self.mu >= 0.0,
|
187
|
-
self.p >= 0.0,
|
188
|
-
self.q >= 0.0,
|
189
|
-
]
|
190
|
-
).all()
|
191
|
-
|
192
|
-
|
193
|
-
@jax_dataclasses.pytree_dataclass
|
194
|
-
class SoftContacts(common.ContactModel):
|
195
|
-
"""Soft contacts model."""
|
196
|
-
|
197
|
-
@classmethod
|
198
|
-
def build(
|
199
|
-
cls: type[Self],
|
200
|
-
model: js.model.JaxSimModel | None = None,
|
201
|
-
**kwargs,
|
202
|
-
) -> Self:
|
203
|
-
"""
|
204
|
-
Create a `SoftContacts` instance with specified parameters.
|
205
|
-
|
206
|
-
Args:
|
207
|
-
model:
|
208
|
-
The robot model considered by the contact model.
|
209
|
-
If passed, it is used to estimate good default parameters.
|
210
|
-
**kwargs: Additional parameters to pass to the contact model.
|
211
|
-
|
212
|
-
Returns:
|
213
|
-
The `SoftContacts` instance.
|
214
|
-
"""
|
215
|
-
|
216
|
-
if len(kwargs) != 0:
|
217
|
-
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
|
218
|
-
|
219
|
-
return cls(**kwargs)
|
220
|
-
|
221
|
-
@classmethod
|
222
|
-
def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
|
223
|
-
"""
|
224
|
-
Build zero state variables of the contact model.
|
225
|
-
"""
|
226
|
-
|
227
|
-
# Initialize the material deformation to zero.
|
228
|
-
tangential_deformation = jnp.zeros(
|
229
|
-
shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3),
|
230
|
-
dtype=float,
|
231
|
-
)
|
232
|
-
|
233
|
-
return {"tangential_deformation": tangential_deformation}
|
234
|
-
|
235
|
-
@staticmethod
|
236
|
-
@functools.partial(jax.jit, static_argnames=("terrain",))
|
237
|
-
def hunt_crossley_contact_model(
|
238
|
-
position: jtp.VectorLike,
|
239
|
-
velocity: jtp.VectorLike,
|
240
|
-
tangential_deformation: jtp.VectorLike,
|
241
|
-
terrain: Terrain,
|
242
|
-
K: jtp.FloatLike,
|
243
|
-
D: jtp.FloatLike,
|
244
|
-
mu: jtp.FloatLike,
|
245
|
-
p: jtp.FloatLike = 0.5,
|
246
|
-
q: jtp.FloatLike = 0.5,
|
247
|
-
) -> tuple[jtp.Vector, jtp.Vector]:
|
248
|
-
"""
|
249
|
-
Compute the contact force using the Hunt/Crossley model.
|
250
|
-
|
251
|
-
Args:
|
252
|
-
position: The position of the collidable point.
|
253
|
-
velocity: The velocity of the collidable point.
|
254
|
-
tangential_deformation: The material deformation of the collidable point.
|
255
|
-
terrain: The terrain model.
|
256
|
-
K: The stiffness parameter.
|
257
|
-
D: The damping parameter of the soft contacts model.
|
258
|
-
mu: The static friction coefficient.
|
259
|
-
p:
|
260
|
-
The exponent p corresponding to the damping-related non-linearity
|
261
|
-
of the Hunt/Crossley model.
|
262
|
-
q:
|
263
|
-
The exponent q corresponding to the spring-related non-linearity
|
264
|
-
of the Hunt/Crossley model
|
265
|
-
|
266
|
-
Returns:
|
267
|
-
A tuple containing the computed contact force and the derivative of the
|
268
|
-
material deformation.
|
269
|
-
"""
|
270
|
-
|
271
|
-
# Convert the input vectors to arrays.
|
272
|
-
W_p_C = jnp.array(position, dtype=float).squeeze()
|
273
|
-
W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()
|
274
|
-
m = jnp.array(tangential_deformation, dtype=float).squeeze()
|
275
|
-
|
276
|
-
# Use symbol for the static friction.
|
277
|
-
μ = mu
|
278
|
-
|
279
|
-
# Compute the penetration depth, its rate, and the considered terrain normal.
|
280
|
-
δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)
|
281
|
-
|
282
|
-
# There are few operations like computing the norm of a vector with zero length
|
283
|
-
# or computing the square root of zero that are problematic in an AD context.
|
284
|
-
# To avoid these issues, we introduce a small tolerance ε to their arguments
|
285
|
-
# and make sure that we do not check them against zero directly.
|
286
|
-
ε = jnp.finfo(float).eps
|
287
|
-
|
288
|
-
# Compute the powers of the penetration depth.
|
289
|
-
# Inject ε to address AD issues in differentiating the square root when
|
290
|
-
# p and q are fractional.
|
291
|
-
δp = jnp.power(δ + ε, p)
|
292
|
-
δq = jnp.power(δ + ε, q)
|
293
|
-
|
294
|
-
# ========================
|
295
|
-
# Compute the normal force
|
296
|
-
# ========================
|
297
|
-
|
298
|
-
# Non-linear spring-damper model (Hunt/Crossley model).
|
299
|
-
# This is the force magnitude along the direction normal to the terrain.
|
300
|
-
force_normal_mag = (K * δp) * δ + (D * δq) * δ̇
|
301
|
-
|
302
|
-
# Depending on the magnitude of δ̇, the normal force could be negative.
|
303
|
-
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
|
304
|
-
|
305
|
-
# Compute the 3D linear force in C[W] frame.
|
306
|
-
f_normal = force_normal_mag * n̂
|
307
|
-
|
308
|
-
# ============================
|
309
|
-
# Compute the tangential force
|
310
|
-
# ============================
|
311
|
-
|
312
|
-
# Extract the tangential component of the velocity.
|
313
|
-
v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂
|
314
|
-
|
315
|
-
# Extract the normal and tangential components of the material deformation.
|
316
|
-
m_normal = jnp.dot(m, n̂) * n̂
|
317
|
-
m_tangential = m - jnp.dot(m, n̂) * n̂
|
318
|
-
|
319
|
-
# Compute the tangential force in the sticking case.
|
320
|
-
# Using the tangential component of the material deformation should not be
|
321
|
-
# necessary if the sticking-slipping transition occurs in a terrain area
|
322
|
-
# with a locally constant normal. However, this assumption is not true in
|
323
|
-
# general, especially for highly uneven terrains.
|
324
|
-
f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)
|
325
|
-
|
326
|
-
# Detect the contact type (sticking or slipping).
|
327
|
-
# Note that if there is no contact, sticking is set to True, and this detail
|
328
|
-
# is exploited in the computation of the `contact_status` variable.
|
329
|
-
sticking = jnp.logical_or(
|
330
|
-
δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2
|
331
|
-
)
|
332
|
-
|
333
|
-
# Compute the direction of the tangential force.
|
334
|
-
# To prevent dividing by zero, we use a switch statement.
|
335
|
-
norm = jaxsim.math.safe_norm(f_tangential)
|
336
|
-
f_tangential_direction = f_tangential / (
|
337
|
-
norm + jnp.finfo(float).eps * (norm == 0)
|
338
|
-
)
|
339
|
-
|
340
|
-
# Project the tangential force to the friction cone if slipping.
|
341
|
-
f_tangential = jnp.where(
|
342
|
-
sticking,
|
343
|
-
f_tangential,
|
344
|
-
jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
|
345
|
-
)
|
346
|
-
|
347
|
-
# Set the tangential force to zero if there is no contact.
|
348
|
-
f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential)
|
349
|
-
|
350
|
-
# =====================================
|
351
|
-
# Compute the material deformation rate
|
352
|
-
# =====================================
|
353
|
-
|
354
|
-
# Compute the derivative of the material deformation.
|
355
|
-
# Note that we included an additional relaxation of `m_normal` in the
|
356
|
-
# sticking case, so that the normal deformation that could have accumulated
|
357
|
-
# from a previous slipping phase can relax to zero.
|
358
|
-
ṁ_no_contact = -(K / D) * m
|
359
|
-
ṁ_sticking = v_tangential - (K / D) * m_normal
|
360
|
-
ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)
|
361
|
-
|
362
|
-
# Compute the contact status:
|
363
|
-
# 0: slipping
|
364
|
-
# 1: sticking
|
365
|
-
# 2: no contact
|
366
|
-
contact_status = sticking.astype(int)
|
367
|
-
contact_status += (δ <= 0).astype(int)
|
368
|
-
|
369
|
-
# Select the right material deformation rate depending on the contact status.
|
370
|
-
ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact)
|
371
|
-
|
372
|
-
# ==========================================
|
373
|
-
# Compute and return the final contact force
|
374
|
-
# ==========================================
|
375
|
-
|
376
|
-
# Sum the normal and tangential forces.
|
377
|
-
CW_fl = f_normal + f_tangential
|
378
|
-
|
379
|
-
return CW_fl, ṁ
|
380
|
-
|
381
|
-
@staticmethod
|
382
|
-
@functools.partial(jax.jit, static_argnames=("terrain",))
|
383
|
-
def compute_contact_force(
|
384
|
-
position: jtp.VectorLike,
|
385
|
-
velocity: jtp.VectorLike,
|
386
|
-
tangential_deformation: jtp.VectorLike,
|
387
|
-
parameters: SoftContactsParams,
|
388
|
-
terrain: Terrain,
|
389
|
-
) -> tuple[jtp.Vector, jtp.Vector]:
|
390
|
-
"""
|
391
|
-
Compute the contact force.
|
392
|
-
|
393
|
-
Args:
|
394
|
-
position: The position of the collidable point.
|
395
|
-
velocity: The velocity of the collidable point.
|
396
|
-
tangential_deformation: The material deformation of the collidable point.
|
397
|
-
parameters: The parameters of the soft contacts model.
|
398
|
-
terrain: The terrain model.
|
399
|
-
|
400
|
-
Returns:
|
401
|
-
A tuple containing the computed contact force and the derivative of the
|
402
|
-
material deformation.
|
403
|
-
"""
|
404
|
-
|
405
|
-
CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model(
|
406
|
-
position=position,
|
407
|
-
velocity=velocity,
|
408
|
-
tangential_deformation=tangential_deformation,
|
409
|
-
terrain=terrain,
|
410
|
-
K=parameters.K,
|
411
|
-
D=parameters.D,
|
412
|
-
mu=parameters.mu,
|
413
|
-
p=parameters.p,
|
414
|
-
q=parameters.q,
|
415
|
-
)
|
416
|
-
|
417
|
-
# Pack a mixed 6D force.
|
418
|
-
CW_f = jnp.hstack([CW_fl, jnp.zeros(3)])
|
419
|
-
|
420
|
-
# Compute the 6D force transform from the mixed to the inertial-fixed frame.
|
421
|
-
W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation(
|
422
|
-
translation=jnp.array(position), inverse=True
|
423
|
-
).T
|
424
|
-
|
425
|
-
# Compute the 6D force in the inertial-fixed frame.
|
426
|
-
W_f = W_Xf_CW @ CW_f
|
427
|
-
|
428
|
-
return W_f, ṁ
|
429
|
-
|
430
|
-
@staticmethod
|
431
|
-
@jax.jit
|
432
|
-
def compute_contact_forces(
|
433
|
-
model: js.model.JaxSimModel,
|
434
|
-
data: js.data.JaxSimModelData,
|
435
|
-
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
436
|
-
"""
|
437
|
-
Compute the contact forces.
|
438
|
-
|
439
|
-
Args:
|
440
|
-
model: The model to consider.
|
441
|
-
data: The data of the considered model.
|
442
|
-
|
443
|
-
Returns:
|
444
|
-
A tuple containing as first element the computed contact forces, and as
|
445
|
-
second element a dictionary with derivative of the material deformation.
|
446
|
-
"""
|
447
|
-
|
448
|
-
# Get the indices of the enabled collidable points.
|
449
|
-
indices_of_enabled_collidable_points = (
|
450
|
-
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
451
|
-
)
|
452
|
-
|
453
|
-
# Compute the position and linear velocities (mixed representation) of
|
454
|
-
# all the collidable points belonging to the robot and extract the ones
|
455
|
-
# for the enabled collidable points.
|
456
|
-
W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data)
|
457
|
-
|
458
|
-
# Extract the material deformation corresponding to the collidable points.
|
459
|
-
m = data.state.extended["tangential_deformation"]
|
460
|
-
|
461
|
-
m_enabled = m[indices_of_enabled_collidable_points]
|
462
|
-
|
463
|
-
# Initialize the tangential deformation rate array for every collidable point.
|
464
|
-
ṁ = jnp.zeros_like(m)
|
465
|
-
|
466
|
-
# Compute the contact forces only for the enabled collidable points.
|
467
|
-
# Since we treat them as independent, we can vmap the computation.
|
468
|
-
W_f, ṁ_enabled = jax.vmap(
|
469
|
-
lambda p, v, m: SoftContacts.compute_contact_force(
|
470
|
-
position=p,
|
471
|
-
velocity=v,
|
472
|
-
tangential_deformation=m,
|
473
|
-
parameters=data.contacts_params,
|
474
|
-
terrain=model.terrain,
|
475
|
-
)
|
476
|
-
)(W_p_C, W_ṗ_C, m_enabled)
|
477
|
-
|
478
|
-
ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled)
|
479
|
-
|
480
|
-
return W_f, dict(m_dot=ṁ)
|