jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev105__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +156 -11
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/METADATA +6 -8
- jaxsim-0.6.2.dev105.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/top_level.txt +0 -0
jaxsim/api/data.py
CHANGED
@@ -4,6 +4,11 @@ import dataclasses
|
|
4
4
|
import functools
|
5
5
|
from collections.abc import Sequence
|
6
6
|
|
7
|
+
try:
|
8
|
+
from typing import override
|
9
|
+
except ImportError:
|
10
|
+
from typing_extensions import override
|
11
|
+
|
7
12
|
import jax
|
8
13
|
import jax.numpy as jnp
|
9
14
|
import jax.scipy.spatial.transform
|
@@ -13,12 +18,9 @@ import jaxsim.api as js
|
|
13
18
|
import jaxsim.math
|
14
19
|
import jaxsim.rbda
|
15
20
|
import jaxsim.typing as jtp
|
16
|
-
from jaxsim.utils import Mutability
|
17
|
-
from jaxsim.utils.tracing import not_tracing
|
18
21
|
|
19
22
|
from . import common
|
20
23
|
from .common import VelRepr
|
21
|
-
from .ode_data import ODEState
|
22
24
|
|
23
25
|
try:
|
24
26
|
from typing import Self
|
@@ -29,72 +31,38 @@ except ImportError:
|
|
29
31
|
@jax_dataclasses.pytree_dataclass
|
30
32
|
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
31
33
|
"""
|
32
|
-
Class
|
34
|
+
Class storing the state of the physics model dynamics.
|
35
|
+
|
36
|
+
Attributes:
|
37
|
+
joint_positions: The vector of joint positions.
|
38
|
+
joint_velocities: The vector of joint velocities.
|
39
|
+
base_position: The 3D position of the base link.
|
40
|
+
base_quaternion: The quaternion defining the orientation of the base link.
|
41
|
+
base_linear_velocity:
|
42
|
+
The linear velocity of the base link in inertial-fixed representation.
|
43
|
+
base_angular_velocity:
|
44
|
+
The angular velocity of the base link in inertial-fixed representation.
|
45
|
+
base_transform: The base transform.
|
46
|
+
joint_transforms: The joint transforms.
|
47
|
+
link_transforms: The link transforms.
|
48
|
+
link_velocities: The link velocities in inertial-fixed representation.
|
33
49
|
"""
|
34
50
|
|
35
|
-
state
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
|
40
|
-
|
41
|
-
def __hash__(self) -> int:
|
42
|
-
|
43
|
-
from jaxsim.utils.wrappers import HashedNumpyArray
|
44
|
-
|
45
|
-
return hash(
|
46
|
-
(
|
47
|
-
hash(self.state),
|
48
|
-
HashedNumpyArray.hash_of_array(self.gravity),
|
49
|
-
hash(self.contacts_params),
|
50
|
-
)
|
51
|
-
)
|
52
|
-
|
53
|
-
def __eq__(self, other: JaxSimModelData) -> bool:
|
54
|
-
|
55
|
-
if not isinstance(other, JaxSimModelData):
|
56
|
-
return False
|
57
|
-
|
58
|
-
return hash(self) == hash(other)
|
59
|
-
|
60
|
-
def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
|
61
|
-
"""
|
62
|
-
Check if the current state is valid for the given model.
|
63
|
-
|
64
|
-
Args:
|
65
|
-
model: The model to check against.
|
66
|
-
|
67
|
-
Returns:
|
68
|
-
`True` if the current state is valid for the given model, `False` otherwise.
|
69
|
-
"""
|
51
|
+
# Joint state
|
52
|
+
_joint_positions: jtp.Vector
|
53
|
+
_joint_velocities: jtp.Vector
|
70
54
|
|
71
|
-
|
72
|
-
|
55
|
+
# Base state
|
56
|
+
_base_quaternion: jtp.Vector
|
57
|
+
_base_linear_velocity: jtp.Vector
|
58
|
+
_base_angular_velocity: jtp.Vector
|
59
|
+
_base_position: jtp.Vector
|
73
60
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
@staticmethod
|
80
|
-
def zero(
|
81
|
-
model: js.model.JaxSimModel,
|
82
|
-
velocity_representation: VelRepr = VelRepr.Inertial,
|
83
|
-
) -> JaxSimModelData:
|
84
|
-
"""
|
85
|
-
Create a `JaxSimModelData` object with zero state.
|
86
|
-
|
87
|
-
Args:
|
88
|
-
model: The model for which to create the zero state.
|
89
|
-
velocity_representation: The velocity representation to use.
|
90
|
-
|
91
|
-
Returns:
|
92
|
-
A `JaxSimModelData` object with zero state.
|
93
|
-
"""
|
94
|
-
|
95
|
-
return JaxSimModelData.build(
|
96
|
-
model=model, velocity_representation=velocity_representation
|
97
|
-
)
|
61
|
+
# Cached computations.
|
62
|
+
_base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None)
|
63
|
+
_joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
|
64
|
+
_link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
|
65
|
+
_link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
|
98
66
|
|
99
67
|
@staticmethod
|
100
68
|
def build(
|
@@ -105,10 +73,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
105
73
|
base_linear_velocity: jtp.VectorLike | None = None,
|
106
74
|
base_angular_velocity: jtp.VectorLike | None = None,
|
107
75
|
joint_velocities: jtp.VectorLike | None = None,
|
108
|
-
|
109
|
-
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
|
110
|
-
velocity_representation: VelRepr = VelRepr.Inertial,
|
111
|
-
extended_ode_state: dict[str, jtp.PyTree] | None = None,
|
76
|
+
velocity_representation: VelRepr = VelRepr.Mixed,
|
112
77
|
) -> JaxSimModelData:
|
113
78
|
"""
|
114
79
|
Create a `JaxSimModelData` object with the given state.
|
@@ -123,13 +88,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
123
88
|
base_angular_velocity:
|
124
89
|
The base angular velocity in the selected representation.
|
125
90
|
joint_velocities: The joint velocities.
|
126
|
-
|
127
|
-
contacts_params: The parameters of the soft contacts.
|
128
|
-
velocity_representation: The velocity representation to use.
|
129
|
-
extended_ode_state:
|
130
|
-
Additional user-defined state variables that are not part of the
|
131
|
-
standard `ODEState` object. Useful to extend the system dynamics
|
132
|
-
considered by default in JaxSim.
|
91
|
+
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
|
133
92
|
|
134
93
|
Returns:
|
135
94
|
A `JaxSimModelData` initialized with the given state.
|
@@ -163,8 +122,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
163
122
|
dtype=float,
|
164
123
|
).squeeze()
|
165
124
|
|
166
|
-
gravity = jnp.zeros(3).at[2].set(-standard_gravity)
|
167
|
-
|
168
125
|
joint_positions = jnp.atleast_1d(
|
169
126
|
jnp.array(
|
170
127
|
(
|
@@ -191,166 +148,104 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
191
148
|
translation=base_position, quaternion=base_quaternion
|
192
149
|
)
|
193
150
|
|
194
|
-
|
151
|
+
W_v_WB = JaxSimModelData.other_representation_to_inertial(
|
195
152
|
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
|
196
153
|
other_representation=velocity_representation,
|
197
154
|
transform=W_H_B,
|
198
155
|
is_force=False,
|
199
156
|
).astype(float)
|
200
157
|
|
201
|
-
|
202
|
-
|
203
|
-
base_position=base_position,
|
204
|
-
base_quaternion=base_quaternion,
|
205
|
-
joint_positions=joint_positions,
|
206
|
-
base_linear_velocity=v_WB[0:3],
|
207
|
-
base_angular_velocity=v_WB[3:6],
|
208
|
-
joint_velocities=joint_velocities,
|
209
|
-
# Unpack all the additional ODE states. If the contact model requires an
|
210
|
-
# additional state that is not explicitly passed to this builder, ODEState
|
211
|
-
# automatically populates that state with zeroed variables.
|
212
|
-
# This is not true for any other custom state that the user might want to
|
213
|
-
# pass to the integrator.
|
214
|
-
**(extended_ode_state if extended_ode_state else {}),
|
158
|
+
joint_transforms = model.kin_dyn_parameters.joint_transforms(
|
159
|
+
joint_positions=joint_positions, base_transform=W_H_B
|
215
160
|
)
|
216
161
|
|
217
|
-
|
218
|
-
|
162
|
+
link_transforms, link_velocities_inertial = (
|
163
|
+
jaxsim.rbda.forward_kinematics_model(
|
164
|
+
model=model,
|
165
|
+
base_position=base_position,
|
166
|
+
base_quaternion=base_quaternion,
|
167
|
+
joint_positions=joint_positions,
|
168
|
+
base_linear_velocity_inertial=W_v_WB[0:3],
|
169
|
+
base_angular_velocity_inertial=W_v_WB[3:6],
|
170
|
+
joint_velocities=joint_velocities,
|
171
|
+
)
|
172
|
+
)
|
219
173
|
|
220
|
-
|
174
|
+
model_data = JaxSimModelData(
|
175
|
+
velocity_representation=velocity_representation,
|
176
|
+
_base_quaternion=base_quaternion,
|
177
|
+
_base_position=base_position,
|
178
|
+
_joint_positions=joint_positions,
|
179
|
+
_base_linear_velocity=W_v_WB[0:3],
|
180
|
+
_base_angular_velocity=W_v_WB[3:6],
|
181
|
+
_joint_velocities=joint_velocities,
|
182
|
+
_base_transform=W_H_B,
|
183
|
+
_joint_transforms=joint_transforms,
|
184
|
+
_link_transforms=link_transforms,
|
185
|
+
_link_velocities=link_velocities_inertial,
|
186
|
+
)
|
221
187
|
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
):
|
188
|
+
if not model_data.valid(model=model):
|
189
|
+
raise ValueError(
|
190
|
+
"The built state is not compatible with the model.", model_data
|
191
|
+
)
|
227
192
|
|
228
|
-
|
229
|
-
model=model, standard_gravity=standard_gravity
|
230
|
-
)
|
193
|
+
return model_data
|
231
194
|
|
232
|
-
|
233
|
-
|
195
|
+
@staticmethod
|
196
|
+
def zero(
|
197
|
+
model: js.model.JaxSimModel,
|
198
|
+
velocity_representation: VelRepr = VelRepr.Mixed,
|
199
|
+
) -> JaxSimModelData:
|
200
|
+
"""
|
201
|
+
Create a `JaxSimModelData` object with zero state.
|
234
202
|
|
235
|
-
|
236
|
-
state
|
237
|
-
|
238
|
-
|
239
|
-
|
203
|
+
Args:
|
204
|
+
model: The model for which to create the state.
|
205
|
+
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
A `JaxSimModelData` initialized with zero state.
|
209
|
+
"""
|
210
|
+
return JaxSimModelData.build(
|
211
|
+
model=model, velocity_representation=velocity_representation
|
240
212
|
)
|
241
213
|
|
242
214
|
# ==================
|
243
215
|
# Extract quantities
|
244
216
|
# ==================
|
245
217
|
|
246
|
-
|
218
|
+
@property
|
219
|
+
def joint_positions(self) -> jtp.Vector:
|
247
220
|
"""
|
248
|
-
Get the
|
221
|
+
Get the joint positions.
|
249
222
|
|
250
223
|
Returns:
|
251
|
-
The
|
224
|
+
The joint positions.
|
252
225
|
"""
|
226
|
+
return self._joint_positions
|
253
227
|
|
254
|
-
|
255
|
-
|
256
|
-
@js.common.named_scope
|
257
|
-
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
258
|
-
def joint_positions(
|
259
|
-
self,
|
260
|
-
model: js.model.JaxSimModel | None = None,
|
261
|
-
joint_names: tuple[str, ...] | None = None,
|
262
|
-
) -> jtp.Vector:
|
228
|
+
@property
|
229
|
+
def joint_velocities(self) -> jtp.Vector:
|
263
230
|
"""
|
264
|
-
Get the joint
|
265
|
-
|
266
|
-
Args:
|
267
|
-
model: The model to consider.
|
268
|
-
joint_names:
|
269
|
-
The names of the joints for which to get the positions. If `None`,
|
270
|
-
the positions of all joints are returned.
|
231
|
+
Get the joint velocities.
|
271
232
|
|
272
233
|
Returns:
|
273
|
-
|
274
|
-
`(DoFs,)` vector corresponding to the serialization of the original
|
275
|
-
model used to build the data object.
|
276
|
-
If a model is provided and no joint names are provided, the joint positions
|
277
|
-
as a `(DoFs,)` vector corresponding to the serialization of the
|
278
|
-
provided model.
|
279
|
-
If a model and joint names are provided, the joint positions as a
|
280
|
-
`(len(joint_names),)` vector corresponding to the serialization of
|
281
|
-
the passed joint names vector.
|
234
|
+
The joint velocities.
|
282
235
|
"""
|
236
|
+
return self._joint_velocities
|
283
237
|
|
284
|
-
|
285
|
-
|
286
|
-
raise ValueError("Joint names cannot be provided without a model")
|
287
|
-
|
288
|
-
return self.state.physics_model.joint_positions
|
289
|
-
|
290
|
-
if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
|
291
|
-
model=model
|
292
|
-
):
|
293
|
-
msg = "The data object is not compatible with the provided model"
|
294
|
-
raise ValueError(msg)
|
295
|
-
|
296
|
-
joint_idxs = (
|
297
|
-
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
298
|
-
if joint_names is not None
|
299
|
-
else jnp.arange(model.number_of_joints())
|
300
|
-
)
|
301
|
-
|
302
|
-
return self.state.physics_model.joint_positions[joint_idxs]
|
303
|
-
|
304
|
-
@js.common.named_scope
|
305
|
-
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
306
|
-
def joint_velocities(
|
307
|
-
self,
|
308
|
-
model: js.model.JaxSimModel | None = None,
|
309
|
-
joint_names: tuple[str, ...] | None = None,
|
310
|
-
) -> jtp.Vector:
|
238
|
+
@property
|
239
|
+
def base_quaternion(self) -> jtp.Vector:
|
311
240
|
"""
|
312
|
-
Get the
|
313
|
-
|
314
|
-
Args:
|
315
|
-
model: The model to consider.
|
316
|
-
joint_names:
|
317
|
-
The names of the joints for which to get the velocities. If `None`,
|
318
|
-
the velocities of all joints are returned.
|
241
|
+
Get the base quaternion.
|
319
242
|
|
320
243
|
Returns:
|
321
|
-
|
322
|
-
`(DoFs,)` vector corresponding to the serialization of the original
|
323
|
-
model used to build the data object.
|
324
|
-
If a model is provided and no joint names are provided, the joint velocities
|
325
|
-
as a `(DoFs,)` vector corresponding to the serialization of the
|
326
|
-
provided model.
|
327
|
-
If a model and joint names are provided, the joint velocities as a
|
328
|
-
`(len(joint_names),)` vector corresponding to the serialization of
|
329
|
-
the passed joint names vector.
|
244
|
+
The base quaternion.
|
330
245
|
"""
|
246
|
+
return self._base_quaternion
|
331
247
|
|
332
|
-
|
333
|
-
if joint_names is not None:
|
334
|
-
raise ValueError("Joint names cannot be provided without a model")
|
335
|
-
|
336
|
-
return self.state.physics_model.joint_velocities
|
337
|
-
|
338
|
-
if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
|
339
|
-
model=model
|
340
|
-
):
|
341
|
-
msg = "The data object is not compatible with the provided model"
|
342
|
-
raise ValueError(msg)
|
343
|
-
|
344
|
-
joint_idxs = (
|
345
|
-
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
346
|
-
if joint_names is not None
|
347
|
-
else jnp.arange(model.number_of_joints())
|
348
|
-
)
|
349
|
-
|
350
|
-
return self.state.physics_model.joint_velocities[joint_idxs]
|
351
|
-
|
352
|
-
@js.common.named_scope
|
353
|
-
@jax.jit
|
248
|
+
@property
|
354
249
|
def base_position(self) -> jtp.Vector:
|
355
250
|
"""
|
356
251
|
Get the base position.
|
@@ -358,24 +253,19 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
358
253
|
Returns:
|
359
254
|
The base position.
|
360
255
|
"""
|
256
|
+
return self._base_position
|
361
257
|
|
362
|
-
|
363
|
-
|
364
|
-
@js.common.named_scope
|
365
|
-
@functools.partial(jax.jit, static_argnames=["dcm"])
|
366
|
-
def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
|
258
|
+
@property
|
259
|
+
def base_orientation(self) -> jtp.Matrix:
|
367
260
|
"""
|
368
261
|
Get the base orientation.
|
369
262
|
|
370
|
-
Args:
|
371
|
-
dcm: Whether to return the orientation as a SO(3) matrix or quaternion.
|
372
|
-
|
373
263
|
Returns:
|
374
264
|
The base orientation.
|
375
265
|
"""
|
376
266
|
|
377
267
|
# Extract the base quaternion.
|
378
|
-
W_Q_B = self.
|
268
|
+
W_Q_B = self.base_quaternion.squeeze()
|
379
269
|
|
380
270
|
# Always normalize the quaternion to avoid numerical issues.
|
381
271
|
# If the active scheme does not integrate the quaternion on its manifold,
|
@@ -384,33 +274,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
384
274
|
# stored in the state is a unit quaternion.
|
385
275
|
norm = jaxsim.math.safe_norm(W_Q_B)
|
386
276
|
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
|
277
|
+
return W_Q_B
|
387
278
|
|
388
|
-
|
389
|
-
float
|
390
|
-
)
|
391
|
-
|
392
|
-
@js.common.named_scope
|
393
|
-
@jax.jit
|
394
|
-
def base_transform(self) -> jtp.Matrix:
|
395
|
-
"""
|
396
|
-
Get the base transform.
|
397
|
-
|
398
|
-
Returns:
|
399
|
-
The base transform as an SE(3) matrix.
|
400
|
-
"""
|
401
|
-
|
402
|
-
W_R_B = self.base_orientation(dcm=True)
|
403
|
-
W_p_B = jnp.vstack(self.base_position())
|
404
|
-
|
405
|
-
return jnp.vstack(
|
406
|
-
[
|
407
|
-
jnp.block([W_R_B, W_p_B]),
|
408
|
-
jnp.array([0, 0, 0, 1]),
|
409
|
-
]
|
410
|
-
)
|
411
|
-
|
412
|
-
@js.common.named_scope
|
413
|
-
@jax.jit
|
279
|
+
@property
|
414
280
|
def base_velocity(self) -> jtp.Vector:
|
415
281
|
"""
|
416
282
|
Get the base 6D velocity.
|
@@ -421,12 +287,12 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
421
287
|
|
422
288
|
W_v_WB = jnp.hstack(
|
423
289
|
[
|
424
|
-
self.
|
425
|
-
self.
|
290
|
+
self._base_linear_velocity,
|
291
|
+
self._base_angular_velocity,
|
426
292
|
]
|
427
293
|
)
|
428
294
|
|
429
|
-
W_H_B = self.
|
295
|
+
W_H_B = self._base_transform
|
430
296
|
|
431
297
|
return (
|
432
298
|
JaxSimModelData.inertial_to_other_representation(
|
@@ -439,8 +305,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
439
305
|
.astype(float)
|
440
306
|
)
|
441
307
|
|
442
|
-
@
|
443
|
-
@jax.jit
|
308
|
+
@property
|
444
309
|
def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
|
445
310
|
r"""
|
446
311
|
Get the generalized position
|
@@ -450,10 +315,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
450
315
|
A tuple containing the base transform and the joint positions.
|
451
316
|
"""
|
452
317
|
|
453
|
-
return self.
|
318
|
+
return self._base_transform, self.joint_positions
|
454
319
|
|
455
|
-
@
|
456
|
-
@jax.jit
|
320
|
+
@property
|
457
321
|
def generalized_velocity(self) -> jtp.Vector:
|
458
322
|
r"""
|
459
323
|
Get the generalized velocity.
|
@@ -465,136 +329,24 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
465
329
|
"""
|
466
330
|
|
467
331
|
return (
|
468
|
-
jnp.hstack([self.base_velocity
|
332
|
+
jnp.hstack([self.base_velocity, self.joint_velocities])
|
469
333
|
.squeeze()
|
470
334
|
.astype(float)
|
471
335
|
)
|
472
336
|
|
473
|
-
|
474
|
-
|
475
|
-
# ================
|
476
|
-
|
477
|
-
@js.common.named_scope
|
478
|
-
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
479
|
-
def reset_joint_positions(
|
480
|
-
self,
|
481
|
-
positions: jtp.VectorLike,
|
482
|
-
model: js.model.JaxSimModel | None = None,
|
483
|
-
joint_names: tuple[str, ...] | None = None,
|
484
|
-
) -> Self:
|
485
|
-
"""
|
486
|
-
Reset the joint positions.
|
487
|
-
|
488
|
-
Args:
|
489
|
-
positions: The joint positions.
|
490
|
-
model: The model to consider.
|
491
|
-
joint_names: The names of the joints for which to set the positions.
|
492
|
-
|
493
|
-
Returns:
|
494
|
-
The updated `JaxSimModelData` object.
|
495
|
-
"""
|
496
|
-
|
497
|
-
positions = jnp.array(positions)
|
498
|
-
|
499
|
-
def replace(s: jtp.VectorLike) -> JaxSimModelData:
|
500
|
-
return self.replace(
|
501
|
-
validate=True,
|
502
|
-
state=self.state.replace(
|
503
|
-
physics_model=self.state.physics_model.replace(
|
504
|
-
joint_positions=jnp.atleast_1d(s.squeeze()).astype(float)
|
505
|
-
)
|
506
|
-
),
|
507
|
-
)
|
508
|
-
|
509
|
-
if model is None:
|
510
|
-
return replace(s=positions)
|
511
|
-
|
512
|
-
if not_tracing(positions) and not self.valid(model=model):
|
513
|
-
msg = "The data object is not compatible with the provided model"
|
514
|
-
raise ValueError(msg)
|
515
|
-
|
516
|
-
joint_idxs = (
|
517
|
-
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
518
|
-
if joint_names is not None
|
519
|
-
else jnp.arange(model.number_of_joints())
|
520
|
-
)
|
521
|
-
|
522
|
-
return replace(
|
523
|
-
s=self.state.physics_model.joint_positions.at[joint_idxs].set(positions)
|
524
|
-
)
|
525
|
-
|
526
|
-
@js.common.named_scope
|
527
|
-
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
528
|
-
def reset_joint_velocities(
|
529
|
-
self,
|
530
|
-
velocities: jtp.VectorLike,
|
531
|
-
model: js.model.JaxSimModel | None = None,
|
532
|
-
joint_names: tuple[str, ...] | None = None,
|
533
|
-
) -> Self:
|
534
|
-
"""
|
535
|
-
Reset the joint velocities.
|
536
|
-
|
537
|
-
Args:
|
538
|
-
velocities: The joint velocities.
|
539
|
-
model: The model to consider.
|
540
|
-
joint_names: The names of the joints for which to set the velocities.
|
541
|
-
|
542
|
-
Returns:
|
543
|
-
The updated `JaxSimModelData` object.
|
544
|
-
"""
|
545
|
-
|
546
|
-
velocities = jnp.array(velocities)
|
547
|
-
|
548
|
-
def replace(ṡ: jtp.VectorLike) -> JaxSimModelData:
|
549
|
-
return self.replace(
|
550
|
-
validate=True,
|
551
|
-
state=self.state.replace(
|
552
|
-
physics_model=self.state.physics_model.replace(
|
553
|
-
joint_velocities=jnp.atleast_1d(ṡ.squeeze()).astype(float)
|
554
|
-
)
|
555
|
-
),
|
556
|
-
)
|
557
|
-
|
558
|
-
if model is None:
|
559
|
-
return replace(ṡ=velocities)
|
560
|
-
|
561
|
-
if not_tracing(velocities) and not self.valid(model=model):
|
562
|
-
msg = "The data object is not compatible with the provided model"
|
563
|
-
raise ValueError(msg)
|
564
|
-
|
565
|
-
joint_idxs = (
|
566
|
-
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
567
|
-
if joint_names is not None
|
568
|
-
else jnp.arange(model.number_of_joints())
|
569
|
-
)
|
570
|
-
|
571
|
-
return replace(
|
572
|
-
ṡ=self.state.physics_model.joint_velocities.at[joint_idxs].set(velocities)
|
573
|
-
)
|
574
|
-
|
575
|
-
@js.common.named_scope
|
576
|
-
@jax.jit
|
577
|
-
def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
|
337
|
+
@property
|
338
|
+
def base_transform(self) -> jtp.Matrix:
|
578
339
|
"""
|
579
|
-
|
580
|
-
|
581
|
-
Args:
|
582
|
-
base_position: The base position.
|
340
|
+
Get the base transform.
|
583
341
|
|
584
342
|
Returns:
|
585
|
-
The
|
343
|
+
The base transform.
|
586
344
|
"""
|
345
|
+
return self._base_transform
|
587
346
|
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
validate=True,
|
592
|
-
state=self.state.replace(
|
593
|
-
physics_model=self.state.physics_model.replace(
|
594
|
-
base_position=jnp.atleast_1d(base_position.squeeze()).astype(float)
|
595
|
-
)
|
596
|
-
),
|
597
|
-
)
|
347
|
+
# ================
|
348
|
+
# Store quantities
|
349
|
+
# ================
|
598
350
|
|
599
351
|
@js.common.named_scope
|
600
352
|
@jax.jit
|
@@ -614,12 +366,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
614
366
|
norm = jaxsim.math.safe_norm(W_Q_B)
|
615
367
|
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
|
616
368
|
|
617
|
-
return self.replace(
|
618
|
-
validate=True,
|
619
|
-
state=self.state.replace(
|
620
|
-
physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
|
621
|
-
),
|
622
|
-
)
|
369
|
+
return self.replace(validate=True, base_quaternion=W_Q_B)
|
623
370
|
|
624
371
|
@js.common.named_scope
|
625
372
|
@jax.jit
|
@@ -635,123 +382,116 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
635
382
|
"""
|
636
383
|
|
637
384
|
base_pose = jnp.array(base_pose)
|
638
|
-
|
639
385
|
W_p_B = base_pose[0:3, 3]
|
640
|
-
|
641
386
|
W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
|
642
|
-
|
643
|
-
|
644
|
-
base_quaternion=W_Q_B
|
387
|
+
return self.replace(
|
388
|
+
base_position=W_p_B,
|
389
|
+
base_quaternion=W_Q_B,
|
645
390
|
)
|
646
391
|
|
647
|
-
@
|
648
|
-
|
649
|
-
def reset_base_linear_velocity(
|
392
|
+
@override
|
393
|
+
def replace(
|
650
394
|
self,
|
651
|
-
|
652
|
-
|
395
|
+
model: js.model.JaxSimModel,
|
396
|
+
joint_positions: jtp.Vector | None = None,
|
397
|
+
joint_velocities: jtp.Vector | None = None,
|
398
|
+
base_quaternion: jtp.Vector | None = None,
|
399
|
+
base_linear_velocity: jtp.Vector | None = None,
|
400
|
+
base_angular_velocity: jtp.Vector | None = None,
|
401
|
+
base_position: jtp.Vector | None = None,
|
402
|
+
validate: bool = False,
|
653
403
|
) -> Self:
|
654
404
|
"""
|
655
|
-
|
656
|
-
|
657
|
-
Args:
|
658
|
-
linear_velocity: The base linear velocity as a 3D array.
|
659
|
-
velocity_representation:
|
660
|
-
The velocity representation in which the base velocity is expressed.
|
661
|
-
If `None`, the active representation is considered.
|
662
|
-
|
663
|
-
Returns:
|
664
|
-
The updated `JaxSimModelData` object.
|
405
|
+
Replace the attributes of the `JaxSimModelData` object.
|
665
406
|
"""
|
407
|
+
if joint_positions is None:
|
408
|
+
joint_positions = self.joint_positions
|
409
|
+
if joint_velocities is None:
|
410
|
+
joint_velocities = self.joint_velocities
|
411
|
+
if base_quaternion is None:
|
412
|
+
base_quaternion = self.base_quaternion
|
413
|
+
if base_position is None:
|
414
|
+
base_position = self.base_position
|
666
415
|
|
667
|
-
|
416
|
+
joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float)
|
417
|
+
joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float)
|
418
|
+
base_quaternion = jnp.atleast_1d(base_quaternion.squeeze()).astype(float)
|
419
|
+
base_position = jnp.atleast_1d(base_position.squeeze()).astype(float)
|
668
420
|
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
]
|
675
|
-
),
|
676
|
-
velocity_representation=velocity_representation,
|
421
|
+
base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
|
422
|
+
translation=base_position, quaternion=base_quaternion
|
423
|
+
)
|
424
|
+
joint_transforms = model.kin_dyn_parameters.joint_transforms(
|
425
|
+
joint_positions=joint_positions, base_transform=base_transform
|
677
426
|
)
|
678
427
|
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
The updated `JaxSimModelData` object.
|
697
|
-
"""
|
428
|
+
if base_linear_velocity is None and base_angular_velocity is None:
|
429
|
+
base_linear_velocity = self._base_linear_velocity
|
430
|
+
base_angular_velocity = self._base_angular_velocity
|
431
|
+
else:
|
432
|
+
if base_linear_velocity is None:
|
433
|
+
base_linear_velocity = self.base_velocity[:3]
|
434
|
+
if base_angular_velocity is None:
|
435
|
+
base_angular_velocity = self.base_velocity[3:]
|
436
|
+
base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze())
|
437
|
+
base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze())
|
438
|
+
W_v_WB = JaxSimModelData.other_representation_to_inertial(
|
439
|
+
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
|
440
|
+
other_representation=self.velocity_representation,
|
441
|
+
transform=base_transform,
|
442
|
+
is_force=False,
|
443
|
+
).astype(float)
|
444
|
+
base_linear_velocity, base_angular_velocity = W_v_WB[:3], W_v_WB[3:]
|
698
445
|
|
699
|
-
|
446
|
+
link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model(
|
447
|
+
model=model,
|
448
|
+
base_position=base_position,
|
449
|
+
base_quaternion=base_quaternion,
|
450
|
+
joint_positions=joint_positions,
|
451
|
+
joint_velocities=joint_velocities,
|
452
|
+
base_linear_velocity_inertial=base_linear_velocity,
|
453
|
+
base_angular_velocity_inertial=base_angular_velocity,
|
454
|
+
)
|
700
455
|
|
701
|
-
return
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
456
|
+
return super().replace(
|
457
|
+
_joint_positions=joint_positions,
|
458
|
+
_joint_velocities=joint_velocities,
|
459
|
+
_base_quaternion=base_quaternion,
|
460
|
+
_base_linear_velocity=base_linear_velocity,
|
461
|
+
_base_angular_velocity=base_angular_velocity,
|
462
|
+
_base_position=base_position,
|
463
|
+
_base_transform=base_transform,
|
464
|
+
_joint_transforms=joint_transforms,
|
465
|
+
_link_transforms=link_transforms,
|
466
|
+
_link_velocities=link_velocities,
|
467
|
+
validate=validate,
|
709
468
|
)
|
710
469
|
|
711
|
-
|
712
|
-
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
713
|
-
def reset_base_velocity(
|
714
|
-
self,
|
715
|
-
base_velocity: jtp.VectorLike,
|
716
|
-
velocity_representation: VelRepr | None = None,
|
717
|
-
) -> Self:
|
470
|
+
def valid(self, model: js.model.JaxSimModel) -> bool:
|
718
471
|
"""
|
719
|
-
|
472
|
+
Check if the `JaxSimModelData` is valid for a given `JaxSimModel`.
|
720
473
|
|
721
474
|
Args:
|
722
|
-
|
723
|
-
velocity_representation:
|
724
|
-
The velocity representation in which the base velocity is expressed.
|
725
|
-
If `None`, the active representation is considered.
|
475
|
+
model: The `JaxSimModel` to validate the `JaxSimModelData` against.
|
726
476
|
|
727
477
|
Returns:
|
728
|
-
|
478
|
+
`True` if the `JaxSimModelData` is valid for the given model,
|
479
|
+
`False` otherwise.
|
729
480
|
"""
|
481
|
+
if self._joint_positions.shape != (model.dofs(),):
|
482
|
+
return False
|
483
|
+
if self._joint_velocities.shape != (model.dofs(),):
|
484
|
+
return False
|
485
|
+
if self._base_position.shape != (3,):
|
486
|
+
return False
|
487
|
+
if self._base_quaternion.shape != (4,):
|
488
|
+
return False
|
489
|
+
if self._base_linear_velocity.shape != (3,):
|
490
|
+
return False
|
491
|
+
if self._base_angular_velocity.shape != (3,):
|
492
|
+
return False
|
730
493
|
|
731
|
-
|
732
|
-
|
733
|
-
velocity_representation = (
|
734
|
-
velocity_representation
|
735
|
-
if velocity_representation is not None
|
736
|
-
else self.velocity_representation
|
737
|
-
)
|
738
|
-
|
739
|
-
W_v_WB = self.other_representation_to_inertial(
|
740
|
-
array=jnp.atleast_1d(base_velocity.squeeze()).astype(float),
|
741
|
-
other_representation=velocity_representation,
|
742
|
-
transform=self.base_transform(),
|
743
|
-
is_force=False,
|
744
|
-
)
|
745
|
-
|
746
|
-
return self.replace(
|
747
|
-
validate=True,
|
748
|
-
state=self.state.replace(
|
749
|
-
physics_model=self.state.physics_model.replace(
|
750
|
-
base_linear_velocity=W_v_WB[0:3].squeeze().astype(float),
|
751
|
-
base_angular_velocity=W_v_WB[3:6].squeeze().astype(float),
|
752
|
-
)
|
753
|
-
),
|
754
|
-
)
|
494
|
+
return True
|
755
495
|
|
756
496
|
|
757
497
|
@functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"])
|
@@ -788,11 +528,6 @@ def random_model_data(
|
|
788
528
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
789
529
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
790
530
|
] = (-1.0, 1.0),
|
791
|
-
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
|
792
|
-
standard_gravity_bounds: tuple[jtp.FloatLike, jtp.FloatLike] = (
|
793
|
-
jaxsim.math.StandardGravity,
|
794
|
-
jaxsim.math.StandardGravity,
|
795
|
-
),
|
796
531
|
) -> JaxSimModelData:
|
797
532
|
"""
|
798
533
|
Randomly generate a `JaxSimModelData` object.
|
@@ -811,15 +546,13 @@ def random_model_data(
|
|
811
546
|
base_vel_lin_bounds: The bounds for the base linear velocity.
|
812
547
|
base_vel_ang_bounds: The bounds for the base angular velocity.
|
813
548
|
joint_vel_bounds: The bounds for the joint velocities.
|
814
|
-
contacts_params: The parameters of the contact model.
|
815
|
-
standard_gravity_bounds: The bounds for the standard gravity.
|
816
549
|
|
817
550
|
Returns:
|
818
551
|
A `JaxSimModelData` object with random data.
|
819
552
|
"""
|
820
553
|
|
821
554
|
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
822
|
-
k1, k2, k3, k4, k5, k6
|
555
|
+
k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6)
|
823
556
|
|
824
557
|
p_min = jnp.array(base_pos_bounds[0], dtype=float)
|
825
558
|
p_max = jnp.array(base_pos_bounds[1], dtype=float)
|
@@ -831,95 +564,64 @@ def random_model_data(
|
|
831
564
|
ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float)
|
832
565
|
ṡ_min, ṡ_max = joint_vel_bounds
|
833
566
|
|
834
|
-
|
835
|
-
model=model,
|
836
|
-
**(
|
837
|
-
dict(velocity_representation=velocity_representation)
|
838
|
-
if velocity_representation is not None
|
839
|
-
else {}
|
840
|
-
),
|
841
|
-
)
|
567
|
+
base_position = jax.random.uniform(key=k1, shape=(3,), minval=p_min, maxval=p_max)
|
842
568
|
|
843
|
-
|
844
|
-
|
845
|
-
|
569
|
+
base_quaternion = jaxsim.math.Quaternion.to_wxyz(
|
570
|
+
xyzw=jax.scipy.spatial.transform.Rotation.from_euler(
|
571
|
+
seq=base_rpy_seq,
|
572
|
+
angles=jax.random.uniform(
|
573
|
+
key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
|
574
|
+
),
|
575
|
+
).as_quat()
|
576
|
+
)
|
846
577
|
|
847
|
-
|
578
|
+
(
|
579
|
+
joint_positions,
|
580
|
+
joint_velocities,
|
581
|
+
base_linear_velocity,
|
582
|
+
base_angular_velocity,
|
583
|
+
) = (None,) * 4
|
848
584
|
|
849
|
-
|
850
|
-
key=k1, shape=(3,), minval=p_min, maxval=p_max
|
851
|
-
)
|
585
|
+
if model.number_of_joints() > 0:
|
852
586
|
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
|
858
|
-
),
|
859
|
-
).as_quat()
|
587
|
+
s_min, s_max = (
|
588
|
+
jnp.array(joint_pos_bounds, dtype=float)
|
589
|
+
if joint_pos_bounds is not None
|
590
|
+
else (None, None)
|
860
591
|
)
|
861
592
|
|
862
|
-
|
863
|
-
|
864
|
-
s_min
|
865
|
-
|
866
|
-
|
867
|
-
else (None, None)
|
868
|
-
)
|
869
|
-
|
870
|
-
physics_model_state.joint_positions = (
|
871
|
-
js.joint.random_joint_positions(model=model, key=k3)
|
872
|
-
if (s_min is None or s_max is None)
|
873
|
-
else jax.random.uniform(
|
874
|
-
key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
|
875
|
-
)
|
876
|
-
)
|
877
|
-
|
878
|
-
physics_model_state.joint_velocities = jax.random.uniform(
|
879
|
-
key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
|
880
|
-
)
|
881
|
-
|
882
|
-
if model.floating_base():
|
883
|
-
physics_model_state.base_linear_velocity = jax.random.uniform(
|
884
|
-
key=k5, shape=(3,), minval=v_min, maxval=v_max
|
885
|
-
)
|
886
|
-
|
887
|
-
physics_model_state.base_angular_velocity = jax.random.uniform(
|
888
|
-
key=k6, shape=(3,), minval=ω_min, maxval=ω_max
|
889
|
-
)
|
890
|
-
|
891
|
-
random_data.gravity = (
|
892
|
-
jnp.zeros(3, dtype=random_data.gravity.dtype)
|
893
|
-
.at[2]
|
894
|
-
.set(
|
895
|
-
-jax.random.uniform(
|
896
|
-
key=k7,
|
897
|
-
shape=(),
|
898
|
-
minval=standard_gravity_bounds[0],
|
899
|
-
maxval=standard_gravity_bounds[1],
|
900
|
-
)
|
593
|
+
joint_positions = (
|
594
|
+
js.joint.random_joint_positions(model=model, key=k3)
|
595
|
+
if (s_min is None or s_max is None)
|
596
|
+
else jax.random.uniform(
|
597
|
+
key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
|
901
598
|
)
|
902
599
|
)
|
903
600
|
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
model.contact_model,
|
908
|
-
jaxsim.rbda.contacts.SoftContacts
|
909
|
-
| jaxsim.rbda.contacts.ViscoElasticContacts,
|
910
|
-
):
|
601
|
+
joint_velocities = jax.random.uniform(
|
602
|
+
key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
|
603
|
+
)
|
911
604
|
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
validate=False,
|
917
|
-
)
|
605
|
+
if model.floating_base():
|
606
|
+
base_linear_velocity = jax.random.uniform(
|
607
|
+
key=k5, shape=(3,), minval=v_min, maxval=v_max
|
608
|
+
)
|
918
609
|
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
validate=False,
|
923
|
-
)
|
610
|
+
base_angular_velocity = jax.random.uniform(
|
611
|
+
key=k6, shape=(3,), minval=ω_min, maxval=ω_max
|
612
|
+
)
|
924
613
|
|
925
|
-
return
|
614
|
+
return JaxSimModelData.build(
|
615
|
+
model=model,
|
616
|
+
base_position=base_position,
|
617
|
+
base_quaternion=base_quaternion,
|
618
|
+
joint_positions=joint_positions,
|
619
|
+
joint_velocities=joint_velocities,
|
620
|
+
base_linear_velocity=base_linear_velocity,
|
621
|
+
base_angular_velocity=base_angular_velocity,
|
622
|
+
**(
|
623
|
+
{"velocity_representation": velocity_representation}
|
624
|
+
if velocity_representation is not None
|
625
|
+
else {}
|
626
|
+
),
|
627
|
+
)
|