jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/data.py CHANGED
@@ -2,28 +2,23 @@ from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
4
  import functools
5
- from typing import Sequence
5
+ from collections.abc import Sequence
6
6
 
7
7
  import jax
8
8
  import jax.numpy as jnp
9
+ import jax.scipy.spatial.transform
9
10
  import jax_dataclasses
10
- import jaxlie
11
- import numpy as np
12
-
13
- import jaxsim.api
14
- import jaxsim.physics.algos.aba
15
- import jaxsim.physics.algos.crba
16
- import jaxsim.physics.algos.forward_kinematics
17
- import jaxsim.physics.algos.rnea
18
- import jaxsim.physics.model.physics_model
19
- import jaxsim.physics.model.physics_model_state
11
+
12
+ import jaxsim.api as js
13
+ import jaxsim.math
14
+ import jaxsim.rbda
20
15
  import jaxsim.typing as jtp
21
- from jaxsim.high_level.common import VelRepr
22
- from jaxsim.physics.algos import soft_contacts
23
- from jaxsim.simulation.ode_data import ODEState
24
16
  from jaxsim.utils import Mutability
17
+ from jaxsim.utils.tracing import not_tracing
25
18
 
26
19
  from . import common
20
+ from .common import VelRepr
21
+ from .ode_data import ODEState
27
22
 
28
23
  try:
29
24
  from typing import Self
@@ -34,21 +29,35 @@ except ImportError:
34
29
  @jax_dataclasses.pytree_dataclass
35
30
  class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
36
31
  """
37
- Class containing the state of a `JaxSimModel` object.
32
+ Class containing the data of a `JaxSimModel` object.
38
33
  """
39
34
 
40
35
  state: ODEState
41
36
 
42
- gravity: jtp.Array
37
+ gravity: jtp.Vector
43
38
 
44
- soft_contacts_params: soft_contacts.SoftContactsParams = dataclasses.field(
45
- repr=False
46
- )
47
- time_ns: jtp.Int = dataclasses.field(
48
- default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
49
- )
39
+ contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
40
+
41
+ def __hash__(self) -> int:
42
+
43
+ from jaxsim.utils.wrappers import HashedNumpyArray
44
+
45
+ return hash(
46
+ (
47
+ hash(self.state),
48
+ HashedNumpyArray.hash_of_array(self.gravity),
49
+ hash(self.contacts_params),
50
+ )
51
+ )
52
+
53
+ def __eq__(self, other: JaxSimModelData) -> bool:
50
54
 
51
- def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool:
55
+ if not isinstance(other, JaxSimModelData):
56
+ return False
57
+
58
+ return hash(self) == hash(other)
59
+
60
+ def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
52
61
  """
53
62
  Check if the current state is valid for the given model.
54
63
 
@@ -60,15 +69,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
60
69
  """
61
70
 
62
71
  valid = True
72
+ valid = valid and self.standard_gravity() > 0
63
73
 
64
74
  if model is not None:
65
- valid = valid and self.state.valid(physics_model=model.physics_model)
75
+ valid = valid and self.state.valid(model=model)
66
76
 
67
77
  return valid
68
78
 
69
79
  @staticmethod
70
80
  def zero(
71
- model: jaxsim.api.model.JaxSimModel,
81
+ model: js.model.JaxSimModel,
72
82
  velocity_representation: VelRepr = VelRepr.Inertial,
73
83
  ) -> JaxSimModelData:
