jaxsim 0.4.3.dev68__py3-none-any.whl → 0.4.3.dev77__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -0
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +32 -1
- jaxsim/api/data.py +68 -20
- jaxsim/api/joint.py +62 -2
- jaxsim/api/model.py +37 -23
- jaxsim/api/ode.py +29 -25
- jaxsim/api/ode_data.py +11 -1
- jaxsim/integrators/common.py +1 -1
- jaxsim/math/inertia.py +1 -1
- jaxsim/mujoco/loaders.py +3 -3
- jaxsim/parsers/kinematic_graph.py +3 -3
- jaxsim/parsers/rod/parser.py +18 -14
- jaxsim/rbda/contacts/relaxed_rigid.py +409 -0
- jaxsim/rbda/contacts/rigid.py +21 -41
- jaxsim/terrain/terrain.py +41 -25
- jaxsim/typing.py +1 -1
- jaxsim/utils/jaxsim_dataclass.py +12 -9
- jaxsim/utils/wrappers.py +1 -1
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/METADATA +2 -1
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/RECORD +24 -23
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/top_level.txt +0 -0
jaxsim/parsers/rod/parser.py
CHANGED
@@ -223,7 +223,7 @@ def extract_model_data(
|
|
223
223
|
child=links_dict[j.child],
|
224
224
|
jtype=utils.joint_to_joint_type(joint=j),
|
225
225
|
axis=(
|
226
|
-
np.array(j.axis.xyz.xyz)
|
226
|
+
np.array(j.axis.xyz.xyz, dtype=float)
|
227
227
|
if j.axis is not None
|
228
228
|
and j.axis.xyz is not None
|
229
229
|
and j.axis.xyz.xyz is not None
|
@@ -232,39 +232,43 @@ def extract_model_data(
|
|
232
232
|
pose=j.pose.transform() if j.pose is not None else np.eye(4),
|
233
233
|
initial_position=0.0,
|
234
234
|
position_limit=(
|
235
|
-
(
|
236
|
-
|
237
|
-
if j.axis is not None
|
238
|
-
|
235
|
+
float(
|
236
|
+
j.axis.limit.lower
|
237
|
+
if j.axis is not None
|
238
|
+
and j.axis.limit is not None
|
239
|
+
and j.axis.limit.lower is not None
|
240
|
+
else jnp.finfo(float).min
|
239
241
|
),
|
240
|
-
(
|
241
|
-
|
242
|
-
if j.axis is not None
|
243
|
-
|
242
|
+
float(
|
243
|
+
j.axis.limit.upper
|
244
|
+
if j.axis is not None
|
245
|
+
and j.axis.limit is not None
|
246
|
+
and j.axis.limit.upper is not None
|
247
|
+
else jnp.finfo(float).max
|
244
248
|
),
|
245
249
|
),
|
246
|
-
friction_static=(
|
250
|
+
friction_static=float(
|
247
251
|
j.axis.dynamics.friction
|
248
252
|
if j.axis is not None
|
249
253
|
and j.axis.dynamics is not None
|
250
254
|
and j.axis.dynamics.friction is not None
|
251
255
|
else 0.0
|
252
256
|
),
|
253
|
-
friction_viscous=(
|
257
|
+
friction_viscous=float(
|
254
258
|
j.axis.dynamics.damping
|
255
259
|
if j.axis is not None
|
256
260
|
and j.axis.dynamics is not None
|
257
261
|
and j.axis.dynamics.damping is not None
|
258
262
|
else 0.0
|
259
263
|
),
|
260
|
-
position_limit_damper=(
|
264
|
+
position_limit_damper=float(
|
261
265
|
j.axis.limit.dissipation
|
262
266
|
if j.axis is not None
|
263
267
|
and j.axis.limit is not None
|
264
268
|
and j.axis.limit.dissipation is not None
|
265
269
|
else 0.0
|
266
270
|
),
|
267
|
-
position_limit_spring=(
|
271
|
+
position_limit_spring=float(
|
268
272
|
j.axis.limit.stiffness
|
269
273
|
if j.axis is not None
|
270
274
|
and j.axis.limit is not None
|
@@ -273,7 +277,7 @@ def extract_model_data(
|
|
273
277
|
),
|
274
278
|
)
|
275
279
|
for j in sdf_model.joints()
|
276
|
-
if j.type in {"revolute", "prismatic", "fixed"}
|
280
|
+
if j.type in {"revolute", "continuous", "prismatic", "fixed"}
|
277
281
|
and j.parent != "world"
|
278
282
|
and j.child in links_dict.keys()
|
279
283
|
]
|
@@ -0,0 +1,409 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import jax_dataclasses
|
9
|
+
import jaxopt
|
10
|
+
|
11
|
+
import jaxsim.api as js
|
12
|
+
import jaxsim.typing as jtp
|
13
|
+
from jaxsim.api.common import VelRepr
|
14
|
+
from jaxsim.math import Adjoint
|
15
|
+
from jaxsim.terrain.terrain import FlatTerrain, Terrain
|
16
|
+
|
17
|
+
from .common import ContactModel, ContactsParams, ContactsState
|
18
|
+
|
19
|
+
|
20
|
+
@jax_dataclasses.pytree_dataclass
|
21
|
+
class RelaxedRigidContactsParams(ContactsParams):
|
22
|
+
"""Parameters of the relaxed rigid contacts model."""
|
23
|
+
|
24
|
+
# Time constant
|
25
|
+
time_constant: jtp.Float = dataclasses.field(
|
26
|
+
default_factory=lambda: jnp.array(0.01, dtype=float)
|
27
|
+
)
|
28
|
+
|
29
|
+
# Adimensional damping coefficient
|
30
|
+
damping_coefficient: jtp.Float = dataclasses.field(
|
31
|
+
default_factory=lambda: jnp.array(1.0, dtype=float)
|
32
|
+
)
|
33
|
+
|
34
|
+
# Minimum impedance
|
35
|
+
d_min: jtp.Float = dataclasses.field(
|
36
|
+
default_factory=lambda: jnp.array(0.9, dtype=float)
|
37
|
+
)
|
38
|
+
|
39
|
+
# Maximum impedance
|
40
|
+
d_max: jtp.Float = dataclasses.field(
|
41
|
+
default_factory=lambda: jnp.array(0.95, dtype=float)
|
42
|
+
)
|
43
|
+
|
44
|
+
# Width
|
45
|
+
width: jtp.Float = dataclasses.field(
|
46
|
+
default_factory=lambda: jnp.array(0.0001, dtype=float)
|
47
|
+
)
|
48
|
+
|
49
|
+
# Midpoint
|
50
|
+
midpoint: jtp.Float = dataclasses.field(
|
51
|
+
default_factory=lambda: jnp.array(0.1, dtype=float)
|
52
|
+
)
|
53
|
+
|
54
|
+
# Power exponent
|
55
|
+
power: jtp.Float = dataclasses.field(
|
56
|
+
default_factory=lambda: jnp.array(1.0, dtype=float)
|
57
|
+
)
|
58
|
+
|
59
|
+
# Stiffness
|
60
|
+
stiffness: jtp.Float = dataclasses.field(
|
61
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
62
|
+
)
|
63
|
+
|
64
|
+
# Damping
|
65
|
+
damping: jtp.Float = dataclasses.field(
|
66
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
67
|
+
)
|
68
|
+
|
69
|
+
# Friction coefficient
|
70
|
+
mu: jtp.Float = dataclasses.field(
|
71
|
+
default_factory=lambda: jnp.array(0.5, dtype=float)
|
72
|
+
)
|
73
|
+
|
74
|
+
# Maximum number of iterations
|
75
|
+
max_iterations: jtp.Int = dataclasses.field(
|
76
|
+
default_factory=lambda: jnp.array(50, dtype=int)
|
77
|
+
)
|
78
|
+
|
79
|
+
# Solver tolerance
|
80
|
+
tolerance: jtp.Float = dataclasses.field(
|
81
|
+
default_factory=lambda: jnp.array(1e-6, dtype=float)
|
82
|
+
)
|
83
|
+
|
84
|
+
def __hash__(self) -> int:
|
85
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
86
|
+
|
87
|
+
return hash(
|
88
|
+
(
|
89
|
+
HashedNumpyArray(self.time_constant),
|
90
|
+
HashedNumpyArray(self.damping_coefficient),
|
91
|
+
HashedNumpyArray(self.d_min),
|
92
|
+
HashedNumpyArray(self.d_max),
|
93
|
+
HashedNumpyArray(self.width),
|
94
|
+
HashedNumpyArray(self.midpoint),
|
95
|
+
HashedNumpyArray(self.power),
|
96
|
+
HashedNumpyArray(self.stiffness),
|
97
|
+
HashedNumpyArray(self.damping),
|
98
|
+
HashedNumpyArray(self.mu),
|
99
|
+
HashedNumpyArray(self.max_iterations),
|
100
|
+
HashedNumpyArray(self.tolerance),
|
101
|
+
)
|
102
|
+
)
|
103
|
+
|
104
|
+
def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
|
105
|
+
return hash(self) == hash(other)
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def build(
|
109
|
+
cls,
|
110
|
+
time_constant: jtp.FloatLike | None = None,
|
111
|
+
damping_coefficient: jtp.FloatLike | None = None,
|
112
|
+
d_min: jtp.FloatLike | None = None,
|
113
|
+
d_max: jtp.FloatLike | None = None,
|
114
|
+
width: jtp.FloatLike | None = None,
|
115
|
+
midpoint: jtp.FloatLike | None = None,
|
116
|
+
power: jtp.FloatLike | None = None,
|
117
|
+
stiffness: jtp.FloatLike | None = None,
|
118
|
+
damping: jtp.FloatLike | None = None,
|
119
|
+
mu: jtp.FloatLike | None = None,
|
120
|
+
max_iterations: jtp.IntLike | None = None,
|
121
|
+
tolerance: jtp.FloatLike | None = None,
|
122
|
+
) -> RelaxedRigidContactsParams:
|
123
|
+
"""Create a `RelaxedRigidContactsParams` instance"""
|
124
|
+
|
125
|
+
return cls(
|
126
|
+
**{
|
127
|
+
field: jnp.array(locals().get(field, default), dtype=default.dtype)
|
128
|
+
for field, default in map(
|
129
|
+
lambda f: (f, cls.__dataclass_fields__[f].default),
|
130
|
+
filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
|
131
|
+
)
|
132
|
+
}
|
133
|
+
)
|
134
|
+
|
135
|
+
def valid(self) -> bool:
|
136
|
+
return bool(
|
137
|
+
jnp.all(self.time_constant >= 0.0)
|
138
|
+
and jnp.all(self.damping_coefficient > 0.0)
|
139
|
+
and jnp.all(self.d_min >= 0.0)
|
140
|
+
and jnp.all(self.d_max <= 1.0)
|
141
|
+
and jnp.all(self.d_min <= self.d_max)
|
142
|
+
and jnp.all(self.width >= 0.0)
|
143
|
+
and jnp.all(self.midpoint >= 0.0)
|
144
|
+
and jnp.all(self.power >= 0.0)
|
145
|
+
and jnp.all(self.mu >= 0.0)
|
146
|
+
and jnp.all(self.max_iterations > 0)
|
147
|
+
and jnp.all(self.tolerance > 0.0)
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
@jax_dataclasses.pytree_dataclass
|
152
|
+
class RelaxedRigidContactsState(ContactsState):
|
153
|
+
"""Class storing the state of the relaxed rigid contacts model."""
|
154
|
+
|
155
|
+
def __eq__(self, other: RelaxedRigidContactsState) -> bool:
|
156
|
+
return hash(self) == hash(other)
|
157
|
+
|
158
|
+
@staticmethod
|
159
|
+
def build() -> RelaxedRigidContactsState:
|
160
|
+
"""Create a `RelaxedRigidContactsState` instance"""
|
161
|
+
|
162
|
+
return RelaxedRigidContactsState()
|
163
|
+
|
164
|
+
@staticmethod
|
165
|
+
def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
|
166
|
+
"""Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
|
167
|
+
return RelaxedRigidContactsState.build()
|
168
|
+
|
169
|
+
def valid(self, model: js.model.JaxSimModel) -> bool:
|
170
|
+
return True
|
171
|
+
|
172
|
+
|
173
|
+
@jax_dataclasses.pytree_dataclass
|
174
|
+
class RelaxedRigidContacts(ContactModel):
|
175
|
+
"""Relaxed rigid contacts model."""
|
176
|
+
|
177
|
+
parameters: RelaxedRigidContactsParams = dataclasses.field(
|
178
|
+
default_factory=RelaxedRigidContactsParams
|
179
|
+
)
|
180
|
+
|
181
|
+
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
182
|
+
default_factory=FlatTerrain
|
183
|
+
)
|
184
|
+
|
185
|
+
def compute_contact_forces(
|
186
|
+
self,
|
187
|
+
position: jtp.Vector,
|
188
|
+
velocity: jtp.Vector,
|
189
|
+
model: js.model.JaxSimModel,
|
190
|
+
data: js.data.JaxSimModelData,
|
191
|
+
link_forces: jtp.MatrixLike | None = None,
|
192
|
+
joint_force_references: jtp.VectorLike | None = None,
|
193
|
+
) -> tuple[jtp.Vector, tuple[Any, ...]]:
|
194
|
+
"""
|
195
|
+
Compute the contact forces.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
position: The position of the collidable point.
|
199
|
+
velocity: The linear velocity of the collidable point.
|
200
|
+
model: The `JaxSimModel` instance.
|
201
|
+
data: The `JaxSimModelData` instance.
|
202
|
+
link_forces:
|
203
|
+
Optional `(n_links, 6)` matrix of external forces acting on the links,
|
204
|
+
expressed in the same representation of data.
|
205
|
+
joint_force_references:
|
206
|
+
Optional `(n_joints,)` vector of joint forces.
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
A tuple containing the contact forces.
|
210
|
+
"""
|
211
|
+
|
212
|
+
link_forces = (
|
213
|
+
link_forces
|
214
|
+
if link_forces is not None
|
215
|
+
else jnp.zeros((model.number_of_links(), 6))
|
216
|
+
)
|
217
|
+
|
218
|
+
joint_force_references = (
|
219
|
+
joint_force_references
|
220
|
+
if joint_force_references is not None
|
221
|
+
else jnp.zeros(model.number_of_joints())
|
222
|
+
)
|
223
|
+
|
224
|
+
references = js.references.JaxSimModelReferences.build(
|
225
|
+
model=model,
|
226
|
+
data=data,
|
227
|
+
velocity_representation=data.velocity_representation,
|
228
|
+
link_forces=link_forces,
|
229
|
+
joint_force_references=joint_force_references,
|
230
|
+
)
|
231
|
+
|
232
|
+
def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
|
233
|
+
x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))
|
234
|
+
|
235
|
+
n̂ = self.terrain.normal(x=x, y=y).squeeze()
|
236
|
+
h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
|
237
|
+
|
238
|
+
return jnp.dot(h, n̂)
|
239
|
+
|
240
|
+
# Compute the activation state of the collidable points
|
241
|
+
δ = jax.vmap(_detect_contact)(*position.T)
|
242
|
+
|
243
|
+
with (
|
244
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
245
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
246
|
+
):
|
247
|
+
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
248
|
+
Jl_WC = jnp.vstack(
|
249
|
+
jax.vmap(lambda J, height: J * (height < 0))(
|
250
|
+
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
251
|
+
)
|
252
|
+
)
|
253
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
254
|
+
BW_ν̇_free = jnp.hstack(
|
255
|
+
js.ode.system_acceleration(
|
256
|
+
model=model,
|
257
|
+
data=data,
|
258
|
+
link_forces=references.link_forces(model=model, data=data),
|
259
|
+
)
|
260
|
+
)
|
261
|
+
BW_ν = data.generalized_velocity()
|
262
|
+
J̇_WC = jnp.vstack(
|
263
|
+
jax.vmap(lambda J̇, height: J̇ * (height < 0))(
|
264
|
+
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
|
265
|
+
),
|
266
|
+
)
|
267
|
+
|
268
|
+
a_ref, R, K, D = self._regularizers(
|
269
|
+
model=model,
|
270
|
+
penetration=δ,
|
271
|
+
velocity=velocity,
|
272
|
+
parameters=self.parameters,
|
273
|
+
)
|
274
|
+
|
275
|
+
G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
|
276
|
+
CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
|
277
|
+
|
278
|
+
# Calculate quantities for the linear optimization problem.
|
279
|
+
A = G + R
|
280
|
+
b = CW_al_free_WC - a_ref
|
281
|
+
|
282
|
+
objective = lambda x: jnp.sum(jnp.square(A @ x + b))
|
283
|
+
|
284
|
+
# Compute the 3D linear force in C[W] frame
|
285
|
+
opt = jaxopt.LBFGS(
|
286
|
+
fun=objective,
|
287
|
+
maxiter=self.parameters.max_iterations,
|
288
|
+
tol=self.parameters.tolerance,
|
289
|
+
maxls=30,
|
290
|
+
history_size=10,
|
291
|
+
max_stepsize=100.0,
|
292
|
+
)
|
293
|
+
|
294
|
+
init_params = (
|
295
|
+
K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
|
296
|
+
+ D[:, jnp.newaxis] * velocity
|
297
|
+
).flatten()
|
298
|
+
|
299
|
+
CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
|
300
|
+
|
301
|
+
def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
|
302
|
+
W_Xf_CW = Adjoint.from_transform(
|
303
|
+
W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
|
304
|
+
inverse=True,
|
305
|
+
).T
|
306
|
+
return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
|
307
|
+
|
308
|
+
W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
|
309
|
+
|
310
|
+
return W_f_C, (None,)
|
311
|
+
|
312
|
+
@staticmethod
|
313
|
+
def _regularizers(
|
314
|
+
model: js.model.JaxSimModel,
|
315
|
+
penetration: jtp.Array,
|
316
|
+
velocity: jtp.Array,
|
317
|
+
parameters: RelaxedRigidContactsParams,
|
318
|
+
) -> tuple:
|
319
|
+
"""
|
320
|
+
Compute the contact jacobian and the reference acceleration.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
model: The jaxsim model.
|
324
|
+
penetration: The penetration of the collidable points.
|
325
|
+
velocity: The velocity of the collidable points.
|
326
|
+
parameters: The parameters of the relaxed rigid contacts model.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
|
330
|
+
"""
|
331
|
+
|
332
|
+
Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
|
333
|
+
parameters
|
334
|
+
)
|
335
|
+
|
336
|
+
def _imp_aref(
|
337
|
+
penetration: jtp.Array,
|
338
|
+
velocity: jtp.Array,
|
339
|
+
) -> tuple[jtp.Array, jtp.Array]:
|
340
|
+
"""
|
341
|
+
Calculates impedance and offset acceleration in constraint frame.
|
342
|
+
|
343
|
+
Args:
|
344
|
+
penetration: penetration in constraint frame
|
345
|
+
velocity: velocity in constraint frame
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
a_ref: offset acceleration in constraint frame
|
349
|
+
R: regularization matrix
|
350
|
+
K: computed stiffness
|
351
|
+
D: computed damping
|
352
|
+
"""
|
353
|
+
position = jnp.zeros(shape=(3,)).at[2].set(penetration)
|
354
|
+
|
355
|
+
imp_x = jnp.abs(position) / width
|
356
|
+
imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
|
357
|
+
|
358
|
+
imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
|
359
|
+
|
360
|
+
imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
|
361
|
+
|
362
|
+
imp = jnp.clip(ξ_min + imp_y * (ξ_max - ξ_min), ξ_min, ξ_max)
|
363
|
+
imp = jnp.atleast_1d(jnp.where(imp_x > 1.0, ξ_max, imp))
|
364
|
+
|
365
|
+
# When passing negative values, K and D represent a spring and damper, respectively.
|
366
|
+
K_f = jnp.where(K < 0, -K / ξ_max**2, 1 / (ξ_max * Ω * ζ) ** 2)
|
367
|
+
D_f = jnp.where(D < 0, -D / ξ_max, 2 / (ξ_max * Ω))
|
368
|
+
|
369
|
+
a_ref = -jnp.atleast_1d(D_f * velocity + K_f * imp * position)
|
370
|
+
|
371
|
+
return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
|
372
|
+
|
373
|
+
def _compute_row(
|
374
|
+
*,
|
375
|
+
link_idx: jtp.Float,
|
376
|
+
penetration: jtp.Array,
|
377
|
+
velocity: jtp.Array,
|
378
|
+
) -> tuple[jtp.Array, jtp.Array]:
|
379
|
+
|
380
|
+
# Compute the reference acceleration.
|
381
|
+
ξ, a_ref, K, D = _imp_aref(
|
382
|
+
penetration=penetration,
|
383
|
+
velocity=velocity,
|
384
|
+
)
|
385
|
+
|
386
|
+
# Compute the regularization terms.
|
387
|
+
R = (
|
388
|
+
(2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
|
389
|
+
* (1 + μ**2)
|
390
|
+
@ jnp.linalg.inv(M_L[link_idx, :3, :3])
|
391
|
+
)
|
392
|
+
|
393
|
+
return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
|
394
|
+
|
395
|
+
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
396
|
+
|
397
|
+
a_ref, R, K, D = jax.tree.map(
|
398
|
+
jnp.concatenate,
|
399
|
+
(
|
400
|
+
*jax.vmap(_compute_row)(
|
401
|
+
link_idx=jnp.array(
|
402
|
+
model.kin_dyn_parameters.contact_parameters.body
|
403
|
+
),
|
404
|
+
penetration=penetration,
|
405
|
+
velocity=velocity,
|
406
|
+
),
|
407
|
+
),
|
408
|
+
)
|
409
|
+
return a_ref, jnp.diag(R), K, D
|
jaxsim/rbda/contacts/rigid.py
CHANGED
@@ -9,7 +9,6 @@ import jax_dataclasses
|
|
9
9
|
|
10
10
|
import jaxsim.api as js
|
11
11
|
import jaxsim.typing as jtp
|
12
|
-
from jaxsim import math
|
13
12
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
14
13
|
from jaxsim.terrain import FlatTerrain, Terrain
|
15
14
|
|
@@ -214,6 +213,7 @@ class RigidContacts(ContactModel):
|
|
214
213
|
model: js.model.JaxSimModel,
|
215
214
|
data: js.data.JaxSimModelData,
|
216
215
|
link_forces: jtp.MatrixLike | None = None,
|
216
|
+
joint_force_references: jtp.VectorLike | None = None,
|
217
217
|
regularization_term: jtp.FloatLike = 1e-6,
|
218
218
|
) -> tuple[jtp.Vector, tuple[Any, ...]]:
|
219
219
|
"""
|
@@ -227,6 +227,8 @@ class RigidContacts(ContactModel):
|
|
227
227
|
link_forces:
|
228
228
|
Optional `(n_links, 6)` matrix of external forces acting on the links,
|
229
229
|
expressed in the same representation of data.
|
230
|
+
joint_force_references:
|
231
|
+
Optional `(n_joints,)` vector of joint forces.
|
230
232
|
regularization_term:
|
231
233
|
The regularization term to add to the diagonal of the Delassus
|
232
234
|
matrix for better numerical conditioning.
|
@@ -244,6 +246,12 @@ class RigidContacts(ContactModel):
|
|
244
246
|
else jnp.zeros((model.number_of_links(), 6))
|
245
247
|
)
|
246
248
|
|
249
|
+
joint_force_references = (
|
250
|
+
joint_force_references
|
251
|
+
if joint_force_references is not None
|
252
|
+
else jnp.zeros((model.number_of_joints(),))
|
253
|
+
)
|
254
|
+
|
247
255
|
# Compute kin-dyn quantities used in the contact model
|
248
256
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
249
257
|
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
@@ -270,11 +278,20 @@ class RigidContacts(ContactModel):
|
|
270
278
|
data=data,
|
271
279
|
velocity_representation=data.velocity_representation,
|
272
280
|
link_forces=link_forces,
|
281
|
+
joint_force_references=joint_force_references,
|
273
282
|
)
|
274
283
|
|
275
|
-
with
|
276
|
-
|
277
|
-
|
284
|
+
with (
|
285
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
286
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
287
|
+
):
|
288
|
+
BW_ν̇_free = jnp.hstack(
|
289
|
+
js.ode.system_acceleration(
|
290
|
+
model=model,
|
291
|
+
data=data,
|
292
|
+
joint_forces=references.joint_force_references(model=model),
|
293
|
+
link_forces=references.link_forces(model=model, data=data),
|
294
|
+
)
|
278
295
|
)
|
279
296
|
|
280
297
|
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
|
@@ -380,43 +397,6 @@ class RigidContacts(ContactModel):
|
|
380
397
|
n_constraints = 6 * n_collidable_points
|
381
398
|
return jnp.zeros(shape=(n_constraints,))
|
382
399
|
|
383
|
-
@staticmethod
|
384
|
-
def _compute_mixed_nu_dot_free(
|
385
|
-
model: js.model.JaxSimModel,
|
386
|
-
data: js.data.JaxSimModelData,
|
387
|
-
references: js.references.JaxSimModelReferences | None = None,
|
388
|
-
) -> jtp.Array:
|
389
|
-
references = (
|
390
|
-
references
|
391
|
-
if references is not None
|
392
|
-
else js.references.JaxSimModelReferences.zero(model=model, data=data)
|
393
|
-
)
|
394
|
-
|
395
|
-
with (
|
396
|
-
data.switch_velocity_representation(VelRepr.Mixed),
|
397
|
-
references.switch_velocity_representation(VelRepr.Mixed),
|
398
|
-
):
|
399
|
-
BW_v_WB = data.base_velocity()
|
400
|
-
W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
|
401
|
-
W_v̇_WB, s̈ = js.ode.system_acceleration(
|
402
|
-
model=model,
|
403
|
-
data=data,
|
404
|
-
joint_forces=references.joint_force_references(model=model),
|
405
|
-
link_forces=references.link_forces(model=model, data=data),
|
406
|
-
)
|
407
|
-
|
408
|
-
# Convert the inertial-fixed base acceleration to a mixed base acceleration.
|
409
|
-
W_H_B = data.base_transform()
|
410
|
-
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
411
|
-
BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
|
412
|
-
term1 = BW_X_W @ W_v̇_WB
|
413
|
-
term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
|
414
|
-
BW_v̇_WB = term1 - term2
|
415
|
-
|
416
|
-
BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
|
417
|
-
|
418
|
-
return BW_ν̇
|
419
|
-
|
420
400
|
@staticmethod
|
421
401
|
def _linear_acceleration_of_collidable_points(
|
422
402
|
model: js.model.JaxSimModel,
|