jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev105__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +156 -11
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/METADATA +6 -8
  34. jaxsim-0.6.2.dev105.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/top_level.txt +0 -0
jaxsim/api/data.py CHANGED
@@ -4,6 +4,11 @@ import dataclasses
4
4
  import functools
5
5
  from collections.abc import Sequence
6
6
 
7
+ try:
8
+ from typing import override
9
+ except ImportError:
10
+ from typing_extensions import override
11
+
7
12
  import jax
8
13
  import jax.numpy as jnp
9
14
  import jax.scipy.spatial.transform
@@ -13,12 +18,9 @@ import jaxsim.api as js
13
18
  import jaxsim.math
14
19
  import jaxsim.rbda
15
20
  import jaxsim.typing as jtp
16
- from jaxsim.utils import Mutability
17
- from jaxsim.utils.tracing import not_tracing
18
21
 
19
22
  from . import common
20
23
  from .common import VelRepr
21
- from .ode_data import ODEState
22
24
 
23
25
  try:
24
26
  from typing import Self
@@ -29,72 +31,38 @@ except ImportError:
29
31
  @jax_dataclasses.pytree_dataclass
30
32
  class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
31
33
  """
32
- Class containing the data of a `JaxSimModel` object.
34
+ Class storing the state of the physics model dynamics.
35
+
36
+ Attributes:
37
+ joint_positions: The vector of joint positions.
38
+ joint_velocities: The vector of joint velocities.
39
+ base_position: The 3D position of the base link.
40
+ base_quaternion: The quaternion defining the orientation of the base link.
41
+ base_linear_velocity:
42
+ The linear velocity of the base link in inertial-fixed representation.
43
+ base_angular_velocity:
44
+ The angular velocity of the base link in inertial-fixed representation.
45
+ base_transform: The base transform.
46
+ joint_transforms: The joint transforms.
47
+ link_transforms: The link transforms.
48
+ link_velocities: The link velocities in inertial-fixed representation.
33
49
  """
34
50
 
35
- state: ODEState
36
-
37
- gravity: jtp.Vector
38
-
39
- contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
40
-
41
- def __hash__(self) -> int:
42
-
43
- from jaxsim.utils.wrappers import HashedNumpyArray
44
-
45
- return hash(
46
- (
47
- hash(self.state),
48
- HashedNumpyArray.hash_of_array(self.gravity),
49
- hash(self.contacts_params),
50
- )
51
- )
52
-
53
- def __eq__(self, other: JaxSimModelData) -> bool:
54
-
55
- if not isinstance(other, JaxSimModelData):
56
- return False
57
-
58
- return hash(self) == hash(other)
59
-
60
- def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
61
- """
62
- Check if the current state is valid for the given model.
63
-
64
- Args:
65
- model: The model to check against.
66
-
67
- Returns:
68
- `True` if the current state is valid for the given model, `False` otherwise.
69
- """
51
+ # Joint state
52
+ _joint_positions: jtp.Vector
53
+ _joint_velocities: jtp.Vector
70
54
 
71
- valid = True
72
- valid = valid and self.standard_gravity() > 0
55
+ # Base state
56
+ _base_quaternion: jtp.Vector
57
+ _base_linear_velocity: jtp.Vector
58
+ _base_angular_velocity: jtp.Vector
59
+ _base_position: jtp.Vector
73
60
 
74
- if model is not None:
75
- valid = valid and self.state.valid(model=model)
76
-
77
- return valid
78
-
79
- @staticmethod
80
- def zero(
81
- model: js.model.JaxSimModel,
82
- velocity_representation: VelRepr = VelRepr.Inertial,
83
- ) -> JaxSimModelData:
84
- """
85
- Create a `JaxSimModelData` object with zero state.
86
-
87
- Args:
88
- model: The model for which to create the zero state.
89
- velocity_representation: The velocity representation to use.
90
-
91
- Returns:
92
- A `JaxSimModelData` object with zero state.
93
- """
94
-
95
- return JaxSimModelData.build(
96
- model=model, velocity_representation=velocity_representation
97
- )
61
+ # Cached computations.
62
+ _base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None)
63
+ _joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
64
+ _link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
65
+ _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
98
66
 
99
67
  @staticmethod
100
68
  def build(
@@ -105,10 +73,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
105
73
  base_linear_velocity: jtp.VectorLike | None = None,
106
74
  base_angular_velocity: jtp.VectorLike | None = None,
107
75
  joint_velocities: jtp.VectorLike | None = None,
108
- standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
109
- contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
110
- velocity_representation: VelRepr = VelRepr.Inertial,
111
- extended_ode_state: dict[str, jtp.PyTree] | None = None,
76
+ velocity_representation: VelRepr = VelRepr.Mixed,
112
77
  ) -> JaxSimModelData:
