numkong 7.5.0 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/binding.gyp +18 -0
  2. package/c/dispatch_e5m2.c +23 -3
  3. package/include/numkong/capabilities.h +1 -1
  4. package/include/numkong/cast/README.md +3 -0
  5. package/include/numkong/cast/haswell.h +28 -64
  6. package/include/numkong/cast/serial.h +17 -0
  7. package/include/numkong/cast/skylake.h +67 -52
  8. package/include/numkong/cast.h +1 -0
  9. package/include/numkong/dot/README.md +1 -0
  10. package/include/numkong/dot/haswell.h +92 -13
  11. package/include/numkong/dot/serial.h +15 -0
  12. package/include/numkong/dot/skylake.h +61 -14
  13. package/include/numkong/dots/README.md +2 -0
  14. package/include/numkong/dots/graniteamx.h +434 -0
  15. package/include/numkong/dots/haswell.h +28 -28
  16. package/include/numkong/dots/sapphireamx.h +1 -1
  17. package/include/numkong/dots/serial.h +23 -8
  18. package/include/numkong/dots/skylake.h +28 -23
  19. package/include/numkong/dots.h +12 -0
  20. package/include/numkong/each/serial.h +18 -1
  21. package/include/numkong/geospatial/serial.h +14 -3
  22. package/include/numkong/maxsim/serial.h +15 -0
  23. package/include/numkong/mesh/README.md +50 -44
  24. package/include/numkong/mesh/genoa.h +462 -0
  25. package/include/numkong/mesh/haswell.h +806 -933
  26. package/include/numkong/mesh/neon.h +871 -943
  27. package/include/numkong/mesh/neonbfdot.h +382 -522
  28. package/include/numkong/mesh/neonfhm.h +676 -0
  29. package/include/numkong/mesh/rvv.h +404 -319
  30. package/include/numkong/mesh/serial.h +204 -162
  31. package/include/numkong/mesh/skylake.h +1029 -1585
  32. package/include/numkong/mesh/v128relaxed.h +403 -377
  33. package/include/numkong/mesh.h +38 -0
  34. package/include/numkong/reduce/serial.h +15 -1
  35. package/include/numkong/sparse/serial.h +17 -2
  36. package/include/numkong/spatial/genoa.h +0 -68
  37. package/include/numkong/spatial/haswell.h +98 -56
  38. package/include/numkong/spatial/serial.h +15 -0
  39. package/include/numkong/spatial/skylake.h +114 -54
  40. package/include/numkong/spatial.h +0 -12
  41. package/include/numkong/spatials/graniteamx.h +128 -0
  42. package/include/numkong/spatials/serial.h +18 -1
  43. package/include/numkong/spatials/skylake.h +2 -2
  44. package/include/numkong/spatials.h +17 -0
  45. package/include/numkong/tensor.hpp +107 -23
  46. package/javascript/numkong.c +3 -2
  47. package/package.json +7 -7
  48. package/wasm/numkong.wasm +0 -0
@@ -8,24 +8,28 @@
8
8
  *
9
9
  * @section mesh_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
10
10
  *
