jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py
CHANGED
@@ -1,30 +1,30 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copy
|
3
4
|
import dataclasses
|
4
5
|
import functools
|
5
6
|
import pathlib
|
7
|
+
from collections.abc import Sequence
|
6
8
|
from typing import Any
|
7
9
|
|
8
10
|
import jax
|
9
11
|
import jax.numpy as jnp
|
10
12
|
import jax_dataclasses
|
11
|
-
import jaxlie
|
12
13
|
import rod
|
13
14
|
from jax_dataclasses import Static
|
14
15
|
|
15
16
|
import jaxsim.api as js
|
16
|
-
import jaxsim.
|
17
|
-
import jaxsim.
|
18
|
-
import jaxsim.physics.algos.forward_kinematics
|
19
|
-
import jaxsim.physics.algos.rnea
|
20
|
-
import jaxsim.physics.model.physics_model
|
17
|
+
import jaxsim.exceptions
|
18
|
+
import jaxsim.terrain
|
21
19
|
import jaxsim.typing as jtp
|
22
|
-
from jaxsim.
|
23
|
-
from jaxsim.
|
24
|
-
from jaxsim.utils import JaxsimDataclass, Mutability
|
20
|
+
from jaxsim.math import Adjoint, Cross
|
21
|
+
from jaxsim.parsers.descriptions import ModelDescription
|
22
|
+
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
|
25
23
|
|
24
|
+
from .common import VelRepr
|
26
25
|
|
27
|
-
|
26
|
+
|
27
|
+
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
|
28
28
|
class JaxSimModel(JaxsimDataclass):
|
29
29
|
"""
|
30
30
|
The JaxSim model defining the kinematics and dynamics of a robot.
|
@@ -32,46 +32,87 @@ class JaxSimModel(JaxsimDataclass):
|
|
32
32
|
|
33
33
|
model_name: Static[str]
|
34
34
|
|
35
|
-
|
36
|
-
|
35
|
+
time_step: jaxsim.integrators.TimeStep = dataclasses.field(
|
36
|
+
default_factory=lambda: jnp.array(0.001, dtype=float),
|
37
|
+
)
|
38
|
+
|
39
|
+
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
|
40
|
+
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
|
41
|
+
)
|
42
|
+
|
43
|
+
# Note that this is the default contact model.
|
44
|
+
contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(
|
45
|
+
default=None, repr=False
|
37
46
|
)
|
38
47
|
|
39
|
-
|
48
|
+
kin_dyn_parameters: js.kin_dyn_parameters.KinDynParameters | None = (
|
49
|
+
dataclasses.field(default=None, repr=False)
|
50
|
+
)
|
40
51
|
|
41
52
|
built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
|
42
|
-
|
53
|
+
default=None, repr=False
|
43
54
|
)
|
44
55
|
|
45
|
-
|
46
|
-
|
56
|
+
integrator: Static[jaxsim.integrators.Integrator | None] = dataclasses.field(
|
57
|
+
default=None, repr=False
|
47
58
|
)
|
48
59
|
|
49
|
-
|
50
|
-
|
60
|
+
_description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
|
61
|
+
dataclasses.field(default=None, repr=False)
|
51
62
|
)
|
52
63
|
|
53
|
-
|
64
|
+
@property
|
65
|
+
def description(self) -> ModelDescription:
|
66
|
+
"""
|
67
|
+
Return the model description.
|
68
|
+
"""
|
69
|
+
return self._description.get()
|
70
|
+
|
71
|
+
def __eq__(self, other: JaxSimModel) -> bool:
|
72
|
+
|
73
|
+
if not isinstance(other, JaxSimModel):
|
74
|
+
return False
|
75
|
+
|
76
|
+
if self.model_name != other.model_name:
|
77
|
+
return False
|
78
|
+
|
79
|
+
if self.time_step != other.time_step:
|
80
|
+
return False
|
54
81
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
82
|
+
if self.kin_dyn_parameters != other.kin_dyn_parameters:
|
83
|
+
return False
|
84
|
+
|
85
|
+
return True
|
86
|
+
|
87
|
+
def __hash__(self) -> int:
|
88
|
+
|
89
|
+
return hash(
|
90
|
+
(
|
91
|
+
hash(self.model_name),
|
92
|
+
hash(float(self.time_step)),
|
93
|
+
hash(self.kin_dyn_parameters),
|
94
|
+
hash(self.contact_model),
|
95
|
+
)
|
96
|
+
)
|
63
97
|
|
64
98
|
# ========================
|
65
99
|
# Initialization and state
|
66
100
|
# ========================
|
67
101
|
|
68
|
-
@
|
102
|
+
@classmethod
|
69
103
|
def build_from_model_description(
|
104
|
+
cls,
|
70
105
|
model_description: str | pathlib.Path | rod.Model,
|
106
|
+
*,
|
71
107
|
model_name: str | None = None,
|
72
|
-
|
108
|
+
time_step: jtp.FloatLike | None = None,
|
109
|
+
integrator: (
|
110
|
+
jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
|
111
|
+
) = None,
|
112
|
+
terrain: jaxsim.terrain.Terrain | None = None,
|
113
|
+
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
|
73
114
|
is_urdf: bool | None = None,
|
74
|
-
considered_joints:
|
115
|
+
considered_joints: Sequence[str] | None = None,
|
75
116
|
) -> JaxSimModel:
|
76
117
|
"""
|
77
118
|
Build a Model object from a model description.
|
@@ -81,12 +122,21 @@ class JaxSimModel(JaxsimDataclass):
|
|
81
122
|
A path to an SDF/URDF file, a string containing
|
82
123
|
its content, or a pre-parsed/pre-built rod model.
|
83
124
|
model_name:
|
84
|
-
The
|
85
|
-
|
86
|
-
|
125
|
+
The name of the model. If not specified, it is read from the description.
|
126
|
+
time_step:
|
127
|
+
The default time step to consider for the simulation. It can be
|
128
|
+
manually overridden in the function that steps the simulation.
|
129
|
+
terrain: The terrain to consider (the default is a flat infinite plane).
|
130
|
+
contact_model:
|
131
|
+
The contact model to consider.
|
132
|
+
If not specified, a soft contacts model is used.
|
133
|
+
integrator:
|
134
|
+
The integrator to use. If not specified, a default one is used.
|
135
|
+
This argument can either be a pre-built integrator instance or one
|
136
|
+
of the integrator classes defined in JaxSim.
|
87
137
|
is_urdf:
|
88
|
-
|
89
|
-
|
138
|
+
The optional flag to force the model description to be parsed as a URDF.
|
139
|
+
This is usually automatically inferred.
|
90
140
|
considered_joints:
|
91
141
|
The list of joints to consider. If None, all joints are considered.
|
92
142
|
|
@@ -97,7 +147,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
97
147
|
import jaxsim.parsers.rod
|
98
148
|
|
99
149
|
# Parse the input resource (either a path to file or a string with the URDF/SDF)
|
100
|
-
# and build the -intermediate- model description
|
150
|
+
# and build the -intermediate- model description.
|
101
151
|
intermediate_description = jaxsim.parsers.rod.build_model_description(
|
102
152
|
model_description=model_description, is_urdf=is_urdf
|
103
153
|
)
|
@@ -109,44 +159,134 @@ class JaxSimModel(JaxsimDataclass):
|
|
109
159
|
considered_joints=considered_joints
|
110
160
|
)
|
111
161
|
|
112
|
-
#
|
113
|
-
|
114
|
-
model_description=intermediate_description,
|
162
|
+
# Build the model.
|
163
|
+
model = cls.build(
|
164
|
+
model_description=intermediate_description,
|
165
|
+
model_name=model_name,
|
166
|
+
time_step=time_step,
|
167
|
+
integrator=integrator,
|
168
|
+
terrain=terrain,
|
169
|
+
contact_model=contact_model,
|
115
170
|
)
|
116
171
|
|
117
|
-
#
|
118
|
-
model = JaxSimModel.build(physics_model=physics_model, model_name=model_name)
|
119
|
-
|
120
|
-
# Store the origin of the model, in case downstream logic needs it
|
172
|
+
# Store the origin of the model, in case downstream logic needs it.
|
121
173
|
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
122
174
|
model.built_from = model_description
|
123
175
|
|
124
176
|
return model
|
125
177
|
|
126
|
-
@
|
178
|
+
@classmethod
|
127
179
|
def build(
|
128
|
-
|
180
|
+
cls,
|
181
|
+
model_description: ModelDescription,
|
182
|
+
*,
|
129
183
|
model_name: str | None = None,
|
184
|
+
time_step: jtp.FloatLike | None = None,
|
185
|
+
integrator: (
|
186
|
+
jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
|
187
|
+
) = None,
|
188
|
+
terrain: jaxsim.terrain.Terrain | None = None,
|
189
|
+
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
|
130
190
|
) -> JaxSimModel:
|
131
191
|
"""
|
132
|
-
Build a Model object from
|
192
|
+
Build a Model object from an intermediate model description.
|
133
193
|
|
134
194
|
Args:
|
135
|
-
|
195
|
+
model_description:
|
196
|
+
The intermediate model description defining the kinematics and dynamics
|
197
|
+
of the model.
|
136
198
|
model_name:
|
199
|
+
The name of the model. If not specified, it is read from the description.
|
200
|
+
time_step:
|
201
|
+
The default time step to consider for the simulation. It can be
|
202
|
+
manually overridden in the function that steps the simulation.
|
203
|
+
terrain: The terrain to consider (the default is a flat infinite plane).
|
137
204
|
The optional name of the model overriding the physics model name.
|
205
|
+
integrator:
|
206
|
+
The integrator to use. If not specified, a default one is used.
|
207
|
+
This argument can either be a pre-built integrator instance or one
|
208
|
+
of the integrator classes defined in JaxSim.
|
209
|
+
contact_model:
|
210
|
+
The contact model to consider.
|
211
|
+
If not specified, a soft contacts model is used.
|
138
212
|
|
139
213
|
Returns:
|
140
214
|
The built Model object.
|
141
215
|
"""
|
142
216
|
|
143
|
-
# Set the model name (if not provided, use the one from the model description)
|
144
|
-
model_name =
|
145
|
-
|
217
|
+
# Set the model name (if not provided, use the one from the model description).
|
218
|
+
model_name = model_name if model_name is not None else model_description.name
|
219
|
+
|
220
|
+
# Consider the default terrain (a flat infinite plane) if not specified.
|
221
|
+
terrain = (
|
222
|
+
terrain
|
223
|
+
if terrain is not None
|
224
|
+
else JaxSimModel.__dataclass_fields__["terrain"].default_factory()
|
225
|
+
)
|
226
|
+
|
227
|
+
# Consider the default time step if not specified.
|
228
|
+
time_step = (
|
229
|
+
time_step
|
230
|
+
if time_step is not None
|
231
|
+
else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
|
232
|
+
)
|
233
|
+
|
234
|
+
# Create the default contact model.
|
235
|
+
# It will be populated with an initial estimation of good parameters.
|
236
|
+
# While these might not be the best, they are a good starting point.
|
237
|
+
contact_model = (
|
238
|
+
contact_model
|
239
|
+
if contact_model is not None
|
240
|
+
else jaxsim.rbda.contacts.SoftContacts.build()
|
146
241
|
)
|
147
242
|
|
148
|
-
# Build the
|
149
|
-
|
243
|
+
# Build the integrator if not provided.
|
244
|
+
match integrator:
|
245
|
+
|
246
|
+
# If None, build a default integrator.
|
247
|
+
case None:
|
248
|
+
|
249
|
+
integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
|
250
|
+
dynamics=js.ode.wrap_system_dynamics_for_integration(
|
251
|
+
system_dynamics=js.ode.system_dynamics
|
252
|
+
)
|
253
|
+
)
|
254
|
+
|
255
|
+
# If it's a pre-built integrator (also a custom one from the user)
|
256
|
+
# just use it as is.
|
257
|
+
case _ if isinstance(integrator, jaxsim.integrators.Integrator):
|
258
|
+
pass
|
259
|
+
|
260
|
+
# If an integrator class is passed, assume that it is a JaxSim integrator
|
261
|
+
# and build it with the default system dynamics.
|
262
|
+
case _ if issubclass(integrator, jaxsim.integrators.Integrator):
|
263
|
+
|
264
|
+
integrator_cls = integrator
|
265
|
+
integrator = integrator_cls.build(
|
266
|
+
dynamics=js.ode.wrap_system_dynamics_for_integration(
|
267
|
+
system_dynamics=js.ode.system_dynamics
|
268
|
+
)
|
269
|
+
)
|
270
|
+
|
271
|
+
case _:
|
272
|
+
raise ValueError(f"Invalid integrator: {integrator}")
|
273
|
+
|
274
|
+
# Build the model.
|
275
|
+
model = cls(
|
276
|
+
model_name=model_name,
|
277
|
+
kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build(
|
278
|
+
model_description=model_description
|
279
|
+
),
|
280
|
+
time_step=time_step,
|
281
|
+
terrain=terrain,
|
282
|
+
contact_model=contact_model,
|
283
|
+
integrator=integrator,
|
284
|
+
# The following is wrapped as hashless since it's a static argument, and we
|
285
|
+
# don't want to trigger recompilation if it changes. All relevant parameters
|
286
|
+
# needed to compute kinematics and dynamics quantities are stored in the
|
287
|
+
# kin_dyn_parameters attribute.
|
288
|
+
_description=wrappers.HashlessObject(obj=model_description),
|
289
|
+
)
|
150
290
|
|
151
291
|
return model
|
152
292
|
|
@@ -164,7 +304,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
164
304
|
|
165
305
|
return self.model_name
|
166
306
|
|
167
|
-
def number_of_links(self) ->
|
307
|
+
def number_of_links(self) -> int:
|
168
308
|
"""
|
169
309
|
Return the number of links in the model.
|
170
310
|
|
@@ -175,9 +315,9 @@ class JaxSimModel(JaxsimDataclass):
|
|
175
315
|
The base link is included in the count and its index is always 0.
|
176
316
|
"""
|
177
317
|
|
178
|
-
return self.
|
318
|
+
return self.kin_dyn_parameters.number_of_links()
|
179
319
|
|
180
|
-
def number_of_joints(self) ->
|
320
|
+
def number_of_joints(self) -> int:
|
181
321
|
"""
|
182
322
|
Return the number of joints in the model.
|
183
323
|
|
@@ -185,7 +325,18 @@ class JaxSimModel(JaxsimDataclass):
|
|
185
325
|
The number of joints in the model.
|
186
326
|
"""
|
187
327
|
|
188
|
-
return self.
|
328
|
+
return self.kin_dyn_parameters.number_of_joints()
|
329
|
+
|
330
|
+
def number_of_frames(self) -> int:
|
331
|
+
"""
|
332
|
+
Return the number of frames in the model.
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
The number of frames in the model.
|
336
|
+
|
337
|
+
"""
|
338
|
+
|
339
|
+
return self.kin_dyn_parameters.number_of_frames()
|
189
340
|
|
190
341
|
# =================
|
191
342
|
# Base link methods
|
@@ -199,7 +350,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
199
350
|
True if the model is floating-base, False otherwise.
|
200
351
|
"""
|
201
352
|
|
202
|
-
return self.
|
353
|
+
return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6)
|
203
354
|
|
204
355
|
def base_link(self) -> str:
|
205
356
|
"""
|
@@ -207,9 +358,12 @@ class JaxSimModel(JaxsimDataclass):
|
|
207
358
|
|
208
359
|
Returns:
|
209
360
|
The name of the base link.
|
361
|
+
|
362
|
+
Note:
|
363
|
+
By default, the base link is the root of the kinematic tree.
|
210
364
|
"""
|
211
365
|
|
212
|
-
return self.
|
366
|
+
return self.link_names()[0]
|
213
367
|
|
214
368
|
# =====================
|
215
369
|
# Joint-related methods
|
@@ -227,7 +381,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
227
381
|
the number of joints. In the future, this could be different.
|
228
382
|
"""
|
229
383
|
|
230
|
-
return
|
384
|
+
return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]))
|
231
385
|
|
232
386
|
def joint_names(self) -> tuple[str, ...]:
|
233
387
|
"""
|
@@ -237,7 +391,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
237
391
|
The names of the joints in the model.
|
238
392
|
"""
|
239
393
|
|
240
|
-
return
|
394
|
+
return self.kin_dyn_parameters.joint_model.joint_names[1:]
|
241
395
|
|
242
396
|
# ====================
|
243
397
|
# Link-related methods
|
@@ -251,7 +405,21 @@ class JaxSimModel(JaxsimDataclass):
|
|
251
405
|
The names of the links in the model.
|
252
406
|
"""
|
253
407
|
|
254
|
-
return
|
408
|
+
return self.kin_dyn_parameters.link_names
|
409
|
+
|
410
|
+
# =====================
|
411
|
+
# Frame-related methods
|
412
|
+
# =====================
|
413
|
+
|
414
|
+
def frame_names(self) -> tuple[str, ...]:
|
415
|
+
"""
|
416
|
+
Return the names of the frames in the model.
|
417
|
+
|
418
|
+
Returns:
|
419
|
+
The names of the frames in the model.
|
420
|
+
"""
|
421
|
+
|
422
|
+
return self.kin_dyn_parameters.frame_parameters.name
|
255
423
|
|
256
424
|
|
257
425
|
# =====================
|
@@ -259,42 +427,63 @@ class JaxSimModel(JaxsimDataclass):
|
|
259
427
|
# =====================
|
260
428
|
|
261
429
|
|
262
|
-
def reduce(
|
430
|
+
def reduce(
|
431
|
+
model: JaxSimModel,
|
432
|
+
considered_joints: tuple[str, ...],
|
433
|
+
locked_joint_positions: dict[str, jtp.FloatLike] | None = None,
|
434
|
+
) -> JaxSimModel:
|
263
435
|
"""
|
264
436
|
Reduce the model by lumping together the links connected by removed joints.
|
265
437
|
|
266
438
|
Args:
|
267
439
|
model: The model to reduce.
|
268
440
|
considered_joints: The sequence of joints to consider.
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
441
|
+
locked_joint_positions:
|
442
|
+
A dictionary containing the positions of the joints to be considered
|
443
|
+
in the reduction process. The removed joints in the reduced model
|
444
|
+
will have their position locked to their value of this dictionary.
|
445
|
+
If a joint is not part of the dictionary, its position is set to zero.
|
274
446
|
"""
|
275
447
|
|
276
|
-
|
277
|
-
|
448
|
+
locked_joint_positions = (
|
449
|
+
locked_joint_positions if locked_joint_positions is not None else {}
|
450
|
+
)
|
451
|
+
|
452
|
+
# If locked joints are passed, make sure that they are valid.
|
453
|
+
if not set(locked_joint_positions).issubset(model.joint_names()):
|
454
|
+
new_joints = set(model.joint_names()) - set(locked_joint_positions)
|
455
|
+
raise ValueError(f"Passed joints not existing in the model: {new_joints}")
|
456
|
+
|
457
|
+
# Operate on a deep copy of the model description in order to prevent problems
|
458
|
+
# when mutable attributes are updated.
|
459
|
+
intermediate_description = copy.deepcopy(model.description)
|
460
|
+
|
461
|
+
# Update the initial position of the joints.
|
462
|
+
# This is necessary to compute the correct pose of the link pairs connected
|
463
|
+
# to removed joints.
|
464
|
+
for joint_name in set(model.joint_names()) - set(considered_joints):
|
465
|
+
j = intermediate_description.joints_dict[joint_name]
|
466
|
+
with j.mutable_context():
|
467
|
+
j.initial_position = float(locked_joint_positions.get(joint_name, 0.0))
|
278
468
|
|
279
469
|
# Reduce the model description.
|
280
|
-
# If considered_joints contains joints not existing in the model,
|
281
|
-
# will raise an exception.
|
282
|
-
reduced_intermediate_description =
|
470
|
+
# If `considered_joints` contains joints not existing in the model,
|
471
|
+
# the method will raise an exception.
|
472
|
+
reduced_intermediate_description = intermediate_description.reduce(
|
283
473
|
considered_joints=list(considered_joints)
|
284
474
|
)
|
285
475
|
|
286
|
-
#
|
287
|
-
physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
|
288
|
-
model_description=reduced_intermediate_description,
|
289
|
-
gravity=model.physics_model.gravity[0:3],
|
290
|
-
)
|
291
|
-
|
292
|
-
# Build the reduced model
|
476
|
+
# Build the reduced model.
|
293
477
|
reduced_model = JaxSimModel.build(
|
294
|
-
|
478
|
+
model_description=reduced_intermediate_description,
|
479
|
+
model_name=model.name(),
|
480
|
+
time_step=model.time_step,
|
481
|
+
terrain=model.terrain,
|
482
|
+
contact_model=model.contact_model,
|
483
|
+
integrator=model.integrator,
|
295
484
|
)
|
296
485
|
|
297
|
-
# Store the origin of the model, in case downstream logic needs it
|
486
|
+
# Store the origin of the model, in case downstream logic needs it.
|
298
487
|
with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
299
488
|
reduced_model.built_from = model.built_from
|
300
489
|
|
@@ -307,6 +496,7 @@ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimMode
|
|
307
496
|
|
308
497
|
|
309
498
|
@jax.jit
|
499
|
+
@js.common.named_scope
|
310
500
|
def total_mass(model: JaxSimModel) -> jtp.Float:
|
311
501
|
"""
|
312
502
|
Compute the total mass of the model.
|
@@ -318,52 +508,25 @@ def total_mass(model: JaxSimModel) -> jtp.Float:
|
|
318
508
|
The total mass of the model.
|
319
509
|
"""
|
320
510
|
|
321
|
-
return (
|
322
|
-
jax.vmap(lambda idx: js.link.mass(model=model, link_index=idx))(
|
323
|
-
jnp.arange(model.number_of_links())
|
324
|
-
)
|
325
|
-
.sum()
|
326
|
-
.astype(float)
|
327
|
-
)
|
328
|
-
|
329
|
-
|
330
|
-
# ==============
|
331
|
-
# Center of mass
|
332
|
-
# ==============
|
511
|
+
return model.kin_dyn_parameters.link_parameters.mass.sum().astype(float)
|
333
512
|
|
334
513
|
|
335
514
|
@jax.jit
|
336
|
-
|
515
|
+
@js.common.named_scope
|
516
|
+
def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
|
337
517
|
"""
|
338
|
-
Compute the
|
518
|
+
Compute the spatial 6D inertia matrices of all links of the model.
|
339
519
|
|
340
520
|
Args:
|
341
521
|
model: The model to consider.
|
342
|
-
data: The data of the considered model.
|
343
522
|
|
344
523
|
Returns:
|
345
|
-
|
524
|
+
A 3D array containing the stacked spatial 6D inertia matrices of the links.
|
346
525
|
"""
|
347
526
|
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
W_H_B = data.base_transform()
|
352
|
-
B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
|
353
|
-
|
354
|
-
def B_p̃_LCoM(i) -> jtp.Vector:
|
355
|
-
m = js.link.mass(model=model, link_index=i)
|
356
|
-
L_p_LCoM = js.link.com_position(
|
357
|
-
model=model, data=data, link_index=i, in_link_frame=True
|
358
|
-
)
|
359
|
-
return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
|
360
|
-
|
361
|
-
com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
|
362
|
-
|
363
|
-
B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
|
364
|
-
B_p̃_CoM = B_p̃_CoM.at[3].set(1)
|
365
|
-
|
366
|
-
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
|
527
|
+
return jax.vmap(js.kin_dyn_parameters.LinkParameters.spatial_inertia)(
|
528
|
+
model.kin_dyn_parameters.link_parameters
|
529
|
+
)
|
367
530
|
|
368
531
|
|
369
532
|
# ==============================
|
@@ -372,6 +535,7 @@ def com_position(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vecto
|
|
372
535
|
|
373
536
|
|
374
537
|
@jax.jit
|
538
|
+
@js.common.named_scope
|
375
539
|
def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
|
376
540
|
"""
|
377
541
|
Compute the SE(3) transforms from the world frame to the frames of all links.
|
@@ -385,10 +549,11 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp
|
|
385
549
|
The first axis is the link index.
|
386
550
|
"""
|
387
551
|
|
388
|
-
W_H_LL = jaxsim.
|
389
|
-
model=model
|
390
|
-
|
391
|
-
|
552
|
+
W_H_LL = jaxsim.rbda.forward_kinematics_model(
|
553
|
+
model=model,
|
554
|
+
base_position=data.base_position(),
|
555
|
+
base_quaternion=data.base_orientation(dcm=False),
|
556
|
+
joint_positions=data.joint_positions(model=model),
|
392
557
|
)
|
393
558
|
|
394
559
|
return jnp.atleast_3d(W_H_LL).astype(float)
|
@@ -424,51 +589,296 @@ def generalized_free_floating_jacobian(
|
|
424
589
|
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
425
590
|
)
|
426
591
|
|
427
|
-
#
|
428
|
-
|
429
|
-
|
430
|
-
|
592
|
+
# Compute the doubly-left free-floating full jacobian.
|
593
|
+
B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
|
594
|
+
model=model,
|
595
|
+
joint_positions=data.joint_positions(),
|
596
|
+
)
|
597
|
+
|
598
|
+
# ======================================================================
|
599
|
+
# Update the input velocity representation such that v_WL = J_WL_I @ I_ν
|
600
|
+
# ======================================================================
|
601
|
+
|
602
|
+
match data.velocity_representation:
|
603
|
+
|
604
|
+
case VelRepr.Inertial:
|
605
|
+
|
606
|
+
W_H_B = data.base_transform()
|
607
|
+
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
|
608
|
+
|
609
|
+
B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841
|
610
|
+
B_J_full_WX_B
|
611
|
+
@ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
612
|
+
)
|
613
|
+
|
614
|
+
case VelRepr.Body:
|
615
|
+
|
616
|
+
B_J_full_WX_I = B_J_full_WX_B
|
617
|
+
|
618
|
+
case VelRepr.Mixed:
|
619
|
+
|
620
|
+
W_R_B = data.base_orientation(dcm=True)
|
621
|
+
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
|
622
|
+
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
623
|
+
|
624
|
+
B_J_full_WX_I = B_J_full_WX_BW = ( # noqa: F841
|
625
|
+
B_J_full_WX_B
|
626
|
+
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
627
|
+
)
|
628
|
+
|
629
|
+
case _:
|
630
|
+
raise ValueError(data.velocity_representation)
|
631
|
+
|
632
|
+
# ====================================================================
|
633
|
+
# Create stacked Jacobian for each link by filtering the full Jacobian
|
634
|
+
# ====================================================================
|
635
|
+
|
636
|
+
κ_bool = model.kin_dyn_parameters.support_body_array_bool
|
637
|
+
|
638
|
+
# Keep only the columns of the full Jacobian corresponding to the support
|
639
|
+
# body array of each link.
|
640
|
+
B_J_WL_I = jax.vmap(
|
641
|
+
lambda κ: jnp.where(
|
642
|
+
jnp.hstack([jnp.ones(5), κ]), B_J_full_WX_I, jnp.zeros_like(B_J_full_WX_I)
|
643
|
+
)
|
644
|
+
)(κ_bool)
|
645
|
+
|
646
|
+
# =======================================================================
|
647
|
+
# Update the output velocity representation such that O_v_WL = O_J_WL @ ν
|
648
|
+
# =======================================================================
|
649
|
+
|
431
650
|
match output_vel_repr:
|
651
|
+
|
432
652
|
case VelRepr.Inertial:
|
433
|
-
|
653
|
+
|
654
|
+
W_H_B = data.base_transform()
|
655
|
+
W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)
|
656
|
+
|
657
|
+
O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841
|
658
|
+
lambda B_J_WL_I: W_X_B @ B_J_WL_I
|
659
|
+
)(B_J_WL_I)
|
434
660
|
|
435
661
|
case VelRepr.Body:
|
436
662
|
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
663
|
+
O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841
|
664
|
+
lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
|
665
|
+
B_H_L, inverse=True
|
666
|
+
)
|
667
|
+
@ B_J_WL_I
|
668
|
+
)(B_H_L, B_J_WL_I)
|
441
669
|
|
442
670
|
case VelRepr.Mixed:
|
443
671
|
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
672
|
+
W_H_B = data.base_transform()
|
673
|
+
|
674
|
+
LW_H_L = jax.vmap(
|
675
|
+
lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3))
|
676
|
+
)(B_H_L)
|
677
|
+
|
678
|
+
LW_H_B = jax.vmap(
|
679
|
+
lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
|
680
|
+
)(LW_H_L, B_H_L)
|
681
|
+
|
682
|
+
O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841
|
683
|
+
lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B)
|
684
|
+
@ B_J_WL_I
|
685
|
+
)(LW_H_B, B_J_WL_I)
|
449
686
|
|
450
687
|
case _:
|
451
688
|
raise ValueError(output_vel_repr)
|
452
689
|
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
690
|
+
return O_J_WL_I
|
691
|
+
|
692
|
+
|
693
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
694
|
+
def generalized_free_floating_jacobian_derivative(
|
695
|
+
model: JaxSimModel,
|
696
|
+
data: js.data.JaxSimModelData,
|
697
|
+
*,
|
698
|
+
output_vel_repr: VelRepr | None = None,
|
699
|
+
) -> jtp.Matrix:
|
700
|
+
"""
|
701
|
+
Compute the free-floating jacobian derivatives of all links.
|
702
|
+
|
703
|
+
Args:
|
704
|
+
model: The model to consider.
|
705
|
+
data: The data of the considered model.
|
706
|
+
output_vel_repr:
|
707
|
+
The output velocity representation of the free-floating jacobian derivatives.
|
708
|
+
|
709
|
+
Returns:
|
710
|
+
The `(nL, 6, 6+dofs)` array containing the stacked free-floating
|
711
|
+
jacobian derivatives of the links. The first axis is the link index.
|
712
|
+
"""
|
713
|
+
|
714
|
+
output_vel_repr = (
|
715
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
716
|
+
)
|
717
|
+
|
718
|
+
# Compute the derivative of the doubly-left free-floating full jacobian.
|
719
|
+
B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left(
|
720
|
+
model=model,
|
721
|
+
joint_positions=data.joint_positions(),
|
722
|
+
joint_velocities=data.joint_velocities(),
|
723
|
+
)
|
724
|
+
|
725
|
+
# The derivative of the equation to change the input and output representations
|
726
|
+
# of the Jacobian derivative needs the computation of the plain link Jacobian.
|
727
|
+
B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
|
728
|
+
model=model,
|
729
|
+
joint_positions=data.joint_positions(),
|
730
|
+
)
|
731
|
+
|
732
|
+
# Compute the actual doubly-left free-floating jacobian derivative of the link
|
733
|
+
# by zeroing the columns not in the path π_B(L) using the boolean κ(i).
|
734
|
+
κb = model.kin_dyn_parameters.support_body_array_bool
|
735
|
+
|
736
|
+
# Compute the base transform.
|
737
|
+
W_H_B = data.base_transform()
|
738
|
+
|
739
|
+
# We add the 5 columns of ones to the Jacobian derivative to account for the
|
740
|
+
# base velocity and acceleration (5 + number of links = 6 + number of joints).
|
741
|
+
B_J̇_WL_B = (
|
742
|
+
jnp.hstack([jnp.ones((κb.shape[0], 5)), κb])[:, jnp.newaxis] * B_J̇_full_WX_B
|
743
|
+
)
|
744
|
+
B_J_WL_B = (
|
745
|
+
jnp.hstack([jnp.ones((κb.shape[0], 5)), κb])[:, jnp.newaxis] * B_J_full_WL_B
|
746
|
+
)
|
747
|
+
|
748
|
+
# =====================================================
|
749
|
+
# Compute quantities to adjust the input representation
|
750
|
+
# =====================================================
|
751
|
+
|
752
|
+
In = jnp.eye(model.dofs())
|
753
|
+
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
|
754
|
+
|
755
|
+
match data.velocity_representation:
|
756
|
+
|
757
|
+
case VelRepr.Inertial:
|
758
|
+
|
759
|
+
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
|
760
|
+
|
761
|
+
W_v_WB = data.base_velocity()
|
762
|
+
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
|
763
|
+
|
764
|
+
# Compute the operator to change the representation of ν, and its
|
765
|
+
# time derivative.
|
766
|
+
T = jax.scipy.linalg.block_diag(B_X_W, In)
|
767
|
+
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On)
|
768
|
+
|
769
|
+
case VelRepr.Body:
|
770
|
+
|
771
|
+
B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
|
772
|
+
translation=jnp.zeros(3), rotation=jnp.eye(3)
|
773
|
+
)
|
774
|
+
|
775
|
+
B_Ẋ_B = jnp.zeros(shape=(6, 6))
|
776
|
+
|
777
|
+
# Compute the operator to change the representation of ν, and its
|
778
|
+
# time derivative.
|
779
|
+
T = jax.scipy.linalg.block_diag(B_X_B, In)
|
780
|
+
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On)
|
781
|
+
|
782
|
+
case VelRepr.Mixed:
|
783
|
+
|
784
|
+
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
|
785
|
+
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
786
|
+
|
787
|
+
BW_v_WB = data.base_velocity()
|
788
|
+
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
789
|
+
|
790
|
+
BW_v_BW_B = BW_v_WB - BW_v_W_BW
|
791
|
+
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
|
792
|
+
|
793
|
+
# Compute the operator to change the representation of ν, and its
|
794
|
+
# time derivative.
|
795
|
+
T = jax.scipy.linalg.block_diag(B_X_BW, In)
|
796
|
+
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On)
|
797
|
+
|
798
|
+
case _:
|
799
|
+
raise ValueError(data.velocity_representation)
|
800
|
+
|
801
|
+
# ======================================================
|
802
|
+
# Compute quantities to adjust the output representation
|
803
|
+
# ======================================================
|
804
|
+
|
805
|
+
match output_vel_repr:
|
806
|
+
|
807
|
+
case VelRepr.Inertial:
|
808
|
+
|
809
|
+
O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)
|
810
|
+
|
811
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
812
|
+
B_v_WB = data.base_velocity()
|
813
|
+
|
814
|
+
O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
|
815
|
+
|
816
|
+
case VelRepr.Body:
|
817
|
+
|
818
|
+
O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
|
819
|
+
transform=B_H_L, inverse=True
|
820
|
+
)
|
821
|
+
|
822
|
+
B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)
|
823
|
+
|
824
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
825
|
+
B_v_WB = data.base_velocity()
|
826
|
+
L_v_WL = jnp.einsum(
|
827
|
+
"b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity()
|
828
|
+
)
|
829
|
+
|
830
|
+
O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
|
831
|
+
jnp.einsum("bij,bj->bi", B_X_L, L_v_WL) - B_v_WB
|
832
|
+
)
|
833
|
+
|
834
|
+
case VelRepr.Mixed:
|
835
|
+
|
836
|
+
W_H_L = W_H_B @ B_H_L
|
837
|
+
LW_H_L = W_H_L.at[:, 0:3, 3].set(jnp.zeros_like(W_H_L[:, 0:3, 3]))
|
838
|
+
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
|
839
|
+
|
840
|
+
O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B)
|
841
|
+
|
842
|
+
B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)
|
843
|
+
|
844
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
845
|
+
B_v_WB = data.base_velocity()
|
846
|
+
|
847
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
848
|
+
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
|
849
|
+
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
850
|
+
LW_v_WL = jnp.einsum(
|
851
|
+
"bij,bj->bi",
|
852
|
+
LW_X_B,
|
853
|
+
B_J_WL_B
|
854
|
+
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
855
|
+
@ data.generalized_velocity(),
|
856
|
+
)
|
857
|
+
|
858
|
+
LW_v_W_LW = LW_v_WL.at[:, 3:6].set(jnp.zeros_like(LW_v_WL[:, 3:6]))
|
859
|
+
|
860
|
+
LW_v_LW_L = LW_v_WL - LW_v_W_LW
|
861
|
+
LW_v_B_LW = LW_v_WL - jnp.einsum("bij,j->bi", LW_X_B, B_v_WB) - LW_v_LW_L
|
862
|
+
|
863
|
+
O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
|
864
|
+
jnp.einsum("bij,bj->bi", B_X_LW, LW_v_B_LW)
|
467
865
|
)
|
468
|
-
)
|
469
|
-
)(jnp.arange(model.number_of_links()))
|
470
866
|
|
471
|
-
|
867
|
+
case _:
|
868
|
+
raise ValueError(output_vel_repr)
|
869
|
+
|
870
|
+
# =============================================================
|
871
|
+
# Express the Jacobian derivative in the target representations
|
872
|
+
# =============================================================
|
873
|
+
|
874
|
+
# Sum all the components that form the Jacobian derivative in the target
|
875
|
+
# input/output velocity representations.
|
876
|
+
O_J̇_WL_I = jnp.zeros_like(B_J̇_WL_B)
|
877
|
+
O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T
|
878
|
+
O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T
|
879
|
+
O_J̇_WL_I += O_X_B @ B_J_WL_B @ Ṫ
|
880
|
+
|
881
|
+
return O_J̇_WL_I
|
472
882
|
|
473
883
|
|
474
884
|
@functools.partial(jax.jit, static_argnames=["prefer_aba"])
|
@@ -477,7 +887,7 @@ def forward_dynamics(
|
|
477
887
|
data: js.data.JaxSimModelData,
|
478
888
|
*,
|
479
889
|
joint_forces: jtp.VectorLike | None = None,
|
480
|
-
|
890
|
+
link_forces: jtp.MatrixLike | None = None,
|
481
891
|
prefer_aba: float = True,
|
482
892
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
483
893
|
"""
|
@@ -488,8 +898,8 @@ def forward_dynamics(
|
|
488
898
|
data: The data of the considered model.
|
489
899
|
joint_forces:
|
490
900
|
The joint forces to consider as a vector of shape `(dofs,)`.
|
491
|
-
|
492
|
-
The
|
901
|
+
link_forces:
|
902
|
+
The link 6D forces consider as a matrix of shape `(nL, 6)`.
|
493
903
|
The frame in which they are expressed must be `data.velocity_representation`.
|
494
904
|
prefer_aba: Whether to prefer the ABA algorithm over the CRB one.
|
495
905
|
|
@@ -505,17 +915,18 @@ def forward_dynamics(
|
|
505
915
|
model=model,
|
506
916
|
data=data,
|
507
917
|
joint_forces=joint_forces,
|
508
|
-
|
918
|
+
link_forces=link_forces,
|
509
919
|
)
|
510
920
|
|
511
921
|
|
512
922
|
@jax.jit
|
923
|
+
@js.common.named_scope
|
513
924
|
def forward_dynamics_aba(
|
514
925
|
model: JaxSimModel,
|
515
926
|
data: js.data.JaxSimModelData,
|
516
927
|
*,
|
517
928
|
joint_forces: jtp.VectorLike | None = None,
|
518
|
-
|
929
|
+
link_forces: jtp.MatrixLike | None = None,
|
519
930
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
520
931
|
"""
|
521
932
|
Compute the forward dynamics of the model with the ABA algorithm.
|
@@ -525,8 +936,8 @@ def forward_dynamics_aba(
|
|
525
936
|
data: The data of the considered model.
|
526
937
|
joint_forces:
|
527
938
|
The joint forces to consider as a vector of shape `(dofs,)`.
|
528
|
-
|
529
|
-
The
|
939
|
+
link_forces:
|
940
|
+
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
|
530
941
|
The frame in which they are expressed must be `data.velocity_representation`.
|
531
942
|
|
532
943
|
Returns:
|
@@ -535,86 +946,132 @@ def forward_dynamics_aba(
|
|
535
946
|
considered joint forces and external forces.
|
536
947
|
"""
|
537
948
|
|
538
|
-
#
|
949
|
+
# ============
|
950
|
+
# Prepare data
|
951
|
+
# ============
|
952
|
+
|
953
|
+
# Build joint forces, if not provided.
|
539
954
|
τ = (
|
540
|
-
joint_forces
|
955
|
+
jnp.atleast_1d(joint_forces.squeeze())
|
541
956
|
if joint_forces is not None
|
542
957
|
else jnp.zeros_like(data.joint_positions())
|
543
958
|
)
|
544
959
|
|
545
|
-
# Build
|
546
|
-
|
547
|
-
|
548
|
-
if
|
960
|
+
# Build link forces, if not provided.
|
961
|
+
f_L = (
|
962
|
+
jnp.atleast_2d(link_forces.squeeze())
|
963
|
+
if link_forces is not None
|
549
964
|
else jnp.zeros((model.number_of_links(), 6))
|
550
965
|
)
|
551
966
|
|
552
|
-
#
|
553
|
-
|
554
|
-
model=model
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
f_ext=f_ext,
|
967
|
+
# Create a references object that simplifies converting among representations.
|
968
|
+
references = js.references.JaxSimModelReferences.build(
|
969
|
+
model=model,
|
970
|
+
joint_force_references=τ,
|
971
|
+
link_forces=f_L,
|
972
|
+
data=data,
|
973
|
+
velocity_representation=data.velocity_representation,
|
560
974
|
)
|
561
975
|
|
562
|
-
|
563
|
-
|
976
|
+
# Extract the link and joint serializations.
|
977
|
+
link_names = model.link_names()
|
978
|
+
joint_names = model.joint_names()
|
979
|
+
|
980
|
+
# Extract the state in inertial-fixed representation.
|
981
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
982
|
+
W_p_B = data.base_position()
|
983
|
+
W_v_WB = data.base_velocity()
|
984
|
+
W_Q_B = data.base_orientation(dcm=False)
|
985
|
+
s = data.joint_positions(model=model, joint_names=joint_names)
|
986
|
+
ṡ = data.joint_velocities(model=model, joint_names=joint_names)
|
987
|
+
|
988
|
+
# Extract the inputs in inertial-fixed representation.
|
989
|
+
with references.switch_velocity_representation(VelRepr.Inertial):
|
990
|
+
W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
|
991
|
+
τ = references.joint_force_references(model=model, joint_names=joint_names)
|
992
|
+
|
993
|
+
# ========================
|
994
|
+
# Compute forward dynamics
|
995
|
+
# ========================
|
996
|
+
|
997
|
+
W_v̇_WB, s̈ = jaxsim.rbda.aba(
|
998
|
+
model=model,
|
999
|
+
base_position=W_p_B,
|
1000
|
+
base_quaternion=W_Q_B,
|
1001
|
+
joint_positions=s,
|
1002
|
+
base_linear_velocity=W_v_WB[0:3],
|
1003
|
+
base_angular_velocity=W_v_WB[3:6],
|
1004
|
+
joint_velocities=ṡ,
|
1005
|
+
joint_forces=τ,
|
1006
|
+
link_forces=W_f_L,
|
1007
|
+
standard_gravity=data.standard_gravity(),
|
1008
|
+
)
|
564
1009
|
|
565
|
-
|
566
|
-
|
1010
|
+
# =============
|
1011
|
+
# Adjust output
|
1012
|
+
# =============
|
567
1013
|
|
568
|
-
|
1014
|
+
def to_active(
|
1015
|
+
W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
|
1016
|
+
) -> jtp.Vector:
|
1017
|
+
"""
|
1018
|
+
Convert the inertial-fixed apparent base acceleration W_v̇_WB to
|
1019
|
+
another representation C_v̇_WB expressed in a generic frame C.
|
1020
|
+
"""
|
569
1021
|
|
570
|
-
|
571
|
-
|
1022
|
+
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
1023
|
+
# In Inertial and Body representations, the cross product is always zero.
|
1024
|
+
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
|
1025
|
+
return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)
|
572
1026
|
|
573
1027
|
match data.velocity_representation:
|
574
1028
|
case VelRepr.Inertial:
|
575
|
-
|
576
|
-
|
1029
|
+
# In this case C=W
|
1030
|
+
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
|
1031
|
+
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
577
1032
|
|
578
1033
|
case VelRepr.Body:
|
1034
|
+
# In this case C=B
|
579
1035
|
W_H_C = W_H_B = data.base_transform()
|
580
|
-
|
1036
|
+
W_v_WC = W_v_WB
|
581
1037
|
|
582
1038
|
case VelRepr.Mixed:
|
1039
|
+
# In this case C=B[W]
|
583
1040
|
W_H_B = data.base_transform()
|
584
|
-
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
585
|
-
|
1041
|
+
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
|
1042
|
+
W_ṗ_B = data.base_velocity()[0:3]
|
1043
|
+
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
|
586
1044
|
|
587
1045
|
case _:
|
588
1046
|
raise ValueError(data.velocity_representation)
|
589
1047
|
|
590
|
-
# We need to convert the derivative of the base
|
1048
|
+
# We need to convert the derivative of the base velocity to the active
|
591
1049
|
# representation. In Mixed representation, this conversion is not a plain
|
592
1050
|
# transformation with just X, but it also involves a cross product in ℝ⁶.
|
593
1051
|
C_v̇_WB = to_active(
|
594
|
-
|
1052
|
+
W_v̇_WB=W_v̇_WB,
|
595
1053
|
W_H_C=W_H_C,
|
596
|
-
W_v_WB=
|
597
|
-
|
598
|
-
data.state.physics_model.base_linear_velocity,
|
599
|
-
data.state.physics_model.base_angular_velocity,
|
600
|
-
]
|
601
|
-
),
|
602
|
-
W_vl_WC=W_vl_WC,
|
1054
|
+
W_v_WB=W_v_WB,
|
1055
|
+
W_v_WC=W_v_WC,
|
603
1056
|
)
|
604
1057
|
|
605
|
-
#
|
606
|
-
|
1058
|
+
# The ABA algorithm already returns a zero base 6D acceleration for
|
1059
|
+
# fixed-based models. However, the to_active function introduces an
|
1060
|
+
# additional acceleration component in Mixed representation.
|
1061
|
+
# Here below we make sure that the base acceleration is zero.
|
1062
|
+
C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6)
|
607
1063
|
|
608
|
-
return C_v̇_WB, s
|
1064
|
+
return C_v̇_WB.astype(float), s̈.astype(float)
|
609
1065
|
|
610
1066
|
|
611
1067
|
@jax.jit
|
1068
|
+
@js.common.named_scope
|
612
1069
|
def forward_dynamics_crb(
|
613
1070
|
model: JaxSimModel,
|
614
1071
|
data: js.data.JaxSimModelData,
|
615
1072
|
*,
|
616
1073
|
joint_forces: jtp.VectorLike | None = None,
|
617
|
-
|
1074
|
+
link_forces: jtp.MatrixLike | None = None,
|
618
1075
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
619
1076
|
"""
|
620
1077
|
Compute the forward dynamics of the model with the CRB algorithm.
|
@@ -624,8 +1081,8 @@ def forward_dynamics_crb(
|
|
624
1081
|
data: The data of the considered model.
|
625
1082
|
joint_forces:
|
626
1083
|
The joint forces to consider as a vector of shape `(dofs,)`.
|
627
|
-
|
628
|
-
The
|
1084
|
+
link_forces:
|
1085
|
+
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
|
629
1086
|
The frame in which they are expressed must be `data.velocity_representation`.
|
630
1087
|
|
631
1088
|
Returns:
|
@@ -638,21 +1095,25 @@ def forward_dynamics_crb(
|
|
638
1095
|
models with a large number of degrees of freedom.
|
639
1096
|
"""
|
640
1097
|
|
641
|
-
#
|
1098
|
+
# ============
|
1099
|
+
# Prepare data
|
1100
|
+
# ============
|
1101
|
+
|
1102
|
+
# Build joint torques if not provided.
|
642
1103
|
τ = (
|
643
1104
|
jnp.atleast_1d(joint_forces)
|
644
1105
|
if joint_forces is not None
|
645
1106
|
else jnp.zeros_like(data.joint_positions())
|
646
1107
|
)
|
647
1108
|
|
648
|
-
# Build external forces if not provided
|
1109
|
+
# Build external forces if not provided.
|
649
1110
|
f = (
|
650
|
-
jnp.atleast_2d(
|
651
|
-
if
|
1111
|
+
jnp.atleast_2d(link_forces)
|
1112
|
+
if link_forces is not None
|
652
1113
|
else jnp.zeros(shape=(model.number_of_links(), 6))
|
653
1114
|
)
|
654
1115
|
|
655
|
-
# Compute terms of the floating-base EoM
|
1116
|
+
# Compute terms of the floating-base EoM.
|
656
1117
|
M = free_floating_mass_matrix(model=model, data=data)
|
657
1118
|
h = free_floating_bias_forces(model=model, data=data)
|
658
1119
|
S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
|
@@ -660,6 +1121,10 @@ def forward_dynamics_crb(
|
|
660
1121
|
|
661
1122
|
# TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i)
|
662
1123
|
|
1124
|
+
# ========================
|
1125
|
+
# Compute forward dynamics
|
1126
|
+
# ========================
|
1127
|
+
|
663
1128
|
if model.floating_base():
|
664
1129
|
# l: number of links.
|
665
1130
|
# g: generalized coordinates, 6 + number of joints.
|
@@ -675,19 +1140,24 @@ def forward_dynamics_crb(
|
|
675
1140
|
v̇_WB = jnp.zeros(6)
|
676
1141
|
ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()])
|
677
1142
|
|
1143
|
+
# =============
|
1144
|
+
# Adjust output
|
1145
|
+
# =============
|
1146
|
+
|
678
1147
|
# Extract the base acceleration in the active representation.
|
679
1148
|
# Note that this is an apparent acceleration (relevant in Mixed representation),
|
680
1149
|
# therefore it cannot be always expressed in different frames with just a
|
681
1150
|
# 6D transformation X.
|
682
1151
|
v̇_WB = ν̇[0:6].squeeze().astype(float)
|
683
1152
|
|
684
|
-
# Extract the joint accelerations
|
1153
|
+
# Extract the joint accelerations.
|
685
1154
|
s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float)
|
686
1155
|
|
687
1156
|
return v̇_WB, s̈
|
688
1157
|
|
689
1158
|
|
690
1159
|
@jax.jit
|
1160
|
+
@js.common.named_scope
|
691
1161
|
def free_floating_mass_matrix(
|
692
1162
|
model: JaxSimModel, data: js.data.JaxSimModelData
|
693
1163
|
) -> jtp.Matrix:
|
@@ -702,9 +1172,9 @@ def free_floating_mass_matrix(
|
|
702
1172
|
The free-floating mass matrix of the model.
|
703
1173
|
"""
|
704
1174
|
|
705
|
-
M_body = jaxsim.
|
706
|
-
model=model
|
707
|
-
|
1175
|
+
M_body = jaxsim.rbda.crba(
|
1176
|
+
model=model,
|
1177
|
+
joint_positions=data.state.physics_model.joint_positions,
|
708
1178
|
)
|
709
1179
|
|
710
1180
|
match data.velocity_representation:
|
@@ -712,29 +1182,19 @@ def free_floating_mass_matrix(
|
|
712
1182
|
return M_body
|
713
1183
|
|
714
1184
|
case VelRepr.Inertial:
|
715
|
-
zero_6n = jnp.zeros(shape=(6, model.dofs()))
|
716
|
-
B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
|
717
1185
|
|
718
|
-
|
719
|
-
|
720
|
-
jnp.block([B_X_W, zero_6n]),
|
721
|
-
jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
|
722
|
-
]
|
1186
|
+
B_X_W = Adjoint.from_transform(
|
1187
|
+
transform=data.base_transform(), inverse=True
|
723
1188
|
)
|
1189
|
+
invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
724
1190
|
|
725
1191
|
return invT.T @ M_body @ invT
|
726
1192
|
|
727
1193
|
case VelRepr.Mixed:
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
invT = jnp.vstack(
|
733
|
-
[
|
734
|
-
jnp.block([BW_X_W, zero_6n]),
|
735
|
-
jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
|
736
|
-
]
|
737
|
-
)
|
1194
|
+
|
1195
|
+
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
1196
|
+
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1197
|
+
invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
738
1198
|
|
739
1199
|
return invT.T @ M_body @ invT
|
740
1200
|
|
@@ -743,77 +1203,206 @@ def free_floating_mass_matrix(
|
|
743
1203
|
|
744
1204
|
|
745
1205
|
@jax.jit
|
746
|
-
|
747
|
-
|
748
|
-
data: js.data.JaxSimModelData
|
749
|
-
|
750
|
-
joint_accelerations: jtp.Vector | None = None,
|
751
|
-
base_acceleration: jtp.Vector | None = None,
|
752
|
-
external_forces: jtp.Matrix | None = None,
|
753
|
-
) -> tuple[jtp.Vector, jtp.Vector]:
|
1206
|
+
@js.common.named_scope
|
1207
|
+
def free_floating_coriolis_matrix(
|
1208
|
+
model: JaxSimModel, data: js.data.JaxSimModelData
|
1209
|
+
) -> jtp.Matrix:
|
754
1210
|
"""
|
755
|
-
Compute
|
1211
|
+
Compute the free-floating Coriolis matrix of the model.
|
756
1212
|
|
757
1213
|
Args:
|
758
1214
|
model: The model to consider.
|
759
1215
|
data: The data of the considered model.
|
760
|
-
joint_accelerations:
|
761
|
-
The joint accelerations to consider as a vector of shape `(dofs,)`.
|
762
|
-
base_acceleration:
|
763
|
-
The base acceleration to consider as a vector of shape `(6,)`.
|
764
|
-
external_forces:
|
765
|
-
The external forces to consider as a matrix of shape `(nL, 6)`.
|
766
|
-
The frame in which they are expressed must be `data.velocity_representation`.
|
767
1216
|
|
768
1217
|
Returns:
|
769
|
-
|
770
|
-
|
771
|
-
|
1218
|
+
The free-floating Coriolis matrix of the model.
|
1219
|
+
|
1220
|
+
Note:
|
1221
|
+
This function, contrarily to other quantities of the equations of motion,
|
1222
|
+
does not exploit any iterative algorithm. Therefore, the computation of
|
1223
|
+
the Coriolis matrix may be much slower than other quantities.
|
772
1224
|
"""
|
773
1225
|
|
774
|
-
#
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
else jnp.zeros_like(data.joint_positions())
|
779
|
-
)
|
1226
|
+
# We perform all the calculation in body-fixed.
|
1227
|
+
# The Coriolis matrix computed in this representation is converted later
|
1228
|
+
# to the active representation stored in data.
|
1229
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
780
1230
|
|
781
|
-
|
782
|
-
base_acceleration = (
|
783
|
-
base_acceleration if base_acceleration is not None else jnp.zeros(6)
|
784
|
-
)
|
1231
|
+
B_ν = data.generalized_velocity()
|
785
1232
|
|
786
|
-
|
787
|
-
|
788
|
-
if external_forces is not None
|
789
|
-
else jnp.zeros(shape=(model.number_of_links(), 6))
|
790
|
-
)
|
1233
|
+
# Doubly-left free-floating Jacobian.
|
1234
|
+
L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
|
791
1235
|
|
792
|
-
|
793
|
-
|
794
|
-
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
|
1236
|
+
# Doubly-left free-floating Jacobian derivative.
|
1237
|
+
L_J̇_WL_B = generalized_free_floating_jacobian_derivative(model=model, data=data)
|
795
1238
|
|
796
|
-
|
797
|
-
return W_X_C @ C_v̇_WB
|
798
|
-
else:
|
799
|
-
from jaxsim.math.cross import Cross
|
1239
|
+
L_M_L = link_spatial_inertia_matrices(model=model)
|
800
1240
|
|
801
|
-
|
802
|
-
|
1241
|
+
# Body-fixed link velocities.
|
1242
|
+
# Note: we could have called link.velocity() instead of computing it ourselves,
|
1243
|
+
# but since we need the link Jacobians later, we can save a double calculation.
|
1244
|
+
L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)
|
803
1245
|
|
804
|
-
|
805
|
-
|
806
|
-
W_H_C = W_H_W = jnp.eye(4)
|
807
|
-
W_vl_WC = W_vl_WW = jnp.zeros(3)
|
1246
|
+
# Compute the contribution of each link to the Coriolis matrix.
|
1247
|
+
def compute_link_contribution(M, v, J, J̇) -> jtp.Array:
|
808
1248
|
|
809
|
-
|
810
|
-
W_H_C = W_H_B = data.base_transform()
|
811
|
-
W_vl_WC = W_vl_WB = data.base_velocity()[0:3]
|
1249
|
+
return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)
|
812
1250
|
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
1251
|
+
C_B_links = jax.vmap(compute_link_contribution)(
|
1252
|
+
L_M_L,
|
1253
|
+
L_v_WL,
|
1254
|
+
L_J_WL_B,
|
1255
|
+
L_J̇_WL_B,
|
1256
|
+
)
|
1257
|
+
|
1258
|
+
# We need to adjust the Coriolis matrix for fixed-base models.
|
1259
|
+
# In this case, the base link does not contribute to the matrix, and we need to zero
|
1260
|
+
# the off-diagonal terms mapping joint quantities onto the base configuration.
|
1261
|
+
if model.floating_base():
|
1262
|
+
C_B = C_B_links.sum(axis=0)
|
1263
|
+
else:
|
1264
|
+
C_B = C_B_links[1:].sum(axis=0)
|
1265
|
+
C_B = C_B.at[0:6, 6:].set(0.0)
|
1266
|
+
C_B = C_B.at[6:, 0:6].set(0.0)
|
1267
|
+
|
1268
|
+
# Adjust the representation of the Coriolis matrix.
|
1269
|
+
# Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.
|
1270
|
+
match data.velocity_representation:
|
1271
|
+
|
1272
|
+
case VelRepr.Body:
|
1273
|
+
return C_B
|
1274
|
+
|
1275
|
+
case VelRepr.Inertial:
|
1276
|
+
|
1277
|
+
n = model.dofs()
|
1278
|
+
W_H_B = data.base_transform()
|
1279
|
+
B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
|
1280
|
+
B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))
|
1281
|
+
|
1282
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
1283
|
+
W_v_WB = data.base_velocity()
|
1284
|
+
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
|
1285
|
+
|
1286
|
+
B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))
|
1287
|
+
|
1288
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
1289
|
+
M = free_floating_mass_matrix(model=model, data=data)
|
1290
|
+
|
1291
|
+
C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W)
|
1292
|
+
|
1293
|
+
return C
|
1294
|
+
|
1295
|
+
case VelRepr.Mixed:
|
1296
|
+
|
1297
|
+
n = model.dofs()
|
1298
|
+
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
1299
|
+
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1300
|
+
B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))
|
1301
|
+
|
1302
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
1303
|
+
BW_v_WB = data.base_velocity()
|
1304
|
+
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
1305
|
+
|
1306
|
+
BW_v_BW_B = BW_v_WB - BW_v_W_BW
|
1307
|
+
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
|
1308
|
+
|
1309
|
+
B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n)))
|
1310
|
+
|
1311
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
1312
|
+
M = free_floating_mass_matrix(model=model, data=data)
|
1313
|
+
|
1314
|
+
C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW)
|
1315
|
+
|
1316
|
+
return C
|
1317
|
+
|
1318
|
+
case _:
|
1319
|
+
raise ValueError(data.velocity_representation)
|
1320
|
+
|
1321
|
+
|
1322
|
+
@jax.jit
|
1323
|
+
@js.common.named_scope
|
1324
|
+
def inverse_dynamics(
|
1325
|
+
model: JaxSimModel,
|
1326
|
+
data: js.data.JaxSimModelData,
|
1327
|
+
*,
|
1328
|
+
joint_accelerations: jtp.VectorLike | None = None,
|
1329
|
+
base_acceleration: jtp.VectorLike | None = None,
|
1330
|
+
link_forces: jtp.MatrixLike | None = None,
|
1331
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
1332
|
+
"""
|
1333
|
+
Compute inverse dynamics with the RNEA algorithm.
|
1334
|
+
|
1335
|
+
Args:
|
1336
|
+
model: The model to consider.
|
1337
|
+
data: The data of the considered model.
|
1338
|
+
joint_accelerations:
|
1339
|
+
The joint accelerations to consider as a vector of shape `(dofs,)`.
|
1340
|
+
base_acceleration:
|
1341
|
+
The base acceleration to consider as a vector of shape `(6,)`.
|
1342
|
+
link_forces:
|
1343
|
+
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
|
1344
|
+
The frame in which they are expressed must be `data.velocity_representation`.
|
1345
|
+
|
1346
|
+
Returns:
|
1347
|
+
A tuple containing the 6D force in the active representation applied to the
|
1348
|
+
base to obtain the considered base acceleration, and the joint forces to apply
|
1349
|
+
to obtain the considered joint accelerations.
|
1350
|
+
"""
|
1351
|
+
|
1352
|
+
# ============
|
1353
|
+
# Prepare data
|
1354
|
+
# ============
|
1355
|
+
|
1356
|
+
# Build joint accelerations, if not provided.
|
1357
|
+
s̈ = (
|
1358
|
+
jnp.atleast_1d(jnp.array(joint_accelerations).squeeze())
|
1359
|
+
if joint_accelerations is not None
|
1360
|
+
else jnp.zeros_like(data.joint_positions())
|
1361
|
+
)
|
1362
|
+
|
1363
|
+
# Build base acceleration, if not provided.
|
1364
|
+
v̇_WB = (
|
1365
|
+
jnp.array(base_acceleration).squeeze()
|
1366
|
+
if base_acceleration is not None
|
1367
|
+
else jnp.zeros(6)
|
1368
|
+
)
|
1369
|
+
|
1370
|
+
# Build link forces, if not provided.
|
1371
|
+
f_L = (
|
1372
|
+
jnp.atleast_2d(jnp.array(link_forces).squeeze())
|
1373
|
+
if link_forces is not None
|
1374
|
+
else jnp.zeros(shape=(model.number_of_links(), 6))
|
1375
|
+
)
|
1376
|
+
|
1377
|
+
def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
|
1378
|
+
"""
|
1379
|
+
Convert the active representation of the base acceleration C_v̇_WB
|
1380
|
+
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
|
1381
|
+
"""
|
1382
|
+
|
1383
|
+
W_X_C = Adjoint.from_transform(transform=W_H_C)
|
1384
|
+
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
|
1385
|
+
C_v_WC = C_X_W @ W_v_WC
|
1386
|
+
|
1387
|
+
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
1388
|
+
# In Inertial and Body representations, the cross product is always zero.
|
1389
|
+
return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)
|
1390
|
+
|
1391
|
+
match data.velocity_representation:
|
1392
|
+
case VelRepr.Inertial:
|
1393
|
+
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
|
1394
|
+
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
1395
|
+
|
1396
|
+
case VelRepr.Body:
|
1397
|
+
W_H_C = W_H_B = data.base_transform()
|
1398
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
1399
|
+
W_v_WC = W_v_WB = data.base_velocity()
|
1400
|
+
|
1401
|
+
case VelRepr.Mixed:
|
1402
|
+
W_H_B = data.base_transform()
|
1403
|
+
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
|
1404
|
+
W_ṗ_B = data.base_velocity()[0:3]
|
1405
|
+
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
|
817
1406
|
|
818
1407
|
case _:
|
819
1408
|
raise ValueError(data.velocity_representation)
|
@@ -822,35 +1411,60 @@ def inverse_dynamics(
|
|
822
1411
|
# representation. In Mixed representation, this conversion is not a plain
|
823
1412
|
# transformation with just X, but it also involves a cross product in ℝ⁶.
|
824
1413
|
W_v̇_WB = to_inertial(
|
825
|
-
C_v̇_WB=
|
1414
|
+
C_v̇_WB=v̇_WB,
|
826
1415
|
W_H_C=W_H_C,
|
827
1416
|
C_v_WB=data.base_velocity(),
|
828
|
-
|
1417
|
+
W_v_WC=W_v_WC,
|
829
1418
|
)
|
830
1419
|
|
1420
|
+
# Create a references object that simplifies converting among representations.
|
831
1421
|
references = js.references.JaxSimModelReferences.build(
|
832
1422
|
model=model,
|
833
1423
|
data=data,
|
834
|
-
link_forces=
|
1424
|
+
link_forces=f_L,
|
835
1425
|
velocity_representation=data.velocity_representation,
|
836
1426
|
)
|
837
1427
|
|
838
|
-
#
|
1428
|
+
# Extract the link and joint serializations.
|
1429
|
+
link_names = model.link_names()
|
1430
|
+
joint_names = model.joint_names()
|
1431
|
+
|
1432
|
+
# Extract the state in inertial-fixed representation.
|
1433
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
1434
|
+
W_p_B = data.base_position()
|
1435
|
+
W_v_WB = data.base_velocity()
|
1436
|
+
W_Q_B = data.base_orientation(dcm=False)
|
1437
|
+
s = data.joint_positions(model=model, joint_names=joint_names)
|
1438
|
+
ṡ = data.joint_velocities(model=model, joint_names=joint_names)
|
1439
|
+
|
1440
|
+
# Extract the inputs in inertial-fixed representation.
|
839
1441
|
with references.switch_velocity_representation(VelRepr.Inertial):
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
1442
|
+
W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
|
1443
|
+
|
1444
|
+
# ========================
|
1445
|
+
# Compute inverse dynamics
|
1446
|
+
# ========================
|
1447
|
+
|
1448
|
+
W_f_B, τ = jaxsim.rbda.rnea(
|
1449
|
+
model=model,
|
1450
|
+
base_position=W_p_B,
|
1451
|
+
base_quaternion=W_Q_B,
|
1452
|
+
joint_positions=s,
|
1453
|
+
base_linear_velocity=W_v_WB[0:3],
|
1454
|
+
base_angular_velocity=W_v_WB[3:6],
|
1455
|
+
joint_velocities=ṡ,
|
1456
|
+
base_linear_acceleration=W_v̇_WB[0:3],
|
1457
|
+
base_angular_acceleration=W_v̇_WB[3:6],
|
1458
|
+
joint_accelerations=s̈,
|
1459
|
+
link_forces=W_f_L,
|
1460
|
+
standard_gravity=data.standard_gravity(),
|
1461
|
+
)
|
849
1462
|
|
850
|
-
#
|
851
|
-
|
1463
|
+
# =============
|
1464
|
+
# Adjust output
|
1465
|
+
# =============
|
852
1466
|
|
853
|
-
# Express W_f_B in the active representation
|
1467
|
+
# Express W_f_B in the active representation.
|
854
1468
|
f_B = js.data.JaxSimModelData.inertial_to_other_representation(
|
855
1469
|
array=W_f_B,
|
856
1470
|
other_representation=data.velocity_representation,
|
@@ -862,10 +1476,11 @@ def inverse_dynamics(
|
|
862
1476
|
|
863
1477
|
|
864
1478
|
@jax.jit
|
1479
|
+
@js.common.named_scope
|
865
1480
|
def free_floating_gravity_forces(
|
866
1481
|
model: JaxSimModel, data: js.data.JaxSimModelData
|
867
1482
|
) -> jtp.Vector:
|
868
|
-
"""
|
1483
|
+
r"""
|
869
1484
|
Compute the free-floating gravity forces :math:`g(\mathbf{q})` of the model.
|
870
1485
|
|
871
1486
|
Args:
|
@@ -876,12 +1491,12 @@ def free_floating_gravity_forces(
|
|
876
1491
|
The free-floating gravity forces of the model.
|
877
1492
|
"""
|
878
1493
|
|
879
|
-
# Build a zeroed state
|
1494
|
+
# Build a zeroed state.
|
880
1495
|
data_rnea = js.data.JaxSimModelData.zero(
|
881
1496
|
model=model, velocity_representation=data.velocity_representation
|
882
1497
|
)
|
883
1498
|
|
884
|
-
# Set just the generalized position
|
1499
|
+
# Set just the generalized position.
|
885
1500
|
with data_rnea.mutable_context(
|
886
1501
|
mutability=Mutability.MUTABLE, restore_after_exception=False
|
887
1502
|
):
|
@@ -905,16 +1520,17 @@ def free_floating_gravity_forces(
|
|
905
1520
|
# Set zero inputs:
|
906
1521
|
joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
|
907
1522
|
base_acceleration=jnp.zeros(6),
|
908
|
-
|
1523
|
+
link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
|
909
1524
|
)
|
910
1525
|
).astype(float)
|
911
1526
|
|
912
1527
|
|
913
1528
|
@jax.jit
|
1529
|
+
@js.common.named_scope
|
914
1530
|
def free_floating_bias_forces(
|
915
1531
|
model: JaxSimModel, data: js.data.JaxSimModelData
|
916
1532
|
) -> jtp.Vector:
|
917
|
-
"""
|
1533
|
+
r"""
|
918
1534
|
Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})`
|
919
1535
|
of the model.
|
920
1536
|
|
@@ -926,12 +1542,12 @@ def free_floating_bias_forces(
|
|
926
1542
|
The free-floating bias forces of the model.
|
927
1543
|
"""
|
928
1544
|
|
929
|
-
# Build a zeroed state
|
1545
|
+
# Build a zeroed state.
|
930
1546
|
data_rnea = js.data.JaxSimModelData.zero(
|
931
1547
|
model=model, velocity_representation=data.velocity_representation
|
932
1548
|
)
|
933
1549
|
|
934
|
-
# Set the generalized position and generalized velocity
|
1550
|
+
# Set the generalized position and generalized velocity.
|
935
1551
|
with data_rnea.mutable_context(
|
936
1552
|
mutability=Mutability.MUTABLE, restore_after_exception=False
|
937
1553
|
):
|
@@ -948,18 +1564,20 @@ def free_floating_bias_forces(
|
|
948
1564
|
data.state.physics_model.joint_positions
|
949
1565
|
)
|
950
1566
|
|
951
|
-
data_rnea.state.physics_model.base_linear_velocity = (
|
952
|
-
data.state.physics_model.base_linear_velocity
|
953
|
-
)
|
954
|
-
|
955
|
-
data_rnea.state.physics_model.base_angular_velocity = (
|
956
|
-
data.state.physics_model.base_angular_velocity
|
957
|
-
)
|
958
|
-
|
959
1567
|
data_rnea.state.physics_model.joint_velocities = (
|
960
1568
|
data.state.physics_model.joint_velocities
|
961
1569
|
)
|
962
1570
|
|
1571
|
+
# Make sure that base velocity is zero for fixed-base model.
|
1572
|
+
if model.floating_base():
|
1573
|
+
data_rnea.state.physics_model.base_linear_velocity = (
|
1574
|
+
data.state.physics_model.base_linear_velocity
|
1575
|
+
)
|
1576
|
+
|
1577
|
+
data_rnea.state.physics_model.base_angular_velocity = (
|
1578
|
+
data.state.physics_model.base_angular_velocity
|
1579
|
+
)
|
1580
|
+
|
963
1581
|
return jnp.hstack(
|
964
1582
|
inverse_dynamics(
|
965
1583
|
model=model,
|
@@ -967,7 +1585,7 @@ def free_floating_bias_forces(
|
|
967
1585
|
# Set zero inputs:
|
968
1586
|
joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
|
969
1587
|
base_acceleration=jnp.zeros(6),
|
970
|
-
|
1588
|
+
link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
|
971
1589
|
)
|
972
1590
|
).astype(float)
|
973
1591
|
|
@@ -978,6 +1596,26 @@ def free_floating_bias_forces(
|
|
978
1596
|
|
979
1597
|
|
980
1598
|
@jax.jit
|
1599
|
+
@js.common.named_scope
|
1600
|
+
def locked_spatial_inertia(
|
1601
|
+
model: JaxSimModel, data: js.data.JaxSimModelData
|
1602
|
+
) -> jtp.Matrix:
|
1603
|
+
"""
|
1604
|
+
Compute the locked 6D inertia matrix of the model.
|
1605
|
+
|
1606
|
+
Args:
|
1607
|
+
model: The model to consider.
|
1608
|
+
data: The data of the considered model.
|
1609
|
+
|
1610
|
+
Returns:
|
1611
|
+
The locked 6D inertia matrix of the model.
|
1612
|
+
"""
|
1613
|
+
|
1614
|
+
return total_momentum_jacobian(model=model, data=data)[:, 0:6]
|
1615
|
+
|
1616
|
+
|
1617
|
+
@jax.jit
|
1618
|
+
@js.common.named_scope
|
981
1619
|
def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
|
982
1620
|
"""
|
983
1621
|
Compute the total momentum of the model.
|
@@ -987,35 +1625,453 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec
|
|
987
1625
|
data: The data of the considered model.
|
988
1626
|
|
989
1627
|
Returns:
|
990
|
-
The total momentum of the model.
|
1628
|
+
The total momentum of the model in the active velocity representation.
|
991
1629
|
"""
|
992
1630
|
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
1631
|
+
ν = data.generalized_velocity()
|
1632
|
+
Jh = total_momentum_jacobian(model=model, data=data)
|
1633
|
+
|
1634
|
+
return Jh @ ν
|
1635
|
+
|
1636
|
+
|
1637
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
1638
|
+
def total_momentum_jacobian(
|
1639
|
+
model: JaxSimModel,
|
1640
|
+
data: js.data.JaxSimModelData,
|
1641
|
+
*,
|
1642
|
+
output_vel_repr: VelRepr | None = None,
|
1643
|
+
) -> jtp.Matrix:
|
1644
|
+
"""
|
1645
|
+
Compute the jacobian of the total momentum.
|
1646
|
+
|
1647
|
+
Args:
|
1648
|
+
model: The model to consider.
|
1649
|
+
data: The data of the considered model.
|
1650
|
+
output_vel_repr: The output velocity representation of the jacobian.
|
1651
|
+
|
1652
|
+
Returns:
|
1653
|
+
The jacobian of the total momentum of the model in the active representation.
|
1654
|
+
"""
|
1655
|
+
|
1656
|
+
output_vel_repr = (
|
1657
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
1658
|
+
)
|
1659
|
+
|
1660
|
+
if output_vel_repr is data.velocity_representation:
|
1661
|
+
return free_floating_mass_matrix(model=model, data=data)[0:6]
|
1662
|
+
|
1663
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
1664
|
+
B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6]
|
1665
|
+
|
1666
|
+
match data.velocity_representation:
|
1667
|
+
case VelRepr.Body:
|
1668
|
+
B_Jh = B_Jh_B
|
1669
|
+
|
1670
|
+
case VelRepr.Inertial:
|
1671
|
+
B_X_W = Adjoint.from_transform(
|
1672
|
+
transform=data.base_transform(), inverse=True
|
1673
|
+
)
|
1674
|
+
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
|
1675
|
+
|
1676
|
+
case VelRepr.Mixed:
|
1677
|
+
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
1678
|
+
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1679
|
+
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
|
1680
|
+
|
1681
|
+
case _:
|
1682
|
+
raise ValueError(data.velocity_representation)
|
1683
|
+
|
1684
|
+
match output_vel_repr:
|
1685
|
+
case VelRepr.Body:
|
1686
|
+
return B_Jh
|
1687
|
+
|
1688
|
+
case VelRepr.Inertial:
|
1689
|
+
W_H_B = data.base_transform()
|
1690
|
+
B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
|
1691
|
+
W_Xf_B = B_Xv_W.T
|
1692
|
+
W_Jh = W_Xf_B @ B_Jh
|
1693
|
+
return W_Jh
|
1694
|
+
|
1695
|
+
case VelRepr.Mixed:
|
1696
|
+
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
1697
|
+
B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
1698
|
+
BW_Xf_B = B_Xv_BW.T
|
1699
|
+
BW_Jh = BW_Xf_B @ B_Jh
|
1700
|
+
return BW_Jh
|
1701
|
+
|
1702
|
+
case _:
|
1703
|
+
raise ValueError(output_vel_repr)
|
1704
|
+
|
1705
|
+
|
1706
|
+
@jax.jit
|
1707
|
+
@js.common.named_scope
|
1708
|
+
def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
|
1709
|
+
"""
|
1710
|
+
Compute the average velocity of the model.
|
1711
|
+
|
1712
|
+
Args:
|
1713
|
+
model: The model to consider.
|
1714
|
+
data: The data of the considered model.
|
1715
|
+
|
1716
|
+
Returns:
|
1717
|
+
The average velocity of the model computed in the base frame and expressed
|
1718
|
+
in the active representation.
|
1719
|
+
"""
|
1720
|
+
|
1721
|
+
ν = data.generalized_velocity()
|
1722
|
+
J = average_velocity_jacobian(model=model, data=data)
|
1723
|
+
|
1724
|
+
return J @ ν
|
1725
|
+
|
1726
|
+
|
1727
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
1728
|
+
def average_velocity_jacobian(
|
1729
|
+
model: JaxSimModel,
|
1730
|
+
data: js.data.JaxSimModelData,
|
1731
|
+
*,
|
1732
|
+
output_vel_repr: VelRepr | None = None,
|
1733
|
+
) -> jtp.Matrix:
|
1734
|
+
"""
|
1735
|
+
Compute the Jacobian of the average velocity of the model.
|
1736
|
+
|
1737
|
+
Args:
|
1738
|
+
model: The model to consider.
|
1739
|
+
data: The data of the considered model.
|
1740
|
+
output_vel_repr: The output velocity representation of the jacobian.
|
1741
|
+
|
1742
|
+
Returns:
|
1743
|
+
The Jacobian of the average centroidal velocity of the model in the desired
|
1744
|
+
representation.
|
1745
|
+
"""
|
1746
|
+
|
1747
|
+
output_vel_repr = (
|
1748
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
1749
|
+
)
|
1750
|
+
|
1751
|
+
# Depending on the velocity representation, the frame G is either G[W] or G[B].
|
1752
|
+
G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data)
|
1753
|
+
|
1754
|
+
match output_vel_repr:
|
1755
|
+
|
1756
|
+
case VelRepr.Inertial:
|
1757
|
+
|
1758
|
+
GW_J = G_J
|
1759
|
+
W_p_CoM = js.com.com_position(model=model, data=data)
|
1760
|
+
|
1761
|
+
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
|
1762
|
+
W_X_GW = Adjoint.from_transform(transform=W_H_GW)
|
1763
|
+
|
1764
|
+
return W_X_GW @ GW_J
|
1765
|
+
|
1766
|
+
case VelRepr.Body:
|
1767
|
+
|
1768
|
+
GB_J = G_J
|
1769
|
+
W_p_B = data.base_position()
|
1770
|
+
W_p_CoM = js.com.com_position(model=model, data=data)
|
1771
|
+
B_R_W = data.base_orientation(dcm=True).transpose()
|
1772
|
+
|
1773
|
+
B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
|
1774
|
+
B_X_GB = Adjoint.from_transform(transform=B_H_GB)
|
1775
|
+
|
1776
|
+
return B_X_GB @ GB_J
|
1777
|
+
|
1778
|
+
case VelRepr.Mixed:
|
999
1779
|
|
1000
|
-
|
1001
|
-
|
1780
|
+
GW_J = G_J
|
1781
|
+
W_p_B = data.base_position()
|
1782
|
+
W_p_CoM = js.com.com_position(model=model, data=data)
|
1002
1783
|
|
1003
|
-
|
1784
|
+
BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
|
1785
|
+
BW_X_GW = Adjoint.from_transform(transform=BW_H_GW)
|
1786
|
+
|
1787
|
+
return BW_X_GW @ GW_J
|
1788
|
+
|
1789
|
+
|
1790
|
+
# ========================
|
1791
|
+
# Other dynamic quantities
|
1792
|
+
# ========================
|
1793
|
+
|
1794
|
+
|
1795
|
+
@jax.jit
|
1796
|
+
@js.common.named_scope
|
1797
|
+
def link_bias_accelerations(
|
1798
|
+
model: JaxSimModel,
|
1799
|
+
data: js.data.JaxSimModelData,
|
1800
|
+
) -> jtp.Vector:
|
1801
|
+
r"""
|
1802
|
+
Compute the bias accelerations of the links of the model.
|
1803
|
+
|
1804
|
+
Args:
|
1805
|
+
model: The model to consider.
|
1806
|
+
data: The data of the considered model.
|
1807
|
+
|
1808
|
+
Returns:
|
1809
|
+
The bias accelerations of the links of the model.
|
1810
|
+
|
1811
|
+
Note:
|
1812
|
+
This function computes the component of the total 6D acceleration not due to
|
1813
|
+
the joint or base acceleration.
|
1814
|
+
It is often called :math:`\dot{J} \boldsymbol{\nu}`.
|
1815
|
+
"""
|
1816
|
+
|
1817
|
+
# ================================================
|
1818
|
+
# Compute the body-fixed zero base 6D acceleration
|
1819
|
+
# ================================================
|
1820
|
+
|
1821
|
+
# Compute the base transform.
|
1004
1822
|
W_H_B = data.base_transform()
|
1005
|
-
B_X_W: jtp.Array = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
|
1006
1823
|
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1824
|
+
def other_representation_to_inertial(
|
1825
|
+
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
|
1826
|
+
) -> jtp.Vector:
|
1827
|
+
"""
|
1828
|
+
Convert the active representation of the base acceleration C_v̇_WB
|
1829
|
+
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
|
1830
|
+
"""
|
1831
|
+
|
1832
|
+
W_X_C = Adjoint.from_transform(transform=W_H_C)
|
1833
|
+
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
|
1010
1834
|
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1835
|
+
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
1836
|
+
# In Inertial and Body representations, the cross product is always zero.
|
1837
|
+
return W_X_C @ (C_v̇_WB + jaxsim.math.Cross.vx(C_X_W @ W_v_WC) @ C_v_WB)
|
1838
|
+
|
1839
|
+
# Here we initialize a zero 6D acceleration in the active representation, and
|
1840
|
+
# convert it to inertial-fixed. This is a useful intermediate representation
|
1841
|
+
# because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration
|
1842
|
+
# W_a_WB, and intrinsic accelerations can be expressed in different frames through
|
1843
|
+
# a simple C_X_W 6D transform.
|
1844
|
+
match data.velocity_representation:
|
1845
|
+
case VelRepr.Inertial:
|
1846
|
+
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
|
1847
|
+
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
1848
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
1849
|
+
C_v_WB = W_v_WB = data.base_velocity()
|
1850
|
+
|
1851
|
+
case VelRepr.Body:
|
1852
|
+
W_H_C = W_H_B
|
1853
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
1854
|
+
W_v_WC = W_v_WB = data.base_velocity() # noqa: F841
|
1855
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
1856
|
+
C_v_WB = B_v_WB = data.base_velocity()
|
1857
|
+
|
1858
|
+
case VelRepr.Mixed:
|
1859
|
+
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
1860
|
+
W_H_C = W_H_BW
|
1861
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
1862
|
+
W_ṗ_B = data.base_velocity()[0:3]
|
1863
|
+
BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
|
1864
|
+
W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
|
1865
|
+
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841
|
1866
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
1867
|
+
C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841
|
1868
|
+
|
1869
|
+
case _:
|
1870
|
+
raise ValueError(data.velocity_representation)
|
1871
|
+
|
1872
|
+
# Convert a zero 6D acceleration from the active representation to inertial-fixed.
|
1873
|
+
W_v̇_WB = other_representation_to_inertial(
|
1874
|
+
C_v̇_WB=jnp.zeros(6), C_v_WB=C_v_WB, W_H_C=W_H_C, W_v_WC=W_v_WC
|
1875
|
+
)
|
1876
|
+
|
1877
|
+
# ===================================
|
1878
|
+
# Initialize buffers and prepare data
|
1879
|
+
# ===================================
|
1880
|
+
|
1881
|
+
# Get the parent array λ(i).
|
1882
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
1883
|
+
λ = model.kin_dyn_parameters.parent_array
|
1884
|
+
|
1885
|
+
# Compute 6D transforms of the base velocity.
|
1886
|
+
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
|
1887
|
+
|
1888
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
1889
|
+
# These transforms define the relative kinematics of the entire model, including
|
1890
|
+
# the base transform for both floating-base and fixed-base models.
|
1891
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
1892
|
+
joint_positions=data.joint_positions(), base_transform=W_H_B
|
1893
|
+
)
|
1894
|
+
|
1895
|
+
# Allocate the buffer to store the body-fixed link velocities.
|
1896
|
+
L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6))
|
1897
|
+
|
1898
|
+
# Store the base velocity.
|
1899
|
+
with data.switch_velocity_representation(VelRepr.Body):
|
1900
|
+
B_v_WB = data.base_velocity()
|
1901
|
+
L_v_WL = L_v_WL.at[0].set(B_v_WB)
|
1902
|
+
|
1903
|
+
# Get the joint velocities.
|
1904
|
+
ṡ = data.joint_velocities(model=model, joint_names=model.joint_names())
|
1905
|
+
|
1906
|
+
# Allocate the buffer to store the body-fixed link accelerations,
|
1907
|
+
# and initialize the base acceleration.
|
1908
|
+
L_v̇_WL = jnp.zeros(shape=(model.number_of_links(), 6))
|
1909
|
+
L_v̇_WL = L_v̇_WL.at[0].set(B_X_W @ W_v̇_WB)
|
1910
|
+
|
1911
|
+
# ======================================
|
1912
|
+
# Propagate accelerations and velocities
|
1913
|
+
# ======================================
|
1914
|
+
|
1915
|
+
# The computation of the bias forces is similar to the forward pass of RNEA,
|
1916
|
+
# this time with zero base and joint accelerations. Furthermore, here we do
|
1917
|
+
# not remove gravity during the propagation.
|
1918
|
+
|
1919
|
+
# Initialize the loop.
|
1920
|
+
Carry = tuple[jtp.Matrix, jtp.Matrix]
|
1921
|
+
carry0: Carry = (L_v_WL, L_v̇_WL)
|
1922
|
+
|
1923
|
+
def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]:
|
1924
|
+
# Initialize index and unpack the carry.
|
1925
|
+
ii = i - 1
|
1926
|
+
v, a = carry
|
1927
|
+
|
1928
|
+
# Get the motion subspace of the joint.
|
1929
|
+
Si = S[i].squeeze()
|
1930
|
+
|
1931
|
+
# Project the joint velocity into its motion subspace.
|
1932
|
+
vJ = Si * ṡ[ii]
|
1933
|
+
|
1934
|
+
# Propagate the link body-fixed velocity.
|
1935
|
+
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
1936
|
+
v = v.at[i].set(v_i)
|
1937
|
+
|
1938
|
+
# Propagate the link body-fixed acceleration considering zero joint acceleration.
|
1939
|
+
s̈ = 0.0
|
1940
|
+
a_i = i_X_λi[i] @ a[λ[i]] + Si * s̈ + jaxsim.math.Cross.vx(v[i]) @ vJ
|
1941
|
+
a = a.at[i].set(a_i)
|
1942
|
+
|
1943
|
+
return (v, a), None
|
1944
|
+
|
1945
|
+
# Compute the body-fixed velocity and body-fixed apparent acceleration of the links.
|
1946
|
+
(L_v_WL, L_v̇_WL), _ = (
|
1947
|
+
jax.lax.scan(
|
1948
|
+
f=propagate_accelerations,
|
1949
|
+
init=carry0,
|
1950
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
1951
|
+
)
|
1952
|
+
if model.number_of_links() > 1
|
1953
|
+
else [(L_v_WL, L_v̇_WL), None]
|
1954
|
+
)
|
1955
|
+
|
1956
|
+
# ===================================================================
|
1957
|
+
# Convert the body-fixed 6D acceleration to the active representation
|
1958
|
+
# ===================================================================
|
1959
|
+
|
1960
|
+
def body_to_other_representation(
|
1961
|
+
L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
|
1962
|
+
) -> jtp.Vector:
|
1963
|
+
"""
|
1964
|
+
Convert the body-fixed apparent acceleration L_v̇_WL to
|
1965
|
+
another representation C_v̇_WL expressed in a generic frame C.
|
1966
|
+
"""
|
1967
|
+
|
1968
|
+
# In Mixed representation, we need to include a cross product in ℝ⁶.
|
1969
|
+
# In Inertial and Body representations, the cross product is always zero.
|
1970
|
+
C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L)
|
1971
|
+
return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL)
|
1972
|
+
|
1973
|
+
match data.velocity_representation:
|
1974
|
+
case VelRepr.Body:
|
1975
|
+
C_H_L = L_H_L = jnp.stack( # noqa: F841
|
1976
|
+
[jnp.eye(4)] * model.number_of_links()
|
1977
|
+
)
|
1978
|
+
L_v_CL = L_v_LL = jnp.zeros( # noqa: F841
|
1979
|
+
shape=(model.number_of_links(), 6)
|
1980
|
+
)
|
1981
|
+
|
1982
|
+
case VelRepr.Inertial:
|
1983
|
+
C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
|
1984
|
+
L_v_CL = L_v_WL
|
1985
|
+
|
1986
|
+
case VelRepr.Mixed:
|
1987
|
+
W_H_L = js.model.forward_kinematics(model=model, data=data)
|
1988
|
+
LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
|
1989
|
+
C_H_L = LW_H_L
|
1990
|
+
L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
|
1991
|
+
lambda v: v.at[0:3].set(jnp.zeros(3))
|
1992
|
+
)(L_v_WL)
|
1993
|
+
|
1994
|
+
case _:
|
1995
|
+
raise ValueError(data.velocity_representation)
|
1996
|
+
|
1997
|
+
# Convert from body-fixed to the active representation.
|
1998
|
+
O_v̇_WL = jax.vmap(body_to_other_representation)(
|
1999
|
+
L_v̇_WL=L_v̇_WL, L_v_WL=L_v_WL, C_H_L=C_H_L, L_v_CL=L_v_CL
|
2000
|
+
)
|
2001
|
+
|
2002
|
+
return O_v̇_WL
|
2003
|
+
|
2004
|
+
|
2005
|
+
@jax.jit
|
2006
|
+
@js.common.named_scope
|
2007
|
+
def link_contact_forces(
|
2008
|
+
model: js.model.JaxSimModel,
|
2009
|
+
data: js.data.JaxSimModelData,
|
2010
|
+
*,
|
2011
|
+
link_forces: jtp.MatrixLike | None = None,
|
2012
|
+
joint_force_references: jtp.VectorLike | None = None,
|
2013
|
+
**kwargs,
|
2014
|
+
) -> jtp.Matrix:
|
2015
|
+
"""
|
2016
|
+
Compute the 6D contact forces of all links of the model.
|
2017
|
+
|
2018
|
+
Args:
|
2019
|
+
model: The model to consider.
|
2020
|
+
data: The data of the considered model.
|
2021
|
+
link_forces:
|
2022
|
+
The 6D external forces to apply to the links expressed in the same
|
2023
|
+
representation of data.
|
2024
|
+
joint_force_references:
|
2025
|
+
The joint force references to apply to the joints.
|
2026
|
+
kwargs: Additional keyword arguments to pass to the active contact model..
|
2027
|
+
|
2028
|
+
Returns:
|
2029
|
+
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
|
2030
|
+
expressed in the frame corresponding to the active representation.
|
2031
|
+
"""
|
2032
|
+
|
2033
|
+
# Note: the following code should be kept in sync with the function
|
2034
|
+
# `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since
|
2035
|
+
# there we need to get also aux_data.
|
2036
|
+
|
2037
|
+
# Build link forces if not provided.
|
2038
|
+
# These forces are expressed in the frame corresponding to the velocity
|
2039
|
+
# representation of data.
|
2040
|
+
O_f_L = (
|
2041
|
+
jnp.atleast_2d(link_forces.squeeze())
|
2042
|
+
if link_forces is not None
|
2043
|
+
else jnp.zeros((model.number_of_links(), 6))
|
1017
2044
|
).astype(float)
|
1018
2045
|
|
2046
|
+
# Build joint force references if not provided.
|
2047
|
+
joint_force_references = (
|
2048
|
+
jnp.atleast_1d(joint_force_references)
|
2049
|
+
if joint_force_references is not None
|
2050
|
+
else jnp.zeros(model.dofs())
|
2051
|
+
)
|
2052
|
+
|
2053
|
+
# We expect that the 6D forces included in the `link_forces` argument are expressed
|
2054
|
+
# in the frame corresponding to the velocity representation of `data`.
|
2055
|
+
input_references = js.references.JaxSimModelReferences.build(
|
2056
|
+
model=model,
|
2057
|
+
data=data,
|
2058
|
+
velocity_representation=data.velocity_representation,
|
2059
|
+
link_forces=O_f_L,
|
2060
|
+
joint_force_references=joint_force_references,
|
2061
|
+
)
|
2062
|
+
|
2063
|
+
# Compute the 6D forces applied to the links equivalent to the forces applied
|
2064
|
+
# to the frames associated to the collidable points.
|
2065
|
+
f_L, _ = model.contact_model.compute_link_contact_forces(
|
2066
|
+
model=model,
|
2067
|
+
data=data,
|
2068
|
+
link_forces=input_references.link_forces(model=model, data=data),
|
2069
|
+
joint_force_references=input_references.joint_force_references(),
|
2070
|
+
**kwargs,
|
2071
|
+
)
|
2072
|
+
|
2073
|
+
return f_L
|
2074
|
+
|
1019
2075
|
|
1020
2076
|
# ======
|
1021
2077
|
# Energy
|
@@ -1023,6 +2079,7 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec
|
|
1023
2079
|
|
1024
2080
|
|
1025
2081
|
@jax.jit
|
2082
|
+
@js.common.named_scope
|
1026
2083
|
def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
|
1027
2084
|
"""
|
1028
2085
|
Compute the mechanical energy of the model.
|
@@ -1042,6 +2099,7 @@ def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.
|
|
1042
2099
|
|
1043
2100
|
|
1044
2101
|
@jax.jit
|
2102
|
+
@js.common.named_scope
|
1045
2103
|
def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
|
1046
2104
|
"""
|
1047
2105
|
Compute the kinetic energy of the model.
|
@@ -1063,6 +2121,7 @@ def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Flo
|
|
1063
2121
|
|
1064
2122
|
|
1065
2123
|
@jax.jit
|
2124
|
+
@js.common.named_scope
|
1066
2125
|
def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
|
1067
2126
|
"""
|
1068
2127
|
Compute the potential energy of the model.
|
@@ -1077,7 +2136,7 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
|
|
1077
2136
|
|
1078
2137
|
m = total_mass(model=model)
|
1079
2138
|
gravity = data.gravity.squeeze()
|
1080
|
-
W_p̃_CoM = jnp.hstack([com_position(model=model, data=data), 1])
|
2139
|
+
W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1])
|
1081
2140
|
|
1082
2141
|
U = -jnp.hstack([gravity, 0]) @ (m * W_p̃_CoM)
|
1083
2142
|
return U.squeeze().astype(float)
|
@@ -1089,15 +2148,18 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
|
|
1089
2148
|
|
1090
2149
|
|
1091
2150
|
@jax.jit
|
2151
|
+
@js.common.named_scope
|
1092
2152
|
def step(
|
1093
2153
|
model: JaxSimModel,
|
1094
2154
|
data: js.data.JaxSimModelData,
|
1095
2155
|
*,
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
2156
|
+
t0: jtp.FloatLike = 0.0,
|
2157
|
+
dt: jtp.FloatLike | None = None,
|
2158
|
+
integrator: jaxsim.integrators.Integrator | None = None,
|
2159
|
+
integrator_metadata: dict[str, Any] | None = None,
|
2160
|
+
link_forces: jtp.MatrixLike | None = None,
|
2161
|
+
joint_force_references: jtp.VectorLike | None = None,
|
2162
|
+
**kwargs,
|
1101
2163
|
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
|
1102
2164
|
"""
|
1103
2165
|
Perform a simulation step.
|
@@ -1105,40 +2167,188 @@ def step(
|
|
1105
2167
|
Args:
|
1106
2168
|
model: The model to consider.
|
1107
2169
|
data: The data of the considered model.
|
1108
|
-
dt: The time step to consider.
|
1109
2170
|
integrator: The integrator to use.
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
The
|
2171
|
+
integrator_metadata: The metadata of the integrator, if needed.
|
2172
|
+
t0: The initial time to consider. Only relevant for time-dependent dynamics.
|
2173
|
+
dt: The time step to consider. If not specified, it is read from the model.
|
2174
|
+
link_forces:
|
2175
|
+
The 6D forces to apply to the links expressed in the frame corresponding to
|
2176
|
+
the velocity representation of `data`.
|
2177
|
+
joint_force_references: The joint force references to consider.
|
2178
|
+
kwargs: Additional kwargs to pass to the integrator.
|
1115
2179
|
|
1116
2180
|
Returns:
|
1117
|
-
A tuple containing the new data of the model
|
1118
|
-
|
2181
|
+
A tuple containing the new data of the model and a dictionary of auxiliary
|
2182
|
+
data computed during the step. If the integrator has metadata, the dictionary
|
2183
|
+
will contain the new metadata stored in the `integrator_metadata` key.
|
2184
|
+
|
2185
|
+
Note:
|
2186
|
+
In order to reduce the occurrences of frame conversions performed internally,
|
2187
|
+
it is recommended to use inertial-fixed velocity representation. This can be
|
2188
|
+
particularly useful for automatically differentiated logic.
|
1119
2189
|
"""
|
1120
2190
|
|
1121
|
-
|
2191
|
+
# Extract the integrator kwargs.
|
2192
|
+
# The following logic allows using integrators having kwargs colliding with the
|
2193
|
+
# kwargs of this step function.
|
2194
|
+
kwargs = kwargs if kwargs is not None else {}
|
2195
|
+
integrator_kwargs = kwargs.pop("integrator_kwargs", {})
|
2196
|
+
integrator_kwargs = kwargs | integrator_kwargs
|
2197
|
+
|
2198
|
+
# Extract the integrator and the optional metadata.
|
2199
|
+
integrator_metadata_t0 = integrator_metadata
|
2200
|
+
integrator = integrator if integrator is not None else model.integrator
|
2201
|
+
|
2202
|
+
# Initialize the time-related variables.
|
2203
|
+
state_t0 = data.state
|
2204
|
+
t0 = jnp.array(t0, dtype=float)
|
2205
|
+
dt = jnp.array(dt if dt is not None else model.time_step).astype(float)
|
2206
|
+
|
2207
|
+
# The visco-elastic contacts operate at best with their own integrator.
|
2208
|
+
# They can be used with Euler-like integrators, paying the price of ignoring
|
2209
|
+
# some of the benefits of continuous-time integration on the system position.
|
2210
|
+
# Furthermore, the requirement to know the Δt used by the integrator is not
|
2211
|
+
# compatible with high-order integrators, that use advanced RK stages to evaluate
|
2212
|
+
# the dynamics at intermediate times.
|
2213
|
+
module = jaxsim.rbda.contacts.visco_elastic.step.__module__
|
2214
|
+
name = jaxsim.rbda.contacts.visco_elastic.step.__name__
|
2215
|
+
msg = "You need to use the custom '{}.{}' function with this contact model."
|
2216
|
+
jaxsim.exceptions.raise_runtime_error_if(
|
2217
|
+
condition=(
|
2218
|
+
isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts)
|
2219
|
+
& (
|
2220
|
+
~jnp.allclose(dt, model.time_step)
|
2221
|
+
| ~int(
|
2222
|
+
isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler)
|
2223
|
+
)
|
2224
|
+
)
|
2225
|
+
),
|
2226
|
+
msg=msg.format(module, name),
|
2227
|
+
)
|
1122
2228
|
|
1123
|
-
#
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
2229
|
+
# =================
|
2230
|
+
# Phase 1: pre-step
|
2231
|
+
# =================
|
2232
|
+
|
2233
|
+
# TODO: some contact models here may want to perform a dynamic filtering of
|
2234
|
+
# the enabled collidable points.
|
2235
|
+
|
2236
|
+
# Build the references object.
|
2237
|
+
# We assume that the link forces are expressed in the frame corresponding to the
|
2238
|
+
# velocity representation of the data.
|
2239
|
+
references = js.references.JaxSimModelReferences.build(
|
2240
|
+
model=model,
|
2241
|
+
data=data,
|
2242
|
+
velocity_representation=data.velocity_representation,
|
2243
|
+
link_forces=link_forces,
|
2244
|
+
joint_force_references=joint_force_references,
|
2245
|
+
)
|
2246
|
+
|
2247
|
+
# =============
|
2248
|
+
# Phase 2: step
|
2249
|
+
# =============
|
2250
|
+
|
2251
|
+
# Prepare the references to pass.
|
2252
|
+
with references.switch_velocity_representation(data.velocity_representation):
|
2253
|
+
|
2254
|
+
f_L = references.link_forces(model=model, data=data)
|
2255
|
+
τ_references = references.joint_force_references(model=model)
|
1127
2256
|
|
1128
2257
|
# Step the dynamics forward.
|
1129
|
-
|
1130
|
-
x0=
|
1131
|
-
t0=
|
2258
|
+
state_tf, integrator_metadata_tf = integrator.step(
|
2259
|
+
x0=state_t0,
|
2260
|
+
t0=t0,
|
1132
2261
|
dt=dt,
|
1133
|
-
|
1134
|
-
|
2262
|
+
metadata=integrator_metadata_t0,
|
2263
|
+
# Always inject the current (model, data) pair into the system dynamics
|
2264
|
+
# considered by the integrator, and include the input variables represented
|
2265
|
+
# by the pair (f_L, τ_references).
|
2266
|
+
# Note that the wrapper of the system dynamics will override (state_x0, t0)
|
2267
|
+
# inside the passed data even if it is not strictly needed. This logic is
|
2268
|
+
# necessary to reuse the jit-compiled step function of compatible pytrees
|
2269
|
+
# of model and data produced e.g. by parameterized applications.
|
2270
|
+
**(
|
2271
|
+
dict(
|
2272
|
+
model=model,
|
2273
|
+
data=data,
|
2274
|
+
link_forces=f_L,
|
2275
|
+
joint_force_references=τ_references,
|
2276
|
+
)
|
2277
|
+
| integrator_kwargs
|
2278
|
+
),
|
1135
2279
|
)
|
1136
2280
|
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
2281
|
+
# Store the new state of the model.
|
2282
|
+
data_tf = data.replace(state=state_tf)
|
2283
|
+
|
2284
|
+
# ==================
|
2285
|
+
# Phase 3: post-step
|
2286
|
+
# ==================
|
2287
|
+
|
2288
|
+
# Post process the simulation state, if needed.
|
2289
|
+
match model.contact_model:
|
2290
|
+
|
2291
|
+
# Rigid contact models use an impact model that produces discontinuous model velocities.
|
2292
|
+
# Hence, here we need to reset the velocity after each impact to guarantee that
|
2293
|
+
# the linear velocity of the active collidable points is zero.
|
2294
|
+
case jaxsim.rbda.contacts.RigidContacts():
|
2295
|
+
|
2296
|
+
# Raise runtime error for not supported case in which Rigid contacts and
|
2297
|
+
# Baumgarte stabilization are enabled and used with ForwardEuler integrator.
|
2298
|
+
jaxsim.exceptions.raise_runtime_error_if(
|
2299
|
+
condition=isinstance(
|
2300
|
+
integrator,
|
2301
|
+
jaxsim.integrators.fixed_step.ForwardEuler
|
2302
|
+
| jaxsim.integrators.fixed_step.ForwardEulerSO3,
|
2303
|
+
)
|
2304
|
+
& ((data_tf.contacts_params.K > 0) | (data_tf.contacts_params.D > 0)),
|
2305
|
+
msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
|
2306
|
+
)
|
2307
|
+
|
2308
|
+
# Extract the indices corresponding to the enabled collidable points.
|
2309
|
+
indices_of_enabled_collidable_points = (
|
2310
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
2311
|
+
)
|
2312
|
+
|
2313
|
+
W_p_C = js.contact.collidable_point_positions(model, data_tf)[
|
2314
|
+
indices_of_enabled_collidable_points
|
2315
|
+
]
|
2316
|
+
|
2317
|
+
# Compute the penetration depth of the collidable points.
|
2318
|
+
δ, *_ = jax.vmap(
|
2319
|
+
jaxsim.rbda.contacts.common.compute_penetration_data,
|
2320
|
+
in_axes=(0, 0, None),
|
2321
|
+
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
|
2322
|
+
|
2323
|
+
with data_tf.switch_velocity_representation(VelRepr.Mixed):
|
2324
|
+
J_WC = js.contact.jacobian(model, data_tf)[
|
2325
|
+
indices_of_enabled_collidable_points
|
2326
|
+
]
|
2327
|
+
M = js.model.free_floating_mass_matrix(model, data_tf)
|
2328
|
+
BW_ν_pre_impact = data_tf.generalized_velocity()
|
2329
|
+
|
2330
|
+
# Compute the impact velocity.
|
2331
|
+
# It may be discontinuous in case new contacts are made.
|
2332
|
+
BW_ν_post_impact = (
|
2333
|
+
jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
|
2334
|
+
generalized_velocity=BW_ν_pre_impact,
|
2335
|
+
inactive_collidable_points=(δ <= 0),
|
2336
|
+
M=M,
|
2337
|
+
J_WC=J_WC,
|
2338
|
+
)
|
2339
|
+
)
|
2340
|
+
|
2341
|
+
# Reset the generalized velocity.
|
2342
|
+
data_tf = data_tf.reset_base_velocity(BW_ν_post_impact[0:6])
|
2343
|
+
data_tf = data_tf.reset_joint_velocities(BW_ν_post_impact[6:])
|
2344
|
+
|
2345
|
+
# Restore the input velocity representation.
|
2346
|
+
data_tf = data_tf.replace(
|
2347
|
+
velocity_representation=data.velocity_representation, validate=False
|
2348
|
+
)
|
2349
|
+
|
2350
|
+
return data_tf, {} | (
|
2351
|
+
dict(integrator_metadata=integrator_metadata_tf)
|
2352
|
+
if integrator_metadata is not None
|
2353
|
+
else {}
|
1144
2354
|
)
|