113
78
  """
114
79
  Create a `JaxSimModelData` object with the given state.
@@ -123,13 +88,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
123
88
  base_angular_velocity:
124
89
  The base angular velocity in the selected representation.
125
90
  joint_velocities: The joint velocities.
126
- standard_gravity: The standard gravity constant.
127
- contacts_params: The parameters of the soft contacts.
128
- velocity_representation: The velocity representation to use.
129
- extended_ode_state:
130
- Additional user-defined state variables that are not part of the
131
- standard `ODEState` object. Useful to extend the system dynamics
132
- considered by default in JaxSim.
91
+ velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
133
92
 
134
93
  Returns:
135
94
  A `JaxSimModelData` initialized with the given state.
@@ -163,8 +122,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
163
122
  dtype=float,
164
123
  ).squeeze()
165
124
 
166
- gravity = jnp.zeros(3).at[2].set(-standard_gravity)
167
-
168
125
  joint_positions = jnp.atleast_1d(
169
126
  jnp.array(
170
127
  (
@@ -191,166 +148,104 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
191
148
  translation=base_position, quaternion=base_quaternion
192
149
  )
193
150
 
194
- v_WB = JaxSimModelData.other_representation_to_inertial(
151
+ W_v_WB = JaxSimModelData.other_representation_to_inertial(
195
152
  array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
196
153
  other_representation=velocity_representation,
197
154
  transform=W_H_B,
198
155
  is_force=False,
199
156
  ).astype(float)
200
157
 
201
- ode_state = ODEState.build_from_jaxsim_model(
202
- model=model,
203
- base_position=base_position,
204
- base_quaternion=base_quaternion,
205
- joint_positions=joint_positions,
206
- base_linear_velocity=v_WB[0:3],
207
- base_angular_velocity=v_WB[3:6],
208
- joint_velocities=joint_velocities,
209
- # Unpack all the additional ODE states. If the contact model requires an
210
- # additional state that is not explicitly passed to this builder, ODEState
211
- # automatically populates that state with zeroed variables.
212
- # This is not true for any other custom state that the user might want to
213
- # pass to the integrator.
214
- **(extended_ode_state if extended_ode_state else {}),
158
+ joint_transforms = model.kin_dyn_parameters.joint_transforms(
159
+ joint_positions=joint_positions, base_transform=W_H_B
215
160
  )
216
161
 
217
- if not ode_state.valid(model=model):
218
- raise ValueError(ode_state)
162
+ link_transforms, link_velocities_inertial = (
163
+ jaxsim.rbda.forward_kinematics_model(
164
+ model=model,
165
+ base_position=base_position,
166
+ base_quaternion=base_quaternion,
167
+ joint_positions=joint_positions,
168
+ base_linear_velocity_inertial=W_v_WB[0:3],
169
+ base_angular_velocity_inertial=W_v_WB[3:6],
170
+ joint_velocities=joint_velocities,
171
+ )
172
+ )
219
173
 
220
- if contacts_params is None:
174
+ model_data = JaxSimModelData(
175
+ velocity_representation=velocity_representation,
176
+ _base_quaternion=base_quaternion,
177
+ _base_position=base_position,
178
+ _joint_positions=joint_positions,
179
+ _base_linear_velocity=W_v_WB[0:3],
180
+ _base_angular_velocity=W_v_WB[3:6],
181
+ _joint_velocities=joint_velocities,
182
+ _base_transform=W_H_B,
183
+ _joint_transforms=joint_transforms,
184
+ _link_transforms=link_transforms,
185
+ _link_velocities=link_velocities_inertial,
186
+ )
221
187
 
222
- if isinstance(
223
- model.contact_model,
224
- jaxsim.rbda.contacts.SoftContacts
225
- | jaxsim.rbda.contacts.ViscoElasticContacts,
226
- ):
188
+ if not model_data.valid(model=model):
189
+ raise ValueError(
190
+ "The built state is not compatible with the model.", model_data
191
+ )
227
192
 
228
- contacts_params = js.contact.estimate_good_contact_parameters(
229
- model=model, standard_gravity=standard_gravity
230
- )
193
+ return model_data
231
194
 
232
- else:
233
- contacts_params = model.contact_model._parameters_class()
195
+ @staticmethod
196
+ def zero(
197
+ model: js.model.JaxSimModel,
198
+ velocity_representation: VelRepr = VelRepr.Mixed,
199
+ ) -> JaxSimModelData:
200
+ """
201
+ Create a `JaxSimModelData` object with zero state.
234
202
 