11
- * Intrinsic Instruction A76 M5
12
- * vld3_u16 LD3 (V.4H x 3) 4cy @ 1p 4cy @ 1p
13
- * vshll_n_u16 USHLL (V.4S, V.4H, #16) 2cy @ 2p 2cy @ 4p
14
- * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
15
- * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
16
- * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
17
- * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
18
- * vdupq_n_f32 DUP (V.4S, scalar) 2cy @ 2p 2cy @ 4p
19
- * vaddvq_f32 FADDP+FADDP (V.4S) 5cy @ 1p 8cy @ 1p
11
+ * Intrinsic Instruction A76 M5
12
+ * vld3q_u16 LD3 (V.8H x 3) 4cy @ 1p 4cy @ 1p
13
+ * vld3_u16 LD3 (V.4H x 3) 4cy @ 1p 4cy @ 1p
14
+ * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy @ 2p 2cy @ 1p
15
+ * vshll_n_u16 USHLL (V.4S, V.4H, #16) 2cy @ 2p 2cy @ 4p
16
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
17
+ * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
18
+ * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
19
+ * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
20
+ * vdupq_n_f32 DUP (V.4S, scalar) 2cy @ 2p 2cy @ 4p
21
+ * vaddvq_f32 FADDP+FADDP (V.4S) 5cy @ 1p 8cy @ 1p
20
22
  *
21
23
  * The ARMv8.6-BF16 extension enables BF16 storage with F32 computation for 3D mesh alignment
22
24
  * operations. BF16's wider exponent range (matching F32) prevents overflow in geometric calculations
23
25
  * while halving memory bandwidth compared to F32.
24
26
  *
25
- * For point cloud registration (RMSD, Kabsch, Umeyama), BF16 data is loaded using VLD3 de-interleave
26
- * operations, converted to F32 via bit-shift widening, then processed with F32 FMA chains. The 2x
27
- * unrolling with dual accumulators hides the 4-cycle FMA latency, achieving near-peak throughput
28
- * for covariance matrix and centroid computations.
27
+ * For point cloud registration (Kabsch, Umeyama), BF16 data is loaded using VLD3 de-interleave
28
+ * operations and processed directly with BFDOT (`vbfdotq_f32`), which computes two BF16 products
29
+ * per 32-bit lane with FP32 accumulation. This skips the explicit bf16→f32 widening that the
30
+ * prior vshll+fmaq approach required, halving the front-end pressure on the covariance/centroid
31
+ * stats pass. RMSD keeps the widen+subtract+fmaq pipeline because it needs the (a - b) difference
32
+ * before squaring, which BFDOT can't express directly.
29
33
  */
30
34
  #ifndef NK_MESH_NEONBFDOT_H
31
35
  #define NK_MESH_NEONBFDOT_H
@@ -34,6 +38,7 @@
34
38
  #if NK_TARGET_NEONBFDOT
35
39
 
36
40
  #include "numkong/types.h"
41
+ #include "numkong/cast/neon.h" // `nk_u16x8_splat_`
37
42
  #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
38
43
 
39
44
  #if defined(__cplusplus)
@@ -75,191 +80,6 @@ NK_INTERNAL void nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(nk_bf16_t cons
75
80
  nk_deinterleave_bf16x4_to_f32x4_neonbfdot_((nk_bf16_t const *)buf, x_out, y_out, z_out);
76
81
  }
77
82
 
78
- /* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
79
- * Loads bf16 data, converts to f32 during processing.
80
- * Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
81
- */
82
- NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_neonbfdot_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
83
- nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
84
- nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
85
- nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
86
- nk_f32_t centroid_b_z) {
87
- // Broadcast scaled rotation matrix elements
88
- float32x4_t scaled_rotation_x_x_f32x4 = vdupq_n_f32(scale * r[0]);
89
- float32x4_t scaled_rotation_x_y_f32x4 = vdupq_n_f32(scale * r[1]);
90
- float32x4_t scaled_rotation_x_z_f32x4 = vdupq_n_f32(scale * r[2]);
91
- float32x4_t scaled_rotation_y_x_f32x4 = vdupq_n_f32(scale * r[3]);
92
- float32x4_t scaled_rotation_y_y_f32x4 = vdupq_n_f32(scale * r[4]);
93
- float32x4_t scaled_rotation_y_z_f32x4 = vdupq_n_f32(scale * r[5]);
94
- float32x4_t scaled_rotation_z_x_f32x4 = vdupq_n_f32(scale * r[6]);
95
- float32x4_t scaled_rotation_z_y_f32x4 = vdupq_n_f32(scale * r[7]);
96
- float32x4_t scaled_rotation_z_z_f32x4 = vdupq_n_f32(scale * r[8]);
97
-
98
- // Broadcast centroids
99
- float32x4_t centroid_a_x_f32x4 = vdupq_n_f32(centroid_a_x);
100
- float32x4_t centroid_a_y_f32x4 = vdupq_n_f32(centroid_a_y);
101
- float32x4_t centroid_a_z_f32x4 = vdupq_n_f32(centroid_a_z);
102
- float32x4_t centroid_b_x_f32x4 = vdupq_n_f32(centroid_b_x);
103
- float32x4_t centroid_b_y_f32x4 = vdupq_n_f32(centroid_b_y);
104
- float32x4_t centroid_b_z_f32x4 = vdupq_n_f32(centroid_b_z);
105
-
106
- // Two independent accumulators to hide FMA latency
107
- float32x4_t sum_squared_a_f32x4 = vdupq_n_f32(0);
108
- float32x4_t sum_squared_b_f32x4 = vdupq_n_f32(0);
109
- nk_size_t j = 0;
110
-
111
- // Main loop: process 8 points per iteration (2x unrolled)
112
- for (; j + 8 <= n; j += 8) {
113
- // First batch of 4 points
114
- float32x4_t a1_x_f32x4, a1_y_f32x4, a1_z_f32x4, b1_x_f32x4, b1_y_f32x4, b1_z_f32x4;
115
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + j * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
116
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + j * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
117
-
118
- // Second batch of 4 points
119
- float32x4_t a2_x_f32x4, a2_y_f32x4, a2_z_f32x4, b2_x_f32x4, b2_y_f32x4, b2_z_f32x4;
120
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + (j + 4) * 3, &a2_x_f32x4, &a2_y_f32x4, &a2_z_f32x4);
121
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + (j + 4) * 3, &b2_x_f32x4, &b2_y_f32x4, &b2_z_f32x4);
122
-
123
- // Center first batch
124
- float32x4_t pa1_x_f32x4 = vsubq_f32(a1_x_f32x4, centroid_a_x_f32x4);
125
- float32x4_t pa1_y_f32x4 = vsubq_f32(a1_y_f32x4, centroid_a_y_f32x4);
126
- float32x4_t pa1_z_f32x4 = vsubq_f32(a1_z_f32x4, centroid_a_z_f32x4);
127
- float32x4_t pb1_x_f32x4 = vsubq_f32(b1_x_f32x4, centroid_b_x_f32x4);
128
- float32x4_t pb1_y_f32x4 = vsubq_f32(b1_y_f32x4, centroid_b_y_f32x4);
129
- float32x4_t pb1_z_f32x4 = vsubq_f32(b1_z_f32x4, centroid_b_z_f32x4);
130
-
131
- // Center second batch
132
- float32x4_t pa2_x_f32x4 = vsubq_f32(a2_x_f32x4, centroid_a_x_f32x4);
133
- float32x4_t pa2_y_f32x4 = vsubq_f32(a2_y_f32x4, centroid_a_y_f32x4);
134
- float32x4_t pa2_z_f32x4 = vsubq_f32(a2_z_f32x4, centroid_a_z_f32x4);
135
- float32x4_t pb2_x_f32x4 = vsubq_f32(b2_x_f32x4, centroid_b_x_f32x4);
136
- float32x4_t pb2_y_f32x4 = vsubq_f32(b2_y_f32x4, centroid_b_y_f32x4);
137
- float32x4_t pb2_z_f32x4 = vsubq_f32(b2_z_f32x4, centroid_b_z_f32x4);
138
-
139
- // Rotate and scale first batch: ra1 = scale * R * pa1
140
- float32x4_t ra1_x_f32x4 = vfmaq_f32(
141
- vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa1_x_f32x4), scaled_rotation_x_y_f32x4, pa1_y_f32x4),
142
- scaled_rotation_x_z_f32x4, pa1_z_f32x4);
143
- float32x4_t ra1_y_f32x4 = vfmaq_f32(
144
- vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa1_x_f32x4), scaled_rotation_y_y_f32x4, pa1_y_f32x4),
145
- scaled_rotation_y_z_f32x4, pa1_z_f32x4);
146
- float32x4_t ra1_z_f32x4 = vfmaq_f32(
147
- vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa1_x_f32x4), scaled_rotation_z_y_f32x4, pa1_y_f32x4),
148
- scaled_rotation_z_z_f32x4, pa1_z_f32x4);
149
-
150
- // Rotate and scale second batch: ra2 = scale * R * pa2
151
- float32x4_t ra2_x_f32x4 = vfmaq_f32(
152
- vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa2_x_f32x4), scaled_rotation_x_y_f32x4, pa2_y_f32x4),
153
- scaled_rotation_x_z_f32x4, pa2_z_f32x4);
154
- float32x4_t ra2_y_f32x4 = vfmaq_f32(
155
- vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa2_x_f32x4), scaled_rotation_y_y_f32x4, pa2_y_f32x4),
156
- scaled_rotation_y_z_f32x4, pa2_z_f32x4);
157
- float32x4_t ra2_z_f32x4 = vfmaq_f32(
158
- vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa2_x_f32x4), scaled_rotation_z_y_f32x4, pa2_y_f32x4),
159
- scaled_rotation_z_z_f32x4, pa2_z_f32x4);
160
-
161
- // Deltas
162
- float32x4_t delta1_x_f32x4 = vsubq_f32(ra1_x_f32x4, pb1_x_f32x4);
163
- float32x4_t delta1_y_f32x4 = vsubq_f32(ra1_y_f32x4, pb1_y_f32x4);
164
- float32x4_t delta1_z_f32x4 = vsubq_f32(ra1_z_f32x4, pb1_z_f32x4);
165
- float32x4_t delta2_x_f32x4 = vsubq_f32(ra2_x_f32x4, pb2_x_f32x4);
166
- float32x4_t delta2_y_f32x4 = vsubq_f32(ra2_y_f32x4, pb2_y_f32x4);
167
- float32x4_t delta2_z_f32x4 = vsubq_f32(ra2_z_f32x4, pb2_z_f32x4);
168
-
169
- // Accumulate to independent accumulators
170
- sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta1_x_f32x4, delta1_x_f32x4);
171
- sum_squared_b_f32x4 = vfmaq_f32(sum_squared_b_f32x4, delta2_x_f32x4, delta2_x_f32x4);
172
- sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta1_y_f32x4, delta1_y_f32x4);
173
- sum_squared_b_f32x4 = vfmaq_f32(sum_squared_b_f32x4, delta2_y_f32x4, delta2_y_f32x4);
174
- sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta1_z_f32x4, delta1_z_f32x4);
175
- sum_squared_b_f32x4 = vfmaq_f32(sum_squared_b_f32x4, delta2_z_f32x4, delta2_z_f32x4);
176
- }
177
-
178
- // Handle remaining 4 points
179
- if (j + 4 <= n) {
180
- float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
181
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + j * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
182
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + j * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
183
-
184
- float32x4_t pa_x_f32x4 = vsubq_f32(a_x_f32x4, centroid_a_x_f32x4);
185
- float32x4_t pa_y_f32x4 = vsubq_f32(a_y_f32x4, centroid_a_y_f32x4);
186
- float32x4_t pa_z_f32x4 = vsubq_f32(a_z_f32x4, centroid_a_z_f32x4);
187
- float32x4_t pb_x_f32x4 = vsubq_f32(b_x_f32x4, centroid_b_x_f32x4);
188
- float32x4_t pb_y_f32x4 = vsubq_f32(b_y_f32x4, centroid_b_y_f32x4);
189
- float32x4_t pb_z_f32x4 = vsubq_f32(b_z_f32x4, centroid_b_z_f32x4);
190
-
191
- float32x4_t ra_x_f32x4 = vfmaq_f32(
192
- vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa_x_f32x4), scaled_rotation_x_y_f32x4, pa_y_f32x4),
193
- scaled_rotation_x_z_f32x4, pa_z_f32x4);
194
- float32x4_t ra_y_f32x4 = vfmaq_f32(
195
- vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa_x_f32x4), scaled_rotation_y_y_f32x4, pa_y_f32x4),
196
- scaled_rotation_y_z_f32x4, pa_z_f32x4);
197
- float32x4_t ra_z_f32x4 = vfmaq_f32(
198
- vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa_x_f32x4), scaled_rotation_z_y_f32x4, pa_y_f32x4),
199
- scaled_rotation_z_z_f32x4, pa_z_f32x4);
200
-
201
- float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
202
- float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
203
- float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
204
-
205
- sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_x_f32x4, delta_x_f32x4);
206
- sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_y_f32x4, delta_y_f32x4);
207
- sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_z_f32x4, delta_z_f32x4);
208
- j += 4;
209
- }
210
-
211
- // Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
212
- if (j < n) {
213
- float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
214
- nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + j * 3, n - j, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
215
- nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + j * 3, n - j, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
216
-
217
- // Mask invalid lanes to zero BEFORE centering
218
- uint32x4_t lane_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
219
- vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
220
- uint32x4_t valid_u32x4 = vcltq_u32(lane_u32x4, vdupq_n_u32((nk_u32_t)(n - j)));
221
- float32x4_t zero_f32x4 = vdupq_n_f32(0);
222
- a_x_f32x4 = vbslq_f32(valid_u32x4, a_x_f32x4, zero_f32x4);
223
- a_y_f32x4 = vbslq_f32(valid_u32x4, a_y_f32x4, zero_f32x4);
224
- a_z_f32x4 = vbslq_f32(valid_u32x4, a_z_f32x4, zero_f32x4);
225
- b_x_f32x4 = vbslq_f32(valid_u32x4, b_x_f32x4, zero_f32x4);
226
- b_y_f32x4 = vbslq_f32(valid_u32x4, b_y_f32x4, zero_f32x4);
227
- b_z_f32x4 = vbslq_f32(valid_u32x4, b_z_f32x4, zero_f32x4);
228
-
229
- // Same centering + rotation + delta + FMA as body
230
- float32x4_t pa_x_f32x4 = vsubq_f32(a_x_f32x4, centroid_a_x_f32x4);
231
- float32x4_t pa_y_f32x4 = vsubq_f32(a_y_f32x4, centroid_a_y_f32x4);
232
- float32x4_t pa_z_f32x4 = vsubq_f32(a_z_f32x4, centroid_a_z_f32x4);
233
- float32x4_t pb_x_f32x4 = vsubq_f32(b_x_f32x4, centroid_b_x_f32x4);
234
- float32x4_t pb_y_f32x4 = vsubq_f32(b_y_f32x4, centroid_b_y_f32x4);
235
- float32x4_t pb_z_f32x4 = vsubq_f32(b_z_f32x4, centroid_b_z_f32x4);
236
-
237
- float32x4_t ra_x_f32x4 = vfmaq_f32(
238
- vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa_x_f32x4), scaled_rotation_x_y_f32x4, pa_y_f32x4),
239
- scaled_rotation_x_z_f32x4, pa_z_f32x4);
240
- float32x4_t ra_y_f32x4 = vfmaq_f32(
241
- vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa_x_f32x4), scaled_rotation_y_y_f32x4, pa_y_f32x4),
242
- scaled_rotation_y_z_f32x4, pa_z_f32x4);
243
- float32x4_t ra_z_f32x4 = vfmaq_f32(
244
- vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa_x_f32x4), scaled_rotation_z_y_f32x4, pa_y_f32x4),
245
- scaled_rotation_z_z_f32x4, pa_z_f32x4);
246
-
247
- float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
248
- float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
249
- float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
250
-
251
- sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_x_f32x4, delta_x_f32x4);
252
- sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_y_f32x4, delta_y_f32x4);
253
- sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_z_f32x4, delta_z_f32x4);
254
- }
255
-
256
- // Combine accumulators and reduce
257
- float32x4_t sum_squared_f32x4 = vaddq_f32(sum_squared_a_f32x4, sum_squared_b_f32x4);
258
- nk_f32_t sum_squared = vaddvq_f32(sum_squared_f32x4);
259
-
260
- return sum_squared;
261
- }
262
-
263
83
  NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
