jaxsim 0.6.2.dev182__py3-none-any.whl → 0.6.2.dev225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.6.2.dev182'
21
- __version_tuple__ = version_tuple = (0, 6, 2, 'dev182')
20
+ __version__ = version = '0.6.2.dev225'
21
+ __version_tuple__ = version_tuple = (0, 6, 2, 'dev225')
jaxsim/api/__init__.py CHANGED
@@ -4,7 +4,6 @@ from . import (
4
4
  actuation_model,
5
5
  com,
6
6
  contact,
7
- contact_model,
8
7
  frame,
9
8
  integrators,
10
9
  joint,
jaxsim/api/com.py CHANGED
@@ -301,9 +301,7 @@ def bias_acceleration(
301
301
  C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
302
302
  C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
303
303
 
304
- L_H_C = L_H_W = jax.vmap( # noqa: F841
305
- lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
306
- )(W_H_L)
304
+ L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841
307
305
 
308
306
  L_v_LC = L_v_LW = jax.vmap( # noqa: F841
309
307
  lambda i: -js.link.velocity(
jaxsim/api/common.py CHANGED
@@ -121,14 +121,8 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
121
121
  The 6D quantity in the other representation.
122
122
  """
123
123
 
124
- W_array = array.squeeze()
125
- W_H_O = transform.squeeze()
126
-
127
- if W_array.size != 6:
128
- raise ValueError(W_array.size, 6)
129
-
130
- if W_H_O.shape != (4, 4):
131
- raise ValueError(W_H_O.shape, (4, 4))
124
+ W_array = array
125
+ W_H_O = transform
132
126
 
133
127
  match other_representation:
134
128
 
@@ -139,25 +133,24 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
139
133
 
140
134
  if not is_force:
141
135
  O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
142
- O_array = O_Xv_W @ W_array
136
+ O_array = jnp.einsum("...ij,...j->...i", O_Xv_W, W_array)
143
137
 
144
138
  else:
145
- O_Xf_W = Adjoint.from_transform(transform=W_H_O).T
146
- O_array = O_Xf_W @ W_array
139
+ O_Xf_W = Adjoint.from_transform(transform=W_H_O).swapaxes(-1, -2)
140
+ O_array = jnp.einsum("...ij,...j->...i", O_Xf_W, W_array)
147
141
 
148
142
  return O_array
149
143
 
150
144
  case VelRepr.Mixed:
151
- W_p_O = W_H_O[0:3, 3]
152
- W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
145
+ W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
153
146
 
154
147
  if not is_force:
155
148
  OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
156
- OW_array = OW_Xv_W @ W_array
149
+ OW_array = jnp.einsum("...ij,...j->...i", OW_Xv_W, W_array)
157
150
 
158
151
  else:
159
- OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T
160
- OW_array = OW_Xf_W @ W_array
152
+ OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).swapaxes(-1, -2)
153
+ OW_array = jnp.einsum("...ij,...j->...i", OW_Xf_W, W_array)
161
154
 
162
155
  return OW_array
163
156
 
@@ -188,45 +181,40 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
188
181
  The 6D quantity in the inertial-fixed representation.
189
182
  """
190
183
 
191
- W_array = array.squeeze()
192
- W_H_O = transform.squeeze()
193
-
194
- if W_array.size != 6:
195
- raise ValueError(W_array.size, 6)
196
-
197
- if W_H_O.shape != (4, 4):
198
- raise ValueError(W_H_O.shape, (4, 4))
184
+ O_array = array
185
+ W_H_O = transform
199
186
 
200
187
  match other_representation:
201
188
  case VelRepr.Inertial:
202
- W_array = array
203
- return W_array
189
+ return O_array
204
190
 
205
191
  case VelRepr.Body:
206
- O_array = array
207
192
 
208
193
  if not is_force:
209
- W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O)
210
- W_array = W_Xv_O @ O_array
194
+ W_Xv_O = Adjoint.from_transform(W_H_O)
195
+ W_array = jnp.einsum("...ij,...j->...i", W_Xv_O, O_array)
211
196
 
212
197
  else:
213
- W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T
214
- W_array = W_Xf_O @ O_array
198
+ W_Xf_O = Adjoint.from_transform(
199
+ transform=W_H_O, inverse=True
200
+ ).swapaxes(-1, -2)
201
+ W_array = jnp.einsum("...ij,...j->...i", W_Xf_O, O_array)
215
202
 
216
203
  return W_array
217
204
 
218
205
  case VelRepr.Mixed:
219
- BW_array = array
220
- W_p_O = W_H_O[0:3, 3]
221
- W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
206
+
207
+ W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
222
208
 
223
209
  if not is_force:
224
- W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW)
225
- W_array = W_Xv_BW @ BW_array
210
+ W_Xv_BW = Adjoint.from_transform(W_H_OW)
211
+ W_array = jnp.einsum("...ij,...j->...i", W_Xv_BW, O_array)
226
212
 
227
213
  else:
228
- W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T
229
- W_array = W_Xf_BW @ BW_array
214
+ W_Xf_BW = Adjoint.from_transform(
215
+ transform=W_H_OW, inverse=True
216
+ ).swapaxes(-1, -2)
217
+ W_array = jnp.einsum("...ij,...j->...i", W_Xf_BW, O_array)
230
218
 
231
219
  return W_array
232
220
 
jaxsim/api/contact.py CHANGED
@@ -11,7 +11,7 @@ import jaxsim.terrain
11
11
  import jaxsim.typing as jtp
12
12
  from jaxsim import logging
13
13
  from jaxsim.math import Adjoint, Cross, Transform
14
- from jaxsim.rbda import contacts
14
+ from jaxsim.rbda.contacts import SoftContacts
15
15
 
16
16
  from .common import VelRepr
17
17
 
@@ -37,14 +37,11 @@ def collidable_point_kinematics(
37
37
  the linear component of the mixed 6D frame velocity.
38
38
  """
39
39
 
40
- # Switch to inertial-fixed since the RBDAs expect velocities in this representation.
41
- with data.switch_velocity_representation(VelRepr.Inertial):
42
-
43
- W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
44
- model=model,
45
- link_transforms=data._link_transforms,
46
- link_velocities=data._link_velocities,
47
- )
40
+ W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
41
+ model=model,
42
+ link_transforms=data._link_transforms,
43
+ link_velocities=data._link_velocities,
44
+ )
48
45
 
49
46
  return W_p_Ci, W_ṗ_Ci
50
47
 
@@ -164,18 +161,23 @@ def estimate_good_soft_contacts_parameters(
164
161
  def estimate_good_contact_parameters(
165
162
  model: js.model.JaxSimModel,
166
163
  *,
164
+ standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
167
165
  static_friction_coefficient: jtp.FloatLike = 0.5,
168
- **kwargs,
166
+ number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
167
+ damping_ratio: jtp.FloatLike = 1.0,
168
+ max_penetration: jtp.FloatLike | None = None,
169
169
  ) -> jaxsim.rbda.contacts.ContactParamsTypes:
170
170
  """
171
171
  Estimate good contact parameters.
172
172
 
173
173
  Args:
174
174
  model: The model to consider.
175
+ standard_gravity: The standard gravity acceleration.
175
176
  static_friction_coefficient: The static friction coefficient.
176
- kwargs:
177
- Additional model-specific parameters passed to the builder method of
178
- the parameters class.
177
+ number_of_active_collidable_points_steady_state:
178
+ The number of active collidable points in steady state.
179
+ damping_ratio: The damping ratio.
180
+ max_penetration: The maximum penetration allowed.
179
181
 
180
182
  Returns:
181
183
  The estimated good contacts parameters.
@@ -190,20 +192,41 @@ def estimate_good_contact_parameters(
190
192
  specific application.
191
193
  """
192
194
 
193
- match model.contact_model:
195
+ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
196
+ """
197
+ Displacement between the CoM and the lowest collidable point using zero
198
+ joint positions.
199
+ """
200
+
201
+ zero_data = js.data.JaxSimModelData.build(
202
+ model=model,
203
+ )
204
+
205
+ W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
194
206
 
195
- case contacts.RelaxedRigidContacts():
196
- assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
207
+ if model.floating_base():
208
+ W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
209
+ return 2 * (W_pz_CoM - W_pz_C.min())
197
210
 
198
- parameters = contacts.RelaxedRigidContactsParams.build(
199
- mu=static_friction_coefficient,
200
- **kwargs,
201
- )
211
+ return 2 * W_pz_CoM
202
212
 
203
- case _:
204
- raise ValueError(f"Invalid contact model: {model.contact_model}")
213
+ max_δ = (
214
+ max_penetration
215
+ if max_penetration is not None
216
+ # Consider as default a 0.5% of the model height.
217
+ else 0.005 * estimate_model_height(model=model)
218
+ )
205
219
 
206
- return parameters
220
+ nc = number_of_active_collidable_points_steady_state
221
+
222
+ return model.contact_model._parameters_class().build_default_from_jaxsim_model(
223
+ model=model,
224
+ standard_gravity=standard_gravity,
225
+ static_friction_coefficient=static_friction_coefficient,
226
+ max_penetration=max_δ,
227
+ number_of_active_collidable_points_steady_state=nc,
228
+ damping_ratio=damping_ratio,
229
+ )
207
230
 
208
231
 
209
232
  @jax.jit
@@ -244,7 +267,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
244
267
 
245
268
  # Build the link-to-point transform from the displacement between the link frame L
246
269
  # and the implicit contact frame C.
247
- L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)
270
+ L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci)
248
271
 
249
272
  # Compose the work-to-link and link-to-point transforms.
250
273
  return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
@@ -504,3 +527,96 @@ def jacobian_derivative(
504
527
  )
505
528
 
506
529
  return O_J̇_WC
530
+
531
+
532
+ @jax.jit
533
+ @js.common.named_scope
534
+ def link_contact_forces(
535
+ model: js.model.JaxSimModel,
536
+ data: js.data.JaxSimModelData,
537
+ *,
538
+ link_forces: jtp.MatrixLike | None = None,
539
+ joint_torques: jtp.VectorLike | None = None,
540
+ ) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
541
+ """
542
+ Compute the 6D contact forces of all links of the model in inertial representation.
543
+
544
+ Args:
545
+ model: The model to consider.
546
+ data: The data of the considered model.
547
+ link_forces:
548
+ The 6D external forces to apply to the links expressed in inertial representation
549
+ joint_torques:
550
+ The joint torques acting on the joints.
551
+
552
+ Returns:
553
+ A `(nL, 6)` array containing the stacked 6D contact forces of the links,
554
+ expressed in inertial representation.
555
+ """
556
+
557
+ # Compute the contact forces for each collidable point with the active contact model.
558
+ W_f_C, aux_dict = model.contact_model.compute_contact_forces(
559
+ model=model,
560
+ data=data,
561
+ **(
562
+ dict(link_forces=link_forces, joint_force_references=joint_torques)
563
+ if not isinstance(model.contact_model, SoftContacts)
564
+ else {}
565
+ ),
566
+ )
567
+
568
+ # Compute the 6D forces applied to the links equivalent to the forces applied
569
+ # to the frames associated to the collidable points.
570
+ W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)
571
+
572
+ return W_f_L, aux_dict
573
+
574
+
575
+ @staticmethod
576
+ def link_forces_from_contact_forces(
577
+ model: js.model.JaxSimModel,
578
+ *,
579
+ contact_forces: jtp.MatrixLike,
580
+ ) -> jtp.Matrix:
581
+ """
582
+ Compute the link forces from the contact forces.
583
+
584
+ Args:
585
+ model: The robot model considered by the contact model.
586
+ contact_forces: The contact forces computed by the contact model.
587
+
588
+ Returns:
589
+ The 6D contact forces applied to the links and expressed in the frame of
590
+ the velocity representation of data.
591
+ """
592
+
593
+ # Get the object storing the contact parameters of the model.
594
+ contact_parameters = model.kin_dyn_parameters.contact_parameters
595
+
596
+ # Extract the indices corresponding to the enabled collidable points.
597
+ indices_of_enabled_collidable_points = (
598
+ contact_parameters.indices_of_enabled_collidable_points
599
+ )
600
+
601
+ # Convert the contact forces to a JAX array.
602
+ W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
603
+
604
+ # Construct the vector defining the parent link index of each collidable point.
605
+ # We use this vector to sum the 6D forces of all collidable points rigidly
606
+ # attached to the same link.
607
+ parent_link_index_of_collidable_points = jnp.array(
608
+ contact_parameters.body, dtype=int
609
+ )[indices_of_enabled_collidable_points]
610
+
611
+ # Create the mask that associate each collidable point to their parent link.
612
+ # We use this mask to sum the collidable points to the right link.
613
+ mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
614
+ model.number_of_links()
615
+ )
616
+
617
+ # Sum the forces of all collidable points rigidly attached to a body.
618
+ # Since the contact forces W_f_C are expressed in the world frame,
619
+ # we don't need any coordinate transformation.
620
+ W_f_L = mask.T @ W_f_C
621
+
622
+ return W_f_L
jaxsim/api/data.py CHANGED
@@ -5,9 +5,9 @@ import functools
5
5
  from collections.abc import Sequence
6
6
 
7
7
  try:
8
- from typing import override
8
+ from typing import Self, override
9
9
  except ImportError:
10
- from typing_extensions import override
10
+ from typing_extensions import override, Self
11
11
 
12
12
  import jax
13
13
  import jax.numpy as jnp
@@ -22,11 +22,6 @@ import jaxsim.typing as jtp
22
22
  from . import common
23
23
  from .common import VelRepr
24
24
 
25
- try:
26
- from typing import Self
27
- except ImportError:
28
- from typing_extensions import Self
29
-
30
25
 
31
26
  @jax_dataclasses.pytree_dataclass
32
27
  class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
@@ -64,6 +59,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
64
59
  _link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
65
60
  _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
66
61
 
62
+ # Extended state for soft and rigid contact models.
63
+ contact_state: dict[str, jtp.Array] = dataclasses.field(default=None)
64
+
67
65
  @staticmethod
68
66
  def build(
69
67
  model: js.model.JaxSimModel,
@@ -73,6 +71,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
73
71
  base_linear_velocity: jtp.VectorLike | None = None,
74
72
  base_angular_velocity: jtp.VectorLike | None = None,
75
73
  joint_velocities: jtp.VectorLike | None = None,
74
+ contact_state: dict[str, jtp.Array] | None = None,
76
75
  velocity_representation: VelRepr = VelRepr.Mixed,
77
76
  ) -> JaxSimModelData:
78
77
  """
@@ -89,6 +88,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
89
88
  The base angular velocity in the selected representation.
90
89
  joint_velocities: The joint velocities.
91
90
  velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
91
+ contact_state: The optional contact state.
92
92
 
93
93
  Returns:
94
94
  A `JaxSimModelData` initialized with the given state.
@@ -171,6 +171,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
171
171
  )
172
172
  )
173
173
 
174
+ contact_state = contact_state or {}
175
+
176
+ if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
177
+ contact_state.setdefault(
178
+ "tangential_deformation",
179
+ jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
180
+ )
181
+
174
182
  model_data = JaxSimModelData(
175
183
  velocity_representation=velocity_representation,
176
184
  _base_quaternion=base_quaternion,
@@ -183,6 +191,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
183
191
  _joint_transforms=joint_transforms,
184
192
  _link_transforms=link_transforms,
185
193
  _link_velocities=link_velocities_inertial,
194
+ contact_state=contact_state,
186
195
  )
187
196
 
188
197
  if not model_data.valid(model=model):
@@ -265,14 +274,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
265
274
  """
266
275
 
267
276
  # Extract the base quaternion.
268
- W_Q_B = self.base_quaternion.squeeze()
277
+ W_Q_B = self.base_quaternion
269
278
 
270
279
  # Always normalize the quaternion to avoid numerical issues.
271
280
  # If the active scheme does not integrate the quaternion on its manifold,
272
281
  # we introduce a Baumgarte stabilization to let the quaternion converge to
273
282
  # a unit quaternion. In this case, it is not guaranteed that the quaternion
274
283
  # stored in the state is a unit quaternion.
275
- norm = jaxsim.math.safe_norm(W_Q_B)
284
+ norm = jaxsim.math.safe_norm(W_Q_B, axis=-1, keepdims=True)
276
285
  W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
277
286
  return W_Q_B
278
287
 
@@ -285,11 +294,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
285
294
  The base 6D velocity in the active representation.
286
295
  """
287
296
 
288
- W_v_WB = jnp.hstack(
289
- [
290
- self._base_linear_velocity,
291
- self._base_angular_velocity,
292
- ]
297
+ W_v_WB = jnp.concatenate(
298
+ [self._base_linear_velocity, self._base_angular_velocity], axis=-1
293
299
  )
294
300
 
295
301
  W_H_B = self._base_transform
@@ -350,11 +356,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
350
356
 
351
357
  @js.common.named_scope
352
358
  @jax.jit
353
- def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
359
+ def reset_base_quaternion(
360
+ self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike
361
+ ) -> Self:
354
362
  """
355
363
  Reset the base quaternion.
356
364
 
357
365
  Args:
366
+ model: The JaxSim model to use.
358
367
  base_quaternion: The base orientation as a quaternion.
359
368
 
360
369
  Returns:
@@ -363,18 +372,21 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
363
372
 
364
373
  W_Q_B = jnp.array(base_quaternion, dtype=float)
365
374
 
366
- norm = jaxsim.math.safe_norm(W_Q_B)
375
+ norm = jaxsim.math.safe_norm(W_Q_B, axis=-1)
367
376
  W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
368
377
 
369
- return self.replace(validate=True, base_quaternion=W_Q_B)
378
+ return self.replace(model=model, base_quaternion=W_Q_B)
370
379
 
371
380
  @js.common.named_scope
372
381
  @jax.jit
373
- def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
382
+ def reset_base_pose(
383
+ self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike
384
+ ) -> Self:
374
385
  """
375
386
  Reset the base pose.
376
387
 
377
388
  Args:
389
+ model: The JaxSim model to use.
378
390
  base_pose: The base pose as an SE(3) matrix.
379
391
 
380
392
  Returns:
@@ -385,6 +397,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
385
397
  W_p_B = base_pose[0:3, 3]
386
398
  W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
387
399
  return self.replace(
400
+ model=model,
388
401
  base_position=W_p_B,
389
402
  base_quaternion=W_Q_B,
390
403
  )
@@ -399,11 +412,19 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
399
412
  base_linear_velocity: jtp.Vector | None = None,
400
413
  base_angular_velocity: jtp.Vector | None = None,
401
414
  base_position: jtp.Vector | None = None,
415
+ *,
416
+ contact_state: dict[str, jtp.Array] | None = None,
402
417
  validate: bool = False,
403
418
  ) -> Self:
404
419
  """
405
420
  Replace the attributes of the `JaxSimModelData` object.
406
421
  """
422
+
423
+ # Extract the batch size.
424
+ batch_size = (
425
+ self._base_transform.shape[0] if self._base_transform.ndim > 2 else 1
426
+ )
427
+
407
428
  if joint_positions is None:
408
429
  joint_positions = self.joint_positions
409
430
  if joint_velocities is None:
@@ -412,6 +433,22 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
412
433
  base_quaternion = self.base_quaternion
413
434
  if base_position is None:
414
435
  base_position = self.base_position
436
+ if contact_state is None:
437
+ contact_state = self.contact_state
438
+
439
+ if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
440
+ contact_state.setdefault(
441
+ "tangential_deformation",
442
+ jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
443
+ )
444
+
445
+ # Normalize the quaternion to avoid numerical issues.
446
+ base_quaternion_norm = jaxsim.math.safe_norm(
447
+ base_quaternion, axis=-1, keepdims=True
448
+ )
449
+ base_quaternion = base_quaternion / jnp.where(
450
+ base_quaternion_norm == 0, 1.0, base_quaternion_norm
451
+ )
415
452
 
416
453
  joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float)
417
454
  joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float)
