jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/data.py ADDED
@@ -0,0 +1,821 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import functools
5
+ from typing import Sequence
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jax_dataclasses
10
+ import jaxlie
11
+ import numpy as np
12
+
13
+ import jaxsim.api as js
14
+ import jaxsim.rbda
15
+ import jaxsim.typing as jtp
16
+ from jaxsim.math import Quaternion
17
+ from jaxsim.utils import Mutability
18
+ from jaxsim.utils.tracing import not_tracing
19
+
20
+ from . import common
21
+ from .common import VelRepr
22
+ from .ode_data import ODEState
23
+
24
+ try:
25
+ from typing import Self
26
+ except ImportError:
27
+ from typing_extensions import Self
28
+
29
+
30
+ @jax_dataclasses.pytree_dataclass
31
+ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
32
+ """
33
+ Class containing the state of a `JaxSimModel` object.
34
+ """
35
+
36
+ state: ODEState
37
+
38
+ gravity: jtp.Array
39
+
40
+ soft_contacts_params: jaxsim.rbda.SoftContactsParams = dataclasses.field(repr=False)
41
+
42
+ time_ns: jtp.Int = dataclasses.field(
43
+ default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
44
+ )
45
+
46
+ def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
47
+ """
48
+ Check if the current state is valid for the given model.
49
+
50
+ Args:
51
+ model: The model to check against.
52
+
53
+ Returns:
54
+ `True` if the current state is valid for the given model, `False` otherwise.
55
+ """
56
+
57
+ valid = True
58
+ valid = valid and self.standard_gravity() > 0
59
+
60
+ if model is not None:
61
+ valid = valid and self.state.valid(model=model)
62
+
63
+ return valid
64
+
65
+ @staticmethod
66
+ def zero(
67
+ model: js.model.JaxSimModel,
68
+ velocity_representation: VelRepr = VelRepr.Inertial,
69
+ ) -> JaxSimModelData:
70
+ """
71
+ Create a `JaxSimModelData` object with zero state.
72
+
73
+ Args:
74
+ model: The model for which to create the zero state.
75
+ velocity_representation: The velocity representation to use.
76
+
77
+ Returns:
78
+ A `JaxSimModelData` object with zero state.
79
+ """
80
+
81
+ return JaxSimModelData.build(
82
+ model=model, velocity_representation=velocity_representation
83
+ )
84
+
85
+ @staticmethod
86
+ def build(
87
+ model: js.model.JaxSimModel,
88
+ base_position: jtp.Vector | None = None,
89
+ base_quaternion: jtp.Vector | None = None,
90
+ joint_positions: jtp.Vector | None = None,
91
+ base_linear_velocity: jtp.Vector | None = None,
92
+ base_angular_velocity: jtp.Vector | None = None,
93
+ joint_velocities: jtp.Vector | None = None,
94
+ standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
95
+ soft_contacts_state: js.ode_data.SoftContactsState | None = None,
96
+ soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
97
+ velocity_representation: VelRepr = VelRepr.Inertial,
98
+ time: jtp.FloatLike | None = None,
99
+ ) -> JaxSimModelData:
100
+ """
101
+ Create a `JaxSimModelData` object with the given state.
102
+
103
+ Args:
104
+ model: The model for which to create the state.
105
+ base_position: The base position.
106
+ base_quaternion: The base orientation as a quaternion.
107
+ joint_positions: The joint positions.
108
+ base_linear_velocity:
109
+ The base linear velocity in the selected representation.
110
+ base_angular_velocity:
111
+ The base angular velocity in the selected representation.
112
+ joint_velocities: The joint velocities.
113
+ standard_gravity: The standard gravity constant.
114
+ soft_contacts_state: The state of the soft contacts.
115
+ soft_contacts_params: The parameters of the soft contacts.
116
+ velocity_representation: The velocity representation to use.
117
+ time: The time at which the state is created.
118
+
119
+ Returns:
120
+ A `JaxSimModelData` object with the given state.
121
+ """
122
+
123
+ base_position = jnp.array(
124
+ base_position if base_position is not None else jnp.zeros(3)
125
+ ).squeeze()
126
+
127
+ base_quaternion = jnp.array(
128
+ base_quaternion
129
+ if base_quaternion is not None
130
+ else jnp.array([1.0, 0, 0, 0])
131
+ ).squeeze()
132
+
133
+ base_linear_velocity = jnp.array(
134
+ base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
135
+ ).squeeze()
136
+
137
+ base_angular_velocity = jnp.array(
138
+ base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
139
+ ).squeeze()
140
+
141
+ gravity = jnp.zeros(3).at[2].set(-standard_gravity)
142
+
143
+ joint_positions = jnp.atleast_1d(
144
+ joint_positions.squeeze()
145
+ if joint_positions is not None
146
+ else jnp.zeros(model.dofs())
147
+ )
148
+
149
+ joint_velocities = jnp.atleast_1d(
150
+ joint_velocities.squeeze()
151
+ if joint_velocities is not None
152
+ else jnp.zeros(model.dofs())
153
+ )
154
+
155
+ time_ns = (
156
+ jnp.array(time * 1e9, dtype=jnp.uint64)
157
+ if time is not None
158
+ else jnp.array(0, dtype=jnp.uint64)
159
+ )
160
+
161
+ soft_contacts_params = (
162
+ soft_contacts_params
163
+ if soft_contacts_params is not None
164
+ else js.contact.estimate_good_soft_contacts_parameters(
165
+ model=model, standard_gravity=standard_gravity
166
+ )
167
+ )
168
+
169
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
170
+ translation=base_position,
171
+ rotation=jaxlie.SO3.from_quaternion_xyzw(
172
+ base_quaternion[jnp.array([1, 2, 3, 0])]
173
+ ),
174
+ ).as_matrix()
175
+
176
+ v_WB = JaxSimModelData.other_representation_to_inertial(
177
+ array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
178
+ other_representation=velocity_representation,
179
+ transform=W_H_B,
180
+ is_force=False,
181
+ )
182
+
183
+ ode_state = ODEState.build_from_jaxsim_model(
184
+ model=model,
185
+ base_position=base_position.astype(float),
186
+ base_quaternion=base_quaternion.astype(float),
187
+ joint_positions=joint_positions.astype(float),
188
+ base_linear_velocity=v_WB[0:3].astype(float),
189
+ base_angular_velocity=v_WB[3:6].astype(float),
190
+ joint_velocities=joint_velocities.astype(float),
191
+ tangential_deformation=(
192
+ soft_contacts_state.tangential_deformation
193
+ if soft_contacts_state is not None
194
+ else None
195
+ ),
196
+ )
197
+
198
+ if not ode_state.valid(model=model):
199
+ raise ValueError(ode_state)
200
+
201
+ return JaxSimModelData(
202
+ time_ns=time_ns,
203
+ state=ode_state,
204
+ gravity=gravity.astype(float),
205
+ soft_contacts_params=soft_contacts_params,
206
+ velocity_representation=velocity_representation,
207
+ )
208
+
209
+ # ==================
210
+ # Extract quantities
211
+ # ==================
212
+
213
+ def time(self) -> jtp.Float:
214
+ """
215
+ Get the simulated time.
216
+
217
+ Returns:
218
+ The simulated time in seconds.
219
+ """
220
+
221
+ return self.time_ns.astype(float) / 1e9
222
+
223
+ def standard_gravity(self) -> jtp.Float:
224
+ """
225
+ Get the standard gravity constant.
226
+
227
+ Returns:
228
+ The standard gravity constant.
229
+ """
230
+
231
+ return -self.gravity[2]
232
+
233
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
234
+ def joint_positions(
235
+ self,
236
+ model: js.model.JaxSimModel | None = None,
237
+ joint_names: tuple[str, ...] | None = None,
238
+ ) -> jtp.Vector:
239
+ """
240
+ Get the joint positions.
241
+
242
+ Args:
243
+ model: The model to consider.
244
+ joint_names:
245
+ The names of the joints for which to get the positions. If `None`,
246
+ the positions of all joints are returned.
247
+
248
+ Returns:
249
+ If no model and no joint names are provided, the joint positions as a
250
+ `(DoFs,)` vector corresponding to the serialization of the original
251
+ model used to build the data object.
252
+ If a model is provided and no joint names are provided, the joint positions
253
+ as a `(DoFs,)` vector corresponding to the serialization of the
254
+ provided model.
255
+ If a model and joint names are provided, the joint positions as a
256
+ `(len(joint_names),)` vector corresponding to the serialization of
257
+ the passed joint names vector.
258
+ """
259
+
260
+ if model is None:
261
+ if joint_names is not None:
262
+ raise ValueError("Joint names cannot be provided without a model")
263
+
264
+ return self.state.physics_model.joint_positions
265
+
266
+ if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
267
+ model=model
268
+ ):
269
+ msg = "The data object is not compatible with the provided model"
270
+ raise ValueError(msg)
271
+
272
+ joint_names = joint_names if joint_names is not None else model.joint_names()
273
+
274
+ return self.state.physics_model.joint_positions[
275
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
276
+ ]
277
+
278
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
279
+ def joint_velocities(
280
+ self,
281
+ model: js.model.JaxSimModel | None = None,
282
+ joint_names: tuple[str, ...] | None = None,
283
+ ) -> jtp.Vector:
284
+ """
285
+ Get the joint velocities.
286
+
287
+ Args:
288
+ model: The model to consider.
289
+ joint_names:
290
+ The names of the joints for which to get the velocities. If `None`,
291
+ the velocities of all joints are returned.
292
+
293
+ Returns:
294
+ If no model and no joint names are provided, the joint velocities as a
295
+ `(DoFs,)` vector corresponding to the serialization of the original
296
+ model used to build the data object.
297
+ If a model is provided and no joint names are provided, the joint velocities
298
+ as a `(DoFs,)` vector corresponding to the serialization of the
299
+ provided model.
300
+ If a model and joint names are provided, the joint velocities as a
301
+ `(len(joint_names),)` vector corresponding to the serialization of
302
+ the passed joint names vector.
303
+ """
304
+
305
+ if model is None:
306
+ if joint_names is not None:
307
+ raise ValueError("Joint names cannot be provided without a model")
308
+
309
+ return self.state.physics_model.joint_velocities
310
+
311
+ if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
312
+ model=model
313
+ ):
314
+ msg = "The data object is not compatible with the provided model"
315
+ raise ValueError(msg)
316
+
317
+ joint_names = joint_names if joint_names is not None else model.joint_names()
318
+
319
+ return self.state.physics_model.joint_velocities[
320
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
321
+ ]
322
+
323
+ @jax.jit
324
+ def base_position(self) -> jtp.Vector:
325
+ """
326
+ Get the base position.
327
+
328
+ Returns:
329
+ The base position.
330
+ """
331
+
332
+ return self.state.physics_model.base_position.squeeze()
333
+
334
+ @functools.partial(jax.jit, static_argnames=["dcm"])
335
+ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
336
+ """
337
+ Get the base orientation.
338
+
339
+ Args:
340
+ dcm: Whether to return the orientation as a SO(3) matrix or quaternion.
341
+
342
+ Returns:
343
+ The base orientation.
344
+ """
345
+
346
+ # Extract the base quaternion.
347
+ W_Q_B = self.state.physics_model.base_quaternion.squeeze()
348
+
349
+ # Always normalize the quaternion to avoid numerical issues.
350
+ # If the active scheme does not integrate the quaternion on its manifold,
351
+ # we introduce a Baumgarte stabilization to let the quaternion converge to
352
+ # a unit quaternion. In this case, it is not guaranteed that the quaternion
353
+ # stored in the state is a unit quaternion.
354
+ W_Q_B = jax.lax.select(
355
+ pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
356
+ on_true=W_Q_B,
357
+ on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
358
+ )
359
+
360
+ return (
361
+ W_Q_B
362
+ if not dcm
363
+ else jaxlie.SO3.from_quaternion_xyzw(
364
+ Quaternion.to_xyzw(wxyz=W_Q_B)
365
+ ).as_matrix()
366
+ ).astype(float)
367
+
368
+ @jax.jit
369
+ def base_transform(self) -> jtp.MatrixJax:
370
+ """
371
+ Get the base transform.
372
+
373
+ Returns:
374
+ The base transform as an SE(3) matrix.
375
+ """
376
+
377
+ W_R_B = self.base_orientation(dcm=True)
378
+ W_p_B = jnp.vstack(self.base_position())
379
+
380
+ return jnp.vstack(
381
+ [
382
+ jnp.block([W_R_B, W_p_B]),
383
+ jnp.array([0, 0, 0, 1]),
384
+ ]
385
+ )
386
+
387
+ @jax.jit
388
+ def base_velocity(self) -> jtp.Vector:
389
+ """
390
+ Get the base 6D velocity.
391
+
392
+ Returns:
393
+ The base 6D velocity in the active representation.
394
+ """
395
+
396
+ W_v_WB = jnp.hstack(
397
+ [
398
+ self.state.physics_model.base_linear_velocity,
399
+ self.state.physics_model.base_angular_velocity,
400
+ ]
401
+ )
402
+
403
+ W_H_B = self.base_transform()
404
+
405
+ return (
406
+ JaxSimModelData.inertial_to_other_representation(
407
+ array=W_v_WB,
408
+ other_representation=self.velocity_representation,
409
+ transform=W_H_B,
410
+ is_force=False,
411
+ )
412
+ .squeeze()
413
+ .astype(float)
414
+ )
415
+
416
+ @jax.jit
417
+ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
418
+ r"""
419
+ Get the generalized position
420
+ :math:`\\mathbf{q} = ({}^W \\mathbf{H}_B, \\mathbf{s}) \\in \text{SO}(3) \times \\mathbb{R}^n`.
421
+
422
+ Returns:
423
+ A tuple containing the base transform and the joint positions.
424
+ """
425
+
426
+ return self.base_transform(), self.joint_positions()
427
+
428
+ @jax.jit
429
+ def generalized_velocity(self) -> jtp.Vector:
430
+ r"""
431
+ Get the generalized velocity
432
+ :math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\\, \boldsymbol{\\omega}_{W,B};\\, \\mathbf{s}) \\in \\mathbb{R}^{6+n}`
433
+
434
+ Returns:
435
+ The generalized velocity in the active representation.
436
+ """
437
+
438
+ return (
439
+ jnp.hstack([self.base_velocity(), self.joint_velocities()])
440
+ .squeeze()
441
+ .astype(float)
442
+ )
443
+
444
+ # ================
445
+ # Store quantities
446
+ # ================
447
+
448
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
449
+ def reset_joint_positions(
450
+ self,
451
+ positions: jtp.VectorLike,
452
+ model: js.model.JaxSimModel | None = None,
453
+ joint_names: tuple[str, ...] | None = None,
454
+ ) -> Self:
455
+ """
456
+ Reset the joint positions.
457
+
458
+ Args:
459
+ positions: The joint positions.
460
+ model: The model to consider.
461
+ joint_names: The names of the joints for which to set the positions.
462
+
463
+ Returns:
464
+ The updated `JaxSimModelData` object.
465
+ """
466
+
467
+ positions = jnp.array(positions)
468
+
469
+ def replace(s: jtp.VectorLike) -> JaxSimModelData:
470
+ return self.replace(
471
+ validate=True,
472
+ state=self.state.replace(
473
+ physics_model=self.state.physics_model.replace(
474
+ joint_positions=jnp.atleast_1d(s.squeeze()).astype(float)
475
+ )
476
+ ),
477
+ )
478
+
479
+ if model is None:
480
+ return replace(s=positions)
481
+
482
+ if not_tracing(positions) and not self.valid(model=model):
483
+ msg = "The data object is not compatible with the provided model"
484
+ raise ValueError(msg)
485
+
486
+ joint_names = joint_names if joint_names is not None else model.joint_names()
487
+
488
+ return replace(
489
+ s=self.state.physics_model.joint_positions.at[
490
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
491
+ ].set(positions)
492
+ )
493
+
494
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
495
+ def reset_joint_velocities(
496
+ self,
497
+ velocities: jtp.VectorLike,
498
+ model: js.model.JaxSimModel | None = None,
499
+ joint_names: tuple[str, ...] | None = None,
500
+ ) -> Self:
501
+ """
502
+ Reset the joint velocities.
503
+
504
+ Args:
505
+ velocities: The joint velocities.
506
+ model: The model to consider.
507
+ joint_names: The names of the joints for which to set the velocities.
508
+
509
+ Returns:
510
+ The updated `JaxSimModelData` object.
511
+ """
512
+
513
+ velocities = jnp.array(velocities)
514
+
515
+ def replace(ṡ: jtp.VectorLike) -> JaxSimModelData:
516
+ return self.replace(
517
+ validate=True,
518
+ state=self.state.replace(
519
+ physics_model=self.state.physics_model.replace(
520
+ joint_velocities=jnp.atleast_1d(ṡ.squeeze()).astype(float)
521
+ )
522
+ ),
523
+ )
524
+
525
+ if model is None:
526
+ return replace(ṡ=velocities)
527
+
528
+ if not_tracing(velocities) and not self.valid(model=model):
529
+ msg = "The data object is not compatible with the provided model"
530
+ raise ValueError(msg)
531
+
532
+ joint_names = joint_names if joint_names is not None else model.joint_names()
533
+
534
+ return replace(
535
+ ṡ=self.state.physics_model.joint_velocities.at[
536
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
537
+ ].set(velocities)
538
+ )
539
+
540
+ @jax.jit
541
+ def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
542
+ """
543
+ Reset the base position.
544
+
545
+ Args:
546
+ base_position: The base position.
547
+
548
+ Returns:
549
+ The updated `JaxSimModelData` object.
550
+ """
551
+
552
+ base_position = jnp.array(base_position)
553
+
554
+ return self.replace(
555
+ validate=True,
556
+ state=self.state.replace(
557
+ physics_model=self.state.physics_model.replace(
558
+ base_position=jnp.atleast_1d(base_position.squeeze()).astype(float)
559
+ )
560
+ ),
561
+ )
562
+
563
+ @jax.jit
564
+ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
565
+ """
566
+ Reset the base quaternion.
567
+
568
+ Args:
569
+ base_quaternion: The base orientation as a quaternion.
570
+
571
+ Returns:
572
+ The updated `JaxSimModelData` object.
573
+ """
574
+
575
+ base_quaternion = jnp.array(base_quaternion)
576
+
577
+ return self.replace(
578
+ validate=True,
579
+ state=self.state.replace(
580
+ physics_model=self.state.physics_model.replace(
581
+ base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
582
+ float
583
+ )
584
+ )
585
+ ),
586
+ )
587
+
588
+ @jax.jit
589
+ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
590
+ """
591
+ Reset the base pose.
592
+
593
+ Args:
594
+ base_pose: The base pose as an SE(3) matrix.
595
+
596
+ Returns:
597
+ The updated `JaxSimModelData` object.
598
+ """
599
+
600
+ base_pose = jnp.array(base_pose)
601
+
602
+ W_p_B = base_pose[0:3, 3]
603
+
604
+ to_wxyz = np.array([3, 0, 1, 2])
605
+ W_R_B: jaxlie.SO3 = jaxlie.SO3.from_matrix(base_pose[0:3, 0:3]) # noqa
606
+ W_Q_B = W_R_B.as_quaternion_xyzw()[to_wxyz]
607
+
608
+ return self.reset_base_position(base_position=W_p_B).reset_base_quaternion(
609
+ base_quaternion=W_Q_B
610
+ )
611
+
612
+ @functools.partial(jax.jit, static_argnames=["velocity_representation"])
613
+ def reset_base_linear_velocity(
614
+ self,
615
+ linear_velocity: jtp.VectorLike,
616
+ velocity_representation: VelRepr | None = None,
617
+ ) -> Self:
618
+ """
619
+ Reset the base linear velocity.
620
+
621
+ Args:
622
+ linear_velocity: The base linear velocity as a 3D array.
623
+ velocity_representation:
624
+ The velocity representation in which the base velocity is expressed.
625
+ If `None`, the active representation is considered.
626
+
627
+ Returns:
628
+ The updated `JaxSimModelData` object.
629
+ """
630
+
631
+ linear_velocity = jnp.array(linear_velocity)
632
+
633
+ return self.reset_base_velocity(
634
+ base_velocity=jnp.hstack(
635
+ [linear_velocity.squeeze(), self.base_velocity()[3:6]]
636
+ ),
637
+ velocity_representation=velocity_representation,
638
+ )
639
+
640
+ @functools.partial(jax.jit, static_argnames=["velocity_representation"])
641
+ def reset_base_angular_velocity(
642
+ self,
643
+ angular_velocity: jtp.VectorLike,
644
+ velocity_representation: VelRepr | None = None,
645
+ ) -> Self:
646
+ """
647
+ Reset the base angular velocity.
648
+
649
+ Args:
650
+ angular_velocity: The base angular velocity as a 3D array.
651
+ velocity_representation:
652
+ The velocity representation in which the base velocity is expressed.
653
+ If `None`, the active representation is considered.
654
+
655
+ Returns:
656
+ The updated `JaxSimModelData` object.
657
+ """
658
+
659
+ angular_velocity = jnp.array(angular_velocity)
660
+
661
+ return self.reset_base_velocity(
662
+ base_velocity=jnp.hstack(
663
+ [self.base_velocity()[0:3], angular_velocity.squeeze()]
664
+ ),
665
+ velocity_representation=velocity_representation,
666
+ )
667
+
668
+ @functools.partial(jax.jit, static_argnames=["velocity_representation"])
669
+ def reset_base_velocity(
670
+ self,
671
+ base_velocity: jtp.VectorLike,
672
+ velocity_representation: VelRepr | None = None,
673
+ ) -> Self:
674
+ """
675
+ Reset the base 6D velocity.
676
+
677
+ Args:
678
+ base_velocity: The base 6D velocity in the active representation.
679
+ velocity_representation:
680
+ The velocity representation in which the base velocity is expressed.
681
+ If `None`, the active representation is considered.
682
+
683
+ Returns:
684
+ The updated `JaxSimModelData` object.
685
+ """
686
+
687
+ base_velocity = jnp.array(base_velocity)
688
+
689
+ velocity_representation = (
690
+ velocity_representation
691
+ if velocity_representation is not None
692
+ else self.velocity_representation
693
+ )
694
+
695
+ W_v_WB = self.other_representation_to_inertial(
696
+ array=jnp.atleast_1d(base_velocity.squeeze()).astype(float),
697
+ other_representation=velocity_representation,
698
+ transform=self.base_transform(),
699
+ is_force=False,
700
+ )
701
+
702
+ return self.replace(
703
+ validate=True,
704
+ state=self.state.replace(
705
+ physics_model=self.state.physics_model.replace(
706
+ base_linear_velocity=W_v_WB[0:3].squeeze().astype(float),
707
+ base_angular_velocity=W_v_WB[3:6].squeeze().astype(float),
708
+ )
709
+ ),
710
+ )
711
+
712
+
713
+ def random_model_data(
714
+ model: js.model.JaxSimModel,
715
+ *,
716
+ key: jax.Array | None = None,
717
+ velocity_representation: VelRepr | None = None,
718
+ base_pos_bounds: tuple[
719
+ jtp.FloatLike | Sequence[jtp.FloatLike],
720
+ jtp.FloatLike | Sequence[jtp.FloatLike],
721
+ ] = ((-1, -1, 0.5), 1.0),
722
+ base_vel_lin_bounds: tuple[
723
+ jtp.FloatLike | Sequence[jtp.FloatLike],
724
+ jtp.FloatLike | Sequence[jtp.FloatLike],
725
+ ] = (-1.0, 1.0),
726
+ base_vel_ang_bounds: tuple[
727
+ jtp.FloatLike | Sequence[jtp.FloatLike],
728
+ jtp.FloatLike | Sequence[jtp.FloatLike],
729
+ ] = (-1.0, 1.0),
730
+ joint_vel_bounds: tuple[
731
+ jtp.FloatLike | Sequence[jtp.FloatLike],
732
+ jtp.FloatLike | Sequence[jtp.FloatLike],
733
+ ] = (-1.0, 1.0),
734
+ standard_gravity_bounds: tuple[jtp.FloatLike, jtp.FloatLike] = (
735
+ jaxsim.math.StandardGravity,
736
+ jaxsim.math.StandardGravity,
737
+ ),
738
+ ) -> JaxSimModelData:
739
+ """
740
+ Randomly generate a `JaxSimModelData` object.
741
+
742
+ Args:
743
+ model: The target model for the random data.
744
+ key: The random key.
745
+ velocity_representation: The velocity representation to use.
746
+ base_pos_bounds: The bounds for the base position.
747
+ base_vel_lin_bounds: The bounds for the base linear velocity.
748
+ base_vel_ang_bounds: The bounds for the base angular velocity.
749
+ joint_vel_bounds: The bounds for the joint velocities.
750
+ standard_gravity_bounds: The bounds for the standard gravity.
751
+
752
+ Returns:
753
+ A `JaxSimModelData` object with random data.
754
+ """
755
+
756
+ key = key if key is not None else jax.random.PRNGKey(seed=0)
757
+ k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, num=7)
758
+
759
+ p_min = jnp.array(base_pos_bounds[0], dtype=float)
760
+ p_max = jnp.array(base_pos_bounds[1], dtype=float)
761
+ v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
762
+ v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
763
+ ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
764
+ ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float)
765
+ ṡ_min, ṡ_max = joint_vel_bounds
766
+
767
+ random_data = JaxSimModelData.zero(
768
+ model=model,
769
+ **(
770
+ dict(velocity_representation=velocity_representation)
771
+ if velocity_representation is not None
772
+ else {}
773
+ ),
774
+ )
775
+
776
+ with random_data.mutable_context(
777
+ mutability=Mutability.MUTABLE, restore_after_exception=False
778
+ ):
779
+
780
+ physics_model_state = random_data.state.physics_model
781
+
782
+ physics_model_state.base_position = jax.random.uniform(
783
+ key=k1, shape=(3,), minval=p_min, maxval=p_max
784
+ )
785
+
786
+ physics_model_state.base_quaternion = jaxlie.SO3.from_rpy_radians(
787
+ *jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
788
+ ).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]
789
+
790
+ if model.number_of_joints() > 0:
791
+ physics_model_state.joint_positions = js.joint.random_joint_positions(
792
+ model=model, key=k3
793
+ )
794
+
795
+ physics_model_state.joint_velocities = jax.random.uniform(
796
+ key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
797
+ )
798
+
799
+ if model.floating_base():
800
+ physics_model_state.base_linear_velocity = jax.random.uniform(
801
+ key=k5, shape=(3,), minval=v_min, maxval=v_max
802
+ )
803
+
804
+ physics_model_state.base_angular_velocity = jax.random.uniform(
805
+ key=k6, shape=(3,), minval=ω_min, maxval=ω_max
806
+ )
807
+
808
+ random_data.gravity = (
809
+ jnp.zeros(3, dtype=random_data.gravity.dtype)
810
+ .at[2]
811
+ .set(
812
+ -jax.random.uniform(
813
+ key=k7,
814
+ shape=(),
815
+ minval=standard_gravity_bounds[0],
816
+ maxval=standard_gravity_bounds[1],
817
+ )
818
+ )
819
+ )
820
+
821
+ return random_data