jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,24 @@
1
1
  import os
2
- from typing import Union
2
+ import pathlib
3
+ from collections.abc import Callable
4
+ from typing import TypeVar
3
5
 
4
- import jax.numpy as jnp
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
8
  import rod
9
+ import trimesh
10
+ from rod.utils.resolve_uris import resolve_local_uri
8
11
 
12
+ import jaxsim.typing as jtp
13
+ from jaxsim import logging
14
+ from jaxsim.math import Adjoint, Inertia
9
15
  from jaxsim.parsers import descriptions
16
+ from jaxsim.parsers.rod import meshes
10
17
 
18
+ MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray])
11
19
 
12
- def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
20
+
21
+ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
13
22
  """
14
23
  Extract the 6D inertia matrix from an SDF inertial element.
15
24
 
@@ -20,13 +29,10 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
20
29
  The 6D inertia matrix of the link expressed in the link frame.
21
30
  """
22
31
 
23
- from jaxsim.math.inertia import Inertia
24
- from jaxsim.sixd import se3
25
-
26
- # Extract the "mass" element
32
+ # Extract the "mass" element.
27
33
  m = inertial.mass
28
34
 
29
- # Extract the "inertia" element
35
+ # Extract the "inertia" element.
30
36
  inertia_element = inertial.inertia
31
37
 
32
38
  ixx = inertia_element.ixx
@@ -36,7 +42,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
36
42
  ixz = inertia_element.ixz if inertia_element.ixz is not None else 0.0
37
43
  iyz = inertia_element.iyz if inertia_element.iyz is not None else 0.0
38
44
 
39
- # Build the 3x3 inertia matrix expressed in the CoM
45
+ # Build the 3x3 inertia matrix expressed in the CoM.
40
46
  I_CoM = np.array(
41
47
  [
42
48
  [ixx, ixy, ixz],
@@ -45,73 +51,52 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
45
51
  ]
46
52
  )
47
53
 
48
- # Build the 6x6 generalized inertia at the CoM
54
+ # Build the 6x6 generalized inertia at the CoM.
49
55
  M_CoM = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_CoM)
50
56
 
51
- # Compute the transform from the inertial frame (CoM) to the link frame
57
+ # Compute the transform from the inertial frame (CoM) to the link frame.
52
58
  L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)
53
59
 
54
- # We need its inverse
55
- CoM_H_L = se3.SE3.from_matrix(matrix=L_H_CoM).inverse()
56
- CoM_X_L: npt.NDArray = CoM_H_L.adjoint()
60
+ # We need its inverse.
61
+ CoM_X_L = Adjoint.from_transform(transform=L_H_CoM, inverse=True)
57
62
 
58
- # Express the CoM inertia matrix in the link frame L
63
+ # Express the CoM inertia matrix in the link frame L.
59
64
  M_L = CoM_X_L.T @ M_CoM @ CoM_X_L
60
65
 
61
- return jnp.array(M_L)
66
+ return M_L.astype(dtype=float)
62
67
 
63
68
 
64
- def axis_to_jtype(
65
- axis: rod.Axis, type: str
66
- ) -> Union[descriptions.JointType, descriptions.JointDescriptor]:
69
+ def joint_to_joint_type(joint: rod.Joint) -> int:
67
70
  """
68
- Convert an SDF axis to a joint type.
71
+ Extract the joint type from an SDF joint.
69
72
 
70
73
  Args:
71
- axis: The SDF axis.
72
- type: The SDF joint type.
74
+ joint: The parsed SDF joint.
73
75
 
74
76
  Returns:
75
- The corresponding joint type description.
77
+ The integer corresponding to the joint type.
76
78
  """
77
79
 
78
- if type == "fixed":
79
- return descriptions.JointType.F
80
+ axis = joint.axis
81
+ joint_type = joint.type
82
+
83
+ if joint_type == "fixed":
84
+ return descriptions.JointType.Fixed
80
85
 
81
86
  if not (axis.xyz is not None and axis.xyz.xyz is not None):
82
87
  raise ValueError("Failed to read axis xyz data")
83
88
 
84
- axis_xyz = np.array(axis.xyz.xyz)
85
-
86
- if np.allclose(axis_xyz, [1, 0, 0]) and type in {"revolute", "continuous"}:
87
- return descriptions.JointType.Rx
88
-
89
- if np.allclose(axis_xyz, [0, 1, 0]) and type in {"revolute", "continuous"}:
90
- return descriptions.JointType.Ry
91
-
92
- if np.allclose(axis_xyz, [0, 0, 1]) and type in {"revolute", "continuous"}:
93
- return descriptions.JointType.Rz
89
+ # Make sure that the axis is a unary vector.
90
+ axis_xyz = np.array(axis.xyz.xyz).astype(float)
91
+ axis_xyz = axis_xyz / np.linalg.norm(axis_xyz)
94
92
 
