jaxsim 0.4.3.dev231__py3-none-any.whl → 0.4.3.dev245__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +48 -77
- jaxsim/api/frame.py +1 -1
- jaxsim/api/kin_dyn_parameters.py +3 -3
- jaxsim/api/model.py +87 -59
- jaxsim/api/ode.py +25 -34
- jaxsim/rbda/contacts/common.py +137 -3
- jaxsim/rbda/contacts/relaxed_rigid.py +48 -15
- jaxsim/rbda/contacts/rigid.py +26 -9
- jaxsim/rbda/contacts/soft.py +9 -5
- jaxsim/rbda/contacts/visco_elastic.py +94 -52
- {jaxsim-0.4.3.dev231.dist-info → jaxsim-0.4.3.dev245.dist-info}/METADATA +1 -1
- {jaxsim-0.4.3.dev231.dist-info → jaxsim-0.4.3.dev245.dist-info}/RECORD +16 -16
- {jaxsim-0.4.3.dev231.dist-info → jaxsim-0.4.3.dev245.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev231.dist-info → jaxsim-0.4.3.dev245.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev231.dist-info → jaxsim-0.4.3.dev245.dist-info}/top_level.txt +0 -0
jaxsim/rbda/contacts/common.py
CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import abc
|
4
4
|
import functools
|
5
|
-
from typing import Any
|
6
5
|
|
7
6
|
import jax
|
8
7
|
import jax.numpy as jnp
|
@@ -10,6 +9,7 @@ import jax.numpy as jnp
|
|
10
9
|
import jaxsim.api as js
|
11
10
|
import jaxsim.terrain
|
12
11
|
import jaxsim.typing as jtp
|
12
|
+
from jaxsim.api.common import ModelDataWithVelocityRepresentation
|
13
13
|
from jaxsim.utils import JaxsimDataclass
|
14
14
|
|
15
15
|
try:
|
@@ -131,7 +131,7 @@ class ContactModel(JaxsimDataclass):
|
|
131
131
|
model: js.model.JaxSimModel,
|
132
132
|
data: js.data.JaxSimModelData,
|
133
133
|
**kwargs,
|
134
|
-
) -> tuple[jtp.Matrix,
|
134
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
135
135
|
"""
|
136
136
|
Compute the contact forces.
|
137
137
|
|
@@ -142,11 +142,145 @@ class ContactModel(JaxsimDataclass):
|
|
142
142
|
Returns:
|
143
143
|
A tuple containing as first element the computed 6D contact force applied to
|
144
144
|
the contact points and expressed in the world frame, and as second element
|
145
|
-
a
|
145
|
+
a dictionary of optional additional information.
|
146
146
|
"""
|
147
147
|
|
148
148
|
pass
|
149
149
|
|
150
|
+
def compute_link_contact_forces(
|
151
|
+
self,
|
152
|
+
model: js.model.JaxSimModel,
|
153
|
+
data: js.data.JaxSimModelData,
|
154
|
+
**kwargs,
|
155
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
156
|
+
"""
|
157
|
+
Compute the link contact forces.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
model: The robot model considered by the contact model.
|
161
|
+
data: The data of the considered model.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
A tuple containing as first element the 6D contact force applied to the
|
165
|
+
links and expressed in the frame of the velocity representation of data,
|
166
|
+
and as second element a dictionary of optional additional information.
|
167
|
+
"""
|
168
|
+
|
169
|
+
# Compute the contact forces expressed in the inertial frame.
|
170
|
+
# This function, contrarily to `compute_contact_forces`, already handles how
|
171
|
+
# the optional kwargs should be passed to the specific contact models.
|
172
|
+
W_f_C, aux_dict = js.contact.collidable_point_dynamics(
|
173
|
+
model=model, data=data, **kwargs
|
174
|
+
)
|
175
|
+
|
176
|
+
# Compute the 6D forces applied to the links equivalent to the forces applied
|
177
|
+
# to the frames associated to the collidable points.
|
178
|
+
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
|
179
|
+
|
180
|
+
W_f_L = self.link_forces_from_contact_forces(
|
181
|
+
model=model, data=data, contact_forces=W_f_C
|
182
|
+
)
|
183
|
+
|
184
|
+
# Store the link forces in the references object for easy conversion.
|
185
|
+
references = js.references.JaxSimModelReferences.build(
|
186
|
+
model=model,
|
187
|
+
data=data,
|
188
|
+
link_forces=W_f_L,
|
189
|
+
velocity_representation=jaxsim.VelRepr.Inertial,
|
190
|
+
)
|
191
|
+
|
192
|
+
# Convert the link forces to the frame corresponding to the velocity
|
193
|
+
# representation of data.
|
194
|
+
with references.switch_velocity_representation(data.velocity_representation):
|
195
|
+
f_L = references.link_forces(model=model, data=data)
|
196
|
+
|
197
|
+
return f_L, aux_dict
|
198
|
+
|
199
|
+
@staticmethod
|
200
|
+
def link_forces_from_contact_forces(
|
201
|
+
model: js.model.JaxSimModel,
|
202
|
+
data: js.data.JaxSimModelData,
|
203
|
+
*,
|
204
|
+
contact_forces: jtp.MatrixLike,
|
205
|
+
) -> jtp.Matrix:
|
206
|
+
"""
|
207
|
+
Compute the link forces from the contact forces.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
model: The robot model considered by the contact model.
|
211
|
+
data: The data of the considered model.
|
212
|
+
contact_forces: The contact forces computed by the contact model.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
The 6D contact forces applied to the links and expressed in the frame of
|
216
|
+
the velocity representation of data.
|
217
|
+
"""
|
218
|
+
|
219
|
+
# Convert the contact forces to a JAX array.
|
220
|
+
f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
|
221
|
+
|
222
|
+
# Get the pose of the enabled collidable points.
|
223
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
224
|
+
|
225
|
+
# Convert the contact forces to inertial-fixed representation.
|
226
|
+
W_f_C = jax.vmap(
|
227
|
+
lambda f_C, W_H_C: (
|
228
|
+
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
229
|
+
array=f_C,
|
230
|
+
other_representation=data.velocity_representation,
|
231
|
+
transform=W_H_C,
|
232
|
+
is_force=True,
|
233
|
+
)
|
234
|
+
)
|
235
|
+
)(f_C, W_H_C)
|
236
|
+
|
237
|
+
# Get the object storing the contact parameters of the model.
|
238
|
+
contact_parameters = model.kin_dyn_parameters.contact_parameters
|
239
|
+
|
240
|
+
# Extract the indices corresponding to the enabled collidable points.
|
241
|
+
indices_of_enabled_collidable_points = (
|
242
|
+
contact_parameters.indices_of_enabled_collidable_points
|
243
|
+
)
|
244
|
+
|
245
|
+
# Construct the vector defining the parent link index of each collidable point.
|
246
|
+
# We use this vector to sum the 6D forces of all collidable points rigidly
|
247
|
+
# attached to the same link.
|
248
|
+
parent_link_index_of_collidable_points = jnp.array(
|
249
|
+
contact_parameters.body, dtype=int
|
250
|
+
)[indices_of_enabled_collidable_points]
|
251
|
+
|
252
|
+
# Create the mask that associate each collidable point to their parent link.
|
253
|
+
# We use this mask to sum the collidable points to the right link.
|
254
|
+
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
|
255
|
+
model.number_of_links()
|
256
|
+
)
|
257
|
+
|
258
|
+
# Sum the forces of all collidable points rigidly attached to a body.
|
259
|
+
# Since the contact forces W_f_C are expressed in the world frame,
|
260
|
+
# we don't need any coordinate transformation.
|
261
|
+
W_f_L = mask.T @ W_f_C
|
262
|
+
|
263
|
+
# Compute the link transforms.
|
264
|
+
W_H_L = (
|
265
|
+
js.model.forward_kinematics(model=model, data=data)
|
266
|
+
if data.velocity_representation is not jaxsim.VelRepr.Inertial
|
267
|
+
else jnp.zeros(shape=(model.number_of_links(), 4, 4))
|
268
|
+
)
|
269
|
+
|
270
|
+
# Convert the inertial-fixed link forces to the velocity representation of data.
|
271
|
+
f_L = jax.vmap(
|
272
|
+
lambda W_f_L, W_H_L: (
|
273
|
+
ModelDataWithVelocityRepresentation.inertial_to_other_representation(
|
274
|
+
array=W_f_L,
|
275
|
+
other_representation=data.velocity_representation,
|
276
|
+
transform=W_H_L,
|
277
|
+
is_force=True,
|
278
|
+
)
|
279
|
+
)
|
280
|
+
)(W_f_L, W_H_L)
|
281
|
+
|
282
|
+
return f_L
|
283
|
+
|
150
284
|
@classmethod
|
151
285
|
def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
|
152
286
|
"""
|
@@ -120,19 +120,44 @@ class RelaxedRigidContactsParams(common.ContactsParams):
|
|
120
120
|
|
121
121
|
return cls(
|
122
122
|
time_constant=jnp.array(
|
123
|
-
|
123
|
+
(
|
124
|
+
time_constant
|
125
|
+
if time_constant is not None
|
126
|
+
else default("time_constant")
|
127
|
+
),
|
128
|
+
dtype=float,
|
124
129
|
),
|
125
130
|
damping_coefficient=jnp.array(
|
126
|
-
|
131
|
+
(
|
132
|
+
damping_coefficient
|
133
|
+
if damping_coefficient is not None
|
134
|
+
else default("damping_coefficient")
|
135
|
+
),
|
136
|
+
dtype=float,
|
137
|
+
),
|
138
|
+
d_min=jnp.array(
|
139
|
+
d_min if d_min is not None else default("d_min"), dtype=float
|
140
|
+
),
|
141
|
+
d_max=jnp.array(
|
142
|
+
d_max if d_max is not None else default("d_max"), dtype=float
|
143
|
+
),
|
144
|
+
width=jnp.array(
|
145
|
+
width if width is not None else default("width"), dtype=float
|
146
|
+
),
|
147
|
+
midpoint=jnp.array(
|
148
|
+
midpoint if midpoint is not None else default("midpoint"), dtype=float
|
127
149
|
),
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
150
|
+
power=jnp.array(
|
151
|
+
power if power is not None else default("power"), dtype=float
|
152
|
+
),
|
153
|
+
stiffness=jnp.array(
|
154
|
+
stiffness if stiffness is not None else default("stiffness"),
|
155
|
+
dtype=float,
|
156
|
+
),
|
157
|
+
damping=jnp.array(
|
158
|
+
damping if damping is not None else default("damping"), dtype=float
|
159
|
+
),
|
160
|
+
mu=jnp.array(mu if mu is not None else default("mu"), dtype=float),
|
136
161
|
)
|
137
162
|
|
138
163
|
def valid(self) -> jtp.BoolLike:
|
@@ -210,7 +235,9 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
210
235
|
|
211
236
|
# Create the solver options to set by combining the default solver options
|
212
237
|
# with the user-provided solver options.
|
213
|
-
solver_options = default_solver_options | (
|
238
|
+
solver_options = default_solver_options | (
|
239
|
+
solver_options if solver_options is not None else {}
|
240
|
+
)
|
214
241
|
|
215
242
|
# Make sure that the solver options are hashable.
|
216
243
|
# We need to check this because the solver options are static.
|
@@ -223,9 +250,15 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
223
250
|
|
224
251
|
return cls(
|
225
252
|
parameters=(
|
226
|
-
parameters
|
253
|
+
parameters
|
254
|
+
if parameters is not None
|
255
|
+
else cls.__dataclass_fields__["parameters"].default_factory()
|
256
|
+
),
|
257
|
+
terrain=(
|
258
|
+
terrain
|
259
|
+
if terrain is not None
|
260
|
+
else cls.__dataclass_fields__["terrain"].default_factory()
|
227
261
|
),
|
228
|
-
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
|
229
262
|
_solver_options_keys=tuple(solver_options.keys()),
|
230
263
|
_solver_options_values=tuple(solver_options.values()),
|
231
264
|
)
|
@@ -238,7 +271,7 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
238
271
|
*,
|
239
272
|
link_forces: jtp.MatrixLike | None = None,
|
240
273
|
joint_force_references: jtp.VectorLike | None = None,
|
241
|
-
) -> tuple[jtp.Matrix,
|
274
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
242
275
|
"""
|
243
276
|
Compute the contact forces.
|
244
277
|
|
@@ -458,7 +491,7 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
458
491
|
),
|
459
492
|
)(CW_fl_C, W_H_C)
|
460
493
|
|
461
|
-
return W_f_C,
|
494
|
+
return W_f_C, {}
|
462
495
|
|
463
496
|
@staticmethod
|
464
497
|
def _regularizers(
|
jaxsim/rbda/contacts/rigid.py
CHANGED
@@ -66,9 +66,17 @@ class RigidContactsParams(ContactsParams):
|
|
66
66
|
"""Create a `RigidContactParams` instance"""
|
67
67
|
|
68
68
|
return cls(
|
69
|
-
mu=
|
70
|
-
|
71
|
-
|
69
|
+
mu=jnp.array(
|
70
|
+
mu
|
71
|
+
if mu is not None
|
72
|
+
else cls.__dataclass_fields__["mu"].default_factory()
|
73
|
+
).astype(float),
|
74
|
+
K=jnp.array(
|
75
|
+
K if K is not None else cls.__dataclass_fields__["K"].default_factory()
|
76
|
+
).astype(float),
|
77
|
+
D=jnp.array(
|
78
|
+
D if D is not None else cls.__dataclass_fields__["D"].default_factory()
|
79
|
+
).astype(float),
|
72
80
|
)
|
73
81
|
|
74
82
|
def valid(self) -> jtp.BoolLike:
|
@@ -147,7 +155,9 @@ class RigidContacts(ContactModel):
|
|
147
155
|
|
148
156
|
# Create the solver options to set by combining the default solver options
|
149
157
|
# with the user-provided solver options.
|
150
|
-
solver_options = default_solver_options | (
|
158
|
+
solver_options = default_solver_options | (
|
159
|
+
solver_options if solver_options is not None else {}
|
160
|
+
)
|
151
161
|
|
152
162
|
# Make sure that the solver options are hashable.
|
153
163
|
# We need to check this because the solver options are static.
|
@@ -160,12 +170,19 @@ class RigidContacts(ContactModel):
|
|
160
170
|
|
161
171
|
return cls(
|
162
172
|
parameters=(
|
163
|
-
parameters
|
173
|
+
parameters
|
174
|
+
if parameters is not None
|
175
|
+
else cls.__dataclass_fields__["parameters"].default_factory()
|
176
|
+
),
|
177
|
+
terrain=(
|
178
|
+
terrain
|
179
|
+
if terrain is not None
|
180
|
+
else cls.__dataclass_fields__["terrain"].default_factory()
|
164
181
|
),
|
165
|
-
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
|
166
182
|
regularization_delassus=float(
|
167
183
|
regularization_delassus
|
168
|
-
|
184
|
+
if regularization_delassus is not None
|
185
|
+
else cls.__dataclass_fields__["regularization_delassus"].default
|
169
186
|
),
|
170
187
|
_solver_options_keys=tuple(solver_options.keys()),
|
171
188
|
_solver_options_values=tuple(solver_options.values()),
|
@@ -242,7 +259,7 @@ class RigidContacts(ContactModel):
|
|
242
259
|
*,
|
243
260
|
link_forces: jtp.MatrixLike | None = None,
|
244
261
|
joint_force_references: jtp.VectorLike | None = None,
|
245
|
-
) -> tuple[jtp.Matrix,
|
262
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
246
263
|
"""
|
247
264
|
Compute the contact forces.
|
248
265
|
|
@@ -402,7 +419,7 @@ class RigidContacts(ContactModel):
|
|
402
419
|
),
|
403
420
|
)(CW_fl_C, W_H_C)
|
404
421
|
|
405
|
-
return W_f_C,
|
422
|
+
return W_f_C, {}
|
406
423
|
|
407
424
|
@staticmethod
|
408
425
|
def _delassus_matrix(
|
jaxsim/rbda/contacts/soft.py
CHANGED
@@ -237,9 +237,13 @@ class SoftContacts(common.ContactModel):
|
|
237
237
|
else cls.__dataclass_fields__["parameters"].default_factory()
|
238
238
|
)
|
239
239
|
|
240
|
-
return
|
240
|
+
return cls(
|
241
241
|
parameters=parameters,
|
242
|
-
terrain=
|
242
|
+
terrain=(
|
243
|
+
terrain
|
244
|
+
if terrain is not None
|
245
|
+
else cls.__dataclass_fields__["terrain"].default_factory()
|
246
|
+
),
|
243
247
|
)
|
244
248
|
|
245
249
|
@classmethod
|
@@ -423,7 +427,7 @@ class SoftContacts(common.ContactModel):
|
|
423
427
|
self,
|
424
428
|
model: js.model.JaxSimModel,
|
425
429
|
data: js.data.JaxSimModelData,
|
426
|
-
) -> tuple[jtp.Matrix,
|
430
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
427
431
|
"""
|
428
432
|
Compute the contact forces.
|
429
433
|
|
@@ -433,7 +437,7 @@ class SoftContacts(common.ContactModel):
|
|
433
437
|
|
434
438
|
Returns:
|
435
439
|
A tuple containing as first element the computed contact forces, and as
|
436
|
-
second element
|
440
|
+
second element a dictionary with derivative of the material deformation.
|
437
441
|
"""
|
438
442
|
|
439
443
|
# Initialize the model and data this contact model is operating on.
|
@@ -460,4 +464,4 @@ class SoftContacts(common.ContactModel):
|
|
460
464
|
)
|
461
465
|
)(W_p_C, W_ṗ_C, m)
|
462
466
|
|
463
|
-
return W_f, (m
|
467
|
+
return W_f, dict(m_dot=ṁ)
|
@@ -13,6 +13,7 @@ import jaxsim.api as js
|
|
13
13
|
import jaxsim.exceptions
|
14
14
|
import jaxsim.typing as jtp
|
15
15
|
from jaxsim import logging
|
16
|
+
from jaxsim.api.common import ModelDataWithVelocityRepresentation
|
16
17
|
from jaxsim.math import StandardGravity
|
17
18
|
from jaxsim.terrain import FlatTerrain, Terrain
|
18
19
|
|
@@ -235,11 +236,17 @@ class ViscoElasticContacts(common.ContactModel):
|
|
235
236
|
else cls.__dataclass_fields__["parameters"].default_factory()
|
236
237
|
)
|
237
238
|
|
238
|
-
return
|
239
|
+
return cls(
|
239
240
|
parameters=parameters,
|
240
|
-
terrain=
|
241
|
+
terrain=(
|
242
|
+
terrain
|
243
|
+
if terrain is not None
|
244
|
+
else cls.__dataclass_fields__["terrain"].default_factory()
|
245
|
+
),
|
241
246
|
max_squarings=int(
|
242
|
-
max_squarings
|
247
|
+
max_squarings
|
248
|
+
if max_squarings is not None
|
249
|
+
else cls.__dataclass_fields__["max_squarings"].default
|
243
250
|
),
|
244
251
|
)
|
245
252
|
|
@@ -266,7 +273,7 @@ class ViscoElasticContacts(common.ContactModel):
|
|
266
273
|
dt: jtp.FloatLike | None = None,
|
267
274
|
link_forces: jtp.MatrixLike | None = None,
|
268
275
|
joint_force_references: jtp.VectorLike | None = None,
|
269
|
-
) -> tuple[jtp.Matrix,
|
276
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
270
277
|
"""
|
271
278
|
Compute the contact forces.
|
272
279
|
|
@@ -291,7 +298,7 @@ class ViscoElasticContacts(common.ContactModel):
|
|
291
298
|
Returns:
|
292
299
|
A tuple containing as first element the computed 6D contact force applied to
|
293
300
|
the contact point and expressed in the world frame, and as second element
|
294
|
-
a
|
301
|
+
a dictionary of optional additional information.
|
295
302
|
"""
|
296
303
|
|
297
304
|
# Initialize the model and data this contact model is operating on.
|
@@ -315,8 +322,8 @@ class ViscoElasticContacts(common.ContactModel):
|
|
315
322
|
model=model,
|
316
323
|
data=data,
|
317
324
|
dt=jnp.array(dt).astype(float),
|
318
|
-
joint_force_references=joint_force_references,
|
319
325
|
link_forces=link_forces,
|
326
|
+
joint_force_references=joint_force_references,
|
320
327
|
indices_of_enabled_collidable_points=indices_of_enabled_collidable_points,
|
321
328
|
max_squarings=self.max_squarings,
|
322
329
|
)
|
@@ -334,11 +341,13 @@ class ViscoElasticContacts(common.ContactModel):
|
|
334
341
|
|
335
342
|
# Vmapped transformation from mixed to inertial-fixed representation.
|
336
343
|
compute_forces_inertial_fixed_vmap = jax.vmap(
|
337
|
-
lambda CW_fl_C, W_H_C:
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
344
|
+
lambda CW_fl_C, W_H_C: (
|
345
|
+
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
346
|
+
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
|
347
|
+
other_representation=jaxsim.VelRepr.Mixed,
|
348
|
+
transform=W_H_C,
|
349
|
+
is_force=True,
|
350
|
+
)
|
342
351
|
)
|
343
352
|
)
|
344
353
|
|
@@ -347,7 +356,7 @@ class ViscoElasticContacts(common.ContactModel):
|
|
347
356
|
lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C)
|
348
357
|
)(jnp.stack([CW_f̅l, CW_fl̿]))
|
349
358
|
|
350
|
-
return W_f̅_C, (W_f̿_C, m_tf)
|
359
|
+
return W_f̅_C, dict(W_f_avg2_C=W_f̿_C, m_tf=m_tf)
|
351
360
|
|
352
361
|
@staticmethod
|
353
362
|
@functools.partial(jax.jit, static_argnames=("max_squarings",))
|
@@ -407,8 +416,8 @@ class ViscoElasticContacts(common.ContactModel):
|
|
407
416
|
A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics(
|
408
417
|
model=model,
|
409
418
|
data=data,
|
410
|
-
joint_force_references=joint_force_references,
|
411
419
|
link_forces=link_forces,
|
420
|
+
joint_force_references=joint_force_references,
|
412
421
|
indices_of_enabled_collidable_points=indices,
|
413
422
|
p_t0=p_t0,
|
414
423
|
v_t0=v_t0,
|
@@ -657,8 +666,8 @@ class ViscoElasticContacts(common.ContactModel):
|
|
657
666
|
BW_v̇_free_WB, s̈_free = js.ode.system_acceleration(
|
658
667
|
model=model,
|
659
668
|
data=data,
|
660
|
-
joint_force_references=references.joint_force_references(model=model),
|
661
669
|
link_forces=references.link_forces(model=model, data=data),
|
670
|
+
joint_force_references=references.joint_force_references(model=model),
|
662
671
|
)
|
663
672
|
|
664
673
|
# Pack the free system acceleration in mixed representation.
|
@@ -688,7 +697,20 @@ class ViscoElasticContacts(common.ContactModel):
|
|
688
697
|
parameters: ViscoElasticContactsParams,
|
689
698
|
terrain: Terrain,
|
690
699
|
) -> tuple[jtp.Matrix, jtp.Vector]:
|
691
|
-
"""
|
700
|
+
"""
|
701
|
+
Linearize the Hunt/Crossley contact model at the initial state.
|
702
|
+
|
703
|
+
Args:
|
704
|
+
position: The position of the contact point.
|
705
|
+
velocity: The velocity of the contact point.
|
706
|
+
tangential_deformation: The tangential deformation of the contact point.
|
707
|
+
parameters: The parameters of the contact model.
|
708
|
+
terrain: The considered terrain.
|
709
|
+
|
710
|
+
Returns:
|
711
|
+
A tuple containing the `A` matrix and the `b` vector of the linear system
|
712
|
+
corresponding to the contact dynamics linearized at the initial state.
|
713
|
+
"""
|
692
714
|
|
693
715
|
# Initialize the state at which the model is linearized.
|
694
716
|
p0 = jnp.array(position, dtype=float).squeeze()
|
@@ -969,58 +991,67 @@ def step(
|
|
969
991
|
assert isinstance(model.contact_model, ViscoElasticContacts)
|
970
992
|
assert isinstance(data.contacts_params, ViscoElasticContactsParams)
|
971
993
|
|
994
|
+
# Compute the contact forces in inertial-fixed representation.
|
995
|
+
# TODO: understand what's wrong in other representations.
|
996
|
+
data_inertial_fixed = data.replace(
|
997
|
+
velocity_representation=jaxsim.VelRepr.Inertial, validate=False
|
998
|
+
)
|
999
|
+
|
1000
|
+
# Create the references object.
|
1001
|
+
references = js.references.JaxSimModelReferences.build(
|
1002
|
+
model=model,
|
1003
|
+
data=data,
|
1004
|
+
link_forces=link_forces,
|
1005
|
+
joint_force_references=joint_force_references,
|
1006
|
+
velocity_representation=data.velocity_representation,
|
1007
|
+
)
|
1008
|
+
|
972
1009
|
# Initialize the time step.
|
973
1010
|
dt = dt if dt is not None else model.time_step
|
974
1011
|
|
975
1012
|
# Compute the contact forces with the exponential integrator.
|
976
|
-
W_f̅_C,
|
1013
|
+
W_f̅_C, aux_data = model.contact_model.compute_contact_forces(
|
977
1014
|
model=model,
|
978
|
-
data=
|
1015
|
+
data=data_inertial_fixed,
|
979
1016
|
dt=jnp.array(dt).astype(float),
|
980
|
-
link_forces=link_forces,
|
981
|
-
joint_force_references=joint_force_references,
|
1017
|
+
link_forces=references.link_forces(model=model, data=data),
|
1018
|
+
joint_force_references=references.joint_force_references(model=model),
|
982
1019
|
)
|
983
1020
|
|
1021
|
+
# Extract the final material deformation and the average of average forces
|
1022
|
+
# from the dictionary containing auxiliary data.
|
1023
|
+
m_tf = aux_data["m_tf"]
|
1024
|
+
W_f̿_C = aux_data["W_f_avg2_C"]
|
1025
|
+
|
984
1026
|
# ===============================
|
985
1027
|
# Compute the link contact forces
|
986
1028
|
# ===============================
|
987
1029
|
|
988
|
-
#
|
989
|
-
#
|
990
|
-
|
991
|
-
model.
|
992
|
-
|
1030
|
+
# Get the link contact forces by summing the forces of contact points belonging
|
1031
|
+
# to the same link.
|
1032
|
+
W_f̅_L, W_f̿_L = jax.vmap(
|
1033
|
+
lambda W_f_C: model.contact_model.link_forces_from_contact_forces(
|
1034
|
+
model=model, data=data_inertial_fixed, contact_forces=W_f_C
|
1035
|
+
)
|
1036
|
+
)(jnp.stack([W_f̅_C, W_f̿_C]))
|
993
1037
|
|
994
1038
|
# Compute the link transforms.
|
995
|
-
W_H_L =
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
# attached to the same link.
|
1000
|
-
parent_link_index_of_collidable_points = jnp.array(
|
1001
|
-
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
1002
|
-
)[indices_of_enabled_collidable_points]
|
1003
|
-
|
1004
|
-
# Create the mask that associate each collidable point to their parent link.
|
1005
|
-
# We use this mask to sum the collidable points to the right link.
|
1006
|
-
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
|
1007
|
-
model.number_of_links()
|
1039
|
+
W_H_L = (
|
1040
|
+
js.model.forward_kinematics(model=model, data=data)
|
1041
|
+
if data.velocity_representation is not jaxsim.VelRepr.Inertial
|
1042
|
+
else jnp.zeros(shape=(model.number_of_links(), 4, 4))
|
1008
1043
|
)
|
1009
1044
|
|
1010
|
-
#
|
1011
|
-
# Since the contact forces W_f_C are expressed in the world frame,
|
1012
|
-
# we don't need any coordinate transformation.
|
1013
|
-
W_f̅_L = mask.T @ W_f̅_C
|
1014
|
-
W_f̿_L = mask.T @ W_f̿_C
|
1015
|
-
|
1016
|
-
# For integration purpose, we need these average of averages expressed in
|
1045
|
+
# For integration purpose, we need the average of average forces expressed in
|
1017
1046
|
# mixed representation.
|
1018
1047
|
LW_f̿_L = jax.vmap(
|
1019
|
-
lambda W_f_L, W_H_L:
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1048
|
+
lambda W_f_L, W_H_L: (
|
1049
|
+
ModelDataWithVelocityRepresentation.inertial_to_other_representation(
|
1050
|
+
array=W_f_L,
|
1051
|
+
other_representation=jaxsim.VelRepr.Mixed,
|
1052
|
+
transform=W_H_L,
|
1053
|
+
is_force=True,
|
1054
|
+
)
|
1024
1055
|
)
|
1025
1056
|
)(W_f̿_L, W_H_L)
|
1026
1057
|
|
@@ -1032,10 +1063,10 @@ def step(
|
|
1032
1063
|
data_tf: js.data.JaxSimModelData = (
|
1033
1064
|
model.contact_model.integrate_data_with_average_contact_forces(
|
1034
1065
|
model=model,
|
1035
|
-
data=
|
1066
|
+
data=data_inertial_fixed,
|
1036
1067
|
dt=dt,
|
1037
|
-
link_forces=link_forces,
|
1038
|
-
joint_force_references=joint_force_references,
|
1068
|
+
link_forces=references.link_forces(model=model, data=data),
|
1069
|
+
joint_force_references=references.joint_force_references(model=model),
|
1039
1070
|
average_link_contact_forces_inertial=W_f̅_L,
|
1040
1071
|
average_of_average_link_contact_forces_mixed=LW_f̿_L,
|
1041
1072
|
)
|
@@ -1046,10 +1077,21 @@ def step(
|
|
1046
1077
|
# be much more accurate than the one computed with the discrete soft contacts.
|
1047
1078
|
with data_tf.mutable_context():
|
1048
1079
|
|
1080
|
+
# Extract the indices corresponding to the enabled collidable points.
|
1081
|
+
# The visco-elastic contact model computed only their contact forces.
|
1082
|
+
indices_of_enabled_collidable_points = (
|
1083
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
1084
|
+
)
|
1085
|
+
|
1049
1086
|
data_tf.state.extended |= {
|
1050
1087
|
"tangential_deformation": data_tf.state.extended["tangential_deformation"]
|
1051
1088
|
.at[indices_of_enabled_collidable_points]
|
1052
1089
|
.set(m_tf)
|
1053
1090
|
}
|
1054
1091
|
|
1092
|
+
# Restore the original velocity representation.
|
1093
|
+
data_tf = data_tf.replace(
|
1094
|
+
velocity_representation=data.velocity_representation, validate=False
|
1095
|
+
)
|
1096
|
+
|
1055
1097
|
return data_tf, {}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev245
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|