jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev366__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev366.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/top_level.txt +0 -0
jaxsim/rbda/aba.py ADDED
@@ -0,0 +1,295 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
8
+
9
+ from . import utils
10
+
11
+
12
+ def aba(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ base_position: jtp.VectorLike,
16
+ base_quaternion: jtp.VectorLike,
17
+ joint_positions: jtp.VectorLike,
18
+ base_linear_velocity: jtp.VectorLike,
19
+ base_angular_velocity: jtp.VectorLike,
20
+ joint_velocities: jtp.VectorLike,
21
+ joint_forces: jtp.VectorLike | None = None,
22
+ link_forces: jtp.MatrixLike | None = None,
23
+ standard_gravity: jtp.FloatLike = StandardGravity,
24
+ ) -> tuple[jtp.Vector, jtp.Vector]:
25
+ """
26
+ Compute forward dynamics using the Articulated Body Algorithm (ABA).
27
+
28
+ Args:
29
+ model: The model to consider.
30
+ base_position: The position of the base link.
31
+ base_quaternion: The quaternion of the base link.
32
+ joint_positions: The positions of the joints.
33
+ base_linear_velocity:
34
+ The linear velocity of the base link in inertial-fixed representation.
35
+ base_angular_velocity:
36
+ The angular velocity of the base link in inertial-fixed representation.
37
+ joint_velocities: The velocities of the joints.
38
+ joint_forces: The forces applied to the joints.
39
+ link_forces:
40
+ The forces applied to the links expressed in the world frame.
41
+ standard_gravity: The standard gravity constant.
42
+
43
+ Returns:
44
+ A tuple containing the base acceleration in inertial-fixed representation
45
+ and the joint accelerations that result from the applications of the given
46
+ joint and link forces.
47
+
48
+ Note:
49
+ The algorithm expects a quaternion with unit norm.
50
+ """
51
+
52
+ W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs(
53
+ model=model,
54
+ base_position=base_position,
55
+ base_quaternion=base_quaternion,
56
+ joint_positions=joint_positions,
57
+ base_linear_velocity=base_linear_velocity,
58
+ base_angular_velocity=base_angular_velocity,
59
+ joint_velocities=joint_velocities,
60
+ base_linear_acceleration=None,
61
+ base_angular_acceleration=None,
62
+ joint_accelerations=None,
63
+ joint_forces=joint_forces,
64
+ link_forces=link_forces,
65
+ standard_gravity=standard_gravity,
66
+ )
67
+
68
+ W_g = jnp.atleast_2d(W_g).T
69
+ W_v_WB = jnp.atleast_2d(W_v_WB).T
70
+
71
+ # Get the 6D spatial inertia matrices of all links.
72
+ M = js.model.link_spatial_inertia_matrices(model=model)
73
+
74
+ # Get the parent array λ(i).
75
+ # Note: λ(0) must not be used, it's initialized to -1.
76
+ λ = model.kin_dyn_parameters.parent_array
77
+
78
+ # Compute the base transform.
79
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
80
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
81
+ translation=W_p_B,
82
+ )
83
+
84
+ # Compute 6D transforms of the base velocity.
85
+ W_X_B = W_H_B.adjoint()
86
+ B_X_W = W_H_B.inverse().adjoint()
87
+
88
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
89
+ # These transforms define the relative kinematics of the entire model, including
90
+ # the base transform for both floating-base and fixed-base models.
91
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
92
+ joint_positions=s, base_transform=W_H_B.as_matrix()
93
+ )
94
+
95
+ # Allocate buffers.
96
+ v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
97
+ c = jnp.zeros(shape=(model.number_of_links(), 6, 1))
98
+ pA = jnp.zeros(shape=(model.number_of_links(), 6, 1))
99
+ MA = jnp.zeros(shape=(model.number_of_links(), 6, 6))
100
+
101
+ # Allocate the buffer of transforms link -> base.
102
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
103
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
104
+
105
+ # Initialize base quantities
106
+ if model.floating_base():
107
+
108
+ # Base velocity v₀ in body-fixed representation.
109
+ v_0 = B_X_W @ W_v_WB
110
+ v = v.at[0].set(v_0)
111
+
112
+ # Initialize the articulated-body inertia (Mᴬ) of base link.
113
+ MA_0 = M[0]
114
+ MA = MA.at[0].set(MA_0)
115
+
116
+ # Initialize the articulated-body bias force (pᴬ) of the base link.
117
+ pA_0 = Cross.vx_star(v[0]) @ MA[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0])
118
+ pA = pA.at[0].set(pA_0)
119
+
120
+ # ======
121
+ # Pass 1
122
+ # ======
123
+
124
+ Pass1Carry = tuple[
125
+ jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
126
+ ]
127
+
128
+ pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)
129
+
130
+ # Propagate kinematics and initialize AB inertia and AB bias forces.
131
+ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:
132
+
133
+ ii = i - 1
134
+ v, c, MA, pA, i_X_0 = carry
135
+
136
+ # Project the joint velocity into its motion subspace.
137
+ vJ = S[i] * ṡ[ii]
138
+
139
+ # Propagate the link velocity.
140
+ v_i = i_X_λi[i] @ v[λ[i]] + vJ
141
+ v = v.at[i].set(v_i)
142
+
143
+ c_i = Cross.vx(v[i]) @ vJ
144
+ c = c.at[i].set(c_i)
145
+
146
+ # Initialize the articulated-body inertia.
147
+ MA_i = jnp.array(M[i])
148
+ MA = MA.at[i].set(MA_i)
149
+
150
+ # Compute the link-to-base transform.
151
+ i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]]
152
+ i_X_0 = i_X_0.at[i].set(i_Xi_0)
153
+
154
+ # Compute link-to-world transform for the 6D force.
155
+ i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
156
+
157
+ # Initialize articulated-body bias force.
158
+ pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i])
159
+ pA = pA.at[i].set(pA_i)
160
+
161
+ return (v, c, MA, pA, i_X_0), None
162
+
163
+ (v, c, MA, pA, i_X_0), _ = (
164
+ jax.lax.scan(
165
+ f=loop_body_pass1,
166
+ init=pass_1_carry,
167
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
168
+ )
169
+ if model.number_of_links() > 1
170
+ else [(v, c, MA, pA, i_X_0), None]
171
+ )
172
+
173
+ # ======
174
+ # Pass 2
175
+ # ======
176
+
177
+ U = jnp.zeros_like(S)
178
+ d = jnp.zeros(shape=(model.number_of_links(), 1))
179
+ u = jnp.zeros(shape=(model.number_of_links(), 1))
180
+
181
+ Pass2Carry = tuple[
182
+ jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
183
+ ]
184
+
185
+ pass_2_carry: Pass2Carry = (U, d, u, MA, pA)
186
+
187
+ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
188
+
189
+ ii = i - 1
190
+ U, d, u, MA, pA = carry
191
+
192
+ U_i = MA[i] @ S[i]
193
+ U = U.at[i].set(U_i)
194
+
195
+ d_i = S[i].T @ U[i]
196
+ d = d.at[i].set(d_i.squeeze())
197
+
198
+ u_i = τ[ii] - S[i].T @ pA[i]
199
+ u = u.at[i].set(u_i.squeeze())
200
+
201
+ # Compute the articulated-body inertia and bias force of this link.
202
+ Ma = MA[i] - U[i] / d[i] @ U[i].T
203
+ pa = pA[i] + Ma @ c[i] + U[i] * (u[i] / d[i])
204
+
205
+ # Propagate them to the parent, handling the base link.
206
+ def propagate(
207
+ MA_pA: tuple[jtp.MatrixJax, jtp.MatrixJax]
208
+ ) -> tuple[jtp.MatrixJax, jtp.MatrixJax]:
209
+
210
+ MA, pA = MA_pA
211
+
212
+ MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
213
+ MA = MA.at[λ[i]].set(MA_λi)
214
+
215
+ pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
216
+ pA = pA.at[λ[i]].set(pA_λi)
217
+
218
+ return MA, pA
219
+
220
+ MA, pA = jax.lax.cond(
221
+ pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
222
+ true_fun=propagate,
223
+ false_fun=lambda MA_pA: MA_pA,
224
+ operand=(MA, pA),
225
+ )
226
+
227
+ return (U, d, u, MA, pA), None
228
+
229
+ (U, d, u, MA, pA), _ = (
230
+ jax.lax.scan(
231
+ f=loop_body_pass2,
232
+ init=pass_2_carry,
233
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
234
+ )
235
+ if model.number_of_links() > 1
236
+ else [(U, d, u, MA, pA), None]
237
+ )
238
+
239
+ # ======
240
+ # Pass 3
241
+ # ======
242
+
243
+ if model.floating_base():
244
+ a0 = jnp.linalg.solve(-MA[0], pA[0])
245
+ else:
246
+ a0 = -B_X_W @ W_g
247
+
248
+ s̈ = jnp.zeros_like(s)
249
+ a = jnp.zeros_like(v).at[0].set(a0)
250
+
251
+ Pass3Carry = tuple[jtp.MatrixJax, jtp.VectorJax]
252
+ pass_3_carry = (a, s̈)
253
+
254
+ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
255
+
256
+ ii = i - 1
257
+ a, s̈ = carry
258
+
259
+ # Propagate the link acceleration.
260
+ a_i = i_X_λi[i] @ a[λ[i]] + c[i]
261
+
262
+ # Compute the joint acceleration.
263
+ s̈_ii = (u[i] - U[i].T @ a_i) / d[i]
264
+ s̈ = s̈.at[ii].set(s̈_ii.squeeze())
265
+
266
+ # Sum the joint acceleration to the parent link acceleration.
267
+ a_i = a_i + S[i] * s̈[ii]
268
+ a = a.at[i].set(a_i)
269
+
270
+ return (a, s̈), None
271
+
272
+ (a, s̈), _ = (
273
+ jax.lax.scan(
274
+ f=loop_body_pass3,
275
+ init=pass_3_carry,
276
+ xs=jnp.arange(1, model.number_of_links()),
277
+ )
278
+ if model.number_of_links() > 1
279
+ else [(a, s̈), None]
280
+ )
281
+
282
+ # ==============
283
+ # Adjust outputs
284
+ # ==============
285
+
286
+ # TODO: remove vstack and shape=(6, 1)?
287
+ if model.floating_base():
288
+ # Convert the base acceleration to inertial-fixed representation,
289
+ # and add gravity.
290
+ B_a_WB = a[0]
291
+ W_a_WB = W_X_B @ B_a_WB + W_g
292
+ else:
293
+ W_a_WB = jnp.zeros(6)
294
+
295
+ return W_a_WB.squeeze(), jnp.atleast_1d(s̈.squeeze())
@@ -0,0 +1,142 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint, Quaternion, Skew
8
+
9
+ from . import utils
10
+
11
+
12
+ def collidable_points_pos_vel(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ base_position: jtp.Vector,
16
+ base_quaternion: jtp.Vector,
17
+ joint_positions: jtp.Vector,
18
+ base_linear_velocity: jtp.Vector,
19
+ base_angular_velocity: jtp.Vector,
20
+ joint_velocities: jtp.Vector,
21
+ ) -> tuple[jtp.Matrix, jtp.Matrix]:
22
+ """
23
+
24
+ Compute the position and linear velocity of collidable points in the world frame.
25
+
26
+ Args:
27
+ model: The model to consider.
28
+ base_position: The position of the base link.
29
+ base_quaternion: The quaternion of the base link.
30
+ joint_positions: The positions of the joints.
31
+ base_linear_velocity:
32
+ The linear velocity of the base link in inertial-fixed representation.
33
+ base_angular_velocity:
34
+ The angular velocity of the base link in inertial-fixed representation.
35
+ joint_velocities: The velocities of the joints.
36
+
37
+ Returns:
38
+ A tuple containing the position and linear velocity of collidable points.
39
+ """
40
+
41
+ if len(model.kin_dyn_parameters.contact_parameters.body) == 0:
42
+ return jnp.array(0).astype(float), jnp.empty(0).astype(float)
43
+
44
+ W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
45
+ model=model,
46
+ base_position=base_position,
47
+ base_quaternion=base_quaternion,
48
+ joint_positions=joint_positions,
49
+ base_linear_velocity=base_linear_velocity,
50
+ base_angular_velocity=base_angular_velocity,
51
+ joint_velocities=joint_velocities,
52
+ )
53
+
54
+ # Get the parent array λ(i).
55
+ # Note: λ(0) must not be used, it's initialized to -1.
56
+ λ = model.kin_dyn_parameters.parent_array
57
+
58
+ # Compute the base transform.
59
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
60
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
61
+ translation=W_p_B,
62
+ )
63
+
64
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
65
+ # These transforms define the relative kinematics of the entire model, including
66
+ # the base transform for both floating-base and fixed-base models.
67
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
68
+ joint_positions=s, base_transform=W_H_B.as_matrix()
69
+ )
70
+
71
+ # Allocate buffer of transforms world -> link and initialize the base pose.
72
+ W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
73
+ W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
74
+
75
+ # Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity.
76
+ W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6))
77
+ W_v_Wi = W_v_Wi.at[0].set(W_v_WB)
78
+
79
+ # ====================
80
+ # Propagate kinematics
81
+ # ====================
82
+
83
+ PropagateTransformsCarry = tuple[jtp.MatrixJax, jtp.Matrix]
84
+ propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)
85
+
86
+ def propagate_kinematics(
87
+ carry: PropagateTransformsCarry, i: jtp.Int
88
+ ) -> tuple[PropagateTransformsCarry, None]:
89
+
90
+ ii = i - 1
91
+ W_X_i, W_v_Wi = carry
92
+
93
+ # Compute the parent to child 6D transform.
94
+ λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i])
95
+
96
+ # Compute the world to child 6D transform.
97
+ W_Xi_i = W_X_i[λ[i]] @ λi_X_i
98
+ W_X_i = W_X_i.at[i].set(W_Xi_i)
99
+
100
+ # Propagate the 6D velocity
101
+ W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
102
+ W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
103
+
104
+ return (W_X_i, W_v_Wi), None
105
+
106
+ (W_X_i, W_v_Wi), _ = (
107
+ jax.lax.scan(
108
+ f=propagate_kinematics,
109
+ init=propagate_transforms_carry,
110
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
111
+ )
112
+ if model.number_of_links() > 1
113
+ else [(W_X_i, W_v_Wi), None]
114
+ )
115
+
116
+ # ==================================================
117
+ # Compute position and velocity of collidable points
118
+ # ==================================================
119
+
120
+ def process_point_kinematics(
121
+ Li_p_C: jtp.VectorJax, parent_body: jtp.Int
122
+ ) -> tuple[jtp.VectorJax, jtp.VectorJax]:
123
+ # Compute the position of the collidable point
124
+ W_p_Ci = (
125
+ Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
126
+ )[0:3]
127
+
128
+ # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}
129
+ CW_vl_WCi = (
130
+ jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
131
+ @ W_v_Wi[parent_body].squeeze()
132
+ )
133
+
134
+ return W_p_Ci, CW_vl_WCi
135
+
136
+ # Process all the collidable points in parallel
137
+ W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
138
+ model.kin_dyn_parameters.contact_parameters.point,
139
+ jnp.array(model.kin_dyn_parameters.contact_parameters.body),
140
+ )
141
+
142
+ return W_p_Ci, CW_vl_WC
@@ -1,68 +1,68 @@
1
- from typing import Tuple
2
-
3
1
  import jax
