jaxsim 0.3.1.dev64__py3-none-any.whl → 0.3.1.dev94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. jaxsim/__init__.py +5 -5
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/com.py +3 -4
  4. jaxsim/api/common.py +11 -11
  5. jaxsim/api/contact.py +11 -3
  6. jaxsim/api/data.py +3 -6
  7. jaxsim/api/frame.py +9 -10
  8. jaxsim/api/kin_dyn_parameters.py +25 -28
  9. jaxsim/api/link.py +12 -12
  10. jaxsim/api/model.py +47 -43
  11. jaxsim/api/ode.py +19 -12
  12. jaxsim/api/ode_data.py +11 -11
  13. jaxsim/integrators/common.py +17 -20
  14. jaxsim/integrators/fixed_step.py +10 -10
  15. jaxsim/integrators/variable_step.py +13 -13
  16. jaxsim/math/__init__.py +2 -1
  17. jaxsim/math/joint_model.py +2 -1
  18. jaxsim/math/quaternion.py +3 -9
  19. jaxsim/math/transform.py +2 -2
  20. jaxsim/mujoco/loaders.py +5 -5
  21. jaxsim/mujoco/model.py +6 -6
  22. jaxsim/mujoco/visualizer.py +3 -0
  23. jaxsim/parsers/__init__.py +0 -1
  24. jaxsim/parsers/descriptions/joint.py +1 -1
  25. jaxsim/parsers/descriptions/link.py +3 -4
  26. jaxsim/parsers/descriptions/model.py +1 -1
  27. jaxsim/parsers/kinematic_graph.py +38 -39
  28. jaxsim/parsers/rod/parser.py +14 -14
  29. jaxsim/parsers/rod/utils.py +9 -11
  30. jaxsim/rbda/aba.py +6 -12
  31. jaxsim/rbda/collidable_points.py +8 -7
  32. jaxsim/rbda/contacts/soft.py +29 -27
  33. jaxsim/rbda/crba.py +3 -3
  34. jaxsim/rbda/forward_kinematics.py +1 -1
  35. jaxsim/rbda/jacobian.py +8 -8
  36. jaxsim/rbda/rnea.py +3 -3
  37. jaxsim/rbda/utils.py +1 -1
  38. jaxsim/terrain/terrain.py +100 -22
  39. jaxsim/typing.py +14 -22
  40. jaxsim/utils/jaxsim_dataclass.py +4 -4
  41. jaxsim/utils/wrappers.py +5 -1
  42. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev94.dist-info}/METADATA +1 -1
  43. jaxsim-0.3.1.dev94.dist-info/RECORD +68 -0
  44. jaxsim-0.3.1.dev64.dist-info/RECORD +0 -68
  45. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev94.dist-info}/LICENSE +0 -0
  46. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev94.dist-info}/WHEEL +0 -0
  47. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev94.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py CHANGED
@@ -9,14 +9,14 @@ from typing import Any, Sequence
9
9
  import jax
10
10
  import jax.numpy as jnp
11
11
  import jax_dataclasses
12
- import jaxlie
13
12
  import rod
14
13
  from jax_dataclasses import Static
15
14
 
16
15
  import jaxsim.api as js
17
- import jaxsim.parsers.descriptions
16
+ import jaxsim.terrain
18
17
  import jaxsim.typing as jtp
19
- from jaxsim.math import Cross
18
+ from jaxsim.math import Adjoint, Cross
19
+ from jaxsim.parsers.descriptions import ModelDescription
20
20
  from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
21
21
 
22
22
  from .common import VelRepr
@@ -46,12 +46,12 @@ class JaxSimModel(JaxsimDataclass):
46
46
  default=None, repr=False
47
47
  )
48
48
 
49
- _description: Static[
50
- wrappers.HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
51
- ] = dataclasses.field(default=None, repr=False)
49
+ _description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
50
+ dataclasses.field(default=None, repr=False)
51
+ )
52
52
 
53
53
  @property
