jaxsim 0.2.dev101__py3-none-any.whl → 0.2.dev166__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +1 -0
- jaxsim/api/contact.py +194 -0
- jaxsim/api/data.py +951 -0
- jaxsim/api/joint.py +148 -0
- jaxsim/api/link.py +262 -0
- jaxsim/api/model.py +1099 -0
- jaxsim/api/ode.py +280 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +508 -0
- jaxsim/integrators/fixed_step.py +158 -0
- jaxsim/mujoco/__init__.py +1 -1
- jaxsim/mujoco/loaders.py +30 -18
- jaxsim/mujoco/visualizer.py +3 -1
- jaxsim/physics/algos/soft_contacts.py +97 -28
- jaxsim/physics/model/physics_model.py +30 -0
- jaxsim/physics/model/physics_model_state.py +110 -11
- jaxsim/simulation/ode_data.py +43 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/METADATA +2 -1
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/RECORD +23 -13
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/WHEEL +0 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/top_level.txt +0 -0
jaxsim/api/data.py
ADDED
@@ -0,0 +1,951 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import dataclasses
|
5
|
+
import functools
|
6
|
+
from typing import ContextManager, Sequence
|
7
|
+
|
8
|
+
import jax
|
9
|
+
import jax.numpy as jnp
|
10
|
+
import jax_dataclasses
|
11
|
+
import jaxlie
|
12
|
+
import numpy as np
|
13
|
+
from jax_dataclasses import Static
|
14
|
+
|
15
|
+
import jaxsim.api
|
16
|
+
import jaxsim.physics.algos.aba
|
17
|
+
import jaxsim.physics.algos.crba
|
18
|
+
import jaxsim.physics.algos.forward_kinematics
|
19
|
+
import jaxsim.physics.algos.rnea
|
20
|
+
import jaxsim.physics.model.physics_model
|
21
|
+
import jaxsim.physics.model.physics_model_state
|
22
|
+
import jaxsim.typing as jtp
|
23
|
+
from jaxsim.high_level.common import VelRepr
|
24
|
+
from jaxsim.physics.algos import soft_contacts
|
25
|
+
from jaxsim.simulation.ode_data import ODEState
|
26
|
+
from jaxsim.utils import JaxsimDataclass, Mutability
|
27
|
+
|
28
|
+
try:
|
29
|
+
from typing import Self
|
30
|
+
except ImportError:
|
31
|
+
from typing_extensions import Self
|
32
|
+
|
33
|
+
|
34
|
+
@jax_dataclasses.pytree_dataclass
|
35
|
+
class JaxSimModelData(JaxsimDataclass):
|
36
|
+
"""
|
37
|
+
Class containing the state of a `JaxSimModel` object.
|
38
|
+
"""
|
39
|
+
|
40
|
+
state: ODEState
|
41
|
+
|
42
|
+
gravity: jtp.Array
|
43
|
+
|
44
|
+
soft_contacts_params: soft_contacts.SoftContactsParams = dataclasses.field(
|
45
|
+
repr=False
|
46
|
+
)
|
47
|
+
time_ns: jtp.Int = dataclasses.field(
|
48
|
+
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
|
49
|
+
)
|
50
|
+
|
51
|
+
velocity_representation: Static[VelRepr] = VelRepr.Inertial
|
52
|
+
|
53
|
+
def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool:
|
54
|
+
"""
|
55
|
+
Check if the current state is valid for the given model.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
model: The model to check against.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
`True` if the current state is valid for the given model, `False` otherwise.
|
62
|
+
"""
|
63
|
+
|
64
|
+
valid = True
|
65
|
+
|
66
|
+
if model is not None:
|
67
|
+
valid = valid and self.state.valid(physics_model=model.physics_model)
|
68
|
+
|
69
|
+
return valid
|
70
|
+
|
71
|
+
@contextlib.contextmanager
|
72
|
+
def switch_velocity_representation(
|
73
|
+
self, velocity_representation: VelRepr
|
74
|
+
) -> ContextManager[Self]:
|
75
|
+
"""
|
76
|
+
Context manager to temporarily switch the velocity representation.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
velocity_representation: The new velocity representation.
|
80
|
+
|
81
|
+
Yields:
|
82
|
+
The same `JaxSimModelData` object with the new velocity representation.
|
83
|
+
"""
|
84
|
+
|
85
|
+
original_representation = self.velocity_representation
|
86
|
+
|
87
|
+
try:
|
88
|
+
|
89
|
+
# First, we replace the velocity representation
|
90
|
+
with self.mutable_context(
|
91
|
+
mutability=Mutability.MUTABLE_NO_VALIDATION,
|
92
|
+
restore_after_exception=True,
|
93
|
+
):
|
94
|
+
self.velocity_representation = velocity_representation
|
95
|
+
|
96
|
+
# Then, we yield the data with changed representation.
|
97
|
+
# We run this in a mutable context with restoration so that any exception
|
98
|
+
# occurring, we restore the original object in case it was modified.
|
99
|
+
with self.mutable_context(
|
100
|
+
mutability=self._mutability(), restore_after_exception=True
|
101
|
+
):
|
102
|
+
yield self
|
103
|
+
|
104
|
+
finally:
|
105
|
+
with self.mutable_context(
|
106
|
+
mutability=Mutability.MUTABLE_NO_VALIDATION,
|
107
|
+
restore_after_exception=True,
|
108
|
+
):
|
109
|
+
self.velocity_representation = original_representation
|
110
|
+
|
111
|
+
@staticmethod
|
112
|
+
def zero(
|
113
|
+
model: jaxsim.api.model.JaxSimModel,
|
114
|
+
velocity_representation: VelRepr = VelRepr.Inertial,
|
115
|
+
) -> JaxSimModelData:
|
116
|
+
"""
|
117
|
+
Create a `JaxSimModelData` object with zero state.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
model: The model for which to create the zero state.
|
121
|
+
velocity_representation: The velocity representation to use.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
A `JaxSimModelData` object with zero state.
|
125
|
+
"""
|
126
|
+
|
127
|
+
return JaxSimModelData.build(
|
128
|
+
model=model, velocity_representation=velocity_representation
|
129
|
+
)
|
130
|
+
|
131
|
+
@staticmethod
|
132
|
+
def build(
|
133
|
+
model: jaxsim.api.model.JaxSimModel,
|
134
|
+
base_position: jtp.Vector | None = None,
|
135
|
+
base_quaternion: jtp.Vector | None = None,
|
136
|
+
joint_positions: jtp.Vector | None = None,
|
137
|
+
base_linear_velocity: jtp.Vector | None = None,
|
138
|
+
base_angular_velocity: jtp.Vector | None = None,
|
139
|
+
joint_velocities: jtp.Vector | None = None,
|
140
|
+
gravity: jtp.Vector | None = None,
|
141
|
+
soft_contacts_state: soft_contacts.SoftContactsState | None = None,
|
142
|
+
soft_contacts_params: soft_contacts.SoftContactsParams | None = None,
|
143
|
+
velocity_representation: VelRepr = VelRepr.Inertial,
|
144
|
+
time: jtp.FloatLike | None = None,
|
145
|
+
) -> JaxSimModelData:
|
146
|
+
"""
|
147
|
+
Create a `JaxSimModelData` object with the given state.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
model: The model for which to create the state.
|
151
|
+
base_position: The base position.
|
152
|
+
base_quaternion: The base orientation as a quaternion.
|
153
|
+
joint_positions: The joint positions.
|
154
|
+
base_linear_velocity:
|
155
|
+
The base linear velocity in the selected representation.
|
156
|
+
base_angular_velocity:
|
157
|
+
The base angular velocity in the selected representation.
|
158
|
+
joint_velocities: The joint velocities.
|
159
|
+
gravity: The gravity 3D vector.
|
160
|
+
soft_contacts_state: The state of the soft contacts.
|
161
|
+
soft_contacts_params: The parameters of the soft contacts.
|
162
|
+
velocity_representation: The velocity representation to use.
|
163
|
+
time: The time at which the state is created.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
A `JaxSimModelData` object with the given state.
|
167
|
+
"""
|
168
|
+
|
169
|
+
base_position = jnp.array(
|
170
|
+
base_position if base_position is not None else jnp.zeros(3)
|
171
|
+
).squeeze()
|
172
|
+
|
173
|
+
base_quaternion = jnp.array(
|
174
|
+
base_quaternion
|
175
|
+
if base_quaternion is not None
|
176
|
+
else jnp.array([1.0, 0, 0, 0])
|
177
|
+
).squeeze()
|
178
|
+
|
179
|
+
base_linear_velocity = jnp.array(
|
180
|
+
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
|
181
|
+
).squeeze()
|
182
|
+
|
183
|
+
base_angular_velocity = jnp.array(
|
184
|
+
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
|
185
|
+
).squeeze()
|
186
|
+
|
187
|
+
gravity = jnp.array(
|
188
|
+
gravity if gravity is not None else model.physics_model.gravity[0:3]
|
189
|
+
).squeeze()
|
190
|
+
|
191
|
+
joint_positions = jnp.atleast_1d(
|
192
|
+
joint_positions.squeeze()
|
193
|
+
if joint_positions is not None
|
194
|
+
else jnp.zeros(model.dofs())
|
195
|
+
)
|
196
|
+
|
197
|
+
joint_velocities = jnp.atleast_1d(
|
198
|
+
joint_velocities.squeeze()
|
199
|
+
if joint_velocities is not None
|
200
|
+
else jnp.zeros(model.dofs())
|
201
|
+
)
|
202
|
+
|
203
|
+
time_ns = (
|
204
|
+
jnp.array(time * 1e9, dtype=jnp.uint64)
|
205
|
+
if time is not None
|
206
|
+
else jnp.array(0, dtype=jnp.uint64)
|
207
|
+
)
|
208
|
+
|
209
|
+
soft_contacts_params = (
|
210
|
+
soft_contacts_params
|
211
|
+
if soft_contacts_params is not None
|
212
|
+
else jaxsim.api.contact.estimate_good_soft_contacts_parameters(model=model)
|
213
|
+
)
|
214
|
+
|
215
|
+
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
216
|
+
translation=base_position,
|
217
|
+
rotation=jaxlie.SO3.from_quaternion_xyzw(
|
218
|
+
base_quaternion[jnp.array([1, 2, 3, 0])]
|
219
|
+
),
|
220
|
+
).as_matrix()
|
221
|
+
|
222
|
+
v_WB = JaxSimModelData.other_representation_to_inertial(
|
223
|
+
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
|
224
|
+
other_representation=velocity_representation,
|
225
|
+
transform=W_H_B,
|
226
|
+
is_force=False,
|
227
|
+
)
|
228
|
+
|
229
|
+
ode_state = ODEState.build(
|
230
|
+
physics_model=model.physics_model,
|
231
|
+
physics_model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState(
|
232
|
+
base_position=base_position.astype(float),
|
233
|
+
base_quaternion=base_quaternion.astype(float),
|
234
|
+
joint_positions=joint_positions.astype(float),
|
235
|
+
base_linear_velocity=v_WB[0:3].astype(float),
|
236
|
+
base_angular_velocity=v_WB[3:6].astype(float),
|
237
|
+
joint_velocities=joint_velocities.astype(float),
|
238
|
+
),
|
239
|
+
soft_contacts_state=soft_contacts_state,
|
240
|
+
)
|
241
|
+
|
242
|
+
if not ode_state.valid(physics_model=model.physics_model):
|
243
|
+
raise ValueError(ode_state)
|
244
|
+
|
245
|
+
return JaxSimModelData(
|
246
|
+
time_ns=time_ns,
|
247
|
+
state=ode_state,
|
248
|
+
gravity=gravity.astype(float),
|
249
|
+
soft_contacts_params=soft_contacts_params,
|
250
|
+
velocity_representation=velocity_representation,
|
251
|
+
)
|
252
|
+
|
253
|
+
# ==================
|
254
|
+
# Extract quantities
|
255
|
+
# ==================
|
256
|
+
|
257
|
+
def time(self) -> jtp.Float:
|
258
|
+
"""
|
259
|
+
Get the simulated time.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
The simulated time in seconds.
|
263
|
+
"""
|
264
|
+
|
265
|
+
return self.time_ns.astype(float) / 1e9
|
266
|
+
|
267
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
268
|
+
def joint_positions(
|
269
|
+
self,
|
270
|
+
model: jaxsim.api.model.JaxSimModel | None = None,
|
271
|
+
joint_names: tuple[str, ...] | None = None,
|
272
|
+
) -> jtp.Vector:
|
273
|
+
"""
|
274
|
+
Get the joint positions.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
model: The model to consider.
|
278
|
+
joint_names:
|
279
|
+
The names of the joints for which to get the positions. If `None`,
|
280
|
+
the positions of all joints are returned.
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
If no model and no joint names are provided, the joint positions as a
|
284
|
+
`(DoFs,)` vector corresponding to the serialization of the original
|
285
|
+
model used to build the data object.
|
286
|
+
If a model is provided and no joint names are provided, the joint positions
|
287
|
+
as a `(DoFs,)` vector corresponding to the serialization of the
|
288
|
+
provided model.
|
289
|
+
If a model and joint names are provided, the joint positions as a
|
290
|
+
`(len(joint_names),)` vector corresponding to the serialization of
|
291
|
+
the passed joint names vector.
|
292
|
+
"""
|
293
|
+
|
294
|
+
if model is None:
|
295
|
+
return self.state.physics_model.joint_positions
|
296
|
+
|
297
|
+
if not self.valid(model=model):
|
298
|
+
msg = "The data object is not compatible with the provided model"
|
299
|
+
raise ValueError(msg)
|
300
|
+
|
301
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
302
|
+
|
303
|
+
return self.state.physics_model.joint_positions[
|
304
|
+
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
|
305
|
+
]
|
306
|
+
|
307
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
308
|
+
def joint_velocities(
|
309
|
+
self,
|
310
|
+
model: jaxsim.api.model.JaxSimModel | None = None,
|
311
|
+
joint_names: tuple[str, ...] | None = None,
|
312
|
+
) -> jtp.Vector:
|
313
|
+
"""
|
314
|
+
Get the joint velocities.
|
315
|
+
|
316
|
+
Args:
|
317
|
+
model: The model to consider.
|
318
|
+
joint_names:
|
319
|
+
The names of the joints for which to get the velocities. If `None`,
|
320
|
+
the velocities of all joints are returned.
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
If no model and no joint names are provided, the joint velocities as a
|
324
|
+
`(DoFs,)` vector corresponding to the serialization of the original
|
325
|
+
model used to build the data object.
|
326
|
+
If a model is provided and no joint names are provided, the joint velocities
|
327
|
+
as a `(DoFs,)` vector corresponding to the serialization of the
|
328
|
+
provided model.
|
329
|
+
If a model and joint names are provided, the joint velocities as a
|
330
|
+
`(len(joint_names),)` vector corresponding to the serialization of
|
331
|
+
the passed joint names vector.
|
332
|
+
"""
|
333
|
+
|
334
|
+
if model is None:
|
335
|
+
return self.state.physics_model.joint_velocities
|
336
|
+
|
337
|
+
if not self.valid(model=model):
|
338
|
+
msg = "The data object is not compatible with the provided model"
|
339
|
+
raise ValueError(msg)
|
340
|
+
|
341
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
342
|
+
|
343
|
+
return self.state.physics_model.joint_velocities[
|
344
|
+
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
|
345
|
+
]
|
346
|
+
|
347
|
+
@jax.jit
|
348
|
+
def base_position(self) -> jtp.Vector:
|
349
|
+
"""
|
350
|
+
Get the base position.
|
351
|
+
|
352
|
+
Returns:
|
353
|
+
The base position.
|
354
|
+
"""
|
355
|
+
|
356
|
+
return self.state.physics_model.base_position.squeeze()
|
357
|
+
|
358
|
+
@functools.partial(jax.jit, static_argnames=["dcm"])
|
359
|
+
def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
|
360
|
+
"""
|
361
|
+
Get the base orientation.
|
362
|
+
|
363
|
+
Args:
|
364
|
+
dcm: Whether to return the orientation as a SO(3) matrix or quaternion.
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
The base orientation.
|
368
|
+
"""
|
369
|
+
|
370
|
+
# Always normalize the quaternion to avoid numerical issues.
|
371
|
+
# If the active scheme does not integrate the quaternion on its manifold,
|
372
|
+
# we introduce a Baumgarte stabilization to let the quaternion converge to
|
373
|
+
# a unit quaternion. In this case, it is not guaranteed that the quaternion
|
374
|
+
# stored in the state is a unit quaternion.
|
375
|
+
base_unit_quaternion = (
|
376
|
+
self.state.physics_model.base_quaternion.squeeze()
|
377
|
+
/ jnp.linalg.norm(self.state.physics_model.base_quaternion)
|
378
|
+
)
|
379
|
+
|
380
|
+
# Slice to convert quaternion wxyz -> xyzw
|
381
|
+
to_xyzw = np.array([1, 2, 3, 0])
|
382
|
+
|
383
|
+
return (
|
384
|
+
base_unit_quaternion
|
385
|
+
if not dcm
|
386
|
+
else jaxlie.SO3.from_quaternion_xyzw(
|
387
|
+
base_unit_quaternion[to_xyzw]
|
388
|
+
).as_matrix()
|
389
|
+
)
|
390
|
+
|
391
|
+
@jax.jit
|
392
|
+
def base_transform(self) -> jtp.MatrixJax:
|
393
|
+
"""
|
394
|
+
Get the base transform.
|
395
|
+
|
396
|
+
Returns:
|
397
|
+
The base transform as an SE(3) matrix.
|
398
|
+
"""
|
399
|
+
|
400
|
+
W_R_B = self.base_orientation(dcm=True)
|
401
|
+
W_p_B = jnp.vstack(self.base_position())
|
402
|
+
|
403
|
+
return jnp.vstack(
|
404
|
+
[
|
405
|
+
jnp.block([W_R_B, W_p_B]),
|
406
|
+
jnp.array([0, 0, 0, 1]),
|
407
|
+
]
|
408
|
+
)
|
409
|
+
|
410
|
+
@jax.jit
|
411
|
+
def base_velocity(self) -> jtp.Vector:
|
412
|
+
"""
|
413
|
+
Get the base 6D velocity.
|
414
|
+
|
415
|
+
Returns:
|
416
|
+
The base 6D velocity in the active representation.
|
417
|
+
"""
|
418
|
+
|
419
|
+
W_v_WB = jnp.hstack(
|
420
|
+
[
|
421
|
+
self.state.physics_model.base_linear_velocity,
|
422
|
+
self.state.physics_model.base_angular_velocity,
|
423
|
+
]
|
424
|
+
)
|
425
|
+
|
426
|
+
W_H_B = self.base_transform()
|
427
|
+
|
428
|
+
return (
|
429
|
+
JaxSimModelData.inertial_to_other_representation(
|
430
|
+
array=W_v_WB,
|
431
|
+
other_representation=self.velocity_representation,
|
432
|
+
transform=W_H_B,
|
433
|
+
is_force=False,
|
434
|
+
)
|
435
|
+
.squeeze()
|
436
|
+
.astype(float)
|
437
|
+
)
|
438
|
+
|
439
|
+
@jax.jit
|
440
|
+
def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
|
441
|
+
"""
|
442
|
+
Get the generalized position
|
443
|
+
:math:`\mathbf{q} = ({}^W \mathbf{H}_B, \mathbf{s}) \in \text{SO}(3) \times \mathbb{R}^n`.
|
444
|
+
|
445
|
+
Returns:
|
446
|
+
A tuple containing the base transform and the joint positions.
|
447
|
+
"""
|
448
|
+
|
449
|
+
return self.base_transform(), self.joint_positions()
|
450
|
+
|
451
|
+
@jax.jit
|
452
|
+
def generalized_velocity(self) -> jtp.Vector:
|
453
|
+
"""
|
454
|
+
Get the generalized velocity
|
455
|
+
:math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`
|
456
|
+
|
457
|
+
Returns:
|
458
|
+
The generalized velocity in the active representation.
|
459
|
+
"""
|
460
|
+
|
461
|
+
return (
|
462
|
+
jnp.hstack([self.base_velocity(), self.joint_velocities()])
|
463
|
+
.squeeze()
|
464
|
+
.astype(float)
|
465
|
+
)
|
466
|
+
|
467
|
+
# ================
|
468
|
+
# Store quantities
|
469
|
+
# ================
|
470
|
+
|
471
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
472
|
+
def reset_joint_positions(
|
473
|
+
self,
|
474
|
+
positions: jtp.VectorLike,
|
475
|
+
model: jaxsim.api.model.JaxSimModel | None = None,
|
476
|
+
joint_names: tuple[str, ...] | None = None,
|
477
|
+
) -> Self:
|
478
|
+
"""
|
479
|
+
Reset the joint positions.
|
480
|
+
|
481
|
+
Args:
|
482
|
+
positions: The joint positions.
|
483
|
+
model: The model to consider.
|
484
|
+
joint_names: The names of the joints for which to set the positions.
|
485
|
+
|
486
|
+
Returns:
|
487
|
+
The updated `JaxSimModelData` object.
|
488
|
+
"""
|
489
|
+
|
490
|
+
positions = jnp.array(positions)
|
491
|
+
|
492
|
+
def replace(s: jtp.VectorLike) -> JaxSimModelData:
|
493
|
+
return self.replace(
|
494
|
+
validate=True,
|
495
|
+
state=self.state.replace(
|
496
|
+
physics_model=self.state.physics_model.replace(
|
497
|
+
joint_positions=jnp.atleast_1d(s.squeeze()).astype(float)
|
498
|
+
)
|
499
|
+
),
|
500
|
+
)
|
501
|
+
|
502
|
+
if model is None:
|
503
|
+
return replace(s=positions)
|
504
|
+
|
505
|
+
if not self.valid(model=model):
|
506
|
+
msg = "The data object is not compatible with the provided model"
|
507
|
+
raise ValueError(msg)
|
508
|
+
|
509
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
510
|
+
|
511
|
+
return replace(
|
512
|
+
s=self.state.physics_model.joint_positions.at[
|
513
|
+
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
|
514
|
+
].set(positions)
|
515
|
+
)
|
516
|
+
|
517
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
518
|
+
def reset_joint_velocities(
|
519
|
+
self,
|
520
|
+
velocities: jtp.VectorLike,
|
521
|
+
model: jaxsim.api.model.JaxSimModel | None = None,
|
522
|
+
joint_names: tuple[str, ...] | None = None,
|
523
|
+
) -> Self:
|
524
|
+
"""
|
525
|
+
Reset the joint velocities.
|
526
|
+
|
527
|
+
Args:
|
528
|
+
velocities: The joint velocities.
|
529
|
+
model: The model to consider.
|
530
|
+
joint_names: The names of the joints for which to set the velocities.
|
531
|
+
|
532
|
+
Returns:
|
533
|
+
The updated `JaxSimModelData` object.
|
534
|
+
"""
|
535
|
+
|
536
|
+
velocities = jnp.array(velocities)
|
537
|
+
|
538
|
+
def replace(ṡ: jtp.VectorLike) -> JaxSimModelData:
|
539
|
+
return self.replace(
|
540
|
+
validate=True,
|
541
|
+
state=self.state.replace(
|
542
|
+
physics_model=self.state.physics_model.replace(
|
543
|
+
joint_velocities=jnp.atleast_1d(ṡ.squeeze()).astype(float)
|
544
|
+
)
|
545
|
+
),
|
546
|
+
)
|
547
|
+
|
548
|
+
if model is None:
|
549
|
+
return replace(ṡ=velocities)
|
550
|
+
|
551
|
+
if not self.valid(model=model):
|
552
|
+
msg = "The data object is not compatible with the provided model"
|
553
|
+
raise ValueError(msg)
|
554
|
+
|
555
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
556
|
+
|
557
|
+
return replace(
|
558
|
+
ṡ=self.state.physics_model.joint_velocities.at[
|
559
|
+
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
|
560
|
+
].set(velocities)
|
561
|
+
)
|
562
|
+
|
563
|
+
@jax.jit
|
564
|
+
def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
|
565
|
+
"""
|
566
|
+
Reset the base position.
|
567
|
+
|
568
|
+
Args:
|
569
|
+
base_position: The base position.
|
570
|
+
|
571
|
+
Returns:
|
572
|
+
The updated `JaxSimModelData` object.
|
573
|
+
"""
|
574
|
+
|
575
|
+
base_position = jnp.array(base_position)
|
576
|
+
|
577
|
+
return self.replace(
|
578
|
+
validate=True,
|
579
|
+
state=self.state.replace(
|
580
|
+
physics_model=self.state.physics_model.replace(
|
581
|
+
base_position=jnp.atleast_1d(base_position.squeeze()).astype(float)
|
582
|
+
)
|
583
|
+
),
|
584
|
+
)
|
585
|
+
|
586
|
+
@jax.jit
|
587
|
+
def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
|
588
|
+
"""
|
589
|
+
Reset the base quaternion.
|
590
|
+
|
591
|
+
Args:
|
592
|
+
base_quaternion: The base orientation as a quaternion.
|
593
|
+
|
594
|
+
Returns:
|
595
|
+
The updated `JaxSimModelData` object.
|
596
|
+
"""
|
597
|
+
|
598
|
+
base_quaternion = jnp.array(base_quaternion)
|
599
|
+
|
600
|
+
return self.replace(
|
601
|
+
validate=True,
|
602
|
+
state=self.state.replace(
|
603
|
+
physics_model=self.state.physics_model.replace(
|
604
|
+
base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
|
605
|
+
float
|
606
|
+
)
|
607
|
+
)
|
608
|
+
),
|
609
|
+
)
|
610
|
+
|
611
|
+
@jax.jit
|
612
|
+
def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
|
613
|
+
"""
|
614
|
+
Reset the base pose.
|
615
|
+
|
616
|
+
Args:
|
617
|
+
base_pose: The base pose as an SE(3) matrix.
|
618
|
+
|
619
|
+
Returns:
|
620
|
+
The updated `JaxSimModelData` object.
|
621
|
+
"""
|
622
|
+
|
623
|
+
base_pose = jnp.array(base_pose)
|
624
|
+
|
625
|
+
W_p_B = base_pose[0:3, 3]
|
626
|
+
|
627
|
+
to_wxyz = np.array([3, 0, 1, 2])
|
628
|
+
W_R_B: jaxlie.SO3 = jaxlie.SO3.from_matrix(base_pose[0:3, 0:3]) # noqa
|
629
|
+
W_Q_B = W_R_B.as_quaternion_xyzw()[to_wxyz]
|
630
|
+
|
631
|
+
return self.reset_base_position(base_position=W_p_B).reset_base_quaternion(
|
632
|
+
base_quaternion=W_Q_B
|
633
|
+
)
|
634
|
+
|
635
|
+
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
636
|
+
def reset_base_linear_velocity(
|
637
|
+
self,
|
638
|
+
linear_velocity: jtp.VectorLike,
|
639
|
+
velocity_representation: VelRepr | None = None,
|
640
|
+
) -> Self:
|
641
|
+
"""
|
642
|
+
Reset the base linear velocity.
|
643
|
+
|
644
|
+
Args:
|
645
|
+
linear_velocity: The base linear velocity as a 3D array.
|
646
|
+
velocity_representation:
|
647
|
+
The velocity representation in which the base velocity is expressed.
|
648
|
+
If `None`, the active representation is considered.
|
649
|
+
|
650
|
+
Returns:
|
651
|
+
The updated `JaxSimModelData` object.
|
652
|
+
"""
|
653
|
+
|
654
|
+
linear_velocity = jnp.array(linear_velocity)
|
655
|
+
|
656
|
+
return self.reset_base_velocity(
|
657
|
+
base_velocity=jnp.hstack(
|
658
|
+
[linear_velocity.squeeze(), self.base_velocity()[3:6]]
|
659
|
+
),
|
660
|
+
velocity_representation=velocity_representation,
|
661
|
+
)
|
662
|
+
|
663
|
+
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
664
|
+
def reset_base_angular_velocity(
|
665
|
+
self,
|
666
|
+
angular_velocity: jtp.VectorLike,
|
667
|
+
velocity_representation: VelRepr | None = None,
|
668
|
+
) -> Self:
|
669
|
+
"""
|
670
|
+
Reset the base angular velocity.
|
671
|
+
|
672
|
+
Args:
|
673
|
+
angular_velocity: The base angular velocity as a 3D array.
|
674
|
+
velocity_representation:
|
675
|
+
The velocity representation in which the base velocity is expressed.
|
676
|
+
If `None`, the active representation is considered.
|
677
|
+
|
678
|
+
Returns:
|
679
|
+
The updated `JaxSimModelData` object.
|
680
|
+
"""
|
681
|
+
|
682
|
+
angular_velocity = jnp.array(angular_velocity)
|
683
|
+
|
684
|
+
return self.reset_base_velocity(
|
685
|
+
base_velocity=jnp.hstack(
|
686
|
+
[self.base_velocity()[0:3], angular_velocity.squeeze()]
|
687
|
+
),
|
688
|
+
velocity_representation=velocity_representation,
|
689
|
+
)
|
690
|
+
|
691
|
+
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
692
|
+
def reset_base_velocity(
|
693
|
+
self,
|
694
|
+
base_velocity: jtp.VectorLike,
|
695
|
+
velocity_representation: VelRepr | None = None,
|
696
|
+
) -> Self:
|
697
|
+
"""
|
698
|
+
Reset the base 6D velocity.
|
699
|
+
|
700
|
+
Args:
|
701
|
+
base_velocity: The base 6D velocity in the active representation.
|
702
|
+
velocity_representation:
|
703
|
+
The velocity representation in which the base velocity is expressed.
|
704
|
+
If `None`, the active representation is considered.
|
705
|
+
|
706
|
+
Returns:
|
707
|
+
The updated `JaxSimModelData` object.
|
708
|
+
"""
|
709
|
+
|
710
|
+
base_velocity = jnp.array(base_velocity)
|
711
|
+
|
712
|
+
velocity_representation = (
|
713
|
+
velocity_representation
|
714
|
+
if velocity_representation is not None
|
715
|
+
else self.velocity_representation
|
716
|
+
)
|
717
|
+
|
718
|
+
W_v_WB = self.other_representation_to_inertial(
|
719
|
+
array=jnp.atleast_1d(base_velocity.squeeze()).astype(float),
|
720
|
+
other_representation=velocity_representation,
|
721
|
+
transform=self.base_transform(),
|
722
|
+
is_force=False,
|
723
|
+
)
|
724
|
+
|
725
|
+
return self.replace(
|
726
|
+
validate=True,
|
727
|
+
state=self.state.replace(
|
728
|
+
physics_model=self.state.physics_model.replace(
|
729
|
+
base_linear_velocity=W_v_WB[0:3].squeeze().astype(float),
|
730
|
+
base_angular_velocity=W_v_WB[3:6].squeeze().astype(float),
|
731
|
+
)
|
732
|
+
),
|
733
|
+
)
|
734
|
+
|
735
|
+
# =============
|
736
|
+
# Other helpers
|
737
|
+
# =============
|
738
|
+
|
739
|
+
@staticmethod
|
740
|
+
@functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
|
741
|
+
def inertial_to_other_representation(
|
742
|
+
array: jtp.Array,
|
743
|
+
other_representation: VelRepr,
|
744
|
+
transform: jtp.Matrix,
|
745
|
+
is_force: bool = False,
|
746
|
+
) -> jtp.Array:
|
747
|
+
"""
|
748
|
+
Convert a 6D quantity from the inertial to another representation.
|
749
|
+
|
750
|
+
Args:
|
751
|
+
array: The 6D quantity to convert.
|
752
|
+
other_representation: The representation to convert to.
|
753
|
+
transform:
|
754
|
+
The `math:W \mathbf{H}_O` transform, where `math:O` is the
|
755
|
+
reference frame of the other representation.
|
756
|
+
is_force: Whether the quantity is a 6D force or 6D velocity.
|
757
|
+
|
758
|
+
Returns:
|
759
|
+
The 6D quantity in the other representation.
|
760
|
+
"""
|
761
|
+
|
762
|
+
W_array = array.squeeze()
|
763
|
+
W_H_O = transform.squeeze()
|
764
|
+
|
765
|
+
if W_array.size != 6:
|
766
|
+
raise ValueError(W_array.size, 6)
|
767
|
+
|
768
|
+
if W_H_O.shape != (4, 4):
|
769
|
+
raise ValueError(W_H_O.shape, (4, 4))
|
770
|
+
|
771
|
+
match other_representation:
|
772
|
+
|
773
|
+
case VelRepr.Inertial:
|
774
|
+
return W_array
|
775
|
+
|
776
|
+
case VelRepr.Body:
|
777
|
+
|
778
|
+
if not is_force:
|
779
|
+
O_Xv_W = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint()
|
780
|
+
O_array = O_Xv_W @ W_array
|
781
|
+
|
782
|
+
else:
|
783
|
+
O_Xf_W = jaxlie.SE3.from_matrix(W_H_O).adjoint().T
|
784
|
+
O_array = O_Xf_W @ W_array
|
785
|
+
|
786
|
+
return O_array
|
787
|
+
|
788
|
+
case VelRepr.Mixed:
|
789
|
+
W_p_O = W_H_O[0:3, 3]
|
790
|
+
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
|
791
|
+
|
792
|
+
if not is_force:
|
793
|
+
OW_Xv_W = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint()
|
794
|
+
OW_array = OW_Xv_W @ W_array
|
795
|
+
|
796
|
+
else:
|
797
|
+
OW_Xf_W = jaxlie.SE3.from_matrix(W_H_OW).adjoint().transpose()
|
798
|
+
OW_array = OW_Xf_W @ W_array
|
799
|
+
|
800
|
+
return OW_array
|
801
|
+
|
802
|
+
case _:
|
803
|
+
raise ValueError(other_representation)
|
804
|
+
|
805
|
+
@staticmethod
|
806
|
+
@functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
|
807
|
+
def other_representation_to_inertial(
|
808
|
+
array: jtp.Array,
|
809
|
+
other_representation: VelRepr,
|
810
|
+
transform: jtp.Matrix,
|
811
|
+
is_force: bool = False,
|
812
|
+
) -> jtp.Array:
|
813
|
+
"""
|
814
|
+
Convert a 6D quantity from another representation to the inertial.
|
815
|
+
|
816
|
+
Args:
|
817
|
+
array: The 6D quantity to convert.
|
818
|
+
other_representation: The representation to convert from.
|
819
|
+
transform:
|
820
|
+
The `math:W \mathbf{H}_O` transform, where `math:O` is the
|
821
|
+
reference frame of the other representation.
|
822
|
+
is_force: Whether the quantity is a 6D force or 6D velocity.
|
823
|
+
|
824
|
+
Returns:
|
825
|
+
The 6D quantity in the inertial representation.
|
826
|
+
"""
|
827
|
+
|
828
|
+
W_array = array.squeeze()
|
829
|
+
W_H_O = transform.squeeze()
|
830
|
+
|
831
|
+
if W_array.size != 6:
|
832
|
+
raise ValueError(W_array.size, 6)
|
833
|
+
|
834
|
+
if W_H_O.shape != (4, 4):
|
835
|
+
raise ValueError(W_H_O.shape, (4, 4))
|
836
|
+
|
837
|
+
match other_representation:
|
838
|
+
case VelRepr.Inertial:
|
839
|
+
W_array = array
|
840
|
+
return W_array
|
841
|
+
|
842
|
+
case VelRepr.Body:
|
843
|
+
O_array = array
|
844
|
+
|
845
|
+
if not is_force:
|
846
|
+
W_Xv_O: jtp.Array = jaxlie.SE3.from_matrix(W_H_O).adjoint()
|
847
|
+
W_array = W_Xv_O @ O_array
|
848
|
+
|
849
|
+
else:
|
850
|
+
W_Xf_O = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint().T
|
851
|
+
W_array = W_Xf_O @ O_array
|
852
|
+
|
853
|
+
return W_array
|
854
|
+
|
855
|
+
case VelRepr.Mixed:
|
856
|
+
BW_array = array
|
857
|
+
W_p_O = W_H_O[0:3, 3]
|
858
|
+
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
|
859
|
+
|
860
|
+
if not is_force:
|
861
|
+
W_Xv_BW: jtp.Array = jaxlie.SE3.from_matrix(W_H_OW).adjoint()
|
862
|
+
W_array = W_Xv_BW @ BW_array
|
863
|
+
|
864
|
+
else:
|
865
|
+
W_Xf_BW = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint().T
|
866
|
+
W_array = W_Xf_BW @ BW_array
|
867
|
+
|
868
|
+
return W_array
|
869
|
+
|
870
|
+
case _:
|
871
|
+
raise ValueError(other_representation)
|
872
|
+
|
873
|
+
|
874
|
+
def random_model_data(
|
875
|
+
model: jaxsim.api.model.JaxSimModel,
|
876
|
+
*,
|
877
|
+
key: jax.Array | None = None,
|
878
|
+
base_pos_bounds: tuple[
|
879
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
880
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
881
|
+
] = ((-1, -1, 0.5), 1.0),
|
882
|
+
base_vel_lin_bounds: tuple[
|
883
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
884
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
885
|
+
] = (-1.0, 1.0),
|
886
|
+
base_vel_ang_bounds: tuple[
|
887
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
888
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
889
|
+
] = (-1.0, 1.0),
|
890
|
+
joint_vel_bounds: tuple[
|
891
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
892
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
893
|
+
] = (-1.0, 1.0),
|
894
|
+
) -> JaxSimModelData:
|
895
|
+
"""
|
896
|
+
Randomly generate a `JaxSimModelData` object.
|
897
|
+
|
898
|
+
Args:
|
899
|
+
model: The target model for the random data.
|
900
|
+
key: The random key.
|
901
|
+
base_pos_bounds: The bounds for the base position.
|
902
|
+
base_vel_lin_bounds: The bounds for the base linear velocity.
|
903
|
+
base_vel_ang_bounds: The bounds for the base angular velocity.
|
904
|
+
joint_vel_bounds: The bounds for the joint velocities.
|
905
|
+
|
906
|
+
Returns:
|
907
|
+
A `JaxSimModelData` object with random data.
|
908
|
+
"""
|
909
|
+
|
910
|
+
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
911
|
+
k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6)
|
912
|
+
|
913
|
+
p_min = jnp.array(base_pos_bounds[0], dtype=float)
|
914
|
+
p_max = jnp.array(base_pos_bounds[1], dtype=float)
|
915
|
+
v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
|
916
|
+
v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
|
917
|
+
ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
|
918
|
+
ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float)
|
919
|
+
ṡ_min, ṡ_max = joint_vel_bounds
|
920
|
+
|
921
|
+
random_data = JaxSimModelData.zero(model=model)
|
922
|
+
|
923
|
+
with random_data.mutable_context(mutability=Mutability.MUTABLE):
|
924
|
+
|
925
|
+
physics_model_state = random_data.state.physics_model
|
926
|
+
|
927
|
+
physics_model_state.base_position = jax.random.uniform(
|
928
|
+
key=k1, shape=(3,), minval=p_min, maxval=p_max
|
929
|
+
)
|
930
|
+
|
931
|
+
physics_model_state.base_quaternion = jaxlie.SO3.from_rpy_radians(
|
932
|
+
*jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
|
933
|
+
).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]
|
934
|
+
|
935
|
+
physics_model_state.joint_positions = jaxsim.api.joint.random_joint_positions(
|
936
|
+
model=model, key=k3
|
937
|
+
)
|
938
|
+
|
939
|
+
physics_model_state.base_linear_velocity = jax.random.uniform(
|
940
|
+
key=k4, shape=(3,), minval=v_min, maxval=v_max
|
941
|
+
)
|
942
|
+
|
943
|
+
physics_model_state.base_angular_velocity = jax.random.uniform(
|
944
|
+
key=k5, shape=(3,), minval=ω_min, maxval=ω_max
|
945
|
+
)
|
946
|
+
|
947
|
+
physics_model_state.joint_velocities = jax.random.uniform(
|
948
|
+
key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
|
949
|
+
)
|
950
|
+
|
951
|
+
return random_data
|