74
84
  """
@@ -88,18 +98,17 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
88
98
 
89
99
  @staticmethod
90
100
  def build(
91
- model: jaxsim.api.model.JaxSimModel,
92
- base_position: jtp.Vector | None = None,
93
- base_quaternion: jtp.Vector | None = None,
94
- joint_positions: jtp.Vector | None = None,
95
- base_linear_velocity: jtp.Vector | None = None,
96
- base_angular_velocity: jtp.Vector | None = None,
97
- joint_velocities: jtp.Vector | None = None,
98
- gravity: jtp.Vector | None = None,
99
- soft_contacts_state: soft_contacts.SoftContactsState | None = None,
100
- soft_contacts_params: soft_contacts.SoftContactsParams | None = None,
101
+ model: js.model.JaxSimModel,
102
+ base_position: jtp.VectorLike | None = None,
103
+ base_quaternion: jtp.VectorLike | None = None,
104
+ joint_positions: jtp.VectorLike | None = None,
105
+ base_linear_velocity: jtp.VectorLike | None = None,
106
+ base_angular_velocity: jtp.VectorLike | None = None,
107
+ joint_velocities: jtp.VectorLike | None = None,
108
+ standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
109
+ contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
101
110
  velocity_representation: VelRepr = VelRepr.Inertial,
102
- time: jtp.FloatLike | None = None,
111
+ extended_ode_state: dict[str, jtp.PyTree] | None = None,
103
112
  ) -> JaxSimModelData:
104
113
  """
105
114
  Create a `JaxSimModelData` object with the given state.
@@ -114,97 +123,119 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
114
123
  base_angular_velocity:
115
124
  The base angular velocity in the selected representation.
116
125
  joint_velocities: The joint velocities.
117
- gravity: The gravity 3D vector.
118
- soft_contacts_state: The state of the soft contacts.
119
- soft_contacts_params: The parameters of the soft contacts.
126
+ standard_gravity: The standard gravity constant.
127
+ contacts_params: The parameters of the soft contacts.
120
128
  velocity_representation: The velocity representation to use.
121
- time: The time at which the state is created.
129
+ extended_ode_state:
130
+ Additional user-defined state variables that are not part of the
131
+ standard `ODEState` object. Useful to extend the system dynamics
132
+ considered by default in JaxSim.
122
133
 
123
134
  Returns:
