jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py
CHANGED
@@ -2,64 +2,99 @@ from typing import Any, Protocol
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
-
import jaxlie
|
6
5
|
|
7
|
-
import jaxsim.
|
6
|
+
import jaxsim.api as js
|
7
|
+
import jaxsim.rbda
|
8
8
|
import jaxsim.typing as jtp
|
9
|
-
from jaxsim import
|
10
|
-
from jaxsim.
|
11
|
-
from jaxsim.
|
12
|
-
from jaxsim.physics.algos.soft_contacts import SoftContactsState
|
13
|
-
from jaxsim.physics.model.physics_model_state import PhysicsModelState
|
14
|
-
from jaxsim.simulation.ode_data import ODEState
|
9
|
+
from jaxsim.integrators import Time
|
10
|
+
from jaxsim.math import Quaternion
|
11
|
+
from jaxsim.rbda import contacts
|
15
12
|
|
16
|
-
from . import
|
17
|
-
from . import
|
18
|
-
from . import model as Model
|
13
|
+
from .common import VelRepr
|
14
|
+
from .ode_data import ODEState
|
19
15
|
|
20
16
|
|
21
17
|
class SystemDynamicsFromModelAndData(Protocol):
|
18
|
+
"""
|
19
|
+
Protocol defining the signature of a function computing the system dynamics
|
20
|
+
given a model and data object.
|
21
|
+
"""
|
22
|
+
|
22
23
|
def __call__(
|
23
24
|
self,
|
24
|
-
model:
|
25
|
-
data:
|
25
|
+
model: js.model.JaxSimModel,
|
26
|
+
data: js.data.JaxSimModelData,
|
26
27
|
**kwargs: dict[str, Any],
|
27
|
-
) -> tuple[ODEState, dict[str, Any]]:
|
28
|
+
) -> tuple[ODEState, dict[str, Any]]:
|
29
|
+
"""
|
30
|
+
Compute the system dynamics given a model and data object.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
model: The model to consider.
|
34
|
+
data: The data of the considered model.
|
35
|
+
**kwargs: Additional keyword arguments.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
A tuple with an `ODEState` object storing in each of its attributes the
|
39
|
+
corresponding derivative, and the dictionary of auxiliary data returned
|
40
|
+
by the system dynamics evaluation.
|
41
|
+
"""
|
42
|
+
|
43
|
+
pass
|
28
44
|
|
29
45
|
|
30
46
|
def wrap_system_dynamics_for_integration(
|
31
|
-
model: Model.JaxSimModel,
|
32
|
-
data: Data.JaxSimModelData,
|
33
47
|
*,
|
34
48
|
system_dynamics: SystemDynamicsFromModelAndData,
|
35
|
-
**kwargs,
|
49
|
+
**kwargs: dict[str, Any],
|
36
50
|
) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
|
37
51
|
"""
|
38
|
-
Wrap
|
39
|
-
|
52
|
+
Wrap the system dynamics considered by JaxSim integrators in a generic
|
53
|
+
`f(x, t, **u, **parameters)` function.
|
40
54
|
|
41
55
|
Args:
|
42
|
-
model: The model to consider.
|
43
|
-
data: The data of the considered model.
|
44
56
|
system_dynamics: The system dynamics to wrap.
|
45
57
|
**kwargs: Additional kwargs to close over the system dynamics.
|
46
58
|
|
47
59
|
Returns:
|
48
|
-
The system dynamics closed over the
|
60
|
+
The system dynamics closed over the additional kwargs to be used by
|
61
|
+
JaxSim integrators.
|
49
62
|
"""
|
50
63
|
|
51
|
-
#
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
64
|
+
# Close `system_dynamics` over additional kwargs.
|
65
|
+
# Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
|
66
|
+
# following logic to allow the caller to close over arguments having the same name
|
67
|
+
# of the ones used in the `wrap_system_dynamics_for_integration` function.
|
68
|
+
kwargs = kwargs.copy() if kwargs is not None else {}
|
69
|
+
colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
|
70
|
+
system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
|
71
|
+
|
72
|
+
# Remove `model` and `data` for backward compatibility.
|
73
|
+
# It's no longer necessary to close over them at this stage, as this is always
|
74
|
+
# done in `jaxsim.api.model.step`.
|
75
|
+
# We can remove the following lines in a few releases.
|
76
|
+
_ = system_dynamics_kwargs.pop("data", None)
|
77
|
+
_ = system_dynamics_kwargs.pop("model", None)
|
78
|
+
|
79
|
+
# Create the function with the signature expected by our generic integrators.
|
80
|
+
# Note that our system dynamics is time independent.
|
81
|
+
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
82
|
+
|
83
|
+
# Get the data and model objects from the kwargs.
|
84
|
+
data_f = kwargs_f.pop("data")
|
85
|
+
model_f = kwargs_f.pop("model")
|
86
|
+
|
87
|
+
# Update the state and time stored inside data.
|
88
|
+
with data_f.editable(validate=True) as data_rw:
|
58
89
|
data_rw.state = x
|
59
|
-
data_rw.time_ns = jnp.array(t * 1e9).astype(jnp.uint64)
|
60
90
|
|
61
|
-
#
|
62
|
-
|
91
|
+
# Evaluate the system dynamics, allowing to override the kwargs originally
|
92
|
+
# passed when the closure was created.
|
93
|
+
return system_dynamics(
|
94
|
+
model=model_f,
|
95
|
+
data=data_rw,
|
96
|
+
**(system_dynamics_kwargs | kwargs_f),
|
97
|
+
)
|
63
98
|
|
64
99
|
f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
|
65
100
|
return f
|
@@ -71,105 +106,217 @@ def wrap_system_dynamics_for_integration(
|
|
71
106
|
|
72
107
|
|
73
108
|
@jax.jit
|
109
|
+
@js.common.named_scope
|
74
110
|
def system_velocity_dynamics(
|
75
|
-
model:
|
76
|
-
data:
|
111
|
+
model: js.model.JaxSimModel,
|
112
|
+
data: js.data.JaxSimModelData,
|
77
113
|
*,
|
78
|
-
|
79
|
-
|
80
|
-
) -> tuple[jtp.Vector, jtp.Vector,
|
114
|
+
link_forces: jtp.Vector | None = None,
|
115
|
+
joint_force_references: jtp.Vector | None = None,
|
116
|
+
) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
|
81
117
|
"""
|
82
118
|
Compute the dynamics of the system velocity.
|
83
119
|
|
84
120
|
Args:
|
85
121
|
model: The model to consider.
|
86
122
|
data: The data of the considered model.
|
87
|
-
|
88
|
-
|
123
|
+
link_forces:
|
124
|
+
The 6D forces to apply to the links expressed in the frame corresponding to
|
125
|
+
the velocity representation of `data`.
|
126
|
+
joint_force_references: The joint force references to apply.
|
89
127
|
|
90
128
|
Returns:
|
91
129
|
A tuple containing the derivative of the base 6D velocity in inertial-fixed
|
92
|
-
representation, the derivative of the joint velocities,
|
93
|
-
|
94
|
-
the system dynamics evalutation.
|
130
|
+
representation, the derivative of the joint velocities, and auxiliary data
|
131
|
+
returned by the system dynamics evaluation.
|
95
132
|
"""
|
96
133
|
|
97
|
-
# Build
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
# Build external forces if not provided
|
105
|
-
f_ext = (
|
106
|
-
jnp.atleast_2d(external_forces.squeeze())
|
107
|
-
if external_forces is not None
|
134
|
+
# Build link forces if not provided.
|
135
|
+
# These forces are expressed in the frame corresponding to the velocity
|
136
|
+
# representation of data.
|
137
|
+
O_f_L = (
|
138
|
+
jnp.atleast_2d(link_forces.squeeze())
|
139
|
+
if link_forces is not None
|
108
140
|
else jnp.zeros((model.number_of_links(), 6))
|
109
141
|
).astype(float)
|
110
142
|
|
143
|
+
# We expect that the 6D forces included in the `link_forces` argument are expressed
|
144
|
+
# in the frame corresponding to the velocity representation of `data`.
|
145
|
+
references = js.references.JaxSimModelReferences.build(
|
146
|
+
model=model,
|
147
|
+
link_forces=O_f_L,
|
148
|
+
joint_force_references=joint_force_references,
|
149
|
+
data=data,
|
150
|
+
velocity_representation=data.velocity_representation,
|
151
|
+
)
|
152
|
+
|
111
153
|
# ======================
|
112
154
|
# Compute contact forces
|
113
155
|
# ======================
|
114
156
|
|
115
157
|
# Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
|
116
158
|
# with the terrain.
|
117
|
-
|
118
|
-
|
119
|
-
# Initialize
|
120
|
-
#
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
159
|
+
W_f_L_terrain = jnp.zeros_like(O_f_L).astype(float)
|
160
|
+
|
161
|
+
# Initialize a dictionary of auxiliary data.
|
162
|
+
# This dictionary is used to store additional data computed by the contact model.
|
163
|
+
aux_data = {}
|
164
|
+
|
165
|
+
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
166
|
+
|
167
|
+
with (
|
168
|
+
data.switch_velocity_representation(VelRepr.Inertial),
|
169
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
170
|
+
):
|
171
|
+
|
172
|
+
# Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
|
173
|
+
# along with contact-specific auxiliary states.
|
174
|
+
W_f_C, aux_data = js.contact.collidable_point_dynamics(
|
175
|
+
model=model,
|
176
|
+
data=data,
|
177
|
+
link_forces=references.link_forces(model=model, data=data),
|
178
|
+
joint_force_references=references.joint_force_references(model=model),
|
179
|
+
)
|
180
|
+
|
181
|
+
# Compute the 6D forces applied to the links equivalent to the forces applied
|
182
|
+
# to the frames associated to the collidable points.
|
183
|
+
W_f_L_terrain = model.contact_model.link_forces_from_contact_forces(
|
184
|
+
model=model,
|
185
|
+
data=data,
|
186
|
+
contact_forces=W_f_C,
|
187
|
+
)
|
188
|
+
|
189
|
+
# ===========================
|
190
|
+
# Compute system acceleration
|
191
|
+
# ===========================
|
192
|
+
|
193
|
+
# Compute the total link forces.
|
194
|
+
with (
|
195
|
+
data.switch_velocity_representation(VelRepr.Inertial),
|
196
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
197
|
+
):
|
198
|
+
|
199
|
+
# Sum the contact forces just computed with the link forces applied by the user.
|
200
|
+
references = references.apply_link_forces(
|
201
|
+
model=model,
|
202
|
+
data=data,
|
203
|
+
forces=W_f_L_terrain,
|
204
|
+
additive=True,
|
205
|
+
)
|
206
|
+
|
207
|
+
# Get the link forces in inertial-fixed representation.
|
208
|
+
f_L_total = references.link_forces(model=model, data=data)
|
209
|
+
|
210
|
+
# Compute the system acceleration in inertial-fixed representation.
|
211
|
+
# This representation is useful for integration purpose.
|
212
|
+
W_v̇_WB, s̈ = system_acceleration(
|
213
|
+
model=model,
|
214
|
+
data=data,
|
215
|
+
joint_force_references=joint_force_references,
|
216
|
+
link_forces=f_L_total,
|
217
|
+
)
|
218
|
+
|
219
|
+
return W_v̇_WB, s̈, aux_data
|
220
|
+
|
221
|
+
|
222
|
+
def system_acceleration(
|
223
|
+
model: js.model.JaxSimModel,
|
224
|
+
data: js.data.JaxSimModelData,
|
225
|
+
*,
|
226
|
+
link_forces: jtp.MatrixLike | None = None,
|
227
|
+
joint_force_references: jtp.VectorLike | None = None,
|
228
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
229
|
+
"""
|
230
|
+
Compute the system acceleration in the active representation.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
model: The model to consider.
|
234
|
+
data: The data of the considered model.
|
235
|
+
link_forces:
|
236
|
+
The 6D forces to apply to the links expressed in the same
|
237
|
+
velocity representation of data.
|
238
|
+
joint_force_references: The joint force references to apply.
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
A tuple containing the base 6D acceleration in the active representation
|
242
|
+
and the joint accelerations.
|
243
|
+
"""
|
244
|
+
|
245
|
+
# ====================
|
246
|
+
# Validate input data
|
247
|
+
# ====================
|
248
|
+
|
249
|
+
# Build link forces if not provided.
|
250
|
+
f_L = (
|
251
|
+
jnp.atleast_2d(link_forces.squeeze())
|
252
|
+
if link_forces is not None
|
253
|
+
else jnp.zeros((model.number_of_links(), 6))
|
254
|
+
).astype(float)
|
255
|
+
|
256
|
+
# Build joint torques if not provided.
|
257
|
+
τ_references = (
|
258
|
+
jnp.atleast_1d(joint_force_references.squeeze())
|
259
|
+
if joint_force_references is not None
|
260
|
+
else jnp.zeros_like(data.joint_positions())
|
261
|
+
).astype(float)
|
151
262
|
|
152
263
|
# ====================
|
153
264
|
# Enforce joint limits
|
154
265
|
# ====================
|
155
266
|
|
156
|
-
|
157
|
-
|
267
|
+
τ_position_limit = jnp.zeros_like(τ_references).astype(float)
|
268
|
+
|
269
|
+
if model.dofs() > 0:
|
270
|
+
|
271
|
+
# Stiffness and damper parameters for the joint position limits.
|
272
|
+
k_j = jnp.array(
|
273
|
+
model.kin_dyn_parameters.joint_parameters.position_limit_spring
|
274
|
+
).astype(float)
|
275
|
+
d_j = jnp.array(
|
276
|
+
model.kin_dyn_parameters.joint_parameters.position_limit_damper
|
277
|
+
).astype(float)
|
278
|
+
|
279
|
+
# Compute the joint position limit violations.
|
280
|
+
lower_violation = jnp.clip(
|
281
|
+
data.state.physics_model.joint_positions
|
282
|
+
- model.kin_dyn_parameters.joint_parameters.position_limits_min,
|
283
|
+
max=0.0,
|
284
|
+
)
|
285
|
+
|
286
|
+
upper_violation = jnp.clip(
|
287
|
+
data.state.physics_model.joint_positions
|
288
|
+
- model.kin_dyn_parameters.joint_parameters.position_limits_max,
|
289
|
+
min=0.0,
|
290
|
+
)
|
291
|
+
|
292
|
+
# Compute the joint position limit torque.
|
293
|
+
τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
|
294
|
+
|
295
|
+
τ_position_limit -= (
|
296
|
+
jnp.positive(τ_position_limit)
|
297
|
+
* jnp.diag(d_j)
|
298
|
+
@ data.state.physics_model.joint_velocities
|
299
|
+
)
|
158
300
|
|
159
301
|
# ====================
|
160
302
|
# Joint friction model
|
161
303
|
# ====================
|
162
304
|
|
163
|
-
τ_friction = jnp.zeros_like(τ).astype(float)
|
305
|
+
τ_friction = jnp.zeros_like(τ_references).astype(float)
|
164
306
|
|
165
307
|
if model.dofs() > 0:
|
166
|
-
# Static and viscous joint friction parameters
|
167
|
-
kc = jnp.array(list(model.physics_model._joint_friction_static.values()))
|
168
|
-
kv = jnp.array(list(model.physics_model._joint_friction_viscous.values()))
|
169
308
|
|
170
|
-
#
|
309
|
+
# Static and viscous joint friction parameters
|
310
|
+
kc = jnp.array(
|
311
|
+
model.kin_dyn_parameters.joint_parameters.friction_static
|
312
|
+
).astype(float)
|
313
|
+
kv = jnp.array(
|
314
|
+
model.kin_dyn_parameters.joint_parameters.friction_viscous
|
315
|
+
).astype(float)
|
316
|
+
|
317
|
+
# Compute the joint friction torque.
|
171
318
|
τ_friction = -(
|
172
|
-
jnp.diag(kc) @ jnp.sign(data.state.physics_model.
|
319
|
+
jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
|
173
320
|
+ jnp.diag(kv) @ data.state.physics_model.joint_velocities
|
174
321
|
)
|
175
322
|
|
@@ -177,28 +324,40 @@ def system_velocity_dynamics(
|
|
177
324
|
# Compute forward dynamics
|
178
325
|
# ========================
|
179
326
|
|
180
|
-
# Compute the total joint forces
|
181
|
-
τ_total = τ + τ_friction + τ_position_limit
|
327
|
+
# Compute the total joint forces.
|
328
|
+
τ_total = τ_references + τ_friction + τ_position_limit
|
182
329
|
|
183
|
-
#
|
184
|
-
|
330
|
+
# Store the link forces in a references object.
|
331
|
+
references = js.references.JaxSimModelReferences.build(
|
332
|
+
model=model,
|
333
|
+
data=data,
|
334
|
+
velocity_representation=data.velocity_representation,
|
335
|
+
link_forces=f_L,
|
336
|
+
)
|
185
337
|
|
338
|
+
# Compute forward dynamics.
|
339
|
+
#
|
186
340
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
187
|
-
# - Base
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
341
|
+
# - Base acceleration: v̇_WB ∈ ℝ⁶
|
342
|
+
#
|
343
|
+
# Note that ABA returns the base acceleration in the velocity representation
|
344
|
+
# stored in the `data` object.
|
345
|
+
v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
346
|
+
model=model,
|
347
|
+
data=data,
|
348
|
+
joint_forces=τ_total,
|
349
|
+
link_forces=references.link_forces(model=model, data=data),
|
350
|
+
)
|
195
351
|
|
196
|
-
return
|
352
|
+
return v̇_WB, s̈
|
197
353
|
|
198
354
|
|
199
355
|
@jax.jit
|
356
|
+
@js.common.named_scope
|
200
357
|
def system_position_dynamics(
|
201
|
-
model:
|
358
|
+
model: js.model.JaxSimModel,
|
359
|
+
data: js.data.JaxSimModelData,
|
360
|
+
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
|
202
361
|
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
|
203
362
|
"""
|
204
363
|
Compute the dynamics of the system position.
|
@@ -206,14 +365,16 @@ def system_position_dynamics(
|
|
206
365
|
Args:
|
207
366
|
model: The model to consider.
|
208
367
|
data: The data of the considered model.
|
368
|
+
baumgarte_quaternion_regularization:
|
369
|
+
The Baumgarte regularization coefficient for adjusting the quaternion norm.
|
209
370
|
|
210
371
|
Returns:
|
211
372
|
A tuple containing the derivative of the base position, the derivative of the
|
212
373
|
base quaternion, and the derivative of the joint positions.
|
213
374
|
"""
|
214
375
|
|
215
|
-
ṡ = data.
|
216
|
-
W_Q_B = data.
|
376
|
+
ṡ = data.joint_velocities(model=model)
|
377
|
+
W_Q_B = data.base_orientation(dcm=False)
|
217
378
|
|
218
379
|
with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
|
219
380
|
W_ṗ_B = data.base_velocity()[0:3]
|
@@ -225,18 +386,21 @@ def system_position_dynamics(
|
|
225
386
|
quaternion=W_Q_B,
|
226
387
|
omega=W_ω_WB,
|
227
388
|
omega_in_body_fixed=False,
|
389
|
+
K=baumgarte_quaternion_regularization,
|
228
390
|
).squeeze()
|
229
391
|
|
230
392
|
return W_ṗ_B, W_Q̇_B, ṡ
|
231
393
|
|
232
394
|
|
233
395
|
@jax.jit
|
396
|
+
@js.common.named_scope
|
234
397
|
def system_dynamics(
|
235
|
-
model:
|
236
|
-
data:
|
398
|
+
model: js.model.JaxSimModel,
|
399
|
+
data: js.data.JaxSimModelData,
|
237
400
|
*,
|
238
|
-
|
239
|
-
|
401
|
+
link_forces: jtp.Vector | None = None,
|
402
|
+
joint_force_references: jtp.Vector | None = None,
|
403
|
+
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
|
240
404
|
) -> tuple[ODEState, dict[str, Any]]:
|
241
405
|
"""
|
242
406
|
Compute the dynamics of the system.
|
@@ -244,8 +408,13 @@ def system_dynamics(
|
|
244
408
|
Args:
|
245
409
|
model: The model to consider.
|
246
410
|
data: The data of the considered model.
|
247
|
-
|
248
|
-
|
411
|
+
link_forces:
|
412
|
+
The 6D forces to apply to the links expressed in the frame corresponding to
|
413
|
+
the velocity representation of `data`.
|
414
|
+
joint_force_references: The joint force references to apply.
|
415
|
+
baumgarte_quaternion_regularization:
|
416
|
+
The Baumgarte regularization coefficient used to adjust the norm of the
|
417
|
+
quaternion (only used in integrators not operating on the SO(3) manifold).
|
249
418
|
|
250
419
|
Returns:
|
251
420
|
A tuple with an `ODEState` object storing in each of its attributes the
|
@@ -254,31 +423,53 @@ def system_dynamics(
|
|
254
423
|
"""
|
255
424
|
|
256
425
|
# Compute the accelerations and the material deformation rate.
|
257
|
-
W_v̇_WB, s̈,
|
426
|
+
W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
|
258
427
|
model=model,
|
259
428
|
data=data,
|
260
|
-
|
261
|
-
|
429
|
+
joint_force_references=joint_force_references,
|
430
|
+
link_forces=link_forces,
|
262
431
|
)
|
263
432
|
|
433
|
+
# Initialize the dictionary storing the derivative of the additional state variables
|
434
|
+
# that extend the state vector of the integrated ODE system.
|
435
|
+
extended_ode_state = {}
|
436
|
+
|
437
|
+
match model.contact_model:
|
438
|
+
|
439
|
+
case contacts.SoftContacts():
|
440
|
+
extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]
|
441
|
+
|
442
|
+
case contacts.ViscoElasticContacts():
|
443
|
+
|
444
|
+
extended_ode_state["tangential_deformation"] = jnp.zeros_like(
|
445
|
+
data.state.extended["tangential_deformation"]
|
446
|
+
)
|
447
|
+
|
448
|
+
case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
|
449
|
+
pass
|
450
|
+
|
451
|
+
case _:
|
452
|
+
raise ValueError(f"Invalid contact model: {model.contact_model}")
|
453
|
+
|
264
454
|
# Extract the velocities.
|
265
|
-
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
455
|
+
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
456
|
+
model=model,
|
457
|
+
data=data,
|
458
|
+
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
|
459
|
+
)
|
266
460
|
|
267
461
|
# Create an ODEState object populated with the derivative of each leaf.
|
268
462
|
# Our integrators, operating on generic pytrees, will be able to handle it
|
269
463
|
# automatically as state derivative.
|
270
|
-
ode_state_derivative = ODEState.
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
soft_contacts_state=SoftContactsState.build(
|
280
|
-
tangential_deformation=ṁ,
|
281
|
-
),
|
464
|
+
ode_state_derivative = ODEState.build_from_jaxsim_model(
|
465
|
+
model=model,
|
466
|
+
base_position=W_ṗ_B,
|
467
|
+
base_quaternion=W_Q̇_B,
|
468
|
+
joint_positions=ṡ,
|
469
|
+
base_linear_velocity=W_v̇_WB[0:3],
|
470
|
+
base_angular_velocity=W_v̇_WB[3:6],
|
471
|
+
joint_velocities=s̈,
|
472
|
+
**extended_ode_state,
|
282
473
|
)
|
283
474
|
|
284
475
|
return ode_state_derivative, aux_dict
|