jaxsim 0.2.dev108__py3-none-any.whl → 0.2.dev166__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/api/data.py ADDED
@@ -0,0 +1,951 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import dataclasses
5
+ import functools
6
+ from typing import ContextManager, Sequence
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import jax_dataclasses
11
+ import jaxlie
12
+ import numpy as np
13
+ from jax_dataclasses import Static
14
+
15
+ import jaxsim.api
16
+ import jaxsim.physics.algos.aba
17
+ import jaxsim.physics.algos.crba
18
+ import jaxsim.physics.algos.forward_kinematics
19
+ import jaxsim.physics.algos.rnea
20
+ import jaxsim.physics.model.physics_model
21
+ import jaxsim.physics.model.physics_model_state
22
+ import jaxsim.typing as jtp
23
+ from jaxsim.high_level.common import VelRepr
24
+ from jaxsim.physics.algos import soft_contacts
25
+ from jaxsim.simulation.ode_data import ODEState
26
+ from jaxsim.utils import JaxsimDataclass, Mutability
27
+
28
+ try:
29
+ from typing import Self
30
+ except ImportError:
31
+ from typing_extensions import Self
32
+
33
+
34
+ @jax_dataclasses.pytree_dataclass
35
+ class JaxSimModelData(JaxsimDataclass):
36
+ """
37
+ Class containing the state of a `JaxSimModel` object.
38
+ """
39
+
40
+ state: ODEState
41
+
42
+ gravity: jtp.Array
43
+
44
+ soft_contacts_params: soft_contacts.SoftContactsParams = dataclasses.field(
45
+ repr=False
46
+ )
47
+ time_ns: jtp.Int = dataclasses.field(
48
+ default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
49
+ )
50
+
51
+ velocity_representation: Static[VelRepr] = VelRepr.Inertial
52
+
53
+ def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool:
54
+ """
55
+ Check if the current state is valid for the given model.
56
+
57
+ Args:
58
+ model: The model to check against.
59
+
60
+ Returns:
61
+ `True` if the current state is valid for the given model, `False` otherwise.
62
+ """
63
+
64
+ valid = True
65
+
66
+ if model is not None:
67
+ valid = valid and self.state.valid(physics_model=model.physics_model)
68
+
69
+ return valid
70
+
71
+ @contextlib.contextmanager
72
+ def switch_velocity_representation(
73
+ self, velocity_representation: VelRepr
74
+ ) -> ContextManager[Self]:
75
+ """
76
+ Context manager to temporarily switch the velocity representation.
77
+
78
+ Args:
79
+ velocity_representation: The new velocity representation.
80
+
81
+ Yields:
82
+ The same `JaxSimModelData` object with the new velocity representation.
83
+ """
84
+
85
+ original_representation = self.velocity_representation
86
+
87
+ try:
88
+
89
+ # First, we replace the velocity representation
90
+ with self.mutable_context(
91
+ mutability=Mutability.MUTABLE_NO_VALIDATION,
92
+ restore_after_exception=True,
93
+ ):
94
+ self.velocity_representation = velocity_representation
95
+
96
+ # Then, we yield the data with changed representation.
97
+ # We run this in a mutable context with restoration so that any exception
98
+ # occurring, we restore the original object in case it was modified.
99
+ with self.mutable_context(
100
+ mutability=self._mutability(), restore_after_exception=True
101
+ ):
102
+ yield self
103
+
104
+ finally:
105
+ with self.mutable_context(
106
+ mutability=Mutability.MUTABLE_NO_VALIDATION,
107
+ restore_after_exception=True,
108
+ ):
109
+ self.velocity_representation = original_representation
110
+
111
+ @staticmethod
112
+ def zero(
113
+ model: jaxsim.api.model.JaxSimModel,
114
+ velocity_representation: VelRepr = VelRepr.Inertial,
115
+ ) -> JaxSimModelData:
116
+ """
117
+ Create a `JaxSimModelData` object with zero state.
118
+
119
+ Args:
120
+ model: The model for which to create the zero state.
121
+ velocity_representation: The velocity representation to use.
122
+
123
+ Returns:
124
+ A `JaxSimModelData` object with zero state.
125
+ """
126
+
127
+ return JaxSimModelData.build(
128
+ model=model, velocity_representation=velocity_representation
129
+ )
130
+
131
+ @staticmethod
132
+ def build(
133
+ model: jaxsim.api.model.JaxSimModel,
134
+ base_position: jtp.Vector | None = None,
135
+ base_quaternion: jtp.Vector | None = None,
136
+ joint_positions: jtp.Vector | None = None,
137
+ base_linear_velocity: jtp.Vector | None = None,
138
+ base_angular_velocity: jtp.Vector | None = None,
139
+ joint_velocities: jtp.Vector | None = None,
140
+ gravity: jtp.Vector | None = None,
141
+ soft_contacts_state: soft_contacts.SoftContactsState | None = None,
142
+ soft_contacts_params: soft_contacts.SoftContactsParams | None = None,
143
+ velocity_representation: VelRepr = VelRepr.Inertial,
144
+ time: jtp.FloatLike | None = None,
145
+ ) -> JaxSimModelData:
146
+ """
147
+ Create a `JaxSimModelData` object with the given state.
148
+
149
+ Args:
150
+ model: The model for which to create the state.
151
+ base_position: The base position.
152
+ base_quaternion: The base orientation as a quaternion.
153
+ joint_positions: The joint positions.
154
+ base_linear_velocity:
155
+ The base linear velocity in the selected representation.
156
+ base_angular_velocity:
157
+ The base angular velocity in the selected representation.
158
+ joint_velocities: The joint velocities.
159
+ gravity: The gravity 3D vector.
160
+ soft_contacts_state: The state of the soft contacts.
161
+ soft_contacts_params: The parameters of the soft contacts.
162
+ velocity_representation: The velocity representation to use.
163
+ time: The time at which the state is created.
164
+
165
+ Returns:
166
+ A `JaxSimModelData` object with the given state.
167
+ """
168
+
169
+ base_position = jnp.array(
170
+ base_position if base_position is not None else jnp.zeros(3)
171
+ ).squeeze()
172
+
173
+ base_quaternion = jnp.array(
174
+ base_quaternion
175
+ if base_quaternion is not None
176
+ else jnp.array([1.0, 0, 0, 0])
177
+ ).squeeze()
178
+
179
+ base_linear_velocity = jnp.array(
180
+ base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
181
+ ).squeeze()
182
+
183
+ base_angular_velocity = jnp.array(
184
+ base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
185
+ ).squeeze()
186
+
187
+ gravity = jnp.array(
188
+ gravity if gravity is not None else model.physics_model.gravity[0:3]
189
+ ).squeeze()
190
+
191
+ joint_positions = jnp.atleast_1d(
192
+ joint_positions.squeeze()
193
+ if joint_positions is not None
194
+ else jnp.zeros(model.dofs())
195
+ )
196
+
197
+ joint_velocities = jnp.atleast_1d(
198
+ joint_velocities.squeeze()
199
+ if joint_velocities is not None
200
+ else jnp.zeros(model.dofs())
201
+ )
202
+
203
+ time_ns = (
204
+ jnp.array(time * 1e9, dtype=jnp.uint64)
205
+ if time is not None
206
+ else jnp.array(0, dtype=jnp.uint64)
207
+ )
208
+
209
+ soft_contacts_params = (
210
+ soft_contacts_params
211
+ if soft_contacts_params is not None
212
+ else jaxsim.api.contact.estimate_good_soft_contacts_parameters(model=model)
213
+ )
214
+
215
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
216
+ translation=base_position,
217
+ rotation=jaxlie.SO3.from_quaternion_xyzw(
218
+ base_quaternion[jnp.array([1, 2, 3, 0])]
219
+ ),
220
+ ).as_matrix()
221
+
222
+ v_WB = JaxSimModelData.other_representation_to_inertial(
223
+ array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
224
+ other_representation=velocity_representation,
225
+ transform=W_H_B,
226
+ is_force=False,
227
+ )
228
+
229
+ ode_state = ODEState.build(
230
+ physics_model=model.physics_model,
231
+ physics_model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState(
232
+ base_position=base_position.astype(float),
233
+ base_quaternion=base_quaternion.astype(float),
234
+ joint_positions=joint_positions.astype(float),
235
+ base_linear_velocity=v_WB[0:3].astype(float),
236
+ base_angular_velocity=v_WB[3:6].astype(float),
237
+ joint_velocities=joint_velocities.astype(float),
238
+ ),
239
+ soft_contacts_state=soft_contacts_state,
240
+ )
241
+
242
+ if not ode_state.valid(physics_model=model.physics_model):
243
+ raise ValueError(ode_state)
244
+
245
+ return JaxSimModelData(
246
+ time_ns=time_ns,
247
+ state=ode_state,
248
+ gravity=gravity.astype(float),
249
+ soft_contacts_params=soft_contacts_params,
250
+ velocity_representation=velocity_representation,
251
+ )
252
+
253
+ # ==================
254
+ # Extract quantities
255
+ # ==================
256
+
257
+ def time(self) -> jtp.Float:
258
+ """
259
+ Get the simulated time.
260
+
261
+ Returns:
262
+ The simulated time in seconds.
263
+ """
264
+
265
+ return self.time_ns.astype(float) / 1e9
266
+
267
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
268
+ def joint_positions(
269
+ self,
270
+ model: jaxsim.api.model.JaxSimModel | None = None,
271
+ joint_names: tuple[str, ...] | None = None,
272
+ ) -> jtp.Vector:
273
+ """
274
+ Get the joint positions.
275
+
276
+ Args:
277
+ model: The model to consider.
278
+ joint_names:
279
+ The names of the joints for which to get the positions. If `None`,
280
+ the positions of all joints are returned.
281
+
282
+ Returns:
283
+ If no model and no joint names are provided, the joint positions as a
284
+ `(DoFs,)` vector corresponding to the serialization of the original
285
+ model used to build the data object.
286
+ If a model is provided and no joint names are provided, the joint positions
287
+ as a `(DoFs,)` vector corresponding to the serialization of the
288
+ provided model.
289
+ If a model and joint names are provided, the joint positions as a
290
+ `(len(joint_names),)` vector corresponding to the serialization of
291
+ the passed joint names vector.
292
+ """
293
+
294
+ if model is None:
295
+ return self.state.physics_model.joint_positions
296
+
297
+ if not self.valid(model=model):
298
+ msg = "The data object is not compatible with the provided model"
299
+ raise ValueError(msg)
300
+
301
+ joint_names = joint_names if joint_names is not None else model.joint_names()
302
+
303
+ return self.state.physics_model.joint_positions[
304
+ jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
305
+ ]
306
+
307
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
308
+ def joint_velocities(
309
+ self,
310
+ model: jaxsim.api.model.JaxSimModel | None = None,
311
+ joint_names: tuple[str, ...] | None = None,
312
+ ) -> jtp.Vector:
313
+ """
314
+ Get the joint velocities.
315
+
316
+ Args:
317
+ model: The model to consider.
318
+ joint_names:
319
+ The names of the joints for which to get the velocities. If `None`,
320
+ the velocities of all joints are returned.
321
+
322
+ Returns:
323
+ If no model and no joint names are provided, the joint velocities as a
324
+ `(DoFs,)` vector corresponding to the serialization of the original
325
+ model used to build the data object.
326
+ If a model is provided and no joint names are provided, the joint velocities
327
+ as a `(DoFs,)` vector corresponding to the serialization of the
328
+ provided model.
329
+ If a model and joint names are provided, the joint velocities as a
330
+ `(len(joint_names),)` vector corresponding to the serialization of
331
+ the passed joint names vector.
332
+ """
333
+
334
+ if model is None:
335
+ return self.state.physics_model.joint_velocities
336
+
337
+ if not self.valid(model=model):
338
+ msg = "The data object is not compatible with the provided model"
339
+ raise ValueError(msg)
340
+
341
+ joint_names = joint_names if joint_names is not None else model.joint_names()
342
+
343
+ return self.state.physics_model.joint_velocities[
344
+ jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
345
+ ]
346
+
347
+ @jax.jit
348
+ def base_position(self) -> jtp.Vector:
349
+ """
350
+ Get the base position.
351
+
352
+ Returns:
353
+ The base position.
354
+ """
355
+
356
+ return self.state.physics_model.base_position.squeeze()
357
+
358
+ @functools.partial(jax.jit, static_argnames=["dcm"])
359
+ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
360
+ """
361
+ Get the base orientation.
362
+
363
+ Args:
364
+ dcm: Whether to return the orientation as a SO(3) matrix or quaternion.
365
+
366
+ Returns:
367
+ The base orientation.
368
+ """
369
+
370
+ # Always normalize the quaternion to avoid numerical issues.
371
+ # If the active scheme does not integrate the quaternion on its manifold,
372
+ # we introduce a Baumgarte stabilization to let the quaternion converge to
373
+ # a unit quaternion. In this case, it is not guaranteed that the quaternion
374
+ # stored in the state is a unit quaternion.
375
+ base_unit_quaternion = (
376
+ self.state.physics_model.base_quaternion.squeeze()
377
+ / jnp.linalg.norm(self.state.physics_model.base_quaternion)
378
+ )
379
+
380
+ # Slice to convert quaternion wxyz -> xyzw
381
+ to_xyzw = np.array([1, 2, 3, 0])
382
+
383
+ return (
384
+ base_unit_quaternion
385
+ if not dcm
386
+ else jaxlie.SO3.from_quaternion_xyzw(
387
+ base_unit_quaternion[to_xyzw]
388
+ ).as_matrix()
389
+ )
390
+
391
+ @jax.jit
392
+ def base_transform(self) -> jtp.MatrixJax:
393
+ """
394
+ Get the base transform.
395
+
396
+ Returns:
397
+ The base transform as an SE(3) matrix.
398
+ """
399
+
400
+ W_R_B = self.base_orientation(dcm=True)
401
+ W_p_B = jnp.vstack(self.base_position())
402
+
403
+ return jnp.vstack(
404
+ [
405
+ jnp.block([W_R_B, W_p_B]),
406
+ jnp.array([0, 0, 0, 1]),
407
+ ]
408
+ )
409
+
410
+ @jax.jit
411
+ def base_velocity(self) -> jtp.Vector:
412
+ """
413
+ Get the base 6D velocity.
414
+
415
+ Returns:
416
+ The base 6D velocity in the active representation.
417
+ """
418
+
419
+ W_v_WB = jnp.hstack(
420
+ [
421
+ self.state.physics_model.base_linear_velocity,
422
+ self.state.physics_model.base_angular_velocity,
423
+ ]
424
+ )
425
+
426
+ W_H_B = self.base_transform()
427
+
428
+ return (
429
+ JaxSimModelData.inertial_to_other_representation(
430
+ array=W_v_WB,
431
+ other_representation=self.velocity_representation,
432
+ transform=W_H_B,
433
+ is_force=False,
434
+ )
435
+ .squeeze()
436
+ .astype(float)
437
+ )
438
+
439
+ @jax.jit
440
+ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
441
+ """
442
+ Get the generalized position
443
+ :math:`\mathbf{q} = ({}^W \mathbf{H}_B, \mathbf{s}) \in \text{SO}(3) \times \mathbb{R}^n`.
444
+
445
+ Returns:
446
+ A tuple containing the base transform and the joint positions.
447
+ """
448
+
449
+ return self.base_transform(), self.joint_positions()
450
+
451
+ @jax.jit
452
+ def generalized_velocity(self) -> jtp.Vector:
453
+ """
454
+ Get the generalized velocity
455
+ :math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`
456
+
457
+ Returns:
458
+ The generalized velocity in the active representation.
459
+ """
460
+
461
+ return (
462
+ jnp.hstack([self.base_velocity(), self.joint_velocities()])
463
+ .squeeze()
464
+ .astype(float)
465
+ )
466
+
467
+ # ================
468
+ # Store quantities
469
+ # ================
470
+
471
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
472
+ def reset_joint_positions(
473
+ self,
474
+ positions: jtp.VectorLike,
475
+ model: jaxsim.api.model.JaxSimModel | None = None,
476
+ joint_names: tuple[str, ...] | None = None,
477
+ ) -> Self:
478
+ """
479
+ Reset the joint positions.
480
+
481
+ Args:
482
+ positions: The joint positions.
483
+ model: The model to consider.
484
+ joint_names: The names of the joints for which to set the positions.
485
+
486
+ Returns:
487
+ The updated `JaxSimModelData` object.
488
+ """
489
+
490
+ positions = jnp.array(positions)
491
+
492
+ def replace(s: jtp.VectorLike) -> JaxSimModelData:
493
+ return self.replace(
494
+ validate=True,
495
+ state=self.state.replace(
496
+ physics_model=self.state.physics_model.replace(
497
+ joint_positions=jnp.atleast_1d(s.squeeze()).astype(float)
498
+ )
499
+ ),
500
+ )
501
+
502
+ if model is None:
503
+ return replace(s=positions)
504
+
505
+ if not self.valid(model=model):
506
+ msg = "The data object is not compatible with the provided model"
507
+ raise ValueError(msg)
508
+
509
+ joint_names = joint_names if joint_names is not None else model.joint_names()
510
+
511
+ return replace(
512
+ s=self.state.physics_model.joint_positions.at[
513
+ jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
514
+ ].set(positions)
515
+ )
516
+
517
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
518
+ def reset_joint_velocities(
519
+ self,
520
+ velocities: jtp.VectorLike,
521
+ model: jaxsim.api.model.JaxSimModel | None = None,
522
+ joint_names: tuple[str, ...] | None = None,
523
+ ) -> Self:
524
+ """
525
+ Reset the joint velocities.
526
+
527
+ Args:
528
+ velocities: The joint velocities.
529
+ model: The model to consider.
530
+ joint_names: The names of the joints for which to set the velocities.
531
+
532
+ Returns:
533
+ The updated `JaxSimModelData` object.
534
+ """
535
+
536
+ velocities = jnp.array(velocities)
537
+
538
+ def replace(ṡ: jtp.VectorLike) -> JaxSimModelData:
539
+ return self.replace(
540
+ validate=True,
541
+ state=self.state.replace(
542
+ physics_model=self.state.physics_model.replace(
543
+ joint_velocities=jnp.atleast_1d(ṡ.squeeze()).astype(float)
544
+ )
545
+ ),
546
+ )
547
+
548
+ if model is None:
549
+ return replace(ṡ=velocities)
550
+
551
+ if not self.valid(model=model):
552
+ msg = "The data object is not compatible with the provided model"
553
+ raise ValueError(msg)
554
+
555
+ joint_names = joint_names if joint_names is not None else model.joint_names()
556
+
557
+ return replace(
558
+ ṡ=self.state.physics_model.joint_velocities.at[
559
+ jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
560
+ ].set(velocities)
561
+ )
562
+
563
+ @jax.jit
564
+ def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
565
+ """
566
+ Reset the base position.
567
+
568
+ Args:
569
+ base_position: The base position.
570
+
571
+ Returns:
572
+ The updated `JaxSimModelData` object.
573
+ """
574
+
575
+ base_position = jnp.array(base_position)
576
+
577
+ return self.replace(
578
+ validate=True,
579
+ state=self.state.replace(
580
+ physics_model=self.state.physics_model.replace(
581
+ base_position=jnp.atleast_1d(base_position.squeeze()).astype(float)
582
+ )
583
+ ),
584
+ )
585
+
586
+ @jax.jit
587
+ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
588
+ """
589
+ Reset the base quaternion.
590
+
591
+ Args:
592
+ base_quaternion: The base orientation as a quaternion.
593
+
594
+ Returns:
595
+ The updated `JaxSimModelData` object.
596
+ """
597
+
598
+ base_quaternion = jnp.array(base_quaternion)
599
+
600
+ return self.replace(
601
+ validate=True,
602
+ state=self.state.replace(
603
+ physics_model=self.state.physics_model.replace(
604
+ base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
605
+ float
606
+ )
607
+ )
608
+ ),
609
+ )
610
+
611
+ @jax.jit
612
+ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
613
+ """
614
+ Reset the base pose.
615
+
616
+ Args:
617
+ base_pose: The base pose as an SE(3) matrix.
618
+
619
+ Returns:
620
+ The updated `JaxSimModelData` object.
621
+ """
622
+
623
+ base_pose = jnp.array(base_pose)
624
+
625
+ W_p_B = base_pose[0:3, 3]
626
+
627
+ to_wxyz = np.array([3, 0, 1, 2])
628
+ W_R_B: jaxlie.SO3 = jaxlie.SO3.from_matrix(base_pose[0:3, 0:3]) # noqa
629
+ W_Q_B = W_R_B.as_quaternion_xyzw()[to_wxyz]
630
+
631
+ return self.reset_base_position(base_position=W_p_B).reset_base_quaternion(
632
+ base_quaternion=W_Q_B
633
+ )
634
+
635
+ @functools.partial(jax.jit, static_argnames=["velocity_representation"])
636
+ def reset_base_linear_velocity(
637
+ self,
638
+ linear_velocity: jtp.VectorLike,
639
+ velocity_representation: VelRepr | None = None,
640
+ ) -> Self:
641
+ """
642
+ Reset the base linear velocity.
643
+
644
+ Args:
645
+ linear_velocity: The base linear velocity as a 3D array.
646
+ velocity_representation:
647
+ The velocity representation in which the base velocity is expressed.
648
+ If `None`, the active representation is considered.
649
+
650
+ Returns:
651
+ The updated `JaxSimModelData` object.
652
+ """
653
+
654
+ linear_velocity = jnp.array(linear_velocity)
655
+
656
+ return self.reset_base_velocity(
657
+ base_velocity=jnp.hstack(
658
+ [linear_velocity.squeeze(), self.base_velocity()[3:6]]
659
+ ),
660
+ velocity_representation=velocity_representation,
661
+ )
662
+
663
+ @functools.partial(jax.jit, static_argnames=["velocity_representation"])
664
+ def reset_base_angular_velocity(
665
+ self,
666
+ angular_velocity: jtp.VectorLike,
667
+ velocity_representation: VelRepr | None = None,
668
+ ) -> Self:
669
+ """
670
+ Reset the base angular velocity.
671
+
672
+ Args:
673
+ angular_velocity: The base angular velocity as a 3D array.
674
+ velocity_representation:
675
+ The velocity representation in which the base velocity is expressed.
676
+ If `None`, the active representation is considered.
677
+
678
+ Returns:
679
+ The updated `JaxSimModelData` object.
680
+ """
681
+
682
+ angular_velocity = jnp.array(angular_velocity)
683
+
684
+ return self.reset_base_velocity(
685
+ base_velocity=jnp.hstack(
686
+ [self.base_velocity()[0:3], angular_velocity.squeeze()]
687
+ ),
688
+ velocity_representation=velocity_representation,
689
+ )
690
+
691
+ @functools.partial(jax.jit, static_argnames=["velocity_representation"])
692
+ def reset_base_velocity(
693
+ self,
694
+ base_velocity: jtp.VectorLike,
695
+ velocity_representation: VelRepr | None = None,
696
+ ) -> Self:
697
+ """
698
+ Reset the base 6D velocity.
699
+
700
+ Args:
701
+ base_velocity: The base 6D velocity in the active representation.
702
+ velocity_representation:
703
+ The velocity representation in which the base velocity is expressed.
704
+ If `None`, the active representation is considered.
705
+
706
+ Returns:
707
+ The updated `JaxSimModelData` object.
708
+ """
709
+
710
+ base_velocity = jnp.array(base_velocity)
711
+
712
+ velocity_representation = (
713
+ velocity_representation
714
+ if velocity_representation is not None
715
+ else self.velocity_representation
716
+ )
717
+
718
+ W_v_WB = self.other_representation_to_inertial(
719
+ array=jnp.atleast_1d(base_velocity.squeeze()).astype(float),
720
+ other_representation=velocity_representation,
721
+ transform=self.base_transform(),
722
+ is_force=False,
723
+ )
724
+
725
+ return self.replace(
726
+ validate=True,
727
+ state=self.state.replace(
728
+ physics_model=self.state.physics_model.replace(
729
+ base_linear_velocity=W_v_WB[0:3].squeeze().astype(float),
730
+ base_angular_velocity=W_v_WB[3:6].squeeze().astype(float),
731
+ )
732
+ ),
733
+ )
734
+
735
+ # =============
736
+ # Other helpers
737
+ # =============
738
+
739
+ @staticmethod
740
+ @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
741
+ def inertial_to_other_representation(
742
+ array: jtp.Array,
743
+ other_representation: VelRepr,
744
+ transform: jtp.Matrix,
745
+ is_force: bool = False,
746
+ ) -> jtp.Array:
747
+ """
748
+ Convert a 6D quantity from the inertial to another representation.
749
+
750
+ Args:
751
+ array: The 6D quantity to convert.
752
+ other_representation: The representation to convert to.
753
+ transform:
754
+ The `math:W \mathbf{H}_O` transform, where `math:O` is the
755
+ reference frame of the other representation.
756
+ is_force: Whether the quantity is a 6D force or 6D velocity.
757
+
758
+ Returns:
759
+ The 6D quantity in the other representation.
760
+ """
761
+
762
+ W_array = array.squeeze()
763
+ W_H_O = transform.squeeze()
764
+
765
+ if W_array.size != 6:
766
+ raise ValueError(W_array.size, 6)
767
+
768
+ if W_H_O.shape != (4, 4):
769
+ raise ValueError(W_H_O.shape, (4, 4))
770
+
771
+ match other_representation:
772
+
773
+ case VelRepr.Inertial:
774
+ return W_array
775
+
776
+ case VelRepr.Body:
777
+
778
+ if not is_force:
779
+ O_Xv_W = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint()
780
+ O_array = O_Xv_W @ W_array
781
+
782
+ else:
783
+ O_Xf_W = jaxlie.SE3.from_matrix(W_H_O).adjoint().T
784
+ O_array = O_Xf_W @ W_array
785
+
786
+ return O_array
787
+
788
+ case VelRepr.Mixed:
789
+ W_p_O = W_H_O[0:3, 3]
790
+ W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
791
+
792
+ if not is_force:
793
+ OW_Xv_W = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint()
794
+ OW_array = OW_Xv_W @ W_array
795
+
796
+ else:
797
+ OW_Xf_W = jaxlie.SE3.from_matrix(W_H_OW).adjoint().transpose()
798
+ OW_array = OW_Xf_W @ W_array
799
+
800
+ return OW_array
801
+
802
+ case _:
803
+ raise ValueError(other_representation)
804
+
805
+ @staticmethod
806
+ @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
807
+ def other_representation_to_inertial(
808
+ array: jtp.Array,
809
+ other_representation: VelRepr,
810
+ transform: jtp.Matrix,
811
+ is_force: bool = False,
812
+ ) -> jtp.Array:
813
+ """
814
+ Convert a 6D quantity from another representation to the inertial.
815
+
816
+ Args:
817
+ array: The 6D quantity to convert.
818
+ other_representation: The representation to convert from.
819
+ transform:
820
+ The `math:W \mathbf{H}_O` transform, where `math:O` is the
821
+ reference frame of the other representation.
822
+ is_force: Whether the quantity is a 6D force or 6D velocity.
823
+
824
+ Returns:
825
+ The 6D quantity in the inertial representation.
826
+ """
827
+
828
+ W_array = array.squeeze()
829
+ W_H_O = transform.squeeze()
830
+
831
+ if W_array.size != 6:
832
+ raise ValueError(W_array.size, 6)
833
+
834
+ if W_H_O.shape != (4, 4):
835
+ raise ValueError(W_H_O.shape, (4, 4))
836
+
837
+ match other_representation:
838
+ case VelRepr.Inertial:
839
+ W_array = array
840
+ return W_array
841
+
842
+ case VelRepr.Body:
843
+ O_array = array
844
+
845
+ if not is_force:
846
+ W_Xv_O: jtp.Array = jaxlie.SE3.from_matrix(W_H_O).adjoint()
847
+ W_array = W_Xv_O @ O_array
848
+
849
+ else:
850
+ W_Xf_O = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint().T
851
+ W_array = W_Xf_O @ O_array
852
+
853
+ return W_array
854
+
855
+ case VelRepr.Mixed:
856
+ BW_array = array
857
+ W_p_O = W_H_O[0:3, 3]
858
+ W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
859
+
860
+ if not is_force:
861
+ W_Xv_BW: jtp.Array = jaxlie.SE3.from_matrix(W_H_OW).adjoint()
862
+ W_array = W_Xv_BW @ BW_array
863
+
864
+ else:
865
+ W_Xf_BW = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint().T
866
+ W_array = W_Xf_BW @ BW_array
867
+
868
+ return W_array
869
+
870
+ case _:
871
+ raise ValueError(other_representation)
872
+
873
+
874
+ def random_model_data(
875
+ model: jaxsim.api.model.JaxSimModel,
876
+ *,
877
+ key: jax.Array | None = None,
878
+ base_pos_bounds: tuple[
879
+ jtp.FloatLike | Sequence[jtp.FloatLike],
880
+ jtp.FloatLike | Sequence[jtp.FloatLike],
881
+ ] = ((-1, -1, 0.5), 1.0),
882
+ base_vel_lin_bounds: tuple[
883
+ jtp.FloatLike | Sequence[jtp.FloatLike],
884
+ jtp.FloatLike | Sequence[jtp.FloatLike],
885
+ ] = (-1.0, 1.0),
886
+ base_vel_ang_bounds: tuple[
887
+ jtp.FloatLike | Sequence[jtp.FloatLike],
888
+ jtp.FloatLike | Sequence[jtp.FloatLike],
889
+ ] = (-1.0, 1.0),
890
+ joint_vel_bounds: tuple[
891
+ jtp.FloatLike | Sequence[jtp.FloatLike],
892
+ jtp.FloatLike | Sequence[jtp.FloatLike],
893
+ ] = (-1.0, 1.0),
894
+ ) -> JaxSimModelData:
895
+ """
896
+ Randomly generate a `JaxSimModelData` object.
897
+
898
+ Args:
899
+ model: The target model for the random data.
900
+ key: The random key.
901
+ base_pos_bounds: The bounds for the base position.
902
+ base_vel_lin_bounds: The bounds for the base linear velocity.
903
+ base_vel_ang_bounds: The bounds for the base angular velocity.
904
+ joint_vel_bounds: The bounds for the joint velocities.
905
+
906
+ Returns:
907
+ A `JaxSimModelData` object with random data.
908
+ """
909
+
910
+ key = key if key is not None else jax.random.PRNGKey(seed=0)
911
+ k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6)
912
+
913
+ p_min = jnp.array(base_pos_bounds[0], dtype=float)
914
+ p_max = jnp.array(base_pos_bounds[1], dtype=float)
915
+ v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
916
+ v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
917
+ ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
918
+ ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float)
919
+ ṡ_min, ṡ_max = joint_vel_bounds
920
+
921
+ random_data = JaxSimModelData.zero(model=model)
922
+
923
+ with random_data.mutable_context(mutability=Mutability.MUTABLE):
924
+
925
+ physics_model_state = random_data.state.physics_model
926
+
927
+ physics_model_state.base_position = jax.random.uniform(
928
+ key=k1, shape=(3,), minval=p_min, maxval=p_max
929
+ )
930
+
931
+ physics_model_state.base_quaternion = jaxlie.SO3.from_rpy_radians(
932
+ *jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
933
+ ).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]
934
+
935
+ physics_model_state.joint_positions = jaxsim.api.joint.random_joint_positions(
936
+ model=model, key=k3
937
+ )
938
+
939
+ physics_model_state.base_linear_velocity = jax.random.uniform(
940
+ key=k4, shape=(3,), minval=v_min, maxval=v_max
941
+ )
942
+
943
+ physics_model_state.base_angular_velocity = jax.random.uniform(
944
+ key=k5, shape=(3,), minval=ω_min, maxval=ω_max
945
+ )
946
+
947
+ physics_model_state.joint_velocities = jax.random.uniform(
948
+ key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
949
+ )
950
+
951
+ return random_data