124
- A `JaxSimModelData` object with the given state.
135
+ A `JaxSimModelData` initialized with the given state.
125
136
  """
126
137
 
127
138
  base_position = jnp.array(
128
- base_position if base_position is not None else jnp.zeros(3)
139
+ base_position if base_position is not None else jnp.zeros(3),
140
+ dtype=float,
129
141
  ).squeeze()
130
142
 
131
143
  base_quaternion = jnp.array(
132
- base_quaternion
133
- if base_quaternion is not None
134
- else jnp.array([1.0, 0, 0, 0])
144
+ (
145
+ base_quaternion
146
+ if base_quaternion is not None
147
+ else jnp.array([1.0, 0, 0, 0])
148
+ ),
149
+ dtype=float,
135
150
  ).squeeze()
136
151
 
137
152
  base_linear_velocity = jnp.array(
138
- base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
153
+ base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3),
154
+ dtype=float,
139
155
  ).squeeze()
140
156
 
141
157
  base_angular_velocity = jnp.array(
142
- base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
158
+ (
159
+ base_angular_velocity
160
+ if base_angular_velocity is not None
161
+ else jnp.zeros(3)
162
+ ),
163
+ dtype=float,
143
164
  ).squeeze()
144
165
 
145
- gravity = jnp.array(
146
- gravity if gravity is not None else model.physics_model.gravity[0:3]
147
- ).squeeze()
166
+ gravity = jnp.zeros(3).at[2].set(-standard_gravity)
148
167
 
149
168
  joint_positions = jnp.atleast_1d(
150
- joint_positions.squeeze()
151
- if joint_positions is not None
152
- else jnp.zeros(model.dofs())
169
+ jnp.array(
170
+ (
171
+ joint_positions
172
+ if joint_positions is not None
173
+ else jnp.zeros(model.dofs())
174
+ ),
175
+ dtype=float,
176
+ ).squeeze()
153
177
  )
154
178
 
155
179
  joint_velocities = jnp.atleast_1d(
156
- joint_velocities.squeeze()
157
- if joint_velocities is not None
158
- else jnp.zeros(model.dofs())
159
- )
160
-
161
- time_ns = (
162
- jnp.array(time * 1e9, dtype=jnp.uint64)
163
- if time is not None
164
- else jnp.array(0, dtype=jnp.uint64)
180
+ jnp.array(
181
+ (
182
+ joint_velocities
183
+ if joint_velocities is not None
184
+ else jnp.zeros(model.dofs())
185
+ ),
186
+ dtype=float,
187
+ ).squeeze()
165
188
  )
166
189
 
167
- soft_contacts_params = (
168
- soft_contacts_params
169
- if soft_contacts_params is not None
170
- else jaxsim.api.contact.estimate_good_soft_contacts_parameters(model=model)
190
+ W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
191
+ translation=base_position, quaternion=base_quaternion
171
192
  )
172
193
 
173
- W_H_B = jaxlie.SE3.from_rotation_and_translation(
174
- translation=base_position,
175
- rotation=jaxlie.SO3.from_quaternion_xyzw(
176
- base_quaternion[jnp.array([1, 2, 3, 0])]
177
- ),
178
- ).as_matrix()
179
-
180
194
  v_WB = JaxSimModelData.other_representation_to_inertial(
181
195
  array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
182
196
  other_representation=velocity_representation,
183
197
  transform=W_H_B,
184
198
  is_force=False,
199
+ ).astype(float)
200
+
201
+ ode_state = ODEState.build_from_jaxsim_model(
202
+ model=model,
203
+ base_position=base_position,
204
+ base_quaternion=base_quaternion,
205
+ joint_positions=joint_positions,
206
+ base_linear_velocity=v_WB[0:3],
207
+ base_angular_velocity=v_WB[3:6],
208
+ joint_velocities=joint_velocities,
209
+ # Unpack all the additional ODE states. If the contact model requires an
210
+ # additional state that is not explicitly passed to this builder, ODEState
211
+ # automatically populates that state with zeroed variables.
212
+ # This is not true for any other custom state that the user might want to
213
+ # pass to the integrator.
214
+ **(extended_ode_state if extended_ode_state else {}),
185
215
  )
186
216
 
187
- ode_state = ODEState.build(
188
- physics_model=model.physics_model,
189
- physics_model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState(
190
- base_position=base_position.astype(float),
191
- base_quaternion=base_quaternion.astype(float),
192
- joint_positions=joint_positions.astype(float),
193
- base_linear_velocity=v_WB[0:3].astype(float),
194
- base_angular_velocity=v_WB[3:6].astype(float),
195
- joint_velocities=joint_velocities.astype(float),
196
- ),
197
- soft_contacts_state=soft_contacts_state,
198
- )
199
-
200
- if not ode_state.valid(physics_model=model.physics_model):
217
+ if not ode_state.valid(model=model):
201
218
  raise ValueError(ode_state)
202
219
 
220
+ if contacts_params is None:
221
+
222
+ if isinstance(
223
+ model.contact_model,
224
+ jaxsim.rbda.contacts.SoftContacts
225
+ | jaxsim.rbda.contacts.ViscoElasticContacts,
226
+ ):
227
+
228
+ contacts_params = js.contact.estimate_good_contact_parameters(
229
+ model=model, standard_gravity=standard_gravity
230
+ )
231
+
232
+ else:
233
+ contacts_params = model.contact_model._parameters_class()
234
+
203
235
  return JaxSimModelData(
204
- time_ns=time_ns,
205
236
  state=ode_state,
206
- gravity=gravity.astype(float),
207
- soft_contacts_params=soft_contacts_params,
237
+ gravity=gravity,
238
+ contacts_params=contacts_params,
208
239
  velocity_representation=velocity_representation,
209
240
  )
210
241
 
@@ -212,20 +243,21 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
212
243
  # Extract quantities
213
244
  # ==================
214
245
 
215
- def time(self) -> jtp.Float:
246
+ def standard_gravity(self) -> jtp.Float:
216
247
  """