54
- def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
54
+ def description(self) -> ModelDescription:
55
55
  return self._description.get()
56
56
 
57
57
  def __eq__(self, other: JaxSimModel) -> bool:
@@ -116,7 +116,7 @@ class JaxSimModel(JaxsimDataclass):
116
116
  import jaxsim.parsers.rod
117
117
 
118
118
  # Parse the input resource (either a path to file or a string with the URDF/SDF)
119
- # and build the -intermediate- model description
119
+ # and build the -intermediate- model description.
120
120
  intermediate_description = jaxsim.parsers.rod.build_model_description(
121
121
  model_description=model_description, is_urdf=is_urdf
122
122
  )
@@ -128,7 +128,7 @@ class JaxSimModel(JaxsimDataclass):
128
128
  considered_joints=considered_joints
129
129
  )
130
130
 
131
- # Build the model
131
+ # Build the model.
132
132
  model = JaxSimModel.build(
133
133
  model_description=intermediate_description,
134
134
  model_name=model_name,
@@ -136,7 +136,7 @@ class JaxSimModel(JaxsimDataclass):
136
136
  contact_model=contact_model,
137
137
  )
138
138
 
139
- # Store the origin of the model, in case downstream logic needs it
139
+ # Store the origin of the model, in case downstream logic needs it.
140
140
  with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
141
141
  model.built_from = model_description
142
142
 
@@ -144,7 +144,7 @@ class JaxSimModel(JaxsimDataclass):
144
144
 
145
145
  @staticmethod