264
84
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
265
85
  // RMSD uses identity rotation and scale=1.0
@@ -318,124 +138,117 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
318
138
  NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
319
139
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
320
140
  float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
321
-
322
- // 2x unrolling with dual accumulators to hide FMA latency.
323
- float32x4_t sum_a_x_a_f32x4 = zeros_f32x4, sum_a_y_a_f32x4 = zeros_f32x4, sum_a_z_a_f32x4 = zeros_f32x4;
324
- float32x4_t sum_b_x_a_f32x4 = zeros_f32x4, sum_b_y_a_f32x4 = zeros_f32x4, sum_b_z_a_f32x4 = zeros_f32x4;
325
- float32x4_t sum_a_x_b_f32x4 = zeros_f32x4, sum_a_y_b_f32x4 = zeros_f32x4, sum_a_z_b_f32x4 = zeros_f32x4;
326
- float32x4_t sum_b_x_b_f32x4 = zeros_f32x4, sum_b_y_b_f32x4 = zeros_f32x4, sum_b_z_b_f32x4 = zeros_f32x4;
327
-
328
- float32x4_t cov_xx_a_f32x4 = zeros_f32x4, cov_xy_a_f32x4 = zeros_f32x4, cov_xz_a_f32x4 = zeros_f32x4;
329
- float32x4_t cov_yx_a_f32x4 = zeros_f32x4, cov_yy_a_f32x4 = zeros_f32x4, cov_yz_a_f32x4 = zeros_f32x4;
330
- float32x4_t cov_zx_a_f32x4 = zeros_f32x4, cov_zy_a_f32x4 = zeros_f32x4, cov_zz_a_f32x4 = zeros_f32x4;
331
- float32x4_t cov_xx_b_f32x4 = zeros_f32x4, cov_xy_b_f32x4 = zeros_f32x4, cov_xz_b_f32x4 = zeros_f32x4;
332
- float32x4_t cov_yx_b_f32x4 = zeros_f32x4, cov_yy_b_f32x4 = zeros_f32x4, cov_yz_b_f32x4 = zeros_f32x4;
333
- float32x4_t cov_zx_b_f32x4 = zeros_f32x4, cov_zy_b_f32x4 = zeros_f32x4, cov_zz_b_f32x4 = zeros_f32x4;
141
+ // bf16 representation of 1.0 is 0x3F80; splatted across 8 lanes for BFDOT-based horizontal sums.
142
+ // The `nk_u16x8_splat_` wrapper prevents GCC from lowering this to `fmov v.8h, #1.0` (a FEAT_FP16
143
+ // encoding) that fails to assemble under a `+bf16`-only pragma.
144
+ bfloat16x8_t const ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
145
+
146
+ // Centroid numerators, norm-squared, and 3x3 cross-covariance accumulators.
147
+ float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
148
+ float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
149
+ float32x4_t covariance_xx_f32x4 = zeros_f32x4, covariance_xy_f32x4 = zeros_f32x4, covariance_xz_f32x4 = zeros_f32x4;
150
+ float32x4_t covariance_yx_f32x4 = zeros_f32x4, covariance_yy_f32x4 = zeros_f32x4, covariance_yz_f32x4 = zeros_f32x4;
151
+ float32x4_t covariance_zx_f32x4 = zeros_f32x4, covariance_zy_f32x4 = zeros_f32x4, covariance_zz_f32x4 = zeros_f32x4;
152
+ float32x4_t norm_squared_a_f32x4 = zeros_f32x4, norm_squared_b_f32x4 = zeros_f32x4;
334
153
 
335
154
  nk_size_t i = 0;
336
- float32x4_t a1_x_f32x4, a1_y_f32x4, a1_z_f32x4, b1_x_f32x4, b1_y_f32x4, b1_z_f32x4;
337
- float32x4_t a2_x_f32x4, a2_y_f32x4, a2_z_f32x4, b2_x_f32x4, b2_y_f32x4, b2_z_f32x4;
338
155
 