4
2
  import jax.numpy as jnp
5
- import numpy as np
6
3
 
4
+ import jaxsim.api as js
7
5
  import jaxsim.typing as jtp
8
- from jaxsim.physics.model.physics_model import PhysicsModel
9
6
 
10
7
  from . import utils
11
8
 
12
9
 
13
- def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
10
+ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Matrix:
14
11
  """
15
- Compute the Composite Rigid-Body Inertia Matrix (CRBA) for an articulated body or robot given joint positions.
12
+ Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA).
16
13
 
17
14
  Args:
18
- model (PhysicsModel): The physics model of the articulated body or robot.
19
- q (jtp.Vector): Joint positions (Generalized coordinates).
15
+ model: The model to consider.
16
+ joint_positions: The positions of the joints.
20
17
 
21
18
  Returns:
22
- jtp.Matrix: The Composite Rigid-Body Inertia Matrix (CRBA) of the articulated body or robot.
19
+ The free-floating mass matrix of the model in body-fixed representation.
23
20
  """
24
21
 
25
- _, q, _, _, _, _ = utils.process_inputs(
26
- physics_model=model, xfb=None, q=q, qd=None, tau=None, f_ext=None
22
+ _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
23
+ model=model, joint_positions=joint_positions
27
24
  )
28
25
 
