jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1rc0.dist-info/METADATA +0 -167
- jaxsim-0.1rc0.dist-info/RECORD +0 -64
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/data.py
ADDED
@@ -0,0 +1,821 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
import functools
|
5
|
+
from typing import Sequence
|
6
|
+
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import jax_dataclasses
|
10
|
+
import jaxlie
|
11
|
+
import numpy as np
|
12
|
+
|
13
|
+
import jaxsim.api as js
|
14
|
+
import jaxsim.rbda
|
15
|
+
import jaxsim.typing as jtp
|
16
|
+
from jaxsim.math import Quaternion
|
17
|
+
from jaxsim.utils import Mutability
|
18
|
+
from jaxsim.utils.tracing import not_tracing
|
19
|
+
|
20
|
+
from . import common
|
21
|
+
from .common import VelRepr
|
22
|
+
from .ode_data import ODEState
|
23
|
+
|
24
|
+
try:
|
25
|
+
from typing import Self
|
26
|
+
except ImportError:
|
27
|
+
from typing_extensions import Self
|
28
|
+
|
29
|
+
|
30
|
+
@jax_dataclasses.pytree_dataclass
|
31
|
+
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
32
|
+
"""
|
33
|
+
Class containing the state of a `JaxSimModel` object.
|
34
|
+
"""
|
35
|
+
|
36
|
+
state: ODEState
|
37
|
+
|
38
|
+
gravity: jtp.Array
|
39
|
+
|
40
|
+
soft_contacts_params: jaxsim.rbda.SoftContactsParams = dataclasses.field(repr=False)
|
41
|
+
|
42
|
+
time_ns: jtp.Int = dataclasses.field(
|
43
|
+
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
|
44
|
+
)
|
45
|
+
|
46
|
+
def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
|
47
|
+
"""
|
48
|
+
Check if the current state is valid for the given model.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
model: The model to check against.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
`True` if the current state is valid for the given model, `False` otherwise.
|
55
|
+
"""
|
56
|
+
|
57
|
+
valid = True
|
58
|
+
valid = valid and self.standard_gravity() > 0
|
59
|
+
|
60
|
+
if model is not None:
|
61
|
+
valid = valid and self.state.valid(model=model)
|
62
|
+
|
63
|
+
return valid
|
64
|
+
|
65
|
+
@staticmethod
|
66
|
+
def zero(
|
67
|
+
model: js.model.JaxSimModel,
|
68
|
+
velocity_representation: VelRepr = VelRepr.Inertial,
|
69
|
+
) -> JaxSimModelData:
|
70
|
+
"""
|
71
|
+
Create a `JaxSimModelData` object with zero state.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
model: The model for which to create the zero state.
|
75
|
+
velocity_representation: The velocity representation to use.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
A `JaxSimModelData` object with zero state.
|
79
|
+
"""
|
80
|
+
|
81
|
+
return JaxSimModelData.build(
|
82
|
+
model=model, velocity_representation=velocity_representation
|
83
|
+
)
|
84
|
+
|
85
|
+
@staticmethod
|
86
|
+
def build(
|
87
|
+
model: js.model.JaxSimModel,
|
88
|
+
base_position: jtp.Vector | None = None,
|
89
|
+
base_quaternion: jtp.Vector | None = None,
|
90
|
+
joint_positions: jtp.Vector | None = None,
|
91
|
+
base_linear_velocity: jtp.Vector | None = None,
|
92
|
+
base_angular_velocity: jtp.Vector | None = None,
|
93
|
+
joint_velocities: jtp.Vector | None = None,
|
94
|
+
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
|
95
|
+
soft_contacts_state: js.ode_data.SoftContactsState | None = None,
|
96
|
+
soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
|
97
|
+
velocity_representation: VelRepr = VelRepr.Inertial,
|
98
|
+
time: jtp.FloatLike | None = None,
|
99
|
+
) -> JaxSimModelData:
|
100
|
+
"""
|
101
|
+
Create a `JaxSimModelData` object with the given state.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
model: The model for which to create the state.
|
105
|
+
base_position: The base position.
|
106
|
+
base_quaternion: The base orientation as a quaternion.
|
107
|
+
joint_positions: The joint positions.
|
108
|
+
base_linear_velocity:
|
109
|
+
The base linear velocity in the selected representation.
|
110
|
+
base_angular_velocity:
|
111
|
+
The base angular velocity in the selected representation.
|
112
|
+
joint_velocities: The joint velocities.
|
113
|
+
standard_gravity: The standard gravity constant.
|
114
|
+
soft_contacts_state: The state of the soft contacts.
|
115
|
+
soft_contacts_params: The parameters of the soft contacts.
|
116
|
+
velocity_representation: The velocity representation to use.
|
117
|
+
time: The time at which the state is created.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
A `JaxSimModelData` object with the given state.
|
121
|
+
"""
|
122
|
+
|
123
|
+
base_position = jnp.array(
|
124
|
+
base_position if base_position is not None else jnp.zeros(3)
|
125
|
+
).squeeze()
|
126
|
+
|
127
|
+
base_quaternion = jnp.array(
|
128
|
+
base_quaternion
|
129
|
+
if base_quaternion is not None
|
130
|
+
else jnp.array([1.0, 0, 0, 0])
|
131
|
+
).squeeze()
|
132
|
+
|
133
|
+
base_linear_velocity = jnp.array(
|
134
|
+
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
|
135
|
+
).squeeze()
|
136
|
+
|
137
|
+
base_angular_velocity = jnp.array(
|
138
|
+
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
|
139
|
+
).squeeze()
|
140
|
+
|
141
|
+
gravity = jnp.zeros(3).at[2].set(-standard_gravity)
|
142
|
+
|
143
|
+
joint_positions = jnp.atleast_1d(
|
144
|
+
joint_positions.squeeze()
|
145
|
+
if joint_positions is not None
|
146
|
+
else jnp.zeros(model.dofs())
|
147
|
+
)
|
148
|
+
|
149
|
+
joint_velocities = jnp.atleast_1d(
|
150
|
+
joint_velocities.squeeze()
|
151
|
+
if joint_velocities is not None
|
152
|
+
else jnp.zeros(model.dofs())
|
153
|
+
)
|
154
|
+
|
155
|
+
time_ns = (
|
156
|
+
jnp.array(time * 1e9, dtype=jnp.uint64)
|
157
|
+
if time is not None
|
158
|
+
else jnp.array(0, dtype=jnp.uint64)
|
159
|
+
)
|
160
|
+
|
161
|
+
soft_contacts_params = (
|
162
|
+
soft_contacts_params
|
163
|
+
if soft_contacts_params is not None
|
164
|
+
else js.contact.estimate_good_soft_contacts_parameters(
|
165
|
+
model=model, standard_gravity=standard_gravity
|
166
|
+
)
|
167
|
+
)
|
168
|
+
|
169
|
+
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
170
|
+
translation=base_position,
|
171
|
+
rotation=jaxlie.SO3.from_quaternion_xyzw(
|
172
|
+
base_quaternion[jnp.array([1, 2, 3, 0])]
|
173
|
+
),
|
174
|
+
).as_matrix()
|
175
|
+
|
176
|
+
v_WB = JaxSimModelData.other_representation_to_inertial(
|
177
|
+
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
|
178
|
+
other_representation=velocity_representation,
|
179
|
+
transform=W_H_B,
|
180
|
+
is_force=False,
|
181
|
+
)
|
182
|
+
|
183
|
+
ode_state = ODEState.build_from_jaxsim_model(
|
184
|
+
model=model,
|
185
|
+
base_position=base_position.astype(float),
|
186
|
+
base_quaternion=base_quaternion.astype(float),
|
187
|
+
joint_positions=joint_positions.astype(float),
|
188
|
+
base_linear_velocity=v_WB[0:3].astype(float),
|
189
|
+
base_angular_velocity=v_WB[3:6].astype(float),
|
190
|
+
joint_velocities=joint_velocities.astype(float),
|
191
|
+
tangential_deformation=(
|
192
|
+
soft_contacts_state.tangential_deformation
|
193
|
+
if soft_contacts_state is not None
|
194
|
+
else None
|
195
|
+
),
|
196
|
+
)
|
197
|
+
|
198
|
+
if not ode_state.valid(model=model):
|
199
|
+
raise ValueError(ode_state)
|
200
|
+
|
201
|
+
return JaxSimModelData(
|
202
|
+
time_ns=time_ns,
|
203
|
+
state=ode_state,
|
204
|
+
gravity=gravity.astype(float),
|
205
|
+
soft_contacts_params=soft_contacts_params,
|
206
|
+
velocity_representation=velocity_representation,
|
207
|
+
)
|
208
|
+
|
209
|
+
# ==================
|
210
|
+
# Extract quantities
|
211
|
+
# ==================
|
212
|
+
|
213
|
+
def time(self) -> jtp.Float:
|
214
|
+
"""
|
215
|
+
Get the simulated time.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
The simulated time in seconds.
|
219
|
+
"""
|
220
|
+
|
221
|
+
return self.time_ns.astype(float) / 1e9
|
222
|
+
|
223
|
+
def standard_gravity(self) -> jtp.Float:
|
224
|
+
"""
|
225
|
+
Get the standard gravity constant.
|
226
|
+
|
227
|
+
Returns:
|
228
|
+
The standard gravity constant.
|
229
|
+
"""
|
230
|
+
|
231
|
+
return -self.gravity[2]
|
232
|
+
|
233
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
234
|
+
def joint_positions(
|
235
|
+
self,
|
236
|
+
model: js.model.JaxSimModel | None = None,
|
237
|
+
joint_names: tuple[str, ...] | None = None,
|
238
|
+
) -> jtp.Vector:
|
239
|
+
"""
|
240
|
+
Get the joint positions.
|
241
|
+
|
242
|
+
Args:
|
243
|
+
model: The model to consider.
|
244
|
+
joint_names:
|
245
|
+
The names of the joints for which to get the positions. If `None`,
|
246
|
+
the positions of all joints are returned.
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
If no model and no joint names are provided, the joint positions as a
|
250
|
+
`(DoFs,)` vector corresponding to the serialization of the original
|
251
|
+
model used to build the data object.
|
252
|
+
If a model is provided and no joint names are provided, the joint positions
|
253
|
+
as a `(DoFs,)` vector corresponding to the serialization of the
|
254
|
+
provided model.
|
255
|
+
If a model and joint names are provided, the joint positions as a
|
256
|
+
`(len(joint_names),)` vector corresponding to the serialization of
|
257
|
+
the passed joint names vector.
|
258
|
+
"""
|
259
|
+
|
260
|
+
if model is None:
|
261
|
+
if joint_names is not None:
|
262
|
+
raise ValueError("Joint names cannot be provided without a model")
|
263
|
+
|
264
|
+
return self.state.physics_model.joint_positions
|
265
|
+
|
266
|
+
if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
|
267
|
+
model=model
|
268
|
+
):
|
269
|
+
msg = "The data object is not compatible with the provided model"
|
270
|
+
raise ValueError(msg)
|
271
|
+
|
272
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
273
|
+
|
274
|
+
return self.state.physics_model.joint_positions[
|
275
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
276
|
+
]
|
277
|
+
|
278
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
279
|
+
def joint_velocities(
|
280
|
+
self,
|
281
|
+
model: js.model.JaxSimModel | None = None,
|
282
|
+
joint_names: tuple[str, ...] | None = None,
|
283
|
+
) -> jtp.Vector:
|
284
|
+
"""
|
285
|
+
Get the joint velocities.
|
286
|
+
|
287
|
+
Args:
|
288
|
+
model: The model to consider.
|
289
|
+
joint_names:
|
290
|
+
The names of the joints for which to get the velocities. If `None`,
|
291
|
+
the velocities of all joints are returned.
|
292
|
+
|
293
|
+
Returns:
|
294
|
+
If no model and no joint names are provided, the joint velocities as a
|
295
|
+
`(DoFs,)` vector corresponding to the serialization of the original
|
296
|
+
model used to build the data object.
|
297
|
+
If a model is provided and no joint names are provided, the joint velocities
|
298
|
+
as a `(DoFs,)` vector corresponding to the serialization of the
|
299
|
+
provided model.
|
300
|
+
If a model and joint names are provided, the joint velocities as a
|
301
|
+
`(len(joint_names),)` vector corresponding to the serialization of
|
302
|
+
the passed joint names vector.
|
303
|
+
"""
|
304
|
+
|
305
|
+
if model is None:
|
306
|
+
if joint_names is not None:
|
307
|
+
raise ValueError("Joint names cannot be provided without a model")
|
308
|
+
|
309
|
+
return self.state.physics_model.joint_velocities
|
310
|
+
|
311
|
+
if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
|
312
|
+
model=model
|
313
|
+
):
|
314
|
+
msg = "The data object is not compatible with the provided model"
|
315
|
+
raise ValueError(msg)
|
316
|
+
|
317
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
318
|
+
|
319
|
+
return self.state.physics_model.joint_velocities[
|
320
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
321
|
+
]
|
322
|
+
|
323
|
+
@jax.jit
|
324
|
+
def base_position(self) -> jtp.Vector:
|
325
|
+
"""
|
326
|
+
Get the base position.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
The base position.
|
330
|
+
"""
|
331
|
+
|
332
|
+
return self.state.physics_model.base_position.squeeze()
|
333
|
+
|
334
|
+
@functools.partial(jax.jit, static_argnames=["dcm"])
|
335
|
+
def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
|
336
|
+
"""
|
337
|
+
Get the base orientation.
|
338
|
+
|
339
|
+
Args:
|
340
|
+
dcm: Whether to return the orientation as a SO(3) matrix or quaternion.
|
341
|
+
|
342
|
+
Returns:
|
343
|
+
The base orientation.
|
344
|
+
"""
|
345
|
+
|
346
|
+
# Extract the base quaternion.
|
347
|
+
W_Q_B = self.state.physics_model.base_quaternion.squeeze()
|
348
|
+
|
349
|
+
# Always normalize the quaternion to avoid numerical issues.
|
350
|
+
# If the active scheme does not integrate the quaternion on its manifold,
|
351
|
+
# we introduce a Baumgarte stabilization to let the quaternion converge to
|
352
|
+
# a unit quaternion. In this case, it is not guaranteed that the quaternion
|
353
|
+
# stored in the state is a unit quaternion.
|
354
|
+
W_Q_B = jax.lax.select(
|
355
|
+
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
|
356
|
+
on_true=W_Q_B,
|
357
|
+
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
|
358
|
+
)
|
359
|
+
|
360
|
+
return (
|
361
|
+
W_Q_B
|
362
|
+
if not dcm
|
363
|
+
else jaxlie.SO3.from_quaternion_xyzw(
|
364
|
+
Quaternion.to_xyzw(wxyz=W_Q_B)
|
365
|
+
).as_matrix()
|
366
|
+
).astype(float)
|
367
|
+
|
368
|
+
@jax.jit
|
369
|
+
def base_transform(self) -> jtp.MatrixJax:
|
370
|
+
"""
|
371
|
+
Get the base transform.
|
372
|
+
|
373
|
+
Returns:
|
374
|
+
The base transform as an SE(3) matrix.
|
375
|
+
"""
|
376
|
+
|
377
|
+
W_R_B = self.base_orientation(dcm=True)
|
378
|
+
W_p_B = jnp.vstack(self.base_position())
|
379
|
+
|
380
|
+
return jnp.vstack(
|
381
|
+
[
|
382
|
+
jnp.block([W_R_B, W_p_B]),
|
383
|
+
jnp.array([0, 0, 0, 1]),
|
384
|
+
]
|
385
|
+
)
|
386
|
+
|
387
|
+
@jax.jit
|
388
|
+
def base_velocity(self) -> jtp.Vector:
|
389
|
+
"""
|
390
|
+
Get the base 6D velocity.
|
391
|
+
|
392
|
+
Returns:
|
393
|
+
The base 6D velocity in the active representation.
|
394
|
+
"""
|
395
|
+
|
396
|
+
W_v_WB = jnp.hstack(
|
397
|
+
[
|
398
|
+
self.state.physics_model.base_linear_velocity,
|
399
|
+
self.state.physics_model.base_angular_velocity,
|
400
|
+
]
|
401
|
+
)
|
402
|
+
|
403
|
+
W_H_B = self.base_transform()
|
404
|
+
|
405
|
+
return (
|
406
|
+
JaxSimModelData.inertial_to_other_representation(
|
407
|
+
array=W_v_WB,
|
408
|
+
other_representation=self.velocity_representation,
|
409
|
+
transform=W_H_B,
|
410
|
+
is_force=False,
|
411
|
+
)
|
412
|
+
.squeeze()
|
413
|
+
.astype(float)
|
414
|
+
)
|
415
|
+
|
416
|
+
@jax.jit
|
417
|
+
def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
|
418
|
+
r"""
|
419
|
+
Get the generalized position
|
420
|
+
:math:`\\mathbf{q} = ({}^W \\mathbf{H}_B, \\mathbf{s}) \\in \text{SO}(3) \times \\mathbb{R}^n`.
|
421
|
+
|
422
|
+
Returns:
|
423
|
+
A tuple containing the base transform and the joint positions.
|
424
|
+
"""
|
425
|
+
|
426
|
+
return self.base_transform(), self.joint_positions()
|
427
|
+
|
428
|
+
@jax.jit
|
429
|
+
def generalized_velocity(self) -> jtp.Vector:
|
430
|
+
r"""
|
431
|
+
Get the generalized velocity
|
432
|
+
:math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\\, \boldsymbol{\\omega}_{W,B};\\, \\mathbf{s}) \\in \\mathbb{R}^{6+n}`
|
433
|
+
|
434
|
+
Returns:
|
435
|
+
The generalized velocity in the active representation.
|
436
|
+
"""
|
437
|
+
|
438
|
+
return (
|
439
|
+
jnp.hstack([self.base_velocity(), self.joint_velocities()])
|
440
|
+
.squeeze()
|
441
|
+
.astype(float)
|
442
|
+
)
|
443
|
+
|
444
|
+
# ================
|
445
|
+
# Store quantities
|
446
|
+
# ================
|
447
|
+
|
448
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
449
|
+
def reset_joint_positions(
|
450
|
+
self,
|
451
|
+
positions: jtp.VectorLike,
|
452
|
+
model: js.model.JaxSimModel | None = None,
|
453
|
+
joint_names: tuple[str, ...] | None = None,
|
454
|
+
) -> Self:
|
455
|
+
"""
|
456
|
+
Reset the joint positions.
|
457
|
+
|
458
|
+
Args:
|
459
|
+
positions: The joint positions.
|
460
|
+
model: The model to consider.
|
461
|
+
joint_names: The names of the joints for which to set the positions.
|
462
|
+
|
463
|
+
Returns:
|
464
|
+
The updated `JaxSimModelData` object.
|
465
|
+
"""
|
466
|
+
|
467
|
+
positions = jnp.array(positions)
|
468
|
+
|
469
|
+
def replace(s: jtp.VectorLike) -> JaxSimModelData:
|
470
|
+
return self.replace(
|
471
|
+
validate=True,
|
472
|
+
state=self.state.replace(
|
473
|
+
physics_model=self.state.physics_model.replace(
|
474
|
+
joint_positions=jnp.atleast_1d(s.squeeze()).astype(float)
|
475
|
+
)
|
476
|
+
),
|
477
|
+
)
|
478
|
+
|
479
|
+
if model is None:
|
480
|
+
return replace(s=positions)
|
481
|
+
|
482
|
+
if not_tracing(positions) and not self.valid(model=model):
|
483
|
+
msg = "The data object is not compatible with the provided model"
|
484
|
+
raise ValueError(msg)
|
485
|
+
|
486
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
487
|
+
|
488
|
+
return replace(
|
489
|
+
s=self.state.physics_model.joint_positions.at[
|
490
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
491
|
+
].set(positions)
|
492
|
+
)
|
493
|
+
|
494
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
495
|
+
def reset_joint_velocities(
|
496
|
+
self,
|
497
|
+
velocities: jtp.VectorLike,
|
498
|
+
model: js.model.JaxSimModel | None = None,
|
499
|
+
joint_names: tuple[str, ...] | None = None,
|
500
|
+
) -> Self:
|
501
|
+
"""
|
502
|
+
Reset the joint velocities.
|
503
|
+
|
504
|
+
Args:
|
505
|
+
velocities: The joint velocities.
|
506
|
+
model: The model to consider.
|
507
|
+
joint_names: The names of the joints for which to set the velocities.
|
508
|
+
|
509
|
+
Returns:
|
510
|
+
The updated `JaxSimModelData` object.
|
511
|
+
"""
|
512
|
+
|
513
|
+
velocities = jnp.array(velocities)
|
514
|
+
|
515
|
+
def replace(ṡ: jtp.VectorLike) -> JaxSimModelData:
|
516
|
+
return self.replace(
|
517
|
+
validate=True,
|
518
|
+
state=self.state.replace(
|
519
|
+
physics_model=self.state.physics_model.replace(
|
520
|
+
joint_velocities=jnp.atleast_1d(ṡ.squeeze()).astype(float)
|
521
|
+
)
|
522
|
+
),
|
523
|
+
)
|
524
|
+
|
525
|
+
if model is None:
|
526
|
+
return replace(ṡ=velocities)
|
527
|
+
|
528
|
+
if not_tracing(velocities) and not self.valid(model=model):
|
529
|
+
msg = "The data object is not compatible with the provided model"
|
530
|
+
raise ValueError(msg)
|
531
|
+
|
532
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
533
|
+
|
534
|
+
return replace(
|
535
|
+
ṡ=self.state.physics_model.joint_velocities.at[
|
536
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
537
|
+
].set(velocities)
|
538
|
+
)
|
539
|
+
|
540
|
+
@jax.jit
|
541
|
+
def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
|
542
|
+
"""
|
543
|
+
Reset the base position.
|
544
|
+
|
545
|
+
Args:
|
546
|
+
base_position: The base position.
|
547
|
+
|
548
|
+
Returns:
|
549
|
+
The updated `JaxSimModelData` object.
|
550
|
+
"""
|
551
|
+
|
552
|
+
base_position = jnp.array(base_position)
|
553
|
+
|
554
|
+
return self.replace(
|
555
|
+
validate=True,
|
556
|
+
state=self.state.replace(
|
557
|
+
physics_model=self.state.physics_model.replace(
|
558
|
+
base_position=jnp.atleast_1d(base_position.squeeze()).astype(float)
|
559
|
+
)
|
560
|
+
),
|
561
|
+
)
|
562
|
+
|
563
|
+
@jax.jit
|
564
|
+
def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
|
565
|
+
"""
|
566
|
+
Reset the base quaternion.
|
567
|
+
|
568
|
+
Args:
|
569
|
+
base_quaternion: The base orientation as a quaternion.
|
570
|
+
|
571
|
+
Returns:
|
572
|
+
The updated `JaxSimModelData` object.
|
573
|
+
"""
|
574
|
+
|
575
|
+
base_quaternion = jnp.array(base_quaternion)
|
576
|
+
|
577
|
+
return self.replace(
|
578
|
+
validate=True,
|
579
|
+
state=self.state.replace(
|
580
|
+
physics_model=self.state.physics_model.replace(
|
581
|
+
base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
|
582
|
+
float
|
583
|
+
)
|
584
|
+
)
|
585
|
+
),
|
586
|
+
)
|
587
|
+
|
588
|
+
@jax.jit
|
589
|
+
def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
|
590
|
+
"""
|
591
|
+
Reset the base pose.
|
592
|
+
|
593
|
+
Args:
|
594
|
+
base_pose: The base pose as an SE(3) matrix.
|
595
|
+
|
596
|
+
Returns:
|
597
|
+
The updated `JaxSimModelData` object.
|
598
|
+
"""
|
599
|
+
|
600
|
+
base_pose = jnp.array(base_pose)
|
601
|
+
|
602
|
+
W_p_B = base_pose[0:3, 3]
|
603
|
+
|
604
|
+
to_wxyz = np.array([3, 0, 1, 2])
|
605
|
+
W_R_B: jaxlie.SO3 = jaxlie.SO3.from_matrix(base_pose[0:3, 0:3]) # noqa
|
606
|
+
W_Q_B = W_R_B.as_quaternion_xyzw()[to_wxyz]
|
607
|
+
|
608
|
+
return self.reset_base_position(base_position=W_p_B).reset_base_quaternion(
|
609
|
+
base_quaternion=W_Q_B
|
610
|
+
)
|
611
|
+
|
612
|
+
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
613
|
+
def reset_base_linear_velocity(
|
614
|
+
self,
|
615
|
+
linear_velocity: jtp.VectorLike,
|
616
|
+
velocity_representation: VelRepr | None = None,
|
617
|
+
) -> Self:
|
618
|
+
"""
|
619
|
+
Reset the base linear velocity.
|
620
|
+
|
621
|
+
Args:
|
622
|
+
linear_velocity: The base linear velocity as a 3D array.
|
623
|
+
velocity_representation:
|
624
|
+
The velocity representation in which the base velocity is expressed.
|
625
|
+
If `None`, the active representation is considered.
|
626
|
+
|
627
|
+
Returns:
|
628
|
+
The updated `JaxSimModelData` object.
|
629
|
+
"""
|
630
|
+
|
631
|
+
linear_velocity = jnp.array(linear_velocity)
|
632
|
+
|
633
|
+
return self.reset_base_velocity(
|
634
|
+
base_velocity=jnp.hstack(
|
635
|
+
[linear_velocity.squeeze(), self.base_velocity()[3:6]]
|
636
|
+
),
|
637
|
+
velocity_representation=velocity_representation,
|
638
|
+
)
|
639
|
+
|
640
|
+
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
641
|
+
def reset_base_angular_velocity(
|
642
|
+
self,
|
643
|
+
angular_velocity: jtp.VectorLike,
|
644
|
+
velocity_representation: VelRepr | None = None,
|
645
|
+
) -> Self:
|
646
|
+
"""
|
647
|
+
Reset the base angular velocity.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
angular_velocity: The base angular velocity as a 3D array.
|
651
|
+
velocity_representation:
|
652
|
+
The velocity representation in which the base velocity is expressed.
|
653
|
+
If `None`, the active representation is considered.
|
654
|
+
|
655
|
+
Returns:
|
656
|
+
The updated `JaxSimModelData` object.
|
657
|
+
"""
|
658
|
+
|
659
|
+
angular_velocity = jnp.array(angular_velocity)
|
660
|
+
|
661
|
+
return self.reset_base_velocity(
|
662
|
+
base_velocity=jnp.hstack(
|
663
|
+
[self.base_velocity()[0:3], angular_velocity.squeeze()]
|
664
|
+
),
|
665
|
+
velocity_representation=velocity_representation,
|
666
|
+
)
|
667
|
+
|
668
|
+
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
669
|
+
def reset_base_velocity(
|
670
|
+
self,
|
671
|
+
base_velocity: jtp.VectorLike,
|
672
|
+
velocity_representation: VelRepr | None = None,
|
673
|
+
) -> Self:
|
674
|
+
"""
|
675
|
+
Reset the base 6D velocity.
|
676
|
+
|
677
|
+
Args:
|
678
|
+
base_velocity: The base 6D velocity in the active representation.
|
679
|
+
velocity_representation:
|
680
|
+
The velocity representation in which the base velocity is expressed.
|
681
|
+
If `None`, the active representation is considered.
|
682
|
+
|
683
|
+
Returns:
|
684
|
+
The updated `JaxSimModelData` object.
|
685
|
+
"""
|
686
|
+
|
687
|
+
base_velocity = jnp.array(base_velocity)
|
688
|
+
|
689
|
+
velocity_representation = (
|
690
|
+
velocity_representation
|
691
|
+
if velocity_representation is not None
|
692
|
+
else self.velocity_representation
|
693
|
+
)
|
694
|
+
|
695
|
+
W_v_WB = self.other_representation_to_inertial(
|
696
|
+
array=jnp.atleast_1d(base_velocity.squeeze()).astype(float),
|
697
|
+
other_representation=velocity_representation,
|
698
|
+
transform=self.base_transform(),
|
699
|
+
is_force=False,
|
700
|
+
)
|
701
|
+
|
702
|
+
return self.replace(
|
703
|
+
validate=True,
|
704
|
+
state=self.state.replace(
|
705
|
+
physics_model=self.state.physics_model.replace(
|
706
|
+
base_linear_velocity=W_v_WB[0:3].squeeze().astype(float),
|
707
|
+
base_angular_velocity=W_v_WB[3:6].squeeze().astype(float),
|
708
|
+
)
|
709
|
+
),
|
710
|
+
)
|
711
|
+
|
712
|
+
|
713
|
+
def random_model_data(
|
714
|
+
model: js.model.JaxSimModel,
|
715
|
+
*,
|
716
|
+
key: jax.Array | None = None,
|
717
|
+
velocity_representation: VelRepr | None = None,
|
718
|
+
base_pos_bounds: tuple[
|
719
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
720
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
721
|
+
] = ((-1, -1, 0.5), 1.0),
|
722
|
+
base_vel_lin_bounds: tuple[
|
723
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
724
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
725
|
+
] = (-1.0, 1.0),
|
726
|
+
base_vel_ang_bounds: tuple[
|
727
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
728
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
729
|
+
] = (-1.0, 1.0),
|
730
|
+
joint_vel_bounds: tuple[
|
731
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
732
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
733
|
+
] = (-1.0, 1.0),
|
734
|
+
standard_gravity_bounds: tuple[jtp.FloatLike, jtp.FloatLike] = (
|
735
|
+
jaxsim.math.StandardGravity,
|
736
|
+
jaxsim.math.StandardGravity,
|
737
|
+
),
|
738
|
+
) -> JaxSimModelData:
|
739
|
+
"""
|
740
|
+
Randomly generate a `JaxSimModelData` object.
|
741
|
+
|
742
|
+
Args:
|
743
|
+
model: The target model for the random data.
|
744
|
+
key: The random key.
|
745
|
+
velocity_representation: The velocity representation to use.
|
746
|
+
base_pos_bounds: The bounds for the base position.
|
747
|
+
base_vel_lin_bounds: The bounds for the base linear velocity.
|
748
|
+
base_vel_ang_bounds: The bounds for the base angular velocity.
|
749
|
+
joint_vel_bounds: The bounds for the joint velocities.
|
750
|
+
standard_gravity_bounds: The bounds for the standard gravity.
|
751
|
+
|
752
|
+
Returns:
|
753
|
+
A `JaxSimModelData` object with random data.
|
754
|
+
"""
|
755
|
+
|
756
|
+
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
757
|
+
k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, num=7)
|
758
|
+
|
759
|
+
p_min = jnp.array(base_pos_bounds[0], dtype=float)
|
760
|
+
p_max = jnp.array(base_pos_bounds[1], dtype=float)
|
761
|
+
v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
|
762
|
+
v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
|
763
|
+
ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
|
764
|
+
ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float)
|
765
|
+
ṡ_min, ṡ_max = joint_vel_bounds
|
766
|
+
|
767
|
+
random_data = JaxSimModelData.zero(
|
768
|
+
model=model,
|
769
|
+
**(
|
770
|
+
dict(velocity_representation=velocity_representation)
|
771
|
+
if velocity_representation is not None
|
772
|
+
else {}
|
773
|
+
),
|
774
|
+
)
|
775
|
+
|
776
|
+
with random_data.mutable_context(
|
777
|
+
mutability=Mutability.MUTABLE, restore_after_exception=False
|
778
|
+
):
|
779
|
+
|
780
|
+
physics_model_state = random_data.state.physics_model
|
781
|
+
|
782
|
+
physics_model_state.base_position = jax.random.uniform(
|
783
|
+
key=k1, shape=(3,), minval=p_min, maxval=p_max
|
784
|
+
)
|
785
|
+
|
786
|
+
physics_model_state.base_quaternion = jaxlie.SO3.from_rpy_radians(
|
787
|
+
*jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
|
788
|
+
).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]
|
789
|
+
|
790
|
+
if model.number_of_joints() > 0:
|
791
|
+
physics_model_state.joint_positions = js.joint.random_joint_positions(
|
792
|
+
model=model, key=k3
|
793
|
+
)
|
794
|
+
|
795
|
+
physics_model_state.joint_velocities = jax.random.uniform(
|
796
|
+
key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
|
797
|
+
)
|
798
|
+
|
799
|
+
if model.floating_base():
|
800
|
+
physics_model_state.base_linear_velocity = jax.random.uniform(
|
801
|
+
key=k5, shape=(3,), minval=v_min, maxval=v_max
|
802
|
+
)
|
803
|
+
|
804
|
+
physics_model_state.base_angular_velocity = jax.random.uniform(
|
805
|
+
key=k6, shape=(3,), minval=ω_min, maxval=ω_max
|
806
|
+
)
|
807
|
+
|
808
|
+
random_data.gravity = (
|
809
|
+
jnp.zeros(3, dtype=random_data.gravity.dtype)
|
810
|
+
.at[2]
|
811
|
+
.set(
|
812
|
+
-jax.random.uniform(
|
813
|
+
key=k7,
|
814
|
+
shape=(),
|
815
|
+
minval=standard_gravity_bounds[0],
|
816
|
+
maxval=standard_gravity_bounds[1],
|
817
|
+
)
|
818
|
+
)
|
819
|
+
)
|
820
|
+
|
821
|
+
return random_data
|