jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/api/frame.py CHANGED
@@ -229,7 +229,7 @@ def velocity(
229
229
  )
230
230
 
231
231
  # Get the generalized velocity in the input velocity representation.
232
- I_ν = data.generalized_velocity()
232
+ I_ν = data.generalized_velocity
233
233
 
234
234
  # Compute the frame velocity in the output velocity representation.
235
235
  return O_J_WF_I @ I_ν
@@ -401,9 +401,9 @@ def jacobian_derivative(
401
401
  Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
402
402
 
403
403
  case VelRepr.Body:
404
- W_H_B = data.base_transform()
404
+ W_H_B = data._base_transform
405
405
  W_X_B = Adjoint.from_transform(transform=W_H_B)
406
- B_v_WB = data.base_velocity()
406
+ B_v_WB = data.base_velocity
407
407
  B_vx_WB = Cross.vx(B_v_WB)
408
408
  W_Ẋ_B = W_X_B @ B_vx_WB
409
409
 
@@ -411,10 +411,10 @@ def jacobian_derivative(
411
411
  Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
412
412
 
413
413
  case VelRepr.Mixed:
414
- W_H_B = data.base_transform()
414
+ W_H_B = data._base_transform
415
415
  W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
416
416
  W_X_BW = Adjoint.from_transform(transform=W_H_BW)
417
- BW_v_WB = data.base_velocity()
417
+ BW_v_WB = data.base_velocity
418
418
  BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
419
419
  BW_vx_W_BW = Cross.vx(BW_v_W_BW)
420
420
  W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
@@ -438,7 +438,7 @@ def jacobian_derivative(
438
438
  W_H_F = transform(model=model, data=data, frame_index=frame_index)
439
439
  O_X_W = F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True)
440
440
  with data.switch_velocity_representation(VelRepr.Inertial):
441
- W_nu = data.generalized_velocity()
441
+ W_nu = data.generalized_velocity
442
442
  W_v_WF = W_J_WL_W @ W_nu
443
443
  W_vx_WF = Cross.vx(W_v_WF)
444
444
  O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF # noqa: F841
@@ -455,7 +455,7 @@ def jacobian_derivative(
455
455
  frame_index=frame_index,
456
456
  output_vel_repr=VelRepr.Mixed,
457
457
  )
458
- FW_v_WF = FW_J_WF_FW @ data.generalized_velocity()
458
+ FW_v_WF = FW_J_WF_FW @ data.generalized_velocity
459
459
  W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3])
460
460
  W_vx_W_FW = Cross.vx(W_v_W_FW)
461
461
  O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW # noqa: F841
@@ -0,0 +1,76 @@
1
+ import dataclasses
2
+
3
+ import jax.numpy as jnp
4
+
5
+ import jaxsim
6
+ import jaxsim.api as js
7
+ import jaxsim.typing as jtp
8
+ from jaxsim.api.data import JaxSimModelData
9
+ from jaxsim.math import Adjoint, Transform
10
+
11
+
12
+ def semi_implicit_euler_integration(
13
+ model: js.model.JaxSimModel,
14
+ data: js.data.JaxSimModelData,
15
+ base_acceleration_inertial: jtp.Vector,
16
+ joint_accelerations: jtp.Vector,
17
+ ) -> JaxSimModelData:
18
+ """Integrate the system state using the semi-implicit Euler method."""
19
+ # Step the dynamics forward.
20
+ with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
21
+
22
+ dt = model.time_step
23
+ W_v̇_WB = base_acceleration_inertial
24
+ s̈ = joint_accelerations
25
+
26
+ B_H_W = Transform.inverse(data._base_transform).at[:3, :3].set(jnp.eye(3))
27
+ BW_X_W = Adjoint.from_transform(B_H_W)
28
+
29
+ new_generalized_acceleration = jnp.hstack([W_v̇_WB, s̈])
30
+
31
+ new_generalized_velocity = (
32
+ data.generalized_velocity + dt * new_generalized_acceleration
33
+ )
34
+
35
+ new_base_velocity_inertial = new_generalized_velocity[0:6]
36
+ new_joint_velocities = new_generalized_velocity[6:]
37
+
38
+ base_lin_velocity_inertial = new_base_velocity_inertial[0:3]
39
+
40
+ new_base_velocity_mixed = BW_X_W @ new_generalized_velocity[0:6]
41
+ base_lin_velocity_mixed = new_base_velocity_mixed[0:3]
42
+ base_ang_velocity_mixed = new_base_velocity_mixed[3:6]
43
+
44
+ base_quaternion_derivative = jaxsim.math.Quaternion.derivative(
45
+ quaternion=data.base_orientation,
46
+ omega=base_ang_velocity_mixed,
47
+ omega_in_body_fixed=False,
48
+ ).squeeze()
49
+
50
+ new_base_position = data.base_position + dt * base_lin_velocity_mixed
51
+ new_base_quaternion = data.base_orientation + dt * base_quaternion_derivative
52
+
53
+ base_quaternion_norm = jaxsim.math.safe_norm(new_base_quaternion)
54
+
55
+ new_base_quaternion = new_base_quaternion / jnp.where(
56
+ base_quaternion_norm == 0, 1.0, base_quaternion_norm
57
+ )
58
+
59
+ new_joint_position = data.joint_positions + dt * new_joint_velocities
60
+
61
+ # TODO: Avoid double replace, e.g. by computing cached value here
62
+ data = dataclasses.replace(
63
+ data,
64
+ _base_quaternion=new_base_quaternion,
65
+ _base_position=new_base_position,
66
+ _joint_positions=new_joint_position,
67
+ _joint_velocities=new_joint_velocities,
68
+ _base_linear_velocity=base_lin_velocity_inertial,
69
+ # Here we use the base angular velocity in mixed representation since
70
+ # it's equivalent to the one in inertial representation
71
+ # See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
72
+ _base_angular_velocity=base_ang_velocity_mixed,
73
+ )
74
+ data = data.replace(model=model) # update cache
75
+
76
+ return data
@@ -11,7 +11,7 @@ from jax_dataclasses import Static
11
11
 