217
- Get the simulated time.
248
+ Get the standard gravity constant.
218
249
 
219
250
  Returns:
220
- The simulated time in seconds.
251
+ The standard gravity constant.
221
252
  """
222
253
 
223
- return self.time_ns.astype(float) / 1e9
254
+ return -self.gravity[2]
224
255
 
256
+ @js.common.named_scope
225
257
  @functools.partial(jax.jit, static_argnames=["joint_names"])
226
258
  def joint_positions(
227
259
  self,
228
- model: jaxsim.api.model.JaxSimModel | None = None,
260
+ model: js.model.JaxSimModel | None = None,
229
261
  joint_names: tuple[str, ...] | None = None,
230
262
  ) -> jtp.Vector:
231
263
  """
@@ -250,22 +282,30 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
250
282
  """
251
283
 
252
284
  if model is None:
285
+ if joint_names is not None:
286
+ raise ValueError("Joint names cannot be provided without a model")
287
+
253
288
  return self.state.physics_model.joint_positions
254
289
 
255
- if not self.valid(model=model):
290
+ if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
291
+ model=model
292
+ ):
256
293
  msg = "The data object is not compatible with the provided model"
257
294
  raise ValueError(msg)
258
295
 
259
- joint_names = joint_names if joint_names is not None else model.joint_names()
296
+ joint_idxs = (
297
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
298
+ if joint_names is not None
299
+ else jnp.arange(model.number_of_joints())
300
+ )
260
301
 
261
- return self.state.physics_model.joint_positions[
262
- jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
263
- ]
302
+ return self.state.physics_model.joint_positions[joint_idxs]
264
303
 
304
+ @js.common.named_scope
265
305
  @functools.partial(jax.jit, static_argnames=["joint_names"])
266
306
  def joint_velocities(
267
307
  self,
268
- model: jaxsim.api.model.JaxSimModel | None = None,
308
+ model: js.model.JaxSimModel | None = None,
269
309
  joint_names: tuple[str, ...] | None = None,
270
310
  ) -> jtp.Vector:
271
311
  """
@@ -290,18 +330,26 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
290
330
  """
291
331
 
292
332
  if model is None:
333
+ if joint_names is not None:
334
+ raise ValueError("Joint names cannot be provided without a model")
335
+
293
336
  return self.state.physics_model.joint_velocities
294
337
 
295
- if not self.valid(model=model):
338
+ if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
339
+ model=model
340
+ ):
296
341
  msg = "The data object is not compatible with the provided model"
297
342
  raise ValueError(msg)
298
343
 
299
- joint_names = joint_names if joint_names is not None else model.joint_names()
344
+ joint_idxs = (
345
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
346
+ if joint_names is not None
347
+ else jnp.arange(model.number_of_joints())
348
+ )
300
349
 
301
- return self.state.physics_model.joint_velocities[
302
- jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
303
- ]
350
+ return self.state.physics_model.joint_velocities[joint_idxs]
304
351
 
352
+ @js.common.named_scope
305
353
  @jax.jit
306
354
  def base_position(self) -> jtp.Vector:
307
355
  """
@@ -313,6 +361,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
313
361
 
314
362
  return self.state.physics_model.base_position.squeeze()
315
363
 
364
+ @js.common.named_scope
316
365
  @functools.partial(jax.jit, static_argnames=["dcm"])
317
366
  def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
318
367
  """
@@ -325,29 +374,24 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
325
374
  The base orientation.
326
375
  """
327
376
 
377
+ # Extract the base quaternion.
378
+ W_Q_B = self.state.physics_model.base_quaternion.squeeze()
379
+
328
380
  # Always normalize the quaternion to avoid numerical issues.
329
381
  # If the active scheme does not integrate the quaternion on its manifold,