339
- // Main loop: 8 points per iteration (2x unrolled)
156
+ // Main loop: 8 triplets per iteration via vld3q_u16 + vbfdotq_f32.
157
+ // Each vld3q_u16 de-interleaves 24 bf16 values into three 8-lane channel vectors.
158
+ // vbfdotq_f32(acc, p, q) computes per-32bit-lane: acc[l] += p[2l]*q[2l] + p[2l+1]*q[2l+1]
159
+ // with bf16 inputs and f32 accumulation. Summing the 4 lanes at the end yields the scalar.
340
160
  for (; i + 8 <= n; i += 8) {
341
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
342
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
343
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + (i + 4) * 3, &a2_x_f32x4, &a2_y_f32x4, &a2_z_f32x4);
344
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + (i + 4) * 3, &b2_x_f32x4, &b2_y_f32x4, &b2_z_f32x4);
345
-
346
- // Interleaved accumulation to hide FMA latency
347
- sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
348
- sum_a_x_b_f32x4 = vaddq_f32(sum_a_x_b_f32x4, a2_x_f32x4);
349
- sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
350
- sum_a_y_b_f32x4 = vaddq_f32(sum_a_y_b_f32x4, a2_y_f32x4);
351
- sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
352
- sum_a_z_b_f32x4 = vaddq_f32(sum_a_z_b_f32x4, a2_z_f32x4);
353
- sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
354
- sum_b_x_b_f32x4 = vaddq_f32(sum_b_x_b_f32x4, b2_x_f32x4);
355
- sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
356
- sum_b_y_b_f32x4 = vaddq_f32(sum_b_y_b_f32x4, b2_y_f32x4);
357
- sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
358
- sum_b_z_b_f32x4 = vaddq_f32(sum_b_z_b_f32x4, b2_z_f32x4);
359
-
360
- cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
361
- cov_xx_b_f32x4 = vfmaq_f32(cov_xx_b_f32x4, a2_x_f32x4, b2_x_f32x4);
362
- cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
363
- cov_xy_b_f32x4 = vfmaq_f32(cov_xy_b_f32x4, a2_x_f32x4, b2_y_f32x4);
364
- cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
365
- cov_xz_b_f32x4 = vfmaq_f32(cov_xz_b_f32x4, a2_x_f32x4, b2_z_f32x4);
366
- cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
367
- cov_yx_b_f32x4 = vfmaq_f32(cov_yx_b_f32x4, a2_y_f32x4, b2_x_f32x4);
368
- cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
369
- cov_yy_b_f32x4 = vfmaq_f32(cov_yy_b_f32x4, a2_y_f32x4, b2_y_f32x4);
370
- cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
371
- cov_yz_b_f32x4 = vfmaq_f32(cov_yz_b_f32x4, a2_y_f32x4, b2_z_f32x4);
372
- cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
373
- cov_zx_b_f32x4 = vfmaq_f32(cov_zx_b_f32x4, a2_z_f32x4, b2_x_f32x4);
374
- cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
375
- cov_zy_b_f32x4 = vfmaq_f32(cov_zy_b_f32x4, a2_z_f32x4, b2_y_f32x4);
376
- cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
377
- cov_zz_b_f32x4 = vfmaq_f32(cov_zz_b_f32x4, a2_z_f32x4, b2_z_f32x4);
161
+ uint16x8x3_t a_xyz_u16x8x3 = vld3q_u16((nk_u16_t const *)(a + i * 3));
162
+ uint16x8x3_t b_xyz_u16x8x3 = vld3q_u16((nk_u16_t const *)(b + i * 3));
163
+ bfloat16x8_t a_x_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[0]);
164
+ bfloat16x8_t a_y_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[1]);
165
+ bfloat16x8_t a_z_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[2]);
166
+ bfloat16x8_t b_x_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[0]);
167
+ bfloat16x8_t b_y_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[1]);
168
+ bfloat16x8_t b_z_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[2]);
169
+
170
+ // Centroid numerators: Σ channel values (pairwise via BFDOT against bf16 1.0).
171
+ sum_a_x_f32x4 = vbfdotq_f32(sum_a_x_f32x4, a_x_bf16x8, ones_bf16x8);
172
+ sum_a_y_f32x4 = vbfdotq_f32(sum_a_y_f32x4, a_y_bf16x8, ones_bf16x8);
173
+ sum_a_z_f32x4 = vbfdotq_f32(sum_a_z_f32x4, a_z_bf16x8, ones_bf16x8);
174
+ sum_b_x_f32x4 = vbfdotq_f32(sum_b_x_f32x4, b_x_bf16x8, ones_bf16x8);
175
+ sum_b_y_f32x4 = vbfdotq_f32(sum_b_y_f32x4, b_y_bf16x8, ones_bf16x8);
176
+ sum_b_z_f32x4 = vbfdotq_f32(sum_b_z_f32x4, b_z_bf16x8, ones_bf16x8);
177
+
178
+ // 3x3 cross-covariance H cells: Σ a_j · b_k.
179
+ covariance_xx_f32x4 = vbfdotq_f32(covariance_xx_f32x4, a_x_bf16x8, b_x_bf16x8);
180
+ covariance_xy_f32x4 = vbfdotq_f32(covariance_xy_f32x4, a_x_bf16x8, b_y_bf16x8);
181
+ covariance_xz_f32x4 = vbfdotq_f32(covariance_xz_f32x4, a_x_bf16x8, b_z_bf16x8);
182
+ covariance_yx_f32x4 = vbfdotq_f32(covariance_yx_f32x4, a_y_bf16x8, b_x_bf16x8);
183
+ covariance_yy_f32x4 = vbfdotq_f32(covariance_yy_f32x4, a_y_bf16x8, b_y_bf16x8);
184
+ covariance_yz_f32x4 = vbfdotq_f32(covariance_yz_f32x4, a_y_bf16x8, b_z_bf16x8);
185
+ covariance_zx_f32x4 = vbfdotq_f32(covariance_zx_f32x4, a_z_bf16x8, b_x_bf16x8);
186
+ covariance_zy_f32x4 = vbfdotq_f32(covariance_zy_f32x4, a_z_bf16x8, b_y_bf16x8);
187
+ covariance_zz_f32x4 = vbfdotq_f32(covariance_zz_f32x4, a_z_bf16x8, b_z_bf16x8);
188
+
189
+ // Norm-squared per point set: Σ ( + y² + z²).
190
+ norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_x_bf16x8, a_x_bf16x8);
191
+ norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_y_bf16x8, a_y_bf16x8);
192
+ norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_z_bf16x8, a_z_bf16x8);
193
+ norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_x_bf16x8, b_x_bf16x8);
194
+ norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_y_bf16x8, b_y_bf16x8);
195
+ norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_z_bf16x8, b_z_bf16x8);
378
196
  }
379
197
 
380
- // 4-point tail
198
+ // 4-point and partial (1-3) tails: keep the widen+fmaq path on f32x4 channel vectors.
199
+ // These branches run at most once each, so we skip another vbfdotq variant for them.
200
+ float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
381
201
  for (; i + 4 <= n; i += 4) {
382
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
383
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
384
- sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
385
- sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
386
- sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
387
- sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
388
- sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
389
- sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
390
- cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
391
- cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
392
- cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
393
- cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
394
- cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
395
- cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
396
- cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
397
- cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
398
- cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
202
+ nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
203
+ nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
204
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
205
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
206
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
207
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
208
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
209
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
210
+ covariance_xx_f32x4 = vfmaq_f32(covariance_xx_f32x4, a_x_f32x4, b_x_f32x4);
211
+ covariance_xy_f32x4 = vfmaq_f32(covariance_xy_f32x4, a_x_f32x4, b_y_f32x4);
212
+ covariance_xz_f32x4 = vfmaq_f32(covariance_xz_f32x4, a_x_f32x4, b_z_f32x4);
213
+ covariance_yx_f32x4 = vfmaq_f32(covariance_yx_f32x4, a_y_f32x4, b_x_f32x4);
214
+ covariance_yy_f32x4 = vfmaq_f32(covariance_yy_f32x4, a_y_f32x4, b_y_f32x4);
215
+ covariance_yz_f32x4 = vfmaq_f32(covariance_yz_f32x4, a_y_f32x4, b_z_f32x4);
216
+ covariance_zx_f32x4 = vfmaq_f32(covariance_zx_f32x4, a_z_f32x4, b_x_f32x4);
217
+ covariance_zy_f32x4 = vfmaq_f32(covariance_zy_f32x4, a_z_f32x4, b_y_f32x4);
218
+ covariance_zz_f32x4 = vfmaq_f32(covariance_zz_f32x4, a_z_f32x4, b_z_f32x4);
219
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_f32x4, a_x_f32x4);
220
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_f32x4, a_y_f32x4);
221
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_f32x4, a_z_f32x4);
222
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_f32x4, b_x_f32x4);
223
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_f32x4, b_y_f32x4);
224
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_f32x4, b_z_f32x4);
399
225
  }