@@ -421,44 +458,70 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
421
458
  base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
422
459
  translation=base_position, quaternion=base_quaternion
423
460
  )
424
- joint_transforms = model.kin_dyn_parameters.joint_transforms(
425
- joint_positions=joint_positions, base_transform=base_transform
461
+
462
+ joint_transforms = jax.vmap(model.kin_dyn_parameters.joint_transforms)(
463
+ joint_positions=jnp.broadcast_to(
464
+ joint_positions, (batch_size, model.dofs())
465
+ ),
466
+ base_transform=jnp.broadcast_to(base_transform, (batch_size, 4, 4)),
426
467
  )
427
468
 
428
469
  if base_linear_velocity is None and base_angular_velocity is None:
429
- base_linear_velocity = self._base_linear_velocity
430
- base_angular_velocity = self._base_angular_velocity
470
+ base_linear_velocity_inertial = self._base_linear_velocity
471
+ base_angular_velocity_inertial = self._base_angular_velocity
431
472
  else:
432
473
  if base_linear_velocity is None:
433
474
  base_linear_velocity = self.base_velocity[:3]
434
475
  if base_angular_velocity is None:
435
476
  base_angular_velocity = self.base_velocity[3:]
477
+
436
478
  base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze())
437
479
  base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze())
480
+
438
481
  W_v_WB = JaxSimModelData.other_representation_to_inertial(
439
482
  array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
440
483
  other_representation=self.velocity_representation,
441
484
  transform=base_transform,
442
485
  is_force=False,
443
486
  ).astype(float)
