jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/contact.py
CHANGED
@@ -1,18 +1,25 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import functools
|
2
4
|
|
3
5
|
import jax
|
4
6
|
import jax.numpy as jnp
|
5
7
|
|
8
|
+
import jaxsim.api as js
|
9
|
+
import jaxsim.exceptions
|
10
|
+
import jaxsim.terrain
|
6
11
|
import jaxsim.typing as jtp
|
7
|
-
from jaxsim
|
12
|
+
from jaxsim import logging
|
13
|
+
from jaxsim.math import Adjoint, Cross, Transform
|
14
|
+
from jaxsim.rbda import contacts
|
8
15
|
|
9
|
-
from . import
|
10
|
-
from . import model as Model
|
16
|
+
from .common import VelRepr
|
11
17
|
|
12
18
|
|
13
19
|
@jax.jit
|
20
|
+
@js.common.named_scope
|
14
21
|
def collidable_point_kinematics(
|
15
|
-
model:
|
22
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
16
23
|
) -> tuple[jtp.Matrix, jtp.Matrix]:
|
17
24
|
"""
|
18
25
|
Compute the position and 3D velocity of the collidable points in the world frame.
|
@@ -30,21 +37,26 @@ def collidable_point_kinematics(
|
|
30
37
|
the linear component of the mixed 6D frame velocity.
|
31
38
|
"""
|
32
39
|
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
40
|
+
# Switch to inertial-fixed since the RBDAs expect velocities in this representation.
|
41
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
42
|
+
|
43
|
+
W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
|
44
|
+
model=model,
|
45
|
+
base_position=data.base_position(),
|
46
|
+
base_quaternion=data.base_orientation(dcm=False),
|
47
|
+
joint_positions=data.joint_positions(model=model),
|
48
|
+
base_linear_velocity=data.base_velocity()[0:3],
|
49
|
+
base_angular_velocity=data.base_velocity()[3:6],
|
50
|
+
joint_velocities=data.joint_velocities(model=model),
|
51
|
+
)
|
41
52
|
|
42
|
-
return W_p_Ci
|
53
|
+
return W_p_Ci, W_ṗ_Ci
|
43
54
|
|
44
55
|
|
45
56
|
@jax.jit
|
57
|
+
@js.common.named_scope
|
46
58
|
def collidable_point_positions(
|
47
|
-
model:
|
59
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
48
60
|
) -> jtp.Matrix:
|
49
61
|
"""
|
50
62
|
Compute the position of the collidable points in the world frame.
|
@@ -57,12 +69,15 @@ def collidable_point_positions(
|
|
57
69
|
The position of the collidable points in the world frame.
|
58
70
|
"""
|
59
71
|
|
60
|
-
|
72
|
+
W_p_Ci, _ = collidable_point_kinematics(model=model, data=data)
|
73
|
+
|
74
|
+
return W_p_Ci
|
61
75
|
|
62
76
|
|
63
77
|
@jax.jit
|
78
|
+
@js.common.named_scope
|
64
79
|
def collidable_point_velocities(
|
65
|
-
model:
|
80
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
66
81
|
) -> jtp.Matrix:
|
67
82
|
"""
|
68
83
|
Compute the 3D velocity of the collidable points in the world frame.
|
@@ -75,13 +90,153 @@ def collidable_point_velocities(
|
|
75
90
|
The 3D velocity of the collidable points.
|
76
91
|
"""
|
77
92
|
|
78
|
-
|
93
|
+
_, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data)
|
94
|
+
|
95
|
+
return W_ṗ_Ci
|
96
|
+
|
97
|
+
|
98
|
+
@jax.jit
|
99
|
+
@js.common.named_scope
|
100
|
+
def collidable_point_forces(
|
101
|
+
model: js.model.JaxSimModel,
|
102
|
+
data: js.data.JaxSimModelData,
|
103
|
+
link_forces: jtp.MatrixLike | None = None,
|
104
|
+
joint_force_references: jtp.VectorLike | None = None,
|
105
|
+
**kwargs,
|
106
|
+
) -> jtp.Matrix:
|
107
|
+
"""
|
108
|
+
Compute the 6D forces applied to each collidable point.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
model: The model to consider.
|
112
|
+
data: The data of the considered model.
|
113
|
+
link_forces:
|
114
|
+
The 6D external forces to apply to the links expressed in the same
|
115
|
+
representation of data.
|
116
|
+
joint_force_references:
|
117
|
+
The joint force references to apply to the joints.
|
118
|
+
kwargs: Additional keyword arguments to pass to the active contact model.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
The 6D forces applied to each collidable point expressed in the frame
|
122
|
+
corresponding to the active representation.
|
123
|
+
"""
|
124
|
+
|
125
|
+
f_Ci, _ = collidable_point_dynamics(
|
126
|
+
model=model,
|
127
|
+
data=data,
|
128
|
+
link_forces=link_forces,
|
129
|
+
joint_force_references=joint_force_references,
|
130
|
+
**kwargs,
|
131
|
+
)
|
132
|
+
|
133
|
+
return f_Ci
|
134
|
+
|
135
|
+
|
136
|
+
@jax.jit
|
137
|
+
@js.common.named_scope
|
138
|
+
def collidable_point_dynamics(
|
139
|
+
model: js.model.JaxSimModel,
|
140
|
+
data: js.data.JaxSimModelData,
|
141
|
+
link_forces: jtp.MatrixLike | None = None,
|
142
|
+
joint_force_references: jtp.VectorLike | None = None,
|
143
|
+
**kwargs,
|
144
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
145
|
+
r"""
|
146
|
+
Compute the 6D force applied to each enabled collidable point.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
model: The model to consider.
|
150
|
+
data: The data of the considered model.
|
151
|
+
link_forces:
|
152
|
+
The 6D external forces to apply to the links expressed in the same
|
153
|
+
representation of data.
|
154
|
+
joint_force_references:
|
155
|
+
The joint force references to apply to the joints.
|
156
|
+
kwargs: Additional keyword arguments to pass to the active contact model.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
The 6D force applied to each enabled collidable point and additional data based
|
160
|
+
on the contact model configured:
|
161
|
+
- Soft: the material deformation rate.
|
162
|
+
- Rigid: no additional data.
|
163
|
+
- QuasiRigid: no additional data.
|
164
|
+
|
165
|
+
Note:
|
166
|
+
The material deformation rate is always returned in the mixed frame
|
167
|
+
`C[W] = ({}^W \mathbf{p}_C, [W])`. This is convenient for integration purpose.
|
168
|
+
Instead, the 6D forces are returned in the active representation.
|
169
|
+
"""
|
170
|
+
|
171
|
+
# Build the common kw arguments to pass to the computation of the contact forces.
|
172
|
+
common_kwargs = dict(
|
173
|
+
link_forces=link_forces,
|
174
|
+
joint_force_references=joint_force_references,
|
175
|
+
)
|
176
|
+
|
177
|
+
# Build the additional kwargs to pass to the computation of the contact forces.
|
178
|
+
match model.contact_model:
|
179
|
+
|
180
|
+
case contacts.SoftContacts():
|
181
|
+
|
182
|
+
kwargs_contact_model = {}
|
183
|
+
|
184
|
+
case contacts.RigidContacts():
|
185
|
+
|
186
|
+
kwargs_contact_model = common_kwargs | kwargs
|
187
|
+
|
188
|
+
case contacts.RelaxedRigidContacts():
|
189
|
+
|
190
|
+
kwargs_contact_model = common_kwargs | kwargs
|
191
|
+
|
192
|
+
case contacts.ViscoElasticContacts():
|
193
|
+
|
194
|
+
kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs
|
195
|
+
|
196
|
+
case _:
|
197
|
+
raise ValueError(f"Invalid contact model: {model.contact_model}")
|
198
|
+
|
199
|
+
# Compute the contact forces with the active contact model.
|
200
|
+
W_f_C, aux_data = model.contact_model.compute_contact_forces(
|
201
|
+
model=model,
|
202
|
+
data=data,
|
203
|
+
**kwargs_contact_model,
|
204
|
+
)
|
205
|
+
|
206
|
+
# Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
|
207
|
+
# associated to the enabled collidable point.
|
208
|
+
# In inertial-fixed representation, the computation of these transforms
|
209
|
+
# is not necessary and the conversion below becomes a no-op.
|
210
|
+
|
211
|
+
# Get the indices of the enabled collidable points.
|
212
|
+
indices_of_enabled_collidable_points = (
|
213
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
214
|
+
)
|
215
|
+
|
216
|
+
W_H_C = (
|
217
|
+
js.contact.transforms(model=model, data=data)
|
218
|
+
if data.velocity_representation is not VelRepr.Inertial
|
219
|
+
else jnp.stack([jnp.eye(4)] * len(indices_of_enabled_collidable_points))
|
220
|
+
)
|
221
|
+
|
222
|
+
# Convert the 6D forces to the active representation.
|
223
|
+
f_Ci = jax.vmap(
|
224
|
+
lambda W_f_C, W_H_C: data.inertial_to_other_representation(
|
225
|
+
array=W_f_C,
|
226
|
+
other_representation=data.velocity_representation,
|
227
|
+
transform=W_H_C,
|
228
|
+
is_force=True,
|
229
|
+
)
|
230
|
+
)(W_f_C, W_H_C)
|
231
|
+
|
232
|
+
return f_Ci, aux_data
|
79
233
|
|
80
234
|
|
81
235
|
@functools.partial(jax.jit, static_argnames=["link_names"])
|
236
|
+
@js.common.named_scope
|
82
237
|
def in_contact(
|
83
|
-
model:
|
84
|
-
data:
|
238
|
+
model: js.model.JaxSimModel,
|
239
|
+
data: js.data.JaxSimModelData,
|
85
240
|
*,
|
86
241
|
link_names: tuple[str, ...] | None = None,
|
87
242
|
) -> jtp.Vector:
|
@@ -98,50 +253,71 @@ def in_contact(
|
|
98
253
|
A boolean vector indicating whether the links are in contact with the terrain.
|
99
254
|
"""
|
100
255
|
|
101
|
-
|
102
|
-
|
103
|
-
if set(link_names) - set(model.link_names()) != set():
|
256
|
+
if link_names is not None and set(link_names).difference(model.link_names()):
|
104
257
|
raise ValueError("One or more link names are not part of the model")
|
105
258
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
model=model.physics_model,
|
110
|
-
q=data.state.physics_model.joint_positions,
|
111
|
-
qd=data.state.physics_model.joint_velocities,
|
112
|
-
xfb=data.state.physics_model.xfb(),
|
259
|
+
# Get the indices of the enabled collidable points.
|
260
|
+
indices_of_enabled_collidable_points = (
|
261
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
113
262
|
)
|
114
263
|
|
264
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
265
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
266
|
+
)[indices_of_enabled_collidable_points]
|
267
|
+
|
268
|
+
W_p_Ci = collidable_point_positions(model=model, data=data)
|
269
|
+
|
115
270
|
terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
|
116
|
-
W_p_Ci[0
|
271
|
+
W_p_Ci[:, 0], W_p_Ci[:, 1]
|
117
272
|
)
|
118
273
|
|
119
|
-
below_terrain = W_p_Ci[2
|
274
|
+
below_terrain = W_p_Ci[:, 2] <= terrain_height
|
275
|
+
|
276
|
+
link_idxs = (
|
277
|
+
js.link.names_to_idxs(link_names=link_names, model=model)
|
278
|
+
if link_names is not None
|
279
|
+
else jnp.arange(model.number_of_links())
|
280
|
+
)
|
120
281
|
|
121
282
|
links_in_contact = jax.vmap(
|
122
283
|
lambda link_index: jnp.where(
|
123
|
-
|
284
|
+
parent_link_idx_of_enabled_collidable_points == link_index,
|
124
285
|
below_terrain,
|
125
286
|
jnp.zeros_like(below_terrain, dtype=bool),
|
126
287
|
).any()
|
127
|
-
)(
|
288
|
+
)(link_idxs)
|
128
289
|
|
129
290
|
return links_in_contact
|
130
291
|
|
131
292
|
|
132
|
-
@jax.jit
|
133
293
|
def estimate_good_soft_contacts_parameters(
|
134
|
-
|
294
|
+
*args, **kwargs
|
295
|
+
) -> jaxsim.rbda.contacts.ContactParamsTypes:
|
296
|
+
"""
|
297
|
+
Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.
|
298
|
+
"""
|
299
|
+
|
300
|
+
msg = "This method is deprecated, please use `{}`."
|
301
|
+
logging.warning(msg.format(estimate_good_contact_parameters.__name__))
|
302
|
+
return estimate_good_contact_parameters(*args, **kwargs)
|
303
|
+
|
304
|
+
|
305
|
+
def estimate_good_contact_parameters(
|
306
|
+
model: js.model.JaxSimModel,
|
307
|
+
*,
|
308
|
+
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
|
135
309
|
static_friction_coefficient: jtp.FloatLike = 0.5,
|
136
310
|
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
137
311
|
damping_ratio: jtp.FloatLike = 1.0,
|
138
312
|
max_penetration: jtp.FloatLike | None = None,
|
139
|
-
|
313
|
+
**kwargs,
|
314
|
+
) -> jaxsim.rbda.contacts.ContactParamsTypes:
|
140
315
|
"""
|
141
|
-
Estimate good
|
316
|
+
Estimate good contact parameters.
|
142
317
|
|
143
318
|
Args:
|
144
319
|
model: The model to consider.
|
320
|
+
standard_gravity: The standard gravity constant.
|
145
321
|
static_friction_coefficient: The static friction coefficient.
|
146
322
|
number_of_active_collidable_points_steady_state:
|
147
323
|
The number of active collidable points in steady state supporting
|
@@ -150,26 +326,37 @@ def estimate_good_soft_contacts_parameters(
|
|
150
326
|
max_penetration:
|
151
327
|
The maximum penetration allowed in steady state when the robot is
|
152
328
|
supported by the configured number of active collidable points.
|
329
|
+
kwargs:
|
330
|
+
Additional model-specific parameters passed to the builder method of
|
331
|
+
the parameters class.
|
153
332
|
|
154
333
|
Returns:
|
155
|
-
The estimated good
|
334
|
+
The estimated good contacts parameters.
|
335
|
+
|
336
|
+
Note:
|
337
|
+
This is primarily a convenience function for soft-like contact models.
|
338
|
+
However, it provides with some good default parameters also for the other ones.
|
156
339
|
|
157
340
|
Note:
|
158
|
-
This method provides a good
|
341
|
+
This method provides a good set of contacts parameters.
|
159
342
|
The user is encouraged to fine-tune the parameters based on the
|
160
343
|
specific application.
|
161
344
|
"""
|
162
345
|
|
163
|
-
def estimate_model_height(model:
|
164
|
-
"""
|
346
|
+
def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
|
347
|
+
"""
|
348
|
+
Displacement between the CoM and the lowest collidable point using zero
|
349
|
+
joint positions.
|
350
|
+
"""
|
165
351
|
|
166
|
-
zero_data =
|
167
|
-
model=model,
|
352
|
+
zero_data = js.data.JaxSimModelData.build(
|
353
|
+
model=model,
|
354
|
+
contacts_params=jaxsim.rbda.contacts.SoftContactsParams(),
|
168
355
|
)
|
169
356
|
|
170
|
-
W_pz_CoM =
|
357
|
+
W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
|
171
358
|
|
172
|
-
if model.
|
359
|
+
if model.floating_base():
|
173
360
|
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
|
174
361
|
return 2 * (W_pz_CoM - W_pz_C.min())
|
175
362
|
|
@@ -178,17 +365,382 @@ def estimate_good_soft_contacts_parameters(
|
|
178
365
|
max_δ = (
|
179
366
|
max_penetration
|
180
367
|
if max_penetration is not None
|
368
|
+
# Consider as default a 0.5% of the model height.
|
181
369
|
else 0.005 * estimate_model_height(model=model)
|
182
370
|
)
|
183
371
|
|
184
372
|
nc = number_of_active_collidable_points_steady_state
|
185
373
|
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
374
|
+
match model.contact_model:
|
375
|
+
|
376
|
+
case contacts.SoftContacts():
|
377
|
+
assert isinstance(model.contact_model, contacts.SoftContacts)
|
378
|
+
|
379
|
+
parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
|
380
|
+
model=model,
|
381
|
+
standard_gravity=standard_gravity,
|
382
|
+
static_friction_coefficient=static_friction_coefficient,
|
383
|
+
max_penetration=max_δ,
|
384
|
+
number_of_active_collidable_points_steady_state=nc,
|
385
|
+
damping_ratio=damping_ratio,
|
386
|
+
**kwargs,
|
387
|
+
)
|
388
|
+
|
389
|
+
case contacts.ViscoElasticContacts():
|
390
|
+
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)
|
391
|
+
|
392
|
+
parameters = (
|
393
|
+
contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model(
|
394
|
+
model=model,
|
395
|
+
standard_gravity=standard_gravity,
|
396
|
+
static_friction_coefficient=static_friction_coefficient,
|
397
|
+
max_penetration=max_δ,
|
398
|
+
number_of_active_collidable_points_steady_state=nc,
|
399
|
+
damping_ratio=damping_ratio,
|
400
|
+
**kwargs,
|
401
|
+
)
|
402
|
+
)
|
403
|
+
|
404
|
+
case contacts.RigidContacts():
|
405
|
+
assert isinstance(model.contact_model, contacts.RigidContacts)
|
406
|
+
|
407
|
+
# Disable Baumgarte stabilization by default since it does not play
|
408
|
+
# well with the forward Euler integrator.
|
409
|
+
K = kwargs.get("K", 0.0)
|
410
|
+
|
411
|
+
parameters = contacts.RigidContactsParams.build(
|
412
|
+
mu=static_friction_coefficient,
|
413
|
+
**(
|
414
|
+
dict(
|
415
|
+
K=K,
|
416
|
+
D=2 * jnp.sqrt(K),
|
417
|
+
)
|
418
|
+
| kwargs
|
419
|
+
),
|
420
|
+
)
|
421
|
+
|
422
|
+
case contacts.RelaxedRigidContacts():
|
423
|
+
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
|
424
|
+
|
425
|
+
parameters = contacts.RelaxedRigidContactsParams.build(
|
426
|
+
mu=static_friction_coefficient,
|
427
|
+
**kwargs,
|
428
|
+
)
|
429
|
+
|
430
|
+
case _:
|
431
|
+
raise ValueError(f"Invalid contact model: {model.contact_model}")
|
432
|
+
|
433
|
+
return parameters
|
434
|
+
|
435
|
+
|
436
|
+
@jax.jit
|
437
|
+
@js.common.named_scope
|
438
|
+
def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
|
439
|
+
r"""
|
440
|
+
Return the pose of the enabled collidable points.
|
441
|
+
|
442
|
+
Args:
|
443
|
+
model: The model to consider.
|
444
|
+
data: The data of the considered model.
|
445
|
+
|
446
|
+
Returns:
|
447
|
+
The stacked SE(3) matrices of all enabled collidable points.
|
448
|
+
|
449
|
+
Note:
|
450
|
+
Each collidable point is implicitly associated with a frame
|
451
|
+
:math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
|
452
|
+
collidable point and :math:`[L]` is the orientation frame of the link it is
|
453
|
+
rigidly attached to.
|
454
|
+
"""
|
455
|
+
|
456
|
+
# Get the indices of the enabled collidable points.
|
457
|
+
indices_of_enabled_collidable_points = (
|
458
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
459
|
+
)
|
460
|
+
|
461
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
462
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
463
|
+
)[indices_of_enabled_collidable_points]
|
464
|
+
|
465
|
+
# Get the transforms of the parent link of all collidable points.
|
466
|
+
W_H_L = js.model.forward_kinematics(model=model, data=data)[
|
467
|
+
parent_link_idx_of_enabled_collidable_points
|
468
|
+
]
|
469
|
+
|
470
|
+
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
|
471
|
+
indices_of_enabled_collidable_points
|
472
|
+
]
|
473
|
+
|
474
|
+
# Build the link-to-point transform from the displacement between the link frame L
|
475
|
+
# and the implicit contact frame C.
|
476
|
+
L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)
|
477
|
+
|
478
|
+
# Compose the work-to-link and link-to-point transforms.
|
479
|
+
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
|
480
|
+
|
481
|
+
|
482
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
483
|
+
@js.common.named_scope
|
484
|
+
def jacobian(
|
485
|
+
model: js.model.JaxSimModel,
|
486
|
+
data: js.data.JaxSimModelData,
|
487
|
+
*,
|
488
|
+
output_vel_repr: VelRepr | None = None,
|
489
|
+
) -> jtp.Array:
|
490
|
+
r"""
|
491
|
+
Return the free-floating Jacobian of the enabled collidable points.
|
492
|
+
|
493
|
+
Args:
|
494
|
+
model: The model to consider.
|
495
|
+
data: The data of the considered model.
|
496
|
+
output_vel_repr:
|
497
|
+
The output velocity representation of the free-floating jacobian.
|
498
|
+
|
499
|
+
Returns:
|
500
|
+
The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the
|
501
|
+
enabled collidable points.
|
502
|
+
|
503
|
+
Note:
|
504
|
+
Each collidable point is implicitly associated with a frame
|
505
|
+
:math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
|
506
|
+
collidable point and :math:`[L]` is the orientation frame of the link it is
|
507
|
+
rigidly attached to.
|
508
|
+
"""
|
509
|
+
|
510
|
+
output_vel_repr = (
|
511
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
512
|
+
)
|
513
|
+
|
514
|
+
# Get the indices of the enabled collidable points.
|
515
|
+
indices_of_enabled_collidable_points = (
|
516
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
517
|
+
)
|
518
|
+
|
519
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
520
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
521
|
+
)[indices_of_enabled_collidable_points]
|
522
|
+
|
523
|
+
# Compute the Jacobians of all links.
|
524
|
+
W_J_WL = js.model.generalized_free_floating_jacobian(
|
525
|
+
model=model, data=data, output_vel_repr=VelRepr.Inertial
|
526
|
+
)
|
527
|
+
|
528
|
+
# Compute the contact Jacobian.
|
529
|
+
# In inertial-fixed output representation, the Jacobian of the parent link is also
|
530
|
+
# the Jacobian of the frame C implicitly associated with the collidable point.
|
531
|
+
W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points]
|
532
|
+
|
533
|
+
# Adjust the output representation.
|
534
|
+
match output_vel_repr:
|
535
|
+
|
536
|
+
case VelRepr.Inertial:
|
537
|
+
O_J_WC = W_J_WC
|
538
|
+
|
539
|
+
case VelRepr.Body:
|
540
|
+
|
541
|
+
W_H_C = transforms(model=model, data=data)
|
542
|
+
|
543
|
+
def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
|
544
|
+
C_X_W = jaxsim.math.Adjoint.from_transform(
|
545
|
+
transform=W_H_C, inverse=True
|
546
|
+
)
|
547
|
+
C_J_WC = C_X_W @ W_J_WC
|
548
|
+
return C_J_WC
|
549
|
+
|
550
|
+
O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)
|
551
|
+
|
552
|
+
case VelRepr.Mixed:
|
553
|
+
|
554
|
+
W_H_C = transforms(model=model, data=data)
|
555
|
+
|
556
|
+
def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
|
557
|
+
|
558
|
+
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
|
559
|
+
|
560
|
+
CW_X_W = jaxsim.math.Adjoint.from_transform(
|
561
|
+
transform=W_H_CW, inverse=True
|
562
|
+
)
|
563
|
+
|
564
|
+
CW_J_WC = CW_X_W @ W_J_WC
|
565
|
+
return CW_J_WC
|
566
|
+
|
567
|
+
O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC)
|
568
|
+
|
569
|
+
case _:
|
570
|
+
raise ValueError(output_vel_repr)
|
571
|
+
|
572
|
+
return O_J_WC
|
573
|
+
|
574
|
+
|
575
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
576
|
+
@js.common.named_scope
|
577
|
+
def jacobian_derivative(
|
578
|
+
model: js.model.JaxSimModel,
|
579
|
+
data: js.data.JaxSimModelData,
|
580
|
+
*,
|
581
|
+
output_vel_repr: VelRepr | None = None,
|
582
|
+
) -> jtp.Matrix:
|
583
|
+
r"""
|
584
|
+
Compute the derivative of the free-floating jacobian of the enabled collidable points.
|
585
|
+
|
586
|
+
Args:
|
587
|
+
model: The model to consider.
|
588
|
+
data: The data of the considered model.
|
589
|
+
output_vel_repr:
|
590
|
+
The output velocity representation of the free-floating jacobian derivative.
|
591
|
+
|
592
|
+
Returns:
|
593
|
+
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the enabled collidable points.
|
594
|
+
|
595
|
+
Note:
|
596
|
+
The input representation of the free-floating jacobian derivative is the active
|
597
|
+
velocity representation.
|
598
|
+
"""
|
599
|
+
|
600
|
+
output_vel_repr = (
|
601
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
602
|
+
)
|
603
|
+
|
604
|
+
indices_of_enabled_collidable_points = (
|
605
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
606
|
+
)
|
607
|
+
|
608
|
+
# Get the index of the parent link and the position of the collidable point.
|
609
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
610
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
611
|
+
)[indices_of_enabled_collidable_points]
|
612
|
+
|
613
|
+
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
|
614
|
+
indices_of_enabled_collidable_points
|
615
|
+
]
|
616
|
+
|
617
|
+
# Get the transforms of all the parent links.
|
618
|
+
W_H_Li = js.model.forward_kinematics(model=model, data=data)
|
619
|
+
|
620
|
+
# =====================================================
|
621
|
+
# Compute quantities to adjust the input representation
|
622
|
+
# =====================================================
|
623
|
+
|
624
|
+
def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
|
625
|
+
In = jnp.eye(model.dofs())
|
626
|
+
T = jax.scipy.linalg.block_diag(X, In)
|
627
|
+
return T
|
628
|
+
|
629
|
+
def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
|
630
|
+
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
|
631
|
+
Ṫ = jax.scipy.linalg.block_diag(Ẋ, On)
|
632
|
+
return Ṫ
|
633
|
+
|
634
|
+
# Compute the operator to change the representation of ν, and its
|
635
|
+
# time derivative.
|
636
|
+
match data.velocity_representation:
|
637
|
+
case VelRepr.Inertial:
|
638
|
+
W_H_W = jnp.eye(4)
|
639
|
+
W_X_W = Adjoint.from_transform(transform=W_H_W)
|
640
|
+
W_Ẋ_W = jnp.zeros((6, 6))
|
641
|
+
|
642
|
+
T = compute_T(model=model, X=W_X_W)
|
643
|
+
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
|
644
|
+
|
645
|
+
case VelRepr.Body:
|
646
|
+
W_H_B = data.base_transform()
|
647
|
+
W_X_B = Adjoint.from_transform(transform=W_H_B)
|
648
|
+
B_v_WB = data.base_velocity()
|
649
|
+
B_vx_WB = Cross.vx(B_v_WB)
|
650
|
+
W_Ẋ_B = W_X_B @ B_vx_WB
|
651
|
+
|
652
|
+
T = compute_T(model=model, X=W_X_B)
|
653
|
+
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
|
654
|
+
|
655
|
+
case VelRepr.Mixed:
|
656
|
+
W_H_B = data.base_transform()
|
657
|
+
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
658
|
+
W_X_BW = Adjoint.from_transform(transform=W_H_BW)
|
659
|
+
BW_v_WB = data.base_velocity()
|
660
|
+
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
661
|
+
BW_vx_W_BW = Cross.vx(BW_v_W_BW)
|
662
|
+
W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
|
663
|
+
|
664
|
+
T = compute_T(model=model, X=W_X_BW)
|
665
|
+
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW)
|
666
|
+
|
667
|
+
case _:
|
668
|
+
raise ValueError(data.velocity_representation)
|
669
|
+
|
670
|
+
# =====================================================
|
671
|
+
# Compute quantities to adjust the output representation
|
672
|
+
# =====================================================
|
673
|
+
|
674
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
675
|
+
# Compute the Jacobian of the parent link in inertial representation.
|
676
|
+
W_J_WL_W = js.model.generalized_free_floating_jacobian(
|
677
|
+
model=model,
|
678
|
+
data=data,
|
679
|
+
output_vel_repr=VelRepr.Inertial,
|
680
|
+
)
|
681
|
+
# Compute the Jacobian derivative of the parent link in inertial representation.
|
682
|
+
W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
|
683
|
+
model=model,
|
684
|
+
data=data,
|
685
|
+
output_vel_repr=VelRepr.Inertial,
|
686
|
+
)
|
687
|
+
|
688
|
+
# Get the Jacobian of the enabled collidable points in the mixed representation.
|
689
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
690
|
+
CW_J_WC_BW = jacobian(
|
691
|
+
model=model,
|
692
|
+
data=data,
|
693
|
+
output_vel_repr=VelRepr.Mixed,
|
694
|
+
)
|
695
|
+
|
696
|
+
def compute_O_J̇_WC_I(
|
697
|
+
L_p_C: jtp.Vector,
|
698
|
+
parent_link_idx: jtp.Int,
|
699
|
+
CW_J_WC_BW: jtp.Matrix,
|
700
|
+
W_H_L: jtp.Matrix,
|
701
|
+
) -> jtp.Matrix:
|
702
|
+
|
703
|
+
match output_vel_repr:
|
704
|
+
case VelRepr.Inertial:
|
705
|
+
O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841
|
706
|
+
transform=jnp.eye(4)
|
707
|
+
)
|
708
|
+
O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841
|
709
|
+
|
710
|
+
case VelRepr.Body:
|
711
|
+
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
|
712
|
+
W_H_C = W_H_L[parent_link_idx] @ L_H_C
|
713
|
+
O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
|
714
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
715
|
+
W_nu = data.generalized_velocity()
|
716
|
+
W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu
|
717
|
+
W_vx_WC = Cross.vx(W_v_WC)
|
718
|
+
O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841
|
719
|
+
|
720
|
+
case VelRepr.Mixed:
|
721
|
+
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
|
722
|
+
W_H_C = W_H_L[parent_link_idx] @ L_H_C
|
723
|
+
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
|
724
|
+
CW_H_W = Transform.inverse(W_H_CW)
|
725
|
+
O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
|
726
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
727
|
+
CW_v_WC = CW_J_WC_BW @ data.generalized_velocity()
|
728
|
+
W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
|
729
|
+
W_vx_W_CW = Cross.vx(W_v_W_CW)
|
730
|
+
O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841
|
731
|
+
|
732
|
+
case _:
|
733
|
+
raise ValueError(output_vel_repr)
|
734
|
+
|
735
|
+
O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs()))
|
736
|
+
O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T
|
737
|
+
O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T
|
738
|
+
O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ
|
739
|
+
|
740
|
+
return O_J̇_WC_I
|
741
|
+
|
742
|
+
O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, 0, None))(
|
743
|
+
L_p_Ci, parent_link_idx_of_enabled_collidable_points, CW_J_WC_BW, W_H_Li
|
192
744
|
)
|
193
745
|
|
194
|
-
return
|
746
|
+
return O_J̇_WC
|