12
12
  import jaxsim.typing as jtp
13
13
  from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
14
- from jaxsim.parsers.descriptions import JointDescription, ModelDescription
14
+ from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription
15
15
  from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
16
16
 
17
17
 
@@ -36,6 +36,7 @@ class KinDynParameters(JaxsimDataclass):
36
36
  link_names: Static[tuple[str]]
37
37
  _parent_array: Static[HashedNumpyArray]
38
38
  _support_body_array_bool: Static[HashedNumpyArray]
39
+ _motion_subspaces: Static[HashedNumpyArray]
39
40
 
40
41
  # Links
41
42
  link_parameters: LinkParameters
@@ -50,6 +51,13 @@ class KinDynParameters(JaxsimDataclass):
50
51
  joint_model: JointModel
51
52
  joint_parameters: JointParameters | None
52
53
 
54
+ @property
55
+ def motion_subspaces(self) -> jtp.Matrix:
56
+ r"""
57
+ Return the motion subspaces :math:`\mathbf{S}(s)` of the joints.
58
+ """
59
+ return self._motion_subspaces.get()
60
+
53
61
  @property
54
62
  def parent_array(self) -> jtp.Vector:
55
63
  r"""
@@ -215,6 +223,31 @@ class KinDynParameters(JaxsimDataclass):
215
223
  jnp.arange(start=0, stop=len(ordered_links))
216
224
  )
217
225
 
226
+ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:
227
+
228
+ S = {
229
+ JointType.Fixed: np.zeros(shape=(6, 1)),
230
+ JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])),
231
+ JointType.Prismatic: np.vstack(np.hstack([axis.axis, np.zeros(3)])),
232
+ }
233
+
234
+ return S[joint_type]
235
+
236
+ S_J = (
237
+ jnp.array(
238
+ [
239
+ motion_subspace(joint_type, axis)
240
+ for joint_type, axis in zip(
241
+ joint_model.joint_types[1:], joint_model.joint_axis, strict=True
242
+ )
243
+ ]
244
+ )
245
+ if len(joint_model.joint_axis) != 0
246
+ else jnp.empty((0, 6, 1))
247
+ )
248
+
249
+ motion_subspaces = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
250
+
218
251
  # =================================
219
252
  # Build and return KinDynParameters
220
253
  # =================================
@@ -223,6 +256,7 @@ class KinDynParameters(JaxsimDataclass):
223
256
  link_names=tuple(l.name for l in ordered_links),
224
257
  _parent_array=HashedNumpyArray(array=parent_array),
225
258
  _support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
259
+ _motion_subspaces=HashedNumpyArray(array=motion_subspaces),
226
260
  link_parameters=link_parameters,
227
261
  joint_model=joint_model,
228
262
  joint_parameters=joint_parameters,
@@ -359,54 +393,6 @@ class KinDynParameters(JaxsimDataclass):
359
393
  of each joint.
360
394
  """
361
395
 
362
- return self.joint_transforms_and_motion_subspaces(
363
- joint_positions=joint_positions,
364
- base_transform=base_transform,
365
- )[0]
366
-
367
- @jax.jit
368
- def joint_motion_subspaces(
369
- self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
370
- ) -> jtp.Array:
371
- r"""
372
- Return the motion subspaces of the joints.
373
-
374
- Args:
375
- joint_positions: The joint positions.
376
- base_transform: The homogeneous matrix defining the base pose.
377
-
378
- Returns:
379
- The stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
380
- """
381
-
382
- return self.joint_transforms_and_motion_subspaces(
383
- joint_positions=joint_positions,
384
- base_transform=base_transform,
385
- )[1]
386
-
387
- @jax.jit
388
- def joint_transforms_and_motion_subspaces(
389
- self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
390
- ) -> tuple[jtp.Array, jtp.Array]:
391
- r"""
392
- Return the transforms and the motion subspaces of the joints.
393
-
394
- Args:
395
- joint_positions: The joint positions.
396
- base_transform: The homogeneous matrix defining the base pose.
397
-
398
- Returns:
399
- A tuple containing the stacked transforms
400
- :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
401
- and the stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
402
-
403
- Note:
404
- The first transform, at index 0, provides the pose of the base link
405
- w.r.t. the world frame. For both floating-base and fixed-base systems,
406
- it takes into account the base pose and the optional transform
407
- between the root frame of the model and the base link.
408
- """
409
-
410
396
  # Rename the base transform.