330
382
  # we introduce a Baumgarte stabilization to let the quaternion converge to
331
383
  # a unit quaternion. In this case, it is not guaranteed that the quaternion
332
384
  # stored in the state is a unit quaternion.
333
- base_unit_quaternion = (
334
- self.state.physics_model.base_quaternion.squeeze()
335
- / jnp.linalg.norm(self.state.physics_model.base_quaternion)
336
- )
385
+ norm = jaxsim.math.safe_norm(W_Q_B)
386
+ W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
337
387
 
338
- # Slice to convert quaternion wxyz -> xyzw
339
- to_xyzw = np.array([1, 2, 3, 0])
340
-
341
- return (
342
- base_unit_quaternion
343
- if not dcm
344
- else jaxlie.SO3.from_quaternion_xyzw(
345
- base_unit_quaternion[to_xyzw]
346
- ).as_matrix()
388
+ return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
389
+ float
347
390
  )
348
391
 
392
+ @js.common.named_scope
349
393
  @jax.jit
350
- def base_transform(self) -> jtp.MatrixJax:
394
+ def base_transform(self) -> jtp.Matrix:
351
395
  """
352
396
  Get the base transform.
353
397
 
@@ -365,6 +409,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
365
409
  ]
366
410
  )
367
411
 
412
+ @js.common.named_scope
368
413
  @jax.jit
369
414
  def base_velocity(self) -> jtp.Vector:
370
415
  """
@@ -394,9 +439,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
394
439
  .astype(float)
395
440
  )
396
441
 
442
+ @js.common.named_scope
397
443
  @jax.jit
398
444
  def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
399
- """
445
+ r"""
400
446
  Get the generalized position
401
447
  :math:`\mathbf{q} = ({}^W \mathbf{H}_B, \mathbf{s}) \in \text{SO}(3) \times \mathbb{R}^n`.
402
448
 
@@ -406,10 +452,12 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
406
452
 
407
453
  return self.base_transform(), self.joint_positions()
408
454
 
455
+ @js.common.named_scope
409
456
  @jax.jit
410
457
  def generalized_velocity(self) -> jtp.Vector:
