jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev366__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev366.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,113 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint, Quaternion
8
+
9
+ from . import utils
10
+
11
+
12
+ def forward_kinematics_model(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ base_position: jtp.VectorLike,
16
+ base_quaternion: jtp.VectorLike,
17
+ joint_positions: jtp.VectorLike,
18
+ ) -> jtp.Array:
19
+ """
20
+ Compute the forward kinematics.
21
+
22
+ Args:
23
+ model: The model to consider.
24
+ base_position: The position of the base link.
25
+ base_quaternion: The quaternion of the base link.
26
+ joint_positions: The positions of the joints.
27
+
28
+ Returns:
29
+ A 3D array containing the SE(3) transforms of all links belonging to the model.
30
+ """
31
+
32
+ W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs(
33
+ model=model,
34
+ base_position=base_position,
35
+ base_quaternion=base_quaternion,
36
+ joint_positions=joint_positions,
37
+ )
38
+
39
+ # Get the parent array λ(i).
40
+ # Note: λ(0) must not be used, it's initialized to -1.
41
+ λ = model.kin_dyn_parameters.parent_array
42
+
43
+ # Compute the base transform.
44
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
45
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
46
+ translation=W_p_B,
47
+ )
48
+
49
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
50
+ # These transforms define the relative kinematics of the entire model, including
51
+ # the base transform for both floating-base and fixed-base models.
52
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
53
+ joint_positions=s, base_transform=W_H_B.as_matrix()
54
+ )
55
+
56
+ # Allocate the buffer of transforms world -> link and initialize the base pose.
57
+ W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
58
+ W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
59
+
60
+ # ========================
61
+ # Propagate the kinematics
62
+ # ========================
63
+
64
+ PropagateKinematicsCarry = tuple[jtp.MatrixJax]
65
+ propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i,)
66
+
67
+ def propagate_kinematics(
68
+ carry: PropagateKinematicsCarry, i: jtp.Int
69
+ ) -> tuple[PropagateKinematicsCarry, None]:
70
+
71
+ (W_X_i,) = carry
72
+
73
+ W_X_i_i = W_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
74
+ W_X_i = W_X_i.at[i].set(W_X_i_i)
75
+
76
+ return (W_X_i,), None
77
+
78
+ (W_X_i,), _ = jax.lax.scan(
79
+ f=propagate_kinematics,
80
+ init=propagate_kinematics_carry,
81
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
82
+ )
83
+
84
+ return jax.vmap(Adjoint.to_transform)(W_X_i)
85
+
86
+
87
+ def forward_kinematics(
88
+ model: js.model.JaxSimModel,
89
+ link_index: jtp.Int,
90
+ base_position: jtp.VectorLike,
91
+ base_quaternion: jtp.VectorLike,
92
+ joint_positions: jtp.VectorLike,
93
+ ) -> jtp.Matrix:
94
+ """
95
+ Compute the forward kinematics of a specific link.
96
+
97
+ Args:
98
+ model: The model to consider.
99
+ link_index: The index of the link to consider.
100
+ base_position: The position of the base link.
101
+ base_quaternion: The quaternion of the base link.
102
+ joint_positions: The positions of the joints.
103
+
104
+ Returns:
105
+ The SE(3) transform of the link.
106
+ """
107
+
108
+ return forward_kinematics_model(
109
+ model=model,
110
+ base_position=base_position,
111
+ base_quaternion=base_quaternion,
112
+ joint_positions=joint_positions,
113
+ )[link_index]
@@ -0,0 +1,201 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint
8
+
9
+ from . import utils
10
+
11
+
12
+ def jacobian(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ link_index: jtp.Int,
16
+ joint_positions: jtp.VectorLike,
17
+ ) -> jtp.Matrix:
18
+ """
19
+ Compute the free-floating Jacobian of a link.
20
+
21
+ Args:
22
+ model: The model to consider.
23
+ link_index: The index of the link for which to compute the Jacobian matrix.
24
+ joint_positions: The positions of the joints.
25
+
26
+ Returns:
27
+ The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`.
28
+ """
29
+
30
+ _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
31
+ model=model, joint_positions=joint_positions
32
+ )
33
+
34
+ # Get the parent array λ(i).
35
+ # Note: λ(0) must not be used, it's initialized to -1.
36
+ λ = model.kin_dyn_parameters.parent_array
37
+
38
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
39
+ # These transforms define the relative kinematics of the entire model, including
40
+ # the base transform for both floating-base and fixed-base models.
41
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
42
+ joint_positions=s, base_transform=jnp.eye(4)
43
+ )
44
+
45
+ # Allocate the buffer of transforms link -> base.
46
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
47
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
48
+
49
+ # ====================
50
+ # Propagate kinematics
51
+ # ====================
52
+
53
+ PropagateKinematicsCarry = tuple[jtp.MatrixJax]
54
+ propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,)
55
+
56
+ def propagate_kinematics(
57
+ carry: PropagateKinematicsCarry, i: jtp.Int
58
+ ) -> tuple[PropagateKinematicsCarry, None]:
59
+
60
+ (i_X_0,) = carry
61
+
62
+ # Compute the base (0) to link (i) adjoint matrix.
63
+ # This works fine since we traverse the kinematic tree following the link
64
+ # indices assigned with BFS.
65
+ i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
66
+ i_X_0 = i_X_0.at[i].set(i_X_0_i)
67
+
68
+ return (i_X_0,), None
69
+
70
+ (i_X_0,), _ = jax.lax.scan(
71
+ f=propagate_kinematics,
72
+ init=propagate_kinematics_carry,
73
+ xs=np.arange(start=1, stop=model.number_of_links()),
74
+ )
75
+
76
+ # ============================
77
+ # Compute doubly-left Jacobian
78
+ # ============================
79
+
80
+ J = jnp.zeros(shape=(6, 6 + model.dofs()))
81
+
82
+ Jb = i_X_0[link_index]
83
+ J = J.at[0:6, 0:6].set(Jb)
84
+
85
+ # To make JIT happy, we operate on a boolean version of κ(i).
86
+ # Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
87
+ κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index]
88
+
89
+ def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> tuple[jtp.MatrixJax, None]:
90
+
91
+ def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:
92
+
93
+ ii = i - 1
94
+
95
+ Js_i = i_X_0[link_index] @ Adjoint.inverse(i_X_0[i]) @ S[i]
96
+ J = J.at[0:6, 6 + ii].set(Js_i.squeeze())
97
+
98
+ return J
99
+
100
+ J = jax.lax.select(
101
+ pred=κ_bool[i],
102
+ on_true=update_jacobian(J, i),
103
+ on_false=J,
104
+ )
105
+
106
+ return J, None
107
+
108
+ L_J_WL_B, _ = jax.lax.scan(
109
+ f=compute_jacobian,
110
+ init=J,
111
+ xs=np.arange(start=1, stop=model.number_of_links()),
112
+ )
113
+
114
+ return L_J_WL_B
115
+
116
+
117
+ @jax.jit
118
+ def jacobian_full_doubly_left(
119
+ model: js.model.JaxSimModel,
120
+ *,
121
+ joint_positions: jtp.VectorLike,
122
+ ) -> tuple[jtp.Matrix, jtp.Array]:
123
+ r"""
124
+ Compute the doubly-left full free-floating Jacobian of a model.
125
+
126
+ The full Jacobian is a 6x(6+n) matrix with all the columns filled.
127
+ It is useful to run the algorithm once, and then extract the link Jacobian by
128
+ filtering the columns of the full Jacobian using the support parent array
129
+ :math:`\kappa(i)` of the link.
130
+
131
+ Args:
132
+ model: The model to consider.
133
+ joint_positions: The positions of the joints.
134
+
135
+ Returns:
136
+ The doubly-left full free-floating Jacobian of a model.
137
+ """
138
+
139
+ _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
140
+ model=model, joint_positions=joint_positions
141
+ )
142
+
143
+ # Get the parent array λ(i).
144
+ # Note: λ(0) must not be used, it's initialized to -1.
145
+ λ = model.kin_dyn_parameters.parent_array
146
+
147
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
148
+ # These transforms define the relative kinematics of the entire model, including
149
+ # the base transform for both floating-base and fixed-base models.
150
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
151
+ joint_positions=s, base_transform=jnp.eye(4)
152
+ )
153
+
154
+ # Allocate the buffer of transforms base -> link.
155
+ B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
156
+ B_X_i = B_X_i.at[0].set(jnp.eye(6))
157
+
158
+ # =============================
159
+ # Compute doubly-left Jacobian
160
+ # =============================
161
+
162
+ # Allocate the Jacobian matrix.
163
+ # The Jbb section of the doubly-left Jacobian is an identity matrix.
164
+ J = jnp.zeros(shape=(6, 6 + model.dofs()))
165
+ J = J.at[0:6, 0:6].set(jnp.eye(6))
166
+
167
+ ComputeFullJacobianCarry = tuple[jtp.MatrixJax, jtp.MatrixJax]
168
+ compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J)
169
+
170
+ def compute_full_jacobian(
171
+ carry: ComputeFullJacobianCarry, i: jtp.Int
172
+ ) -> tuple[ComputeFullJacobianCarry, None]:
173
+
174
+ ii = i - 1
175
+ B_X_i, J = carry
176
+
177
+ # Compute the base (0) to link (i) adjoint matrix.
178
+ B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
179
+ B_X_i = B_X_i.at[i].set(B_Xi_i)
180
+
181
+ # Compute the ii-th column of the B_S_BL(s) matrix.
182
+ B_Sii_BL = B_Xi_i @ S[i]
183
+ J = J.at[0:6, 6 + ii].set(B_Sii_BL.squeeze())
184
+
185
+ return (B_X_i, J), None
186
+
187
+ (B_X_i, J), _ = jax.lax.scan(
188
+ f=compute_full_jacobian,
189
+ init=compute_full_jacobian_carry,
190
+ xs=np.arange(start=1, stop=model.number_of_links()),
191
+ )
192
+
193
+ # Convert adjoints to SE(3) transforms.
194
+ # Returning them here prevents calling FK in case the output representation
195
+ # of the Jacobian needs to be changed.
196
+ B_H_L = jax.vmap(lambda B_X_L: Adjoint.to_transform(B_X_L))(B_X_i)
197
+
198
+ # Adjust shape of doubly-left free-floating full Jacobian.
199
+ B_J_full_WL_B = J.squeeze().astype(float)
200
+
201
+ return B_J_full_WL_B, B_H_L
jaxsim/rbda/rnea.py ADDED
@@ -0,0 +1,237 @@
1
+ from typing import Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jaxlie
6
+
7
+ import jaxsim.api as js
8
+ import jaxsim.typing as jtp
9
+ from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
10
+
11
+ from . import utils
12
+
13
+
14
+ def rnea(
15
+ model: js.model.JaxSimModel,
16
+ *,
17
+ base_position: jtp.Vector,
18
+ base_quaternion: jtp.Vector,
19
+ joint_positions: jtp.Vector,
20
+ base_linear_velocity: jtp.Vector,
21
+ base_angular_velocity: jtp.Vector,
22
+ joint_velocities: jtp.Vector,
23
+ base_linear_acceleration: jtp.Vector | None = None,
24
+ base_angular_acceleration: jtp.Vector | None = None,
25
+ joint_accelerations: jtp.Vector | None = None,
26
+ link_forces: jtp.Matrix | None = None,
27
+ standard_gravity: jtp.FloatLike = StandardGravity,
28
+ ) -> Tuple[jtp.Vector, jtp.Vector]:
29
+ """
30
+ Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA).
31
+
32
+ Args:
33
+ model: The model to consider.
34
+ base_position: The position of the base link.
35
+ base_quaternion: The quaternion of the base link.
36
+ joint_positions: The positions of the joints.
37
+ base_linear_velocity:
38
+ The linear velocity of the base link in inertial-fixed representation.
39
+ base_angular_velocity:
40
+ The angular velocity of the base link in inertial-fixed representation.
41
+ joint_velocities: The velocities of the joints.
42
+ base_linear_acceleration:
43
+ The linear acceleration of the base link in inertial-fixed representation.
44
+ base_angular_acceleration:
45
+ The angular acceleration of the base link in inertial-fixed representation.
46
+ joint_accelerations: The accelerations of the joints.
47
+ link_forces:
48
+ The forces applied to the links expressed in the world frame.
49
+ standard_gravity: The standard gravity constant.
50
+
51
+ Returns:
52
+ A tuple containing the 6D force applied to the base link expressed in the
53
+ world frame and the joint forces that, when applied respectively to the base
54
+ link and joints, produce the given base and joint accelerations.
55
+ """
56
+
57
+ W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, _, W_f, W_g = utils.process_inputs(
58
+ model=model,
59
+ base_position=base_position,
60
+ base_quaternion=base_quaternion,
61
+ joint_positions=joint_positions,
62
+ base_linear_velocity=base_linear_velocity,
63
+ base_angular_velocity=base_angular_velocity,
64
+ joint_velocities=joint_velocities,
65
+ base_linear_acceleration=base_linear_acceleration,
66
+ base_angular_acceleration=base_angular_acceleration,
67
+ joint_accelerations=joint_accelerations,
68
+ link_forces=link_forces,
69
+ standard_gravity=standard_gravity,
70
+ )
71
+
72
+ W_g = jnp.atleast_2d(W_g).T
73
+ W_v_WB = jnp.atleast_2d(W_v_WB).T
74
+ W_v̇_WB = jnp.atleast_2d(W_v̇_WB).T
75
+
76
+ # Get the 6D spatial inertia matrices of all links.
77
+ M = js.model.link_spatial_inertia_matrices(model=model)
78
+
79
+ # Get the parent array λ(i).
80
+ # Note: λ(0) must not be used, it's initialized to -1.
81
+ λ = model.kin_dyn_parameters.parent_array
82
+
83
+ # Compute the base transform.
84
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
85
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
86
+ translation=W_p_B,
87
+ )
88
+
89
+ # Compute 6D transforms of the base velocity.
90
+ W_X_B = W_H_B.adjoint()
91
+ B_X_W = W_H_B.inverse().adjoint()
92
+
93
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
94
+ # These transforms define the relative kinematics of the entire model, including
95
+ # the base transform for both floating-base and fixed-base models.
96
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
97
+ joint_positions=s, base_transform=W_H_B.as_matrix()
98
+ )
99
+
100
+ # Allocate buffers.
101
+ v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
102
+ a = jnp.zeros(shape=(model.number_of_links(), 6, 1))
103
+ f = jnp.zeros(shape=(model.number_of_links(), 6, 1))
104
+
105
+ # Allocate the buffer of transforms link -> base.
106
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
107
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
108
+
109
+ # Initialize the acceleration of the base link.
110
+ a_0 = -B_X_W @ W_g
111
+ a = a.at[0].set(a_0)
112
+
113
+ if model.floating_base():
114
+
115
+ # Base velocity v₀ in body-fixed representation.
116
+ v_0 = B_X_W @ W_v_WB
117
+ v = v.at[0].set(v_0)
118
+
119
+ # Base acceleration a₀ in body-fixed representation w/o gravity.
120
+ a_0 = B_X_W @ (W_v̇_WB - W_g)
121
+ a = a.at[0].set(a_0)
122
+
123
+ # Force applied to the base link that produce the base acceleration w/o gravity.
124
+ f_0 = (
125
+ M[0] @ a[0]
126
+ + Cross.vx_star(v[0]) @ M[0] @ v[0]
127
+ - W_X_B.T @ jnp.vstack(W_f[0])
128
+ )
129
+ f = f.at[0].set(f_0)
130
+
131
+ # ======
132
+ # Pass 1
133
+ # ======
134
+
135
+ ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
136
+ forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)
137
+
138
+ def forward_pass(
139
+ carry: ForwardPassCarry, i: jtp.Int
140
+ ) -> Tuple[ForwardPassCarry, None]:
141
+
142
+ ii = i - 1
143
+ v, a, i_X_0, f = carry
144
+
145
+ # Project the joint velocity into its motion subspace.
146
+ vJ = S[i] * ṡ[ii]
147
+
148
+ # Propagate the link velocity.
149
+ v_i = i_X_λi[i] @ v[λ[i]] + vJ
150
+ v = v.at[i].set(v_i)
151
+
152
+ # Propagate the link acceleration.
153
+ a_i = i_X_λi[i] @ a[λ[i]] + S[i] * s̈[ii] + Cross.vx(v[i]) @ vJ
154
+ a = a.at[i].set(a_i)
155
+
156
+ # Compute the link-to-base transform.
157
+ i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
158
+ i_X_0 = i_X_0.at[i].set(i_X_0_i)
159
+
160
+ # Compute link-to-world transform for the 6D force.
161
+ i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
162
+
163
+ # Compute the force acting on the link.
164
+ f_i = (
165
+ M[i] @ a[i]
166
+ + Cross.vx_star(v[i]) @ M[i] @ v[i]
167
+ - i_Xf_W @ jnp.vstack(W_f[i])
168
+ )
169
+ f = f.at[i].set(f_i)
170
+
171
+ return (v, a, i_X_0, f), None
172
+
173
+ (v, a, i_X_0, f), _ = (
174
+ jax.lax.scan(
175
+ f=forward_pass,
176
+ init=forward_pass_carry,
177
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
178
+ )
179
+ if model.number_of_links() > 1
180
+ else [(v, a, i_X_0, f), None]
181
+ )
182
+
183
+ # ======
184
+ # Pass 2
185
+ # ======
186
+
187
+ τ = jnp.zeros_like(s)
188
+
189
+ BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
190
+ backward_pass_carry: BackwardPassCarry = (τ, f)
191
+
192
+ def backward_pass(
193
+ carry: BackwardPassCarry, i: jtp.Int
194
+ ) -> Tuple[BackwardPassCarry, None]:
195
+
196
+ ii = i - 1
197
+ τ, f = carry
198
+
199
+ # Project the 6D force to the DoF of the joint.
200
+ τ_i = S[i].T @ f[i]
201
+ τ = τ.at[ii].set(τ_i.squeeze())
202
+
203
+ # Propagate the force to the parent link.
204
+ def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax:
205
+
206
+ f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
207
+ f = f.at[λ[i]].set(f_λi)
208
+
209
+ return f
210
+
211
+ f = jax.lax.cond(
212
+ pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
213
+ true_fun=update_f,
214
+ false_fun=lambda f: f,
215
+ operand=f,
216
+ )
217
+
218
+ return (τ, f), None
219
+
220
+ (τ, f), _ = (
221
+ jax.lax.scan(
222
+ f=backward_pass,
223
+ init=backward_pass_carry,
224
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
225
+ )
226
+ if model.number_of_links() > 1
227
+ else [(τ, f), None]
228
+ )
229
+
230
+ # ==============
231
+ # Adjust outputs
232
+ # ==============
233
+
234
+ # Express the base 6D force in the world frame.
235
+ W_f0 = B_X_W.T @ f[0]
236
+
237
+ return W_f0.squeeze(), jnp.atleast_1d(τ.squeeze())