444
- base_linear_velocity, base_angular_velocity = W_v_WB[:3], W_v_WB[3:]
445
487
 
446
- link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model(
447
- model=model,
448
- base_position=base_position,
449
- base_quaternion=base_quaternion,
450
- joint_positions=joint_positions,
451
- joint_velocities=joint_velocities,
452
- base_linear_velocity_inertial=base_linear_velocity,
453
- base_angular_velocity_inertial=base_angular_velocity,
488
+ base_linear_velocity_inertial, base_angular_velocity_inertial = (
489
+ W_v_WB[..., :3],
490
+ W_v_WB[..., 3:],
491
+ )
492
+
493
+ link_transforms, link_velocities = jax.vmap(
494
+ jaxsim.rbda.forward_kinematics_model, in_axes=(None,)
495
+ )(
496
+ model,
497
+ base_position=jnp.broadcast_to(base_position, (batch_size, 3)),
498
+ base_quaternion=jnp.broadcast_to(base_quaternion, (batch_size, 4)),
499
+ joint_positions=jnp.broadcast_to(
500
+ joint_positions, (batch_size, model.dofs())
501
+ ),
502
+ joint_velocities=jnp.broadcast_to(
503
+ joint_velocities, (batch_size, model.dofs())
504
+ ),
505
+ base_linear_velocity_inertial=jnp.broadcast_to(
506
+ base_linear_velocity_inertial, (batch_size, 3)
507
+ ),
508
+ base_angular_velocity_inertial=jnp.broadcast_to(
509
+ base_angular_velocity_inertial, (batch_size, 3)
510
+ ),
454
511
  )
455
512
 
513
+ # Adjust the output shapes.
514
+ if batch_size == 1:
515
+ link_transforms = link_transforms.reshape(self._link_transforms.shape)
516
+ link_velocities = link_velocities.reshape(self._link_velocities.shape)
517
+ joint_transforms = joint_transforms.reshape(self._joint_transforms.shape)
518
+
456
519
  return super().replace(
457
520
  _joint_positions=joint_positions,
458
521
  _joint_velocities=joint_velocities,
459
522
  _base_quaternion=base_quaternion,
460
- _base_linear_velocity=base_linear_velocity,
461
- _base_angular_velocity=base_angular_velocity,
523
+ _base_linear_velocity=base_linear_velocity_inertial,
524
+ _base_angular_velocity=base_angular_velocity_inertial,
462
525
  _base_position=base_position,
463
526
  _base_transform=base_transform,
464
527
  _joint_transforms=joint_transforms,