29
- Xtree = model.tree_transforms
30
- Mc = model.spatial_inertias
31
- S = model.motion_subspaces(q=q)
32
- Xj = model.joint_transforms(q=q)
26
+ # Get the 6D spatial inertia matrices of all links.
27
+ Mc = js.model.link_spatial_inertia_matrices(model=model)
33
28
 
34
- Xup = jnp.zeros_like(Xtree)
35
- i_X_0 = jnp.zeros_like(Xtree)
36
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
29
+ # Get the parent array λ(i).
30
+ # Note: λ(0) must not be used, it's initialized to -1.
31
+ λ = model.kin_dyn_parameters.parent_array
32
+
33
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
34
+ # These transforms define the relative kinematics of the entire model, including
35
+ # the base transform for both floating-base and fixed-base models.
36
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
37
+ joint_positions=s, base_transform=jnp.eye(4)
38
+ )
37
39
 
38
- # Parent array mapping: i -> λ(i).
39
- # Exception: λ(0) must not be used, it's initialized to -1.
40
- λ = model.parent
40
+ # Allocate the buffer of transforms link -> base.
41
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
42
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
41
43
 
42
44
  # ====================
43
45
  # Propagate kinematics
44
46
  # ====================
45
47
 
46
- ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
47
- forward_pass_carry = (Xup, i_X_0)
48
+ ForwardPassCarry = tuple[jtp.MatrixJax]
49
+ forward_pass_carry: ForwardPassCarry = (i_X_0,)
48
50
 