400
-
401
- // Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
402
226
  if (i < n) {
403
- nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
404
- nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
405
- sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
406
- sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
407
- sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
408
- sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
409
- sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
410
- sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
411
- cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
412
- cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
413
- cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
414
- cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
415
- cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
416
- cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
417
- cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
418
- cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
419
- cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
227
+ nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
228
+ nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
229
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
230
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
231
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
232
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
233
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
234
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
235
+ covariance_xx_f32x4 = vfmaq_f32(covariance_xx_f32x4, a_x_f32x4, b_x_f32x4);
236
+ covariance_xy_f32x4 = vfmaq_f32(covariance_xy_f32x4, a_x_f32x4, b_y_f32x4);
237
+ covariance_xz_f32x4 = vfmaq_f32(covariance_xz_f32x4, a_x_f32x4, b_z_f32x4);
238
+ covariance_yx_f32x4 = vfmaq_f32(covariance_yx_f32x4, a_y_f32x4, b_x_f32x4);
239
+ covariance_yy_f32x4 = vfmaq_f32(covariance_yy_f32x4, a_y_f32x4, b_y_f32x4);
240
+ covariance_yz_f32x4 = vfmaq_f32(covariance_yz_f32x4, a_y_f32x4, b_z_f32x4);
241
+ covariance_zx_f32x4 = vfmaq_f32(covariance_zx_f32x4, a_z_f32x4, b_x_f32x4);
242
+ covariance_zy_f32x4 = vfmaq_f32(covariance_zy_f32x4, a_z_f32x4, b_y_f32x4);
243
+ covariance_zz_f32x4 = vfmaq_f32(covariance_zz_f32x4, a_z_f32x4, b_z_f32x4);
244
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_f32x4, a_x_f32x4);
245
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_f32x4, a_y_f32x4);
246
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_f32x4, a_z_f32x4);
247
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_f32x4, b_x_f32x4);
248
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_f32x4, b_y_f32x4);
249
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_f32x4, b_z_f32x4);
420
250
  }
421
251
 
422
- // Combine dual accumulators
423
- float32x4_t sum_a_x_f32x4 = vaddq_f32(sum_a_x_a_f32x4, sum_a_x_b_f32x4);
424
- float32x4_t sum_a_y_f32x4 = vaddq_f32(sum_a_y_a_f32x4, sum_a_y_b_f32x4);
425
- float32x4_t sum_a_z_f32x4 = vaddq_f32(sum_a_z_a_f32x4, sum_a_z_b_f32x4);
426
- float32x4_t sum_b_x_f32x4 = vaddq_f32(sum_b_x_a_f32x4, sum_b_x_b_f32x4);
427
- float32x4_t sum_b_y_f32x4 = vaddq_f32(sum_b_y_a_f32x4, sum_b_y_b_f32x4);
428
- float32x4_t sum_b_z_f32x4 = vaddq_f32(sum_b_z_a_f32x4, sum_b_z_b_f32x4);
429
- float32x4_t cov_xx_f32x4 = vaddq_f32(cov_xx_a_f32x4, cov_xx_b_f32x4);
430
- float32x4_t cov_xy_f32x4 = vaddq_f32(cov_xy_a_f32x4, cov_xy_b_f32x4);
431
- float32x4_t cov_xz_f32x4 = vaddq_f32(cov_xz_a_f32x4, cov_xz_b_f32x4);
432
- float32x4_t cov_yx_f32x4 = vaddq_f32(cov_yx_a_f32x4, cov_yx_b_f32x4);
433
- float32x4_t cov_yy_f32x4 = vaddq_f32(cov_yy_a_f32x4, cov_yy_b_f32x4);
434
- float32x4_t cov_yz_f32x4 = vaddq_f32(cov_yz_a_f32x4, cov_yz_b_f32x4);
435
- float32x4_t cov_zx_f32x4 = vaddq_f32(cov_zx_a_f32x4, cov_zx_b_f32x4);
436
- float32x4_t cov_zy_f32x4 = vaddq_f32(cov_zy_a_f32x4, cov_zy_b_f32x4);
437
- float32x4_t cov_zz_f32x4 = vaddq_f32(cov_zz_a_f32x4, cov_zz_b_f32x4);
438
-
439
252
  // Reduce vector accumulators
440
253
  nk_f32_t sum_a_x = vaddvq_f32(sum_a_x_f32x4);
441
254
  nk_f32_t sum_a_y = vaddvq_f32(sum_a_y_f32x4);
@@ -444,15 +257,17 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
444
257
  nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
445
258
  nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
446
259
 
447
- nk_f32_t covariance_x_x = vaddvq_f32(cov_xx_f32x4);
448
- nk_f32_t covariance_x_y = vaddvq_f32(cov_xy_f32x4);
449
- nk_f32_t covariance_x_z = vaddvq_f32(cov_xz_f32x4);
450
- nk_f32_t covariance_y_x = vaddvq_f32(cov_yx_f32x4);
451
- nk_f32_t covariance_y_y = vaddvq_f32(cov_yy_f32x4);
452
- nk_f32_t covariance_y_z = vaddvq_f32(cov_yz_f32x4);
453
- nk_f32_t covariance_z_x = vaddvq_f32(cov_zx_f32x4);
454
- nk_f32_t covariance_z_y = vaddvq_f32(cov_zy_f32x4);
455
- nk_f32_t covariance_z_z = vaddvq_f32(cov_zz_f32x4);
260
+ nk_f32_t covariance_x_x = vaddvq_f32(covariance_xx_f32x4);
261
+ nk_f32_t covariance_x_y = vaddvq_f32(covariance_xy_f32x4);
262
+ nk_f32_t covariance_x_z = vaddvq_f32(covariance_xz_f32x4);
263
+ nk_f32_t covariance_y_x = vaddvq_f32(covariance_yx_f32x4);
264
+ nk_f32_t covariance_y_y = vaddvq_f32(covariance_yy_f32x4);
265
+ nk_f32_t covariance_y_z = vaddvq_f32(covariance_yz_f32x4);
266
+ nk_f32_t covariance_z_x = vaddvq_f32(covariance_zx_f32x4);
267
+ nk_f32_t covariance_z_y = vaddvq_f32(covariance_zy_f32x4);
268
+ nk_f32_t covariance_z_z = vaddvq_f32(covariance_zz_f32x4);
269
+ nk_f32_t norm_squared_a = vaddvq_f32(norm_squared_a_f32x4);
270
+ nk_f32_t norm_squared_b = vaddvq_f32(norm_squared_b_f32x4);
456
271
 
457
272
  // Compute centroids
458
273
  nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
@@ -466,6 +281,16 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
466
281
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
467
282
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
468
283
 
284
+ // Centered norm-squared via parallel-axis identity; clamp at zero for numeric safety.
285
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
286
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
287
+ centroid_a_z * centroid_a_z);
288
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
289
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
290
+ centroid_b_z * centroid_b_z);
291
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
292
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
293
+
469
294
  // Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
470
295
  covariance_x_x -= n * centroid_a_x * centroid_b_x;
471
296
  covariance_x_y -= n * centroid_a_x * centroid_b_y;
@@ -480,187 +305,186 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
480
305
  // Compute SVD and optimal rotation
481
306
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
482
307
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
483
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
484
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
485
-
486
- // R = V * Uᵀ
487
- nk_f32_t r[9];
488
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
489
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
490
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
491
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
492
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
493
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
494
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
495
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
496
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
497
-
498
- // Handle reflection: if det(R) < 0, negate third column of V and recompute R
499
- if (nk_det3x3_f32_(r) < 0) {
500
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
501
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
502
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
503
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
504
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
505
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
506
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
507
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
508
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
509
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
308
+
309
+ // Identity-dominant short-circuit: if H ≈ diag(positive entries), R = I and trace(R·H) = trace(H).
310
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
311
+ cross_covariance[4] * cross_covariance[4] +
312
+ cross_covariance[8] * cross_covariance[8];
313
+ nk_f32_t covariance_offdiagonal_norm_squared =
314
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
315
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
316
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
317
+ nk_f32_t optimal_rotation[9];
318
+ nk_f32_t trace_rotation_covariance;
319
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
320
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
321
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
322
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
323
+ optimal_rotation[8] = 1;
324
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
325
+ }
326
+ else {
327
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
328
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
329
+
330
+ // R = V * Uᵀ
331
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
332
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
333
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
334
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
335
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
336
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
337
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
338
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
339
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
340
+
341
+ // Handle reflection: if det(R) < 0, negate third column of V and recompute R
342
+ if (nk_det3x3_f32_(optimal_rotation) < 0) {
343
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
344
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
345
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
346
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
347
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
348
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
349
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
350
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
351
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
352
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
353
+ }
354
+
355
+ trace_rotation_covariance =
356
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
357
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
358
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
359
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
360
+ optimal_rotation[8] * cross_covariance[8];
510
361
  }
511
362
 
512
363
  // Output rotation matrix and scale=1.0
513
364
  if (rotation)
514
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
365
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
515
366
  if (scale) *scale = 1.0f;
516
367
 
517
- // Compute RMSD after optimal rotation
518
- nk_f32_t sum_squared = nk_transformed_ssd_bf16_neonbfdot_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y,
519
- centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z);
368
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
369
+ nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
370
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
520
371
  *result = nk_f32_sqrt_neon(sum_squared * inv_n);