411
- """
412
- Get the generalized velocity
458
+ r"""
459
+ Get the generalized velocity.
460
+
413
461
  :math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`
414
462
 
415
463
  Returns:
@@ -426,11 +474,12 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
426
474
  # Store quantities
427
475
  # ================
428
476
 
477
+ @js.common.named_scope
429
478
  @functools.partial(jax.jit, static_argnames=["joint_names"])
430
479
  def reset_joint_positions(
431
480
  self,
432
481
  positions: jtp.VectorLike,
433
- model: jaxsim.api.model.JaxSimModel | None = None,
482
+ model: js.model.JaxSimModel | None = None,
434
483
  joint_names: tuple[str, ...] | None = None,
435
484
  ) -> Self:
436
485
  """
@@ -460,23 +509,26 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
460
509
  if model is None:
461
510
  return replace(s=positions)
462
511
 
463
- if not self.valid(model=model):
512
+ if not_tracing(positions) and not self.valid(model=model):
464
513
  msg = "The data object is not compatible with the provided model"
465
514
  raise ValueError(msg)
466
515
 
467
- joint_names = joint_names if joint_names is not None else model.joint_names()
516
+ joint_idxs = (
517
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
518
+ if joint_names is not None
519
+ else jnp.arange(model.number_of_joints())
520
+ )
468
521
 
469
522
  return replace(
470
- s=self.state.physics_model.joint_positions.at[
471
- jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
472
- ].set(positions)
523
+ s=self.state.physics_model.joint_positions.at[joint_idxs].set(positions)
473
524
  )
474
525
 
526
+ @js.common.named_scope
475
527
  @functools.partial(jax.jit, static_argnames=["joint_names"])
476
528
  def reset_joint_velocities(
477
529
  self,
478
530
  velocities: jtp.VectorLike,
479
- model: jaxsim.api.model.JaxSimModel | None = None,
531
+ model: js.model.JaxSimModel | None = None,
480
532
  joint_names: tuple[str, ...] | None = None,
481
533
  ) -> Self:
482
534
  """
@@ -506,18 +558,21 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
506
558
  if model is None:
507
559
  return replace(ṡ=velocities)
508
560
 
509
- if not self.valid(model=model):
561
+ if not_tracing(velocities) and not self.valid(model=model):
510
562
  msg = "The data object is not compatible with the provided model"
511
563
  raise ValueError(msg)
512
564
 
513
- joint_names = joint_names if joint_names is not None else model.joint_names()
565
+ joint_idxs = (
566
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
567
+ if joint_names is not None
568
+ else jnp.arange(model.number_of_joints())
569
+ )
514
570
 
515
571
  return replace(
516
- ṡ=self.state.physics_model.joint_velocities.at[
517
- jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
518
- ].set(velocities)
572
+ ṡ=self.state.physics_model.joint_velocities.at[joint_idxs].set(velocities)
519
573
  )
520
574
 
575
+ @js.common.named_scope
521
576
  @jax.jit
522
577
  def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
523
578
  """
@@ -541,6 +596,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
541
596
  ),
542
597
  )
543
598
 
599
+ @js.common.named_scope
544
600
  @jax.jit
545
601
  def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
546
602
  """
@@ -553,19 +609,19 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
553
609
  The updated `JaxSimModelData` object.
554
610
  """
555
611
 
556
- base_quaternion = jnp.array(base_quaternion)
612
+ W_Q_B = jnp.array(base_quaternion, dtype=float)
613
+
614
+ norm = jaxsim.math.safe_norm(W_Q_B)
615
+ W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
557
616
 
558
617
  return self.replace(
559
618
  validate=True,
560
619
  state=self.state.replace(
561
- physics_model=self.state.physics_model.replace(
562
- base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
563
- float
564
- )
565
- )
620
+ physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
566
621
  ),
567
622
  )
568
623
 
624
+ @js.common.named_scope
569
625
  @jax.jit
570
626
  def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
571
627
  """
@@ -582,14 +638,13 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
582
638
 
583
639
  W_p_B = base_pose[0:3, 3]
584
640
 
585
- to_wxyz = np.array([3, 0, 1, 2])
586
- W_R_B: jaxlie.SO3 = jaxlie.SO3.from_matrix(base_pose[0:3, 0:3]) # noqa
587
- W_Q_B = W_R_B.as_quaternion_xyzw()[to_wxyz]
641
+ W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
588
642
 
589
643
  return self.reset_base_position(base_position=W_p_B).reset_base_quaternion(
590
644
  base_quaternion=W_Q_B
591
645
  )
592
646
 
647
+ @js.common.named_scope
593
648
  @functools.partial(jax.jit, static_argnames=["velocity_representation"])
594
649
  def reset_base_linear_velocity(
595
650
  self,
@@ -613,11 +668,15 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
613
668
 
614
669
  return self.reset_base_velocity(
615
670
  base_velocity=jnp.hstack(
616
- [linear_velocity.squeeze(), self.base_velocity()[3:6]]
671
+ [
672
+ linear_velocity.squeeze(),
673
+ self.base_velocity()[3:6],
674
+ ]
617
675
  ),
618
676
  velocity_representation=velocity_representation,
619
677
  )
620
678
 
679
+ @js.common.named_scope
621
680
  @functools.partial(jax.jit, static_argnames=["velocity_representation"])
622
681
  def reset_base_angular_velocity(
623
682
  self,
@@ -641,11 +700,15 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
641
700
 
642
701
  return self.reset_base_velocity(
643
702
  base_velocity=jnp.hstack(
644
- [self.base_velocity()[0:3], angular_velocity.squeeze()]
703
+ [
704
+ self.base_velocity()[0:3],
705
+ angular_velocity.squeeze(),
706
+ ]
645
707
  ),
646
708
  velocity_representation=velocity_representation,
647
709
  )
648
710
 
711
+ @js.common.named_scope
649
712
  @functools.partial(jax.jit, static_argnames=["velocity_representation"])
650
713
  def reset_base_velocity(
651
714
  self,
@@ -691,8 +754,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
691
754
  )
692
755
 
693
756
 
757
+ @functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"])
694
758
  def random_model_data(
695
- model: jaxsim.api.model.JaxSimModel,
759
+ model: js.model.JaxSimModel,
696
760
  *,
697
761
  key: jax.Array | None = None,
698
762
  velocity_representation: VelRepr | None = None,
@@ -700,6 +764,18 @@ def random_model_data(
700
764
  jtp.FloatLike | Sequence[jtp.FloatLike],
701
765
  jtp.FloatLike | Sequence[jtp.FloatLike],
702
766
  ] = ((-1, -1, 0.5), 1.0),
767
+ base_rpy_bounds: tuple[
768
+ jtp.FloatLike | Sequence[jtp.FloatLike],
769
+ jtp.FloatLike | Sequence[jtp.FloatLike],
770
+ ] = (-jnp.pi, jnp.pi),
771
+ base_rpy_seq: str = "XYZ",
772
+ joint_pos_bounds: (
773
+ tuple[
774
+ jtp.FloatLike | Sequence[jtp.FloatLike],
775
+ jtp.FloatLike | Sequence[jtp.FloatLike],
776
+ ]
777
+ | None
778
+ ) = None,
703
779
  base_vel_lin_bounds: tuple[
704
780
  jtp.FloatLike | Sequence[jtp.FloatLike],
705
781
  jtp.FloatLike | Sequence[jtp.FloatLike],
@@ -712,6 +788,11 @@ def random_model_data(
712
788
  jtp.FloatLike | Sequence[jtp.FloatLike],
713
789
  jtp.FloatLike | Sequence[jtp.FloatLike],
714
790
  ] = (-1.0, 1.0),
791
+ contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
792
+ standard_gravity_bounds: tuple[jtp.FloatLike, jtp.FloatLike] = (
793
+ jaxsim.math.StandardGravity,
794
+ jaxsim.math.StandardGravity,
795
+ ),
715
796
  ) -> JaxSimModelData:
716
797
  """
717
798
  Randomly generate a `JaxSimModelData` object.
@@ -721,19 +802,29 @@ def random_model_data(
721
802
  key: The random key.
722
803
  velocity_representation: The velocity representation to use.
723
804
  base_pos_bounds: The bounds for the base position.
805
+ base_rpy_bounds:
806
+ The bounds for the euler angles used to build the base orientation.
807
+ base_rpy_seq:
808
+ The sequence of axes for rotation (using `Rotation` from scipy).
809
+ joint_pos_bounds:
810
+ The bounds for the joint positions (reading the joint limits if None).
724
811
  base_vel_lin_bounds: The bounds for the base linear velocity.
725
812
  base_vel_ang_bounds: The bounds for the base angular velocity.
726
813
  joint_vel_bounds: The bounds for the joint velocities.
814
+ contacts_params: The parameters of the contact model.
815
+ standard_gravity_bounds: The bounds for the standard gravity.
727
816
 
728
817
  Returns:
729
818
  A `JaxSimModelData` object with random data.
730
819
  """
731
820
 
732
821
  key = key if key is not None else jax.random.PRNGKey(seed=0)
733
- k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6)
822
+ k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, num=7)
734
823
 
735
824
  p_min = jnp.array(base_pos_bounds[0], dtype=float)
736
825
  p_max = jnp.array(base_pos_bounds[1], dtype=float)
826
+ rpy_min = jnp.array(base_rpy_bounds[0], dtype=float)
827
+ rpy_max = jnp.array(base_rpy_bounds[1], dtype=float)
737
828
  v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
738
829
  v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
739
830
  ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
@@ -749,7 +840,9 @@ def random_model_data(
749
840
  ),
750
841
  )
751
842
 
752
- with random_data.mutable_context(mutability=Mutability.MUTABLE):
843
+ with random_data.mutable_context(
844
+ mutability=Mutability.MUTABLE, restore_after_exception=False
845
+ ):
753
846
 
754
847
  physics_model_state = random_data.state.physics_model
755
848
 
@@ -757,24 +850,76 @@ def random_model_data(
757
850
  key=k1, shape=(3,), minval=p_min, maxval=p_max
758
851
  )
759
852
 
760
- physics_model_state.base_quaternion = jaxlie.SO3.from_rpy_radians(
761
- *jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
762
- ).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]
763
-
764
- physics_model_state.joint_positions = jaxsim.api.joint.random_joint_positions(
765
- model=model, key=k3
853
+ physics_model_state.base_quaternion = jaxsim.math.Quaternion.to_wxyz(
854
+ xyzw=jax.scipy.spatial.transform.Rotation.from_euler(
855
+ seq=base_rpy_seq,
856
+ angles=jax.random.uniform(
857
+ key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
858
+ ),
859
+ ).as_quat()
766
860
  )
767
861
 
768
- physics_model_state.base_linear_velocity = jax.random.uniform(
769
- key=k4, shape=(3,), minval=v_min, maxval=v_max
770
- )
862
+ if model.number_of_joints() > 0:
771
863
 
772
- physics_model_state.base_angular_velocity = jax.random.uniform(
773
- key=k5, shape=(3,), minval=ω_min, maxval=ω_max
774
- )
864
+ s_min, s_max = (
865
+ jnp.array(joint_pos_bounds, dtype=float)
866
+ if joint_pos_bounds is not None
867
+ else (None, None)
868
+ )
869
+
870
+ physics_model_state.joint_positions = (
871
+ js.joint.random_joint_positions(model=model, key=k3)
872
+ if (s_min is None or s_max is None)
873
+ else jax.random.uniform(
874
+ key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
875
+ )
876
+ )
877
+
878
+ physics_model_state.joint_velocities = jax.random.uniform(
879
+ key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
880
+ )
881
+
882
+ if model.floating_base():
883
+ physics_model_state.base_linear_velocity = jax.random.uniform(
884
+ key=k5, shape=(3,), minval=v_min, maxval=v_max
885
+ )
886
+
887
+ physics_model_state.base_angular_velocity = jax.random.uniform(
888
+ key=k6, shape=(3,), minval=ω_min, maxval=ω_max
889
+ )
775
890
 
776
- physics_model_state.joint_velocities = jax.random.uniform(
777
- key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
891
+ random_data.gravity = (
892
+ jnp.zeros(3, dtype=random_data.gravity.dtype)
893
+ .at[2]
894
+ .set(
895
+ -jax.random.uniform(
896
+ key=k7,
897
+ shape=(),
898
+ minval=standard_gravity_bounds[0],
899
+ maxval=standard_gravity_bounds[1],
900
+ )
901
+ )
778
902
  )
779
903
 
904
+ if contacts_params is None:
905
+
906
+ if isinstance(
907
+ model.contact_model,
908
+ jaxsim.rbda.contacts.SoftContacts
909
+ | jaxsim.rbda.contacts.ViscoElasticContacts,
910
+ ):
911
+
912
+ random_data = random_data.replace(
913
+ contacts_params=js.contact.estimate_good_contact_parameters(
914
+ model=model, standard_gravity=random_data.gravity
915
+ ),
916
+ validate=False,
917
+ )
918
+
919
+ else:
920
+ random_data = random_data.replace(
921
+ contacts_params=model.contact_model._parameters_class(),
922
+ validate=False,
923
+ )
924
+
780
925
  return random_data