49
51
  def propagate_kinematics(
50
52
  carry: ForwardPassCarry, i: jtp.Int
51
- ) -> Tuple[ForwardPassCarry, None]:
52
- Xup, i_X_0 = carry
53
+ ) -> tuple[ForwardPassCarry, None]:
53
54
 
54
- Xup_i = Xj[i] @ Xtree[i]
55
- Xup = Xup.at[i].set(Xup_i)
55
+ (i_X_0,) = carry
56
56
 
57
- i_X_0_i = Xup[i] @ i_X_0[λ[i]]
57
+ i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
58
58
  i_X_0 = i_X_0.at[i].set(i_X_0_i)
59
59
 
60
- return (Xup, i_X_0), None
60
+ return (i_X_0,), None
61
61
 
62
- (Xup, i_X_0), _ = jax.lax.scan(
62
+ (i_X_0,), _ = jax.lax.scan(
63
63
  f=propagate_kinematics,
64
64
  init=forward_pass_carry,
65
- xs=np.arange(start=1, stop=model.NB),
65
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
66
66
  )
67
67
 
68
68
  # ===================
@@ -71,16 +71,17 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
71
71
 
72
72
  M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))
73
73
 
74
- BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
75
- backward_pass_carry = (Mc, M)
74
+ BackwardPassCarry = tuple[jtp.MatrixJax, jtp.MatrixJax]
75
+ backward_pass_carry: BackwardPassCarry = (Mc, M)
76
76
 