521
372
  }
522
373
 
523
374
  NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
524
375
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
525
376
  float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
526
-
527
- // 2x unrolling with dual accumulators to hide FMA latency.
528
- float32x4_t sum_a_x_a_f32x4 = zeros_f32x4, sum_a_y_a_f32x4 = zeros_f32x4, sum_a_z_a_f32x4 = zeros_f32x4;
529
- float32x4_t sum_b_x_a_f32x4 = zeros_f32x4, sum_b_y_a_f32x4 = zeros_f32x4, sum_b_z_a_f32x4 = zeros_f32x4;
530
- float32x4_t sum_a_x_b_f32x4 = zeros_f32x4, sum_a_y_b_f32x4 = zeros_f32x4, sum_a_z_b_f32x4 = zeros_f32x4;
531
- float32x4_t sum_b_x_b_f32x4 = zeros_f32x4, sum_b_y_b_f32x4 = zeros_f32x4, sum_b_z_b_f32x4 = zeros_f32x4;
532
-
533
- float32x4_t cov_xx_a_f32x4 = zeros_f32x4, cov_xy_a_f32x4 = zeros_f32x4, cov_xz_a_f32x4 = zeros_f32x4;
534
- float32x4_t cov_yx_a_f32x4 = zeros_f32x4, cov_yy_a_f32x4 = zeros_f32x4, cov_yz_a_f32x4 = zeros_f32x4;
535
- float32x4_t cov_zx_a_f32x4 = zeros_f32x4, cov_zy_a_f32x4 = zeros_f32x4, cov_zz_a_f32x4 = zeros_f32x4;
536
- float32x4_t cov_xx_b_f32x4 = zeros_f32x4, cov_xy_b_f32x4 = zeros_f32x4, cov_xz_b_f32x4 = zeros_f32x4;
537
- float32x4_t cov_yx_b_f32x4 = zeros_f32x4, cov_yy_b_f32x4 = zeros_f32x4, cov_yz_b_f32x4 = zeros_f32x4;
538
- float32x4_t cov_zx_b_f32x4 = zeros_f32x4, cov_zy_b_f32x4 = zeros_f32x4, cov_zz_b_f32x4 = zeros_f32x4;
539
-
540
- // Variance of A accumulators
541
- float32x4_t variance_a_a_f32x4 = zeros_f32x4;
542
- float32x4_t variance_a_b_f32x4 = zeros_f32x4;
377
+ // bf16 representation of 1.0 is 0x3F80; splatted across 8 lanes for BFDOT-based horizontal sums.
378
+ // The `nk_u16x8_splat_` wrapper prevents GCC from lowering this to `fmov v.8h, #1.0` (a FEAT_FP16
379
+ // encoding) that fails to assemble under a `+bf16`-only pragma.
380
+ bfloat16x8_t const ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
381
+
382
+ // Centroid numerators, norm-squared, and 3x3 cross-covariance accumulators.
383
+ float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
384
+ float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
385
+ float32x4_t covariance_xx_f32x4 = zeros_f32x4, covariance_xy_f32x4 = zeros_f32x4, covariance_xz_f32x4 = zeros_f32x4;
386
+ float32x4_t covariance_yx_f32x4 = zeros_f32x4, covariance_yy_f32x4 = zeros_f32x4, covariance_yz_f32x4 = zeros_f32x4;
387
+ float32x4_t covariance_zx_f32x4 = zeros_f32x4, covariance_zy_f32x4 = zeros_f32x4, covariance_zz_f32x4 = zeros_f32x4;
388
+ float32x4_t norm_squared_a_f32x4 = zeros_f32x4, norm_squared_b_f32x4 = zeros_f32x4;
543
389
 
544
390
  nk_size_t i = 0;
545
- float32x4_t a1_x_f32x4, a1_y_f32x4, a1_z_f32x4, b1_x_f32x4, b1_y_f32x4, b1_z_f32x4;
546
- float32x4_t a2_x_f32x4, a2_y_f32x4, a2_z_f32x4, b2_x_f32x4, b2_y_f32x4, b2_z_f32x4;
547
391
 
548
- // Main loop: 8 points per iteration (2x unrolled)
392
+ // Main loop: 8 triplets per iteration via vld3q_u16 + vbfdotq_f32.
393
+ // Each vld3q_u16 de-interleaves 24 bf16 values into three 8-lane channel vectors.
394
+ // vbfdotq_f32(acc, p, q) computes per-32bit-lane: acc[l] += p[2l]*q[2l] + p[2l+1]*q[2l+1]
395
+ // with bf16 inputs and f32 accumulation. Summing the 4 lanes at the end yields the scalar.
549
396
  for (; i + 8 <= n; i += 8) {
550
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
551
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
552
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + (i + 4) * 3, &a2_x_f32x4, &a2_y_f32x4, &a2_z_f32x4);
553
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + (i + 4) * 3, &b2_x_f32x4, &b2_y_f32x4, &b2_z_f32x4);
554
-
555
- // Interleaved accumulation to hide FMA latency
556
- sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
557
- sum_a_x_b_f32x4 = vaddq_f32(sum_a_x_b_f32x4, a2_x_f32x4);
558
- sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
559
- sum_a_y_b_f32x4 = vaddq_f32(sum_a_y_b_f32x4, a2_y_f32x4);
560
- sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
561
- sum_a_z_b_f32x4 = vaddq_f32(sum_a_z_b_f32x4, a2_z_f32x4);
562
- sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
563
- sum_b_x_b_f32x4 = vaddq_f32(sum_b_x_b_f32x4, b2_x_f32x4);
564
- sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
565
- sum_b_y_b_f32x4 = vaddq_f32(sum_b_y_b_f32x4, b2_y_f32x4);
566
- sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
567
- sum_b_z_b_f32x4 = vaddq_f32(sum_b_z_b_f32x4, b2_z_f32x4);
568
-
569
- // Covariance matrix
570
- cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
571
- cov_xx_b_f32x4 = vfmaq_f32(cov_xx_b_f32x4, a2_x_f32x4, b2_x_f32x4);
572
- cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
573
- cov_xy_b_f32x4 = vfmaq_f32(cov_xy_b_f32x4, a2_x_f32x4, b2_y_f32x4);
574
- cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
575
- cov_xz_b_f32x4 = vfmaq_f32(cov_xz_b_f32x4, a2_x_f32x4, b2_z_f32x4);
576
- cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
577
- cov_yx_b_f32x4 = vfmaq_f32(cov_yx_b_f32x4, a2_y_f32x4, b2_x_f32x4);
578
- cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
579
- cov_yy_b_f32x4 = vfmaq_f32(cov_yy_b_f32x4, a2_y_f32x4, b2_y_f32x4);
580
- cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
581
- cov_yz_b_f32x4 = vfmaq_f32(cov_yz_b_f32x4, a2_y_f32x4, b2_z_f32x4);
582
- cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
583
- cov_zx_b_f32x4 = vfmaq_f32(cov_zx_b_f32x4, a2_z_f32x4, b2_x_f32x4);
584
- cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
585
- cov_zy_b_f32x4 = vfmaq_f32(cov_zy_b_f32x4, a2_z_f32x4, b2_y_f32x4);
586
- cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
587
- cov_zz_b_f32x4 = vfmaq_f32(cov_zz_b_f32x4, a2_z_f32x4, b2_z_f32x4);
588
-
589
- // Variance of A
590
- variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_x_f32x4, a1_x_f32x4);
591
- variance_a_b_f32x4 = vfmaq_f32(variance_a_b_f32x4, a2_x_f32x4, a2_x_f32x4);
592
- variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_y_f32x4, a1_y_f32x4);
593
- variance_a_b_f32x4 = vfmaq_f32(variance_a_b_f32x4, a2_y_f32x4, a2_y_f32x4);
594
- variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_z_f32x4, a1_z_f32x4);
595
- variance_a_b_f32x4 = vfmaq_f32(variance_a_b_f32x4, a2_z_f32x4, a2_z_f32x4);
397
+ uint16x8x3_t a_xyz_u16x8x3 = vld3q_u16((nk_u16_t const *)(a + i * 3));
398
+ uint16x8x3_t b_xyz_u16x8x3 = vld3q_u16((nk_u16_t const *)(b + i * 3));
399
+ bfloat16x8_t a_x_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[0]);
400
+ bfloat16x8_t a_y_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[1]);
401
+ bfloat16x8_t a_z_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[2]);
402
+ bfloat16x8_t b_x_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[0]);
403
+ bfloat16x8_t b_y_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[1]);
404
+ bfloat16x8_t b_z_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[2]);
405
+
406
+ // Centroid numerators: Σ channel values (pairwise via BFDOT against bf16 1.0).
407
+ sum_a_x_f32x4 = vbfdotq_f32(sum_a_x_f32x4, a_x_bf16x8, ones_bf16x8);
408
+ sum_a_y_f32x4 = vbfdotq_f32(sum_a_y_f32x4, a_y_bf16x8, ones_bf16x8);
409
+ sum_a_z_f32x4 = vbfdotq_f32(sum_a_z_f32x4, a_z_bf16x8, ones_bf16x8);
410
+ sum_b_x_f32x4 = vbfdotq_f32(sum_b_x_f32x4, b_x_bf16x8, ones_bf16x8);
411
+ sum_b_y_f32x4 = vbfdotq_f32(sum_b_y_f32x4, b_y_bf16x8, ones_bf16x8);
412
+ sum_b_z_f32x4 = vbfdotq_f32(sum_b_z_f32x4, b_z_bf16x8, ones_bf16x8);
413
+
414
+ // 3x3 cross-covariance H cells: Σ a_j · b_k.
415
+ covariance_xx_f32x4 = vbfdotq_f32(covariance_xx_f32x4, a_x_bf16x8, b_x_bf16x8);
416
+ covariance_xy_f32x4 = vbfdotq_f32(covariance_xy_f32x4, a_x_bf16x8, b_y_bf16x8);
417
+ covariance_xz_f32x4 = vbfdotq_f32(covariance_xz_f32x4, a_x_bf16x8, b_z_bf16x8);
418
+ covariance_yx_f32x4 = vbfdotq_f32(covariance_yx_f32x4, a_y_bf16x8, b_x_bf16x8);
419
+ covariance_yy_f32x4 = vbfdotq_f32(covariance_yy_f32x4, a_y_bf16x8, b_y_bf16x8);
420
+ covariance_yz_f32x4 = vbfdotq_f32(covariance_yz_f32x4, a_y_bf16x8, b_z_bf16x8);
421
+ covariance_zx_f32x4 = vbfdotq_f32(covariance_zx_f32x4, a_z_bf16x8, b_x_bf16x8);
422
+ covariance_zy_f32x4 = vbfdotq_f32(covariance_zy_f32x4, a_z_bf16x8, b_y_bf16x8);
423
+ covariance_zz_f32x4 = vbfdotq_f32(covariance_zz_f32x4, a_z_bf16x8, b_z_bf16x8);
424
+
425
+ // Norm-squared per point set: Σ ( + y² + z²).
426
+ norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_x_bf16x8, a_x_bf16x8);
427
+ norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_y_bf16x8, a_y_bf16x8);
428
+ norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_z_bf16x8, a_z_bf16x8);
429
+ norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_x_bf16x8, b_x_bf16x8);
430
+ norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_y_bf16x8, b_y_bf16x8);
431
+ norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_z_bf16x8, b_z_bf16x8);
596
432
  }