146
146
  def build(
147
- model_description: jaxsim.parsers.descriptions.ModelDescription,
147
+ model_description: ModelDescription,
148
148
  model_name: str | None = None,
149
149
  *,
150
150
  terrain: jaxsim.terrain.Terrain | None = None,
@@ -169,14 +169,14 @@ class JaxSimModel(JaxsimDataclass):
169
169
  """
170
170
  from jaxsim.rbda.contacts.soft import SoftContacts
171
171
 
172
- # Set the model name (if not provided, use the one from the model description)
172
+ # Set the model name (if not provided, use the one from the model description).
173
173
  model_name = model_name if model_name is not None else model_description.name
174
174
 
175
- # Set the terrain (if not provided, use the default flat terrain)
175
+ # Set the terrain (if not provided, use the default flat terrain).
176
176
  terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
177
177
  contact_model = contact_model or SoftContacts(terrain=terrain)
178
178
 
179
- # Build the model
179
+ # Build the model.
180
180
  model = JaxSimModel(
181
181
  model_name=model_name,
182
182
  _description=wrappers.HashlessObject(obj=model_description),
@@ -361,7 +361,7 @@ def reduce(
361
361
  considered_joints=list(considered_joints)
362
362
  )
363
363
 
364
- # Build the reduced model
364
+ # Build the reduced model.
365
365
  reduced_model = JaxSimModel.build(
366
366
  model_description=reduced_intermediate_description,
367
367
  model_name=model.name(),
@@ -369,7 +369,7 @@ def reduce(
369
369
  contact_model=model.contact_model,
370
370
  )
371
371
 
372
- # Store the origin of the model, in case downstream logic needs it
372
+ # Store the origin of the model, in case downstream logic needs it.
373
373
  with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
374
374
  reduced_model.built_from = model.built_from
375
375
 
@@ -493,7 +493,7 @@ def generalized_free_floating_jacobian(
493
493
  case VelRepr.Inertial:
494
494
 
495
495
  W_H_B = data.base_transform()
496
- B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
496
+ B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
497
497
 
498
498
  B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
499
499
  B_X_W, jnp.eye(model.dofs())
@@ -507,7 +507,7 @@ def generalized_free_floating_jacobian(
507
507
 
508
508
  W_R_B = data.base_orientation(dcm=True)
509
509
  BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
510
- B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
510
+ B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
511
511
 
512
512
  B_J_full_WX_I = B_J_full_WX_BW = (
513
513
  B_J_full_WX_B
@@ -715,7 +715,7 @@ def forward_dynamics_aba(
715
715
 
716
716
  # In Mixed representation, we need to include a cross product in ℝ⁶.
717
717
  # In Inertial and Body representations, the cross product is always zero.
718
- C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
718
+ C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
719
719
  return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)
720
720
 
721
721
  match data.velocity_representation:
@@ -797,21 +797,21 @@ def forward_dynamics_crb(
797
797
  # Prepare data
798
798
  # ============
799
799
 
800
- # Build joint torques if not provided
800
+ # Build joint torques if not provided.
801
801
  τ = (
802
802
  jnp.atleast_1d(joint_forces)
803
803
  if joint_forces is not None
804
804
  else jnp.zeros_like(data.joint_positions())
805
805
  )
806
806
 
807
- # Build external forces if not provided
807
+ # Build external forces if not provided.
808
808
  f = (
809
809
  jnp.atleast_2d(link_forces)
810
810
  if link_forces is not None
811
811
  else jnp.zeros(shape=(model.number_of_links(), 6))
812
812
  )
813
813
 
814
- # Compute terms of the floating-base EoM
814
+ # Compute terms of the floating-base EoM.
815
815
  M = free_floating_mass_matrix(model=model, data=data)
816
816
  h = free_floating_bias_forces(model=model, data=data)
817
817
  S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
@@ -848,7 +848,7 @@ def forward_dynamics_crb(
848
848
  # 6D transformation X.
849
849
  v̇_WB = ν̇[0:6].squeeze().astype(float)
850
850
 
851
- # Extract the joint accelerations
851
+ # Extract the joint accelerations.
852
852
  s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float)
853
853
 
854
854
  return v̇_WB, s̈
@@ -880,7 +880,9 @@ def free_floating_mass_matrix(
880
880
 
881
881
  case VelRepr.Inertial:
882
882
 
883
- B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
883
+ B_X_W = Adjoint.from_transform(
884
+ transform=data.base_transform(), inverse=True
885
+ )
884
886
  invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
885
887
 
886
888
  return invT.T @ M_body @ invT
@@ -888,7 +890,7 @@ def free_floating_mass_matrix(
888
890
  case VelRepr.Mixed:
889
891
 
890
892
  BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
891
- B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
893
+ B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
892
894
  invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
893
895
 
894
896
  return invT.T @ M_body @ invT
@@ -1077,8 +1079,8 @@ def inverse_dynamics(
1077
1079
  expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
1078
1080
  """
1079
1081
 
1080
- W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
1081
- C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
1082
+ W_X_C = Adjoint.from_transform(transform=W_H_C)
1083
+ C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
1082
1084
  C_v_WC = C_X_W @ W_v_WC
1083
1085
 
1084
1086
  # In Mixed representation, we need to include a cross product in ℝ⁶.
@@ -1187,12 +1189,12 @@ def free_floating_gravity_forces(
1187
1189
  The free-floating gravity forces of the model.
1188
1190
  """
1189
1191
 
1190
- # Build a zeroed state
1192
+ # Build a zeroed state.
1191
1193
  data_rnea = js.data.JaxSimModelData.zero(
1192
1194
  model=model, velocity_representation=data.velocity_representation
1193
1195
  )
1194
1196
 
1195
- # Set just the generalized position
1197
+ # Set just the generalized position.
1196
1198
  with data_rnea.mutable_context(
1197
1199
  mutability=Mutability.MUTABLE, restore_after_exception=False
1198
1200
  ):
@@ -1237,12 +1239,12 @@ def free_floating_bias_forces(
1237
1239
  The free-floating bias forces of the model.
1238
1240
  """
1239
1241
 
1240
- # Build a zeroed state
1242
+ # Build a zeroed state.
1241
1243
  data_rnea = js.data.JaxSimModelData.zero(
1242
1244
  model=model, velocity_representation=data.velocity_representation
1243
1245
  )
1244
1246
 
1245
- # Set the generalized position and generalized velocity
1247
+ # Set the generalized position and generalized velocity.
1246
1248
  with data_rnea.mutable_context(
1247
1249
  mutability=Mutability.MUTABLE, restore_after_exception=False
1248
1250
  ):
@@ -1361,12 +1363,14 @@ def total_momentum_jacobian(
1361
1363
  B_Jh = B_Jh_B
1362
1364
 
1363
1365
  case VelRepr.Inertial:
1364
- B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
1366
+ B_X_W = Adjoint.from_transform(
1367
+ transform=data.base_transform(), inverse=True
1368
+ )
1365
1369
  B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
1366
1370
 
1367
1371
  case VelRepr.Mixed:
1368
1372
  BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1369
- B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
1373
+ B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
1370
1374
  B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
1371
1375
 
1372
1376
  case _:
@@ -1378,14 +1382,14 @@ def total_momentum_jacobian(
1378
1382
 
1379
1383
  case VelRepr.Inertial:
1380
1384
  W_H_B = data.base_transform()
1381
- B_Xv_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
1385
+ B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
1382
1386
  W_Xf_B = B_Xv_W.T
1383
1387
  W_Jh = W_Xf_B @ B_Jh
1384
1388
  return W_Jh
1385
1389
 
1386
1390
  case VelRepr.Mixed:
1387
1391
  BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1388
- B_Xv_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
1392
+ B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
1389
1393
  BW_Xf_B = B_Xv_BW.T
1390
1394
  BW_Jh = BW_Xf_B @ B_Jh
1391
1395
  return BW_Jh
@@ -1449,7 +1453,7 @@ def average_velocity_jacobian(
1449
1453
  W_p_CoM = js.com.com_position(model=model, data=data)
1450
1454
 
1451
1455
  W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
1452
- W_X_GW = jaxlie.SE3.from_matrix(W_H_GW).adjoint()
1456
+ W_X_GW = Adjoint.from_transform(transform=W_H_GW)
1453
1457
 
1454
1458
  return W_X_GW @ GW_J
1455
1459
 
@@ -1461,7 +1465,7 @@ def average_velocity_jacobian(
1461
1465
  B_R_W = data.base_orientation(dcm=True).transpose()
1462
1466
 
1463
1467
  B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
1464
- B_X_GB = jaxlie.SE3.from_matrix(B_H_GB).adjoint()
1468
+ B_X_GB = Adjoint.from_transform(transform=B_H_GB)
1465
1469
 
1466
1470
  return B_X_GB @ GB_J
1467
1471
 
@@ -1472,7 +1476,7 @@ def average_velocity_jacobian(
1472
1476
  W_p_CoM = js.com.com_position(model=model, data=data)
1473
1477
 
1474
1478
  BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
1475
- BW_X_GW = jaxlie.SE3.from_matrix(BW_H_GW).adjoint()
1479
+ BW_X_GW = Adjoint.from_transform(transform=BW_H_GW)
1476
1480
 
1477
1481
  return BW_X_GW @ GW_J
1478
1482
 
@@ -1518,8 +1522,8 @@ def link_bias_accelerations(
1518
1522
  expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
1519
1523
  """
1520
1524
 
1521
- W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
1522
- C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
1525
+ W_X_C = Adjoint.from_transform(transform=W_H_C)
1526
+ C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
1523
1527
 
1524
1528
  # In Mixed representation, we need to include a cross product in ℝ⁶.
1525
1529
  # In Inertial and Body representations, the cross product is always zero.
@@ -1606,7 +1610,7 @@ def link_bias_accelerations(
1606
1610
  # not remove gravity during the propagation.
1607
1611
 
1608
1612
  # Initialize the loop.
1609
- Carry = tuple[jtp.MatrixJax, jtp.MatrixJax]
1613
+ Carry = tuple[jtp.Matrix, jtp.Matrix]
1610
1614
  carry0: Carry = (L_v_WL, L_v̇_WL)
1611
1615
 
1612
1616
  def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]:
@@ -1832,8 +1836,8 @@ def step(
1832
1836
  integrator_state: The state of the integrator.
1833
1837
  joint_forces: The joint forces to consider.
1834
1838
  link_forces:
1835
- The link 6D forces to consider.
1836
- The frame in which they are expressed must be `data.velocity_representation`.
1839
+ The 6D forces to apply to the links expressed in the frame corresponding to
1840
+ the velocity representation of `data`.
1837
1841
  kwargs: Additional kwargs to pass to the integrator.
1838
1842
 
1839
1843
  Returns:
jaxsim/api/ode.py CHANGED
@@ -96,7 +96,9 @@ def system_velocity_dynamics(
96
96
  model: The model to consider.
97
97
  data: The data of the considered model.
98
98
  joint_forces: The joint forces to apply.
99
- link_forces: The 6D forces to apply to the links.
99
+ link_forces:
100
+ The 6D forces to apply to the links expressed in the frame corresponding to
101
+ the velocity representation of `data`.
100
102
 
101
103
  Returns:
102
104
  A tuple containing the derivative of the base 6D velocity in inertial-fixed
@@ -105,14 +107,16 @@ def system_velocity_dynamics(
105
107
  the system dynamics evaluation.
106
108
  """
107
109
 
108
- # Build joint torques if not provided
110
+ # Build joint torques if not provided.
109
111
  τ = (
110
112
  jnp.atleast_1d(joint_forces.squeeze())
111
113
  if joint_forces is not None
112
114
  else jnp.zeros_like(data.joint_positions())
113
115
  ).astype(float)
114
116
 
115
- # Build link forces if not provided
117
+ # Build link forces if not provided.
118
+ # These forces are expressed in the frame corresponding to the velocity
119
+ # representation of data.
116
120
  O_f_L = (
117
121
  jnp.atleast_2d(link_forces.squeeze())
118
122
  if link_forces is not None
@@ -127,16 +131,17 @@ def system_velocity_dynamics(
127
131
  # with the terrain.
128
132
  W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
129
133
 
130
- # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
131
- # expressed in the world frame.
132
- W_f_Ci = None
134
+ # Import privately the soft contacts classes.
135
+ from jaxsim.rbda.contacts.soft import SoftContactsState
133
136
 
134
137
  # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
138
+ assert isinstance(data.state.contact, SoftContactsState)
135
139
  ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)
136
140
 
137
141
  if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
138
- # Compute the 6D forces applied to each collidable point and the
139
- # corresponding material deformation rates.
142
+
143
+ # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
144
+ # and the corresponding material deformation rates.
140
145
  with data.switch_velocity_representation(VelRepr.Inertial):
141
146
  W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data)
142
147
 
@@ -178,7 +183,7 @@ def system_velocity_dynamics(
178
183
  model.kin_dyn_parameters.joint_parameters.friction_viscous
179
184
  ).astype(float)
180
185
 
181
- # Compute the joint friction torque
186
+ # Compute the joint friction torque.
182
187
  τ_friction = -(
183
188
  jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
184
189
  + jnp.diag(kv) @ data.state.physics_model.joint_velocities
@@ -188,7 +193,7 @@ def system_velocity_dynamics(
188
193
  # Compute forward dynamics
189
194
  # ========================
190
195
 
191
- # Compute the total joint forces
196
+ # Compute the total joint forces.
192
197
  τ_total = τ + τ_friction + τ_position_limit
193
198
 
194
199
  references = js.references.JaxSimModelReferences.build(
@@ -202,7 +207,7 @@ def system_velocity_dynamics(
202
207
  with references.switch_velocity_representation(VelRepr.Inertial):
203
208
  W_f_L = references.link_forces(model=model, data=data)
204
209
 
205
- # Compute the total external 6D forces applied to the links
210
+ # Compute the total external 6D forces applied to the links.
206
211
  W_f_L_total = W_f_L + W_f_Li_terrain
207
212
 
208
213
  # - Joint accelerations: s̈ ∈ ℝⁿ
@@ -273,7 +278,9 @@ def system_dynamics(
273
278
  model: The model to consider.
274
279
  data: The data of the considered model.
275
280
  joint_forces: The joint forces to apply.
276
- link_forces: The 6D forces to apply to the links.
281
+ link_forces:
282
+ The 6D forces to apply to the links expressed in the frame corresponding to
283
+ the velocity representation of `data`.
277
284
  baumgarte_quaternion_regularization:
278
285
  The Baumgarte regularization coefficient used to adjust the norm of the
279
286
  quaternion (only used in integrators not operating on the SO(3) manifold).
jaxsim/api/ode_data.py CHANGED
@@ -31,8 +31,8 @@ class ODEInput(JaxsimDataclass):
31
31
  @staticmethod
32
32
  def build_from_jaxsim_model(
33
33
  model: js.model.JaxSimModel | None = None,
34
- joint_forces: jtp.VectorJax | None = None,
35
- link_forces: jtp.MatrixJax | None = None,
34
+ joint_forces: jtp.VectorLike | None = None,
35
+ link_forces: jtp.MatrixLike | None = None,
36
36
  ) -> ODEInput:
37
37
  """
38
38
  Build an `ODEInput` from a `JaxSimModel`.
@@ -160,7 +160,7 @@ class ODEState(JaxsimDataclass):
160
160
  `JaxSimModel` and initialized to zero.
161
161
  """
162
162
 
163
- # Get the contact model from the `JaxSimModel`
163
+ # Get the contact model from the `JaxSimModel`.
164
164
  match model.contact_model:
165
165
  case SoftContacts():
166
166
  contact = SoftContactsState.build_from_jaxsim_model(
@@ -212,7 +212,7 @@ class ODEState(JaxsimDataclass):
212
212
  else PhysicsModelState.zero(model=model)
213
213
  )
214
214
 
215
- # Get the contact model from the `JaxSimModel`
215
+ # Get the contact model from the `JaxSimModel`.
216
216
  match contact:
217
217
  case SoftContactsState():
218
218
  pass
@@ -423,7 +423,7 @@ class PhysicsModelState(JaxsimDataclass):
423
423
  base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
424
424
  )
425
425
 
426
- # assert state.valid(physics_model)
426
+ # TODO (diegoferigo): assert state.valid(physics_model)
427
427
  return physics_model_state
428
428
 
429
429
  @staticmethod
@@ -501,14 +501,14 @@ class PhysicsModelInput(JaxsimDataclass):
501
501
  f_ext: The matrix of external forces applied to the links.
502
502
  """
503
503
 
504
- tau: jtp.VectorJax
505
- f_ext: jtp.MatrixJax
504
+ tau: jtp.Vector
505
+ f_ext: jtp.Matrix
506
506
 
507
507
  @staticmethod
508
508
  def build_from_jaxsim_model(
509
509
  model: js.model.JaxSimModel | None = None,
510
- joint_forces: jtp.VectorJax | None = None,
511
- link_forces: jtp.MatrixJax | None = None,
510
+ joint_forces: jtp.VectorLike | None = None,
511
+ link_forces: jtp.MatrixLike | None = None,
512
512
  ) -> PhysicsModelInput:
513
513
  """
514
514
  Build a `PhysicsModelInput` from a `JaxSimModel`.
@@ -535,8 +535,8 @@ class PhysicsModelInput(JaxsimDataclass):
535
535
 
536
536
  @staticmethod
537
537
  def build(
538
- joint_forces: jtp.VectorJax | None = None,
539
- link_forces: jtp.MatrixJax | None = None,
538
+ joint_forces: jtp.VectorLike | None = None,
539
+ link_forces: jtp.MatrixLike | None = None,
540
540
  number_of_dofs: jtp.Int | None = None,
541
541
  number_of_links: jtp.Int | None = None,
542
542
  ) -> PhysicsModelInput:
@@ -27,8 +27,8 @@ except ImportError:
27
27
  # Generic types
28
28
  # =============
29
29
 
30
- Time = jax.typing.ArrayLike
31
- TimeStep = jax.typing.ArrayLike
30
+ Time = jtp.FloatLike
31
+ TimeStep = jtp.FloatLike
32
32
  State = NextState = TypeVar("State")
33
33
  StateDerivative = TypeVar("StateDerivative")
34
34
  PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
@@ -79,7 +79,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
79
79
  The integrator object.
80
80
  """
81
81
 
82
- return cls(dynamics=dynamics, **kwargs) # noqa
82
+ return cls(dynamics=dynamics, **kwargs)
83
83
 
84
84
  def step(
85
85
  self,
@@ -191,14 +191,14 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
191
191
  class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
192
192
 
193
193
  # The Runge-Kutta matrix.
194
- A: ClassVar[jax.typing.ArrayLike]
194
+ A: ClassVar[jtp.Matrix]
195
195
 
196
196
  # The weights coefficients.
197
197
  # Note that in practice we typically use its transpose `b.transpose()`.
198
- b: ClassVar[jax.typing.ArrayLike]
198
+ b: ClassVar[jtp.Matrix]
199
199
 
200
200
  # The nodes coefficients.
201
- c: ClassVar[jax.typing.ArrayLike]
201
+ c: ClassVar[jtp.Vector]
202
202
 
203
203
  # Define the order of the solution.
204
204
  # It should have as many elements as the number of rows of `b.transpose()`.
@@ -384,7 +384,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
384
384
  # Define the computation of the Runge-Kutta stage.
385
385
  def compute_ki() -> jax.Array:
386
386
 
387
- # Compute ∑ⱼ aᵢⱼ k
387
+ # Compute ∑ⱼ aᵢⱼ kⱼ.
388
388
  op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k)
389
389
  sum_ak = jax.tree_util.tree_map(op_sum_ak, K)
390
390
 
@@ -440,7 +440,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
440
440
 
441
441
  @staticmethod
442
442
  def butcher_tableau_is_valid(
443
- A: jax.typing.ArrayLike, b: jax.typing.ArrayLike, c: jax.typing.ArrayLike
443
+ A: jtp.Matrix, b: jtp.Matrix, c: jtp.Vector
444
444
  ) -> jtp.Bool:
445
445
  """
446
446
  Check if the Butcher tableau is valid.
@@ -466,7 +466,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
466
466
  return valid
467
467
 
468
468
  @staticmethod
469
- def butcher_tableau_is_explicit(A: jax.typing.ArrayLike) -> jtp.Bool:
469
+ def butcher_tableau_is_explicit(A: jtp.Matrix) -> jtp.Bool:
470
470
  """
471
471
  Check if the Butcher tableau corresponds to an explicit integration scheme.
472
472
 
@@ -481,9 +481,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
481
481
 
482
482
  @staticmethod
483
483
  def butcher_tableau_supports_fsal(
484
- A: jax.typing.ArrayLike,
485
- b: jax.typing.ArrayLike,
486
- c: jax.typing.ArrayLike,
484
+ A: jtp.Matrix,
485
+ b: jtp.Matrix,
486
+ c: jtp.Vector,
487
487
  index_of_solution: jtp.IntLike = 0,
488
488
  ) -> [bool, int | None]:
489
489
  """
@@ -562,10 +562,9 @@ class ExplicitRungeKuttaSO3Mixin:
562
562
 
563
563
  # Indices to convert quaternions between serializations.
564
564
  to_xyzw = jnp.array([1, 2, 3, 0])
565
- to_wxyz = jnp.array([3, 0, 1, 2])
566
565
 
567
- # Get the initial quaternion.
568
- W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
566
+ # Get the initial rotation.
567
+ W_R_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
569
568
  xyzw=x0.physics_model.base_quaternion[to_xyzw]
570
569
  )
571
570
 
@@ -575,15 +574,13 @@ class ExplicitRungeKuttaSO3Mixin:
575
574
  # on the SO(3) manifold.
576
575
  W_ω_WB_tf = xf.physics_model.base_angular_velocity
577
576
 
578
- # Integrate the quaternion on SO(3).
577
+ # Integrate the orientation on SO(3).
579
578
  # Note that we left-multiply with the exponential map since the angular
580
579
  # velocity is expressed in the inertial frame.
581
- W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
580
+ W_R_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_R_B_t0
582
581
 
583
582
  # Replace the quaternion in the final state.
584
583
  return xf.replace(
585
- physics_model=xf.physics_model.replace(
586
- base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
587
- ),
584
+ physics_model=xf.physics_model.replace(base_quaternion=W_R_B_tf.wxyz),
588
585
  validate=True,
589
586
  )
@@ -1,10 +1,10 @@
1
1
  from typing import ClassVar, Generic
2
2
 
3
- import jax
4
3
  import jax.numpy as jnp
5
4
  import jax_dataclasses
6
5
 
7
6
  import jaxsim.api as js
7
+ import jaxsim.typing as jtp
8
8
 
9
9
  from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
10
10
 
@@ -18,11 +18,11 @@ ODEStateDerivative = js.ode_data.ODEState
18
18
  @jax_dataclasses.pytree_dataclass
19
19
  class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
20
20
 
21
- A: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(0).astype(float)
21
+ A: ClassVar[jtp.Matrix] = jnp.atleast_2d(0).astype(float)
22
22
 
23
- b: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(1).astype(float).transpose()
23
+ b: ClassVar[jtp.Matrix] = jnp.atleast_2d(1).astype(float).transpose()
24
24
 
25
- c: ClassVar[jax.typing.ArrayLike] = jnp.atleast_1d(0).astype(float)
25
+ c: ClassVar[jtp.Vector] = jnp.atleast_1d(0).astype(float)
26
26
 
27
27
  row_index_of_solution: ClassVar[int] = 0
28
28
  order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
@@ -31,14 +31,14 @@ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
31
31
  @jax_dataclasses.pytree_dataclass
32
32
  class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
33
33
 
34
- A: ClassVar[jax.typing.ArrayLike] = jnp.array(
34
+ A: ClassVar[jtp.Matrix] = jnp.array(
35
35
  [
36
36
  [0, 0],
37
37
  [1, 0],
38
38
  ]
39
39
  ).astype(float)
40
40
 
41
- b: ClassVar[jax.typing.ArrayLike] = (
41
+ b: ClassVar[jtp.Matrix] = (
42
42
  jnp.atleast_2d(
43
43
  jnp.array([1 / 2, 1 / 2]),
44
44
  )
@@ -46,7 +46,7 @@ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
46
46
  .transpose()
47
47
  )
48
48
 
49
- c: ClassVar[jax.typing.ArrayLike] = jnp.array(
49
+ c: ClassVar[jtp.Vector] = jnp.array(
50
50
  [0, 1],
51
51
  ).astype(float)
52
52
 
@@ -57,7 +57,7 @@ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
57
57
  @jax_dataclasses.pytree_dataclass
58
58
  class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
59
59
 
60
- A: ClassVar[jax.typing.ArrayLike] = jnp.array(
60
+ A: ClassVar[jtp.Matrix] = jnp.array(
61
61
  [
62
62
  [0, 0, 0, 0],
63
63
  [1 / 2, 0, 0, 0],
@@ -66,7 +66,7 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
66
66
  ]
67
67
  ).astype(float)
68
68
 
69
- b: ClassVar[jax.typing.ArrayLike] = (
69
+ b: ClassVar[jtp.Matrix] = (
70
70
  jnp.atleast_2d(
71
71
  jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
72
72
  )
@@ -74,7 +74,7 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
74
74
  .transpose()
75
75
  )
76
76
 
77
- c: ClassVar[jax.typing.ArrayLike] = jnp.array(
77
+ c: ClassVar[jtp.Vector] = jnp.array(
78
78
  [0, 1 / 2, 1 / 2, 1],
79
79
  ).astype(float)
80
80