411
397
  W_H_B = base_transform
412
398
 
@@ -417,22 +403,19 @@ class KinDynParameters(JaxsimDataclass):
417
403
  self.joint_model.λ_H_pre[1 : 1 + self.number_of_joints()],
418
404
  ]
419
405
  )
420
-
421
- # Compute the transforms and motion subspaces of the joints.
422
406
  if self.number_of_joints() == 0:
423
- pre_H_suc_J, S_J = jnp.empty((0, 4, 4)), jnp.empty((0, 6, 1))
407
+ pre_H_suc_J = jnp.empty((0, 4, 4))
424
408
  else:
425
- pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
426
- jnp.array(self.joint_model.joint_types[1:]).astype(int),
427
- jnp.array(joint_positions),
428
- jnp.array([j.axis for j in self.joint_model.joint_axis]),
409
+ pre_H_suc_J = jax.vmap(supported_joint_motion)(
410
+ joint_types=jnp.array(self.joint_model.joint_types[1:]).astype(int),
411
+ joint_positions=jnp.array(joint_positions),
412
+ joint_axes=jnp.array([j.axis for j in self.joint_model.joint_axis]),
429
413
  )
430
414
 
431
415
  # Extract the transforms and motion subspaces of the joints.
432
416
  # We stack the base transform W_H_B at index 0, and a dummy motion subspace
433
417
  # for either the fixed or free-floating joint connecting the world to the base.
434
418
  pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
435
- S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
436
419
 
437
420
  # Extract the successor-to-child fixed transforms.
438
421
  # Note that here we include also the index 0 since suc_H_child[0] stores the
@@ -448,7 +431,7 @@ class KinDynParameters(JaxsimDataclass):
448
431
  )
449
432
  )(λ_H_pre, pre_H_suc, suc_H_i)
450
433
 
451
- return i_X_λ, S
434
+ return i_X_λ
452
435
 
453
436
  # ============================
454
437
  # Helpers to update parameters
jaxsim/api/link.py CHANGED
@@ -187,7 +187,7 @@ def transform(
187
187
  idx=link_index,
188
188
  )
189
189
 
190
- return js.model.forward_kinematics(model=model, data=data)[link_index]
190
+ return data._link_transforms[link_index]
191
191
 
192
192
 
193
193
  @jax.jit
@@ -275,7 +275,7 @@ def jacobian(
275
275
  # Compute the doubly-left free-floating full jacobian.
276
276
  B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
277
277
  model=model,
278
- joint_positions=data.joint_positions(),
278
+ joint_positions=data.joint_positions,
279
279
  )
280
280
 
281
281
  # Compute the actual doubly-left free-floating jacobian of the link.
@@ -285,7 +285,7 @@ def jacobian(
285
285
  # Adjust the input representation such that `J_WL_I @ I_ν`.
286
286
  match data.velocity_representation:
287
287
  case VelRepr.Inertial:
288
- W_H_B = data.base_transform()
288
+ W_H_B = data._base_transform
289
289
  B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
290
290
  B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
291
291
  B_X_W, jnp.eye(model.dofs())
@@ -295,7 +295,7 @@ def jacobian(
295
295
  B_J_WL_I = B_J_WL_B
296
296
 
297
297
  case VelRepr.Mixed:
298
- W_R_B = data.base_orientation(dcm=True)
298
+ W_R_B = jaxsim.math.Quaternion.to_dcm(data.base_orientation)
299
299
  BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
300
300
  B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
301
301
  B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
@@ -310,7 +310,7 @@ def jacobian(
310
310
  # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
311
311
  match output_vel_repr:
312
312
  case VelRepr.Inertial:
313
- W_H_B = data.base_transform()
313
+ W_H_B = data._base_transform
314
314
  W_X_B = Adjoint.from_transform(transform=W_H_B)
315
315
  O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841
316
316
 
@@ -320,7 +320,7 @@ def jacobian(
320
320
  O_J_WL_I = L_J_WL_I
321
321
 
322
322
  case VelRepr.Mixed:
323
- W_H_B = data.base_transform()
323
+ W_H_B = data._base_transform
324
324
  W_H_L = W_H_B @ B_H_L
325
325
  LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
326
326
  LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
@@ -378,7 +378,7 @@ def velocity(
378
378
  )
379
379
 
380
380
  # Get the generalized velocity in the input velocity representation.
381
- I_ν = data.generalized_velocity()
381
+ I_ν = data.generalized_velocity
382
382
 
383
383
  # Compute the link velocity in the output velocity representation.
384
384
  return O_J_WL_I @ I_ν