597
433
 
598
- // 4-point tail
434
+ // 4-point and partial (1-3) tails: keep the widen+fmaq path on f32x4 channel vectors.
435
+ // These branches run at most once each, so we skip another vbfdotq variant for them.
436
+ float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
599
437
  for (; i + 4 <= n; i += 4) {
600
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
601
- nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
602
- sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
603
- sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
604
- sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
605
- sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
606
- sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
607
- sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
608
- cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
609
- cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
610
- cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
611
- cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
612
- cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
613
- cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
614
- cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
615
- cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
616
- cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
617
- variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_x_f32x4, a1_x_f32x4);
618
- variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_y_f32x4, a1_y_f32x4);
619
- variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_z_f32x4, a1_z_f32x4);
438
+ nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
439
+ nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
440
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
441
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
442
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
443
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
444
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
445
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
446
+ covariance_xx_f32x4 = vfmaq_f32(covariance_xx_f32x4, a_x_f32x4, b_x_f32x4);
447
+ covariance_xy_f32x4 = vfmaq_f32(covariance_xy_f32x4, a_x_f32x4, b_y_f32x4);
448
+ covariance_xz_f32x4 = vfmaq_f32(covariance_xz_f32x4, a_x_f32x4, b_z_f32x4);
449
+ covariance_yx_f32x4 = vfmaq_f32(covariance_yx_f32x4, a_y_f32x4, b_x_f32x4);
450
+ covariance_yy_f32x4 = vfmaq_f32(covariance_yy_f32x4, a_y_f32x4, b_y_f32x4);
451
+ covariance_yz_f32x4 = vfmaq_f32(covariance_yz_f32x4, a_y_f32x4, b_z_f32x4);
452
+ covariance_zx_f32x4 = vfmaq_f32(covariance_zx_f32x4, a_z_f32x4, b_x_f32x4);
453
+ covariance_zy_f32x4 = vfmaq_f32(covariance_zy_f32x4, a_z_f32x4, b_y_f32x4);
454
+ covariance_zz_f32x4 = vfmaq_f32(covariance_zz_f32x4, a_z_f32x4, b_z_f32x4);
455
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_f32x4, a_x_f32x4);
456
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_f32x4, a_y_f32x4);
457
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_f32x4, a_z_f32x4);
458
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_f32x4, b_x_f32x4);
459
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_f32x4, b_y_f32x4);
460
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_f32x4, b_z_f32x4);
620
461
  }
621
-
622
- // Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
623
462
  if (i < n) {
624
- nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
625
- nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
626
- sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
627
- sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
628
- sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
629
- sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
630
- sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
631
- sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
632
- cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
633
- cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
634
- cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
635
- cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
636
- cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
637
- cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
638
- cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
639
- cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
640
- cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
641
- variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_x_f32x4, a1_x_f32x4);
642
- variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_y_f32x4, a1_y_f32x4);
643
- variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_z_f32x4, a1_z_f32x4);
463
+ nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
464
+ nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
465
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
466
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
467
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
468
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
469
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
470
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
471
+ covariance_xx_f32x4 = vfmaq_f32(covariance_xx_f32x4, a_x_f32x4, b_x_f32x4);
472
+ covariance_xy_f32x4 = vfmaq_f32(covariance_xy_f32x4, a_x_f32x4, b_y_f32x4);
473
+ covariance_xz_f32x4 = vfmaq_f32(covariance_xz_f32x4, a_x_f32x4, b_z_f32x4);
474
+ covariance_yx_f32x4 = vfmaq_f32(covariance_yx_f32x4, a_y_f32x4, b_x_f32x4);
475
+ covariance_yy_f32x4 = vfmaq_f32(covariance_yy_f32x4, a_y_f32x4, b_y_f32x4);
476
+ covariance_yz_f32x4 = vfmaq_f32(covariance_yz_f32x4, a_y_f32x4, b_z_f32x4);
477
+ covariance_zx_f32x4 = vfmaq_f32(covariance_zx_f32x4, a_z_f32x4, b_x_f32x4);
478
+ covariance_zy_f32x4 = vfmaq_f32(covariance_zy_f32x4, a_z_f32x4, b_y_f32x4);
479
+ covariance_zz_f32x4 = vfmaq_f32(covariance_zz_f32x4, a_z_f32x4, b_z_f32x4);
480
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_f32x4, a_x_f32x4);
481
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_f32x4, a_y_f32x4);
482
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_f32x4, a_z_f32x4);
483
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_f32x4, b_x_f32x4);
484
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_f32x4, b_y_f32x4);
485
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_f32x4, b_z_f32x4);
644
486
  }
645
487
 