77
77
  def backward_pass(
78
78
  carry: BackwardPassCarry, i: jtp.Int
79
- ) -> Tuple[BackwardPassCarry, None]:
79
+ ) -> tuple[BackwardPassCarry, None]:
80
+
80
81
  ii = i - 1
81
82
  Mc, M = carry
82
83
 
83
- Mc_λi = Mc[λ[i]] + Xup[i].T @ Mc[i] @ Xup[i]
84
+ Mc_λi = Mc[λ[i]] + i_X_λi[i].T @ Mc[i] @ i_X_λi[i]
84
85
  Mc = Mc.at[λ[i]].set(Mc_λi)
85
86
 
86
87
  Fi = Mc[i] @ S[i]
@@ -89,13 +90,13 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
89
90
 
90
91
  j = i
91
92
 
92
- CarryInnerFn = Tuple[jtp.Int, jtp.MatrixJax, jtp.MatrixJax]
93
+ CarryInnerFn = tuple[jtp.Int, jtp.MatrixJax, jtp.MatrixJax]
93
94
  carry_inner_fn = (j, Fi, M)
94
95
 
95
96
  def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn:
96
97
  j, Fi, M = carry
97
98
 
98
- Fi = Xup[j].T @ Fi
99
+ Fi = i_X_λi[j].T @ Fi
99
100
  j = λ[j]
100
101
  jj = j - 1
101
102
 
@@ -108,8 +109,8 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
108
109
 
109
110
  # The following functions are part of a (rather messy) workaround for computing
110
111
  # a while loop using a for loop with fixed number of iterations.
111
- def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> Tuple[CarryInnerFn, None]:
112
- def compute_inner(carry: CarryInnerFn) -> Tuple[CarryInnerFn, None]:
112
+ def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> tuple[CarryInnerFn, None]:
113
+ def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]:
113
114
  j, Fi, M = carry
114
115
  out = jax.lax.cond(
115
116
  pred=(λ[j] > 0),
@@ -130,7 +131,7 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
130
131
  (j, Fi, M), _ = jax.lax.scan(
131
132
  f=inner_fn,
132
133
  init=carry_inner_fn,
133
- xs=np.flip(np.arange(start=1, stop=model.NB)),
134
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
134
135
  )
135
136
 
136
137
  Fi = i_X_0[j].T @ Fi
@@ -145,10 +146,10 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
145
146
  (Mc, M), _ = jax.lax.scan(
146
147
  f=backward_pass,
147
148
  init=backward_pass_carry,
148
- xs=np.flip(np.arange(start=1, stop=model.NB)),
149
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
149
150
  )
150
151
 
151
- # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶
152
+ # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.
152
153
  M = M.at[0:6, 0:6].set(Mc[0])
153
154
 
154
155
  return M