jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev105__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +156 -11
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/METADATA +6 -8
- jaxsim-0.6.2.dev105.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py
CHANGED
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.6.2.
|
16
|
-
__version_tuple__ = version_tuple = (0, 6, 2, '
|
15
|
+
__version__ = version = '0.6.2.dev105'
|
16
|
+
__version_tuple__ = version_tuple = (0, 6, 2, 'dev105')
|
jaxsim/api/__init__.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1
1
|
from . import common # isort:skip
|
2
2
|
from . import model, data # isort:skip
|
3
3
|
from . import (
|
4
|
+
actuation_model,
|
4
5
|
com,
|
5
6
|
contact,
|
7
|
+
contact_model,
|
6
8
|
frame,
|
9
|
+
integrators,
|
7
10
|
joint,
|
8
11
|
kin_dyn_parameters,
|
9
12
|
link,
|
10
13
|
ode,
|
11
|
-
ode_data,
|
12
14
|
references,
|
13
15
|
)
|
@@ -0,0 +1,96 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
|
3
|
+
import jaxsim.api as js
|
4
|
+
import jaxsim.typing as jtp
|
5
|
+
|
6
|
+
|
7
|
+
def compute_resultant_torques(
|
8
|
+
model: js.model.JaxSimModel,
|
9
|
+
data: js.data.JaxSimModelData,
|
10
|
+
*,
|
11
|
+
joint_force_references: jtp.Vector | None = None,
|
12
|
+
) -> jtp.Vector:
|
13
|
+
"""
|
14
|
+
Compute the resultant torques acting on the joints.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
model: The model to consider.
|
18
|
+
data: The data of the considered model.
|
19
|
+
joint_force_references: The joint force references to apply.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
The resultant torques acting on the joints.
|
23
|
+
"""
|
24
|
+
|
25
|
+
# Build joint torques if not provided.
|
26
|
+
τ_references = (
|
27
|
+
jnp.atleast_1d(joint_force_references.squeeze())
|
28
|
+
if joint_force_references is not None
|
29
|
+
else jnp.zeros_like(data.joint_positions)
|
30
|
+
).astype(float)
|
31
|
+
|
32
|
+
# ====================
|
33
|
+
# Enforce joint limits
|
34
|
+
# ====================
|
35
|
+
|
36
|
+
τ_position_limit = jnp.zeros_like(τ_references).astype(float)
|
37
|
+
|
38
|
+
if model.dofs() > 0:
|
39
|
+
|
40
|
+
# Stiffness and damper parameters for the joint position limits.
|
41
|
+
k_j = jnp.array(
|
42
|
+
model.kin_dyn_parameters.joint_parameters.position_limit_spring
|
43
|
+
).astype(float)
|
44
|
+
d_j = jnp.array(
|
45
|
+
model.kin_dyn_parameters.joint_parameters.position_limit_damper
|
46
|
+
).astype(float)
|
47
|
+
|
48
|
+
# Compute the joint position limit violations.
|
49
|
+
lower_violation = jnp.clip(
|
50
|
+
data.joint_positions
|
51
|
+
- model.kin_dyn_parameters.joint_parameters.position_limits_min,
|
52
|
+
max=0.0,
|
53
|
+
)
|
54
|
+
|
55
|
+
upper_violation = jnp.clip(
|
56
|
+
data.joint_positions
|
57
|
+
- model.kin_dyn_parameters.joint_parameters.position_limits_max,
|
58
|
+
min=0.0,
|
59
|
+
)
|
60
|
+
|
61
|
+
# Compute the joint position limit torque.
|
62
|
+
τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
|
63
|
+
|
64
|
+
τ_position_limit -= (
|
65
|
+
jnp.positive(τ_position_limit) * jnp.diag(d_j) @ data.joint_velocities
|
66
|
+
)
|
67
|
+
|
68
|
+
# ====================
|
69
|
+
# Joint friction model
|
70
|
+
# ====================
|
71
|
+
|
72
|
+
τ_friction = jnp.zeros_like(τ_references).astype(float)
|
73
|
+
|
74
|
+
if model.dofs() > 0:
|
75
|
+
|
76
|
+
# Static and viscous joint friction parameters
|
77
|
+
kc = jnp.array(
|
78
|
+
model.kin_dyn_parameters.joint_parameters.friction_static
|
79
|
+
).astype(float)
|
80
|
+
kv = jnp.array(
|
81
|
+
model.kin_dyn_parameters.joint_parameters.friction_viscous
|
82
|
+
).astype(float)
|
83
|
+
|
84
|
+
# Compute the joint friction torque.
|
85
|
+
τ_friction = -(
|
86
|
+
jnp.diag(kc) @ jnp.sign(data.joint_velocities)
|
87
|
+
+ jnp.diag(kv) @ data.joint_velocities
|
88
|
+
)
|
89
|
+
|
90
|
+
# ===============================
|
91
|
+
# Compute the total joint forces.
|
92
|
+
# ===============================
|
93
|
+
|
94
|
+
τ_total = τ_references + τ_friction + τ_position_limit
|
95
|
+
|
96
|
+
return τ_total
|
jaxsim/api/com.py
CHANGED
@@ -26,8 +26,8 @@ def com_position(
|
|
26
26
|
|
27
27
|
m = js.model.total_mass(model=model)
|
28
28
|
|
29
|
-
W_H_L =
|
30
|
-
W_H_B = data.
|
29
|
+
W_H_L = data._link_transforms
|
30
|
+
W_H_B = data._base_transform
|
31
31
|
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)
|
32
32
|
|
33
33
|
def B_p̃_LCoM(i) -> jtp.Vector:
|
@@ -98,7 +98,7 @@ def centroidal_momentum(
|
|
98
98
|
and :math:`C = B` if the active velocity representation is body-fixed.
|
99
99
|
"""
|
100
100
|
|
101
|
-
ν = data.generalized_velocity
|
101
|
+
ν = data.generalized_velocity
|
102
102
|
G_J = centroidal_momentum_jacobian(model=model, data=data)
|
103
103
|
|
104
104
|
return G_J @ ν
|
@@ -134,7 +134,7 @@ def centroidal_momentum_jacobian(
|
|
134
134
|
model=model, data=data, output_vel_repr=VelRepr.Body
|
135
135
|
)
|
136
136
|
|
137
|
-
W_H_B = data.
|
137
|
+
W_H_B = data._base_transform
|
138
138
|
B_H_W = jaxsim.math.Transform.inverse(W_H_B)
|
139
139
|
|
140
140
|
W_p_CoM = com_position(model=model, data=data)
|
@@ -172,7 +172,7 @@ def locked_centroidal_spatial_inertia(
|
|
172
172
|
with data.switch_velocity_representation(VelRepr.Body):
|
173
173
|
B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)
|
174
174
|
|
175
|
-
W_H_B = data.
|
175
|
+
W_H_B = data._base_transform
|
176
176
|
W_p_CoM = com_position(model=model, data=data)
|
177
177
|
|
178
178
|
match data.velocity_representation:
|
@@ -213,7 +213,7 @@ def average_centroidal_velocity(
|
|
213
213
|
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
|
214
214
|
"""
|
215
215
|
|
216
|
-
ν = data.generalized_velocity
|
216
|
+
ν = data.generalized_velocity
|
217
217
|
G_J = average_centroidal_velocity_jacobian(model=model, data=data)
|
218
218
|
|
219
219
|
return G_J @ ν
|
@@ -269,7 +269,7 @@ def bias_acceleration(
|
|
269
269
|
"""
|
270
270
|
|
271
271
|
# Compute the pose of all links with forward kinematics.
|
272
|
-
W_H_L =
|
272
|
+
W_H_L = data._link_transforms
|
273
273
|
|
274
274
|
# Compute the bias acceleration of all links by zeroing the generalized velocity
|
275
275
|
# in the active representation.
|
@@ -411,7 +411,7 @@ def bias_acceleration(
|
|
411
411
|
case VelRepr.Body:
|
412
412
|
|
413
413
|
GB_Xf_W = jaxsim.math.Adjoint.from_transform(
|
414
|
-
transform=data.
|
414
|
+
transform=data._base_transform.at[0:3].set(W_p_CoM)
|
415
415
|
).T
|
416
416
|
|
417
417
|
GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
|
jaxsim/api/contact.py
CHANGED
@@ -42,12 +42,8 @@ def collidable_point_kinematics(
|
|
42
42
|
|
43
43
|
W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
|
44
44
|
model=model,
|
45
|
-
|
46
|
-
|
47
|
-
joint_positions=data.joint_positions(model=model),
|
48
|
-
base_linear_velocity=data.base_velocity()[0:3],
|
49
|
-
base_angular_velocity=data.base_velocity()[3:6],
|
50
|
-
joint_velocities=data.joint_velocities(model=model),
|
45
|
+
link_transforms=data._link_transforms,
|
46
|
+
link_velocities=data._link_velocities,
|
51
47
|
)
|
52
48
|
|
53
49
|
return W_p_Ci, W_ṗ_Ci
|
@@ -95,143 +91,6 @@ def collidable_point_velocities(
|
|
95
91
|
return W_ṗ_Ci
|
96
92
|
|
97
93
|
|
98
|
-
@jax.jit
|
99
|
-
@js.common.named_scope
|
100
|
-
def collidable_point_forces(
|
101
|
-
model: js.model.JaxSimModel,
|
102
|
-
data: js.data.JaxSimModelData,
|
103
|
-
link_forces: jtp.MatrixLike | None = None,
|
104
|
-
joint_force_references: jtp.VectorLike | None = None,
|
105
|
-
**kwargs,
|
106
|
-
) -> jtp.Matrix:
|
107
|
-
"""
|
108
|
-
Compute the 6D forces applied to each collidable point.
|
109
|
-
|
110
|
-
Args:
|
111
|
-
model: The model to consider.
|
112
|
-
data: The data of the considered model.
|
113
|
-
link_forces:
|
114
|
-
The 6D external forces to apply to the links expressed in the same
|
115
|
-
representation of data.
|
116
|
-
joint_force_references:
|
117
|
-
The joint force references to apply to the joints.
|
118
|
-
kwargs: Additional keyword arguments to pass to the active contact model.
|
119
|
-
|
120
|
-
Returns:
|
121
|
-
The 6D forces applied to each collidable point expressed in the frame
|
122
|
-
corresponding to the active representation.
|
123
|
-
"""
|
124
|
-
|
125
|
-
f_Ci, _ = collidable_point_dynamics(
|
126
|
-
model=model,
|
127
|
-
data=data,
|
128
|
-
link_forces=link_forces,
|
129
|
-
joint_force_references=joint_force_references,
|
130
|
-
**kwargs,
|
131
|
-
)
|
132
|
-
|
133
|
-
return f_Ci
|
134
|
-
|
135
|
-
|
136
|
-
@jax.jit
|
137
|
-
@js.common.named_scope
|
138
|
-
def collidable_point_dynamics(
|
139
|
-
model: js.model.JaxSimModel,
|
140
|
-
data: js.data.JaxSimModelData,
|
141
|
-
link_forces: jtp.MatrixLike | None = None,
|
142
|
-
joint_force_references: jtp.VectorLike | None = None,
|
143
|
-
**kwargs,
|
144
|
-
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
145
|
-
r"""
|
146
|
-
Compute the 6D force applied to each enabled collidable point.
|
147
|
-
|
148
|
-
Args:
|
149
|
-
model: The model to consider.
|
150
|
-
data: The data of the considered model.
|
151
|
-
link_forces:
|
152
|
-
The 6D external forces to apply to the links expressed in the same
|
153
|
-
representation of data.
|
154
|
-
joint_force_references:
|
155
|
-
The joint force references to apply to the joints.
|
156
|
-
kwargs: Additional keyword arguments to pass to the active contact model.
|
157
|
-
|
158
|
-
Returns:
|
159
|
-
The 6D force applied to each enabled collidable point and additional data based
|
160
|
-
on the contact model configured:
|
161
|
-
- Soft: the material deformation rate.
|
162
|
-
- Rigid: no additional data.
|
163
|
-
- QuasiRigid: no additional data.
|
164
|
-
|
165
|
-
Note:
|
166
|
-
The material deformation rate is always returned in the mixed frame
|
167
|
-
`C[W] = ({}^W \mathbf{p}_C, [W])`. This is convenient for integration purpose.
|
168
|
-
Instead, the 6D forces are returned in the active representation.
|
169
|
-
"""
|
170
|
-
|
171
|
-
# Build the common kw arguments to pass to the computation of the contact forces.
|
172
|
-
common_kwargs = dict(
|
173
|
-
link_forces=link_forces,
|
174
|
-
joint_force_references=joint_force_references,
|
175
|
-
)
|
176
|
-
|
177
|
-
# Build the additional kwargs to pass to the computation of the contact forces.
|
178
|
-
match model.contact_model:
|
179
|
-
|
180
|
-
case contacts.SoftContacts():
|
181
|
-
|
182
|
-
kwargs_contact_model = {}
|
183
|
-
|
184
|
-
case contacts.RigidContacts():
|
185
|
-
|
186
|
-
kwargs_contact_model = common_kwargs | kwargs
|
187
|
-
|
188
|
-
case contacts.RelaxedRigidContacts():
|
189
|
-
|
190
|
-
kwargs_contact_model = common_kwargs | kwargs
|
191
|
-
|
192
|
-
case contacts.ViscoElasticContacts():
|
193
|
-
|
194
|
-
kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs
|
195
|
-
|
196
|
-
case _:
|
197
|
-
raise ValueError(f"Invalid contact model: {model.contact_model}")
|
198
|
-
|
199
|
-
# Compute the contact forces with the active contact model.
|
200
|
-
W_f_C, aux_data = model.contact_model.compute_contact_forces(
|
201
|
-
model=model,
|
202
|
-
data=data,
|
203
|
-
**kwargs_contact_model,
|
204
|
-
)
|
205
|
-
|
206
|
-
# Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
|
207
|
-
# associated to the enabled collidable point.
|
208
|
-
# In inertial-fixed representation, the computation of these transforms
|
209
|
-
# is not necessary and the conversion below becomes a no-op.
|
210
|
-
|
211
|
-
# Get the indices of the enabled collidable points.
|
212
|
-
indices_of_enabled_collidable_points = (
|
213
|
-
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
214
|
-
)
|
215
|
-
|
216
|
-
W_H_C = (
|
217
|
-
js.contact.transforms(model=model, data=data)
|
218
|
-
if data.velocity_representation is not VelRepr.Inertial
|
219
|
-
else jnp.stack([jnp.eye(4)] * len(indices_of_enabled_collidable_points))
|
220
|
-
)
|
221
|
-
|
222
|
-
# Convert the 6D forces to the active representation.
|
223
|
-
f_Ci = jax.vmap(
|
224
|
-
lambda W_f_C, W_H_C: data.inertial_to_other_representation(
|
225
|
-
array=W_f_C,
|
226
|
-
other_representation=data.velocity_representation,
|
227
|
-
transform=W_H_C,
|
228
|
-
is_force=True,
|
229
|
-
)
|
230
|
-
)(W_f_C, W_H_C)
|
231
|
-
|
232
|
-
return f_Ci, aux_data
|
233
|
-
|
234
|
-
|
235
94
|
@functools.partial(jax.jit, static_argnames=["link_names"])
|
236
95
|
@js.common.named_scope
|
237
96
|
def in_contact(
|
@@ -305,11 +164,7 @@ def estimate_good_soft_contacts_parameters(
|
|
305
164
|
def estimate_good_contact_parameters(
|
306
165
|
model: js.model.JaxSimModel,
|
307
166
|
*,
|
308
|
-
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
|
309
167
|
static_friction_coefficient: jtp.FloatLike = 0.5,
|
310
|
-
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
311
|
-
damping_ratio: jtp.FloatLike = 1.0,
|
312
|
-
max_penetration: jtp.FloatLike | None = None,
|
313
168
|
**kwargs,
|
314
169
|
) -> jaxsim.rbda.contacts.ContactParamsTypes:
|
315
170
|
"""
|
@@ -317,15 +172,7 @@ def estimate_good_contact_parameters(
|
|
317
172
|
|
318
173
|
Args:
|
319
174
|
model: The model to consider.
|
320
|
-
standard_gravity: The standard gravity constant.
|
321
175
|
static_friction_coefficient: The static friction coefficient.
|
322
|
-
number_of_active_collidable_points_steady_state:
|
323
|
-
The number of active collidable points in steady state supporting
|
324
|
-
the weight of the robot.
|
325
|
-
damping_ratio: The damping ratio.
|
326
|
-
max_penetration:
|
327
|
-
The maximum penetration allowed in steady state when the robot is
|
328
|
-
supported by the configured number of active collidable points.
|
329
176
|
kwargs:
|
330
177
|
Additional model-specific parameters passed to the builder method of
|
331
178
|
the parameters class.
|
@@ -343,82 +190,8 @@ def estimate_good_contact_parameters(
|
|
343
190
|
specific application.
|
344
191
|
"""
|
345
192
|
|
346
|
-
def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
|
347
|
-
"""
|
348
|
-
Displacement between the CoM and the lowest collidable point using zero
|
349
|
-
joint positions.
|
350
|
-
"""
|
351
|
-
|
352
|
-
zero_data = js.data.JaxSimModelData.build(
|
353
|
-
model=model,
|
354
|
-
contacts_params=jaxsim.rbda.contacts.SoftContactsParams(),
|
355
|
-
)
|
356
|
-
|
357
|
-
W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
|
358
|
-
|
359
|
-
if model.floating_base():
|
360
|
-
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
|
361
|
-
return 2 * (W_pz_CoM - W_pz_C.min())
|
362
|
-
|
363
|
-
return 2 * W_pz_CoM
|
364
|
-
|
365
|
-
max_δ = (
|
366
|
-
max_penetration
|
367
|
-
if max_penetration is not None
|
368
|
-
# Consider as default a 0.5% of the model height.
|
369
|
-
else 0.005 * estimate_model_height(model=model)
|
370
|
-
)
|
371
|
-
|
372
|
-
nc = number_of_active_collidable_points_steady_state
|
373
|
-
|
374
193
|
match model.contact_model:
|
375
194
|
|
376
|
-
case contacts.SoftContacts():
|
377
|
-
assert isinstance(model.contact_model, contacts.SoftContacts)
|
378
|
-
|
379
|
-
parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
|
380
|
-
model=model,
|
381
|
-
standard_gravity=standard_gravity,
|
382
|
-
static_friction_coefficient=static_friction_coefficient,
|
383
|
-
max_penetration=max_δ,
|
384
|
-
number_of_active_collidable_points_steady_state=nc,
|
385
|
-
damping_ratio=damping_ratio,
|
386
|
-
**kwargs,
|
387
|
-
)
|
388
|
-
|
389
|
-
case contacts.ViscoElasticContacts():
|
390
|
-
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)
|
391
|
-
|
392
|
-
parameters = (
|
393
|
-
contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model(
|
394
|
-
model=model,
|
395
|
-
standard_gravity=standard_gravity,
|
396
|
-
static_friction_coefficient=static_friction_coefficient,
|
397
|
-
max_penetration=max_δ,
|
398
|
-
number_of_active_collidable_points_steady_state=nc,
|
399
|
-
damping_ratio=damping_ratio,
|
400
|
-
**kwargs,
|
401
|
-
)
|
402
|
-
)
|
403
|
-
|
404
|
-
case contacts.RigidContacts():
|
405
|
-
assert isinstance(model.contact_model, contacts.RigidContacts)
|
406
|
-
|
407
|
-
# Disable Baumgarte stabilization by default since it does not play
|
408
|
-
# well with the forward Euler integrator.
|
409
|
-
K = kwargs.get("K", 0.0)
|
410
|
-
|
411
|
-
parameters = contacts.RigidContactsParams.build(
|
412
|
-
mu=static_friction_coefficient,
|
413
|
-
**(
|
414
|
-
dict(
|
415
|
-
K=K,
|
416
|
-
D=2 * jnp.sqrt(K),
|
417
|
-
)
|
418
|
-
| kwargs
|
419
|
-
),
|
420
|
-
)
|
421
|
-
|
422
195
|
case contacts.RelaxedRigidContacts():
|
423
196
|
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
|
424
197
|
|
@@ -463,9 +236,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
|
|
463
236
|
)[indices_of_enabled_collidable_points]
|
464
237
|
|
465
238
|
# Get the transforms of the parent link of all collidable points.
|
466
|
-
W_H_L =
|
467
|
-
parent_link_idx_of_enabled_collidable_points
|
468
|
-
]
|
239
|
+
W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_points]
|
469
240
|
|
470
241
|
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
|
471
242
|
indices_of_enabled_collidable_points
|
@@ -615,7 +386,10 @@ def jacobian_derivative(
|
|
615
386
|
]
|
616
387
|
|
617
388
|
# Get the transforms of all the parent links.
|
618
|
-
W_H_Li =
|
389
|
+
W_H_Li = data._link_transforms
|
390
|
+
|
391
|
+
# Get the link velocities.
|
392
|
+
W_v_WLi = data._link_velocities
|
619
393
|
|
620
394
|
# =====================================================
|
621
395
|
# Compute quantities to adjust the input representation
|
@@ -643,9 +417,9 @@ def jacobian_derivative(
|
|
643
417
|
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
|
644
418
|
|
645
419
|
case VelRepr.Body:
|
646
|
-
W_H_B = data.
|
420
|
+
W_H_B = data._base_transform
|
647
421
|
W_X_B = Adjoint.from_transform(transform=W_H_B)
|
648
|
-
B_v_WB = data.base_velocity
|
422
|
+
B_v_WB = data.base_velocity
|
649
423
|
B_vx_WB = Cross.vx(B_v_WB)
|
650
424
|
W_Ẋ_B = W_X_B @ B_vx_WB
|
651
425
|
|
@@ -653,10 +427,10 @@ def jacobian_derivative(
|
|
653
427
|
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
|
654
428
|
|
655
429
|
case VelRepr.Mixed:
|
656
|
-
W_H_B = data.
|
430
|
+
W_H_B = data._base_transform
|
657
431
|
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
658
432
|
W_X_BW = Adjoint.from_transform(transform=W_H_BW)
|
659
|
-
BW_v_WB = data.base_velocity
|
433
|
+
BW_v_WB = data.base_velocity
|
660
434
|
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
661
435
|
BW_vx_W_BW = Cross.vx(BW_v_W_BW)
|
662
436
|
W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
|
@@ -676,27 +450,16 @@ def jacobian_derivative(
|
|
676
450
|
W_J_WL_W = js.model.generalized_free_floating_jacobian(
|
677
451
|
model=model,
|
678
452
|
data=data,
|
679
|
-
output_vel_repr=VelRepr.Inertial,
|
680
453
|
)
|
681
454
|
# Compute the Jacobian derivative of the parent link in inertial representation.
|
682
455
|
W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
|
683
456
|
model=model,
|
684
457
|
data=data,
|
685
|
-
output_vel_repr=VelRepr.Inertial,
|
686
|
-
)
|
687
|
-
|
688
|
-
# Get the Jacobian of the enabled collidable points in the mixed representation.
|
689
|
-
with data.switch_velocity_representation(VelRepr.Mixed):
|
690
|
-
CW_J_WC_BW = jacobian(
|
691
|
-
model=model,
|
692
|
-
data=data,
|
693
|
-
output_vel_repr=VelRepr.Mixed,
|
694
458
|
)
|
695
459
|
|
696
460
|
def compute_O_J̇_WC_I(
|
697
461
|
L_p_C: jtp.Vector,
|
698
462
|
parent_link_idx: jtp.Int,
|
699
|
-
CW_J_WC_BW: jtp.Matrix,
|
700
463
|
W_H_L: jtp.Matrix,
|
701
464
|
) -> jtp.Matrix:
|
702
465
|
|
@@ -711,9 +474,7 @@ def jacobian_derivative(
|
|
711
474
|
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
|
712
475
|
W_H_C = W_H_L[parent_link_idx] @ L_H_C
|
713
476
|
O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
|
714
|
-
|
715
|
-
W_nu = data.generalized_velocity()
|
716
|
-
W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu
|
477
|
+
W_v_WC = W_v_WLi[parent_link_idx]
|
717
478
|
W_vx_WC = Cross.vx(W_v_WC)
|
718
479
|
O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841
|
719
480
|
|
@@ -723,8 +484,7 @@ def jacobian_derivative(
|
|
723
484
|
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
|
724
485
|
CW_H_W = Transform.inverse(W_H_CW)
|
725
486
|
O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
|
726
|
-
|
727
|
-
CW_v_WC = CW_J_WC_BW @ data.generalized_velocity()
|
487
|
+
CW_v_WC = CW_X_W @ W_v_WLi[parent_link_idx]
|
728
488
|
W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
|
729
489
|
W_vx_W_CW = Cross.vx(W_v_W_CW)
|
730
490
|
O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841
|
@@ -739,8 +499,8 @@ def jacobian_derivative(
|
|
739
499
|
|
740
500
|
return O_J̇_WC_I
|
741
501
|
|
742
|
-
O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0,
|
743
|
-
L_p_Ci, parent_link_idx_of_enabled_collidable_points,
|
502
|
+
O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, None))(
|
503
|
+
L_p_Ci, parent_link_idx_of_enabled_collidable_points, W_H_Li
|
744
504
|
)
|
745
505
|
|
746
506
|
return O_J̇_WC
|
@@ -0,0 +1,101 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
|
6
|
+
import jaxsim.api as js
|
7
|
+
import jaxsim.typing as jtp
|
8
|
+
|
9
|
+
|
10
|
+
@jax.jit
|
11
|
+
@js.common.named_scope
|
12
|
+
def link_contact_forces(
|
13
|
+
model: js.model.JaxSimModel,
|
14
|
+
data: js.data.JaxSimModelData,
|
15
|
+
*,
|
16
|
+
link_forces: jtp.MatrixLike | None = None,
|
17
|
+
joint_torques: jtp.VectorLike | None = None,
|
18
|
+
) -> jtp.Matrix:
|
19
|
+
"""
|
20
|
+
Compute the 6D contact forces of all links of the model in inertial representation.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
model: The model to consider.
|
24
|
+
data: The data of the considered model.
|
25
|
+
link_forces:
|
26
|
+
The 6D external forces to apply to the links expressed in inertial representation
|
27
|
+
joint_torques:
|
28
|
+
The joint torques acting on the joints.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
|
32
|
+
expressed in inertial representation.
|
33
|
+
"""
|
34
|
+
|
35
|
+
# Compute the contact forces for each collidable point with the active contact model.
|
36
|
+
W_f_C, _ = model.contact_model.compute_contact_forces(
|
37
|
+
model=model,
|
38
|
+
data=data,
|
39
|
+
link_forces=link_forces,
|
40
|
+
joint_force_references=joint_torques,
|
41
|
+
)
|
42
|
+
|
43
|
+
# Compute the 6D forces applied to the links equivalent to the forces applied
|
44
|
+
# to the frames associated to the collidable points.
|
45
|
+
W_f_L = link_forces_from_contact_forces(
|
46
|
+
model=model, data=data, contact_forces=W_f_C
|
47
|
+
)
|
48
|
+
|
49
|
+
return W_f_L
|
50
|
+
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def link_forces_from_contact_forces(
|
54
|
+
model: js.model.JaxSimModel,
|
55
|
+
data: js.data.JaxSimModelData,
|
56
|
+
*,
|
57
|
+
contact_forces: jtp.MatrixLike,
|
58
|
+
) -> jtp.Matrix:
|
59
|
+
"""
|
60
|
+
Compute the link forces from the contact forces.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
model: The robot model considered by the contact model.
|
64
|
+
data: The data of the considered model.
|
65
|
+
contact_forces: The contact forces computed by the contact model.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
The 6D contact forces applied to the links and expressed in the frame of
|
69
|
+
the velocity representation of data.
|
70
|
+
"""
|
71
|
+
|
72
|
+
# Get the object storing the contact parameters of the model.
|
73
|
+
contact_parameters = model.kin_dyn_parameters.contact_parameters
|
74
|
+
|
75
|
+
# Extract the indices corresponding to the enabled collidable points.
|
76
|
+
indices_of_enabled_collidable_points = (
|
77
|
+
contact_parameters.indices_of_enabled_collidable_points
|
78
|
+
)
|
79
|
+
|
80
|
+
# Convert the contact forces to a JAX array.
|
81
|
+
W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
|
82
|
+
|
83
|
+
# Construct the vector defining the parent link index of each collidable point.
|
84
|
+
# We use this vector to sum the 6D forces of all collidable points rigidly
|
85
|
+
# attached to the same link.
|
86
|
+
parent_link_index_of_collidable_points = jnp.array(
|
87
|
+
contact_parameters.body, dtype=int
|
88
|
+
)[indices_of_enabled_collidable_points]
|
89
|
+
|
90
|
+
# Create the mask that associate each collidable point to their parent link.
|
91
|
+
# We use this mask to sum the collidable points to the right link.
|
92
|
+
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
|
93
|
+
model.number_of_links()
|
94
|
+
)
|
95
|
+
|
96
|
+
# Sum the forces of all collidable points rigidly attached to a body.
|
97
|
+
# Since the contact forces W_f_C are expressed in the world frame,
|
98
|
+
# we don't need any coordinate transformation.
|
99
|
+
W_f_L = mask.T @ W_f_C
|
100
|
+
|
101
|
+
return W_f_L
|