jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py
CHANGED
@@ -2,64 +2,99 @@ from typing import Any, Protocol
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
-
import jaxlie
|
6
5
|
|
7
|
-
import jaxsim.
|
6
|
+
import jaxsim.api as js
|
7
|
+
import jaxsim.rbda
|
8
8
|
import jaxsim.typing as jtp
|
9
|
-
from jaxsim import
|
10
|
-
from jaxsim.
|
11
|
-
from jaxsim.
|
12
|
-
from jaxsim.physics.algos.soft_contacts import SoftContactsState
|
13
|
-
from jaxsim.physics.model.physics_model_state import PhysicsModelState
|
14
|
-
from jaxsim.simulation.ode_data import ODEState
|
9
|
+
from jaxsim.integrators import Time
|
10
|
+
from jaxsim.math import Quaternion
|
11
|
+
from jaxsim.rbda import contacts
|
15
12
|
|
16
|
-
from . import
|
17
|
-
from . import
|
18
|
-
from . import model as Model
|
13
|
+
from .common import VelRepr
|
14
|
+
from .ode_data import ODEState
|
19
15
|
|
20
16
|
|
21
17
|
class SystemDynamicsFromModelAndData(Protocol):
|
18
|
+
"""
|
19
|
+
Protocol defining the signature of a function computing the system dynamics
|
20
|
+
given a model and data object.
|
21
|
+
"""
|
22
|
+
|
22
23
|
def __call__(
|
23
24
|
self,
|
24
|
-
model:
|
25
|
-
data:
|
25
|
+
model: js.model.JaxSimModel,
|
26
|
+
data: js.data.JaxSimModelData,
|
26
27
|
**kwargs: dict[str, Any],
|
27
|
-
) -> tuple[ODEState, dict[str, Any]]:
|
28
|
+
) -> tuple[ODEState, dict[str, Any]]:
|
29
|
+
"""
|
30
|
+
Compute the system dynamics given a model and data object.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
model: The model to consider.
|
34
|
+
data: The data of the considered model.
|
35
|
+
**kwargs: Additional keyword arguments.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
A tuple with an `ODEState` object storing in each of its attributes the
|
39
|
+
corresponding derivative, and the dictionary of auxiliary data returned
|
40
|
+
by the system dynamics evaluation.
|
41
|
+
"""
|
42
|
+
|
43
|
+
pass
|
28
44
|
|
29
45
|
|
30
46
|
def wrap_system_dynamics_for_integration(
|
31
|
-
model: Model.JaxSimModel,
|
32
|
-
data: Data.JaxSimModelData,
|
33
47
|
*,
|
34
48
|
system_dynamics: SystemDynamicsFromModelAndData,
|
35
|
-
**kwargs,
|
49
|
+
**kwargs: dict[str, Any],
|
36
50
|
) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
|
37
51
|
"""
|
38
|
-
Wrap
|
39
|
-
|
52
|
+
Wrap the system dynamics considered by JaxSim integrators in a generic
|
53
|
+
`f(x, t, **u, **parameters)` function.
|
40
54
|
|
41
55
|
Args:
|
42
|
-
model: The model to consider.
|
43
|
-
data: The data of the considered model.
|
44
56
|
system_dynamics: The system dynamics to wrap.
|
45
57
|
**kwargs: Additional kwargs to close over the system dynamics.
|
46
58
|
|
47
59
|
Returns:
|
48
|
-
The system dynamics closed over the
|
60
|
+
The system dynamics closed over the additional kwargs to be used by
|
61
|
+
JaxSim integrators.
|
49
62
|
"""
|
50
63
|
|
51
|
-
#
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
64
|
+
# Close `system_dynamics` over additional kwargs.
|
65
|
+
# Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
|
66
|
+
# following logic to allow the caller to close over arguments having the same name
|
67
|
+
# of the ones used in the `wrap_system_dynamics_for_integration` function.
|
68
|
+
kwargs = kwargs.copy() if kwargs is not None else {}
|
69
|
+
colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
|
70
|
+
system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
|
71
|
+
|
72
|
+
# Remove `model` and `data` for backward compatibility.
|
73
|
+
# It's no longer necessary to close over them at this stage, as this is always
|
74
|
+
# done in `jaxsim.api.model.step`.
|
75
|
+
# We can remove the following lines in a few releases.
|
76
|
+
_ = system_dynamics_kwargs.pop("data", None)
|
77
|
+
_ = system_dynamics_kwargs.pop("model", None)
|
78
|
+
|
79
|
+
# Create the function with the signature expected by our generic integrators.
|
80
|
+
# Note that our system dynamics is time independent.
|
81
|
+
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
82
|
+
|
83
|
+
# Get the data and model objects from the kwargs.
|
84
|
+
data_f = kwargs_f.pop("data")
|
85
|
+
model_f = kwargs_f.pop("model")
|
86
|
+
|
87
|
+
# Update the state and time stored inside data.
|
88
|
+
with data_f.editable(validate=True) as data_rw:
|
58
89
|
data_rw.state = x
|
59
|
-
data_rw.time_ns = jnp.array(t * 1e9).astype(jnp.uint64)
|
60
90
|
|
61
|
-
#
|
62
|
-
|
91
|
+
# Evaluate the system dynamics, allowing to override the kwargs originally
|
92
|
+
# passed when the closure was created.
|
93
|
+
return system_dynamics(
|
94
|
+
model=model_f,
|
95
|
+
data=data_rw,
|
96
|
+
**(system_dynamics_kwargs | kwargs_f),
|
97
|
+
)
|
63
98
|
|
64
99
|
f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
|
65
100
|
return f
|
@@ -71,101 +106,217 @@ def wrap_system_dynamics_for_integration(
|
|
71
106
|
|
72
107
|
|
73
108
|
@jax.jit
|
109
|
+
@js.common.named_scope
|
74
110
|
def system_velocity_dynamics(
|
75
|
-
model:
|
76
|
-
data:
|
111
|
+
model: js.model.JaxSimModel,
|
112
|
+
data: js.data.JaxSimModelData,
|
77
113
|
*,
|
78
|
-
|
79
|
-
|
80
|
-
) -> tuple[jtp.Vector, jtp.Vector,
|
114
|
+
link_forces: jtp.Vector | None = None,
|
115
|
+
joint_force_references: jtp.Vector | None = None,
|
116
|
+
) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
|
81
117
|
"""
|
82
118
|
Compute the dynamics of the system velocity.
|
83
119
|
|
84
120
|
Args:
|
85
121
|
model: The model to consider.
|
86
122
|
data: The data of the considered model.
|
87
|
-
|
88
|
-
|
123
|
+
link_forces:
|
124
|
+
The 6D forces to apply to the links expressed in the frame corresponding to
|
125
|
+
the velocity representation of `data`.
|
126
|
+
joint_force_references: The joint force references to apply.
|
89
127
|
|
90
128
|
Returns:
|
91
129
|
A tuple containing the derivative of the base 6D velocity in inertial-fixed
|
92
|
-
representation, the derivative of the joint velocities,
|
93
|
-
|
94
|
-
the system dynamics evalutation.
|
130
|
+
representation, the derivative of the joint velocities, and auxiliary data
|
131
|
+
returned by the system dynamics evaluation.
|
95
132
|
"""
|
96
133
|
|
97
|
-
# Build
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
# Build external forces if not provided
|
105
|
-
f_ext = (
|
106
|
-
jnp.atleast_2d(external_forces.squeeze())
|
107
|
-
if external_forces is not None
|
134
|
+
# Build link forces if not provided.
|
135
|
+
# These forces are expressed in the frame corresponding to the velocity
|
136
|
+
# representation of data.
|
137
|
+
O_f_L = (
|
138
|
+
jnp.atleast_2d(link_forces.squeeze())
|
139
|
+
if link_forces is not None
|
108
140
|
else jnp.zeros((model.number_of_links(), 6))
|
109
141
|
).astype(float)
|
110
142
|
|
143
|
+
# We expect that the 6D forces included in the `link_forces` argument are expressed
|
144
|
+
# in the frame corresponding to the velocity representation of `data`.
|
145
|
+
references = js.references.JaxSimModelReferences.build(
|
146
|
+
model=model,
|
147
|
+
link_forces=O_f_L,
|
148
|
+
joint_force_references=joint_force_references,
|
149
|
+
data=data,
|
150
|
+
velocity_representation=data.velocity_representation,
|
151
|
+
)
|
152
|
+
|
111
153
|
# ======================
|
112
154
|
# Compute contact forces
|
113
155
|
# ======================
|
114
156
|
|
115
157
|
# Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
|
116
158
|
# with the terrain.
|
117
|
-
|
118
|
-
|
119
|
-
# Initialize
|
120
|
-
#
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
)
|
146
|
-
|
159
|
+
W_f_L_terrain = jnp.zeros_like(O_f_L).astype(float)
|
160
|
+
|
161
|
+
# Initialize a dictionary of auxiliary data.
|
162
|
+
# This dictionary is used to store additional data computed by the contact model.
|
163
|
+
aux_data = {}
|
164
|
+
|
165
|
+
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
166
|
+
|
167
|
+
with (
|
168
|
+
data.switch_velocity_representation(VelRepr.Inertial),
|
169
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
170
|
+
):
|
171
|
+
|
172
|
+
# Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
|
173
|
+
# along with contact-specific auxiliary states.
|
174
|
+
W_f_C, aux_data = js.contact.collidable_point_dynamics(
|
175
|
+
model=model,
|
176
|
+
data=data,
|
177
|
+
link_forces=references.link_forces(model=model, data=data),
|
178
|
+
joint_force_references=references.joint_force_references(model=model),
|
179
|
+
)
|
180
|
+
|
181
|
+
# Compute the 6D forces applied to the links equivalent to the forces applied
|
182
|
+
# to the frames associated to the collidable points.
|
183
|
+
W_f_L_terrain = model.contact_model.link_forces_from_contact_forces(
|
184
|
+
model=model,
|
185
|
+
data=data,
|
186
|
+
contact_forces=W_f_C,
|
187
|
+
)
|
188
|
+
|
189
|
+
# ===========================
|
190
|
+
# Compute system acceleration
|
191
|
+
# ===========================
|
192
|
+
|
193
|
+
# Compute the total link forces.
|
194
|
+
with (
|
195
|
+
data.switch_velocity_representation(VelRepr.Inertial),
|
196
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
197
|
+
):
|
198
|
+
|
199
|
+
# Sum the contact forces just computed with the link forces applied by the user.
|
200
|
+
references = references.apply_link_forces(
|
201
|
+
model=model,
|
202
|
+
data=data,
|
203
|
+
forces=W_f_L_terrain,
|
204
|
+
additive=True,
|
205
|
+
)
|
206
|
+
|
207
|
+
# Get the link forces in inertial-fixed representation.
|
208
|
+
f_L_total = references.link_forces(model=model, data=data)
|
209
|
+
|
210
|
+
# Compute the system acceleration in inertial-fixed representation.
|
211
|
+
# This representation is useful for integration purpose.
|
212
|
+
W_v̇_WB, s̈ = system_acceleration(
|
213
|
+
model=model,
|
214
|
+
data=data,
|
215
|
+
joint_force_references=joint_force_references,
|
216
|
+
link_forces=f_L_total,
|
217
|
+
)
|
218
|
+
|
219
|
+
return W_v̇_WB, s̈, aux_data
|
220
|
+
|
221
|
+
|
222
|
+
def system_acceleration(
|
223
|
+
model: js.model.JaxSimModel,
|
224
|
+
data: js.data.JaxSimModelData,
|
225
|
+
*,
|
226
|
+
link_forces: jtp.MatrixLike | None = None,
|
227
|
+
joint_force_references: jtp.VectorLike | None = None,
|
228
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
229
|
+
"""
|
230
|
+
Compute the system acceleration in the active representation.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
model: The model to consider.
|
234
|
+
data: The data of the considered model.
|
235
|
+
link_forces:
|
236
|
+
The 6D forces to apply to the links expressed in the same
|
237
|
+
velocity representation of data.
|
238
|
+
joint_force_references: The joint force references to apply.
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
A tuple containing the base 6D acceleration in the active representation
|
242
|
+
and the joint accelerations.
|
243
|
+
"""
|
244
|
+
|
245
|
+
# ====================
|
246
|
+
# Validate input data
|
247
|
+
# ====================
|
248
|
+
|
249
|
+
# Build link forces if not provided.
|
250
|
+
f_L = (
|
251
|
+
jnp.atleast_2d(link_forces.squeeze())
|
252
|
+
if link_forces is not None
|
253
|
+
else jnp.zeros((model.number_of_links(), 6))
|
254
|
+
).astype(float)
|
255
|
+
|
256
|
+
# Build joint torques if not provided.
|
257
|
+
τ_references = (
|
258
|
+
jnp.atleast_1d(joint_force_references.squeeze())
|
259
|
+
if joint_force_references is not None
|
260
|
+
else jnp.zeros_like(data.joint_positions())
|
261
|
+
).astype(float)
|
147
262
|
|
148
263
|
# ====================
|
149
264
|
# Enforce joint limits
|
150
265
|
# ====================
|
151
266
|
|
152
|
-
|
153
|
-
|
267
|
+
τ_position_limit = jnp.zeros_like(τ_references).astype(float)
|
268
|
+
|
269
|
+
if model.dofs() > 0:
|
270
|
+
|
271
|
+
# Stiffness and damper parameters for the joint position limits.
|
272
|
+
k_j = jnp.array(
|
273
|
+
model.kin_dyn_parameters.joint_parameters.position_limit_spring
|
274
|
+
).astype(float)
|
275
|
+
d_j = jnp.array(
|
276
|
+
model.kin_dyn_parameters.joint_parameters.position_limit_damper
|
277
|
+
).astype(float)
|
278
|
+
|
279
|
+
# Compute the joint position limit violations.
|
280
|
+
lower_violation = jnp.clip(
|
281
|
+
data.state.physics_model.joint_positions
|
282
|
+
- model.kin_dyn_parameters.joint_parameters.position_limits_min,
|
283
|
+
max=0.0,
|
284
|
+
)
|
285
|
+
|
286
|
+
upper_violation = jnp.clip(
|
287
|
+
data.state.physics_model.joint_positions
|
288
|
+
- model.kin_dyn_parameters.joint_parameters.position_limits_max,
|
289
|
+
min=0.0,
|
290
|
+
)
|
291
|
+
|
292
|
+
# Compute the joint position limit torque.
|
293
|
+
τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
|
294
|
+
|
295
|
+
τ_position_limit -= (
|
296
|
+
jnp.positive(τ_position_limit)
|
297
|
+
* jnp.diag(d_j)
|
298
|
+
@ data.state.physics_model.joint_velocities
|
299
|
+
)
|
154
300
|
|
155
301
|
# ====================
|
156
302
|
# Joint friction model
|
157
303
|
# ====================
|
158
304
|
|
159
|
-
τ_friction = jnp.zeros_like(τ).astype(float)
|
305
|
+
τ_friction = jnp.zeros_like(τ_references).astype(float)
|
160
306
|
|
161
307
|
if model.dofs() > 0:
|
162
|
-
# Static and viscous joint friction parameters
|
163
|
-
kc = jnp.array(list(model.physics_model._joint_friction_static.values()))
|
164
|
-
kv = jnp.array(list(model.physics_model._joint_friction_viscous.values()))
|
165
308
|
|
166
|
-
#
|
309
|
+
# Static and viscous joint friction parameters
|
310
|
+
kc = jnp.array(
|
311
|
+
model.kin_dyn_parameters.joint_parameters.friction_static
|
312
|
+
).astype(float)
|
313
|
+
kv = jnp.array(
|
314
|
+
model.kin_dyn_parameters.joint_parameters.friction_viscous
|
315
|
+
).astype(float)
|
316
|
+
|
317
|
+
# Compute the joint friction torque.
|
167
318
|
τ_friction = -(
|
168
|
-
jnp.diag(kc) @ jnp.sign(data.state.physics_model.
|
319
|
+
jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
|
169
320
|
+ jnp.diag(kv) @ data.state.physics_model.joint_velocities
|
170
321
|
)
|
171
322
|
|
@@ -173,28 +324,40 @@ def system_velocity_dynamics(
|
|
173
324
|
# Compute forward dynamics
|
174
325
|
# ========================
|
175
326
|
|
176
|
-
# Compute the total joint forces
|
177
|
-
τ_total = τ + τ_friction + τ_position_limit
|
327
|
+
# Compute the total joint forces.
|
328
|
+
τ_total = τ_references + τ_friction + τ_position_limit
|
178
329
|
|
179
|
-
#
|
180
|
-
|
330
|
+
# Store the link forces in a references object.
|
331
|
+
references = js.references.JaxSimModelReferences.build(
|
332
|
+
model=model,
|
333
|
+
data=data,
|
334
|
+
velocity_representation=data.velocity_representation,
|
335
|
+
link_forces=f_L,
|
336
|
+
)
|
181
337
|
|
338
|
+
# Compute forward dynamics.
|
339
|
+
#
|
182
340
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
183
|
-
# - Base
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
341
|
+
# - Base acceleration: v̇_WB ∈ ℝ⁶
|
342
|
+
#
|
343
|
+
# Note that ABA returns the base acceleration in the velocity representation
|
344
|
+
# stored in the `data` object.
|
345
|
+
v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
346
|
+
model=model,
|
347
|
+
data=data,
|
348
|
+
joint_forces=τ_total,
|
349
|
+
link_forces=references.link_forces(model=model, data=data),
|
350
|
+
)
|
191
351
|
|
192
|
-
return
|
352
|
+
return v̇_WB, s̈
|
193
353
|
|
194
354
|
|
195
355
|
@jax.jit
|
356
|
+
@js.common.named_scope
|
196
357
|
def system_position_dynamics(
|
197
|
-
model:
|
358
|
+
model: js.model.JaxSimModel,
|
359
|
+
data: js.data.JaxSimModelData,
|
360
|
+
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
|
198
361
|
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
|
199
362
|
"""
|
200
363
|
Compute the dynamics of the system position.
|
@@ -202,14 +365,16 @@ def system_position_dynamics(
|
|
202
365
|
Args:
|
203
366
|
model: The model to consider.
|
204
367
|
data: The data of the considered model.
|
368
|
+
baumgarte_quaternion_regularization:
|
369
|
+
The Baumgarte regularization coefficient for adjusting the quaternion norm.
|
205
370
|
|
206
371
|
Returns:
|
207
372
|
A tuple containing the derivative of the base position, the derivative of the
|
208
373
|
base quaternion, and the derivative of the joint positions.
|
209
374
|
"""
|
210
375
|
|
211
|
-
ṡ = data.
|
212
|
-
W_Q_B = data.
|
376
|
+
ṡ = data.joint_velocities(model=model)
|
377
|
+
W_Q_B = data.base_orientation(dcm=False)
|
213
378
|
|
214
379
|
with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
|
215
380
|
W_ṗ_B = data.base_velocity()[0:3]
|
@@ -221,18 +386,21 @@ def system_position_dynamics(
|
|
221
386
|
quaternion=W_Q_B,
|
222
387
|
omega=W_ω_WB,
|
223
388
|
omega_in_body_fixed=False,
|
389
|
+
K=baumgarte_quaternion_regularization,
|
224
390
|
).squeeze()
|
225
391
|
|
226
392
|
return W_ṗ_B, W_Q̇_B, ṡ
|
227
393
|
|
228
394
|
|
229
395
|
@jax.jit
|
396
|
+
@js.common.named_scope
|
230
397
|
def system_dynamics(
|
231
|
-
model:
|
232
|
-
data:
|
398
|
+
model: js.model.JaxSimModel,
|
399
|
+
data: js.data.JaxSimModelData,
|
233
400
|
*,
|
234
|
-
|
235
|
-
|
401
|
+
link_forces: jtp.Vector | None = None,
|
402
|
+
joint_force_references: jtp.Vector | None = None,
|
403
|
+
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
|
236
404
|
) -> tuple[ODEState, dict[str, Any]]:
|
237
405
|
"""
|
238
406
|
Compute the dynamics of the system.
|
@@ -240,8 +408,13 @@ def system_dynamics(
|
|
240
408
|
Args:
|
241
409
|
model: The model to consider.
|
242
410
|
data: The data of the considered model.
|
243
|
-
|
244
|
-
|
411
|
+
link_forces:
|
412
|
+
The 6D forces to apply to the links expressed in the frame corresponding to
|
413
|
+
the velocity representation of `data`.
|
414
|
+
joint_force_references: The joint force references to apply.
|
415
|
+
baumgarte_quaternion_regularization:
|
416
|
+
The Baumgarte regularization coefficient used to adjust the norm of the
|
417
|
+
quaternion (only used in integrators not operating on the SO(3) manifold).
|
245
418
|
|
246
419
|
Returns:
|
247
420
|
A tuple with an `ODEState` object storing in each of its attributes the
|
@@ -250,31 +423,53 @@ def system_dynamics(
|
|
250
423
|
"""
|
251
424
|
|
252
425
|
# Compute the accelerations and the material deformation rate.
|
253
|
-
W_v̇_WB, s̈,
|
426
|
+
W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
|
254
427
|
model=model,
|
255
428
|
data=data,
|
256
|
-
|
257
|
-
|
429
|
+
joint_force_references=joint_force_references,
|
430
|
+
link_forces=link_forces,
|
258
431
|
)
|
259
432
|
|
433
|
+
# Initialize the dictionary storing the derivative of the additional state variables
|
434
|
+
# that extend the state vector of the integrated ODE system.
|
435
|
+
extended_ode_state = {}
|
436
|
+
|
437
|
+
match model.contact_model:
|
438
|
+
|
439
|
+
case contacts.SoftContacts():
|
440
|
+
extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]
|
441
|
+
|
442
|
+
case contacts.ViscoElasticContacts():
|
443
|
+
|
444
|
+
extended_ode_state["tangential_deformation"] = jnp.zeros_like(
|
445
|
+
data.state.extended["tangential_deformation"]
|
446
|
+
)
|
447
|
+
|
448
|
+
case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
|
449
|
+
pass
|
450
|
+
|
451
|
+
case _:
|
452
|
+
raise ValueError(f"Invalid contact model: {model.contact_model}")
|
453
|
+
|
260
454
|
# Extract the velocities.
|
261
|
-
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
455
|
+
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
456
|
+
model=model,
|
457
|
+
data=data,
|
458
|
+
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
|
459
|
+
)
|
262
460
|
|
263
461
|
# Create an ODEState object populated with the derivative of each leaf.
|
264
462
|
# Our integrators, operating on generic pytrees, will be able to handle it
|
265
463
|
# automatically as state derivative.
|
266
|
-
ode_state_derivative = ODEState.
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
soft_contacts_state=SoftContactsState.build(
|
276
|
-
tangential_deformation=ṁ,
|
277
|
-
),
|
464
|
+
ode_state_derivative = ODEState.build_from_jaxsim_model(
|
465
|
+
model=model,
|
466
|
+
base_position=W_ṗ_B,
|
467
|
+
base_quaternion=W_Q̇_B,
|
468
|
+
joint_positions=ṡ,
|
469
|
+
base_linear_velocity=W_v̇_WB[0:3],
|
470
|
+
base_angular_velocity=W_v̇_WB[3:6],
|
471
|
+
joint_velocities=s̈,
|
472
|
+
**extended_ode_state,
|
278
473
|
)
|
279
474
|
|
280
475
|
return ode_state_derivative, aux_dict
|