95
- if np.allclose(axis_xyz, [1, 0, 0]) and type == "prismatic":
96
- return descriptions.JointType.Px
93
+ if joint_type in {"revolute", "continuous"}:
94
+ return descriptions.JointType.Revolute
97
95
 
98
- if np.allclose(axis_xyz, [0, 1, 0]) and type == "prismatic":
99
- return descriptions.JointType.Py
96
+ if joint_type == "prismatic":
97
+ return descriptions.JointType.Prismatic
100
98
 
101
- if np.allclose(axis_xyz, [0, 0, 1]) and type == "prismatic":
102
- return descriptions.JointType.Pz
103
-
104
- if type == "revolute":
105
- return descriptions.JointGenericAxis(
106
- code=descriptions.JointType.R, axis=np.array(axis_xyz, dtype=float)
107
- )
108
-
109
- if type == "prismatic":
110
- return descriptions.JointGenericAxis(
111
- code=descriptions.JointType.P, axis=np.array(axis_xyz, dtype=float)
112
- )
113
-
114
- raise ValueError("Joint not supported", axis_xyz, type)
99
+ raise ValueError("Joint not supported", axis_xyz, joint_type)
115
100
 
116
101
 
117
102
  def create_box_collision(
@@ -132,22 +117,19 @@ def create_box_collision(
132
117
 
133
118
  center = np.array([x / 2, y / 2, z / 2])
134
119
 
135
- box_corners = (
136
- np.vstack(
137
- [
138
- np.array([0, 0, 0]),
139
- np.array([x, 0, 0]),
140
- np.array([x, y, 0]),
141
- np.array([0, y, 0]),
142
- np.array([0, 0, z]),
143
- np.array([x, 0, z]),
144
- np.array([x, y, z]),
145
- np.array([0, y, z]),
146
- ]
147
- )
148
- - center
120
+ # Define the bottom corners.
121
+ bottom_corners = np.array([[0, 0, 0], [x, 0, 0], [x, y, 0], [0, y, 0]])
122
+
123
+ # Conditionally add the top corners based on the environment variable.
124
+ top_corners = (
125
+ np.array([[0, 0, z], [x, 0, z], [x, y, z], [0, y, z]])
126
+ if not os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0")
127
+ else []
149
128
  )
150
129
 
130
+ # Combine and shift by the center
131
+ box_corners = np.vstack([bottom_corners, *top_corners]) - center
132
+
151
133
  H = collision.pose.transform() if collision.pose is not None else np.eye(4)
152
134
 
153
135
  center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1]
@@ -158,7 +140,7 @@ def create_box_collision(
158
140
  collidable_points = [
159
141
  descriptions.CollidablePoint(
160
142
  parent_link=link_description,
161
- position=corner,
143
+ position=np.array(corner),
162
144
  enabled=True,
163
145
  )
164
146
  for corner in box_corners_wrt_link.T
@@ -185,25 +167,33 @@ def create_sphere_collision(
185
167
 
186
168
  # From https://stackoverflow.com/a/26127012
187
169
  def fibonacci_sphere(samples: int) -> npt.NDArray:
188
- points = []
189
- phi = np.pi * (3.0 - np.sqrt(5.0)) # golden angle in radians
190
-
191
- for i in range(samples):
192
- y = 1 - (i / float(samples - 1)) * 2 # y goes from 1 to -1
193
- radius = np.sqrt(1 - y * y) # radius at y
194
-
195
- theta = phi * i # golden angle increment
196
-
197
- x = np.cos(theta) * radius
198
- z = np.sin(theta) * radius
170
+ # Get the golden ratio in radians.
171
+ phi = np.pi * (3.0 - np.sqrt(5.0))
172
+
173
+ # Generate the points.
174
+ points = [
175
+ np.array(
176
+ [
177
+ np.cos(phi * i)
178
+ * np.sqrt(1 - (y := 1 - 2 * i / (samples - 1)) ** 2),
179
+ y,
180
+ np.sin(phi * i) * np.sqrt(1 - y**2),
181
+ ]
182
+ )
183
+ for i in range(samples)
184
+ ]
199
185
 
200
- points.append(np.array([x, y, z]))
186
+ # Filter to keep only the bottom half if required.
187
+ if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0"):
188
+ # Keep only the points with z <= 0.
189
+ points = [point for point in points if point[2] <= 0]
201
190
 
202
191
  return np.vstack(points)
203
192
 
204
193
  r = collision.geometry.sphere.radius
194
+
205
195
  sphere_points = r * fibonacci_sphere(
206
- samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="250"))
196
+ samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="50"))
207
197
  )
208
198
 
209
199
  H = collision.pose.transform() if collision.pose is not None else np.eye(4)
@@ -217,7 +207,7 @@ def create_sphere_collision(
217
207
  collidable_points = [
218
208
  descriptions.CollidablePoint(
219
209
  parent_link=link_description,
220
- position=point,
210
+ position=np.array(point),
221
211
  enabled=True,
222
212
  )
223
213
  for point in sphere_points_wrt_link.T
@@ -226,3 +216,58 @@ def create_sphere_collision(
226
216
  return descriptions.SphereCollision(
227
217
  collidable_points=collidable_points, center=center_wrt_link
228
218
  )
219
+
220
+
221
+ def create_mesh_collision(
222
+ collision: rod.Collision,
223
+ link_description: descriptions.LinkDescription,
224
+ method: MeshMappingMethod = None,
225
+ ) -> descriptions.MeshCollision:
226
+ """
227
+ Create a mesh collision from an SDF collision element.
228
+
229
+ Args:
230
+ collision: The SDF collision element.
231
+ link_description: The link description.
232
+ method: The method to use for mesh wrapping.
233
+
234
+ Returns:
235
+ The mesh collision description.
236
+ """
237
+
238
+ file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
239
+ file_type = file.suffix.replace(".", "")
240
+ mesh = trimesh.load_mesh(file, file_type=file_type)
241
+
242
+ if mesh.is_empty:
243
+ raise RuntimeError(f"Failed to process '{file}' with trimesh")
244
+
245
+ mesh.apply_scale(collision.geometry.mesh.scale)
246
+ logging.info(
247
+ msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{file_type}'"
248
+ )
249
+
250
+ if method is None:
251
+ method = meshes.VertexExtraction()
252
+ logging.debug("Using default Vertex Extraction method for mesh wrapping")
253
+ else:
254
+ logging.debug(f"Using method {method} for mesh wrapping")
255
+
256
+ points = method(mesh=mesh)
257
+ logging.debug(f"Extracted {len(points)} points from mesh")
258
+
259
+ W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4)
260
+
261
+ # Extract translation from transformation matrix
262
+ W_p_L = W_H_L[:3, 3]
263
+ mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L
264
+ collidable_points = [
265
+ descriptions.CollidablePoint(
266
+ parent_link=link_description,
267
+ position=point,
268
+ enabled=True,
269
+ )
270
+ for point in mesh_points_wrt_link
271
+ ]
272
+
273
+ return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L)
@@ -0,0 +1,11 @@
1
+ from . import contacts
2
+ from .aba import aba
3
+ from .collidable_points import collidable_points_pos_vel
4
+ from .crba import crba
5
+ from .forward_kinematics import forward_kinematics, forward_kinematics_model
6
+ from .jacobian import (
7
+ jacobian,
8
+ jacobian_derivative_full_doubly_left,
9
+ jacobian_full_doubly_left,
10
+ )
11
+ from .rnea import rnea
jaxsim/rbda/aba.py ADDED
@@ -0,0 +1,289 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint, Cross, StandardGravity
8
+
9
+ from . import utils
10
+
11
+
12
+ def aba(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ base_position: jtp.VectorLike,
16
+ base_quaternion: jtp.VectorLike,
17
+ joint_positions: jtp.VectorLike,
18
+ base_linear_velocity: jtp.VectorLike,
19
+ base_angular_velocity: jtp.VectorLike,
20
+ joint_velocities: jtp.VectorLike,
21
+ joint_forces: jtp.VectorLike | None = None,
22
+ link_forces: jtp.MatrixLike | None = None,
23
+ standard_gravity: jtp.FloatLike = StandardGravity,
24
+ ) -> tuple[jtp.Vector, jtp.Vector]:
25
+ """
26
+ Compute forward dynamics using the Articulated Body Algorithm (ABA).
27
+
28
+ Args:
29
+ model: The model to consider.
30
+ base_position: The position of the base link.
31
+ base_quaternion: The quaternion of the base link.
32
+ joint_positions: The positions of the joints.
33
+ base_linear_velocity:
34
+ The linear velocity of the base link in inertial-fixed representation.
35
+ base_angular_velocity:
36
+ The angular velocity of the base link in inertial-fixed representation.
37
+ joint_velocities: The velocities of the joints.
38
+ joint_forces: The forces applied to the joints.
39
+ link_forces:
40
+ The forces applied to the links expressed in the world frame.
41
+ standard_gravity: The standard gravity constant.
42
+
43
+ Returns:
44
+ A tuple containing the base acceleration in inertial-fixed representation
45
+ and the joint accelerations that result from the applications of the given
46
+ joint and link forces.
47
+
48
+ Note:
49
+ The algorithm expects a quaternion with unit norm.
50
+ """
51
+
52
+ W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs(
53
+ model=model,
54
+ base_position=base_position,
55
+ base_quaternion=base_quaternion,
56
+ joint_positions=joint_positions,
57
+ base_linear_velocity=base_linear_velocity,
58
+ base_angular_velocity=base_angular_velocity,
59
+ joint_velocities=joint_velocities,
60
+ base_linear_acceleration=None,
61
+ base_angular_acceleration=None,
62
+ joint_accelerations=None,
63
+ joint_forces=joint_forces,
64
+ link_forces=link_forces,
65
+ standard_gravity=standard_gravity,
66
+ )
67
+
68
+ W_g = jnp.atleast_2d(W_g).T
69
+ W_v_WB = jnp.atleast_2d(W_v_WB).T
70
+
71
+ # Get the 6D spatial inertia matrices of all links.
72
+ M = js.model.link_spatial_inertia_matrices(model=model)
73
+
74
+ # Get the parent array λ(i).
75
+ # Note: λ(0) must not be used, it's initialized to -1.
76
+ λ = model.kin_dyn_parameters.parent_array
77
+
78
+ # Compute the base transform.
79
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
80
+ rotation=jaxlie.SO3(wxyz=W_Q_B),
81
+ translation=W_p_B,
82
+ )
83
+
84
+ # Compute 6D transforms of the base velocity.
85
+ W_X_B = W_H_B.adjoint()
86
+ B_X_W = W_H_B.inverse().adjoint()
87
+
88
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
89
+ # These transforms define the relative kinematics of the entire model, including
90
+ # the base transform for both floating-base and fixed-base models.
91
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
92
+ joint_positions=s, base_transform=W_H_B.as_matrix()
93
+ )
94
+
95
+ # Allocate buffers.
96
+ v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
97
+ c = jnp.zeros(shape=(model.number_of_links(), 6, 1))
98
+ pA = jnp.zeros(shape=(model.number_of_links(), 6, 1))
99
+ MA = jnp.zeros(shape=(model.number_of_links(), 6, 6))
100
+
101
+ # Allocate the buffer of transforms link -> base.
102
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
103
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
104
+
105
+ # Initialize base quantities.
106
+ if model.floating_base():
107
+
108
+ # Base velocity v₀ in body-fixed representation.
109
+ v_0 = B_X_W @ W_v_WB
110
+ v = v.at[0].set(v_0)
111
+
112
+ # Initialize the articulated-body inertia (Mᴬ) of base link.
113
+ MA_0 = M[0]
114
+ MA = MA.at[0].set(MA_0)
115
+
116
+ # Initialize the articulated-body bias force (pᴬ) of the base link.
117
+ pA_0 = Cross.vx_star(v[0]) @ MA[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0])
118
+ pA = pA.at[0].set(pA_0)
119
+
120
+ # ======
121
+ # Pass 1
122
+ # ======
123
+
124
+ Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
125
+ pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)
126
+
127
+ # Propagate kinematics and initialize AB inertia and AB bias forces.
128
+ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:
129
+
130
+ ii = i - 1
131
+ v, c, MA, pA, i_X_0 = carry
132
+
133
+ # Project the joint velocity into its motion subspace.
134
+ vJ = S[i] * ṡ[ii]
135
+
136
+ # Propagate the link velocity.
137
+ v_i = i_X_λi[i] @ v[λ[i]] + vJ
138
+ v = v.at[i].set(v_i)
139
+
140
+ c_i = Cross.vx(v[i]) @ vJ
141
+ c = c.at[i].set(c_i)
142
+
143
+ # Initialize the articulated-body inertia.
144
+ MA_i = jnp.array(M[i])
145
+ MA = MA.at[i].set(MA_i)
146
+
147
+ # Compute the link-to-base transform.
148
+ i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]]
149
+ i_X_0 = i_X_0.at[i].set(i_Xi_0)
150
+
151
+ # Compute link-to-world transform for the 6D force.
152
+ i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
153
+
154
+ # Initialize articulated-body bias force.
155
+ pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i])
156
+ pA = pA.at[i].set(pA_i)
157
+
158
+ return (v, c, MA, pA, i_X_0), None
159
+
160
+ (v, c, MA, pA, i_X_0), _ = (
161
+ jax.lax.scan(
162
+ f=loop_body_pass1,
163
+ init=pass_1_carry,
164
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
165
+ )
166
+ if model.number_of_links() > 1
167
+ else [(v, c, MA, pA, i_X_0), None]
168
+ )
169
+
170
+ # ======
171
+ # Pass 2
172
+ # ======
173
+
174
+ U = jnp.zeros_like(S)
175
+ d = jnp.zeros(shape=(model.number_of_links(), 1))
176
+ u = jnp.zeros(shape=(model.number_of_links(), 1))
177
+
178
+ Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
179
+ pass_2_carry: Pass2Carry = (U, d, u, MA, pA)
180
+
181
+ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
182
+
183
+ ii = i - 1
184
+ U, d, u, MA, pA = carry
185
+
186
+ U_i = MA[i] @ S[i]
187
+ U = U.at[i].set(U_i)
188
+
189
+ d_i = S[i].T @ U[i]
190
+ d = d.at[i].set(d_i.squeeze())
191
+
192
+ u_i = τ[ii] - S[i].T @ pA[i]
193
+ u = u.at[i].set(u_i.squeeze())
194
+
195
+ # Compute the articulated-body inertia and bias force of this link.
196
+ Ma = MA[i] - U[i] / d[i] @ U[i].T
197
+ pa = pA[i] + Ma @ c[i] + U[i] * (u[i] / d[i])
198
+
199
+ # Propagate them to the parent, handling the base link.
200
+ def propagate(
201
+ MA_pA: tuple[jtp.Matrix, jtp.Matrix]
202
+ ) -> tuple[jtp.Matrix, jtp.Matrix]:
203
+
204
+ MA, pA = MA_pA
205
+
206
+ MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
207
+ MA = MA.at[λ[i]].set(MA_λi)
208
+
209
+ pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
210
+ pA = pA.at[λ[i]].set(pA_λi)
211
+
212
+ return MA, pA
213
+
214
+ MA, pA = jax.lax.cond(
215
+ pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
216
+ true_fun=propagate,
217
+ false_fun=lambda MA_pA: MA_pA,
218
+ operand=(MA, pA),
219
+ )
220
+
221
+ return (U, d, u, MA, pA), None
222
+
223
+ (U, d, u, MA, pA), _ = (
224
+ jax.lax.scan(
225
+ f=loop_body_pass2,
226
+ init=pass_2_carry,
227
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
228
+ )
229
+ if model.number_of_links() > 1
230
+ else [(U, d, u, MA, pA), None]
231
+ )
232
+
233
+ # ======
234
+ # Pass 3
235
+ # ======
236
+
237
+ if model.floating_base():
238
+ a0 = jnp.linalg.solve(-MA[0], pA[0])
239
+ else:
240
+ a0 = -B_X_W @ W_g
241
+
242
+ s̈ = jnp.zeros_like(s)
243
+ a = jnp.zeros_like(v).at[0].set(a0)
244
+
245
+ Pass3Carry = tuple[jtp.Matrix, jtp.Vector]
246
+ pass_3_carry = (a, s̈)
247
+
248
+ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
249
+
250
+ ii = i - 1
251
+ a, s̈ = carry
252
+
253
+ # Propagate the link acceleration.
254
+ a_i = i_X_λi[i] @ a[λ[i]] + c[i]
255
+
256
+ # Compute the joint acceleration.
257
+ s̈_ii = (u[i] - U[i].T @ a_i) / d[i]
258
+ s̈ = s̈.at[ii].set(s̈_ii.squeeze())
259
+
260
+ # Sum the joint acceleration to the parent link acceleration.
261
+ a_i = a_i + S[i] * s̈[ii]
262
+ a = a.at[i].set(a_i)
263
+
264
+ return (a, s̈), None
265
+
266
+ (a, s̈), _ = (
267
+ jax.lax.scan(
268
+ f=loop_body_pass3,
269
+ init=pass_3_carry,
270
+ xs=jnp.arange(1, model.number_of_links()),
271
+ )
272
+ if model.number_of_links() > 1
273
+ else [(a, s̈), None]
274
+ )
275
+
276
+ # ==============
277
+ # Adjust outputs
278
+ # ==============
279
+
280
+ # TODO: remove vstack and shape=(6, 1)?
281
+ if model.floating_base():
282
+ # Convert the base acceleration to inertial-fixed representation,
283
+ # and add gravity.
284
+ B_a_WB = a[0]
285
+ W_a_WB = W_X_B @ B_a_WB + W_g
286
+ else:
287
+ W_a_WB = jnp.zeros(6)
288
+
289
+ return W_a_WB.squeeze(), jnp.atleast_1d(s̈.squeeze())