235
- return JaxSimModelData(
236
- state=ode_state,
237
- gravity=gravity,
238
- contacts_params=contacts_params,
239
- velocity_representation=velocity_representation,
203
+ Args:
204
+ model: The model for which to create the state.
205
+ velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
206
+
207
+ Returns:
208
+ A `JaxSimModelData` initialized with zero state.
209
+ """
210
+ return JaxSimModelData.build(
211
+ model=model, velocity_representation=velocity_representation
240
212
  )
241
213
 
242
214
  # ==================
243
215
  # Extract quantities
244
216
  # ==================
245
217
 
246
- def standard_gravity(self) -> jtp.Float:
218
+ @property
219
+ def joint_positions(self) -> jtp.Vector:
247
220
  """
248
- Get the standard gravity constant.
221
+ Get the joint positions.
249
222
 
250
223
  Returns:
251
- The standard gravity constant.
224
+ The joint positions.
252
225
  """
226
+ return self._joint_positions
253
227
 
254
- return -self.gravity[2]
255
-
256
- @js.common.named_scope
257
- @functools.partial(jax.jit, static_argnames=["joint_names"])
258
- def joint_positions(
259
- self,
260
- model: js.model.JaxSimModel | None = None,
261
- joint_names: tuple[str, ...] | None = None,
262
- ) -> jtp.Vector:
228
+ @property
229
+ def joint_velocities(self) -> jtp.Vector:
263
230
  """
264
- Get the joint positions.
265
-
266
- Args:
267
- model: The model to consider.
268
- joint_names:
269
- The names of the joints for which to get the positions. If `None`,
270
- the positions of all joints are returned.
231
+ Get the joint velocities.
271
232
 
272
233
  Returns:
273
- If no model and no joint names are provided, the joint positions as a
274
- `(DoFs,)` vector corresponding to the serialization of the original
275
- model used to build the data object.
276
- If a model is provided and no joint names are provided, the joint positions
277
- as a `(DoFs,)` vector corresponding to the serialization of the
278
- provided model.
279
- If a model and joint names are provided, the joint positions as a
280
- `(len(joint_names),)` vector corresponding to the serialization of
281
- the passed joint names vector.
234
+ The joint velocities.
282
235
  """
236
+ return self._joint_velocities
283
237
 
284
- if model is None:
285
- if joint_names is not None:
286
- raise ValueError("Joint names cannot be provided without a model")
287
-
288
- return self.state.physics_model.joint_positions
289
-
290
- if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
291
- model=model
292
- ):
293
- msg = "The data object is not compatible with the provided model"
294
- raise ValueError(msg)
295
-
296
- joint_idxs = (
297
- js.joint.names_to_idxs(joint_names=joint_names, model=model)
298
- if joint_names is not None
299
- else jnp.arange(model.number_of_joints())
300
- )
301
-
302
- return self.state.physics_model.joint_positions[joint_idxs]
303
-
304
- @js.common.named_scope
305
- @functools.partial(jax.jit, static_argnames=["joint_names"])
306
- def joint_velocities(
307
- self,
308
- model: js.model.JaxSimModel | None = None,
309
- joint_names: tuple[str, ...] | None = None,
310
- ) -> jtp.Vector:
238
+ @property
239
+ def base_quaternion(self) -> jtp.Vector:
311
240
  """
312
- Get the joint velocities.
313
-
314
- Args:
315
- model: The model to consider.
316
- joint_names:
317
- The names of the joints for which to get the velocities. If `None`,
318
- the velocities of all joints are returned.
241
+ Get the base quaternion.
319
242
 
320
243
  Returns:
321
- If no model and no joint names are provided, the joint velocities as a
322
- `(DoFs,)` vector corresponding to the serialization of the original
323
- model used to build the data object.
324
- If a model is provided and no joint names are provided, the joint velocities
325
- as a `(DoFs,)` vector corresponding to the serialization of the
326
- provided model.
327
- If a model and joint names are provided, the joint velocities as a
328
- `(len(joint_names),)` vector corresponding to the serialization of
329
- the passed joint names vector.
244
+ The base quaternion.
330
245
  """
246
+ return self._base_quaternion
331
247
 
332
- if model is None:
333
- if joint_names is not None:
334
- raise ValueError("Joint names cannot be provided without a model")
335
-
336
- return self.state.physics_model.joint_velocities
337
-
338
- if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
339
- model=model
340
- ):
341
- msg = "The data object is not compatible with the provided model"
342
- raise ValueError(msg)
343
-
344
- joint_idxs = (
345
- js.joint.names_to_idxs(joint_names=joint_names, model=model)
346
- if joint_names is not None
347
- else jnp.arange(model.number_of_joints())
348
- )
349
-
350
- return self.state.physics_model.joint_velocities[joint_idxs]
351
-
352
- @js.common.named_scope
353
- @jax.jit
248
+ @property
354
249
  def base_position(self) -> jtp.Vector:
355
250
  """
356
251
  Get the base position.
@@ -358,24 +253,19 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
358
253
  Returns:
359
254
  The base position.
360
255
  """
256
+ return self._base_position
361
257
 
362
- return self.state.physics_model.base_position.squeeze()
363
-
364
- @js.common.named_scope
365
- @functools.partial(jax.jit, static_argnames=["dcm"])
366
- def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
258
+ @property
259
+ def base_orientation(self) -> jtp.Matrix:
367
260
  """
368
261
  Get the base orientation.
369
262
 
370
- Args:
371
- dcm: Whether to return the orientation as a SO(3) matrix or quaternion.
372
-
373
263
  Returns:
374
264
  The base orientation.
375
265
  """
376
266
 
377
267
  # Extract the base quaternion.
378
- W_Q_B = self.state.physics_model.base_quaternion.squeeze()
268
+ W_Q_B = self.base_quaternion.squeeze()
379
269
 
380
270
  # Always normalize the quaternion to avoid numerical issues.
381
271
  # If the active scheme does not integrate the quaternion on its manifold,
@@ -384,33 +274,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
384
274
  # stored in the state is a unit quaternion.
385
275
  norm = jaxsim.math.safe_norm(W_Q_B)
386
276
  W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
277
+ return W_Q_B
387
278
 
388
- return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
389
- float
390
- )
391
-
392
- @js.common.named_scope
393
- @jax.jit
394
- def base_transform(self) -> jtp.Matrix:
395
- """
396
- Get the base transform.
397
-
398
- Returns:
399
- The base transform as an SE(3) matrix.
400
- """
401
-
402
- W_R_B = self.base_orientation(dcm=True)
403
- W_p_B = jnp.vstack(self.base_position())
404
-
405
- return jnp.vstack(
406
- [
407
- jnp.block([W_R_B, W_p_B]),
408
- jnp.array([0, 0, 0, 1]),
409
- ]
410
- )
411
-
412
- @js.common.named_scope
413
- @jax.jit
279
+ @property
414
280
  def base_velocity(self) -> jtp.Vector:
415
281
  """
416
282
  Get the base 6D velocity.
@@ -421,12 +287,12 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
421
287
 
422
288
  W_v_WB = jnp.hstack(
423
289
  [
424
- self.state.physics_model.base_linear_velocity,
425
- self.state.physics_model.base_angular_velocity,
290
+ self._base_linear_velocity,
291
+ self._base_angular_velocity,
426
292
  ]
427
293
  )
428
294
 
429
- W_H_B = self.base_transform()
295
+ W_H_B = self._base_transform
430
296
 
431
297
  return (
432
298
  JaxSimModelData.inertial_to_other_representation(
@@ -439,8 +305,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
439
305
  .astype(float)
440
306
  )
441
307
 
442
- @js.common.named_scope
443
- @jax.jit
308
+ @property
444
309
  def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
445
310
  r"""
446
311
  Get the generalized position
@@ -450,10 +315,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
450
315
  A tuple containing the base transform and the joint positions.
451
316
  """
452
317
 
453
- return self.base_transform(), self.joint_positions()
318
+ return self._base_transform, self.joint_positions
454
319
 
455
- @js.common.named_scope
456
- @jax.jit
320
+ @property
457
321
  def generalized_velocity(self) -> jtp.Vector:
458
322
  r"""
459
323
  Get the generalized velocity.
@@ -465,136 +329,24 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
465
329
  """
466
330
 
467
331
  return (
468
- jnp.hstack([self.base_velocity(), self.joint_velocities()])
332
+ jnp.hstack([self.base_velocity, self.joint_velocities])
469
333
  .squeeze()
470
334
  .astype(float)
471
335
  )
472
336
 
473
- # ================
474
- # Store quantities
475
- # ================
476
-
477
- @js.common.named_scope
478
- @functools.partial(jax.jit, static_argnames=["joint_names"])
479
- def reset_joint_positions(
480
- self,
481
- positions: jtp.VectorLike,
482
- model: js.model.JaxSimModel | None = None,
483
- joint_names: tuple[str, ...] | None = None,
484
- ) -> Self:
485
- """
486
- Reset the joint positions.
487
-
488
- Args:
489
- positions: The joint positions.
490
- model: The model to consider.
491
- joint_names: The names of the joints for which to set the positions.
492
-
493
- Returns:
494
- The updated `JaxSimModelData` object.
495
- """
496
-
497
- positions = jnp.array(positions)
498
-
499
- def replace(s: jtp.VectorLike) -> JaxSimModelData:
500
- return self.replace(
501
- validate=True,
502
- state=self.state.replace(
503
- physics_model=self.state.physics_model.replace(
504
- joint_positions=jnp.atleast_1d(s.squeeze()).astype(float)
505
- )
506
- ),
507
- )
508
-
509
- if model is None:
510
- return replace(s=positions)
511
-
512
- if not_tracing(positions) and not self.valid(model=model):
513
- msg = "The data object is not compatible with the provided model"
514
- raise ValueError(msg)
515
-
516
- joint_idxs = (
517
- js.joint.names_to_idxs(joint_names=joint_names, model=model)
518
- if joint_names is not None
519
- else jnp.arange(model.number_of_joints())
520
- )
521
-
522
- return replace(
523
- s=self.state.physics_model.joint_positions.at[joint_idxs].set(positions)
524
- )
525
-
526
- @js.common.named_scope
527
- @functools.partial(jax.jit, static_argnames=["joint_names"])
528
- def reset_joint_velocities(
529
- self,
530
- velocities: jtp.VectorLike,
531
- model: js.model.JaxSimModel | None = None,
532
- joint_names: tuple[str, ...] | None = None,
533
- ) -> Self:
534
- """
535
- Reset the joint velocities.
536
-
537
- Args:
538
- velocities: The joint velocities.
539
- model: The model to consider.
540
- joint_names: The names of the joints for which to set the velocities.
541
-
542
- Returns:
543
- The updated `JaxSimModelData` object.
544
- """
545
-
546
- velocities = jnp.array(velocities)
547
-
548
- def replace(ṡ: jtp.VectorLike) -> JaxSimModelData:
549
- return self.replace(
550
- validate=True,
551
- state=self.state.replace(
552
- physics_model=self.state.physics_model.replace(
553
- joint_velocities=jnp.atleast_1d(ṡ.squeeze()).astype(float)
554
- )
555
- ),
556
- )
557
-
558
- if model is None:
559
- return replace(ṡ=velocities)
560
-
561
- if not_tracing(velocities) and not self.valid(model=model):
562
- msg = "The data object is not compatible with the provided model"
563
- raise ValueError(msg)
564
-
565
- joint_idxs = (
566
- js.joint.names_to_idxs(joint_names=joint_names, model=model)
567
- if joint_names is not None
568
- else jnp.arange(model.number_of_joints())
569
- )
570
-
571
- return replace(
572
- ṡ=self.state.physics_model.joint_velocities.at[joint_idxs].set(velocities)
573
- )
574
-
575
- @js.common.named_scope
576
- @jax.jit
577
- def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
337
+ @property
338
+ def base_transform(self) -> jtp.Matrix:
578
339
  """
579
- Reset the base position.
580
-
581
- Args:
582
- base_position: The base position.
340
+ Get the base transform.
583
341
 
584
342
  Returns:
585
- The updated `JaxSimModelData` object.
343
+ The base transform.
586
344
  """
345
+ return self._base_transform
587
346
 
588
- base_position = jnp.array(base_position)
589
-
590
- return self.replace(
591
- validate=True,
592
- state=self.state.replace(
593
- physics_model=self.state.physics_model.replace(
594
- base_position=jnp.atleast_1d(base_position.squeeze()).astype(float)
595
- )
596
- ),
597
- )
347
+ # ================
348
+ # Store quantities
349
+ # ================
598
350
 
599
351
  @js.common.named_scope
600
352
  @jax.jit
@@ -614,12 +366,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
614
366
  norm = jaxsim.math.safe_norm(W_Q_B)
615
367
  W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
616
368
 
617
- return self.replace(
618
- validate=True,
619
- state=self.state.replace(
620
- physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
621
- ),
622
- )
369
+ return self.replace(validate=True, base_quaternion=W_Q_B)
623
370
 
624
371
  @js.common.named_scope
625
372
  @jax.jit
@@ -635,123 +382,116 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
635
382
  """
636
383
 
637
384
  base_pose = jnp.array(base_pose)
638
-
639
385
  W_p_B = base_pose[0:3, 3]
640
-
641
386
  W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
642
-
643
- return self.reset_base_position(base_position=W_p_B).reset_base_quaternion(
644
- base_quaternion=W_Q_B
387
+ return self.replace(
388
+ base_position=W_p_B,
389
+ base_quaternion=W_Q_B,
645
390
  )
646
391
 
647
- @js.common.named_scope
648
- @functools.partial(jax.jit, static_argnames=["velocity_representation"])
649
- def reset_base_linear_velocity(
392
+ @override
393
+ def replace(
650
394
  self,
651
- linear_velocity: jtp.VectorLike,
652
- velocity_representation: VelRepr | None = None,
395
+ model: js.model.JaxSimModel,
396
+ joint_positions: jtp.Vector | None = None,
397
+ joint_velocities: jtp.Vector | None = None,
398
+ base_quaternion: jtp.Vector | None = None,
399
+ base_linear_velocity: jtp.Vector | None = None,
400
+ base_angular_velocity: jtp.Vector | None = None,
401
+ base_position: jtp.Vector | None = None,
402
+ validate: bool = False,
653
403
  ) -> Self:
654
404
  """
655
- Reset the base linear velocity.
656
-
657
- Args:
658
- linear_velocity: The base linear velocity as a 3D array.
659
- velocity_representation:
660
- The velocity representation in which the base velocity is expressed.
661
- If `None`, the active representation is considered.
662
-
663
- Returns:
664
- The updated `JaxSimModelData` object.
405
+ Replace the attributes of the `JaxSimModelData` object.
665
406
  """
407
+ if joint_positions is None:
408
+ joint_positions = self.joint_positions
409
+ if joint_velocities is None:
410
+ joint_velocities = self.joint_velocities
411
+ if base_quaternion is None:
412
+ base_quaternion = self.base_quaternion
413
+ if base_position is None:
414
+ base_position = self.base_position
666
415
 
667
- linear_velocity = jnp.array(linear_velocity)
416
+ joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float)
417
+ joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float)
418
+ base_quaternion = jnp.atleast_1d(base_quaternion.squeeze()).astype(float)
419
+ base_position = jnp.atleast_1d(base_position.squeeze()).astype(float)
668
420
 
669
- return self.reset_base_velocity(
670
- base_velocity=jnp.hstack(
671
- [
672
- linear_velocity.squeeze(),
673
- self.base_velocity()[3:6],
674
- ]
675
- ),
676
- velocity_representation=velocity_representation,
421
+ base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
422
+ translation=base_position, quaternion=base_quaternion
423
+ )
424
+ joint_transforms = model.kin_dyn_parameters.joint_transforms(
425
+ joint_positions=joint_positions, base_transform=base_transform
677
426
  )
678
427
 
679
- @js.common.named_scope
680
- @functools.partial(jax.jit, static_argnames=["velocity_representation"])
681
- def reset_base_angular_velocity(
682
- self,
683
- angular_velocity: jtp.VectorLike,
684
- velocity_representation: VelRepr | None = None,
685
- ) -> Self:
686
- """
687
- Reset the base angular velocity.
688
-
689
- Args:
690
- angular_velocity: The base angular velocity as a 3D array.
691
- velocity_representation:
692
- The velocity representation in which the base velocity is expressed.
693
- If `None`, the active representation is considered.
694
-
695
- Returns:
696
- The updated `JaxSimModelData` object.
697
- """
428
+ if base_linear_velocity is None and base_angular_velocity is None:
429
+ base_linear_velocity = self._base_linear_velocity
430
+ base_angular_velocity = self._base_angular_velocity
431
+ else:
432
+ if base_linear_velocity is None:
433
+ base_linear_velocity = self.base_velocity[:3]
434
+ if base_angular_velocity is None:
435
+ base_angular_velocity = self.base_velocity[3:]
436
+ base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze())
437
+ base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze())
438
+ W_v_WB = JaxSimModelData.other_representation_to_inertial(
439
+ array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
440
+ other_representation=self.velocity_representation,
441
+ transform=base_transform,
442
+ is_force=False,
443
+ ).astype(float)
444
+ base_linear_velocity, base_angular_velocity = W_v_WB[:3], W_v_WB[3:]
698
445
 
699
- angular_velocity = jnp.array(angular_velocity)
446
+ link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model(
447
+ model=model,
448
+ base_position=base_position,
449
+ base_quaternion=base_quaternion,
450
+ joint_positions=joint_positions,
451
+ joint_velocities=joint_velocities,
452
+ base_linear_velocity_inertial=base_linear_velocity,
453
+ base_angular_velocity_inertial=base_angular_velocity,
454
+ )
700
455
 
701
- return self.reset_base_velocity(
702
- base_velocity=jnp.hstack(
703
- [
704
- self.base_velocity()[0:3],
705
- angular_velocity.squeeze(),
706
- ]
707
- ),
708
- velocity_representation=velocity_representation,
456
+ return super().replace(
457
+ _joint_positions=joint_positions,
458
+ _joint_velocities=joint_velocities,
459
+ _base_quaternion=base_quaternion,
460
+ _base_linear_velocity=base_linear_velocity,
461
+ _base_angular_velocity=base_angular_velocity,
462
+ _base_position=base_position,
463
+ _base_transform=base_transform,
464
+ _joint_transforms=joint_transforms,
465
+ _link_transforms=link_transforms,
466
+ _link_velocities=link_velocities,
467
+ validate=validate,
709
468
  )
710
469
 
711
- @js.common.named_scope
712
- @functools.partial(jax.jit, static_argnames=["velocity_representation"])
713
- def reset_base_velocity(
714
- self,
715
- base_velocity: jtp.VectorLike,
716
- velocity_representation: VelRepr | None = None,
717
- ) -> Self:
470
+ def valid(self, model: js.model.JaxSimModel) -> bool:
718
471
  """
719
- Reset the base 6D velocity.
472
+ Check if the `JaxSimModelData` is valid for a given `JaxSimModel`.
720
473
 
721
474
  Args:
722
- base_velocity: The base 6D velocity in the active representation.
723
- velocity_representation:
724
- The velocity representation in which the base velocity is expressed.
725
- If `None`, the active representation is considered.
475
+ model: The `JaxSimModel` to validate the `JaxSimModelData` against.
726
476
 
727
477
  Returns:
728
- The updated `JaxSimModelData` object.
478
+ `True` if the `JaxSimModelData` is valid for the given model,
479
+ `False` otherwise.
729
480
  """
481
+ if self._joint_positions.shape != (model.dofs(),):
482
+ return False
483
+ if self._joint_velocities.shape != (model.dofs(),):
484
+ return False
485
+ if self._base_position.shape != (3,):
486
+ return False
487
+ if self._base_quaternion.shape != (4,):
488
+ return False
489
+ if self._base_linear_velocity.shape != (3,):
490
+ return False
491
+ if self._base_angular_velocity.shape != (3,):
492
+ return False
730
493
 
731
- base_velocity = jnp.array(base_velocity)
732
-
733
- velocity_representation = (
734
- velocity_representation
735
- if velocity_representation is not None
736
- else self.velocity_representation
737
- )
738
-
739
- W_v_WB = self.other_representation_to_inertial(
740
- array=jnp.atleast_1d(base_velocity.squeeze()).astype(float),
741
- other_representation=velocity_representation,
742
- transform=self.base_transform(),
743
- is_force=False,
744
- )
745
-
746
- return self.replace(
747
- validate=True,
748
- state=self.state.replace(
749
- physics_model=self.state.physics_model.replace(
750
- base_linear_velocity=W_v_WB[0:3].squeeze().astype(float),
751
- base_angular_velocity=W_v_WB[3:6].squeeze().astype(float),
752
- )
753
- ),
754
- )
494
+ return True
755
495
 
756
496
 
757
497
  @functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"])
@@ -788,11 +528,6 @@ def random_model_data(
788
528
  jtp.FloatLike | Sequence[jtp.FloatLike],
789
529
  jtp.FloatLike | Sequence[jtp.FloatLike],
790
530
  ] = (-1.0, 1.0),
791
- contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
792
- standard_gravity_bounds: tuple[jtp.FloatLike, jtp.FloatLike] = (
793
- jaxsim.math.StandardGravity,
794
- jaxsim.math.StandardGravity,
795
- ),
796
531
  ) -> JaxSimModelData:
797
532
  """
798
533
  Randomly generate a `JaxSimModelData` object.
@@ -811,15 +546,13 @@ def random_model_data(
811
546
  base_vel_lin_bounds: The bounds for the base linear velocity.
812
547
  base_vel_ang_bounds: The bounds for the base angular velocity.
813
548
  joint_vel_bounds: The bounds for the joint velocities.
814
- contacts_params: The parameters of the contact model.
815
- standard_gravity_bounds: The bounds for the standard gravity.
816
549
 
817
550
  Returns:
818
551
  A `JaxSimModelData` object with random data.
819
552
  """
820
553
 
821
554
  key = key if key is not None else jax.random.PRNGKey(seed=0)
822
- k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, num=7)
555
+ k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6)
823
556
 
824
557
  p_min = jnp.array(base_pos_bounds[0], dtype=float)
825
558
  p_max = jnp.array(base_pos_bounds[1], dtype=float)
@@ -831,95 +564,64 @@ def random_model_data(
831
564
  ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float)
832
565
  ṡ_min, ṡ_max = joint_vel_bounds
833
566
 
834
- random_data = JaxSimModelData.zero(
835
- model=model,
836
- **(
837
- dict(velocity_representation=velocity_representation)
838
- if velocity_representation is not None
839
- else {}
840
- ),
841
- )
567
+ base_position = jax.random.uniform(key=k1, shape=(3,), minval=p_min, maxval=p_max)
842
568
 
843
- with random_data.mutable_context(
844
- mutability=Mutability.MUTABLE, restore_after_exception=False
845
- ):
569
+ base_quaternion = jaxsim.math.Quaternion.to_wxyz(
570
+ xyzw=jax.scipy.spatial.transform.Rotation.from_euler(
571
+ seq=base_rpy_seq,
572
+ angles=jax.random.uniform(
573
+ key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
574
+ ),
575
+ ).as_quat()
576
+ )
846
577
 
847
- physics_model_state = random_data.state.physics_model
578
+ (
579
+ joint_positions,
580
+ joint_velocities,
581
+ base_linear_velocity,
582
+ base_angular_velocity,
583
+ ) = (None,) * 4
848
584
 
849
- physics_model_state.base_position = jax.random.uniform(
850
- key=k1, shape=(3,), minval=p_min, maxval=p_max
851
- )
585
+ if model.number_of_joints() > 0:
852
586
 
853
- physics_model_state.base_quaternion = jaxsim.math.Quaternion.to_wxyz(
854
- xyzw=jax.scipy.spatial.transform.Rotation.from_euler(
855
- seq=base_rpy_seq,
856
- angles=jax.random.uniform(
857
- key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
858
- ),
859
- ).as_quat()
587
+ s_min, s_max = (
588
+ jnp.array(joint_pos_bounds, dtype=float)
589
+ if joint_pos_bounds is not None
590
+ else (None, None)
860
591
  )
861
592
 
862
- if model.number_of_joints() > 0:
863
-
864
- s_min, s_max = (
865
- jnp.array(joint_pos_bounds, dtype=float)
866
- if joint_pos_bounds is not None
867
- else (None, None)
868
- )
869
-
870
- physics_model_state.joint_positions = (
871
- js.joint.random_joint_positions(model=model, key=k3)
872
- if (s_min is None or s_max is None)
873
- else jax.random.uniform(
874
- key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
875
- )
876
- )
877
-
878
- physics_model_state.joint_velocities = jax.random.uniform(
879
- key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
880
- )
881
-
882
- if model.floating_base():
883
- physics_model_state.base_linear_velocity = jax.random.uniform(
884
- key=k5, shape=(3,), minval=v_min, maxval=v_max
885
- )
886
-
887
- physics_model_state.base_angular_velocity = jax.random.uniform(
888
- key=k6, shape=(3,), minval=ω_min, maxval=ω_max
889
- )
890
-
891
- random_data.gravity = (
892
- jnp.zeros(3, dtype=random_data.gravity.dtype)
893
- .at[2]
894
- .set(
895
- -jax.random.uniform(
896
- key=k7,
897
- shape=(),
898
- minval=standard_gravity_bounds[0],
899
- maxval=standard_gravity_bounds[1],
900
- )
593
+ joint_positions = (
594
+ js.joint.random_joint_positions(model=model, key=k3)
595
+ if (s_min is None or s_max is None)
596
+ else jax.random.uniform(
597
+ key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
901
598
  )
902
599
  )
903
600
 
904
- if contacts_params is None:
905
-
906
- if isinstance(
907
- model.contact_model,
908
- jaxsim.rbda.contacts.SoftContacts
909
- | jaxsim.rbda.contacts.ViscoElasticContacts,
910
- ):
601
+ joint_velocities = jax.random.uniform(
602
+ key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
603
+ )
911
604
 
912
- random_data = random_data.replace(
913
- contacts_params=js.contact.estimate_good_contact_parameters(
914
- model=model, standard_gravity=random_data.gravity
915
- ),
916
- validate=False,
917
- )
605
+ if model.floating_base():
606
+ base_linear_velocity = jax.random.uniform(
607
+ key=k5, shape=(3,), minval=v_min, maxval=v_max
608
+ )
918
609
 
919
- else:
920
- random_data = random_data.replace(
921
- contacts_params=model.contact_model._parameters_class(),
922
- validate=False,
923
- )
610
+ base_angular_velocity = jax.random.uniform(
611
+ key=k6, shape=(3,), minval=ω_min, maxval=ω_max
612
+ )
924
613
 
925
- return random_data
614
+ return JaxSimModelData.build(
615
+ model=model,
616
+ base_position=base_position,
617
+ base_quaternion=base_quaternion,
618
+ joint_positions=joint_positions,
619
+ joint_velocities=joint_velocities,
620
+ base_linear_velocity=base_linear_velocity,
621
+ base_angular_velocity=base_angular_velocity,
622
+ **(
623
+ {"velocity_representation": velocity_representation}
624
+ if velocity_representation is not None
625
+ else {}
626
+ ),
627
+ )