jaxsim 0.4.3.dev12__py3-none-any.whl → 0.4.3.dev18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/common.py +2 -2
- jaxsim/api/contact.py +37 -9
- jaxsim/api/data.py +1 -1
- jaxsim/api/frame.py +1 -1
- jaxsim/api/joint.py +1 -1
- jaxsim/api/link.py +1 -1
- jaxsim/api/model.py +62 -8
- jaxsim/api/ode.py +114 -36
- jaxsim/api/ode_data.py +11 -7
- jaxsim/integrators/common.py +30 -21
- jaxsim/integrators/variable_step.py +2 -2
- jaxsim/logging.py +1 -2
- jaxsim/math/inertia.py +1 -3
- jaxsim/math/joint_model.py +1 -1
- jaxsim/math/rotation.py +1 -3
- jaxsim/mujoco/loaders.py +2 -1
- jaxsim/mujoco/model.py +2 -1
- jaxsim/mujoco/visualizer.py +2 -2
- jaxsim/parsers/descriptions/model.py +1 -1
- jaxsim/parsers/kinematic_graph.py +4 -3
- jaxsim/parsers/rod/parser.py +10 -10
- jaxsim/rbda/contacts/common.py +3 -2
- jaxsim/rbda/contacts/rigid.py +478 -0
- jaxsim/rbda/rnea.py +5 -7
- jaxsim/utils/jaxsim_dataclass.py +3 -3
- jaxsim/utils/wrappers.py +2 -1
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/METADATA +2 -1
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/RECORD +32 -31
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev18'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev18')
|
jaxsim/api/common.py
CHANGED
@@ -3,7 +3,7 @@ import contextlib
|
|
3
3
|
import dataclasses
|
4
4
|
import enum
|
5
5
|
import functools
|
6
|
-
from
|
6
|
+
from collections.abc import Iterator
|
7
7
|
|
8
8
|
import jax
|
9
9
|
import jax.numpy as jnp
|
@@ -44,7 +44,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
44
44
|
@contextlib.contextmanager
|
45
45
|
def switch_velocity_representation(
|
46
46
|
self, velocity_representation: VelRepr
|
47
|
-
) ->
|
47
|
+
) -> Iterator[Self]:
|
48
48
|
"""
|
49
49
|
Context manager to temporarily switch the velocity representation.
|
50
50
|
|
jaxsim/api/contact.py
CHANGED
@@ -114,19 +114,24 @@ def collidable_point_forces(
|
|
114
114
|
|
115
115
|
@jax.jit
|
116
116
|
def collidable_point_dynamics(
|
117
|
-
model: js.model.JaxSimModel,
|
118
|
-
|
117
|
+
model: js.model.JaxSimModel,
|
118
|
+
data: js.data.JaxSimModelData,
|
119
|
+
link_forces: jtp.MatrixLike | None = None,
|
120
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
|
119
121
|
r"""
|
120
|
-
Compute the 6D force applied to each collidable point
|
121
|
-
material deformation rate.
|
122
|
+
Compute the 6D force applied to each collidable point.
|
122
123
|
|
123
124
|
Args:
|
124
125
|
model: The model to consider.
|
125
126
|
data: The data of the considered model.
|
127
|
+
link_forces:
|
128
|
+
The 6D external forces to apply to the links expressed in the same
|
129
|
+
representation of data.
|
126
130
|
|
127
131
|
Returns:
|
128
|
-
The 6D force applied to each collidable point and the
|
129
|
-
material deformation rate.
|
132
|
+
The 6D force applied to each collidable point and additional data based on the contact model configured:
|
133
|
+
- Soft: the material deformation rate.
|
134
|
+
- Rigid: nothing.
|
130
135
|
|
131
136
|
Note:
|
132
137
|
The material deformation rate is always returned in the mixed frame
|
@@ -138,7 +143,8 @@ def collidable_point_dynamics(
|
|
138
143
|
# all collidable points belonging to the robot.
|
139
144
|
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
|
140
145
|
|
141
|
-
# Import privately the
|
146
|
+
# Import privately the contacts classes.
|
147
|
+
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
142
148
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
143
149
|
|
144
150
|
# Build the soft contact model.
|
@@ -161,9 +167,31 @@ def collidable_point_dynamics(
|
|
161
167
|
W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
|
162
168
|
W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
|
163
169
|
)
|
170
|
+
aux_data = dict(m_dot=CW_ṁ)
|
171
|
+
|
172
|
+
case RigidContacts():
|
173
|
+
assert isinstance(model.contact_model, RigidContacts)
|
174
|
+
assert isinstance(data.state.contact, RigidContactsState)
|
175
|
+
|
176
|
+
# Build the contact model.
|
177
|
+
rigid_contacts = RigidContacts(
|
178
|
+
parameters=data.contacts_params, terrain=model.terrain
|
179
|
+
)
|
180
|
+
|
181
|
+
# Compute the 6D force expressed in the inertial frame and applied to each
|
182
|
+
# collidable point.
|
183
|
+
W_f_Ci, _ = rigid_contacts.compute_contact_forces(
|
184
|
+
position=W_p_Ci,
|
185
|
+
velocity=W_ṗ_Ci,
|
186
|
+
model=model,
|
187
|
+
data=data,
|
188
|
+
link_forces=link_forces,
|
189
|
+
)
|
190
|
+
|
191
|
+
aux_data = dict()
|
164
192
|
|
165
193
|
case _:
|
166
|
-
raise ValueError("Invalid contact model {
|
194
|
+
raise ValueError(f"Invalid contact model {model.contact_model}")
|
167
195
|
|
168
196
|
# Convert the 6D forces to the active representation.
|
169
197
|
f_Ci = jax.vmap(
|
@@ -175,7 +203,7 @@ def collidable_point_dynamics(
|
|
175
203
|
)
|
176
204
|
)(W_f_Ci)
|
177
205
|
|
178
|
-
return f_Ci,
|
206
|
+
return f_Ci, aux_data
|
179
207
|
|
180
208
|
|
181
209
|
@functools.partial(jax.jit, static_argnames=["link_names"])
|
jaxsim/api/data.py
CHANGED
jaxsim/api/frame.py
CHANGED
jaxsim/api/joint.py
CHANGED
jaxsim/api/link.py
CHANGED
jaxsim/api/model.py
CHANGED
@@ -4,7 +4,8 @@ import copy
|
|
4
4
|
import dataclasses
|
5
5
|
import functools
|
6
6
|
import pathlib
|
7
|
-
from
|
7
|
+
from collections.abc import Sequence
|
8
|
+
from typing import Any
|
8
9
|
|
9
10
|
import jax
|
10
11
|
import jax.numpy as jnp
|
@@ -13,6 +14,7 @@ import rod
|
|
13
14
|
from jax_dataclasses import Static
|
14
15
|
|
15
16
|
import jaxsim.api as js
|
17
|
+
import jaxsim.exceptions
|
16
18
|
import jaxsim.terrain
|
17
19
|
import jaxsim.typing as jtp
|
18
20
|
from jaxsim.math import Adjoint, Cross
|
@@ -1889,6 +1891,8 @@ def step(
|
|
1889
1891
|
and the new state of the integrator.
|
1890
1892
|
"""
|
1891
1893
|
|
1894
|
+
from jaxsim.rbda.contacts.rigid import RigidContacts
|
1895
|
+
|
1892
1896
|
# Extract the integrator kwargs.
|
1893
1897
|
# The following logic allows using integrators having kwargs colliding with the
|
1894
1898
|
# kwargs of this step function.
|
@@ -1900,12 +1904,12 @@ def step(
|
|
1900
1904
|
|
1901
1905
|
# Extract the initial resources.
|
1902
1906
|
t0_ns = data.time_ns
|
1903
|
-
|
1907
|
+
state_t0 = data.state
|
1904
1908
|
integrator_state_x0 = integrator_state
|
1905
1909
|
|
1906
1910
|
# Step the dynamics forward.
|
1907
|
-
|
1908
|
-
x0=
|
1911
|
+
state_tf, integrator_state_tf = integrator.step(
|
1912
|
+
x0=state_t0,
|
1909
1913
|
t0=jnp.array(t0_ns / 1e9).astype(float),
|
1910
1914
|
dt=dt,
|
1911
1915
|
params=integrator_state_x0,
|
@@ -1927,11 +1931,61 @@ def step(
|
|
1927
1931
|
),
|
1928
1932
|
)
|
1929
1933
|
|
1930
|
-
|
1934
|
+
data_tf = (
|
1931
1935
|
# Store the new state of the model and the new time.
|
1932
1936
|
data.replace(
|
1933
|
-
state=
|
1937
|
+
state=state_tf,
|
1934
1938
|
time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
|
1935
|
-
)
|
1936
|
-
|
1939
|
+
)
|
1940
|
+
)
|
1941
|
+
|
1942
|
+
# Post process the simulation state, if needed.
|
1943
|
+
match model.contact_model:
|
1944
|
+
|
1945
|
+
# Rigid contact models use an impact model that produces a discontinuous model velocity.
|
1946
|
+
# Hence here we need to reset the velocity after each impact to guarantee that
|
1947
|
+
# the linear velocity of the active collidable points is zero.
|
1948
|
+
case RigidContacts():
|
1949
|
+
# Raise runtime error for not supported case in which Rigid contacts and Baumgarte stabilization
|
1950
|
+
# enabled are used with ForwardEuler integrator.
|
1951
|
+
jaxsim.exceptions.raise_runtime_error_if(
|
1952
|
+
condition=jnp.logical_and(
|
1953
|
+
isinstance(
|
1954
|
+
integrator,
|
1955
|
+
jaxsim.integrators.fixed_step.ForwardEuler
|
1956
|
+
| jaxsim.integrators.fixed_step.ForwardEulerSO3,
|
1957
|
+
),
|
1958
|
+
jnp.array(
|
1959
|
+
[data_tf.contacts_params.K, data_tf.contacts_params.D]
|
1960
|
+
).any(),
|
1961
|
+
),
|
1962
|
+
msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
|
1963
|
+
)
|
1964
|
+
|
1965
|
+
with data_tf.switch_velocity_representation(VelRepr.Mixed):
|
1966
|
+
W_p_C = js.contact.collidable_point_positions(model, data_tf)
|
1967
|
+
M = js.model.free_floating_mass_matrix(model, data_tf)
|
1968
|
+
J_WC = js.contact.jacobian(model, data_tf)
|
1969
|
+
px, py, _ = W_p_C.T
|
1970
|
+
terrain_height = jax.vmap(model.terrain.height)(px, py)
|
1971
|
+
inactive_collidable_points, _ = RigidContacts.detect_contacts(
|
1972
|
+
W_p_C=W_p_C,
|
1973
|
+
terrain_height=terrain_height,
|
1974
|
+
)
|
1975
|
+
BW_nu_post_impact = RigidContacts.compute_impact_velocity(
|
1976
|
+
data=data_tf,
|
1977
|
+
inactive_collidable_points=inactive_collidable_points,
|
1978
|
+
M=M,
|
1979
|
+
J_WC=J_WC,
|
1980
|
+
)
|
1981
|
+
data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
|
1982
|
+
data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
|
1983
|
+
# Restore the input velocity representation.
|
1984
|
+
data_tf = data_tf.replace(
|
1985
|
+
velocity_representation=data.velocity_representation, validate=False
|
1986
|
+
)
|
1987
|
+
|
1988
|
+
return (
|
1989
|
+
data_tf,
|
1990
|
+
integrator_state_tf,
|
1937
1991
|
)
|
jaxsim/api/ode.py
CHANGED
@@ -50,7 +50,7 @@ def wrap_system_dynamics_for_integration(
|
|
50
50
|
# The wrapped dynamics will hold a reference of this object.
|
51
51
|
model_closed = model.copy()
|
52
52
|
data_closed = data.copy().replace(
|
53
|
-
state=js.ode_data.ODEState.zero(model=model_closed)
|
53
|
+
state=js.ode_data.ODEState.zero(model=model_closed, data=data)
|
54
54
|
)
|
55
55
|
|
56
56
|
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
@@ -88,7 +88,7 @@ def system_velocity_dynamics(
|
|
88
88
|
*,
|
89
89
|
joint_forces: jtp.Vector | None = None,
|
90
90
|
link_forces: jtp.Vector | None = None,
|
91
|
-
) -> tuple[jtp.Vector, jtp.Vector,
|
91
|
+
) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
|
92
92
|
"""
|
93
93
|
Compute the dynamics of the system velocity.
|
94
94
|
|
@@ -102,18 +102,10 @@ def system_velocity_dynamics(
|
|
102
102
|
|
103
103
|
Returns:
|
104
104
|
A tuple containing the derivative of the base 6D velocity in inertial-fixed
|
105
|
-
representation, the derivative of the joint velocities,
|
106
|
-
|
107
|
-
the system dynamics evaluation.
|
105
|
+
representation, the derivative of the joint velocities, and auxiliary data
|
106
|
+
returned by the system dynamics evaluation.
|
108
107
|
"""
|
109
108
|
|
110
|
-
# Build joint torques if not provided.
|
111
|
-
τ = (
|
112
|
-
jnp.atleast_1d(joint_forces.squeeze())
|
113
|
-
if joint_forces is not None
|
114
|
-
else jnp.zeros_like(data.joint_positions())
|
115
|
-
).astype(float)
|
116
|
-
|
117
109
|
# Build link forces if not provided.
|
118
110
|
# These forces are expressed in the frame corresponding to the velocity
|
119
111
|
# representation of data.
|
@@ -123,6 +115,15 @@ def system_velocity_dynamics(
|
|
123
115
|
else jnp.zeros((model.number_of_links(), 6))
|
124
116
|
).astype(float)
|
125
117
|
|
118
|
+
# We expect that the 6D forces included in the `link_forces` argument are expressed
|
119
|
+
# in the frame corresponding to the velocity representation of `data`.
|
120
|
+
references = js.references.JaxSimModelReferences.build(
|
121
|
+
model=model,
|
122
|
+
link_forces=O_f_L,
|
123
|
+
data=data,
|
124
|
+
velocity_representation=data.velocity_representation,
|
125
|
+
)
|
126
|
+
|
126
127
|
# ======================
|
127
128
|
# Compute contact forces
|
128
129
|
# ======================
|
@@ -131,19 +132,17 @@ def system_velocity_dynamics(
|
|
131
132
|
# with the terrain.
|
132
133
|
W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
|
133
134
|
|
134
|
-
|
135
|
-
from jaxsim.rbda.contacts.soft import SoftContactsState
|
136
|
-
|
137
|
-
# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
|
138
|
-
assert isinstance(data.state.contact, SoftContactsState)
|
139
|
-
ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)
|
140
|
-
|
135
|
+
aux_data = {}
|
141
136
|
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
142
137
|
|
143
138
|
# Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
|
144
|
-
#
|
139
|
+
# along with contact-specific auxiliary states.
|
145
140
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
146
|
-
W_f_Ci,
|
141
|
+
W_f_Ci, aux_data = js.contact.collidable_point_dynamics(
|
142
|
+
model=model,
|
143
|
+
data=data,
|
144
|
+
link_forces=references.link_forces(model=model, data=data),
|
145
|
+
)
|
147
146
|
|
148
147
|
# Construct the vector defining the parent link index of each collidable point.
|
149
148
|
# We use this vector to sum the 6D forces of all collidable points rigidly
|
@@ -161,6 +160,74 @@ def system_velocity_dynamics(
|
|
161
160
|
|
162
161
|
W_f_Li_terrain = mask.T @ W_f_Ci
|
163
162
|
|
163
|
+
# ===========================
|
164
|
+
# Compute system acceleration
|
165
|
+
# ===========================
|
166
|
+
|
167
|
+
# Compute the total link forces
|
168
|
+
with (
|
169
|
+
data.switch_velocity_representation(VelRepr.Inertial),
|
170
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
171
|
+
):
|
172
|
+
references = references.apply_link_forces(
|
173
|
+
model=model,
|
174
|
+
data=data,
|
175
|
+
forces=W_f_Li_terrain,
|
176
|
+
additive=True,
|
177
|
+
)
|
178
|
+
# Get the link forces in the data representation
|
179
|
+
with references.switch_velocity_representation(data.velocity_representation):
|
180
|
+
f_L_total = references.link_forces(model=model, data=data)
|
181
|
+
|
182
|
+
# The following method always returns the inertial-fixed acceleration, and expects
|
183
|
+
# the link_forces expressed in the inertial frame.
|
184
|
+
W_v̇_WB, s̈ = system_acceleration(
|
185
|
+
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
|
186
|
+
)
|
187
|
+
|
188
|
+
return W_v̇_WB, s̈, aux_data
|
189
|
+
|
190
|
+
|
191
|
+
def system_acceleration(
|
192
|
+
model: js.model.JaxSimModel,
|
193
|
+
data: js.data.JaxSimModelData,
|
194
|
+
*,
|
195
|
+
joint_forces: jtp.VectorLike | None = None,
|
196
|
+
link_forces: jtp.MatrixLike | None = None,
|
197
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
198
|
+
"""
|
199
|
+
Compute the system acceleration in inertial-fixed representation.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
model: The model to consider.
|
203
|
+
data: The data of the considered model.
|
204
|
+
joint_forces: The joint forces to apply.
|
205
|
+
link_forces:
|
206
|
+
The 6D forces to apply to the links expressed in the same representation of data.
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
A tuple containing the base 6D acceleration in inertial-fixed representation
|
210
|
+
and the joint accelerations.
|
211
|
+
"""
|
212
|
+
|
213
|
+
# ====================
|
214
|
+
# Validate input data
|
215
|
+
# ====================
|
216
|
+
|
217
|
+
# Build link forces if not provided.
|
218
|
+
f_L = (
|
219
|
+
jnp.atleast_2d(link_forces.squeeze())
|
220
|
+
if link_forces is not None
|
221
|
+
else jnp.zeros((model.number_of_links(), 6))
|
222
|
+
).astype(float)
|
223
|
+
|
224
|
+
# Build joint torques if not provided.
|
225
|
+
τ = (
|
226
|
+
jnp.atleast_1d(joint_forces.squeeze())
|
227
|
+
if joint_forces is not None
|
228
|
+
else jnp.zeros_like(data.joint_positions())
|
229
|
+
).astype(float)
|
230
|
+
|
164
231
|
# ====================
|
165
232
|
# Enforce joint limits
|
166
233
|
# ====================
|
@@ -198,29 +265,25 @@ def system_velocity_dynamics(
|
|
198
265
|
|
199
266
|
references = js.references.JaxSimModelReferences.build(
|
200
267
|
model=model,
|
201
|
-
joint_force_references=τ_total,
|
202
|
-
link_forces=O_f_L,
|
203
268
|
data=data,
|
204
269
|
velocity_representation=data.velocity_representation,
|
270
|
+
joint_force_references=τ_total,
|
271
|
+
link_forces=f_L,
|
205
272
|
)
|
206
273
|
|
207
|
-
with references.switch_velocity_representation(VelRepr.Inertial):
|
208
|
-
W_f_L = references.link_forces(model=model, data=data)
|
209
|
-
|
210
|
-
# Compute the total external 6D forces applied to the links.
|
211
|
-
W_f_L_total = W_f_L + W_f_Li_terrain
|
212
|
-
|
213
274
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
214
275
|
# - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
|
215
|
-
with
|
276
|
+
with (
|
277
|
+
data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
|
278
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
279
|
+
):
|
216
280
|
W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
217
281
|
model=model,
|
218
282
|
data=data,
|
219
|
-
joint_forces
|
220
|
-
link_forces=
|
283
|
+
joint_forces=references.joint_force_references(),
|
284
|
+
link_forces=references.link_forces(),
|
221
285
|
)
|
222
|
-
|
223
|
-
return W_v̇_WB, s̈, ṁ, dict()
|
286
|
+
return W_v̇_WB, s̈
|
224
287
|
|
225
288
|
|
226
289
|
@jax.jit
|
@@ -291,14 +354,29 @@ def system_dynamics(
|
|
291
354
|
by the system dynamics evaluation.
|
292
355
|
"""
|
293
356
|
|
357
|
+
from jaxsim.rbda.contacts.rigid import RigidContacts
|
358
|
+
from jaxsim.rbda.contacts.soft import SoftContacts
|
359
|
+
|
294
360
|
# Compute the accelerations and the material deformation rate.
|
295
|
-
W_v̇_WB, s̈,
|
361
|
+
W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
|
296
362
|
model=model,
|
297
363
|
data=data,
|
298
364
|
joint_forces=joint_forces,
|
299
365
|
link_forces=link_forces,
|
300
366
|
)
|
301
367
|
|
368
|
+
ode_state_kwargs = {}
|
369
|
+
|
370
|
+
match model.contact_model:
|
371
|
+
case SoftContacts():
|
372
|
+
ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
|
373
|
+
|
374
|
+
case RigidContacts():
|
375
|
+
pass
|
376
|
+
|
377
|
+
case _:
|
378
|
+
raise ValueError("Unable to determine contact state class prefix.")
|
379
|
+
|
302
380
|
# Extract the velocities.
|
303
381
|
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
304
382
|
model=model,
|
@@ -317,7 +395,7 @@ def system_dynamics(
|
|
317
395
|
base_linear_velocity=W_v̇_WB[0:3],
|
318
396
|
base_angular_velocity=W_v̇_WB[3:6],
|
319
397
|
joint_velocities=s̈,
|
320
|
-
|
398
|
+
**ode_state_kwargs,
|
321
399
|
)
|
322
400
|
|
323
401
|
return ode_state_derivative, aux_dict
|
jaxsim/api/ode_data.py
CHANGED
@@ -6,6 +6,7 @@ import jax_dataclasses
|
|
6
6
|
import jaxsim.api as js
|
7
7
|
import jaxsim.typing as jtp
|
8
8
|
from jaxsim.rbda import ContactsState
|
9
|
+
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
9
10
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
10
11
|
from jaxsim.utils import JaxsimDataclass
|
11
12
|
|
@@ -133,7 +134,7 @@ class ODEState(JaxsimDataclass):
|
|
133
134
|
base_quaternion: jtp.Vector | None = None,
|
134
135
|
base_linear_velocity: jtp.Vector | None = None,
|
135
136
|
base_angular_velocity: jtp.Vector | None = None,
|
136
|
-
|
137
|
+
**kwargs,
|
137
138
|
) -> ODEState:
|
138
139
|
"""
|
139
140
|
Build an `ODEState` from a `JaxSimModel`.
|
@@ -148,9 +149,7 @@ class ODEState(JaxsimDataclass):
|
|
148
149
|
The linear velocity of the base link in inertial-fixed representation.
|
149
150
|
base_angular_velocity:
|
150
151
|
The angular velocity of the base link in inertial-fixed representation.
|
151
|
-
|
152
|
-
The matrix of 3D tangential material deformations corresponding to
|
153
|
-
each collidable point.
|
152
|
+
kwargs: Additional arguments needed to build the contact state.
|
154
153
|
|
155
154
|
Returns:
|
156
155
|
The `ODEState` built from the `JaxSimModel`.
|
@@ -163,6 +162,7 @@ class ODEState(JaxsimDataclass):
|
|
163
162
|
# Get the contact model from the `JaxSimModel`.
|
164
163
|
match model.contact_model:
|
165
164
|
case SoftContacts():
|
165
|
+
tangential_deformation = kwargs.get("tangential_deformation", None)
|
166
166
|
contact = SoftContactsState.build_from_jaxsim_model(
|
167
167
|
model=model,
|
168
168
|
**(
|
@@ -171,6 +171,8 @@ class ODEState(JaxsimDataclass):
|
|
171
171
|
else dict()
|
172
172
|
),
|
173
173
|
)
|
174
|
+
case RigidContacts():
|
175
|
+
contact = RigidContactsState.build()
|
174
176
|
case _:
|
175
177
|
raise ValueError("Unable to determine contact state class prefix.")
|
176
178
|
|
@@ -214,7 +216,7 @@ class ODEState(JaxsimDataclass):
|
|
214
216
|
|
215
217
|
# Get the contact model from the `JaxSimModel`.
|
216
218
|
match contact:
|
217
|
-
case SoftContactsState():
|
219
|
+
case SoftContactsState() | RigidContactsState():
|
218
220
|
pass
|
219
221
|
case None:
|
220
222
|
contact = SoftContactsState.zero(model=model)
|
@@ -224,7 +226,7 @@ class ODEState(JaxsimDataclass):
|
|
224
226
|
return ODEState(physics_model=physics_model_state, contact=contact)
|
225
227
|
|
226
228
|
@staticmethod
|
227
|
-
def zero(model: js.model.JaxSimModel) -> ODEState:
|
229
|
+
def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState:
|
228
230
|
"""
|
229
231
|
Build a zero `ODEState` from a `JaxSimModel`.
|
230
232
|
|
@@ -235,7 +237,9 @@ class ODEState(JaxsimDataclass):
|
|
235
237
|
A zero `ODEState` instance.
|
236
238
|
"""
|
237
239
|
|
238
|
-
model_state = ODEState.build(
|
240
|
+
model_state = ODEState.build(
|
241
|
+
model=model, contact=data.state.contact.zero(model=model)
|
242
|
+
)
|
239
243
|
|
240
244
|
return model_state
|
241
245
|
|