jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +64 -30
- jaxsim/math/cross.py +18 -9
- jaxsim/math/inertia.py +11 -9
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +59 -25
- jaxsim/math/rotation.py +30 -24
- jaxsim/math/skew.py +18 -7
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/high_level/model.py
DELETED
@@ -1,1686 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
import functools
|
3
|
-
import pathlib
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
5
|
-
|
6
|
-
import jax
|
7
|
-
import jax.numpy as jnp
|
8
|
-
import jax_dataclasses
|
9
|
-
import numpy as np
|
10
|
-
import rod
|
11
|
-
from jax_dataclasses import Static
|
12
|
-
|
13
|
-
import jaxsim.physics.algos.aba
|
14
|
-
import jaxsim.physics.algos.crba
|
15
|
-
import jaxsim.physics.algos.forward_kinematics
|
16
|
-
import jaxsim.physics.algos.rnea
|
17
|
-
import jaxsim.physics.model.physics_model
|
18
|
-
import jaxsim.physics.model.physics_model_state
|
19
|
-
import jaxsim.typing as jtp
|
20
|
-
from jaxsim import high_level, logging, physics, sixd
|
21
|
-
from jaxsim.physics.algos import soft_contacts
|
22
|
-
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
|
23
|
-
from jaxsim.utils import JaxsimDataclass, Mutability, Vmappable, oop
|
24
|
-
|
25
|
-
from .common import VelRepr
|
26
|
-
|
27
|
-
|
28
|
-
@jax_dataclasses.pytree_dataclass
|
29
|
-
class ModelData(JaxsimDataclass):
|
30
|
-
"""
|
31
|
-
Class used to store the model state and input at a given time.
|
32
|
-
"""
|
33
|
-
|
34
|
-
model_state: jaxsim.physics.model.physics_model_state.PhysicsModelState
|
35
|
-
model_input: jaxsim.physics.model.physics_model_state.PhysicsModelInput
|
36
|
-
contact_state: jaxsim.physics.algos.soft_contacts.SoftContactsState
|
37
|
-
|
38
|
-
@staticmethod
|
39
|
-
def zero(physics_model: physics.model.physics_model.PhysicsModel) -> "ModelData":
|
40
|
-
"""
|
41
|
-
Return a ModelData object with all fields set to zero and initialized with the right shape.
|
42
|
-
|
43
|
-
Args:
|
44
|
-
physics_model: The considered physics model.
|
45
|
-
|
46
|
-
Returns:
|
47
|
-
The zero ModelData object of the given physics model.
|
48
|
-
"""
|
49
|
-
|
50
|
-
return ModelData(
|
51
|
-
model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState.zero(
|
52
|
-
physics_model=physics_model
|
53
|
-
),
|
54
|
-
model_input=jaxsim.physics.model.physics_model_state.PhysicsModelInput.zero(
|
55
|
-
physics_model=physics_model
|
56
|
-
),
|
57
|
-
contact_state=jaxsim.physics.algos.soft_contacts.SoftContactsState.zero(
|
58
|
-
physics_model=physics_model
|
59
|
-
),
|
60
|
-
)
|
61
|
-
|
62
|
-
|
63
|
-
@jax_dataclasses.pytree_dataclass
|
64
|
-
class StepData(JaxsimDataclass):
|
65
|
-
"""
|
66
|
-
Class used to store the data computed at each step of the simulation.
|
67
|
-
"""
|
68
|
-
|
69
|
-
t0: float
|
70
|
-
tf: float
|
71
|
-
dt: float
|
72
|
-
|
73
|
-
# Starting model data and real input (tau, f_ext) computed at t0
|
74
|
-
t0_model_data: ModelData = dataclasses.field(repr=False)
|
75
|
-
t0_model_input_real: jaxsim.physics.model.physics_model_state.PhysicsModelInput = (
|
76
|
-
dataclasses.field(repr=False)
|
77
|
-
)
|
78
|
-
|
79
|
-
# ABA output
|
80
|
-
t0_base_acceleration: jtp.Vector = dataclasses.field(repr=False)
|
81
|
-
t0_joint_acceleration: jtp.Vector = dataclasses.field(repr=False)
|
82
|
-
|
83
|
-
# (new ODEState)
|
84
|
-
# Starting from t0_model_data, can be obtained by integrating the ABA output
|
85
|
-
# and tangential_deformation_dot (which is fn of ode_state at t0)
|
86
|
-
tf_model_state: jaxsim.physics.model.physics_model_state.PhysicsModelState = (
|
87
|
-
dataclasses.field(repr=False)
|
88
|
-
)
|
89
|
-
tf_contact_state: jaxsim.physics.algos.soft_contacts.SoftContactsState = (
|
90
|
-
dataclasses.field(repr=False)
|
91
|
-
)
|
92
|
-
|
93
|
-
aux: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
94
|
-
|
95
|
-
|
96
|
-
@jax_dataclasses.pytree_dataclass
|
97
|
-
class Model(Vmappable):
|
98
|
-
"""
|
99
|
-
High-level class to operate on a simulated model.
|
100
|
-
"""
|
101
|
-
|
102
|
-
model_name: Static[str]
|
103
|
-
|
104
|
-
physics_model: physics.model.physics_model.PhysicsModel = dataclasses.field(
|
105
|
-
repr=False
|
106
|
-
)
|
107
|
-
|
108
|
-
velocity_representation: Static[VelRepr] = dataclasses.field(default=VelRepr.Mixed)
|
109
|
-
|
110
|
-
data: ModelData = dataclasses.field(default=None, repr=False)
|
111
|
-
|
112
|
-
# ========================
|
113
|
-
# Initialization and state
|
114
|
-
# ========================
|
115
|
-
|
116
|
-
@staticmethod
|
117
|
-
def build_from_model_description(
|
118
|
-
model_description: Union[str, pathlib.Path, rod.Model],
|
119
|
-
model_name: str | None = None,
|
120
|
-
vel_repr: VelRepr = VelRepr.Mixed,
|
121
|
-
gravity: jtp.Array = jaxsim.physics.default_gravity(),
|
122
|
-
is_urdf: bool | None = None,
|
123
|
-
considered_joints: List[str] | None = None,
|
124
|
-
) -> "Model":
|
125
|
-
"""
|
126
|
-
Build a Model object from a model description.
|
127
|
-
|
128
|
-
Args:
|
129
|
-
model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model.
|
130
|
-
model_name: The optional name of the model that overrides the one in the description.
|
131
|
-
vel_repr: The velocity representation to use.
|
132
|
-
gravity: The 3D gravity vector.
|
133
|
-
is_urdf: Whether the model description is a URDF or an SDF. This is automatically inferred if the model description is a path to a file.
|
134
|
-
considered_joints: The list of joints to consider. If None, all joints are considered.
|
135
|
-
|
136
|
-
Returns:
|
137
|
-
The built Model object.
|
138
|
-
"""
|
139
|
-
|
140
|
-
import jaxsim.parsers.rod
|
141
|
-
|
142
|
-
# Parse the input resource (either a path to file or a string with the URDF/SDF)
|
143
|
-
# and build the -intermediate- model description
|
144
|
-
model_description = jaxsim.parsers.rod.build_model_description(
|
145
|
-
model_description=model_description, is_urdf=is_urdf
|
146
|
-
)
|
147
|
-
|
148
|
-
# Lump links together if not all joints are considered.
|
149
|
-
# Note: this procedure assigns a zero position to all joints not considered.
|
150
|
-
if considered_joints is not None:
|
151
|
-
model_description = model_description.reduce(
|
152
|
-
considered_joints=considered_joints
|
153
|
-
)
|
154
|
-
|
155
|
-
# Create the physics model from the model description
|
156
|
-
physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
|
157
|
-
model_description=model_description, gravity=gravity
|
158
|
-
)
|
159
|
-
|
160
|
-
# Build and return the high-level model
|
161
|
-
return Model.build(
|
162
|
-
physics_model=physics_model,
|
163
|
-
model_name=model_name,
|
164
|
-
vel_repr=vel_repr,
|
165
|
-
)
|
166
|
-
|
167
|
-
@staticmethod
|
168
|
-
def build_from_sdf(
|
169
|
-
sdf: Union[str, pathlib.Path],
|
170
|
-
model_name: str | None = None,
|
171
|
-
vel_repr: VelRepr = VelRepr.Mixed,
|
172
|
-
gravity: jtp.Array = jaxsim.physics.default_gravity(),
|
173
|
-
is_urdf: bool | None = None,
|
174
|
-
considered_joints: List[str] | None = None,
|
175
|
-
) -> "Model":
|
176
|
-
"""
|
177
|
-
Build a Model object from an SDF description.
|
178
|
-
This is a deprecated method, use build_from_model_description instead.
|
179
|
-
"""
|
180
|
-
|
181
|
-
msg = "Model.{} is deprecated, use Model.{} instead."
|
182
|
-
logging.warning(
|
183
|
-
msg=msg.format("build_from_sdf", "build_from_model_description")
|
184
|
-
)
|
185
|
-
|
186
|
-
return Model.build_from_model_description(
|
187
|
-
model_description=sdf,
|
188
|
-
model_name=model_name,
|
189
|
-
vel_repr=vel_repr,
|
190
|
-
gravity=gravity,
|
191
|
-
is_urdf=is_urdf,
|
192
|
-
considered_joints=considered_joints,
|
193
|
-
)
|
194
|
-
|
195
|
-
@staticmethod
|
196
|
-
def build(
|
197
|
-
physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
|
198
|
-
model_name: str | None = None,
|
199
|
-
vel_repr: VelRepr = VelRepr.Mixed,
|
200
|
-
) -> "Model":
|
201
|
-
"""
|
202
|
-
Build a Model object from a physics model.
|
203
|
-
|
204
|
-
Args:
|
205
|
-
physics_model: The physics model.
|
206
|
-
model_name: The optional name of the model that overrides the one in the physics model.
|
207
|
-
vel_repr: The velocity representation to use.
|
208
|
-
|
209
|
-
Returns:
|
210
|
-
The built Model object.
|
211
|
-
"""
|
212
|
-
|
213
|
-
# Set the model name (if not provided, use the one from the model description)
|
214
|
-
model_name = (
|
215
|
-
model_name if model_name is not None else physics_model.description.name
|
216
|
-
)
|
217
|
-
|
218
|
-
# Build the high-level model
|
219
|
-
model = Model(
|
220
|
-
physics_model=physics_model,
|
221
|
-
model_name=model_name,
|
222
|
-
velocity_representation=vel_repr,
|
223
|
-
)
|
224
|
-
|
225
|
-
# Zero the model data
|
226
|
-
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
227
|
-
model.zero()
|
228
|
-
|
229
|
-
# Check model validity
|
230
|
-
if not model.valid():
|
231
|
-
raise RuntimeError("The model is not valid.")
|
232
|
-
|
233
|
-
# Return the high-level model
|
234
|
-
return model
|
235
|
-
|
236
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
|
237
|
-
def reduce(
|
238
|
-
self, considered_joints: tuple[str, ...], keep_base_pose: bool = False
|
239
|
-
) -> None:
|
240
|
-
"""
|
241
|
-
Reduce the model by lumping together the links connected by removed joints.
|
242
|
-
|
243
|
-
Args:
|
244
|
-
considered_joints: The sequence of joints to consider.
|
245
|
-
keep_base_pose: A flag indicating whether to keep the base pose or not.
|
246
|
-
"""
|
247
|
-
|
248
|
-
if self.vectorized:
|
249
|
-
raise RuntimeError("Cannot reduce a vectorized model.")
|
250
|
-
|
251
|
-
# Reduce the model description.
|
252
|
-
# If considered_joints contains joints not existing in the model, the method
|
253
|
-
# will raise an exception.
|
254
|
-
reduced_model_description = self.physics_model.description.reduce(
|
255
|
-
considered_joints=list(considered_joints)
|
256
|
-
)
|
257
|
-
|
258
|
-
# Create the physics model from the reduced model description
|
259
|
-
physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
|
260
|
-
model_description=reduced_model_description,
|
261
|
-
gravity=self.physics_model.gravity[0:3],
|
262
|
-
)
|
263
|
-
|
264
|
-
# Build the reduced high-level model
|
265
|
-
reduced_model = Model.build(
|
266
|
-
physics_model=physics_model,
|
267
|
-
model_name=self.name(),
|
268
|
-
vel_repr=self.velocity_representation,
|
269
|
-
)
|
270
|
-
|
271
|
-
# Extract the base pose
|
272
|
-
W_p_B = self.base_position()
|
273
|
-
W_Q_B = self.base_orientation(dcm=False)
|
274
|
-
|
275
|
-
# Replace the current model with the reduced model.
|
276
|
-
# Since the structure of the PyTree changes, we disable validation.
|
277
|
-
self.physics_model = reduced_model.physics_model
|
278
|
-
self.data = reduced_model.data
|
279
|
-
|
280
|
-
if keep_base_pose:
|
281
|
-
self.reset_base_position(position=W_p_B)
|
282
|
-
self.reset_base_orientation(orientation=W_Q_B, dcm=False)
|
283
|
-
|
284
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False)
|
285
|
-
def zero(self) -> None:
|
286
|
-
""""""
|
287
|
-
|
288
|
-
self.data = ModelData.zero(physics_model=self.physics_model)
|
289
|
-
|
290
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False)
|
291
|
-
def zero_input(self) -> None:
|
292
|
-
""""""
|
293
|
-
|
294
|
-
self.data.model_input = ModelData.zero(
|
295
|
-
physics_model=self.physics_model
|
296
|
-
).model_input
|
297
|
-
|
298
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False)
|
299
|
-
def zero_state(self) -> None:
|
300
|
-
""""""
|
301
|
-
|
302
|
-
model_data_zero = ModelData.zero(physics_model=self.physics_model)
|
303
|
-
self.data.model_state = model_data_zero.model_state
|
304
|
-
self.data.contact_state = model_data_zero.contact_state
|
305
|
-
|
306
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False)
|
307
|
-
def set_velocity_representation(self, vel_repr: VelRepr) -> None:
|
308
|
-
""""""
|
309
|
-
|
310
|
-
if self.velocity_representation is vel_repr:
|
311
|
-
return
|
312
|
-
|
313
|
-
self.velocity_representation = vel_repr
|
314
|
-
|
315
|
-
# ==========
|
316
|
-
# Properties
|
317
|
-
# ==========
|
318
|
-
|
319
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
320
|
-
def valid(self) -> jtp.Bool:
|
321
|
-
""""""
|
322
|
-
|
323
|
-
valid = True
|
324
|
-
valid = valid and all(l.valid() for l in self.links())
|
325
|
-
valid = valid and all(j.valid() for j in self.joints())
|
326
|
-
return jnp.array(valid, dtype=bool)
|
327
|
-
|
328
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
329
|
-
def floating_base(self) -> jtp.Bool:
|
330
|
-
""""""
|
331
|
-
|
332
|
-
return jnp.array(self.physics_model.is_floating_base, dtype=bool)
|
333
|
-
|
334
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
335
|
-
def dofs(self) -> jtp.Int:
|
336
|
-
""""""
|
337
|
-
|
338
|
-
return self.joint_positions().size
|
339
|
-
|
340
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
341
|
-
def name(self) -> str:
|
342
|
-
""""""
|
343
|
-
|
344
|
-
return self.model_name
|
345
|
-
|
346
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
347
|
-
def nr_of_links(self) -> jtp.Int:
|
348
|
-
""""""
|
349
|
-
|
350
|
-
return jnp.array(len(self.links()), dtype=int)
|
351
|
-
|
352
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
353
|
-
def nr_of_joints(self) -> jtp.Int:
|
354
|
-
""""""
|
355
|
-
|
356
|
-
return jnp.array(len(self.joints()), dtype=int)
|
357
|
-
|
358
|
-
@functools.partial(oop.jax_tf.method_ro)
|
359
|
-
def total_mass(self) -> jtp.Float:
|
360
|
-
""""""
|
361
|
-
|
362
|
-
return jnp.sum(jnp.array([l.mass() for l in self.links()]), dtype=float)
|
363
|
-
|
364
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
365
|
-
def get_link(self, link_name: str) -> high_level.link.Link:
|
366
|
-
""""""
|
367
|
-
|
368
|
-
if link_name not in self.link_names():
|
369
|
-
msg = f"Link '{link_name}' is not part of model '{self.name()}'"
|
370
|
-
raise ValueError(msg)
|
371
|
-
|
372
|
-
return self.links(link_names=(link_name,))[0]
|
373
|
-
|
374
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
375
|
-
def get_joint(self, joint_name: str) -> high_level.joint.Joint:
|
376
|
-
""""""
|
377
|
-
|
378
|
-
if joint_name not in self.joint_names():
|
379
|
-
msg = f"Joint '{joint_name}' is not part of model '{self.name()}'"
|
380
|
-
raise ValueError(msg)
|
381
|
-
|
382
|
-
return self.joints(joint_names=(joint_name,))[0]
|
383
|
-
|
384
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
385
|
-
def link_names(self) -> tuple[str, ...]:
|
386
|
-
""""""
|
387
|
-
|
388
|
-
return tuple(self.physics_model.description.links_dict.keys())
|
389
|
-
|
390
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
391
|
-
def joint_names(self) -> tuple[str, ...]:
|
392
|
-
""""""
|
393
|
-
|
394
|
-
return tuple(self.physics_model.description.joints_dict.keys())
|
395
|
-
|
396
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
397
|
-
def links(
|
398
|
-
self, link_names: tuple[str, ...] | None = None
|
399
|
-
) -> tuple[high_level.link.Link, ...]:
|
400
|
-
""""""
|
401
|
-
|
402
|
-
all_links = {
|
403
|
-
l.name: high_level.link.Link(
|
404
|
-
link_description=l, _parent_model=self, batch_size=self.batch_size
|
405
|
-
)
|
406
|
-
for l in sorted(
|
407
|
-
self.physics_model.description.links_dict.values(),
|
408
|
-
key=lambda l: l.index,
|
409
|
-
)
|
410
|
-
}
|
411
|
-
|
412
|
-
for l in all_links.values():
|
413
|
-
l._set_mutability(self._mutability())
|
414
|
-
|
415
|
-
if link_names is None:
|
416
|
-
return tuple(all_links.values())
|
417
|
-
|
418
|
-
return tuple(all_links[name] for name in link_names)
|
419
|
-
|
420
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
421
|
-
def joints(
|
422
|
-
self, joint_names: tuple[str, ...] | None = None
|
423
|
-
) -> tuple[high_level.joint.Joint, ...]:
|
424
|
-
""""""
|
425
|
-
|
426
|
-
all_joints = {
|
427
|
-
j.name: high_level.joint.Joint(
|
428
|
-
joint_description=j, _parent_model=self, batch_size=self.batch_size
|
429
|
-
)
|
430
|
-
for j in sorted(
|
431
|
-
self.physics_model.description.joints_dict.values(),
|
432
|
-
key=lambda j: j.index,
|
433
|
-
)
|
434
|
-
}
|
435
|
-
|
436
|
-
for j in all_joints.values():
|
437
|
-
j._set_mutability(self._mutability())
|
438
|
-
|
439
|
-
if joint_names is None:
|
440
|
-
return tuple(all_joints.values())
|
441
|
-
|
442
|
-
return tuple(all_joints[name] for name in joint_names)
|
443
|
-
|
444
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["link_names", "terrain"])
|
445
|
-
def in_contact(
|
446
|
-
self,
|
447
|
-
link_names: tuple[str, ...] | None = None,
|
448
|
-
terrain: Terrain = FlatTerrain(),
|
449
|
-
) -> jtp.Vector:
|
450
|
-
""""""
|
451
|
-
|
452
|
-
link_names = link_names if link_names is not None else self.link_names()
|
453
|
-
|
454
|
-
if set(link_names) - set(self.link_names()) != set():
|
455
|
-
raise ValueError("One or more link names are not part of the model")
|
456
|
-
|
457
|
-
from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel
|
458
|
-
|
459
|
-
W_p_Ci, _ = collidable_points_pos_vel(
|
460
|
-
model=self.physics_model,
|
461
|
-
q=self.data.model_state.joint_positions,
|
462
|
-
qd=self.data.model_state.joint_velocities,
|
463
|
-
xfb=self.data.model_state.xfb(),
|
464
|
-
)
|
465
|
-
|
466
|
-
terrain_height = jax.vmap(terrain.height)(W_p_Ci[0, :], W_p_Ci[1, :])
|
467
|
-
|
468
|
-
below_terrain = W_p_Ci[2, :] <= terrain_height
|
469
|
-
|
470
|
-
links_in_contact = jax.vmap(
|
471
|
-
lambda link_index: jnp.where(
|
472
|
-
self.physics_model.gc.body == link_index,
|
473
|
-
below_terrain,
|
474
|
-
jnp.zeros_like(below_terrain, dtype=bool),
|
475
|
-
).any()
|
476
|
-
)(jnp.array([link.index() for link in self.links(link_names=link_names)]))
|
477
|
-
|
478
|
-
return links_in_contact
|
479
|
-
|
480
|
-
# =================
|
481
|
-
# Multi-DoF methods
|
482
|
-
# =================
|
483
|
-
|
484
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
|
485
|
-
def joint_positions(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector:
|
486
|
-
""""""
|
487
|
-
|
488
|
-
return self.data.model_state.joint_positions[
|
489
|
-
self._joint_indices(joint_names=joint_names)
|
490
|
-
]
|
491
|
-
|
492
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
|
493
|
-
def joint_random_positions(
|
494
|
-
self,
|
495
|
-
joint_names: tuple[str, ...] | None = None,
|
496
|
-
key: jax.Array | None = None,
|
497
|
-
) -> jtp.Vector:
|
498
|
-
""""""
|
499
|
-
|
500
|
-
if key is None:
|
501
|
-
key = jax.random.PRNGKey(seed=0)
|
502
|
-
|
503
|
-
s_min, s_max = self.joint_limits(joint_names=joint_names)
|
504
|
-
|
505
|
-
s_random = jax.random.uniform(
|
506
|
-
minval=s_min,
|
507
|
-
maxval=s_max,
|
508
|
-
key=key,
|
509
|
-
shape=s_min.shape,
|
510
|
-
)
|
511
|
-
|
512
|
-
return s_random
|
513
|
-
|
514
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
|
515
|
-
def joint_velocities(
|
516
|
-
self, joint_names: tuple[str, ...] | None = None
|
517
|
-
) -> jtp.Vector:
|
518
|
-
""""""
|
519
|
-
|
520
|
-
return self.data.model_state.joint_velocities[
|
521
|
-
self._joint_indices(joint_names=joint_names)
|
522
|
-
]
|
523
|
-
|
524
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
|
525
|
-
def joint_generalized_forces_targets(
|
526
|
-
self, joint_names: tuple[str, ...] | None = None
|
527
|
-
) -> jtp.Vector:
|
528
|
-
""""""
|
529
|
-
|
530
|
-
return self.data.model_input.tau[self._joint_indices(joint_names=joint_names)]
|
531
|
-
|
532
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
|
533
|
-
def joint_limits(
|
534
|
-
self, joint_names: tuple[str, ...] | None = None
|
535
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
536
|
-
""""""
|
537
|
-
|
538
|
-
# Consider all joints if not specified otherwise
|
539
|
-
joint_names = joint_names if joint_names is not None else self.joint_names()
|
540
|
-
|
541
|
-
# Create a (Dofs, 2) matrix containing the joint limits
|
542
|
-
limits = jnp.vstack(
|
543
|
-
jnp.array([j.position_limit() for j in self.joints(joint_names)])
|
544
|
-
)
|
545
|
-
|
546
|
-
# Get the limits, reordering them in case low > high
|
547
|
-
s_low = jnp.min(limits, axis=1)
|
548
|
-
s_high = jnp.max(limits, axis=1)
|
549
|
-
|
550
|
-
return s_low, s_high
|
551
|
-
|
552
|
-
# =========
|
553
|
-
# Base link
|
554
|
-
# =========
|
555
|
-
|
556
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
557
|
-
def base_frame(self) -> str:
|
558
|
-
""""""
|
559
|
-
|
560
|
-
return self.physics_model.description.root.name
|
561
|
-
|
562
|
-
@functools.partial(oop.jax_tf.method_ro)
|
563
|
-
def base_position(self) -> jtp.Vector:
|
564
|
-
""""""
|
565
|
-
|
566
|
-
return self.data.model_state.base_position.squeeze()
|
567
|
-
|
568
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"])
|
569
|
-
def base_orientation(self, dcm: bool = False) -> jtp.Vector:
|
570
|
-
""""""
|
571
|
-
|
572
|
-
# Normalize the quaternion before using it.
|
573
|
-
# Our integration logic has a Baumgarte stabilization term makes the quaternion
|
574
|
-
# norm converge to 1, but it does not enforce to be 1 at all the time instants.
|
575
|
-
base_unit_quaternion = (
|
576
|
-
self.data.model_state.base_quaternion.squeeze()
|
577
|
-
/ jnp.linalg.norm(self.data.model_state.base_quaternion)
|
578
|
-
)
|
579
|
-
|
580
|
-
# wxyz -> xyzw
|
581
|
-
to_xyzw = np.array([1, 2, 3, 0])
|
582
|
-
|
583
|
-
return (
|
584
|
-
base_unit_quaternion
|
585
|
-
if not dcm
|
586
|
-
else sixd.so3.SO3.from_quaternion_xyzw(
|
587
|
-
base_unit_quaternion[to_xyzw]
|
588
|
-
).as_matrix()
|
589
|
-
)
|
590
|
-
|
591
|
-
@functools.partial(oop.jax_tf.method_ro)
|
592
|
-
def base_transform(self) -> jtp.MatrixJax:
|
593
|
-
""""""
|
594
|
-
|
595
|
-
W_R_B = self.base_orientation(dcm=True)
|
596
|
-
W_p_B = jnp.vstack(self.base_position())
|
597
|
-
|
598
|
-
return jnp.vstack(
|
599
|
-
[
|
600
|
-
jnp.block([W_R_B, W_p_B]),
|
601
|
-
jnp.array([0, 0, 0, 1]),
|
602
|
-
]
|
603
|
-
)
|
604
|
-
|
605
|
-
@functools.partial(oop.jax_tf.method_ro)
|
606
|
-
def base_velocity(self) -> jtp.Vector:
|
607
|
-
""""""
|
608
|
-
|
609
|
-
W_v_WB = jnp.hstack(
|
610
|
-
[
|
611
|
-
self.data.model_state.base_linear_velocity,
|
612
|
-
self.data.model_state.base_angular_velocity,
|
613
|
-
]
|
614
|
-
)
|
615
|
-
|
616
|
-
return self.inertial_to_active_representation(array=W_v_WB)
|
617
|
-
|
618
|
-
@functools.partial(oop.jax_tf.method_ro)
|
619
|
-
def external_forces(self) -> jtp.Matrix:
|
620
|
-
"""
|
621
|
-
Return the active external forces acting on the robot.
|
622
|
-
|
623
|
-
The external forces are a user input and are not computed by the physics engine.
|
624
|
-
During the simulation, these external forces are summed to other terms like
|
625
|
-
the external forces due to the contact with the environment.
|
626
|
-
|
627
|
-
Returns:
|
628
|
-
A matrix of shape (n_links, 6) containing the external forces acting on the
|
629
|
-
robot links. The forces are expressed in the active representation.
|
630
|
-
"""
|
631
|
-
|
632
|
-
# Get the active external forces that are always stored internally
|
633
|
-
# in Inertial representation
|
634
|
-
W_f_ext = self.data.model_input.f_ext
|
635
|
-
|
636
|
-
inertial_to_active = lambda f: self.inertial_to_active_representation(
|
637
|
-
f, is_force=True
|
638
|
-
)
|
639
|
-
|
640
|
-
return jax.vmap(inertial_to_active, in_axes=0)(W_f_ext)
|
641
|
-
|
642
|
-
# =======================
|
643
|
-
# Single link r/w methods
|
644
|
-
# =======================
|
645
|
-
|
646
|
-
@functools.partial(
|
647
|
-
oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"]
|
648
|
-
)
|
649
|
-
def apply_external_force_to_link(
|
650
|
-
self,
|
651
|
-
link_name: str,
|
652
|
-
force: jtp.Array | None = None,
|
653
|
-
torque: jtp.Array | None = None,
|
654
|
-
additive: bool = True,
|
655
|
-
) -> None:
|
656
|
-
""""""
|
657
|
-
|
658
|
-
# Get the target link with the correct mutability
|
659
|
-
link = self.get_link(link_name=link_name)
|
660
|
-
link._set_mutability(mutability=self._mutability())
|
661
|
-
|
662
|
-
# Initialize zero force components if not set
|
663
|
-
force = force if force is not None else jnp.zeros(3)
|
664
|
-
torque = torque if torque is not None else jnp.zeros(3)
|
665
|
-
|
666
|
-
# Build the target 6D force in the active representation
|
667
|
-
f_ext = jnp.hstack([force, torque])
|
668
|
-
|
669
|
-
# Convert the 6D force to the inertial representation
|
670
|
-
if self.velocity_representation is VelRepr.Inertial:
|
671
|
-
W_f_ext = f_ext
|
672
|
-
|
673
|
-
elif self.velocity_representation is VelRepr.Body:
|
674
|
-
L_f_ext = f_ext
|
675
|
-
W_H_L = link.transform()
|
676
|
-
L_X_W = sixd.se3.SE3.from_matrix(W_H_L).inverse().adjoint()
|
677
|
-
|
678
|
-
W_f_ext = L_X_W.transpose() @ L_f_ext
|
679
|
-
|
680
|
-
elif self.velocity_representation is VelRepr.Mixed:
|
681
|
-
LW_f_ext = f_ext
|
682
|
-
|
683
|
-
W_p_L = link.transform()[0:3, 3]
|
684
|
-
W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L)
|
685
|
-
LW_X_W = sixd.se3.SE3.from_matrix(W_H_LW).inverse().adjoint()
|
686
|
-
|
687
|
-
W_f_ext = LW_X_W.transpose() @ LW_f_ext
|
688
|
-
|
689
|
-
else:
|
690
|
-
raise ValueError(self.velocity_representation)
|
691
|
-
|
692
|
-
# Obtain the new 6D force considering the 'additive' flag
|
693
|
-
W_f_ext_current = self.data.model_input.f_ext[link.index(), :]
|
694
|
-
new_force = W_f_ext_current + W_f_ext if additive else W_f_ext
|
695
|
-
|
696
|
-
# Update the model data
|
697
|
-
self.data.model_input.f_ext = self.data.model_input.f_ext.at[
|
698
|
-
link.index(), :
|
699
|
-
].set(new_force)
|
700
|
-
|
701
|
-
@functools.partial(
|
702
|
-
oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"]
|
703
|
-
)
|
704
|
-
def apply_external_force_to_link_com(
|
705
|
-
self,
|
706
|
-
link_name: str,
|
707
|
-
force: jtp.Array | None = None,
|
708
|
-
torque: jtp.Array | None = None,
|
709
|
-
additive: bool = True,
|
710
|
-
) -> None:
|
711
|
-
""""""
|
712
|
-
|
713
|
-
# Get the target link with the correct mutability
|
714
|
-
link = self.get_link(link_name=link_name)
|
715
|
-
link._set_mutability(mutability=self._mutability())
|
716
|
-
|
717
|
-
# Initialize zero force components if not set
|
718
|
-
force = force if force is not None else jnp.zeros(3)
|
719
|
-
torque = torque if torque is not None else jnp.zeros(3)
|
720
|
-
|
721
|
-
# Build the target 6D force in the active representation
|
722
|
-
f_ext = jnp.hstack([force, torque])
|
723
|
-
|
724
|
-
# Convert the 6D force to the inertial representation
|
725
|
-
if self.velocity_representation is VelRepr.Inertial:
|
726
|
-
W_f_ext = f_ext
|
727
|
-
|
728
|
-
elif self.velocity_representation is VelRepr.Body:
|
729
|
-
GL_f_ext = f_ext
|
730
|
-
|
731
|
-
W_H_L = link.transform()
|
732
|
-
L_p_CoM = link.com_position(in_link_frame=True)
|
733
|
-
L_H_GL = jnp.eye(4).at[0:3, 3].set(L_p_CoM)
|
734
|
-
W_H_GL = W_H_L @ L_H_GL
|
735
|
-
GL_X_W = sixd.se3.SE3.from_matrix(W_H_GL).inverse().adjoint()
|
736
|
-
|
737
|
-
W_f_ext = GL_X_W.transpose() @ GL_f_ext
|
738
|
-
|
739
|
-
elif self.velocity_representation is VelRepr.Mixed:
|
740
|
-
GW_f_ext = f_ext
|
741
|
-
|
742
|
-
W_p_CoM = link.com_position(in_link_frame=False)
|
743
|
-
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
|
744
|
-
GW_X_W = sixd.se3.SE3.from_matrix(W_H_GW).inverse().adjoint()
|
745
|
-
|
746
|
-
W_f_ext = GW_X_W.transpose() @ GW_f_ext
|
747
|
-
|
748
|
-
else:
|
749
|
-
raise ValueError(self.velocity_representation)
|
750
|
-
|
751
|
-
# Obtain the new 6D force considering the 'additive' flag
|
752
|
-
W_f_ext_current = self.data.model_input.f_ext[link.index(), :]
|
753
|
-
new_force = W_f_ext_current + W_f_ext if additive else W_f_ext
|
754
|
-
|
755
|
-
# Update the model data
|
756
|
-
self.data.model_input.f_ext = self.data.model_input.f_ext.at[
|
757
|
-
link.index(), :
|
758
|
-
].set(new_force)
|
759
|
-
|
760
|
-
# ================================================
|
761
|
-
# Generalized methods and free-floating quantities
|
762
|
-
# ================================================
|
763
|
-
|
764
|
-
@functools.partial(oop.jax_tf.method_ro)
|
765
|
-
def generalized_position(self) -> Tuple[jtp.Matrix, jtp.Vector]:
|
766
|
-
""""""
|
767
|
-
|
768
|
-
return self.base_transform(), self.joint_positions()
|
769
|
-
|
770
|
-
@functools.partial(oop.jax_tf.method_ro)
|
771
|
-
def generalized_velocity(self) -> jtp.Vector:
|
772
|
-
""""""
|
773
|
-
|
774
|
-
return jnp.hstack([self.base_velocity(), self.joint_velocities()])
|
775
|
-
|
776
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"])
|
777
|
-
def generalized_free_floating_jacobian(
|
778
|
-
self, output_vel_repr: VelRepr | None = None
|
779
|
-
) -> jtp.Matrix:
|
780
|
-
""""""
|
781
|
-
|
782
|
-
if output_vel_repr is None:
|
783
|
-
output_vel_repr = self.velocity_representation
|
784
|
-
|
785
|
-
# The body frame of the Link.jacobian method is the link frame L.
|
786
|
-
# In this method, we want instead to use the base link B as body frame.
|
787
|
-
# Therefore, we always get the link jacobian having Inertial as output
|
788
|
-
# representation, and then we convert it to the desired output representation.
|
789
|
-
if output_vel_repr is VelRepr.Inertial:
|
790
|
-
to_output = lambda J: J
|
791
|
-
|
792
|
-
elif output_vel_repr is VelRepr.Body:
|
793
|
-
|
794
|
-
def to_output(W_J_Wi):
|
795
|
-
W_H_B = self.base_transform()
|
796
|
-
B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
|
797
|
-
return B_X_W @ W_J_Wi
|
798
|
-
|
799
|
-
elif output_vel_repr is VelRepr.Mixed:
|
800
|
-
|
801
|
-
def to_output(W_J_Wi):
|
802
|
-
W_H_B = self.base_transform()
|
803
|
-
W_H_BW = jnp.array(W_H_B).at[0:3, 0:3].set(jnp.eye(3))
|
804
|
-
BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
|
805
|
-
return BW_X_W @ W_J_Wi
|
806
|
-
|
807
|
-
else:
|
808
|
-
raise ValueError(output_vel_repr)
|
809
|
-
|
810
|
-
# Get the link jacobians in Inertial representation and convert them to the
|
811
|
-
# target output representation in which the body frame is the base link B
|
812
|
-
J_free_floating = jnp.vstack(
|
813
|
-
[
|
814
|
-
to_output(
|
815
|
-
self.get_link(link_name=link_name).jacobian(
|
816
|
-
output_vel_repr=VelRepr.Inertial
|
817
|
-
)
|
818
|
-
)
|
819
|
-
for link_name in self.link_names()
|
820
|
-
]
|
821
|
-
)
|
822
|
-
|
823
|
-
return J_free_floating
|
824
|
-
|
825
|
-
@functools.partial(oop.jax_tf.method_ro)
|
826
|
-
def free_floating_mass_matrix(self) -> jtp.Matrix:
|
827
|
-
""""""
|
828
|
-
|
829
|
-
M_body = jaxsim.physics.algos.crba.crba(
|
830
|
-
model=self.physics_model,
|
831
|
-
q=self.data.model_state.joint_positions,
|
832
|
-
)
|
833
|
-
|
834
|
-
if self.velocity_representation is VelRepr.Body:
|
835
|
-
return M_body
|
836
|
-
|
837
|
-
elif self.velocity_representation is VelRepr.Inertial:
|
838
|
-
zero_6n = jnp.zeros(shape=(6, self.dofs()))
|
839
|
-
B_X_W = sixd.se3.SE3.from_matrix(self.base_transform()).inverse().adjoint()
|
840
|
-
|
841
|
-
invT = jnp.vstack(
|
842
|
-
[
|
843
|
-
jnp.block([B_X_W, zero_6n]),
|
844
|
-
jnp.block([zero_6n.T, jnp.eye(self.dofs())]),
|
845
|
-
]
|
846
|
-
)
|
847
|
-
|
848
|
-
return invT.T @ M_body @ invT
|
849
|
-
|
850
|
-
elif self.velocity_representation is VelRepr.Mixed:
|
851
|
-
zero_6n = jnp.zeros(shape=(6, self.dofs()))
|
852
|
-
W_H_BW = self.base_transform().at[0:3, 3].set(jnp.zeros(3))
|
853
|
-
BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
|
854
|
-
|
855
|
-
invT = jnp.vstack(
|
856
|
-
[
|
857
|
-
jnp.block([BW_X_W, zero_6n]),
|
858
|
-
jnp.block([zero_6n.T, jnp.eye(self.dofs())]),
|
859
|
-
]
|
860
|
-
)
|
861
|
-
|
862
|
-
return invT.T @ M_body @ invT
|
863
|
-
|
864
|
-
else:
|
865
|
-
raise ValueError(self.velocity_representation)
|
866
|
-
|
867
|
-
@functools.partial(oop.jax_tf.method_ro)
|
868
|
-
def free_floating_bias_forces(self) -> jtp.Vector:
|
869
|
-
""""""
|
870
|
-
|
871
|
-
with self.editable(validate=True) as model:
|
872
|
-
model.zero_input()
|
873
|
-
|
874
|
-
return jnp.hstack(
|
875
|
-
model.inverse_dynamics(
|
876
|
-
base_acceleration=jnp.zeros(6), joint_accelerations=None
|
877
|
-
)
|
878
|
-
)
|
879
|
-
|
880
|
-
@functools.partial(oop.jax_tf.method_ro)
|
881
|
-
def free_floating_gravity_forces(self) -> jtp.Vector:
|
882
|
-
""""""
|
883
|
-
|
884
|
-
with self.editable(validate=True) as model:
|
885
|
-
model.zero_input()
|
886
|
-
model.data.model_state.joint_velocities = jnp.zeros_like(
|
887
|
-
model.data.model_state.joint_velocities
|
888
|
-
)
|
889
|
-
model.data.model_state.base_linear_velocity = jnp.zeros_like(
|
890
|
-
model.data.model_state.base_linear_velocity
|
891
|
-
)
|
892
|
-
model.data.model_state.base_angular_velocity = jnp.zeros_like(
|
893
|
-
model.data.model_state.base_angular_velocity
|
894
|
-
)
|
895
|
-
|
896
|
-
return jnp.hstack(
|
897
|
-
model.inverse_dynamics(
|
898
|
-
base_acceleration=jnp.zeros(6), joint_accelerations=None
|
899
|
-
)
|
900
|
-
)
|
901
|
-
|
902
|
-
@functools.partial(oop.jax_tf.method_ro)
|
903
|
-
def momentum(self) -> jtp.Vector:
|
904
|
-
""""""
|
905
|
-
|
906
|
-
with self.editable(validate=True) as m:
|
907
|
-
m.set_velocity_representation(vel_repr=VelRepr.Body)
|
908
|
-
|
909
|
-
# Compute the momentum in body-fixed velocity representation.
|
910
|
-
# Note: the first 6 rows of the mass matrix define the jacobian of the
|
911
|
-
# floating-base momentum.
|
912
|
-
B_h = m.free_floating_mass_matrix()[0:6, :] @ m.generalized_velocity()
|
913
|
-
|
914
|
-
W_H_B = self.base_transform()
|
915
|
-
B_X_W: jtp.Array = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
|
916
|
-
|
917
|
-
W_h = B_X_W.T @ B_h
|
918
|
-
return self.inertial_to_active_representation(array=W_h, is_force=True)
|
919
|
-
|
920
|
-
# ===========
|
921
|
-
# CoM methods
|
922
|
-
# ===========
|
923
|
-
|
924
|
-
@functools.partial(oop.jax_tf.method_ro)
|
925
|
-
def com_position(self) -> jtp.Vector:
|
926
|
-
""""""
|
927
|
-
|
928
|
-
m = self.total_mass()
|
929
|
-
|
930
|
-
W_H_L = self.forward_kinematics()
|
931
|
-
W_H_B = self.base_transform()
|
932
|
-
B_H_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().as_matrix()
|
933
|
-
|
934
|
-
com_links = [
|
935
|
-
(
|
936
|
-
l.mass()
|
937
|
-
* B_H_W
|
938
|
-
@ W_H_L[l.index()]
|
939
|
-
@ jnp.hstack([l.com_position(in_link_frame=True), 1])
|
940
|
-
)
|
941
|
-
for l in self.links()
|
942
|
-
]
|
943
|
-
|
944
|
-
B_ph_CoM = (1 / m) * jnp.sum(jnp.array(com_links), axis=0)
|
945
|
-
|
946
|
-
return (W_H_B @ B_ph_CoM)[0:3]
|
947
|
-
|
948
|
-
# ==========
|
949
|
-
# Algorithms
|
950
|
-
# ==========
|
951
|
-
|
952
|
-
@functools.partial(oop.jax_tf.method_ro)
|
953
|
-
def forward_kinematics(self) -> jtp.Array:
|
954
|
-
""""""
|
955
|
-
|
956
|
-
W_H_i = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
|
957
|
-
model=self.physics_model,
|
958
|
-
q=self.data.model_state.joint_positions,
|
959
|
-
xfb=self.data.model_state.xfb(),
|
960
|
-
)
|
961
|
-
|
962
|
-
return W_H_i
|
963
|
-
|
964
|
-
@functools.partial(oop.jax_tf.method_ro)
|
965
|
-
def inverse_dynamics(
|
966
|
-
self,
|
967
|
-
joint_accelerations: jtp.Vector | None = None,
|
968
|
-
base_acceleration: jtp.Vector | None = None,
|
969
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
970
|
-
"""
|
971
|
-
Compute inverse dynamics with the RNEA algorithm.
|
972
|
-
|
973
|
-
Args:
|
974
|
-
joint_accelerations: the joint accelerations to consider.
|
975
|
-
base_acceleration: the base acceleration in the active representation to consider.
|
976
|
-
|
977
|
-
Returns:
|
978
|
-
A tuple containing the 6D force in active representation applied to the base
|
979
|
-
to obtain the considered base acceleration, and the joint torques to apply
|
980
|
-
to obtain the considered joint accelerations.
|
981
|
-
"""
|
982
|
-
|
983
|
-
# Build joint accelerations if not provided
|
984
|
-
joint_accelerations = (
|
985
|
-
joint_accelerations
|
986
|
-
if joint_accelerations is not None
|
987
|
-
else jnp.zeros_like(self.joint_positions())
|
988
|
-
)
|
989
|
-
|
990
|
-
# Build base acceleration if not provided
|
991
|
-
base_acceleration = (
|
992
|
-
base_acceleration if base_acceleration is not None else jnp.zeros(6)
|
993
|
-
)
|
994
|
-
|
995
|
-
if base_acceleration.size != 6:
|
996
|
-
raise ValueError(base_acceleration.size)
|
997
|
-
|
998
|
-
def to_inertial(C_vd_WB, W_H_C, C_v_WB, W_vl_WC):
|
999
|
-
W_X_C = sixd.se3.SE3.from_matrix(W_H_C).adjoint()
|
1000
|
-
C_X_W = sixd.se3.SE3.from_matrix(W_H_C).inverse().adjoint()
|
1001
|
-
|
1002
|
-
if self.velocity_representation != VelRepr.Mixed:
|
1003
|
-
return W_X_C @ C_vd_WB
|
1004
|
-
else:
|
1005
|
-
from jaxsim.math.cross import Cross
|
1006
|
-
|
1007
|
-
C_v_WC = C_X_W @ jnp.hstack([W_vl_WC, jnp.zeros(3)])
|
1008
|
-
return W_X_C @ (C_vd_WB + Cross.vx(C_v_WC) @ C_v_WB)
|
1009
|
-
|
1010
|
-
if self.velocity_representation is VelRepr.Inertial:
|
1011
|
-
W_H_C = W_H_W = jnp.eye(4)
|
1012
|
-
W_vl_WC = W_vl_WW = jnp.zeros(3)
|
1013
|
-
|
1014
|
-
elif self.velocity_representation is VelRepr.Body:
|
1015
|
-
W_H_C = W_H_B = self.base_transform()
|
1016
|
-
W_vl_WC = W_vl_WB = self.base_velocity()[0:3]
|
1017
|
-
|
1018
|
-
elif self.velocity_representation is VelRepr.Mixed:
|
1019
|
-
W_H_B = self.base_transform()
|
1020
|
-
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
1021
|
-
W_vl_WC = W_vl_W_BW = self.base_velocity()[0:3]
|
1022
|
-
|
1023
|
-
else:
|
1024
|
-
raise ValueError(self.velocity_representation)
|
1025
|
-
|
1026
|
-
# We need to convert the derivative of the base acceleration to the Inertial
|
1027
|
-
# representation. In Mixed representation, this conversion is not a plain
|
1028
|
-
# transformation with just X, but it also involves a cross product in ℝ⁶.
|
1029
|
-
W_v̇_WB = to_inertial(
|
1030
|
-
C_vd_WB=base_acceleration,
|
1031
|
-
W_H_C=W_H_C,
|
1032
|
-
C_v_WB=self.base_velocity(),
|
1033
|
-
W_vl_WC=W_vl_WC,
|
1034
|
-
)
|
1035
|
-
|
1036
|
-
# Compute RNEA
|
1037
|
-
W_f_B, tau = jaxsim.physics.algos.rnea.rnea(
|
1038
|
-
model=self.physics_model,
|
1039
|
-
xfb=self.data.model_state.xfb(),
|
1040
|
-
q=self.data.model_state.joint_positions,
|
1041
|
-
qd=self.data.model_state.joint_velocities,
|
1042
|
-
qdd=joint_accelerations,
|
1043
|
-
a0fb=W_v̇_WB,
|
1044
|
-
f_ext=self.data.model_input.f_ext,
|
1045
|
-
)
|
1046
|
-
|
1047
|
-
# Adjust shape
|
1048
|
-
tau = jnp.atleast_1d(tau.squeeze())
|
1049
|
-
|
1050
|
-
# Express W_f_B in the active representation
|
1051
|
-
f_B = self.inertial_to_active_representation(array=W_f_B, is_force=True)
|
1052
|
-
|
1053
|
-
return f_B, tau
|
1054
|
-
|
1055
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["prefer_aba"])
|
1056
|
-
def forward_dynamics(
|
1057
|
-
self, tau: jtp.Vector | None = None, prefer_aba: float = True
|
1058
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
1059
|
-
""""""
|
1060
|
-
|
1061
|
-
return (
|
1062
|
-
self.forward_dynamics_aba(tau=tau)
|
1063
|
-
if prefer_aba
|
1064
|
-
else self.forward_dynamics_crb(tau=tau)
|
1065
|
-
)
|
1066
|
-
|
1067
|
-
@functools.partial(oop.jax_tf.method_ro)
|
1068
|
-
def forward_dynamics_aba(
|
1069
|
-
self, tau: jtp.Vector | None = None
|
1070
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
1071
|
-
""""""
|
1072
|
-
|
1073
|
-
# Build joint torques if not provided
|
1074
|
-
tau = tau if tau is not None else jnp.zeros_like(self.joint_positions())
|
1075
|
-
|
1076
|
-
# Compute ABA
|
1077
|
-
W_v̇_WB, s̈ = jaxsim.physics.algos.aba.aba(
|
1078
|
-
model=self.physics_model,
|
1079
|
-
xfb=self.data.model_state.xfb(),
|
1080
|
-
q=self.data.model_state.joint_positions,
|
1081
|
-
qd=self.data.model_state.joint_velocities,
|
1082
|
-
tau=tau,
|
1083
|
-
f_ext=self.data.model_input.f_ext,
|
1084
|
-
)
|
1085
|
-
|
1086
|
-
def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
|
1087
|
-
C_X_W = sixd.se3.SE3.from_matrix(W_H_C).inverse().adjoint()
|
1088
|
-
|
1089
|
-
if self.velocity_representation != VelRepr.Mixed:
|
1090
|
-
return C_X_W @ W_vd_WB
|
1091
|
-
else:
|
1092
|
-
from jaxsim.math.cross import Cross
|
1093
|
-
|
1094
|
-
W_v_WC = jnp.hstack([W_vl_WC, jnp.zeros(3)])
|
1095
|
-
return C_X_W @ (W_vd_WB - Cross.vx(W_v_WC) @ W_v_WB)
|
1096
|
-
|
1097
|
-
if self.velocity_representation is VelRepr.Inertial:
|
1098
|
-
W_H_C = W_H_W = jnp.eye(4)
|
1099
|
-
W_vl_WC = W_vl_WW = jnp.zeros(3)
|
1100
|
-
|
1101
|
-
elif self.velocity_representation is VelRepr.Body:
|
1102
|
-
W_H_C = W_H_B = self.base_transform()
|
1103
|
-
W_vl_WC = W_vl_WB = self.base_velocity()[0:3]
|
1104
|
-
|
1105
|
-
elif self.velocity_representation is VelRepr.Mixed:
|
1106
|
-
W_H_B = self.base_transform()
|
1107
|
-
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
1108
|
-
W_vl_WC = W_vl_W_BW = self.base_velocity()[0:3]
|
1109
|
-
|
1110
|
-
else:
|
1111
|
-
raise ValueError(self.velocity_representation)
|
1112
|
-
|
1113
|
-
# We need to convert the derivative of the base acceleration to the active
|
1114
|
-
# representation. In Mixed representation, this conversion is not a plain
|
1115
|
-
# transformation with just X, but it also involves a cross product in ℝ⁶.
|
1116
|
-
C_v̇_WB = to_active(
|
1117
|
-
W_vd_WB=W_v̇_WB.squeeze(),
|
1118
|
-
W_H_C=W_H_C,
|
1119
|
-
W_v_WB=jnp.hstack(
|
1120
|
-
[
|
1121
|
-
self.data.model_state.base_linear_velocity,
|
1122
|
-
self.data.model_state.base_angular_velocity,
|
1123
|
-
]
|
1124
|
-
),
|
1125
|
-
W_vl_WC=W_vl_WC,
|
1126
|
-
)
|
1127
|
-
|
1128
|
-
# Adjust shape
|
1129
|
-
s̈ = jnp.atleast_1d(s̈.squeeze())
|
1130
|
-
|
1131
|
-
return C_v̇_WB, s̈
|
1132
|
-
|
1133
|
-
@functools.partial(oop.jax_tf.method_ro)
|
1134
|
-
def forward_dynamics_crb(
|
1135
|
-
self, tau: jtp.Vector | None = None
|
1136
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
1137
|
-
""""""
|
1138
|
-
|
1139
|
-
# Build joint torques if not provided
|
1140
|
-
τ = tau if tau is not None else jnp.zeros(shape=(self.dofs(),))
|
1141
|
-
τ = jnp.atleast_1d(τ.squeeze())
|
1142
|
-
τ = jnp.vstack(τ) if τ.size > 0 else jnp.empty(shape=(0, 1))
|
1143
|
-
|
1144
|
-
# Extract motor parameters from the physics model
|
1145
|
-
GR = self.motor_gear_ratios()
|
1146
|
-
IM = self.motor_inertias()
|
1147
|
-
KV = jnp.diag(self.motor_viscous_frictions())
|
1148
|
-
|
1149
|
-
# Compute auxiliary quantities
|
1150
|
-
Γ = jnp.diag(GR)
|
1151
|
-
K̅ᵥ = Γ.T @ KV @ Γ
|
1152
|
-
|
1153
|
-
# Compute terms of the floating-base EoM
|
1154
|
-
M = self.free_floating_mass_matrix()
|
1155
|
-
h = jnp.vstack(self.free_floating_bias_forces())
|
1156
|
-
J = self.generalized_free_floating_jacobian()
|
1157
|
-
f_ext = jnp.vstack(self.external_forces().flatten())
|
1158
|
-
S = jnp.block([jnp.zeros(shape=(self.dofs(), 6)), jnp.eye(self.dofs())]).T
|
1159
|
-
|
1160
|
-
# Configure the slice for motors
|
1161
|
-
sl_m = np.s_[M.shape[0] - self.dofs() :]
|
1162
|
-
|
1163
|
-
# Add the motor related terms to the EoM
|
1164
|
-
M = M.at[sl_m, sl_m].set(M[sl_m, sl_m] + jnp.diag(Γ.T @ IM @ Γ))
|
1165
|
-
h = h.at[sl_m].set(h[sl_m] + K̅ᵥ @ self.joint_velocities()[:, None])
|
1166
|
-
S = S.at[sl_m].set(S[sl_m])
|
1167
|
-
|
1168
|
-
# Compute the generalized acceleration by inverting the EoM
|
1169
|
-
ν̇ = jax.lax.select(
|
1170
|
-
pred=self.floating_base(),
|
1171
|
-
on_true=jnp.linalg.inv(M) @ ((S @ τ) - h + J.T @ f_ext),
|
1172
|
-
on_false=jnp.vstack(
|
1173
|
-
[
|
1174
|
-
jnp.zeros(shape=(6, 1)),
|
1175
|
-
jnp.linalg.inv(M[6:, 6:])
|
1176
|
-
@ ((S @ τ)[6:] - h[6:] + J[:, 6:].T @ f_ext),
|
1177
|
-
]
|
1178
|
-
),
|
1179
|
-
).squeeze()
|
1180
|
-
|
1181
|
-
# Extract the base acceleration in the active representation.
|
1182
|
-
# Note that this is an apparent acceleration (relevant in Mixed representation),
|
1183
|
-
# therefore it cannot be always expressed in different frames with just a
|
1184
|
-
# 6D transformation X.
|
1185
|
-
v̇_WB = ν̇[0:6]
|
1186
|
-
|
1187
|
-
# Extract the joint accelerations
|
1188
|
-
s̈ = jnp.atleast_1d(ν̇[6:])
|
1189
|
-
|
1190
|
-
return v̇_WB, s̈
|
1191
|
-
|
1192
|
-
# ======
|
1193
|
-
# Energy
|
1194
|
-
# ======
|
1195
|
-
|
1196
|
-
@functools.partial(oop.jax_tf.method_ro)
|
1197
|
-
def mechanical_energy(self) -> jtp.Float:
|
1198
|
-
""""""
|
1199
|
-
|
1200
|
-
K = self.kinetic_energy()
|
1201
|
-
U = self.potential_energy()
|
1202
|
-
|
1203
|
-
return K + U
|
1204
|
-
|
1205
|
-
@functools.partial(oop.jax_tf.method_ro)
|
1206
|
-
def kinetic_energy(self) -> jtp.Float:
|
1207
|
-
""""""
|
1208
|
-
|
1209
|
-
with self.editable(validate=True) as m:
|
1210
|
-
m.set_velocity_representation(vel_repr=VelRepr.Body)
|
1211
|
-
|
1212
|
-
nu = m.generalized_velocity()
|
1213
|
-
M = m.free_floating_mass_matrix()
|
1214
|
-
|
1215
|
-
return 0.5 * nu.T @ M @ nu
|
1216
|
-
|
1217
|
-
@functools.partial(oop.jax_tf.method_ro)
|
1218
|
-
def potential_energy(self) -> jtp.Float:
|
1219
|
-
""""""
|
1220
|
-
|
1221
|
-
m = self.total_mass()
|
1222
|
-
W_p_CoM = jnp.hstack([self.com_position(), 1])
|
1223
|
-
gravity = self.physics_model.gravity[3:6].squeeze()
|
1224
|
-
|
1225
|
-
return -(m * jnp.hstack([gravity, 0]) @ W_p_CoM)
|
1226
|
-
|
1227
|
-
# ===========
|
1228
|
-
# Set targets
|
1229
|
-
# ===========
|
1230
|
-
|
1231
|
-
@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
|
1232
|
-
def set_joint_generalized_force_targets(
|
1233
|
-
self, forces: jtp.Vector, joint_names: tuple[str, ...] | None = None
|
1234
|
-
) -> None:
|
1235
|
-
""""""
|
1236
|
-
|
1237
|
-
if joint_names is None:
|
1238
|
-
joint_names = self.joint_names()
|
1239
|
-
|
1240
|
-
if forces.size != len(joint_names):
|
1241
|
-
raise ValueError("Wrong arguments size", forces.size, len(joint_names))
|
1242
|
-
|
1243
|
-
self.data.model_input.tau = self.data.model_input.tau.at[
|
1244
|
-
self._joint_indices(joint_names=joint_names)
|
1245
|
-
].set(forces)
|
1246
|
-
|
1247
|
-
# ==========
|
1248
|
-
# Reset data
|
1249
|
-
# ==========
|
1250
|
-
|
1251
|
-
@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
|
1252
|
-
def reset_joint_positions(
|
1253
|
-
self, positions: jtp.Vector, joint_names: tuple[str, ...] | None = None
|
1254
|
-
) -> None:
|
1255
|
-
""""""
|
1256
|
-
|
1257
|
-
if joint_names is None:
|
1258
|
-
joint_names = self.joint_names()
|
1259
|
-
|
1260
|
-
if positions.size != len(joint_names):
|
1261
|
-
raise ValueError("Wrong arguments size", positions.size, len(joint_names))
|
1262
|
-
|
1263
|
-
if positions.size == 0:
|
1264
|
-
return
|
1265
|
-
|
1266
|
-
# TODO: joint position limits
|
1267
|
-
|
1268
|
-
self.data.model_state.joint_positions = jnp.atleast_1d(
|
1269
|
-
jnp.array(
|
1270
|
-
self.data.model_state.joint_positions.at[
|
1271
|
-
self._joint_indices(joint_names=joint_names)
|
1272
|
-
].set(positions),
|
1273
|
-
dtype=float,
|
1274
|
-
)
|
1275
|
-
)
|
1276
|
-
|
1277
|
-
@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
|
1278
|
-
def reset_joint_velocities(
|
1279
|
-
self, velocities: jtp.Vector, joint_names: tuple[str, ...] | None = None
|
1280
|
-
) -> None:
|
1281
|
-
""""""
|
1282
|
-
|
1283
|
-
if joint_names is None:
|
1284
|
-
joint_names = self.joint_names()
|
1285
|
-
|
1286
|
-
if velocities.size != len(joint_names):
|
1287
|
-
raise ValueError("Wrong arguments size", velocities.size, len(joint_names))
|
1288
|
-
|
1289
|
-
if velocities.size == 0:
|
1290
|
-
return
|
1291
|
-
|
1292
|
-
# TODO: joint velocity limits
|
1293
|
-
|
1294
|
-
self.data.model_state.joint_velocities = jnp.atleast_1d(
|
1295
|
-
jnp.array(
|
1296
|
-
self.data.model_state.joint_velocities.at[
|
1297
|
-
self._joint_indices(joint_names=joint_names)
|
1298
|
-
].set(velocities),
|
1299
|
-
dtype=float,
|
1300
|
-
)
|
1301
|
-
)
|
1302
|
-
|
1303
|
-
@functools.partial(oop.jax_tf.method_rw)
|
1304
|
-
def reset_base_position(self, position: jtp.Vector) -> None:
|
1305
|
-
""""""
|
1306
|
-
|
1307
|
-
self.data.model_state.base_position = jnp.array(position, dtype=float)
|
1308
|
-
|
1309
|
-
@functools.partial(oop.jax_tf.method_rw, static_argnames=["dcm"])
|
1310
|
-
def reset_base_orientation(self, orientation: jtp.Array, dcm: bool = False) -> None:
|
1311
|
-
""""""
|
1312
|
-
|
1313
|
-
if dcm:
|
1314
|
-
to_wxyz = np.array([3, 0, 1, 2])
|
1315
|
-
orientation_xyzw = sixd.so3.SO3.from_matrix(
|
1316
|
-
orientation
|
1317
|
-
).as_quaternion_xyzw()
|
1318
|
-
orientation = orientation_xyzw[to_wxyz]
|
1319
|
-
|
1320
|
-
unit_quaternion = orientation / jnp.linalg.norm(orientation)
|
1321
|
-
self.data.model_state.base_quaternion = jnp.array(unit_quaternion, dtype=float)
|
1322
|
-
|
1323
|
-
@functools.partial(oop.jax_tf.method_rw)
|
1324
|
-
def reset_base_transform(self, transform: jtp.Matrix) -> None:
|
1325
|
-
""""""
|
1326
|
-
|
1327
|
-
if transform.shape != (4, 4):
|
1328
|
-
raise ValueError(transform.shape)
|
1329
|
-
|
1330
|
-
self.reset_base_position(position=transform[0:3, 3])
|
1331
|
-
self.reset_base_orientation(orientation=transform[0:3, 0:3], dcm=True)
|
1332
|
-
|
1333
|
-
@functools.partial(oop.jax_tf.method_rw)
|
1334
|
-
def reset_base_velocity(self, base_velocity: jtp.VectorJax) -> None:
|
1335
|
-
""""""
|
1336
|
-
|
1337
|
-
if not self.physics_model.is_floating_base:
|
1338
|
-
msg = "Changing the base velocity of a fixed-based model is not allowed"
|
1339
|
-
raise RuntimeError(msg)
|
1340
|
-
|
1341
|
-
# Remove extra dimensions
|
1342
|
-
base_velocity = base_velocity.squeeze()
|
1343
|
-
|
1344
|
-
# Check for a valid shape
|
1345
|
-
if base_velocity.shape != (6,):
|
1346
|
-
raise ValueError(base_velocity.shape)
|
1347
|
-
|
1348
|
-
# Convert, if needed, to the representation used internally (VelRepr.Inertial)
|
1349
|
-
if self.velocity_representation is VelRepr.Inertial:
|
1350
|
-
base_velocity_inertial = base_velocity
|
1351
|
-
|
1352
|
-
elif self.velocity_representation is VelRepr.Body:
|
1353
|
-
w_X_b = sixd.se3.SE3.from_rotation_and_translation(
|
1354
|
-
rotation=sixd.so3.SO3.from_matrix(self.base_orientation(dcm=True)),
|
1355
|
-
translation=self.base_position(),
|
1356
|
-
).adjoint()
|
1357
|
-
|
1358
|
-
base_velocity_inertial = w_X_b @ base_velocity
|
1359
|
-
|
1360
|
-
elif self.velocity_representation is VelRepr.Mixed:
|
1361
|
-
w_X_bw = sixd.se3.SE3.from_rotation_and_translation(
|
1362
|
-
rotation=sixd.so3.SO3.identity(),
|
1363
|
-
translation=self.base_position(),
|
1364
|
-
).adjoint()
|
1365
|
-
|
1366
|
-
base_velocity_inertial = w_X_bw @ base_velocity
|
1367
|
-
|
1368
|
-
else:
|
1369
|
-
raise ValueError(self.velocity_representation)
|
1370
|
-
|
1371
|
-
self.data.model_state.base_linear_velocity = jnp.array(
|
1372
|
-
base_velocity_inertial[0:3], dtype=float
|
1373
|
-
)
|
1374
|
-
|
1375
|
-
self.data.model_state.base_angular_velocity = jnp.array(
|
1376
|
-
base_velocity_inertial[3:6], dtype=float
|
1377
|
-
)
|
1378
|
-
|
1379
|
-
# ===========
|
1380
|
-
# Integration
|
1381
|
-
# ===========
|
1382
|
-
|
1383
|
-
@functools.partial(
|
1384
|
-
oop.jax_tf.method_rw,
|
1385
|
-
static_argnames=["sub_steps", "integrator_type", "terrain"],
|
1386
|
-
vmap_in_axes=(0, 0, 0, None, None, None, 0, None),
|
1387
|
-
)
|
1388
|
-
def integrate(
|
1389
|
-
self,
|
1390
|
-
t0: jtp.Float,
|
1391
|
-
tf: jtp.Float,
|
1392
|
-
sub_steps: int = 1,
|
1393
|
-
integrator_type: Optional[
|
1394
|
-
"jaxsim.simulation.ode_integration.IntegratorType"
|
1395
|
-
] = None,
|
1396
|
-
terrain: soft_contacts.Terrain = soft_contacts.FlatTerrain(),
|
1397
|
-
contact_parameters: soft_contacts.SoftContactsParams = soft_contacts.SoftContactsParams(),
|
1398
|
-
clear_inputs: bool = False,
|
1399
|
-
) -> StepData:
|
1400
|
-
""""""
|
1401
|
-
|
1402
|
-
from jaxsim.simulation import ode_data, ode_integration
|
1403
|
-
from jaxsim.simulation.ode_integration import IntegratorType
|
1404
|
-
|
1405
|
-
if integrator_type is None:
|
1406
|
-
integrator_type = IntegratorType.EulerForward
|
1407
|
-
|
1408
|
-
x0 = ode_integration.ode.ode_data.ODEState(
|
1409
|
-
physics_model=self.data.model_state,
|
1410
|
-
soft_contacts=self.data.contact_state,
|
1411
|
-
)
|
1412
|
-
|
1413
|
-
ode_input = ode_integration.ode.ode_data.ODEInput(
|
1414
|
-
physics_model=self.data.model_input
|
1415
|
-
)
|
1416
|
-
|
1417
|
-
assert isinstance(integrator_type, IntegratorType)
|
1418
|
-
|
1419
|
-
# Integrate the model dynamics
|
1420
|
-
ode_states, aux = ode_integration.ode_integration_fixed_step(
|
1421
|
-
x0=x0,
|
1422
|
-
t=jnp.array([t0, tf], dtype=float),
|
1423
|
-
ode_input=ode_input,
|
1424
|
-
physics_model=self.physics_model,
|
1425
|
-
soft_contacts_params=contact_parameters,
|
1426
|
-
num_sub_steps=sub_steps,
|
1427
|
-
terrain=terrain,
|
1428
|
-
integrator_type=integrator_type,
|
1429
|
-
return_aux=True,
|
1430
|
-
)
|
1431
|
-
|
1432
|
-
# Get quantities at t0
|
1433
|
-
t0_model_data = self.data
|
1434
|
-
t0_model_input = jax.tree_util.tree_map(
|
1435
|
-
lambda l: l[0],
|
1436
|
-
aux["ode_input"],
|
1437
|
-
)
|
1438
|
-
t0_model_input_real = jax.tree_util.tree_map(
|
1439
|
-
lambda l: l[0],
|
1440
|
-
aux["ode_input_real"],
|
1441
|
-
)
|
1442
|
-
t0_model_acceleration = jax.tree_util.tree_map(
|
1443
|
-
lambda l: l[0],
|
1444
|
-
aux["model_acceleration"],
|
1445
|
-
)
|
1446
|
-
|
1447
|
-
# Get quantities at tf
|
1448
|
-
ode_states: ode_data.ODEState
|
1449
|
-
tf_model_state = jax.tree_util.tree_map(
|
1450
|
-
lambda l: l[-1], ode_states.physics_model
|
1451
|
-
)
|
1452
|
-
tf_contact_state = jax.tree_util.tree_map(
|
1453
|
-
lambda l: l[-1], ode_states.soft_contacts
|
1454
|
-
)
|
1455
|
-
|
1456
|
-
# Clear user inputs (joint torques and external forces) if asked
|
1457
|
-
model_input = jax.lax.cond(
|
1458
|
-
pred=clear_inputs,
|
1459
|
-
false_fun=lambda: t0_model_input.physics_model,
|
1460
|
-
true_fun=lambda: jaxsim.physics.model.physics_model_state.PhysicsModelInput.zero(
|
1461
|
-
physics_model=self.physics_model
|
1462
|
-
),
|
1463
|
-
)
|
1464
|
-
|
1465
|
-
# Update model state
|
1466
|
-
self.data = ModelData(
|
1467
|
-
model_state=tf_model_state,
|
1468
|
-
contact_state=tf_contact_state,
|
1469
|
-
model_input=model_input,
|
1470
|
-
)
|
1471
|
-
|
1472
|
-
return StepData(
|
1473
|
-
t0=t0,
|
1474
|
-
tf=tf,
|
1475
|
-
dt=(tf - t0),
|
1476
|
-
t0_model_data=t0_model_data,
|
1477
|
-
t0_model_input_real=t0_model_input_real.physics_model,
|
1478
|
-
t0_base_acceleration=t0_model_acceleration[0:6],
|
1479
|
-
t0_joint_acceleration=t0_model_acceleration[6:],
|
1480
|
-
tf_model_state=tf_model_state,
|
1481
|
-
tf_contact_state=tf_contact_state,
|
1482
|
-
aux={
|
1483
|
-
"t0": jax.tree_util.tree_map(
|
1484
|
-
lambda l: l[0],
|
1485
|
-
aux,
|
1486
|
-
),
|
1487
|
-
"tf": jax.tree_util.tree_map(
|
1488
|
-
lambda l: l[-1],
|
1489
|
-
aux,
|
1490
|
-
),
|
1491
|
-
},
|
1492
|
-
)
|
1493
|
-
|
1494
|
-
# ==============
|
1495
|
-
# Motor dynamics
|
1496
|
-
# ==============
|
1497
|
-
|
1498
|
-
@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
|
1499
|
-
def set_motor_inertias(
|
1500
|
-
self, inertias: jtp.Vector, joint_names: tuple[str, ...] | None = None
|
1501
|
-
) -> None:
|
1502
|
-
joint_names = joint_names or self.joint_names()
|
1503
|
-
|
1504
|
-
if inertias.size != len(joint_names):
|
1505
|
-
raise ValueError("Wrong arguments size", inertias.size, len(joint_names))
|
1506
|
-
|
1507
|
-
self.physics_model._joint_motor_inertia.update(
|
1508
|
-
dict(zip(self.physics_model._joint_motor_inertia, inertias))
|
1509
|
-
)
|
1510
|
-
|
1511
|
-
logging.info("Setting attribute `motor_inertias`")
|
1512
|
-
|
1513
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False)
|
1514
|
-
def set_motor_gear_ratios(
|
1515
|
-
self, gear_ratios: jtp.Vector, joint_names: tuple[str, ...] | None = None
|
1516
|
-
) -> None:
|
1517
|
-
joint_names = joint_names or self.joint_names()
|
1518
|
-
|
1519
|
-
if gear_ratios.size != len(joint_names):
|
1520
|
-
raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names))
|
1521
|
-
|
1522
|
-
# Check on gear ratios if motor_inertias are not zero
|
1523
|
-
for idx, gr in enumerate(gear_ratios):
|
1524
|
-
if gr != 0 and self.motor_inertias()[idx] == 0:
|
1525
|
-
raise ValueError(
|
1526
|
-
f"Zero motor inertia with non-zero gear ratio found in position {idx}"
|
1527
|
-
)
|
1528
|
-
|
1529
|
-
self.physics_model._joint_motor_gear_ratio.update(
|
1530
|
-
dict(zip(self.physics_model._joint_motor_gear_ratio, gear_ratios))
|
1531
|
-
)
|
1532
|
-
|
1533
|
-
logging.info("Setting attribute `motor_gear_ratios`")
|
1534
|
-
|
1535
|
-
@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
|
1536
|
-
def set_motor_viscous_frictions(
|
1537
|
-
self,
|
1538
|
-
viscous_frictions: jtp.Vector,
|
1539
|
-
joint_names: tuple[str, ...] | None = None,
|
1540
|
-
) -> None:
|
1541
|
-
joint_names = joint_names or self.joint_names()
|
1542
|
-
|
1543
|
-
if viscous_frictions.size != len(joint_names):
|
1544
|
-
raise ValueError(
|
1545
|
-
"Wrong arguments size", viscous_frictions.size, len(joint_names)
|
1546
|
-
)
|
1547
|
-
|
1548
|
-
self.physics_model._joint_motor_viscous_friction.update(
|
1549
|
-
dict(
|
1550
|
-
zip(
|
1551
|
-
self.physics_model._joint_motor_viscous_friction,
|
1552
|
-
viscous_frictions,
|
1553
|
-
)
|
1554
|
-
)
|
1555
|
-
)
|
1556
|
-
|
1557
|
-
logging.info("Setting attribute `motor_viscous_frictions`")
|
1558
|
-
|
1559
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
1560
|
-
def motor_inertias(self) -> jtp.Vector:
|
1561
|
-
return jnp.array(
|
1562
|
-
[*self.physics_model._joint_motor_inertia.values()], dtype=float
|
1563
|
-
)
|
1564
|
-
|
1565
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
1566
|
-
def motor_gear_ratios(self) -> jtp.Vector:
|
1567
|
-
return jnp.array(
|
1568
|
-
[*self.physics_model._joint_motor_gear_ratio.values()], dtype=float
|
1569
|
-
)
|
1570
|
-
|
1571
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
1572
|
-
def motor_viscous_frictions(self) -> jtp.Vector:
|
1573
|
-
return jnp.array(
|
1574
|
-
[*self.physics_model._joint_motor_viscous_friction.values()], dtype=float
|
1575
|
-
)
|
1576
|
-
|
1577
|
-
# ===============
|
1578
|
-
# Private methods
|
1579
|
-
# ===============
|
1580
|
-
|
1581
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["is_force"])
|
1582
|
-
def inertial_to_active_representation(
|
1583
|
-
self, array: jtp.Array, is_force: bool = False
|
1584
|
-
) -> jtp.Array:
|
1585
|
-
""""""
|
1586
|
-
|
1587
|
-
W_array = array.squeeze()
|
1588
|
-
|
1589
|
-
if W_array.size != 6:
|
1590
|
-
raise ValueError(W_array.size)
|
1591
|
-
|
1592
|
-
if self.velocity_representation is VelRepr.Inertial:
|
1593
|
-
return W_array
|
1594
|
-
|
1595
|
-
elif self.velocity_representation is VelRepr.Body:
|
1596
|
-
W_H_B = self.base_transform()
|
1597
|
-
|
1598
|
-
if not is_force:
|
1599
|
-
B_Xv_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
|
1600
|
-
B_array = B_Xv_W @ W_array
|
1601
|
-
|
1602
|
-
else:
|
1603
|
-
B_Xf_W = sixd.se3.SE3.from_matrix(W_H_B).adjoint().T
|
1604
|
-
B_array = B_Xf_W @ W_array
|
1605
|
-
|
1606
|
-
return B_array
|
1607
|
-
|
1608
|
-
elif self.velocity_representation is VelRepr.Mixed:
|
1609
|
-
W_H_BW = jnp.eye(4).at[0:3, 3].set(self.base_position())
|
1610
|
-
|
1611
|
-
if not is_force:
|
1612
|
-
BW_Xv_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
|
1613
|
-
BW_array = BW_Xv_W @ W_array
|
1614
|
-
|
1615
|
-
else:
|
1616
|
-
BW_Xf_W = sixd.se3.SE3.from_matrix(W_H_BW).adjoint().T
|
1617
|
-
BW_array = BW_Xf_W @ W_array
|
1618
|
-
|
1619
|
-
return BW_array
|
1620
|
-
|
1621
|
-
else:
|
1622
|
-
raise ValueError(self.velocity_representation)
|
1623
|
-
|
1624
|
-
@functools.partial(oop.jax_tf.method_ro, static_argnames=["is_force"])
|
1625
|
-
def active_to_inertial_representation(
|
1626
|
-
self, array: jtp.Array, is_force: bool = False
|
1627
|
-
) -> jtp.Array:
|
1628
|
-
""""""
|
1629
|
-
|
1630
|
-
array = array.squeeze()
|
1631
|
-
|
1632
|
-
if array.size != 6:
|
1633
|
-
raise ValueError(array.size)
|
1634
|
-
|
1635
|
-
if self.velocity_representation is VelRepr.Inertial:
|
1636
|
-
W_array = array
|
1637
|
-
return W_array
|
1638
|
-
|
1639
|
-
elif self.velocity_representation is VelRepr.Body:
|
1640
|
-
B_array = array
|
1641
|
-
W_H_B = self.base_transform()
|
1642
|
-
|
1643
|
-
if not is_force:
|
1644
|
-
W_Xv_B: jtp.Array = sixd.se3.SE3.from_matrix(W_H_B).adjoint()
|
1645
|
-
W_array = W_Xv_B @ B_array
|
1646
|
-
|
1647
|
-
else:
|
1648
|
-
W_Xf_B = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint().T
|
1649
|
-
W_array = W_Xf_B @ B_array
|
1650
|
-
|
1651
|
-
return W_array
|
1652
|
-
|
1653
|
-
elif self.velocity_representation is VelRepr.Mixed:
|
1654
|
-
BW_array = array
|
1655
|
-
W_H_BW = jnp.eye(4).at[0:3, 3].set(self.base_position())
|
1656
|
-
|
1657
|
-
if not is_force:
|
1658
|
-
W_Xv_BW: jtp.Array = sixd.se3.SE3.from_matrix(W_H_BW).adjoint()
|
1659
|
-
W_array = W_Xv_BW @ BW_array
|
1660
|
-
|
1661
|
-
else:
|
1662
|
-
W_Xf_BW = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint().T
|
1663
|
-
W_array = W_Xf_BW @ BW_array
|
1664
|
-
|
1665
|
-
return W_array
|
1666
|
-
|
1667
|
-
else:
|
1668
|
-
raise ValueError(self.velocity_representation)
|
1669
|
-
|
1670
|
-
def _joint_indices(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector:
|
1671
|
-
""""""
|
1672
|
-
|
1673
|
-
if joint_names is None:
|
1674
|
-
joint_names = self.joint_names()
|
1675
|
-
|
1676
|
-
if set(joint_names) - set(self.joint_names()) != set():
|
1677
|
-
raise ValueError("One or more joint names are not part of the model")
|
1678
|
-
|
1679
|
-
# Note: joints share the same index as their child link, therefore the first
|
1680
|
-
# joint has index=1. We need to subtract one to get the right entry of
|
1681
|
-
# data stored in the PhysicsModelState arrays.
|
1682
|
-
joint_indices = [
|
1683
|
-
j.joint_description.index - 1 for j in self.joints(joint_names=joint_names)
|
1684
|
-
]
|
1685
|
-
|
1686
|
-
return np.array(joint_indices, dtype=int)
|