jaxsim 0.4.3.dev17__py3-none-any.whl → 0.4.3.dev18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +36 -8
- jaxsim/api/model.py +60 -7
- jaxsim/api/ode.py +114 -36
- jaxsim/api/ode_data.py +11 -7
- jaxsim/integrators/common.py +27 -18
- jaxsim/rbda/contacts/common.py +3 -2
- jaxsim/rbda/contacts/rigid.py +478 -0
- {jaxsim-0.4.3.dev17.dist-info → jaxsim-0.4.3.dev18.dist-info}/METADATA +2 -1
- {jaxsim-0.4.3.dev17.dist-info → jaxsim-0.4.3.dev18.dist-info}/RECORD +13 -12
- {jaxsim-0.4.3.dev17.dist-info → jaxsim-0.4.3.dev18.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev17.dist-info → jaxsim-0.4.3.dev18.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev17.dist-info → jaxsim-0.4.3.dev18.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev18'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev18')
|
jaxsim/api/contact.py
CHANGED
@@ -114,19 +114,24 @@ def collidable_point_forces(
|
|
114
114
|
|
115
115
|
@jax.jit
|
116
116
|
def collidable_point_dynamics(
|
117
|
-
model: js.model.JaxSimModel,
|
118
|
-
|
117
|
+
model: js.model.JaxSimModel,
|
118
|
+
data: js.data.JaxSimModelData,
|
119
|
+
link_forces: jtp.MatrixLike | None = None,
|
120
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
|
119
121
|
r"""
|
120
|
-
Compute the 6D force applied to each collidable point
|
121
|
-
material deformation rate.
|
122
|
+
Compute the 6D force applied to each collidable point.
|
122
123
|
|
123
124
|
Args:
|
124
125
|
model: The model to consider.
|
125
126
|
data: The data of the considered model.
|
127
|
+
link_forces:
|
128
|
+
The 6D external forces to apply to the links expressed in the same
|
129
|
+
representation of data.
|
126
130
|
|
127
131
|
Returns:
|
128
|
-
The 6D force applied to each collidable point and the
|
129
|
-
material deformation rate.
|
132
|
+
The 6D force applied to each collidable point and additional data based on the contact model configured:
|
133
|
+
- Soft: the material deformation rate.
|
134
|
+
- Rigid: nothing.
|
130
135
|
|
131
136
|
Note:
|
132
137
|
The material deformation rate is always returned in the mixed frame
|
@@ -138,7 +143,8 @@ def collidable_point_dynamics(
|
|
138
143
|
# all collidable points belonging to the robot.
|
139
144
|
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
|
140
145
|
|
141
|
-
# Import privately the
|
146
|
+
# Import privately the contacts classes.
|
147
|
+
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
142
148
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
143
149
|
|
144
150
|
# Build the soft contact model.
|
@@ -161,6 +167,28 @@ def collidable_point_dynamics(
|
|
161
167
|
W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
|
162
168
|
W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
|
163
169
|
)
|
170
|
+
aux_data = dict(m_dot=CW_ṁ)
|
171
|
+
|
172
|
+
case RigidContacts():
|
173
|
+
assert isinstance(model.contact_model, RigidContacts)
|
174
|
+
assert isinstance(data.state.contact, RigidContactsState)
|
175
|
+
|
176
|
+
# Build the contact model.
|
177
|
+
rigid_contacts = RigidContacts(
|
178
|
+
parameters=data.contacts_params, terrain=model.terrain
|
179
|
+
)
|
180
|
+
|
181
|
+
# Compute the 6D force expressed in the inertial frame and applied to each
|
182
|
+
# collidable point.
|
183
|
+
W_f_Ci, _ = rigid_contacts.compute_contact_forces(
|
184
|
+
position=W_p_Ci,
|
185
|
+
velocity=W_ṗ_Ci,
|
186
|
+
model=model,
|
187
|
+
data=data,
|
188
|
+
link_forces=link_forces,
|
189
|
+
)
|
190
|
+
|
191
|
+
aux_data = dict()
|
164
192
|
|
165
193
|
case _:
|
166
194
|
raise ValueError(f"Invalid contact model {model.contact_model}")
|
@@ -175,7 +203,7 @@ def collidable_point_dynamics(
|
|
175
203
|
)
|
176
204
|
)(W_f_Ci)
|
177
205
|
|
178
|
-
return f_Ci,
|
206
|
+
return f_Ci, aux_data
|
179
207
|
|
180
208
|
|
181
209
|
@functools.partial(jax.jit, static_argnames=["link_names"])
|
jaxsim/api/model.py
CHANGED
@@ -14,6 +14,7 @@ import rod
|
|
14
14
|
from jax_dataclasses import Static
|
15
15
|
|
16
16
|
import jaxsim.api as js
|
17
|
+
import jaxsim.exceptions
|
17
18
|
import jaxsim.terrain
|
18
19
|
import jaxsim.typing as jtp
|
19
20
|
from jaxsim.math import Adjoint, Cross
|
@@ -1890,6 +1891,8 @@ def step(
|
|
1890
1891
|
and the new state of the integrator.
|
1891
1892
|
"""
|
1892
1893
|
|
1894
|
+
from jaxsim.rbda.contacts.rigid import RigidContacts
|
1895
|
+
|
1893
1896
|
# Extract the integrator kwargs.
|
1894
1897
|
# The following logic allows using integrators having kwargs colliding with the
|
1895
1898
|
# kwargs of this step function.
|
@@ -1901,12 +1904,12 @@ def step(
|
|
1901
1904
|
|
1902
1905
|
# Extract the initial resources.
|
1903
1906
|
t0_ns = data.time_ns
|
1904
|
-
|
1907
|
+
state_t0 = data.state
|
1905
1908
|
integrator_state_x0 = integrator_state
|
1906
1909
|
|
1907
1910
|
# Step the dynamics forward.
|
1908
|
-
|
1909
|
-
x0=
|
1911
|
+
state_tf, integrator_state_tf = integrator.step(
|
1912
|
+
x0=state_t0,
|
1910
1913
|
t0=jnp.array(t0_ns / 1e9).astype(float),
|
1911
1914
|
dt=dt,
|
1912
1915
|
params=integrator_state_x0,
|
@@ -1928,11 +1931,61 @@ def step(
|
|
1928
1931
|
),
|
1929
1932
|
)
|
1930
1933
|
|
1931
|
-
|
1934
|
+
data_tf = (
|
1932
1935
|
# Store the new state of the model and the new time.
|
1933
1936
|
data.replace(
|
1934
|
-
state=
|
1937
|
+
state=state_tf,
|
1935
1938
|
time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
|
1936
|
-
)
|
1937
|
-
|
1939
|
+
)
|
1940
|
+
)
|
1941
|
+
|
1942
|
+
# Post process the simulation state, if needed.
|
1943
|
+
match model.contact_model:
|
1944
|
+
|
1945
|
+
# Rigid contact models use an impact model that produces a discontinuous model velocity.
|
1946
|
+
# Hence here we need to reset the velocity after each impact to guarantee that
|
1947
|
+
# the linear velocity of the active collidable points is zero.
|
1948
|
+
case RigidContacts():
|
1949
|
+
# Raise runtime error for not supported case in which Rigid contacts and Baumgarte stabilization
|
1950
|
+
# enabled are used with ForwardEuler integrator.
|
1951
|
+
jaxsim.exceptions.raise_runtime_error_if(
|
1952
|
+
condition=jnp.logical_and(
|
1953
|
+
isinstance(
|
1954
|
+
integrator,
|
1955
|
+
jaxsim.integrators.fixed_step.ForwardEuler
|
1956
|
+
| jaxsim.integrators.fixed_step.ForwardEulerSO3,
|
1957
|
+
),
|
1958
|
+
jnp.array(
|
1959
|
+
[data_tf.contacts_params.K, data_tf.contacts_params.D]
|
1960
|
+
).any(),
|
1961
|
+
),
|
1962
|
+
msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
|
1963
|
+
)
|
1964
|
+
|
1965
|
+
with data_tf.switch_velocity_representation(VelRepr.Mixed):
|
1966
|
+
W_p_C = js.contact.collidable_point_positions(model, data_tf)
|
1967
|
+
M = js.model.free_floating_mass_matrix(model, data_tf)
|
1968
|
+
J_WC = js.contact.jacobian(model, data_tf)
|
1969
|
+
px, py, _ = W_p_C.T
|
1970
|
+
terrain_height = jax.vmap(model.terrain.height)(px, py)
|
1971
|
+
inactive_collidable_points, _ = RigidContacts.detect_contacts(
|
1972
|
+
W_p_C=W_p_C,
|
1973
|
+
terrain_height=terrain_height,
|
1974
|
+
)
|
1975
|
+
BW_nu_post_impact = RigidContacts.compute_impact_velocity(
|
1976
|
+
data=data_tf,
|
1977
|
+
inactive_collidable_points=inactive_collidable_points,
|
1978
|
+
M=M,
|
1979
|
+
J_WC=J_WC,
|
1980
|
+
)
|
1981
|
+
data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
|
1982
|
+
data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
|
1983
|
+
# Restore the input velocity representation.
|
1984
|
+
data_tf = data_tf.replace(
|
1985
|
+
velocity_representation=data.velocity_representation, validate=False
|
1986
|
+
)
|
1987
|
+
|
1988
|
+
return (
|
1989
|
+
data_tf,
|
1990
|
+
integrator_state_tf,
|
1938
1991
|
)
|
jaxsim/api/ode.py
CHANGED
@@ -50,7 +50,7 @@ def wrap_system_dynamics_for_integration(
|
|
50
50
|
# The wrapped dynamics will hold a reference of this object.
|
51
51
|
model_closed = model.copy()
|
52
52
|
data_closed = data.copy().replace(
|
53
|
-
state=js.ode_data.ODEState.zero(model=model_closed)
|
53
|
+
state=js.ode_data.ODEState.zero(model=model_closed, data=data)
|
54
54
|
)
|
55
55
|
|
56
56
|
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
@@ -88,7 +88,7 @@ def system_velocity_dynamics(
|
|
88
88
|
*,
|
89
89
|
joint_forces: jtp.Vector | None = None,
|
90
90
|
link_forces: jtp.Vector | None = None,
|
91
|
-
) -> tuple[jtp.Vector, jtp.Vector,
|
91
|
+
) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
|
92
92
|
"""
|
93
93
|
Compute the dynamics of the system velocity.
|
94
94
|
|
@@ -102,18 +102,10 @@ def system_velocity_dynamics(
|
|
102
102
|
|
103
103
|
Returns:
|
104
104
|
A tuple containing the derivative of the base 6D velocity in inertial-fixed
|
105
|
-
representation, the derivative of the joint velocities,
|
106
|
-
|
107
|
-
the system dynamics evaluation.
|
105
|
+
representation, the derivative of the joint velocities, and auxiliary data
|
106
|
+
returned by the system dynamics evaluation.
|
108
107
|
"""
|
109
108
|
|
110
|
-
# Build joint torques if not provided.
|
111
|
-
τ = (
|
112
|
-
jnp.atleast_1d(joint_forces.squeeze())
|
113
|
-
if joint_forces is not None
|
114
|
-
else jnp.zeros_like(data.joint_positions())
|
115
|
-
).astype(float)
|
116
|
-
|
117
109
|
# Build link forces if not provided.
|
118
110
|
# These forces are expressed in the frame corresponding to the velocity
|
119
111
|
# representation of data.
|
@@ -123,6 +115,15 @@ def system_velocity_dynamics(
|
|
123
115
|
else jnp.zeros((model.number_of_links(), 6))
|
124
116
|
).astype(float)
|
125
117
|
|
118
|
+
# We expect that the 6D forces included in the `link_forces` argument are expressed
|
119
|
+
# in the frame corresponding to the velocity representation of `data`.
|
120
|
+
references = js.references.JaxSimModelReferences.build(
|
121
|
+
model=model,
|
122
|
+
link_forces=O_f_L,
|
123
|
+
data=data,
|
124
|
+
velocity_representation=data.velocity_representation,
|
125
|
+
)
|
126
|
+
|
126
127
|
# ======================
|
127
128
|
# Compute contact forces
|
128
129
|
# ======================
|
@@ -131,19 +132,17 @@ def system_velocity_dynamics(
|
|
131
132
|
# with the terrain.
|
132
133
|
W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
|
133
134
|
|
134
|
-
|
135
|
-
from jaxsim.rbda.contacts.soft import SoftContactsState
|
136
|
-
|
137
|
-
# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
|
138
|
-
assert isinstance(data.state.contact, SoftContactsState)
|
139
|
-
ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)
|
140
|
-
|
135
|
+
aux_data = {}
|
141
136
|
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
142
137
|
|
143
138
|
# Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
|
144
|
-
#
|
139
|
+
# along with contact-specific auxiliary states.
|
145
140
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
146
|
-
W_f_Ci,
|
141
|
+
W_f_Ci, aux_data = js.contact.collidable_point_dynamics(
|
142
|
+
model=model,
|
143
|
+
data=data,
|
144
|
+
link_forces=references.link_forces(model=model, data=data),
|
145
|
+
)
|
147
146
|
|
148
147
|
# Construct the vector defining the parent link index of each collidable point.
|
149
148
|
# We use this vector to sum the 6D forces of all collidable points rigidly
|
@@ -161,6 +160,74 @@ def system_velocity_dynamics(
|
|
161
160
|
|
162
161
|
W_f_Li_terrain = mask.T @ W_f_Ci
|
163
162
|
|
163
|
+
# ===========================
|
164
|
+
# Compute system acceleration
|
165
|
+
# ===========================
|
166
|
+
|
167
|
+
# Compute the total link forces
|
168
|
+
with (
|
169
|
+
data.switch_velocity_representation(VelRepr.Inertial),
|
170
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
171
|
+
):
|
172
|
+
references = references.apply_link_forces(
|
173
|
+
model=model,
|
174
|
+
data=data,
|
175
|
+
forces=W_f_Li_terrain,
|
176
|
+
additive=True,
|
177
|
+
)
|
178
|
+
# Get the link forces in the data representation
|
179
|
+
with references.switch_velocity_representation(data.velocity_representation):
|
180
|
+
f_L_total = references.link_forces(model=model, data=data)
|
181
|
+
|
182
|
+
# The following method always returns the inertial-fixed acceleration, and expects
|
183
|
+
# the link_forces expressed in the inertial frame.
|
184
|
+
W_v̇_WB, s̈ = system_acceleration(
|
185
|
+
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
|
186
|
+
)
|
187
|
+
|
188
|
+
return W_v̇_WB, s̈, aux_data
|
189
|
+
|
190
|
+
|
191
|
+
def system_acceleration(
|
192
|
+
model: js.model.JaxSimModel,
|
193
|
+
data: js.data.JaxSimModelData,
|
194
|
+
*,
|
195
|
+
joint_forces: jtp.VectorLike | None = None,
|
196
|
+
link_forces: jtp.MatrixLike | None = None,
|
197
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
198
|
+
"""
|
199
|
+
Compute the system acceleration in inertial-fixed representation.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
model: The model to consider.
|
203
|
+
data: The data of the considered model.
|
204
|
+
joint_forces: The joint forces to apply.
|
205
|
+
link_forces:
|
206
|
+
The 6D forces to apply to the links expressed in the same representation of data.
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
A tuple containing the base 6D acceleration in inertial-fixed representation
|
210
|
+
and the joint accelerations.
|
211
|
+
"""
|
212
|
+
|
213
|
+
# ====================
|
214
|
+
# Validate input data
|
215
|
+
# ====================
|
216
|
+
|
217
|
+
# Build link forces if not provided.
|
218
|
+
f_L = (
|
219
|
+
jnp.atleast_2d(link_forces.squeeze())
|
220
|
+
if link_forces is not None
|
221
|
+
else jnp.zeros((model.number_of_links(), 6))
|
222
|
+
).astype(float)
|
223
|
+
|
224
|
+
# Build joint torques if not provided.
|
225
|
+
τ = (
|
226
|
+
jnp.atleast_1d(joint_forces.squeeze())
|
227
|
+
if joint_forces is not None
|
228
|
+
else jnp.zeros_like(data.joint_positions())
|
229
|
+
).astype(float)
|
230
|
+
|
164
231
|
# ====================
|
165
232
|
# Enforce joint limits
|
166
233
|
# ====================
|
@@ -198,29 +265,25 @@ def system_velocity_dynamics(
|
|
198
265
|
|
199
266
|
references = js.references.JaxSimModelReferences.build(
|
200
267
|
model=model,
|
201
|
-
joint_force_references=τ_total,
|
202
|
-
link_forces=O_f_L,
|
203
268
|
data=data,
|
204
269
|
velocity_representation=data.velocity_representation,
|
270
|
+
joint_force_references=τ_total,
|
271
|
+
link_forces=f_L,
|
205
272
|
)
|
206
273
|
|
207
|
-
with references.switch_velocity_representation(VelRepr.Inertial):
|
208
|
-
W_f_L = references.link_forces(model=model, data=data)
|
209
|
-
|
210
|
-
# Compute the total external 6D forces applied to the links.
|
211
|
-
W_f_L_total = W_f_L + W_f_Li_terrain
|
212
|
-
|
213
274
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
214
275
|
# - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
|
215
|
-
with
|
276
|
+
with (
|
277
|
+
data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
|
278
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
279
|
+
):
|
216
280
|
W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
217
281
|
model=model,
|
218
282
|
data=data,
|
219
|
-
joint_forces
|
220
|
-
link_forces=
|
283
|
+
joint_forces=references.joint_force_references(),
|
284
|
+
link_forces=references.link_forces(),
|
221
285
|
)
|
222
|
-
|
223
|
-
return W_v̇_WB, s̈, ṁ, dict()
|
286
|
+
return W_v̇_WB, s̈
|
224
287
|
|
225
288
|
|
226
289
|
@jax.jit
|
@@ -291,14 +354,29 @@ def system_dynamics(
|
|
291
354
|
by the system dynamics evaluation.
|
292
355
|
"""
|
293
356
|
|
357
|
+
from jaxsim.rbda.contacts.rigid import RigidContacts
|
358
|
+
from jaxsim.rbda.contacts.soft import SoftContacts
|
359
|
+
|
294
360
|
# Compute the accelerations and the material deformation rate.
|
295
|
-
W_v̇_WB, s̈,
|
361
|
+
W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
|
296
362
|
model=model,
|
297
363
|
data=data,
|
298
364
|
joint_forces=joint_forces,
|
299
365
|
link_forces=link_forces,
|
300
366
|
)
|
301
367
|
|
368
|
+
ode_state_kwargs = {}
|
369
|
+
|
370
|
+
match model.contact_model:
|
371
|
+
case SoftContacts():
|
372
|
+
ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
|
373
|
+
|
374
|
+
case RigidContacts():
|
375
|
+
pass
|
376
|
+
|
377
|
+
case _:
|
378
|
+
raise ValueError("Unable to determine contact state class prefix.")
|
379
|
+
|
302
380
|
# Extract the velocities.
|
303
381
|
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
304
382
|
model=model,
|
@@ -317,7 +395,7 @@ def system_dynamics(
|
|
317
395
|
base_linear_velocity=W_v̇_WB[0:3],
|
318
396
|
base_angular_velocity=W_v̇_WB[3:6],
|
319
397
|
joint_velocities=s̈,
|
320
|
-
|
398
|
+
**ode_state_kwargs,
|
321
399
|
)
|
322
400
|
|
323
401
|
return ode_state_derivative, aux_dict
|
jaxsim/api/ode_data.py
CHANGED
@@ -6,6 +6,7 @@ import jax_dataclasses
|
|
6
6
|
import jaxsim.api as js
|
7
7
|
import jaxsim.typing as jtp
|
8
8
|
from jaxsim.rbda import ContactsState
|
9
|
+
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
9
10
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
10
11
|
from jaxsim.utils import JaxsimDataclass
|
11
12
|
|
@@ -133,7 +134,7 @@ class ODEState(JaxsimDataclass):
|
|
133
134
|
base_quaternion: jtp.Vector | None = None,
|
134
135
|
base_linear_velocity: jtp.Vector | None = None,
|
135
136
|
base_angular_velocity: jtp.Vector | None = None,
|
136
|
-
|
137
|
+
**kwargs,
|
137
138
|
) -> ODEState:
|
138
139
|
"""
|
139
140
|
Build an `ODEState` from a `JaxSimModel`.
|
@@ -148,9 +149,7 @@ class ODEState(JaxsimDataclass):
|
|
148
149
|
The linear velocity of the base link in inertial-fixed representation.
|
149
150
|
base_angular_velocity:
|
150
151
|
The angular velocity of the base link in inertial-fixed representation.
|
151
|
-
|
152
|
-
The matrix of 3D tangential material deformations corresponding to
|
153
|
-
each collidable point.
|
152
|
+
kwargs: Additional arguments needed to build the contact state.
|
154
153
|
|
155
154
|
Returns:
|
156
155
|
The `ODEState` built from the `JaxSimModel`.
|
@@ -163,6 +162,7 @@ class ODEState(JaxsimDataclass):
|
|
163
162
|
# Get the contact model from the `JaxSimModel`.
|
164
163
|
match model.contact_model:
|
165
164
|
case SoftContacts():
|
165
|
+
tangential_deformation = kwargs.get("tangential_deformation", None)
|
166
166
|
contact = SoftContactsState.build_from_jaxsim_model(
|
167
167
|
model=model,
|
168
168
|
**(
|
@@ -171,6 +171,8 @@ class ODEState(JaxsimDataclass):
|
|
171
171
|
else dict()
|
172
172
|
),
|
173
173
|
)
|
174
|
+
case RigidContacts():
|
175
|
+
contact = RigidContactsState.build()
|
174
176
|
case _:
|
175
177
|
raise ValueError("Unable to determine contact state class prefix.")
|
176
178
|
|
@@ -214,7 +216,7 @@ class ODEState(JaxsimDataclass):
|
|
214
216
|
|
215
217
|
# Get the contact model from the `JaxSimModel`.
|
216
218
|
match contact:
|
217
|
-
case SoftContactsState():
|
219
|
+
case SoftContactsState() | RigidContactsState():
|
218
220
|
pass
|
219
221
|
case None:
|
220
222
|
contact = SoftContactsState.zero(model=model)
|
@@ -224,7 +226,7 @@ class ODEState(JaxsimDataclass):
|
|
224
226
|
return ODEState(physics_model=physics_model_state, contact=contact)
|
225
227
|
|
226
228
|
@staticmethod
|
227
|
-
def zero(model: js.model.JaxSimModel) -> ODEState:
|
229
|
+
def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState:
|
228
230
|
"""
|
229
231
|
Build a zero `ODEState` from a `JaxSimModel`.
|
230
232
|
|
@@ -235,7 +237,9 @@ class ODEState(JaxsimDataclass):
|
|
235
237
|
A zero `ODEState` instance.
|
236
238
|
"""
|
237
239
|
|
238
|
-
model_state = ODEState.build(
|
240
|
+
model_state = ODEState.build(
|
241
|
+
model=model, contact=data.state.contact.zero(model=model)
|
242
|
+
)
|
239
243
|
|
240
244
|
return model_state
|
241
245
|
|
jaxsim/integrators/common.py
CHANGED
@@ -109,11 +109,14 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
109
109
|
integrator.params = params
|
110
110
|
|
111
111
|
with integrator.mutable_context(mutability=Mutability.MUTABLE):
|
112
|
-
xf = integrator(x0, t0, dt, **kwargs)
|
112
|
+
xf, aux_dict = integrator(x0, t0, dt, **kwargs)
|
113
113
|
|
114
|
-
return
|
115
|
-
|
116
|
-
|
114
|
+
return (
|
115
|
+
xf,
|
116
|
+
integrator.params
|
117
|
+
| {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
|
118
|
+
| aux_dict,
|
119
|
+
)
|
117
120
|
|
118
121
|
@abc.abstractmethod
|
119
122
|
def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
|
@@ -277,15 +280,19 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
277
280
|
|
278
281
|
return integrator
|
279
282
|
|
280
|
-
def __call__(
|
283
|
+
def __call__(
|
284
|
+
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
285
|
+
) -> tuple[NextState, dict[str, Any]]:
|
281
286
|
|
282
287
|
# Here z is a batched state with as many batch elements as b.T rows.
|
283
288
|
# Note that z has multiple batches only if b.T has more than one row,
|
284
289
|
# e.g. in Butcher tableau of embedded schemes.
|
285
|
-
z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
|
290
|
+
z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
|
286
291
|
|
287
292
|
# The next state is the batch element located at the configured index of solution.
|
288
|
-
|
293
|
+
next_state = jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
|
294
|
+
|
295
|
+
return next_state, aux_dict
|
289
296
|
|
290
297
|
@classmethod
|
291
298
|
def integrate_rk_stage(
|
@@ -343,7 +350,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
343
350
|
|
344
351
|
def _compute_next_state(
|
345
352
|
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
346
|
-
) -> NextState:
|
353
|
+
) -> tuple[NextState, dict[str, Any]]:
|
347
354
|
"""
|
348
355
|
Compute the next state of the system, returning all the output states.
|
349
356
|
|
@@ -373,19 +380,21 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
373
380
|
)
|
374
381
|
|
375
382
|
# Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
|
376
|
-
get_ẋ
|
383
|
+
get_ẋ0_and_aux_dict = lambda: self.params.get("dxdt0", f(x0, t0))
|
377
384
|
|
378
385
|
# We use a `jax.lax.scan` to compile the `f` function only once.
|
379
386
|
# Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
|
380
387
|
# would include 4 repetitions of the `f` logic, making everything extremely slow.
|
381
|
-
def scan_body(
|
388
|
+
def scan_body(
|
389
|
+
carry: jax.Array, i: int | jax.Array
|
390
|
+
) -> tuple[jax.Array, dict[str, Any]]:
|
382
391
|
""""""
|
383
392
|
|
384
393
|
# Unpack the carry, i.e. the stacked kᵢ vectors.
|
385
394
|
K = carry
|
386
395
|
|
387
396
|
# Define the computation of the Runge-Kutta stage.
|
388
|
-
def compute_ki() -> jax.Array:
|
397
|
+
def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
|
389
398
|
|
390
399
|
# Compute ∑ⱼ aᵢⱼ kⱼ.
|
391
400
|
op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
|
@@ -398,13 +407,13 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
398
407
|
# Compute the next time for the kᵢ evaluation.
|
399
408
|
ti = t0 + c[i] * Δt
|
400
409
|
|
401
|
-
# This is k
|
402
|
-
return f(xi, ti)
|
410
|
+
# This is kᵢ, aux_dict = f(xᵢ, tᵢ).
|
411
|
+
return f(xi, ti)
|
403
412
|
|
404
413
|
# This selector enables FSAL property in the first iteration (i=0).
|
405
|
-
ki = jax.lax.cond(
|
414
|
+
ki, aux_dict = jax.lax.cond(
|
406
415
|
pred=jnp.logical_and(i == 0, self.has_fsal),
|
407
|
-
true_fun=get_ẋ
|
416
|
+
true_fun=get_ẋ0_and_aux_dict,
|
408
417
|
false_fun=compute_ki,
|
409
418
|
)
|
410
419
|
|
@@ -413,10 +422,10 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
413
422
|
K = jax.tree_util.tree_map(op, K, ki)
|
414
423
|
|
415
424
|
carry = K
|
416
|
-
return carry,
|
425
|
+
return carry, aux_dict
|
417
426
|
|
418
427
|
# Compute the state derivatives kᵢ.
|
419
|
-
K,
|
428
|
+
K, aux_dict = jax.lax.scan(
|
420
429
|
f=scan_body,
|
421
430
|
init=carry0,
|
422
431
|
xs=jnp.arange(c.size),
|
@@ -439,7 +448,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
439
448
|
lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
|
440
449
|
)(z)
|
441
450
|
|
442
|
-
return z_transformed
|
451
|
+
return z_transformed, aux_dict
|
443
452
|
|
444
453
|
@staticmethod
|
445
454
|
def butcher_tableau_is_valid(
|
jaxsim/rbda/contacts/common.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Any
|
|
5
5
|
|
6
6
|
import jaxsim.terrain
|
7
7
|
import jaxsim.typing as jtp
|
8
|
+
from jaxsim.utils import JaxsimDataclass
|
8
9
|
|
9
10
|
|
10
11
|
class ContactsState(abc.ABC):
|
@@ -42,7 +43,7 @@ class ContactsState(abc.ABC):
|
|
42
43
|
pass
|
43
44
|
|
44
45
|
|
45
|
-
class ContactsParams(
|
46
|
+
class ContactsParams(JaxsimDataclass):
|
46
47
|
"""
|
47
48
|
Abstract class representing the parameters of a contact model.
|
48
49
|
"""
|
@@ -67,7 +68,7 @@ class ContactsParams(abc.ABC):
|
|
67
68
|
pass
|
68
69
|
|
69
70
|
|
70
|
-
class ContactModel(
|
71
|
+
class ContactModel(JaxsimDataclass):
|
71
72
|
"""
|
72
73
|
Abstract class representing a contact model.
|
73
74
|
|
@@ -0,0 +1,478 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import jax_dataclasses
|
9
|
+
|
10
|
+
import jaxsim.api as js
|
11
|
+
import jaxsim.typing as jtp
|
12
|
+
from jaxsim import math
|
13
|
+
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
14
|
+
from jaxsim.terrain import FlatTerrain, Terrain
|
15
|
+
|
16
|
+
from .common import ContactModel, ContactsParams, ContactsState
|
17
|
+
|
18
|
+
|
19
|
+
@jax_dataclasses.pytree_dataclass
|
20
|
+
class RigidContactsParams(ContactsParams):
|
21
|
+
"""Parameters of the rigid contacts model."""
|
22
|
+
|
23
|
+
# Static friction coefficient
|
24
|
+
mu: jtp.Float = dataclasses.field(
|
25
|
+
default_factory=lambda: jnp.array(0.5, dtype=float)
|
26
|
+
)
|
27
|
+
|
28
|
+
# Baumgarte proportional term
|
29
|
+
K: jtp.Float = dataclasses.field(
|
30
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
31
|
+
)
|
32
|
+
|
33
|
+
# Baumgarte derivative term
|
34
|
+
D: jtp.Float = dataclasses.field(
|
35
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
36
|
+
)
|
37
|
+
|
38
|
+
def __hash__(self) -> int:
|
39
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
40
|
+
|
41
|
+
return hash(
|
42
|
+
(
|
43
|
+
HashedNumpyArray.hash_of_array(self.mu),
|
44
|
+
HashedNumpyArray.hash_of_array(self.K),
|
45
|
+
HashedNumpyArray.hash_of_array(self.D),
|
46
|
+
)
|
47
|
+
)
|
48
|
+
|
49
|
+
def __eq__(self, other: RigidContactsParams) -> bool:
|
50
|
+
return hash(self) == hash(other)
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def build(
|
54
|
+
cls,
|
55
|
+
mu: jtp.FloatLike | None = None,
|
56
|
+
K: jtp.FloatLike | None = None,
|
57
|
+
D: jtp.FloatLike | None = None,
|
58
|
+
) -> RigidContactsParams:
|
59
|
+
"""Create a `RigidContactParams` instance"""
|
60
|
+
return RigidContactsParams(
|
61
|
+
mu=mu or cls.__dataclass_fields__["mu"].default,
|
62
|
+
K=K or cls.__dataclass_fields__["K"].default,
|
63
|
+
D=D or cls.__dataclass_fields__["D"].default,
|
64
|
+
)
|
65
|
+
|
66
|
+
def valid(self) -> bool:
|
67
|
+
return bool(
|
68
|
+
jnp.all(self.mu >= 0.0)
|
69
|
+
and jnp.all(self.K >= 0.0)
|
70
|
+
and jnp.all(self.D >= 0.0)
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
@jax_dataclasses.pytree_dataclass
|
75
|
+
class RigidContactsState(ContactsState):
|
76
|
+
"""Class storing the state of the rigid contacts model."""
|
77
|
+
|
78
|
+
def __eq__(self, other: RigidContactsState) -> bool:
|
79
|
+
return hash(self) == hash(other)
|
80
|
+
|
81
|
+
@staticmethod
|
82
|
+
def build(**kwargs) -> RigidContactsState:
|
83
|
+
"""Create a `RigidContactsState` instance"""
|
84
|
+
|
85
|
+
return RigidContactsState()
|
86
|
+
|
87
|
+
@staticmethod
|
88
|
+
def zero(**kwargs) -> RigidContactsState:
|
89
|
+
"""Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
|
90
|
+
return RigidContactsState.build()
|
91
|
+
|
92
|
+
def valid(self, **kwargs) -> bool:
|
93
|
+
return True
|
94
|
+
|
95
|
+
|
96
|
+
@jax_dataclasses.pytree_dataclass
|
97
|
+
class RigidContacts(ContactModel):
|
98
|
+
"""Rigid contacts model."""
|
99
|
+
|
100
|
+
parameters: RigidContactsParams = dataclasses.field(
|
101
|
+
default_factory=RigidContactsParams
|
102
|
+
)
|
103
|
+
|
104
|
+
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
105
|
+
default_factory=FlatTerrain
|
106
|
+
)
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def detect_contacts(
|
110
|
+
W_p_C: jtp.ArrayLike,
|
111
|
+
terrain_height: jtp.ArrayLike,
|
112
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
113
|
+
"""
|
114
|
+
Detect contacts between the collidable points and the terrain.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
W_p_C: The position of the collidable points.
|
118
|
+
terrain_height: The height of the terrain at the collidable point position.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
A tuple containing the activation state of the collidable points and the contact penetration depth h.
|
122
|
+
"""
|
123
|
+
|
124
|
+
# TODO: reduce code duplication with js.contact.in_contact
|
125
|
+
def detect_contact(
|
126
|
+
W_p_C: jtp.ArrayLike,
|
127
|
+
terrain_height: jtp.FloatLike,
|
128
|
+
) -> tuple[jtp.Bool, jtp.Float]:
|
129
|
+
"""
|
130
|
+
Detect contacts between the collidable points and the terrain.
|
131
|
+
"""
|
132
|
+
|
133
|
+
# Unpack the position of the collidable point.
|
134
|
+
_, _, pz = W_p_C.squeeze()
|
135
|
+
|
136
|
+
inactive = pz > terrain_height
|
137
|
+
|
138
|
+
# Compute contact penetration depth
|
139
|
+
h = jnp.maximum(0.0, terrain_height - pz)
|
140
|
+
|
141
|
+
return inactive, h
|
142
|
+
|
143
|
+
inactive_collidable_points, h = jax.vmap(detect_contact)(W_p_C, terrain_height)
|
144
|
+
|
145
|
+
return inactive_collidable_points, h
|
146
|
+
|
147
|
+
@staticmethod
|
148
|
+
def compute_impact_velocity(
|
149
|
+
inactive_collidable_points: jtp.ArrayLike,
|
150
|
+
M: jtp.MatrixLike,
|
151
|
+
J_WC: jtp.MatrixLike,
|
152
|
+
data: js.data.JaxSimModelData,
|
153
|
+
) -> jtp.Vector:
|
154
|
+
"""Returns the new velocity of the system after a potential impact.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
inactive_collidable_points: The activation state of the collidable points.
|
158
|
+
M: The mass matrix of the system.
|
159
|
+
J_WC: The Jacobian matrix of the collidable points.
|
160
|
+
data: The `JaxSimModelData` instance.
|
161
|
+
"""
|
162
|
+
|
163
|
+
def impact_velocity(
|
164
|
+
inactive_collidable_points: jtp.ArrayLike,
|
165
|
+
nu_pre: jtp.ArrayLike,
|
166
|
+
M: jtp.MatrixLike,
|
167
|
+
J_WC: jtp.MatrixLike,
|
168
|
+
data: js.data.JaxSimModelData,
|
169
|
+
):
|
170
|
+
# Compute system velocity after impact maintaining zero linear velocity of active points
|
171
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
172
|
+
sl = jnp.s_[:, 0:3, :]
|
173
|
+
Jl_WC = J_WC[sl]
|
174
|
+
# Zero out the jacobian rows of inactive points
|
175
|
+
Jl_WC = jnp.vstack(
|
176
|
+
jnp.where(
|
177
|
+
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
|
178
|
+
jnp.zeros_like(Jl_WC),
|
179
|
+
Jl_WC,
|
180
|
+
)
|
181
|
+
)
|
182
|
+
|
183
|
+
A = jnp.vstack(
|
184
|
+
[
|
185
|
+
jnp.hstack([M, -Jl_WC.T]),
|
186
|
+
jnp.hstack(
|
187
|
+
[Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]
|
188
|
+
),
|
189
|
+
]
|
190
|
+
)
|
191
|
+
b = jnp.hstack([M @ nu_pre, jnp.zeros(Jl_WC.shape[0])])
|
192
|
+
x = jnp.linalg.lstsq(A, b)[0]
|
193
|
+
nu_post = x[0 : M.shape[0]]
|
194
|
+
|
195
|
+
return nu_post
|
196
|
+
|
197
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
198
|
+
BW_ν_pre_impact = data.generalized_velocity()
|
199
|
+
|
200
|
+
BW_ν_post_impact = impact_velocity(
|
201
|
+
data=data,
|
202
|
+
inactive_collidable_points=inactive_collidable_points,
|
203
|
+
nu_pre=BW_ν_pre_impact,
|
204
|
+
M=M,
|
205
|
+
J_WC=J_WC,
|
206
|
+
)
|
207
|
+
|
208
|
+
return BW_ν_post_impact
|
209
|
+
|
210
|
+
def compute_contact_forces(
|
211
|
+
self,
|
212
|
+
position: jtp.Vector,
|
213
|
+
velocity: jtp.Vector,
|
214
|
+
model: js.model.JaxSimModel,
|
215
|
+
data: js.data.JaxSimModelData,
|
216
|
+
link_forces: jtp.MatrixLike | None = None,
|
217
|
+
regularization_term: jtp.FloatLike = 1e-6,
|
218
|
+
) -> tuple[jtp.Vector, tuple[Any, ...]]:
|
219
|
+
"""
|
220
|
+
Compute the contact forces.
|
221
|
+
|
222
|
+
Args:
|
223
|
+
position: The position of the collidable point.
|
224
|
+
velocity: The linear velocity of the collidable point.
|
225
|
+
model: The `JaxSimModel` instance.
|
226
|
+
data: The `JaxSimModelData` instance.
|
227
|
+
link_forces:
|
228
|
+
Optional `(n_links, 6)` matrix of external forces acting on the links,
|
229
|
+
expressed in the same representation of data.
|
230
|
+
regularization_term:
|
231
|
+
The regularization term to add to the diagonal of the Delassus
|
232
|
+
matrix for better numerical conditioning.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
A tuple containing the contact forces.
|
236
|
+
"""
|
237
|
+
|
238
|
+
# Import qpax just in this method
|
239
|
+
import qpax
|
240
|
+
|
241
|
+
link_forces = (
|
242
|
+
link_forces
|
243
|
+
if link_forces is not None
|
244
|
+
else jnp.zeros((model.number_of_links(), 6))
|
245
|
+
)
|
246
|
+
|
247
|
+
# Compute kin-dyn quantities used in the contact model
|
248
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
249
|
+
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
250
|
+
J_WC = js.contact.jacobian(model=model, data=data)
|
251
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
252
|
+
terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1])
|
253
|
+
n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0]
|
254
|
+
|
255
|
+
# Compute the activation state of the collidable points
|
256
|
+
inactive_collidable_points, h = RigidContacts.detect_contacts(
|
257
|
+
W_p_C=position,
|
258
|
+
terrain_height=terrain_height,
|
259
|
+
)
|
260
|
+
|
261
|
+
delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
|
262
|
+
|
263
|
+
# Add regularization for better numerical conditioning
|
264
|
+
delassus_matrix = delassus_matrix + regularization_term * jnp.eye(
|
265
|
+
delassus_matrix.shape[0]
|
266
|
+
)
|
267
|
+
|
268
|
+
references = js.references.JaxSimModelReferences.build(
|
269
|
+
model=model,
|
270
|
+
data=data,
|
271
|
+
velocity_representation=data.velocity_representation,
|
272
|
+
link_forces=link_forces,
|
273
|
+
)
|
274
|
+
|
275
|
+
with references.switch_velocity_representation(VelRepr.Mixed):
|
276
|
+
BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
|
277
|
+
model, data, references=references
|
278
|
+
)
|
279
|
+
|
280
|
+
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
|
281
|
+
model,
|
282
|
+
data,
|
283
|
+
BW_ν̇_free,
|
284
|
+
).flatten()
|
285
|
+
|
286
|
+
# Compute stabilization term
|
287
|
+
ḣ = velocity[:, 2].squeeze()
|
288
|
+
baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term(
|
289
|
+
inactive_collidable_points=inactive_collidable_points,
|
290
|
+
h=h,
|
291
|
+
ḣ=ḣ,
|
292
|
+
K=self.parameters.K,
|
293
|
+
D=self.parameters.D,
|
294
|
+
).flatten()
|
295
|
+
|
296
|
+
free_contact_acc -= baumgarte_term
|
297
|
+
|
298
|
+
# Setup optimization problem
|
299
|
+
Q = delassus_matrix
|
300
|
+
q = free_contact_acc
|
301
|
+
G = RigidContacts._compute_ineq_constraint_matrix(
|
302
|
+
inactive_collidable_points=inactive_collidable_points, mu=self.parameters.mu
|
303
|
+
)
|
304
|
+
h_bounds = RigidContacts._compute_ineq_bounds(
|
305
|
+
n_collidable_points=n_collidable_points
|
306
|
+
)
|
307
|
+
A = jnp.zeros((0, 3 * n_collidable_points))
|
308
|
+
b = jnp.zeros((0,))
|
309
|
+
|
310
|
+
# Solve the optimization problem
|
311
|
+
solution, *_ = qpax.solve_qp(Q=Q, q=q, A=A, b=b, G=G, h=h_bounds)
|
312
|
+
|
313
|
+
f_C_lin = solution.reshape(-1, 3)
|
314
|
+
|
315
|
+
# Transform linear contact forces to 6D
|
316
|
+
CW_f_C = jnp.hstack(
|
317
|
+
(
|
318
|
+
f_C_lin,
|
319
|
+
jnp.zeros((f_C_lin.shape[0], 3)),
|
320
|
+
)
|
321
|
+
)
|
322
|
+
|
323
|
+
# Transform the contact forces to inertial-fixed representation
|
324
|
+
W_f_C = jax.vmap(
|
325
|
+
lambda CW_f_C, W_H_C: ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
326
|
+
array=CW_f_C,
|
327
|
+
transform=W_H_C,
|
328
|
+
other_representation=VelRepr.Mixed,
|
329
|
+
is_force=True,
|
330
|
+
),
|
331
|
+
)(
|
332
|
+
CW_f_C,
|
333
|
+
W_H_C,
|
334
|
+
)
|
335
|
+
|
336
|
+
return W_f_C, ()
|
337
|
+
|
338
|
+
@staticmethod
|
339
|
+
def _delassus_matrix(
|
340
|
+
M: jtp.MatrixLike,
|
341
|
+
J_WC: jtp.MatrixLike,
|
342
|
+
) -> jtp.Matrix:
|
343
|
+
sl = jnp.s_[:, 0:3, :]
|
344
|
+
J_WC_lin = jnp.vstack(J_WC[sl])
|
345
|
+
|
346
|
+
delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
|
347
|
+
return delassus_matrix
|
348
|
+
|
349
|
+
@staticmethod
|
350
|
+
def _compute_ineq_constraint_matrix(
|
351
|
+
inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
|
352
|
+
) -> jtp.Matrix:
|
353
|
+
def compute_G_single_point(mu: float, c: float) -> jtp.Matrix:
|
354
|
+
"""
|
355
|
+
Compute the inequality constraint matrix for a single collidable point
|
356
|
+
Rows 0-3: enforce the friction pyramid constraint,
|
357
|
+
Row 4: last one is for the non negativity of the vertical force
|
358
|
+
Row 5: contact complementarity condition
|
359
|
+
"""
|
360
|
+
G_single_point = jnp.array(
|
361
|
+
[
|
362
|
+
[1, 0, -mu],
|
363
|
+
[0, 1, -mu],
|
364
|
+
[-1, 0, -mu],
|
365
|
+
[0, -1, -mu],
|
366
|
+
[0, 0, -1],
|
367
|
+
[0, 0, c],
|
368
|
+
]
|
369
|
+
)
|
370
|
+
return G_single_point
|
371
|
+
|
372
|
+
G = jax.vmap(compute_G_single_point, in_axes=(None, 0))(
|
373
|
+
mu, inactive_collidable_points
|
374
|
+
)
|
375
|
+
G = jax.scipy.linalg.block_diag(*G)
|
376
|
+
return G
|
377
|
+
|
378
|
+
@staticmethod
|
379
|
+
def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector:
|
380
|
+
n_constraints = 6 * n_collidable_points
|
381
|
+
return jnp.zeros(shape=(n_constraints,))
|
382
|
+
|
383
|
+
@staticmethod
|
384
|
+
def _compute_mixed_nu_dot_free(
|
385
|
+
model: js.model.JaxSimModel,
|
386
|
+
data: js.data.JaxSimModelData,
|
387
|
+
references: js.references.JaxSimModelReferences | None = None,
|
388
|
+
) -> jtp.Array:
|
389
|
+
references = (
|
390
|
+
references
|
391
|
+
if references is not None
|
392
|
+
else js.references.JaxSimModelReferences.zero(model=model, data=data)
|
393
|
+
)
|
394
|
+
|
395
|
+
with (
|
396
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
397
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
398
|
+
):
|
399
|
+
BW_v_WB = data.base_velocity()
|
400
|
+
W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
|
401
|
+
W_v̇_WB, s̈ = js.ode.system_acceleration(
|
402
|
+
model=model,
|
403
|
+
data=data,
|
404
|
+
joint_forces=references.joint_force_references(model=model),
|
405
|
+
link_forces=references.link_forces(model=model, data=data),
|
406
|
+
)
|
407
|
+
|
408
|
+
# Convert the inertial-fixed base acceleration to a mixed base acceleration.
|
409
|
+
W_H_B = data.base_transform()
|
410
|
+
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
411
|
+
BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
|
412
|
+
term1 = BW_X_W @ W_v̇_WB
|
413
|
+
term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
|
414
|
+
BW_v̇_WB = term1 - term2
|
415
|
+
|
416
|
+
BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
|
417
|
+
|
418
|
+
return BW_ν̇
|
419
|
+
|
420
|
+
@staticmethod
|
421
|
+
def _linear_acceleration_of_collidable_points(
|
422
|
+
model: js.model.JaxSimModel,
|
423
|
+
data: js.data.JaxSimModelData,
|
424
|
+
mixed_nu_dot: jtp.ArrayLike,
|
425
|
+
) -> jtp.Matrix:
|
426
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
427
|
+
CW_J_WC_BW = js.contact.jacobian(
|
428
|
+
model=model,
|
429
|
+
data=data,
|
430
|
+
output_vel_repr=VelRepr.Mixed,
|
431
|
+
)
|
432
|
+
CW_J̇_WC_BW = js.contact.jacobian_derivative(
|
433
|
+
model=model,
|
434
|
+
data=data,
|
435
|
+
output_vel_repr=VelRepr.Mixed,
|
436
|
+
)
|
437
|
+
|
438
|
+
BW_ν = data.generalized_velocity()
|
439
|
+
BW_ν̇ = mixed_nu_dot
|
440
|
+
|
441
|
+
CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
|
442
|
+
CW_a_WC = CW_a_WC.reshape(-1, 6)
|
443
|
+
|
444
|
+
return CW_a_WC[:, 0:3].squeeze()
|
445
|
+
|
446
|
+
@staticmethod
|
447
|
+
def _compute_baumgarte_stabilization_term(
|
448
|
+
inactive_collidable_points: jtp.ArrayLike,
|
449
|
+
h: jtp.ArrayLike,
|
450
|
+
ḣ: jtp.ArrayLike,
|
451
|
+
K: jtp.FloatLike,
|
452
|
+
D: jtp.FloatLike,
|
453
|
+
) -> jtp.Array:
|
454
|
+
def baumgarte_stabilization(
|
455
|
+
inactive: jtp.BoolLike,
|
456
|
+
h: jtp.FloatLike,
|
457
|
+
ḣ: jtp.FloatLike,
|
458
|
+
k_baumgarte: jtp.FloatLike,
|
459
|
+
d_baumgarte: jtp.FloatLike,
|
460
|
+
) -> jtp.Array:
|
461
|
+
baumgarte_term = jax.lax.cond(
|
462
|
+
inactive,
|
463
|
+
lambda h, ḣ, K, D: jnp.zeros(shape=(3,)),
|
464
|
+
lambda h, ḣ, K, D: jnp.zeros(shape=(3,)).at[2].set(K * h + D * ḣ),
|
465
|
+
*(
|
466
|
+
h,
|
467
|
+
ḣ,
|
468
|
+
k_baumgarte,
|
469
|
+
d_baumgarte,
|
470
|
+
),
|
471
|
+
)
|
472
|
+
return baumgarte_term
|
473
|
+
|
474
|
+
baumgarte_term = jax.vmap(
|
475
|
+
baumgarte_stabilization, in_axes=(0, 0, 0, None, None)
|
476
|
+
)(inactive_collidable_points, h, ḣ, K, D)
|
477
|
+
|
478
|
+
return baumgarte_term
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev18
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -65,6 +65,7 @@ Requires-Dist: jaxlib>=0.4.13
|
|
65
65
|
Requires-Dist: jaxlie>=1.3.0
|
66
66
|
Requires-Dist: jax-dataclasses>=1.4.0
|
67
67
|
Requires-Dist: pptree
|
68
|
+
Requires-Dist: qpax
|
68
69
|
Requires-Dist: rod>=0.3.0
|
69
70
|
Requires-Dist: typing-extensions; python_version < "3.12"
|
70
71
|
Provides-Extra: all
|
@@ -1,23 +1,23 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=ixsS4dYMPex2wOUUp_rkPnwrPhYzkRh1xO_YuMj3Cr4,2626
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=SFJGfO84uy3oOc6jDmWWev6VuJIqHL3tI8_OvaYfdsA,426
|
3
3
|
jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=IbFx3UkEXi-cm7UBqMPi58rJAFV_HbZ9E_K4JwfNvVM,753
|
6
6
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
7
7
|
jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
|
8
8
|
jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
|
9
|
-
jaxsim/api/contact.py,sha256=
|
9
|
+
jaxsim/api/contact.py,sha256=HyEAjF7BySDDOlRahN0l7V15IPB0HPXuoM0twamuEW0,20913
|
10
10
|
jaxsim/api/data.py,sha256=CUh9lvhVk3_clNQ26BUBGpjvFSsK_PrVWVMEWpMdHRM,27206
|
11
11
|
jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
|
12
12
|
jaxsim/api/joint.py,sha256=L81bQe-noPT6_54KOSF7KBjRmEPAS433ULn2EcXI8vI,5115
|
13
13
|
jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
|
14
14
|
jaxsim/api/link.py,sha256=qPRtc8qqMRjZxUCZYXJMygbB6huDXBfIT1b1b8Durkw,18631
|
15
|
-
jaxsim/api/model.py,sha256=
|
16
|
-
jaxsim/api/ode.py,sha256=
|
17
|
-
jaxsim/api/ode_data.py,sha256=
|
15
|
+
jaxsim/api/model.py,sha256=HXoqCtQ3KStGoxhgvFm8P_Sc-lbEM4l5No2MoHzNlOk,65558
|
16
|
+
jaxsim/api/ode.py,sha256=Vb2sN4zwpXnaJDD9-ziz2qvfmfa4jvIQ0fONbBIRGmU,13368
|
17
|
+
jaxsim/api/ode_data.py,sha256=U7F6TL6bENAxpQQl4PupPoDG7d7VfTTFqDAs3xwu6Hs,20003
|
18
18
|
jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
|
19
19
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
20
|
-
jaxsim/integrators/common.py,sha256=
|
20
|
+
jaxsim/integrators/common.py,sha256=ntjflaV3qWaFH_E65pAGZ6QipdnFsgQDasKtIKpxTe4,20432
|
21
21
|
jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
|
22
22
|
jaxsim/integrators/variable_step.py,sha256=5StkFh9oQba34zlkIoXG2fUN78gbxkHePWbrpQ-QZOI,21274
|
23
23
|
jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
|
@@ -53,7 +53,8 @@ jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
|
|
53
53
|
jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
|
54
54
|
jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
|
55
55
|
jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
56
|
-
jaxsim/rbda/contacts/common.py,sha256=
|
56
|
+
jaxsim/rbda/contacts/common.py,sha256=VwAs742futAmLnDgbaOuLzNDBFiKDfYItdEZ4UcFgzE,2467
|
57
|
+
jaxsim/rbda/contacts/rigid.py,sha256=8Vbnxng-ERZ5ka_eZGIBuhBDr2PNjc7m-Or255AfEw4,15862
|
57
58
|
jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
|
58
59
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
59
60
|
jaxsim/terrain/terrain.py,sha256=ctyNANIFSM3tZmamprjaEDcWgUSP0oNJbmT1zw9RjPs,4565
|
@@ -61,8 +62,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
|
61
62
|
jaxsim/utils/jaxsim_dataclass.py,sha256=5xJbY0G8d7C0OTNIW9T4vQxiDak6TGZT9gpNOvRykFI,11373
|
62
63
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
63
64
|
jaxsim/utils/wrappers.py,sha256=JhLUh1g8iU-lhjbuZRfkscPZhYlLCOorVM2Xl3ulRBI,4054
|
64
|
-
jaxsim-0.4.3.
|
65
|
-
jaxsim-0.4.3.
|
66
|
-
jaxsim-0.4.3.
|
67
|
-
jaxsim-0.4.3.
|
68
|
-
jaxsim-0.4.3.
|
65
|
+
jaxsim-0.4.3.dev18.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
66
|
+
jaxsim-0.4.3.dev18.dist-info/METADATA,sha256=aLpRkfa9CC7GVzXMKX3LY5DkCHEmOr4CE-u3Vbt5fx8,17247
|
67
|
+
jaxsim-0.4.3.dev18.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
68
|
+
jaxsim-0.4.3.dev18.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
69
|
+
jaxsim-0.4.3.dev18.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|