646
- // Combine dual accumulators
647
- float32x4_t sum_a_x_f32x4 = vaddq_f32(sum_a_x_a_f32x4, sum_a_x_b_f32x4);
648
- float32x4_t sum_a_y_f32x4 = vaddq_f32(sum_a_y_a_f32x4, sum_a_y_b_f32x4);
649
- float32x4_t sum_a_z_f32x4 = vaddq_f32(sum_a_z_a_f32x4, sum_a_z_b_f32x4);
650
- float32x4_t sum_b_x_f32x4 = vaddq_f32(sum_b_x_a_f32x4, sum_b_x_b_f32x4);
651
- float32x4_t sum_b_y_f32x4 = vaddq_f32(sum_b_y_a_f32x4, sum_b_y_b_f32x4);
652
- float32x4_t sum_b_z_f32x4 = vaddq_f32(sum_b_z_a_f32x4, sum_b_z_b_f32x4);
653
- float32x4_t cov_xx_f32x4 = vaddq_f32(cov_xx_a_f32x4, cov_xx_b_f32x4);
654
- float32x4_t cov_xy_f32x4 = vaddq_f32(cov_xy_a_f32x4, cov_xy_b_f32x4);
655
- float32x4_t cov_xz_f32x4 = vaddq_f32(cov_xz_a_f32x4, cov_xz_b_f32x4);
656
- float32x4_t cov_yx_f32x4 = vaddq_f32(cov_yx_a_f32x4, cov_yx_b_f32x4);
657
- float32x4_t cov_yy_f32x4 = vaddq_f32(cov_yy_a_f32x4, cov_yy_b_f32x4);
658
- float32x4_t cov_yz_f32x4 = vaddq_f32(cov_yz_a_f32x4, cov_yz_b_f32x4);
659
- float32x4_t cov_zx_f32x4 = vaddq_f32(cov_zx_a_f32x4, cov_zx_b_f32x4);
660
- float32x4_t cov_zy_f32x4 = vaddq_f32(cov_zy_a_f32x4, cov_zy_b_f32x4);
661
- float32x4_t cov_zz_f32x4 = vaddq_f32(cov_zz_a_f32x4, cov_zz_b_f32x4);
662
- float32x4_t variance_a_f32x4 = vaddq_f32(variance_a_a_f32x4, variance_a_b_f32x4);
663
-
664
488
  // Reduce vector accumulators
665
489
  nk_f32_t sum_a_x = vaddvq_f32(sum_a_x_f32x4);
666
490
  nk_f32_t sum_a_y = vaddvq_f32(sum_a_y_f32x4);
@@ -669,16 +493,17 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
669
493
  nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
670
494
  nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
671
495
 
672
- nk_f32_t covariance_x_x = vaddvq_f32(cov_xx_f32x4);
673
- nk_f32_t covariance_x_y = vaddvq_f32(cov_xy_f32x4);
674
- nk_f32_t covariance_x_z = vaddvq_f32(cov_xz_f32x4);
675
- nk_f32_t covariance_y_x = vaddvq_f32(cov_yx_f32x4);
676
- nk_f32_t covariance_y_y = vaddvq_f32(cov_yy_f32x4);
677
- nk_f32_t covariance_y_z = vaddvq_f32(cov_yz_f32x4);
678
- nk_f32_t covariance_z_x = vaddvq_f32(cov_zx_f32x4);
679
- nk_f32_t covariance_z_y = vaddvq_f32(cov_zy_f32x4);
680
- nk_f32_t covariance_z_z = vaddvq_f32(cov_zz_f32x4);
681
- nk_f32_t variance_a_sum = vaddvq_f32(variance_a_f32x4);
496
+ nk_f32_t covariance_x_x = vaddvq_f32(covariance_xx_f32x4);
497
+ nk_f32_t covariance_x_y = vaddvq_f32(covariance_xy_f32x4);
498
+ nk_f32_t covariance_x_z = vaddvq_f32(covariance_xz_f32x4);
499
+ nk_f32_t covariance_y_x = vaddvq_f32(covariance_yx_f32x4);
500
+ nk_f32_t covariance_y_y = vaddvq_f32(covariance_yy_f32x4);
501
+ nk_f32_t covariance_y_z = vaddvq_f32(covariance_yz_f32x4);
502
+ nk_f32_t covariance_z_x = vaddvq_f32(covariance_zx_f32x4);
503
+ nk_f32_t covariance_z_y = vaddvq_f32(covariance_zy_f32x4);
504
+ nk_f32_t covariance_z_z = vaddvq_f32(covariance_zz_f32x4);
505
+ nk_f32_t norm_squared_a = vaddvq_f32(norm_squared_a_f32x4);
506
+ nk_f32_t norm_squared_b = vaddvq_f32(norm_squared_b_f32x4);
682
507
 
683
508
  // Compute centroids
684
509
  nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
@@ -692,9 +517,15 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
692
517
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
693
518
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
694
519
 
695
- // Compute centered variance of A
696
- nk_f32_t variance_a = variance_a_sum * inv_n -
697
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
520
+ // Centered norm-squared via parallel-axis identity; clamp at zero for numeric safety.
521
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
522
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
523
+ centroid_a_z * centroid_a_z);
524
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
525
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
526
+ centroid_b_z * centroid_b_z);
527
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
528
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
698
529
 
699
530
  // Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
700
531
  covariance_x_x -= n * centroid_a_x * centroid_b_x;
@@ -710,49 +541,78 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
710
541
  // Compute SVD
711
542
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
712
543
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
713
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
714
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
715
-
716
- // R = V * Uᵀ
717
- nk_f32_t r[9];
718
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
719
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
720
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
721
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
722
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
723
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
724
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
725
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
726
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
727
-
728
- // Handle reflection and compute scale: c = trace(D × S) / variance(a)
729
- // D = diag(1, 1, det(R)), svd_s contains proper positive singular values on diagonal
730
- nk_f32_t rotation_det = nk_det3x3_f32_(r);
731
- nk_f32_t sign_det = rotation_det < 0 ? -1.0f : 1.0f;
732
- nk_f32_t trace_scaled_s = svd_s[0] + svd_s[4] + sign_det * svd_s[8];
733
- nk_f32_t c = trace_scaled_s / ((nk_f32_t)n * variance_a);
734
- if (scale) *scale = c;
735
544
 
736
- if (rotation_det < 0) {
737
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
738
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
739
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
740
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
741
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
742
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
743
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
744
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
745
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
746
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
545
+ // Identity-dominant short-circuit: if H ≈ diag(positive entries), R = I and trace(R·H) = trace(H).
546
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
547
+ cross_covariance[4] * cross_covariance[4] +
548
+ cross_covariance[8] * cross_covariance[8];
549
+ nk_f32_t covariance_offdiagonal_norm_squared =
550
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
551
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
552
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
553
+ nk_f32_t optimal_rotation[9];
554
+ nk_f32_t trace_rotation_covariance;
555
+ nk_f32_t c;
556
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
557
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
558
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
559
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
560
+ optimal_rotation[8] = 1;
561
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
562
+ c = centered_norm_squared_a > 0.0f ? trace_rotation_covariance / centered_norm_squared_a : 0.0f;
747
563
  }
564
+ else {
565
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
566
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
567
+
568
+ // R = V * Uᵀ
569
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
570
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
571
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
572
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
573
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
574
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
575
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
576
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
577
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
578
+
579
+ // Handle reflection and compute scale: c = trace(D · S) / ‖a-ā‖²
580
+ // D = diag(1, 1, det(R)), svd_diagonal contains proper positive singular values on diagonal
581
+ nk_f32_t rotation_det = nk_det3x3_f32_(optimal_rotation);
582
+ nk_f32_t sign_det = rotation_det < 0 ? -1.0f : 1.0f;
583
+ nk_f32_t trace_scaled_s = svd_diagonal[0] + svd_diagonal[4] + sign_det * svd_diagonal[8];
584
+ c = centered_norm_squared_a > 0.0f ? trace_scaled_s / centered_norm_squared_a : 0.0f;
585
+
586
+ if (rotation_det < 0) {
587
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
588
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
589
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
590
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
591
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
592
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
593
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
594
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
595
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
596
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
597
+ }
598
+
599
+ trace_rotation_covariance =
600
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
601
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
602
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
603
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
604
+ optimal_rotation[8] * cross_covariance[8];
605
+ }
606
+ if (scale) *scale = c;
748
607
 
749
608
  // Output rotation matrix
750
609
  if (rotation)
751
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
610
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
752
611
 
753
- // Compute RMSD after similarity transform: c × R × a - b‖
754
- nk_f32_t sum_squared = nk_transformed_ssd_bf16_neonbfdot_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
755
- centroid_b_x, centroid_b_y, centroid_b_z);
612
+ // Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² 2c·trace(R · H_centered).
613
+ nk_f32_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
614
+ 2.0f * c * trace_rotation_covariance;
615
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
756
616
  *result = nk_f32_sqrt_neon(sum_squared * inv_n);
757
617
  }
758
618