jaxsim 0.6.2.dev182__py3-none-any.whl → 0.6.2.dev225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +0 -1
- jaxsim/api/com.py +1 -3
- jaxsim/api/common.py +26 -38
- jaxsim/api/contact.py +140 -24
- jaxsim/api/data.py +96 -33
- jaxsim/api/integrators.py +18 -11
- jaxsim/api/model.py +25 -43
- jaxsim/api/ode.py +28 -6
- jaxsim/api/references.py +9 -16
- jaxsim/math/__init__.py +1 -1
- jaxsim/math/adjoint.py +2 -2
- jaxsim/math/transform.py +2 -2
- jaxsim/math/utils.py +3 -2
- jaxsim/mujoco/visualizer.py +1 -1
- jaxsim/parsers/kinematic_graph.py +1 -1
- jaxsim/rbda/__init__.py +1 -1
- jaxsim/rbda/contacts/__init__.py +6 -2
- jaxsim/rbda/contacts/common.py +114 -4
- jaxsim/rbda/contacts/relaxed_rigid.py +57 -177
- jaxsim/rbda/contacts/rigid.py +538 -0
- jaxsim/rbda/contacts/soft.py +448 -0
- jaxsim/rbda/forward_kinematics.py +0 -29
- jaxsim/rbda/utils.py +2 -2
- jaxsim/terrain/terrain.py +1 -1
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/METADATA +3 -2
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/RECORD +30 -29
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/WHEEL +1 -1
- jaxsim/api/contact_model.py +0 -101
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info/licenses}/LICENSE +0 -0
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/top_level.txt +0 -0
jaxsim/api/integrators.py
CHANGED
@@ -22,7 +22,7 @@ def semi_implicit_euler_integration(
|
|
22
22
|
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
|
23
23
|
|
24
24
|
# Compute the system acceleration
|
25
|
-
W_v̇_WB, s
|
25
|
+
W_v̇_WB, s̈, contact_state_derivative = js.ode.system_acceleration(
|
26
26
|
model=model,
|
27
27
|
data=data,
|
28
28
|
link_forces=link_forces,
|
@@ -58,12 +58,18 @@ def semi_implicit_euler_integration(
|
|
58
58
|
W_p_B = data.base_position + dt * W_ṗ_B
|
59
59
|
W_Q_B = data.base_orientation + dt * W_Q̇_B
|
60
60
|
|
61
|
-
base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B)
|
61
|
+
base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B, axis=-1)
|
62
62
|
|
63
63
|
W_Q_B = W_Q_B / jnp.where(base_quaternion_norm == 0, 1.0, base_quaternion_norm)
|
64
64
|
|
65
65
|
s = data.joint_positions + dt * ṡ
|
66
66
|
|
67
|
+
integrated_contact_state = jax.tree.map(
|
68
|
+
lambda x, x_dot: x + dt * x_dot,
|
69
|
+
data.contact_state,
|
70
|
+
contact_state_derivative,
|
71
|
+
)
|
72
|
+
|
67
73
|
# TODO: Avoid double replace, e.g. by computing cached value here
|
68
74
|
data = dataclasses.replace(
|
69
75
|
data,
|
@@ -73,6 +79,7 @@ def semi_implicit_euler_integration(
|
|
73
79
|
_joint_velocities=ṡ,
|
74
80
|
_base_linear_velocity=W_v_B[0:3],
|
75
81
|
_base_angular_velocity=W_ω_WB,
|
82
|
+
contact_state=integrated_contact_state,
|
76
83
|
)
|
77
84
|
|
78
85
|
# Update the cached computations.
|
@@ -104,7 +111,7 @@ def rk4_integration(
|
|
104
111
|
joint_torques=joint_torques,
|
105
112
|
)
|
106
113
|
|
107
|
-
base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion)
|
114
|
+
base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1)
|
108
115
|
base_quaternion = data._base_quaternion / jnp.where(
|
109
116
|
base_quaternion_norm == 0, 1.0, base_quaternion_norm
|
110
117
|
)
|
@@ -116,6 +123,7 @@ def rk4_integration(
|
|
116
123
|
base_linear_velocity=data._base_linear_velocity,
|
117
124
|
base_angular_velocity=data._base_angular_velocity,
|
118
125
|
joint_velocities=data._joint_velocities,
|
126
|
+
contact_state=data.contact_state,
|
119
127
|
)
|
120
128
|
|
121
129
|
euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt
|
@@ -136,14 +144,13 @@ def rk4_integration(
|
|
136
144
|
|
137
145
|
data_tf = dataclasses.replace(
|
138
146
|
data,
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
},
|
147
|
+
_base_position=x_tf["base_position"],
|
148
|
+
_base_quaternion=x_tf["base_quaternion"],
|
149
|
+
_joint_positions=x_tf["joint_positions"],
|
150
|
+
_base_linear_velocity=x_tf["base_linear_velocity"],
|
151
|
+
_base_angular_velocity=x_tf["base_angular_velocity"],
|
152
|
+
_joint_velocities=x_tf["joint_velocities"],
|
153
|
+
contact_state=x_tf["contact_state"],
|
147
154
|
)
|
148
155
|
|
149
156
|
return data_tf.replace(model=model)
|
jaxsim/api/model.py
CHANGED
@@ -47,13 +47,13 @@ class JaxSimModel(JaxsimDataclass):
|
|
47
47
|
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
|
48
48
|
)
|
49
49
|
|
50
|
-
gravity: Static[float] = jaxsim.math.STANDARD_GRAVITY
|
50
|
+
gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY
|
51
51
|
|
52
52
|
contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(
|
53
53
|
default=None, repr=False
|
54
54
|
)
|
55
55
|
|
56
|
-
|
56
|
+
contact_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field(
|
57
57
|
default=None, repr=False
|
58
58
|
)
|
59
59
|
|
@@ -177,9 +177,9 @@ class JaxSimModel(JaxsimDataclass):
|
|
177
177
|
time_step=time_step,
|
178
178
|
terrain=terrain,
|
179
179
|
contact_model=contact_model,
|
180
|
-
|
180
|
+
contact_params=contact_params,
|
181
181
|
integrator=integrator,
|
182
|
-
gravity
|
182
|
+
gravity=-gravity,
|
183
183
|
)
|
184
184
|
|
185
185
|
# Store the origin of the model, in case downstream logic needs it.
|
@@ -197,7 +197,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
197
197
|
time_step: jtp.FloatLike | None = None,
|
198
198
|
terrain: jaxsim.terrain.Terrain | None = None,
|
199
199
|
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
|
200
|
-
|
200
|
+
contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
|
201
201
|
integrator: IntegratorType | None = None,
|
202
202
|
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
|
203
203
|
) -> JaxSimModel:
|
@@ -217,8 +217,8 @@ class JaxSimModel(JaxsimDataclass):
|
|
217
217
|
The optional name of the model overriding the physics model name.
|
218
218
|
contact_model:
|
219
219
|
The contact model to consider.
|
220
|
-
If not specified, a
|
221
|
-
|
220
|
+
If not specified, a relaxed-constraints rigid contacts model is used.
|
221
|
+
contact_params: The parameters of the contact model.
|
222
222
|
integrator: The integrator to use for the simulation.
|
223
223
|
gravity: The gravity constant.
|
224
224
|
|
@@ -252,8 +252,8 @@ class JaxSimModel(JaxsimDataclass):
|
|
252
252
|
else jaxsim.rbda.contacts.RelaxedRigidContacts.build()
|
253
253
|
)
|
254
254
|
|
255
|
-
if
|
256
|
-
|
255
|
+
if contact_params is None:
|
256
|
+
contact_params = contact_model._parameters_class()
|
257
257
|
|
258
258
|
# Consider the default integrator if not specified.
|
259
259
|
integrator = (
|
@@ -271,7 +271,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
271
271
|
time_step=time_step,
|
272
272
|
terrain=terrain,
|
273
273
|
contact_model=contact_model,
|
274
|
-
|
274
|
+
contact_params=contact_params,
|
275
275
|
integrator=integrator,
|
276
276
|
gravity=gravity,
|
277
277
|
# The following is wrapped as hashless since it's a static argument, and we
|
@@ -473,7 +473,7 @@ def reduce(
|
|
473
473
|
time_step=model.time_step,
|
474
474
|
terrain=model.terrain,
|
475
475
|
contact_model=model.contact_model,
|
476
|
-
|
476
|
+
contact_params=model.contact_params,
|
477
477
|
gravity=model.gravity,
|
478
478
|
integrator=model.integrator,
|
479
479
|
)
|
@@ -2069,14 +2069,12 @@ def step(
|
|
2069
2069
|
)
|
2070
2070
|
|
2071
2071
|
# Get the external forces in inertial-fixed representation.
|
2072
|
-
W_f_L_external =
|
2073
|
-
|
2074
|
-
|
2075
|
-
|
2076
|
-
|
2077
|
-
|
2078
|
-
)
|
2079
|
-
)(O_f_L_external, data._link_transforms)
|
2072
|
+
W_f_L_external = js.data.JaxSimModelData.other_representation_to_inertial(
|
2073
|
+
O_f_L_external,
|
2074
|
+
other_representation=data.velocity_representation,
|
2075
|
+
transform=data._link_transforms,
|
2076
|
+
is_force=True,
|
2077
|
+
)
|
2080
2078
|
|
2081
2079
|
τ_references = jnp.atleast_1d(
|
2082
2080
|
jnp.array(joint_force_references, dtype=float).squeeze()
|
@@ -2092,29 +2090,6 @@ def step(
|
|
2092
2090
|
model, data, joint_force_references=τ_references
|
2093
2091
|
)
|
2094
2092
|
|
2095
|
-
# ======================
|
2096
|
-
# Compute contact forces
|
2097
|
-
# ======================
|
2098
|
-
|
2099
|
-
W_f_L_terrain = jnp.zeros_like(W_f_L_external)
|
2100
|
-
|
2101
|
-
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
2102
|
-
|
2103
|
-
# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
|
2104
|
-
# with the terrain.
|
2105
|
-
W_f_L_terrain = js.contact_model.link_contact_forces(
|
2106
|
-
model=model,
|
2107
|
-
data=data,
|
2108
|
-
link_forces=W_f_L_external,
|
2109
|
-
joint_torques=τ_total,
|
2110
|
-
)
|
2111
|
-
|
2112
|
-
# ==============================
|
2113
|
-
# Compute the total link forces
|
2114
|
-
# ==============================
|
2115
|
-
|
2116
|
-
W_f_L_total = W_f_L_external + W_f_L_terrain
|
2117
|
-
|
2118
2093
|
# =============================
|
2119
2094
|
# Advance the simulation state
|
2120
2095
|
# =============================
|
@@ -2124,7 +2099,14 @@ def step(
|
|
2124
2099
|
integrator_fn = _INTEGRATORS_MAP[model.integrator]
|
2125
2100
|
|
2126
2101
|
data_tf = integrator_fn(
|
2127
|
-
model=model,
|
2102
|
+
model=model,
|
2103
|
+
data=data,
|
2104
|
+
link_forces=W_f_L_external,
|
2105
|
+
joint_torques=τ_total,
|
2106
|
+
)
|
2107
|
+
|
2108
|
+
data_tf = model.contact_model.update_velocity_after_impact(
|
2109
|
+
model=model, data=data_tf
|
2128
2110
|
)
|
2129
2111
|
|
2130
2112
|
return data_tf
|
jaxsim/api/ode.py
CHANGED
@@ -46,12 +46,36 @@ def system_acceleration(
|
|
46
46
|
else jnp.zeros((model.number_of_links(), 6))
|
47
47
|
).astype(float)
|
48
48
|
|
49
|
+
# ======================
|
50
|
+
# Compute contact forces
|
51
|
+
# ======================
|
52
|
+
|
53
|
+
W_f_L_terrain = jnp.zeros_like(f_L)
|
54
|
+
contact_state_derivative = {}
|
55
|
+
|
56
|
+
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
57
|
+
|
58
|
+
# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
|
59
|
+
# with the terrain.
|
60
|
+
W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces(
|
61
|
+
model=model,
|
62
|
+
data=data,
|
63
|
+
link_forces=f_L,
|
64
|
+
joint_torques=joint_torques,
|
65
|
+
)
|
66
|
+
|
67
|
+
W_f_L_total = f_L + W_f_L_terrain
|
68
|
+
|
69
|
+
# Update the contact state data. This is necessary only for the contact models
|
70
|
+
# that require propagation and integration of contact state.
|
71
|
+
contact_state = model.contact_model.update_contact_state(contact_state_derivative)
|
72
|
+
|
49
73
|
# Store the link forces in a references object.
|
50
74
|
references = js.references.JaxSimModelReferences.build(
|
51
75
|
model=model,
|
52
76
|
data=data,
|
53
77
|
velocity_representation=data.velocity_representation,
|
54
|
-
link_forces=
|
78
|
+
link_forces=W_f_L_total,
|
55
79
|
)
|
56
80
|
|
57
81
|
# Compute forward dynamics.
|
@@ -68,13 +92,12 @@ def system_acceleration(
|
|
68
92
|
link_forces=references.link_forces(model=model, data=data),
|
69
93
|
)
|
70
94
|
|
71
|
-
return v̇_WB, s
|
95
|
+
return v̇_WB, s̈, contact_state
|
72
96
|
|
73
97
|
|
74
98
|
@jax.jit
|
75
99
|
@js.common.named_scope
|
76
100
|
def system_position_dynamics(
|
77
|
-
model: js.model.JaxSimModel,
|
78
101
|
data: js.data.JaxSimModelData,
|
79
102
|
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
|
80
103
|
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
|
@@ -82,7 +105,6 @@ def system_position_dynamics(
|
|
82
105
|
Compute the dynamics of the system position.
|
83
106
|
|
84
107
|
Args:
|
85
|
-
model: The model to consider.
|
86
108
|
data: The data of the considered model.
|
87
109
|
baumgarte_quaternion_regularization:
|
88
110
|
The Baumgarte regularization coefficient for adjusting the quaternion norm.
|
@@ -144,7 +166,7 @@ def system_dynamics(
|
|
144
166
|
"""
|
145
167
|
|
146
168
|
with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
|
147
|
-
W_v̇_WB, s
|
169
|
+
W_v̇_WB, s̈, contact_state_derivative = system_acceleration(
|
148
170
|
model=model,
|
149
171
|
data=data,
|
150
172
|
joint_torques=joint_torques,
|
@@ -152,7 +174,6 @@ def system_dynamics(
|
|
152
174
|
)
|
153
175
|
|
154
176
|
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
155
|
-
model=model,
|
156
177
|
data=data,
|
157
178
|
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
|
158
179
|
)
|
@@ -164,4 +185,5 @@ def system_dynamics(
|
|
164
185
|
base_linear_velocity=W_v̇_WB[0:3],
|
165
186
|
base_angular_velocity=W_v̇_WB[3:6],
|
166
187
|
joint_velocities=s̈,
|
188
|
+
contact_state=contact_state_derivative,
|
167
189
|
)
|
jaxsim/api/references.py
CHANGED
@@ -434,24 +434,17 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
434
434
|
if not_tracing(forces) and not data.valid(model=model):
|
435
435
|
raise ValueError("The provided data is not valid for the model")
|
436
436
|
|
437
|
-
|
438
|
-
# considering as body the link (i.e. L_f_L and LW_f_L).
|
439
|
-
def convert_using_link_frame(
|
440
|
-
f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike
|
441
|
-
) -> jtp.Matrix:
|
442
|
-
|
443
|
-
return jax.vmap(
|
444
|
-
lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial(
|
445
|
-
array=f_L,
|
446
|
-
other_representation=self.velocity_representation,
|
447
|
-
transform=W_H_L,
|
448
|
-
is_force=True,
|
449
|
-
)
|
450
|
-
)(f_L, W_H_L)
|
437
|
+
W_H_L = data._link_transforms
|
451
438
|
|
439
|
+
# Convert a single 6D force to the inertial representation
|
440
|
+
# considering as body the link (i.e. L_f_L and LW_f_L).
|
452
441
|
# The f_L input is either L_f_L or LW_f_L, depending on the representation.
|
453
|
-
|
454
|
-
|
442
|
+
W_f_L = JaxSimModelReferences.other_representation_to_inertial(
|
443
|
+
array=f_L,
|
444
|
+
other_representation=self.velocity_representation,
|
445
|
+
transform=W_H_L[link_idxs] if model.number_of_links() > 1 else W_H_L,
|
446
|
+
is_force=True,
|
447
|
+
)
|
455
448
|
|
456
449
|
return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))
|
457
450
|
|
jaxsim/math/__init__.py
CHANGED
jaxsim/math/adjoint.py
CHANGED
@@ -55,13 +55,13 @@ class Adjoint:
|
|
55
55
|
The 6x6 adjoint matrix.
|
56
56
|
"""
|
57
57
|
|
58
|
-
A_H_B =
|
58
|
+
A_H_B = transform
|
59
59
|
|
60
60
|
return (
|
61
61
|
jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint()
|
62
62
|
if not inverse
|
63
63
|
else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint()
|
64
|
-
)
|
64
|
+
)
|
65
65
|
|
66
66
|
@staticmethod
|
67
67
|
def from_rotation_and_translation(
|
jaxsim/math/transform.py
CHANGED
@@ -36,8 +36,8 @@ class Transform:
|
|
36
36
|
W_Q_B = jnp.array(quaternion).astype(float)
|
37
37
|
W_p_B = jnp.array(translation).astype(float)
|
38
38
|
|
39
|
-
assert W_p_B.
|
40
|
-
assert W_Q_B.
|
39
|
+
assert W_p_B.shape[-1] == 3
|
40
|
+
assert W_Q_B.shape[-1] == 4
|
41
41
|
|
42
42
|
A_R_B = jaxlie.SO3(wxyz=W_Q_B)
|
43
43
|
A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
|
jaxsim/math/utils.py
CHANGED
@@ -3,7 +3,7 @@ import jax.numpy as jnp
|
|
3
3
|
import jaxsim.typing as jtp
|
4
4
|
|
5
5
|
|
6
|
-
def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
|
6
|
+
def safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False) -> jtp.Array:
|
7
7
|
"""
|
8
8
|
Compute an array norm handling NaNs and making sure that
|
9
9
|
it is safe to get the gradient.
|
@@ -11,6 +11,7 @@ def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
|
|
11
11
|
Args:
|
12
12
|
array: The array for which to compute the norm.
|
13
13
|
axis: The axis for which to compute the norm.
|
14
|
+
keepdims: Whether to keep the dimensions of the input
|
14
15
|
|
15
16
|
Returns:
|
16
17
|
The norm of the array with handling for zero arrays to avoid NaNs.
|
@@ -24,7 +25,7 @@ def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
|
|
24
25
|
array = jnp.where(is_zero, jnp.ones_like(array), array)
|
25
26
|
|
26
27
|
# Compute the norm of the array along the specified axis.
|
27
|
-
norm = jnp.linalg.norm(array, axis=axis)
|
28
|
+
norm = jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
|
28
29
|
|
29
30
|
# Use `jnp.where` to set the norm to 0.0 where the input array was all zeros.
|
30
31
|
# This usage supports potential batch processing for future scalability.
|
jaxsim/mujoco/visualizer.py
CHANGED
@@ -64,7 +64,7 @@ class MujocoVideoRecorder:
|
|
64
64
|
|
65
65
|
self.frames = []
|
66
66
|
|
67
|
-
self.data =
|
67
|
+
self.data = data if data is not None else self.data
|
68
68
|
self.data = self.data if isinstance(self.data, list) else [self.data]
|
69
69
|
|
70
70
|
self.model = model if model is not None else self.model
|
@@ -973,7 +973,7 @@ class KinematicGraphTransforms:
|
|
973
973
|
|
974
974
|
if frame.parent_name in self.graph.links_dict:
|
975
975
|
return frame.parent_name
|
976
|
-
|
976
|
+
if frame.parent_name in self.graph.frames_dict:
|
977
977
|
return self.find_parent_link_of_frame(name=frame.parent_name)
|
978
978
|
|
979
979
|
msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent_name}'"
|
jaxsim/rbda/__init__.py
CHANGED
@@ -2,7 +2,7 @@ from . import contacts
|
|
2
2
|
from .aba import aba
|
3
3
|
from .collidable_points import collidable_points_pos_vel
|
4
4
|
from .crba import crba
|
5
|
-
from .forward_kinematics import
|
5
|
+
from .forward_kinematics import forward_kinematics_model
|
6
6
|
from .jacobian import (
|
7
7
|
jacobian,
|
8
8
|
jacobian_derivative_full_doubly_left,
|
jaxsim/rbda/contacts/__init__.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1
|
-
from . import relaxed_rigid
|
1
|
+
from . import relaxed_rigid, rigid, soft
|
2
2
|
from .common import ContactModel, ContactsParams
|
3
3
|
from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
|
4
|
+
from .rigid import RigidContacts, RigidContactsParams
|
5
|
+
from .soft import SoftContacts, SoftContactsParams
|
4
6
|
|
5
|
-
ContactParamsTypes =
|
7
|
+
ContactParamsTypes = (
|
8
|
+
SoftContactsParams | RigidContactsParams | RelaxedRigidContactsParams
|
9
|
+
)
|
jaxsim/rbda/contacts/common.py
CHANGED
@@ -9,6 +9,7 @@ import jax.numpy as jnp
|
|
9
9
|
import jaxsim.api as js
|
10
10
|
import jaxsim.terrain
|
11
11
|
import jaxsim.typing as jtp
|
12
|
+
from jaxsim.math import STANDARD_GRAVITY
|
12
13
|
from jaxsim.utils import JaxsimDataclass
|
13
14
|
|
14
15
|
try:
|
@@ -80,6 +81,86 @@ class ContactsParams(JaxsimDataclass):
|
|
80
81
|
"""
|
81
82
|
pass
|
82
83
|
|
84
|
+
def build_default_from_jaxsim_model(
|
85
|
+
self: type[Self],
|
86
|
+
model: js.model.JaxSimModel,
|
87
|
+
*,
|
88
|
+
stiffness: jtp.FloatLike | None = None,
|
89
|
+
damping: jtp.FloatLike | None = None,
|
90
|
+
standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,
|
91
|
+
static_friction_coefficient: jtp.FloatLike = 0.5,
|
92
|
+
max_penetration: jtp.FloatLike = 0.001,
|
93
|
+
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
94
|
+
damping_ratio: jtp.FloatLike = 1.0,
|
95
|
+
p: jtp.FloatLike = 0.5,
|
96
|
+
q: jtp.FloatLike = 0.5,
|
97
|
+
**kwargs,
|
98
|
+
) -> Self:
|
99
|
+
"""
|
100
|
+
Create a `ContactsParams` instance with default parameters.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
model: The robot model considered by the contact model.
|
104
|
+
stiffness: The stiffness of the contact model.
|
105
|
+
damping: The damping of the contact model.
|
106
|
+
standard_gravity: The standard gravity acceleration.
|
107
|
+
static_friction_coefficient: The static friction coefficient.
|
108
|
+
max_penetration: The maximum penetration depth.
|
109
|
+
number_of_active_collidable_points_steady_state:
|
110
|
+
The number of active collidable points in steady state.
|
111
|
+
damping_ratio: The damping ratio.
|
112
|
+
p: The first parameter of the contact model.
|
113
|
+
q: The second parameter of the contact model.
|
114
|
+
**kwargs: Optional additional arguments.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
The `ContactsParams` instance.
|
118
|
+
|
119
|
+
Note:
|
120
|
+
The `stiffness` is intended as the terrain stiffness in the Soft Contacts model,
|
121
|
+
while it is the Baumgarte stabilization stiffness in the Rigid Contacts model.
|
122
|
+
|
123
|
+
The `damping` is intended as the terrain damping in the Soft Contacts model,
|
124
|
+
while it is the Baumgarte stabilization damping in the Rigid Contacts model.
|
125
|
+
|
126
|
+
The `damping_ratio` parameter allows to operate on the following conditions:
|
127
|
+
- ξ > 1.0: over-damped
|
128
|
+
- ξ = 1.0: critically damped
|
129
|
+
- ξ < 1.0: under-damped
|
130
|
+
"""
|
131
|
+
|
132
|
+
# Use symbols for input parameters.
|
133
|
+
ξ = damping_ratio
|
134
|
+
δ_max = max_penetration
|
135
|
+
μc = static_friction_coefficient
|
136
|
+
|
137
|
+
# Compute the total mass of the model.
|
138
|
+
m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
|
139
|
+
|
140
|
+
# Rename the standard gravity.
|
141
|
+
g = standard_gravity
|
142
|
+
|
143
|
+
# Compute the average support force on each collidable point.
|
144
|
+
f_average = m * g / number_of_active_collidable_points_steady_state
|
145
|
+
|
146
|
+
# Compute the stiffness to get the desired steady-state penetration.
|
147
|
+
# Note that this is dependent on the non-linear exponent used in
|
148
|
+
# the damping term of the Hunt/Crossley model.
|
149
|
+
K = f_average / jnp.power(δ_max, 1 + p) if stiffness is None else stiffness
|
150
|
+
|
151
|
+
# Compute the damping using the damping ratio.
|
152
|
+
critical_damping = 2 * jnp.sqrt(K * m)
|
153
|
+
D = ξ * critical_damping if damping is None else damping
|
154
|
+
|
155
|
+
return self.build(
|
156
|
+
K=K,
|
157
|
+
D=D,
|
158
|
+
mu=μc,
|
159
|
+
p=p,
|
160
|
+
q=q,
|
161
|
+
**kwargs,
|
162
|
+
)
|
163
|
+
|
83
164
|
@abc.abstractmethod
|
84
165
|
def valid(self, **kwargs) -> jtp.BoolLike:
|
85
166
|
"""
|
@@ -156,7 +237,7 @@ class ContactModel(JaxsimDataclass):
|
|
156
237
|
return {}
|
157
238
|
|
158
239
|
@property
|
159
|
-
def _parameters_class(
|
240
|
+
def _parameters_class(self) -> type[ContactsParams]:
|
160
241
|
"""
|
161
242
|
Return the class of the contact parameters.
|
162
243
|
|
@@ -168,8 +249,37 @@ class ContactModel(JaxsimDataclass):
|
|
168
249
|
return getattr(
|
169
250
|
importlib.import_module("jaxsim.rbda.contacts"),
|
170
251
|
(
|
171
|
-
|
172
|
-
if isinstance(
|
173
|
-
else
|
252
|
+
self.__name__ + "Params"
|
253
|
+
if isinstance(self, type)
|
254
|
+
else self.__class__.__name__ + "Params"
|
174
255
|
),
|
175
256
|
)
|
257
|
+
|
258
|
+
@abc.abstractmethod
|
259
|
+
def update_contact_state(
|
260
|
+
self: type[Self], old_contact_state: dict[str, jtp.Array]
|
261
|
+
) -> dict[str, jtp.Array]:
|
262
|
+
"""
|
263
|
+
Update the contact state.
|
264
|
+
|
265
|
+
Args:
|
266
|
+
old_contact_state: The old contact state.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
The updated contact state.
|
270
|
+
"""
|
271
|
+
|
272
|
+
@abc.abstractmethod
|
273
|
+
def update_velocity_after_impact(
|
274
|
+
self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
275
|
+
) -> js.data.JaxSimModelData:
|
276
|
+
"""
|
277
|
+
Update the velocity after an impact.
|
278
|
+
|
279
|
+
Args:
|
280
|
+
model: The robot model considered by the contact model.
|
281
|
+
data: The data of the considered model.
|
282
|
+
|
283
|
+
Returns:
|
284
|
+
The updated data of the considered model.
|
285
|
+
"""
|