numkong 7.5.0 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/binding.gyp +18 -0
  2. package/c/dispatch_e5m2.c +23 -3
  3. package/include/numkong/capabilities.h +1 -1
  4. package/include/numkong/cast/README.md +3 -0
  5. package/include/numkong/cast/haswell.h +28 -64
  6. package/include/numkong/cast/serial.h +17 -0
  7. package/include/numkong/cast/skylake.h +67 -52
  8. package/include/numkong/cast.h +1 -0
  9. package/include/numkong/dot/README.md +1 -0
  10. package/include/numkong/dot/haswell.h +92 -13
  11. package/include/numkong/dot/serial.h +15 -0
  12. package/include/numkong/dot/skylake.h +61 -14
  13. package/include/numkong/dots/README.md +2 -0
  14. package/include/numkong/dots/graniteamx.h +434 -0
  15. package/include/numkong/dots/haswell.h +28 -28
  16. package/include/numkong/dots/sapphireamx.h +1 -1
  17. package/include/numkong/dots/serial.h +23 -8
  18. package/include/numkong/dots/skylake.h +28 -23
  19. package/include/numkong/dots.h +12 -0
  20. package/include/numkong/each/serial.h +18 -1
  21. package/include/numkong/geospatial/serial.h +14 -3
  22. package/include/numkong/maxsim/serial.h +15 -0
  23. package/include/numkong/mesh/README.md +50 -44
  24. package/include/numkong/mesh/genoa.h +462 -0
  25. package/include/numkong/mesh/haswell.h +806 -933
  26. package/include/numkong/mesh/neon.h +871 -943
  27. package/include/numkong/mesh/neonbfdot.h +382 -522
  28. package/include/numkong/mesh/neonfhm.h +676 -0
  29. package/include/numkong/mesh/rvv.h +404 -319
  30. package/include/numkong/mesh/serial.h +204 -162
  31. package/include/numkong/mesh/skylake.h +1029 -1585
  32. package/include/numkong/mesh/v128relaxed.h +403 -377
  33. package/include/numkong/mesh.h +38 -0
  34. package/include/numkong/reduce/serial.h +15 -1
  35. package/include/numkong/sparse/serial.h +17 -2
  36. package/include/numkong/spatial/genoa.h +0 -68
  37. package/include/numkong/spatial/haswell.h +98 -56
  38. package/include/numkong/spatial/serial.h +15 -0
  39. package/include/numkong/spatial/skylake.h +114 -54
  40. package/include/numkong/spatial.h +0 -12
  41. package/include/numkong/spatials/graniteamx.h +128 -0
  42. package/include/numkong/spatials/serial.h +18 -1
  43. package/include/numkong/spatials/skylake.h +2 -2
  44. package/include/numkong/spatials.h +17 -0
  45. package/include/numkong/tensor.hpp +107 -23
  46. package/javascript/numkong.c +3 -2
  47. package/package.json +7 -7
  48. package/wasm/numkong.wasm +0 -0
@@ -14,9 +14,12 @@
14
14
  * _mm512_permutex2var_ps VPERMT2PS (ZMM, ZMM, ZMM) 3cy @ p5 4cy @ p12
15
15
  * _mm512_extractf32x8_ps VEXTRACTF32X8 (YMM, ZMM, I8) 3cy @ p5 1cy @ p0123
16
16
  *
17
- * Point cloud operations use VPERMT2PS for stride-3 deinterleaving of xyz coordinates, avoiding
18
- * expensive gather instructions. This achieves ~1.8x speedup over scalar deinterleaving. Dual FMA
19
- * accumulators on Skylake-X server chips hide the 4cy latency for centroid and covariance computation.
17
+ * Most `*_f32` mesh kernels use a 15-lane stride-3 chunk layout: 5 xyz triplets per ZMM (lane 15
18
+ * masked to zero) so the xyz phase is identical across all chunks and no per-chunk deinterleave is
19
+ * needed. The 9 cross-covariance cells come from three accumulators a*b, a*rot1(b), a*rot2(b)
20
+ * demuxed per channel post-loop, where rot1/rot2 are cheap within-triplet permutexvar rotations.
21
+ * `*_f64`, `*_f16`, `*_bf16` kernels still use VPERMT2PS deinterleave (helpers retained below).
22
+ * Dual FMA accumulators on Skylake-X hide the 4cy latency for centroid and covariance computation.
20
23
  */
21
24
  #ifndef NK_MESH_SKYLAKE_H
22
25
  #define NK_MESH_SKYLAKE_H
@@ -231,10 +234,6 @@ NK_INTERNAL nk_f64_t nk_reduce_stable_f64x8_skylake_(__m512d values_f64x8) {
231
234
  return sum + compensation;
232
235
  }
233
236
 
234
- NK_INTERNAL void nk_rotation_from_svd_f64_skylake_(nk_f64_t const *svd_u, nk_f64_t const *svd_v, nk_f64_t *rotation) {
235
- nk_rotation_from_svd_f64_serial_(svd_u, svd_v, rotation);
236
- }
237
-
238
237
  NK_INTERNAL void nk_accumulate_square_f64x8_skylake_(__m512d *sum_f64x8, __m512d *compensation_f64x8,
239
238
  __m512d values_f64x8) {
240
239
  __m512d product_f64x8 = _mm512_mul_pd(values_f64x8, values_f64x8);
@@ -248,394 +247,181 @@ NK_INTERNAL void nk_accumulate_square_f64x8_skylake_(__m512d *sum_f64x8, __m512d
248
247
  *compensation_f64x8 = _mm512_add_pd(*compensation_f64x8, _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
249
248
  }
250
249
 
251
- /* Compute sum of squared distances after applying rotation (and optional scale).
252
- * Used by kabsch (scale=1.0) and umeyama (scale=computed_scale).
253
- * Returns sum_squared, caller computes √(sum_squared / n).
254
- */
255
- NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_skylake_(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
256
- nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
257
- nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
258
- nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
259
- nk_f64_t centroid_b_z) {
260
- __m512d scaled_rotation_x_x_f64x8 = _mm512_set1_pd(scale * r[0]);
261
- __m512d scaled_rotation_x_y_f64x8 = _mm512_set1_pd(scale * r[1]);
262
- __m512d scaled_rotation_x_z_f64x8 = _mm512_set1_pd(scale * r[2]);
263
- __m512d scaled_rotation_y_x_f64x8 = _mm512_set1_pd(scale * r[3]);
264
- __m512d scaled_rotation_y_y_f64x8 = _mm512_set1_pd(scale * r[4]);
265
- __m512d scaled_rotation_y_z_f64x8 = _mm512_set1_pd(scale * r[5]);
266
- __m512d scaled_rotation_z_x_f64x8 = _mm512_set1_pd(scale * r[6]);
267
- __m512d scaled_rotation_z_y_f64x8 = _mm512_set1_pd(scale * r[7]);
268
- __m512d scaled_rotation_z_z_f64x8 = _mm512_set1_pd(scale * r[8]);
269
- __m512d centroid_a_x_f64x8 = _mm512_set1_pd(centroid_a_x), centroid_a_y_f64x8 = _mm512_set1_pd(centroid_a_y);
270
- __m512d centroid_a_z_f64x8 = _mm512_set1_pd(centroid_a_z), centroid_b_x_f64x8 = _mm512_set1_pd(centroid_b_x);
271
- __m512d centroid_b_y_f64x8 = _mm512_set1_pd(centroid_b_y), centroid_b_z_f64x8 = _mm512_set1_pd(centroid_b_z);
272
- __m512d sum_squared_f64x8 = _mm512_setzero_pd();
273
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
274
- nk_size_t index = 0;
275
-
276
- for (; index + 16 <= n; index += 16) {
277
- nk_deinterleave_f32x16_skylake_(a + index * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16),
278
- nk_deinterleave_f32x16_skylake_(b + index * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
279
- __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
280
- __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
281
- __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
282
- __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
283
- __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
284
- __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
285
- __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
286
- __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
287
- __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
288
- __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
289
- __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
290
- __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
291
-
292
- __m512d centered_a_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, centroid_a_x_f64x8);
293
- __m512d centered_a_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, centroid_a_x_f64x8);
294
- __m512d centered_a_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, centroid_a_y_f64x8);
295
- __m512d centered_a_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, centroid_a_y_f64x8);
296
- __m512d centered_a_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, centroid_a_z_f64x8);
297
- __m512d centered_a_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, centroid_a_z_f64x8);
298
- __m512d centered_b_x_low_f64x8 = _mm512_sub_pd(b_x_low_f64x8, centroid_b_x_f64x8);
299
- __m512d centered_b_x_high_f64x8 = _mm512_sub_pd(b_x_high_f64x8, centroid_b_x_f64x8);
300
- __m512d centered_b_y_low_f64x8 = _mm512_sub_pd(b_y_low_f64x8, centroid_b_y_f64x8);
301
- __m512d centered_b_y_high_f64x8 = _mm512_sub_pd(b_y_high_f64x8, centroid_b_y_f64x8);
302
- __m512d centered_b_z_low_f64x8 = _mm512_sub_pd(b_z_low_f64x8, centroid_b_z_f64x8);
303
- __m512d centered_b_z_high_f64x8 = _mm512_sub_pd(b_z_high_f64x8, centroid_b_z_f64x8);
304
-
305
- __m512d rotated_a_x_low_f64x8 = _mm512_fmadd_pd(
306
- scaled_rotation_x_z_f64x8, centered_a_z_low_f64x8,
307
- _mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_low_f64x8,
308
- _mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_low_f64x8)));
309
- __m512d rotated_a_x_high_f64x8 = _mm512_fmadd_pd(
310
- scaled_rotation_x_z_f64x8, centered_a_z_high_f64x8,
311
- _mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_high_f64x8,
312
- _mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_high_f64x8)));
313
- __m512d rotated_a_y_low_f64x8 = _mm512_fmadd_pd(
314
- scaled_rotation_y_z_f64x8, centered_a_z_low_f64x8,
315
- _mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_low_f64x8,
316
- _mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_low_f64x8)));
317
- __m512d rotated_a_y_high_f64x8 = _mm512_fmadd_pd(
318
- scaled_rotation_y_z_f64x8, centered_a_z_high_f64x8,
319
- _mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_high_f64x8,
320
- _mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_high_f64x8)));
321
- __m512d rotated_a_z_low_f64x8 = _mm512_fmadd_pd(
322
- scaled_rotation_z_z_f64x8, centered_a_z_low_f64x8,
323
- _mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_low_f64x8,
324
- _mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_low_f64x8)));
325
- __m512d rotated_a_z_high_f64x8 = _mm512_fmadd_pd(
326
- scaled_rotation_z_z_f64x8, centered_a_z_high_f64x8,
327
- _mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_high_f64x8,
328
- _mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_high_f64x8)));
329
-
330
- __m512d delta_x_low_f64x8 = _mm512_sub_pd(rotated_a_x_low_f64x8, centered_b_x_low_f64x8);
331
- __m512d delta_x_high_f64x8 = _mm512_sub_pd(rotated_a_x_high_f64x8, centered_b_x_high_f64x8);
332
- __m512d delta_y_low_f64x8 = _mm512_sub_pd(rotated_a_y_low_f64x8, centered_b_y_low_f64x8);
333
- __m512d delta_y_high_f64x8 = _mm512_sub_pd(rotated_a_y_high_f64x8, centered_b_y_high_f64x8);
334
- __m512d delta_z_low_f64x8 = _mm512_sub_pd(rotated_a_z_low_f64x8, centered_b_z_low_f64x8);
335
- __m512d delta_z_high_f64x8 = _mm512_sub_pd(rotated_a_z_high_f64x8, centered_b_z_high_f64x8);
336
-
337
- __m512d batch_sum_squared_f64x8 = _mm512_add_pd(_mm512_mul_pd(delta_x_low_f64x8, delta_x_low_f64x8),
338
- _mm512_mul_pd(delta_x_high_f64x8, delta_x_high_f64x8));
339
- batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, batch_sum_squared_f64x8);
340
- batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, batch_sum_squared_f64x8);
341
- batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, batch_sum_squared_f64x8);
342
- batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, batch_sum_squared_f64x8);
343
- sum_squared_f64x8 = _mm512_add_pd(sum_squared_f64x8, batch_sum_squared_f64x8);
344
- }
345
-
346
- nk_f64_t sum_squared = _mm512_reduce_add_pd(sum_squared_f64x8);
347
- for (; index < n; ++index) {
348
- nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x,
349
- centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y,
350
- centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
351
- nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x,
352
- centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y,
353
- centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
354
- nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z),
355
- rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z),
356
- rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
357
- nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
358
- delta_z = rotated_a_z - centered_b_z;
359
- sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
360
- }
361
-
362
- return sum_squared;
363
- }
364
-
365
- /* Compute sum of squared distances for f64 after applying rotation (and optional scale).
366
- * Rotation matrix, scale and data are all f64 for full precision.
367
- */
368
- NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_skylake_(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
369
- nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
370
- nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
371
- nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
372
- nk_f64_t centroid_b_z) {
373
- // Broadcast scaled rotation matrix elements
374
- __m512d scaled_rotation_x_x_f64x8 = _mm512_set1_pd(scale * r[0]);
375
- __m512d scaled_rotation_x_y_f64x8 = _mm512_set1_pd(scale * r[1]);
376
- __m512d scaled_rotation_x_z_f64x8 = _mm512_set1_pd(scale * r[2]);
377
- __m512d scaled_rotation_y_x_f64x8 = _mm512_set1_pd(scale * r[3]);
378
- __m512d scaled_rotation_y_y_f64x8 = _mm512_set1_pd(scale * r[4]);
379
- __m512d scaled_rotation_y_z_f64x8 = _mm512_set1_pd(scale * r[5]);
380
- __m512d scaled_rotation_z_x_f64x8 = _mm512_set1_pd(scale * r[6]);
381
- __m512d scaled_rotation_z_y_f64x8 = _mm512_set1_pd(scale * r[7]);
382
- __m512d scaled_rotation_z_z_f64x8 = _mm512_set1_pd(scale * r[8]);
383
-
384
- // Broadcast centroids
385
- __m512d centroid_a_x_f64x8 = _mm512_set1_pd(centroid_a_x);
386
- __m512d centroid_a_y_f64x8 = _mm512_set1_pd(centroid_a_y);
387
- __m512d centroid_a_z_f64x8 = _mm512_set1_pd(centroid_a_z);
388
- __m512d centroid_b_x_f64x8 = _mm512_set1_pd(centroid_b_x);
389
- __m512d centroid_b_y_f64x8 = _mm512_set1_pd(centroid_b_y);
390
- __m512d centroid_b_z_f64x8 = _mm512_set1_pd(centroid_b_z);
391
-
392
- __m512d sum_squared_f64x8 = _mm512_setzero_pd();
393
- __m512d sum_squared_compensation_f64x8 = _mm512_setzero_pd();
394
- __m512d a_x_f64x8, a_y_f64x8, a_z_f64x8, b_x_f64x8, b_y_f64x8, b_z_f64x8;
395
- nk_size_t j = 0;
396
-
397
- for (; j + 8 <= n; j += 8) {
398
- nk_deinterleave_f64x8_skylake_(a + j * 3, &a_x_f64x8, &a_y_f64x8, &a_z_f64x8);
399
- nk_deinterleave_f64x8_skylake_(b + j * 3, &b_x_f64x8, &b_y_f64x8, &b_z_f64x8);
400
-
401
- // Center points
402
- __m512d pa_x_f64x8 = _mm512_sub_pd(a_x_f64x8, centroid_a_x_f64x8);
403
- __m512d pa_y_f64x8 = _mm512_sub_pd(a_y_f64x8, centroid_a_y_f64x8);
404
- __m512d pa_z_f64x8 = _mm512_sub_pd(a_z_f64x8, centroid_a_z_f64x8);
405
- __m512d pb_x_f64x8 = _mm512_sub_pd(b_x_f64x8, centroid_b_x_f64x8);
406
- __m512d pb_y_f64x8 = _mm512_sub_pd(b_y_f64x8, centroid_b_y_f64x8);
407
- __m512d pb_z_f64x8 = _mm512_sub_pd(b_z_f64x8, centroid_b_z_f64x8);
408
-
409
- // Rotate and scale: ra = scale * R * pa
410
- __m512d ra_x_f64x8 = _mm512_fmadd_pd(scaled_rotation_x_z_f64x8, pa_z_f64x8,
411
- _mm512_fmadd_pd(scaled_rotation_x_y_f64x8, pa_y_f64x8,
412
- _mm512_mul_pd(scaled_rotation_x_x_f64x8, pa_x_f64x8)));
413
- __m512d ra_y_f64x8 = _mm512_fmadd_pd(scaled_rotation_y_z_f64x8, pa_z_f64x8,
414
- _mm512_fmadd_pd(scaled_rotation_y_y_f64x8, pa_y_f64x8,
415
- _mm512_mul_pd(scaled_rotation_y_x_f64x8, pa_x_f64x8)));
416
- __m512d ra_z_f64x8 = _mm512_fmadd_pd(scaled_rotation_z_z_f64x8, pa_z_f64x8,
417
- _mm512_fmadd_pd(scaled_rotation_z_y_f64x8, pa_y_f64x8,
418
- _mm512_mul_pd(scaled_rotation_z_x_f64x8, pa_x_f64x8)));
419
-
420
- // Delta and accumulate
421
- __m512d delta_x_f64x8 = _mm512_sub_pd(ra_x_f64x8, pb_x_f64x8);
422
- __m512d delta_y_f64x8 = _mm512_sub_pd(ra_y_f64x8, pb_y_f64x8);
423
- __m512d delta_z_f64x8 = _mm512_sub_pd(ra_z_f64x8, pb_z_f64x8);
424
-
425
- nk_accumulate_square_f64x8_skylake_(&sum_squared_f64x8, &sum_squared_compensation_f64x8, delta_x_f64x8);
426
- nk_accumulate_square_f64x8_skylake_(&sum_squared_f64x8, &sum_squared_compensation_f64x8, delta_y_f64x8);
427
- nk_accumulate_square_f64x8_skylake_(&sum_squared_f64x8, &sum_squared_compensation_f64x8, delta_z_f64x8);
428
- }
429
-
430
- nk_f64_t sum_squared = nk_dot_stable_sum_f64x8_skylake_(sum_squared_f64x8, sum_squared_compensation_f64x8);
431
- nk_f64_t sum_squared_compensation = 0.0;
432
-
433
- // Scalar tail
434
- for (; j < n; ++j) {
435
- nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x, pa_y = a[j * 3 + 1] - centroid_a_y,
436
- pa_z = a[j * 3 + 2] - centroid_a_z;
437
- nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x, pb_y = b[j * 3 + 1] - centroid_b_y,
438
- pb_z = b[j * 3 + 2] - centroid_b_z;
439
-
440
- nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
441
- ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
442
- ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
443
-
444
- nk_f64_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
445
- nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
446
- nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
447
- nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
448
- }
449
-
450
- return sum_squared + sum_squared_compensation;
451
- }
452
-
453
- /* Compute sum of squared distances for f16 data after applying rotation (and optional scale).
454
- * Loads f16, converts to f32 for computation. Rotation matrix, scale, and centroids are f32.
250
+ /* Single-pass streaming statistics over an f32 xyz point-cloud pair.
251
+ * Processes 5 xyz triplets per chunk (15 fp32 lanes, lane 15 masked to zero) so the stride-3
252
+ * phase is identical across all chunks and no deinterleave is needed. All accumulators are f64.
253
+ * Outputs via pointers:
254
+ * sum_a_out[3] / sum_b_out[3] - per-channel Sum(a), Sum(b)
255
+ * raw_covarianceariance_out[9] - row-major uncentered Sum(a_j * b_k)
256
+ * norm_squared_a_out / norm_squared_b_out - Sum(||a||^2), Sum(||b||^2) across all three channels
257
+ *
258
+ * The 9 H-cells come from three product accumulators prod_{diag,rot1,rot2} demuxed post-loop
259
+ * by a-channel. Rotations of b happen in fp64 via permutex2var_pd on the already-widened
260
+ * halves: widening the rotated fp32 vector would add two extra cvtps_pd per chunk, which we
261
+ * skip. Post-loop, each (accumulator, channel) pair is gathered into a single 8-lane vector
262
+ * via one maskz-permutex2var_pd and reduced once — 17 horizontal reductions total (the
263
+ * theoretical minimum for 17 scalar outputs) instead of 32 masked ones.
455
264
  */
456
- NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_skylake_(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
457
- nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
458
- nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
459
- nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
460
- nk_f32_t centroid_b_z) {
461
- __m512 scaled_rotation_x_x_f32x16 = _mm512_set1_ps(scale * r[0]);
462
- __m512 scaled_rotation_x_y_f32x16 = _mm512_set1_ps(scale * r[1]);
463
- __m512 scaled_rotation_x_z_f32x16 = _mm512_set1_ps(scale * r[2]);
464
- __m512 scaled_rotation_y_x_f32x16 = _mm512_set1_ps(scale * r[3]);
465
- __m512 scaled_rotation_y_y_f32x16 = _mm512_set1_ps(scale * r[4]);
466
- __m512 scaled_rotation_y_z_f32x16 = _mm512_set1_ps(scale * r[5]);
467
- __m512 scaled_rotation_z_x_f32x16 = _mm512_set1_ps(scale * r[6]);
468
- __m512 scaled_rotation_z_y_f32x16 = _mm512_set1_ps(scale * r[7]);
469
- __m512 scaled_rotation_z_z_f32x16 = _mm512_set1_ps(scale * r[8]);
470
-
471
- __m512 centroid_a_x_f32x16 = _mm512_set1_ps(centroid_a_x);
472
- __m512 centroid_a_y_f32x16 = _mm512_set1_ps(centroid_a_y);
473
- __m512 centroid_a_z_f32x16 = _mm512_set1_ps(centroid_a_z);
474
- __m512 centroid_b_x_f32x16 = _mm512_set1_ps(centroid_b_x);
475
- __m512 centroid_b_y_f32x16 = _mm512_set1_ps(centroid_b_y);
476
- __m512 centroid_b_z_f32x16 = _mm512_set1_ps(centroid_b_z);
477
-
478
- __m512 sum_squared_f32x16 = _mm512_setzero_ps();
479
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
480
- nk_size_t j = 0;
481
-
482
- for (; j + 16 <= n; j += 16) {
483
- nk_deinterleave_f16x16_to_f32x16_skylake_(a + j * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
484
- nk_deinterleave_f16x16_to_f32x16_skylake_(b + j * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
485
-
486
- __m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
487
- __m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
488
- __m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
489
- __m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
490
- __m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
491
- __m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
492
-
493
- __m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
494
- _mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
495
- _mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
496
- __m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
497
- _mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
498
- _mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
499
- __m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
500
- _mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
501
- _mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
502
-
503
- __m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
504
- __m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
505
- __m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
506
-
507
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
508
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
509
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
510
- }
511
-
512
- // Tail: deinterleave remaining points into zero-initialized vectors
513
- if (j < n) {
514
- nk_size_t tail = n - j;
515
- nk_deinterleave_f16_tail_to_f32x16_skylake_(a + j * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
516
- nk_deinterleave_f16_tail_to_f32x16_skylake_(b + j * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
517
-
518
- __m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
519
- __m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
520
- __m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
521
- __m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
522
- __m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
523
- __m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
524
-
525
- __m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
526
- _mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
527
- _mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
528
- __m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
529
- _mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
530
- _mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
531
- __m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
532
- _mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
533
- _mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
534
-
535
- __m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
536
- __m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
537
- __m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
538
-
539
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
540
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
541
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
542
- }
265
+ NK_INTERNAL void nk_mesh_streaming_stats_f32_skylake_( //
266
+ nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *sum_a_out, nk_f64_t *sum_b_out,
267
+ nk_f64_t *raw_covarianceariance_out, nk_f64_t *norm_squared_a_out, nk_f64_t *norm_squared_b_out) {
268
+
269
+ // Within-triplet rotation indices for fp64 permutex2var across (b_low, b_high) as the two
270
+ // sources. Indices 0..7 pull from b_low, 8..15 pull from b_high. Derived from the fp32
271
+ // rotation pattern {1,2,0,4,5,3,7,8,6,10,11,9,13,14,12,15} (rot1) and {2,0,1,5,3,4,8,6,
272
+ // 7,11,9,10,14,12,13,15} (rot2) split at fp32 lane 8.
273
+ __m512i const idx_rotation_1_low_i64x8 = _mm512_setr_epi64(1, 2, 0, 4, 5, 3, 7, 8);
274
+ __m512i const idx_rotation_1_high_i64x8 = _mm512_setr_epi64(6, 10, 11, 9, 13, 14, 12, 15);
275
+ __m512i const idx_rotation_2_low_i64x8 = _mm512_setr_epi64(2, 0, 1, 5, 3, 4, 8, 6);
276
+ __m512i const idx_rotation_2_high_i64x8 = _mm512_setr_epi64(7, 11, 9, 10, 14, 12, 13, 15);
277
+
278
+ // Per-channel gather indices packing the 5 contributing fp64 lanes (across both halves)
279
+ // into lanes 0..4 of the output, with lanes 5..7 zeroed by maskz so the subsequent
280
+ // _mm512_reduce_add_pd is exact without needing a mask-reduce variant.
281
+ // x -> low {0,3,6} + high {1,4} = idx [0, 3, 6, 9, 12, _, _, _]
282
+ // y -> low {1,4,7} + high {2,5} = idx [1, 4, 7, 10, 13, _, _, _]
283
+ // z -> low {2,5} + high {0,3,6} = idx [2, 5, 8, 11, 14, _, _, _]
284
+ __m512i const idx_channel_x_i64x8 = _mm512_setr_epi64(0, 3, 6, 9, 12, 0, 0, 0);
285
+ __m512i const idx_channel_y_i64x8 = _mm512_setr_epi64(1, 4, 7, 10, 13, 0, 0, 0);
286
+ __m512i const idx_channel_z_i64x8 = _mm512_setr_epi64(2, 5, 8, 11, 14, 0, 0, 0);
287
+ __mmask8 const channel_lanes_mask = 0x1F;
543
288
 
544
- return _mm512_reduce_add_ps(sum_squared_f32x16);
545
- }
546
-
547
- /* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
548
- * Loads bf16, converts to f32 for computation. Rotation matrix, scale, and centroids are f32.
549
- */
550
- NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_skylake_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
551
- nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
552
- nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
553
- nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
554
- nk_f32_t centroid_b_z) {
555
- __m512 scaled_rotation_x_x_f32x16 = _mm512_set1_ps(scale * r[0]);
556
- __m512 scaled_rotation_x_y_f32x16 = _mm512_set1_ps(scale * r[1]);
557
- __m512 scaled_rotation_x_z_f32x16 = _mm512_set1_ps(scale * r[2]);
558
- __m512 scaled_rotation_y_x_f32x16 = _mm512_set1_ps(scale * r[3]);
559
- __m512 scaled_rotation_y_y_f32x16 = _mm512_set1_ps(scale * r[4]);
560
- __m512 scaled_rotation_y_z_f32x16 = _mm512_set1_ps(scale * r[5]);
561
- __m512 scaled_rotation_z_x_f32x16 = _mm512_set1_ps(scale * r[6]);
562
- __m512 scaled_rotation_z_y_f32x16 = _mm512_set1_ps(scale * r[7]);
563
- __m512 scaled_rotation_z_z_f32x16 = _mm512_set1_ps(scale * r[8]);
564
-
565
- __m512 centroid_a_x_f32x16 = _mm512_set1_ps(centroid_a_x);
566
- __m512 centroid_a_y_f32x16 = _mm512_set1_ps(centroid_a_y);
567
- __m512 centroid_a_z_f32x16 = _mm512_set1_ps(centroid_a_z);
568
- __m512 centroid_b_x_f32x16 = _mm512_set1_ps(centroid_b_x);
569
- __m512 centroid_b_y_f32x16 = _mm512_set1_ps(centroid_b_y);
570
- __m512 centroid_b_z_f32x16 = _mm512_set1_ps(centroid_b_z);
289
+ __m512d const zeros_f64x8 = _mm512_setzero_pd();
290
+ __m512d sum_a_low_f64x8 = zeros_f64x8, sum_a_high_f64x8 = zeros_f64x8;
291
+ __m512d sum_b_low_f64x8 = zeros_f64x8, sum_b_high_f64x8 = zeros_f64x8;
292
+ __m512d norm_squared_a_low_f64x8 = zeros_f64x8, norm_squared_a_high_f64x8 = zeros_f64x8;
293
+ __m512d norm_squared_b_low_f64x8 = zeros_f64x8, norm_squared_b_high_f64x8 = zeros_f64x8;
294
+ __m512d product_diagonal_low_f64x8 = zeros_f64x8, product_diagonal_high_f64x8 = zeros_f64x8;
295
+ __m512d product_rotation_1_low_f64x8 = zeros_f64x8, product_rotation_1_high_f64x8 = zeros_f64x8;
296
+ __m512d product_rotation_2_low_f64x8 = zeros_f64x8, product_rotation_2_high_f64x8 = zeros_f64x8;
571
297
 
572
- __m512 sum_squared_f32x16 = _mm512_setzero_ps();
573
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
574
- nk_size_t j = 0;
575
-
576
- for (; j + 16 <= n; j += 16) {
577
- nk_deinterleave_bf16x16_to_f32x16_skylake_(a + j * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
578
- nk_deinterleave_bf16x16_to_f32x16_skylake_(b + j * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
579
-
580
- __m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
581
- __m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
582
- __m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
583
- __m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
584
- __m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
585
- __m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
586
-
587
- __m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
588
- _mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
589
- _mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
590
- __m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
591
- _mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
592
- _mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
593
- __m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
594
- _mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
595
- _mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
596
-
597
- __m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
598
- __m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
599
- __m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
600
-
601
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
602
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
603
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
298
+ nk_size_t index = 0;
299
+ // Main loop: 5 points (15 fp32) per chunk, lane 15 zeroed by mask 0x7FFF.
300
+ for (; index + 5 <= n; index += 5) {
301
+ __m512 a_f32x16 = _mm512_maskz_loadu_ps(0x7FFF, a + index * 3);
302
+ __m512 b_f32x16 = _mm512_maskz_loadu_ps(0x7FFF, b + index * 3);
303
+
304
+ __m512d a_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_f32x16));
305
+ __m512d a_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_f32x16, 1));
306
+ __m512d b_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_f32x16));
307
+ __m512d b_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_f32x16, 1));
308
+
309
+ __m512d b_rot1_low_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_1_low_i64x8, b_high_f64x8);
310
+ __m512d b_rot1_high_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_1_high_i64x8, b_high_f64x8);
311
+ __m512d b_rot2_low_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_2_low_i64x8, b_high_f64x8);
312
+ __m512d b_rot2_high_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_2_high_i64x8, b_high_f64x8);
313
+
314
+ sum_a_low_f64x8 = _mm512_add_pd(sum_a_low_f64x8, a_low_f64x8);
315
+ sum_a_high_f64x8 = _mm512_add_pd(sum_a_high_f64x8, a_high_f64x8);
316
+ sum_b_low_f64x8 = _mm512_add_pd(sum_b_low_f64x8, b_low_f64x8);
317
+ sum_b_high_f64x8 = _mm512_add_pd(sum_b_high_f64x8, b_high_f64x8);
318
+
319
+ norm_squared_a_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, a_low_f64x8, norm_squared_a_low_f64x8);
320
+ norm_squared_a_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, a_high_f64x8, norm_squared_a_high_f64x8);
321
+ norm_squared_b_low_f64x8 = _mm512_fmadd_pd(b_low_f64x8, b_low_f64x8, norm_squared_b_low_f64x8);
322
+ norm_squared_b_high_f64x8 = _mm512_fmadd_pd(b_high_f64x8, b_high_f64x8, norm_squared_b_high_f64x8);
323
+
324
+ product_diagonal_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_low_f64x8, product_diagonal_low_f64x8);
325
+ product_diagonal_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_high_f64x8, product_diagonal_high_f64x8);
326
+ product_rotation_1_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_rot1_low_f64x8, product_rotation_1_low_f64x8);
327
+ product_rotation_1_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_rot1_high_f64x8, product_rotation_1_high_f64x8);
328
+ product_rotation_2_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_rot2_low_f64x8, product_rotation_2_low_f64x8);
329
+ product_rotation_2_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_rot2_high_f64x8, product_rotation_2_high_f64x8);
604
330
  }
605
331
 
606
- // Tail: deinterleave remaining points into zero-initialized vectors
607
- if (j < n) {
608
- nk_size_t tail = n - j;
609
- nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + j * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
610
- nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + j * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
611
-
612
- __m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
613
- __m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
614
- __m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
615
- __m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
616
- __m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
617
- __m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
618
-
619
- __m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
620
- _mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
621
- _mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
622
- __m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
623
- _mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
624
- _mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
625
- __m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
626
- _mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
627
- _mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
628
-
629
- __m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
630
- __m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
631
- __m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
632
-
633
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
634
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
635
- sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
332
+ // Tail: 1..4 points (3..12 fp32) via narrower mask; identical body.
333
+ if (index < n) {
334
+ nk_size_t tail_floats = (n - index) * 3;
335
+ __mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, tail_floats);
336
+ __m512 a_f32x16 = _mm512_maskz_loadu_ps(tail_mask, a + index * 3);
337
+ __m512 b_f32x16 = _mm512_maskz_loadu_ps(tail_mask, b + index * 3);
338
+
339
+ __m512d a_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_f32x16));
340
+ __m512d a_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_f32x16, 1));
341
+ __m512d b_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_f32x16));
342
+ __m512d b_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_f32x16, 1));
343
+
344
+ __m512d b_rot1_low_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_1_low_i64x8, b_high_f64x8);
345
+ __m512d b_rot1_high_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_1_high_i64x8, b_high_f64x8);
346
+ __m512d b_rot2_low_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_2_low_i64x8, b_high_f64x8);
347
+ __m512d b_rot2_high_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_2_high_i64x8, b_high_f64x8);
348
+
349
+ sum_a_low_f64x8 = _mm512_add_pd(sum_a_low_f64x8, a_low_f64x8);
350
+ sum_a_high_f64x8 = _mm512_add_pd(sum_a_high_f64x8, a_high_f64x8);
351
+ sum_b_low_f64x8 = _mm512_add_pd(sum_b_low_f64x8, b_low_f64x8);
352
+ sum_b_high_f64x8 = _mm512_add_pd(sum_b_high_f64x8, b_high_f64x8);
353
+
354
+ norm_squared_a_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, a_low_f64x8, norm_squared_a_low_f64x8);
355
+ norm_squared_a_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, a_high_f64x8, norm_squared_a_high_f64x8);
356
+ norm_squared_b_low_f64x8 = _mm512_fmadd_pd(b_low_f64x8, b_low_f64x8, norm_squared_b_low_f64x8);
357
+ norm_squared_b_high_f64x8 = _mm512_fmadd_pd(b_high_f64x8, b_high_f64x8, norm_squared_b_high_f64x8);
358
+
359
+ product_diagonal_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_low_f64x8, product_diagonal_low_f64x8);
360
+ product_diagonal_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_high_f64x8, product_diagonal_high_f64x8);
361
+ product_rotation_1_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_rot1_low_f64x8, product_rotation_1_low_f64x8);
362
+ product_rotation_1_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_rot1_high_f64x8, product_rotation_1_high_f64x8);
363
+ product_rotation_2_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_rot2_low_f64x8, product_rotation_2_low_f64x8);
364
+ product_rotation_2_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_rot2_high_f64x8, product_rotation_2_high_f64x8);
636
365
  }
637
366
 
638
- return _mm512_reduce_add_ps(sum_squared_f32x16);
367
+ // Post-loop: gather each (accumulator, a-channel) pair into a single 8-lane vector via one
368
+ // maskz-permutex2var_pd across (low, high) halves, then one _mm512_reduce_add_pd per scalar
369
+ // output. 17 reductions total (6 sums + 9 H cells + 2 norms) = the scalar-output floor.
370
+
371
+ __m512d sum_a_x_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_a_low_f64x8, idx_channel_x_i64x8,
372
+ sum_a_high_f64x8);
373
+ __m512d sum_a_y_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_a_low_f64x8, idx_channel_y_i64x8,
374
+ sum_a_high_f64x8);
375
+ __m512d sum_a_z_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_a_low_f64x8, idx_channel_z_i64x8,
376
+ sum_a_high_f64x8);
377
+ sum_a_out[0] = _mm512_reduce_add_pd(sum_a_x_f64x8);
378
+ sum_a_out[1] = _mm512_reduce_add_pd(sum_a_y_f64x8);
379
+ sum_a_out[2] = _mm512_reduce_add_pd(sum_a_z_f64x8);
380
+
381
+ __m512d sum_b_x_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_b_low_f64x8, idx_channel_x_i64x8,
382
+ sum_b_high_f64x8);
383
+ __m512d sum_b_y_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_b_low_f64x8, idx_channel_y_i64x8,
384
+ sum_b_high_f64x8);
385
+ __m512d sum_b_z_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_b_low_f64x8, idx_channel_z_i64x8,
386
+ sum_b_high_f64x8);
387
+ sum_b_out[0] = _mm512_reduce_add_pd(sum_b_x_f64x8);
388
+ sum_b_out[1] = _mm512_reduce_add_pd(sum_b_y_f64x8);
389
+ sum_b_out[2] = _mm512_reduce_add_pd(sum_b_z_f64x8);
390
+
391
+ // H cells: a-channel picks which demux mask applies; prod-vector picks which b-channel the
392
+ // product pairs a with (diag -> same, rot1 -> +1, rot2 -> +2 mod 3).
393
+ __m512d product_diagonal_x_f64x8 = _mm512_maskz_permutex2var_pd( //
394
+ channel_lanes_mask, product_diagonal_low_f64x8, idx_channel_x_i64x8, product_diagonal_high_f64x8);
395
+ __m512d product_diagonal_y_f64x8 = _mm512_maskz_permutex2var_pd( //
396
+ channel_lanes_mask, product_diagonal_low_f64x8, idx_channel_y_i64x8, product_diagonal_high_f64x8);
397
+ __m512d product_diagonal_z_f64x8 = _mm512_maskz_permutex2var_pd( //
398
+ channel_lanes_mask, product_diagonal_low_f64x8, idx_channel_z_i64x8, product_diagonal_high_f64x8);
399
+ __m512d product_rotation_1_x_f64x8 = _mm512_maskz_permutex2var_pd( //
400
+ channel_lanes_mask, product_rotation_1_low_f64x8, idx_channel_x_i64x8, product_rotation_1_high_f64x8);
401
+ __m512d product_rotation_1_y_f64x8 = _mm512_maskz_permutex2var_pd( //
402
+ channel_lanes_mask, product_rotation_1_low_f64x8, idx_channel_y_i64x8, product_rotation_1_high_f64x8);
403
+ __m512d product_rotation_1_z_f64x8 = _mm512_maskz_permutex2var_pd( //
404
+ channel_lanes_mask, product_rotation_1_low_f64x8, idx_channel_z_i64x8, product_rotation_1_high_f64x8);
405
+ __m512d product_rotation_2_x_f64x8 = _mm512_maskz_permutex2var_pd( //
406
+ channel_lanes_mask, product_rotation_2_low_f64x8, idx_channel_x_i64x8, product_rotation_2_high_f64x8);
407
+ __m512d product_rotation_2_y_f64x8 = _mm512_maskz_permutex2var_pd( //
408
+ channel_lanes_mask, product_rotation_2_low_f64x8, idx_channel_y_i64x8, product_rotation_2_high_f64x8);
409
+ __m512d product_rotation_2_z_f64x8 = _mm512_maskz_permutex2var_pd( //
410
+ channel_lanes_mask, product_rotation_2_low_f64x8, idx_channel_z_i64x8, product_rotation_2_high_f64x8);
411
+
412
+ raw_covarianceariance_out[0] = _mm512_reduce_add_pd(product_diagonal_x_f64x8); // H[x,x]
413
+ raw_covarianceariance_out[1] = _mm512_reduce_add_pd(product_rotation_1_x_f64x8); // H[x,y]
414
+ raw_covarianceariance_out[2] = _mm512_reduce_add_pd(product_rotation_2_x_f64x8); // H[x,z]
415
+ raw_covarianceariance_out[3] = _mm512_reduce_add_pd(product_rotation_2_y_f64x8); // H[y,x]
416
+ raw_covarianceariance_out[4] = _mm512_reduce_add_pd(product_diagonal_y_f64x8); // H[y,y]
417
+ raw_covarianceariance_out[5] = _mm512_reduce_add_pd(product_rotation_1_y_f64x8); // H[y,z]
418
+ raw_covarianceariance_out[6] = _mm512_reduce_add_pd(product_rotation_1_z_f64x8); // H[z,x]
419
+ raw_covarianceariance_out[7] = _mm512_reduce_add_pd(product_rotation_2_z_f64x8); // H[z,y]
420
+ raw_covarianceariance_out[8] = _mm512_reduce_add_pd(product_diagonal_z_f64x8); // H[z,z]
421
+
422
+ // Norms collapse all three channels, no demux.
423
+ *norm_squared_a_out = _mm512_reduce_add_pd(_mm512_add_pd(norm_squared_a_low_f64x8, norm_squared_a_high_f64x8));
424
+ *norm_squared_b_out = _mm512_reduce_add_pd(_mm512_add_pd(norm_squared_b_low_f64x8, norm_squared_b_high_f64x8));
639
425
  }
640
426
 
641
427
  NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
@@ -647,277 +433,55 @@ NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size
647
433
  if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
648
434
  if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
649
435
 
650
- __m512d const zeros_f64x8 = _mm512_setzero_pd();
651
- __m512d sum_squared_x_f64x8 = zeros_f64x8, sum_squared_y_f64x8 = zeros_f64x8, sum_squared_z_f64x8 = zeros_f64x8;
652
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
653
- nk_size_t i = 0;
654
-
655
- // Main loop with 2x unrolling (32 points per iteration)
656
- for (; i + 32 <= n; i += 32) {
657
- // Iteration 0: points i..i+15
658
- nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
659
- nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
660
- __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
661
- __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
662
- __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
663
- __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
664
- __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
665
- __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
666
- __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
667
- __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
668
- __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
669
- __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
670
- __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
671
- __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
672
-
673
- __m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
674
- __m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
675
- __m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
676
- __m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
677
- __m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
678
- __m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
679
- sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
680
- sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
681
- sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
682
- sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
683
- sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
684
- sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
685
-
686
- // Iteration 1: points i+16..i+31
687
- nk_deinterleave_f32x16_skylake_(a + (i + 16) * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
688
- nk_deinterleave_f32x16_skylake_(b + (i + 16) * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
689
- a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
690
- a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
691
- a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
692
- a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
693
- a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
694
- a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
695
- b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
696
- b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
697
- b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
698
- b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
699
- b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
700
- b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
701
-
702
- delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
703
- delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
704
- delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
705
- delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
706
- delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
707
- delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
708
- sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
709
- sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
710
- sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
711
- sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
712
- sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
713
- sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
714
- }
436
+ // 15-lane stride-3 chunks: 5 points (15 fp32) per iteration, lane 15 zeroed by mask.
437
+ // Identity rotation + zero centroid (per commit 1a83ab4f) make this a single (a-b)^2 sum.
438
+ __m512d sum_squared_low_f64x8 = _mm512_setzero_pd();
439
+ __m512d sum_squared_high_f64x8 = _mm512_setzero_pd();
440
+ nk_size_t index = 0;
715
441
 
716
- // Handle 16-point remainder
717
- for (; i + 16 <= n; i += 16) {
718
- nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
719
- nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
720
- __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
721
- __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
722
- __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
723
- __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
724
- __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
725
- __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
726
- __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
727
- __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
728
- __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
729
- __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
730
- __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
731
- __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
732
-
733
- __m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
734
- __m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
735
- __m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
736
- __m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
737
- __m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
738
- __m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
739
- sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
740
- sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
741
- sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
742
- sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
743
- sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
744
- sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
442
+ for (; index + 5 <= n; index += 5) {
443
+ __m512 a_f32x16 = _mm512_maskz_loadu_ps(0x7FFF, a + index * 3);
444
+ __m512 b_f32x16 = _mm512_maskz_loadu_ps(0x7FFF, b + index * 3);
445
+ // Widen before subtracting: fp32 subtraction catastrophically cancels when a ~ b.
446
+ __m512d a_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_f32x16));
447
+ __m512d a_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_f32x16, 1));
448
+ __m512d b_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_f32x16));
449
+ __m512d b_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_f32x16, 1));
450
+ __m512d delta_low_f64x8 = _mm512_sub_pd(a_low_f64x8, b_low_f64x8);
451
+ __m512d delta_high_f64x8 = _mm512_sub_pd(a_high_f64x8, b_high_f64x8);
452
+ sum_squared_low_f64x8 = _mm512_fmadd_pd(delta_low_f64x8, delta_low_f64x8, sum_squared_low_f64x8);
453
+ sum_squared_high_f64x8 = _mm512_fmadd_pd(delta_high_f64x8, delta_high_f64x8, sum_squared_high_f64x8);
745
454
  }
746
455
 
747
- // Tail: use masked gather for remaining < 16 points
748
- if (i < n) {
749
- nk_size_t tail = n - i;
750
- __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
751
- __m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
752
- __m512 zeros_f32x16 = _mm512_setzero_ps();
753
- nk_f32_t const *a_tail = a + i * 3;
754
- nk_f32_t const *b_tail = b + i * 3;
755
-
756
- a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
757
- a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
758
- a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
759
- b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
760
- b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
761
- b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
762
-
763
- __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
764
- __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
765
- __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
766
- __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
767
- __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
768
- __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
769
- __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
770
- __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
771
- __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
772
- __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
773
- __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
774
- __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
775
-
776
- __m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
777
- __m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
778
- __m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
779
- __m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
780
- __m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
781
- __m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
782
- sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
783
- sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
784
- sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
785
- sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
786
- sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
787
- sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
456
+ if (index < n) {
457
+ nk_size_t tail_floats = (n - index) * 3;
458
+ __mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, tail_floats);
459
+ __m512 a_f32x16 = _mm512_maskz_loadu_ps(tail_mask, a + index * 3);
460
+ __m512 b_f32x16 = _mm512_maskz_loadu_ps(tail_mask, b + index * 3);
461
+ __m512d a_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_f32x16));
462
+ __m512d a_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_f32x16, 1));
463
+ __m512d b_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_f32x16));
464
+ __m512d b_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_f32x16, 1));
465
+ __m512d delta_low_f64x8 = _mm512_sub_pd(a_low_f64x8, b_low_f64x8);
466
+ __m512d delta_high_f64x8 = _mm512_sub_pd(a_high_f64x8, b_high_f64x8);
467
+ sum_squared_low_f64x8 = _mm512_fmadd_pd(delta_low_f64x8, delta_low_f64x8, sum_squared_low_f64x8);
468
+ sum_squared_high_f64x8 = _mm512_fmadd_pd(delta_high_f64x8, delta_high_f64x8, sum_squared_high_f64x8);
788
469
  }
789
470
 
790
- nk_f64_t total_sq_x = _mm512_reduce_add_pd(sum_squared_x_f64x8);
791
- nk_f64_t total_sq_y = _mm512_reduce_add_pd(sum_squared_y_f64x8);
792
- nk_f64_t total_sq_z = _mm512_reduce_add_pd(sum_squared_z_f64x8);
793
- *result = nk_f64_sqrt_haswell((total_sq_x + total_sq_y + total_sq_z) / (nk_f64_t)n);
471
+ nk_f64_t sum_squared = _mm512_reduce_add_pd(_mm512_add_pd(sum_squared_low_f64x8, sum_squared_high_f64x8));
472
+ *result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
794
473
  }
795
474
 
796
475
  NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
797
476
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
798
- // Fused single-pass: centroids + covariance in f64
799
- __m512d const zeros_f64x8 = _mm512_setzero_pd();
800
- __m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
801
- __m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
802
- __m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
803
- __m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
804
- __m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
805
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
806
- nk_size_t i = 0;
807
-
808
- for (; i + 16 <= n; i += 16) {
809
- nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
810
- nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
811
- __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
812
- __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
813
- __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
814
- __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
815
- __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
816
- __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
817
- __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
818
- __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
819
- __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
820
- __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
821
- __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
822
- __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
823
-
824
- sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
825
- sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
826
- sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
827
- sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
828
- sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
829
- sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
830
-
831
- cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
832
- _mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
833
- cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
834
- _mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
835
- cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
836
- _mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
837
- cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
838
- _mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
839
- cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
840
- _mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
841
- cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
842
- _mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
843
- cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
844
- _mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
845
- cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
846
- _mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
847
- cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
848
- _mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
849
- }
850
-
851
- // Tail: use masked gather for remaining < 16 points
852
- if (i < n) {
853
- nk_size_t tail = n - i;
854
- __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
855
- __m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
856
- __m512 zeros_f32x16 = _mm512_setzero_ps();
857
- nk_f32_t const *a_tail = a + i * 3;
858
- nk_f32_t const *b_tail = b + i * 3;
859
-
860
- a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
861
- a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
862
- a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
863
- b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
864
- b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
865
- b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
866
-
867
- __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
868
- __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
869
- __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
870
- __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
871
- __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
872
- __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
873
- __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
874
- __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
875
- __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
876
- __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
877
- __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
878
- __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
879
-
880
- sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
881
- sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
882
- sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
883
- sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
884
- sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
885
- sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
886
-
887
- cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
888
- _mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
889
- cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
890
- _mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
891
- cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
892
- _mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
893
- cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
894
- _mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
895
- cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
896
- _mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
897
- cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
898
- _mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
899
- cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
900
- _mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
901
- cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
902
- _mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
903
- cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
904
- _mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
905
- }
477
+ // Single pass over (a, b) via streaming-stats helper — no deinterleave, no second SSD pass.
478
+ nk_f64_t sum_a[3], sum_b[3], raw_covariance[9], norm_squared_a, norm_squared_b;
479
+ nk_mesh_streaming_stats_f32_skylake_(a, b, n, sum_a, sum_b, raw_covariance, &norm_squared_a, &norm_squared_b);
906
480
 
907
- nk_f64_t sum_a_x = _mm512_reduce_add_pd(sum_a_x_f64x8), sum_a_y = _mm512_reduce_add_pd(sum_a_y_f64x8),
908
- sum_a_z = _mm512_reduce_add_pd(sum_a_z_f64x8);
909
- nk_f64_t sum_b_x = _mm512_reduce_add_pd(sum_b_x_f64x8), sum_b_y = _mm512_reduce_add_pd(sum_b_y_f64x8),
910
- sum_b_z = _mm512_reduce_add_pd(sum_b_z_f64x8);
911
- nk_f64_t covariance_x_x = _mm512_reduce_add_pd(cov_xx_f64x8), covariance_x_y = _mm512_reduce_add_pd(cov_xy_f64x8),
912
- covariance_x_z = _mm512_reduce_add_pd(cov_xz_f64x8);
913
- nk_f64_t covariance_y_x = _mm512_reduce_add_pd(cov_yx_f64x8), covariance_y_y = _mm512_reduce_add_pd(cov_yy_f64x8),
914
- covariance_y_z = _mm512_reduce_add_pd(cov_yz_f64x8);
915
- nk_f64_t covariance_z_x = _mm512_reduce_add_pd(cov_zx_f64x8), covariance_z_y = _mm512_reduce_add_pd(cov_zy_f64x8),
916
- covariance_z_z = _mm512_reduce_add_pd(cov_zz_f64x8);
917
-
918
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
919
- nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
920
- nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
481
+ nk_f64_t n_f64 = (nk_f64_t)n;
482
+ nk_f64_t inv_n = 1.0 / n_f64;
483
+ nk_f64_t centroid_a_x = sum_a[0] * inv_n, centroid_a_y = sum_a[1] * inv_n, centroid_a_z = sum_a[2] * inv_n;
484
+ nk_f64_t centroid_b_x = sum_b[0] * inv_n, centroid_b_y = sum_b[1] * inv_n, centroid_b_z = sum_b[2] * inv_n;
921
485
  if (a_centroid)
922
486
  a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
923
487
  a_centroid[2] = (nk_f32_t)centroid_a_z;
@@ -926,32 +490,69 @@ NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_si
926
490
  b_centroid[2] = (nk_f32_t)centroid_b_z;
927
491
  if (scale) *scale = 1.0f;
928
492
 
929
- nk_f64_t n_f64 = (nk_f64_t)n;
493
+ // Parallel-axis correction: H_centered[j,k] = Sum(a_j * b_k) - n * centroid_a[j] * centroid_b[k].
930
494
  nk_f64_t cross_covariance[9];
931
- cross_covariance[0] = covariance_x_x - n_f64 * centroid_a_x * centroid_b_x;
932
- cross_covariance[1] = covariance_x_y - n_f64 * centroid_a_x * centroid_b_y;
933
- cross_covariance[2] = covariance_x_z - n_f64 * centroid_a_x * centroid_b_z;
934
- cross_covariance[3] = covariance_y_x - n_f64 * centroid_a_y * centroid_b_x;
935
- cross_covariance[4] = covariance_y_y - n_f64 * centroid_a_y * centroid_b_y;
936
- cross_covariance[5] = covariance_y_z - n_f64 * centroid_a_y * centroid_b_z;
937
- cross_covariance[6] = covariance_z_x - n_f64 * centroid_a_z * centroid_b_x;
938
- cross_covariance[7] = covariance_z_y - n_f64 * centroid_a_z * centroid_b_y;
939
- cross_covariance[8] = covariance_z_z - n_f64 * centroid_a_z * centroid_b_z;
940
-
941
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
942
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
943
- nk_f64_t r[9];
944
- nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
945
- if (nk_det3x3_f64_(r) < 0) {
946
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
947
- nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
495
+ cross_covariance[0] = raw_covariance[0] - n_f64 * centroid_a_x * centroid_b_x;
496
+ cross_covariance[1] = raw_covariance[1] - n_f64 * centroid_a_x * centroid_b_y;
497
+ cross_covariance[2] = raw_covariance[2] - n_f64 * centroid_a_x * centroid_b_z;
498
+ cross_covariance[3] = raw_covariance[3] - n_f64 * centroid_a_y * centroid_b_x;
499
+ cross_covariance[4] = raw_covariance[4] - n_f64 * centroid_a_y * centroid_b_y;
500
+ cross_covariance[5] = raw_covariance[5] - n_f64 * centroid_a_y * centroid_b_z;
501
+ cross_covariance[6] = raw_covariance[6] - n_f64 * centroid_a_z * centroid_b_x;
502
+ cross_covariance[7] = raw_covariance[7] - n_f64 * centroid_a_z * centroid_b_y;
503
+ cross_covariance[8] = raw_covariance[8] - n_f64 * centroid_a_z * centroid_b_z;
504
+
505
+ // Identity-dominant short-circuit: skip SVD + rotation_from_svd when H is near-diagonal
506
+ // positive-definite. `r` is set to identity and trace(R * H) collapses to H[0]+H[4]+H[8].
507
+ // Saves ~500 cycles on aligned/pre-registered inputs; zero cost when inputs are random
508
+ // (branch is well-predicted in practice).
509
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
510
+ cross_covariance[4] * cross_covariance[4] +
511
+ cross_covariance[8] * cross_covariance[8];
512
+ nk_f64_t covariance_offdiagonal_norm_squared =
513
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
514
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
515
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
516
+ nk_f64_t optimal_rotation[9];
517
+ nk_f64_t trace_rotation_covariance;
518
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
519
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
520
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
521
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
522
+ optimal_rotation[8] = 1;
523
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
524
+ }
525
+ else {
526
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
527
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
528
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
529
+ if (nk_det3x3_f64_(optimal_rotation) < 0) {
530
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
531
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
532
+ }
533
+ trace_rotation_covariance =
534
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
535
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
536
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
537
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
538
+ optimal_rotation[8] * cross_covariance[8];
948
539
  }
949
540
  if (rotation)
950
- for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)r[j];
951
- *result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_skylake_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y,
952
- centroid_a_z, centroid_b_x, centroid_b_y,
953
- centroid_b_z) /
954
- n_f64);
541
+ for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
542
+
543
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
544
+ nk_f64_t centered_norm_squared_a = norm_squared_a -
545
+ n_f64 * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
546
+ centroid_a_z * centroid_a_z);
547
+ nk_f64_t centered_norm_squared_b = norm_squared_b -
548
+ n_f64 * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
549
+ centroid_b_z * centroid_b_z);
550
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
551
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
552
+
553
+ nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
554
+ if (sum_squared < 0.0) sum_squared = 0.0;
555
+ *result = nk_f64_sqrt_haswell(sum_squared / n_f64);
955
556
  }
956
557
 
957
558
  NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
@@ -1064,14 +665,15 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
1064
665
  __m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
1065
666
 
1066
667
  // Accumulators for covariance matrix (sum of outer products)
1067
- __m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
1068
- __m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
1069
- __m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
668
+ __m512d covariance_xx_f64x8 = zeros_f64x8, covariance_xy_f64x8 = zeros_f64x8, covariance_xz_f64x8 = zeros_f64x8;
669
+ __m512d covariance_yx_f64x8 = zeros_f64x8, covariance_yy_f64x8 = zeros_f64x8, covariance_yz_f64x8 = zeros_f64x8;
670
+ __m512d covariance_zx_f64x8 = zeros_f64x8, covariance_zy_f64x8 = zeros_f64x8, covariance_zz_f64x8 = zeros_f64x8;
671
+ __m512d norm_squared_a_f64x8 = zeros_f64x8, norm_squared_b_f64x8 = zeros_f64x8;
1070
672
 
1071
673
  nk_size_t i = 0;
1072
674
  __m512d a_x_f64x8, a_y_f64x8, a_z_f64x8, b_x_f64x8, b_y_f64x8, b_z_f64x8;
1073
675
 
1074
- // Fused single-pass: accumulate sums and outer products together
676
+ // Fused single-pass: accumulate sums, outer products, and norms^2 together
1075
677
  for (; i + 8 <= n; i += 8) {
1076
678
  nk_deinterleave_f64x8_skylake_(a + i * 3, &a_x_f64x8, &a_y_f64x8, &a_z_f64x8);
1077
679
  nk_deinterleave_f64x8_skylake_(b + i * 3, &b_x_f64x8, &b_y_f64x8, &b_z_f64x8);
@@ -1085,15 +687,21 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
1085
687
  sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
1086
688
 
1087
689
  // Accumulate outer products (raw, not centered)
1088
- cov_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, cov_xx_f64x8),
1089
- cov_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, cov_xy_f64x8),
1090
- cov_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, cov_xz_f64x8);
1091
- cov_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, cov_yx_f64x8),
1092
- cov_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, cov_yy_f64x8),
1093
- cov_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, cov_yz_f64x8);
1094
- cov_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, cov_zx_f64x8),
1095
- cov_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, cov_zy_f64x8),
1096
- cov_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, cov_zz_f64x8);
690
+ covariance_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, covariance_xx_f64x8),
691
+ covariance_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, covariance_xy_f64x8),
692
+ covariance_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, covariance_xz_f64x8);
693
+ covariance_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, covariance_yx_f64x8),
694
+ covariance_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, covariance_yy_f64x8),
695
+ covariance_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, covariance_yz_f64x8);
696
+ covariance_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, covariance_zx_f64x8),
697
+ covariance_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, covariance_zy_f64x8),
698
+ covariance_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, covariance_zz_f64x8);
699
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, norm_squared_a_f64x8);
700
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, norm_squared_a_f64x8);
701
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, norm_squared_a_f64x8);
702
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_x_f64x8, b_x_f64x8, norm_squared_b_f64x8);
703
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_y_f64x8, b_y_f64x8, norm_squared_b_f64x8);
704
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_z_f64x8, b_z_f64x8, norm_squared_b_f64x8);
1097
705
  }
1098
706
 
1099
707
  // Tail: masked gather for remaining points
@@ -1117,15 +725,21 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
1117
725
  sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y_f64x8),
1118
726
  sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
1119
727
 
1120
- cov_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, cov_xx_f64x8),
1121
- cov_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, cov_xy_f64x8),
1122
- cov_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, cov_xz_f64x8);
1123
- cov_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, cov_yx_f64x8),
1124
- cov_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, cov_yy_f64x8),
1125
- cov_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, cov_yz_f64x8);
1126
- cov_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, cov_zx_f64x8),
1127
- cov_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, cov_zy_f64x8),
1128
- cov_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, cov_zz_f64x8);
728
+ covariance_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, covariance_xx_f64x8),
729
+ covariance_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, covariance_xy_f64x8),
730
+ covariance_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, covariance_xz_f64x8);
731
+ covariance_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, covariance_yx_f64x8),
732
+ covariance_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, covariance_yy_f64x8),
733
+ covariance_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, covariance_yz_f64x8);
734
+ covariance_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, covariance_zx_f64x8),
735
+ covariance_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, covariance_zy_f64x8),
736
+ covariance_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, covariance_zz_f64x8);
737
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, norm_squared_a_f64x8);
738
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, norm_squared_a_f64x8);
739
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, norm_squared_a_f64x8);
740
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_x_f64x8, b_x_f64x8, norm_squared_b_f64x8);
741
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_y_f64x8, b_y_f64x8, norm_squared_b_f64x8);
742
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_z_f64x8, b_z_f64x8, norm_squared_b_f64x8);
1129
743
  i = n;
1130
744
  }
1131
745
 
@@ -1137,15 +751,19 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
1137
751
  nk_f64_t sum_b_x = nk_reduce_stable_f64x8_skylake_(sum_b_x_f64x8), sum_b_x_compensation = 0.0;
1138
752
  nk_f64_t sum_b_y = nk_reduce_stable_f64x8_skylake_(sum_b_y_f64x8), sum_b_y_compensation = 0.0;
1139
753
  nk_f64_t sum_b_z = nk_reduce_stable_f64x8_skylake_(sum_b_z_f64x8), sum_b_z_compensation = 0.0;
1140
- nk_f64_t covariance_x_x = nk_reduce_stable_f64x8_skylake_(cov_xx_f64x8), covariance_x_x_compensation = 0.0;
1141
- nk_f64_t covariance_x_y = nk_reduce_stable_f64x8_skylake_(cov_xy_f64x8), covariance_x_y_compensation = 0.0;
1142
- nk_f64_t covariance_x_z = nk_reduce_stable_f64x8_skylake_(cov_xz_f64x8), covariance_x_z_compensation = 0.0;
1143
- nk_f64_t covariance_y_x = nk_reduce_stable_f64x8_skylake_(cov_yx_f64x8), covariance_y_x_compensation = 0.0;
1144
- nk_f64_t covariance_y_y = nk_reduce_stable_f64x8_skylake_(cov_yy_f64x8), covariance_y_y_compensation = 0.0;
1145
- nk_f64_t covariance_y_z = nk_reduce_stable_f64x8_skylake_(cov_yz_f64x8), covariance_y_z_compensation = 0.0;
1146
- nk_f64_t covariance_z_x = nk_reduce_stable_f64x8_skylake_(cov_zx_f64x8), covariance_z_x_compensation = 0.0;
1147
- nk_f64_t covariance_z_y = nk_reduce_stable_f64x8_skylake_(cov_zy_f64x8), covariance_z_y_compensation = 0.0;
1148
- nk_f64_t covariance_z_z = nk_reduce_stable_f64x8_skylake_(cov_zz_f64x8), covariance_z_z_compensation = 0.0;
754
+ nk_f64_t covariance_x_x = nk_reduce_stable_f64x8_skylake_(covariance_xx_f64x8), covariance_x_x_compensation = 0.0;
755
+ nk_f64_t covariance_x_y = nk_reduce_stable_f64x8_skylake_(covariance_xy_f64x8), covariance_x_y_compensation = 0.0;
756
+ nk_f64_t covariance_x_z = nk_reduce_stable_f64x8_skylake_(covariance_xz_f64x8), covariance_x_z_compensation = 0.0;
757
+ nk_f64_t covariance_y_x = nk_reduce_stable_f64x8_skylake_(covariance_yx_f64x8), covariance_y_x_compensation = 0.0;
758
+ nk_f64_t covariance_y_y = nk_reduce_stable_f64x8_skylake_(covariance_yy_f64x8), covariance_y_y_compensation = 0.0;
759
+ nk_f64_t covariance_y_z = nk_reduce_stable_f64x8_skylake_(covariance_yz_f64x8), covariance_y_z_compensation = 0.0;
760
+ nk_f64_t covariance_z_x = nk_reduce_stable_f64x8_skylake_(covariance_zx_f64x8), covariance_z_x_compensation = 0.0;
761
+ nk_f64_t covariance_z_y = nk_reduce_stable_f64x8_skylake_(covariance_zy_f64x8), covariance_z_y_compensation = 0.0;
762
+ nk_f64_t covariance_z_z = nk_reduce_stable_f64x8_skylake_(covariance_zz_f64x8), covariance_z_z_compensation = 0.0;
763
+ nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x8_skylake_(norm_squared_a_f64x8),
764
+ norm_squared_a_compensation = 0.0;
765
+ nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x8_skylake_(norm_squared_b_f64x8),
766
+ norm_squared_b_compensation = 0.0;
1149
767
 
1150
768
  for (; i < n; ++i) {
1151
769
  nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
@@ -1165,6 +783,12 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
1165
783
  nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
1166
784
  nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
1167
785
  nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
786
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax);
787
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay);
788
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
789
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx);
790
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by);
791
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
1168
792
  }
1169
793
 
1170
794
  sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
@@ -1175,6 +799,8 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
1175
799
  covariance_y_z += covariance_y_z_compensation;
1176
800
  covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
1177
801
  covariance_z_z += covariance_z_z_compensation;
802
+ norm_squared_a_sum += norm_squared_a_compensation;
803
+ norm_squared_b_sum += norm_squared_b_compensation;
1178
804
 
1179
805
  nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
1180
806
  nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
@@ -1194,177 +820,70 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
1194
820
  cross_covariance[7] = covariance_z_y - sum_a_z * sum_b_y * inv_n;
1195
821
  cross_covariance[8] = covariance_z_z - sum_a_z * sum_b_z * inv_n;
1196
822
 
1197
- // SVD using f64 for full precision
1198
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
1199
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
1200
-
1201
- nk_f64_t r[9];
1202
- nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
1203
-
1204
- // Handle reflection
1205
- if (nk_det3x3_f64_(r) < 0) {
1206
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1207
- nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
823
+ // Identity-dominant short-circuit: if H_centered is near-diagonal positive-definite,
824
+ // R = I and trace(R * H) = H[0] + H[4] + H[8]. Saves ~500 cycles on aligned inputs.
825
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
826
+ cross_covariance[4] * cross_covariance[4] +
827
+ cross_covariance[8] * cross_covariance[8];
828
+ nk_f64_t covariance_offdiagonal_norm_squared =
829
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
830
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
831
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
832
+ nk_f64_t optimal_rotation[9];
833
+ nk_f64_t trace_rotation_covariance;
834
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
835
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
836
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
837
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
838
+ optimal_rotation[8] = 1;
839
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
840
+ }
841
+ else {
842
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
843
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
844
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
845
+ if (nk_det3x3_f64_(optimal_rotation) < 0) {
846
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
847
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
848
+ }
849
+ trace_rotation_covariance =
850
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
851
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
852
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
853
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
854
+ optimal_rotation[8] * cross_covariance[8];
1208
855
  }
1209
856
 
1210
857
  // Output rotation matrix and scale=1.0.
1211
858
  if (rotation)
1212
- for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)r[j];
859
+ for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)optimal_rotation[j];
1213
860
  if (scale) *scale = 1.0;
1214
861
 
1215
- // Compute RMSD after optimal rotation
1216
- nk_f64_t sum_squared = nk_transformed_ssd_f64_skylake_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y, centroid_a_z,
1217
- centroid_b_x, centroid_b_y, centroid_b_z);
862
+ // Folded SSD via trace identity - no second pass over the buffers:
863
+ // SSD = a-ā‖² + ‖b-b̄‖² 2·trace(R · H_centered).
864
+ nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
865
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
866
+ centroid_a_z * centroid_a_z);
867
+ nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
868
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
869
+ centroid_b_z * centroid_b_z);
870
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
871
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
872
+ nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
873
+ if (sum_squared < 0.0) sum_squared = 0.0;
1218
874
  *result = nk_f64_sqrt_haswell(sum_squared * inv_n);
1219
875
  }
1220
876
 
1221
877
  NK_PUBLIC void nk_umeyama_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1222
878
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
1223
- // Fused single-pass: centroids + covariance + variance of A, all in f64
1224
- __m512d const zeros_f64x8 = _mm512_setzero_pd();
1225
- __m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
1226
- __m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
1227
- __m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
1228
- __m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
1229
- __m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
1230
- __m512d variance_a_f64x8 = zeros_f64x8;
1231
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1232
- nk_size_t i = 0;
1233
-
1234
- for (; i + 16 <= n; i += 16) {
1235
- nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1236
- nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1237
- __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
1238
- __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
1239
- __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
1240
- __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
1241
- __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
1242
- __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
1243
- __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
1244
- __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
1245
- __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
1246
- __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
1247
- __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
1248
- __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
1249
-
1250
- sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
1251
- sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
1252
- sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
1253
- sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
1254
- sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
1255
- sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
1256
-
1257
- cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
1258
- _mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
1259
- cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
1260
- _mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
1261
- cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
1262
- _mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
1263
- cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
1264
- _mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
1265
- cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
1266
- _mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
1267
- cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
1268
- _mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
1269
- cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
1270
- _mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
1271
- cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
1272
- _mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
1273
- cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
1274
- _mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
1275
-
1276
- variance_a_f64x8 = _mm512_add_pd(
1277
- variance_a_f64x8,
1278
- _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, a_x_low_f64x8), _mm512_mul_pd(a_x_high_f64x8, a_x_high_f64x8)));
1279
- variance_a_f64x8 = _mm512_add_pd(
1280
- variance_a_f64x8,
1281
- _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, a_y_low_f64x8), _mm512_mul_pd(a_y_high_f64x8, a_y_high_f64x8)));
1282
- variance_a_f64x8 = _mm512_add_pd(
1283
- variance_a_f64x8,
1284
- _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, a_z_low_f64x8), _mm512_mul_pd(a_z_high_f64x8, a_z_high_f64x8)));
1285
- }
1286
-
1287
- // Tail: use masked gather for remaining < 16 points
1288
- if (i < n) {
1289
- nk_size_t tail = n - i;
1290
- __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
1291
- __m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
1292
- __m512 zeros_f32x16 = _mm512_setzero_ps();
1293
- nk_f32_t const *a_tail = a + i * 3;
1294
- nk_f32_t const *b_tail = b + i * 3;
1295
-
1296
- a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
1297
- a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
1298
- a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
1299
- b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
1300
- b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
1301
- b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
1302
-
1303
- __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
1304
- __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
1305
- __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
1306
- __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
1307
- __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
1308
- __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
1309
- __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
1310
- __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
1311
- __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
1312
- __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
1313
- __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
1314
- __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
1315
-
1316
- sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
1317
- sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
1318
- sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
1319
- sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
1320
- sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
1321
- sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
1322
-
1323
- cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
1324
- _mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
1325
- cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
1326
- _mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
1327
- cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
1328
- _mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
1329
- cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
1330
- _mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
1331
- cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
1332
- _mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
1333
- cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
1334
- _mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
1335
- cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
1336
- _mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
1337
- cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
1338
- _mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
1339
- cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
1340
- _mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
1341
-
1342
- variance_a_f64x8 = _mm512_add_pd(
1343
- variance_a_f64x8,
1344
- _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, a_x_low_f64x8), _mm512_mul_pd(a_x_high_f64x8, a_x_high_f64x8)));
1345
- variance_a_f64x8 = _mm512_add_pd(
1346
- variance_a_f64x8,
1347
- _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, a_y_low_f64x8), _mm512_mul_pd(a_y_high_f64x8, a_y_high_f64x8)));
1348
- variance_a_f64x8 = _mm512_add_pd(
1349
- variance_a_f64x8,
1350
- _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, a_z_low_f64x8), _mm512_mul_pd(a_z_high_f64x8, a_z_high_f64x8)));
1351
- }
879
+ // Single pass over (a, b) via streaming-stats helper — no deinterleave, no second SSD pass.
880
+ nk_f64_t sum_a[3], sum_b[3], raw_covariance[9], norm_squared_a, norm_squared_b;
881
+ nk_mesh_streaming_stats_f32_skylake_(a, b, n, sum_a, sum_b, raw_covariance, &norm_squared_a, &norm_squared_b);
1352
882
 
1353
- nk_f64_t sum_a_x = _mm512_reduce_add_pd(sum_a_x_f64x8), sum_a_y = _mm512_reduce_add_pd(sum_a_y_f64x8),
1354
- sum_a_z = _mm512_reduce_add_pd(sum_a_z_f64x8);
1355
- nk_f64_t sum_b_x = _mm512_reduce_add_pd(sum_b_x_f64x8), sum_b_y = _mm512_reduce_add_pd(sum_b_y_f64x8),
1356
- sum_b_z = _mm512_reduce_add_pd(sum_b_z_f64x8);
1357
- nk_f64_t covariance_x_x = _mm512_reduce_add_pd(cov_xx_f64x8), covariance_x_y = _mm512_reduce_add_pd(cov_xy_f64x8),
1358
- covariance_x_z = _mm512_reduce_add_pd(cov_xz_f64x8);
1359
- nk_f64_t covariance_y_x = _mm512_reduce_add_pd(cov_yx_f64x8), covariance_y_y = _mm512_reduce_add_pd(cov_yy_f64x8),
1360
- covariance_y_z = _mm512_reduce_add_pd(cov_yz_f64x8);
1361
- nk_f64_t covariance_z_x = _mm512_reduce_add_pd(cov_zx_f64x8), covariance_z_y = _mm512_reduce_add_pd(cov_zy_f64x8),
1362
- covariance_z_z = _mm512_reduce_add_pd(cov_zz_f64x8);
1363
- nk_f64_t variance_a_sum = _mm512_reduce_add_pd(variance_a_f64x8);
1364
-
1365
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
1366
- nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
1367
- nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
883
+ nk_f64_t n_f64 = (nk_f64_t)n;
884
+ nk_f64_t inv_n = 1.0 / n_f64;
885
+ nk_f64_t centroid_a_x = sum_a[0] * inv_n, centroid_a_y = sum_a[1] * inv_n, centroid_a_z = sum_a[2] * inv_n;
886
+ nk_f64_t centroid_b_x = sum_b[0] * inv_n, centroid_b_y = sum_b[1] * inv_n, centroid_b_z = sum_b[2] * inv_n;
1368
887
  if (a_centroid)
1369
888
  a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
1370
889
  a_centroid[2] = (nk_f32_t)centroid_a_z;
@@ -1372,49 +891,81 @@ NK_PUBLIC void nk_umeyama_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_s
1372
891
  b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
1373
892
  b_centroid[2] = (nk_f32_t)centroid_b_z;
1374
893
 
1375
- // Compute centered covariance and variance
1376
- nk_f64_t variance_a = variance_a_sum * inv_n -
1377
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
894
+ // Centered norms and centered covariance via parallel-axis identity.
895
+ nk_f64_t centered_norm_squared_a = norm_squared_a -
896
+ n_f64 * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
897
+ centroid_a_z * centroid_a_z);
898
+ nk_f64_t centered_norm_squared_b = norm_squared_b -
899
+ n_f64 * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
900
+ centroid_b_z * centroid_b_z);
901
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
902
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
903
+ nk_f64_t variance_a = centered_norm_squared_a * inv_n;
1378
904
 
1379
- // Compute centered covariance matrix: Hᵢⱼ = Σ(aᵢ×bⱼ) - Σaᵢ × Σbⱼ / n
1380
- nk_f64_t n_f64 = (nk_f64_t)n;
1381
905
  nk_f64_t cross_covariance[9];
1382
- cross_covariance[0] = covariance_x_x - n_f64 * centroid_a_x * centroid_b_x;
1383
- cross_covariance[1] = covariance_x_y - n_f64 * centroid_a_x * centroid_b_y;
1384
- cross_covariance[2] = covariance_x_z - n_f64 * centroid_a_x * centroid_b_z;
1385
- cross_covariance[3] = covariance_y_x - n_f64 * centroid_a_y * centroid_b_x;
1386
- cross_covariance[4] = covariance_y_y - n_f64 * centroid_a_y * centroid_b_y;
1387
- cross_covariance[5] = covariance_y_z - n_f64 * centroid_a_y * centroid_b_z;
1388
- cross_covariance[6] = covariance_z_x - n_f64 * centroid_a_z * centroid_b_x;
1389
- cross_covariance[7] = covariance_z_y - n_f64 * centroid_a_z * centroid_b_y;
1390
- cross_covariance[8] = covariance_z_z - n_f64 * centroid_a_z * centroid_b_z;
1391
-
1392
- // SVD using f64 for full precision
1393
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
1394
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
1395
-
1396
- nk_f64_t r[9];
1397
- nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
1398
-
1399
- // Scale factor: c = trace(D × S) / (n × variance(a))
1400
- nk_f64_t det = nk_det3x3_f64_(r);
1401
- nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
1402
- nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
1403
- nk_f64_t applied_scale = trace_ds / ((nk_f64_t)n * variance_a);
1404
- if (scale) *scale = (nk_f32_t)applied_scale;
1405
-
1406
- // Handle reflection
1407
- if (det < 0) {
1408
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1409
- nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
906
+ cross_covariance[0] = raw_covariance[0] - n_f64 * centroid_a_x * centroid_b_x;
907
+ cross_covariance[1] = raw_covariance[1] - n_f64 * centroid_a_x * centroid_b_y;
908
+ cross_covariance[2] = raw_covariance[2] - n_f64 * centroid_a_x * centroid_b_z;
909
+ cross_covariance[3] = raw_covariance[3] - n_f64 * centroid_a_y * centroid_b_x;
910
+ cross_covariance[4] = raw_covariance[4] - n_f64 * centroid_a_y * centroid_b_y;
911
+ cross_covariance[5] = raw_covariance[5] - n_f64 * centroid_a_y * centroid_b_z;
912
+ cross_covariance[6] = raw_covariance[6] - n_f64 * centroid_a_z * centroid_b_x;
913
+ cross_covariance[7] = raw_covariance[7] - n_f64 * centroid_a_z * centroid_b_y;
914
+ cross_covariance[8] = raw_covariance[8] - n_f64 * centroid_a_z * centroid_b_z;
915
+
916
+ // Identity-dominant short-circuit: when H_centered is near-diagonal positive-definite, R = I
917
+ // and trace(R * H) collapses to H[0]+H[4]+H[8]. Also d3 = +1, so trace_ds = sum of diagonal,
918
+ // and applied_scale = trace_ds / (n * variance_a). Skips SVD + two rotation_from_svd calls.
919
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
920
+ cross_covariance[4] * cross_covariance[4] +
921
+ cross_covariance[8] * cross_covariance[8];
922
+ nk_f64_t covariance_offdiagonal_norm_squared =
923
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
924
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
925
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
926
+ nk_f64_t optimal_rotation[9];
927
+ nk_f64_t applied_scale;
928
+ nk_f64_t trace_rotation_covariance;
929
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
930
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
931
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
932
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
933
+ optimal_rotation[8] = 1;
934
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
935
+ applied_scale = trace_rotation_covariance / (n_f64 * variance_a);
1410
936
  }
1411
-
937
+ else {
938
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
939
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
940
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
941
+
942
+ // Scale factor: c = trace(D · S) / (n * variance_a), with reflection sign via d3.
943
+ nk_f64_t det = nk_det3x3_f64_(optimal_rotation);
944
+ nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
945
+ nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_diagonal[0], 1.0, svd_diagonal[4], 1.0, svd_diagonal[8], d3);
946
+ applied_scale = trace_ds / (n_f64 * variance_a);
947
+
948
+ if (det < 0) {
949
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
950
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
951
+ }
952
+ trace_rotation_covariance =
953
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
954
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
955
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
956
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
957
+ optimal_rotation[8] * cross_covariance[8];
958
+ }
959
+ if (scale) *scale = (nk_f32_t)applied_scale;
1412
960
  if (rotation)
1413
- for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)r[j];
1414
- *result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_skylake_(a, b, n, r, applied_scale, centroid_a_x, centroid_a_y,
1415
- centroid_a_z, centroid_b_x, centroid_b_y,
1416
- centroid_b_z) /
1417
- n_f64);
961
+ for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
962
+
963
+ // Folded SSD with scale: sum(|| s*R*(a-abar) - (b-bbar) ||^2)
964
+ // = s²·‖a-ā‖² + ‖b-b̄‖² − 2s·trace(R · H_centered).
965
+ nk_f64_t sum_squared = applied_scale * applied_scale * centered_norm_squared_a + centered_norm_squared_b -
966
+ 2.0 * applied_scale * trace_rotation_covariance;
967
+ if (sum_squared < 0.0) sum_squared = 0.0;
968
+ *result = nk_f64_sqrt_haswell(sum_squared / n_f64);
1418
969
  }
1419
970
 
1420
971
  NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
@@ -1425,10 +976,10 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1425
976
 
1426
977
  __m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
1427
978
  __m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
1428
- __m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
1429
- __m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
1430
- __m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
1431
- __m512d variance_a_f64x8 = zeros_f64x8;
979
+ __m512d covariance_xx_f64x8 = zeros_f64x8, covariance_xy_f64x8 = zeros_f64x8, covariance_xz_f64x8 = zeros_f64x8;
980
+ __m512d covariance_yx_f64x8 = zeros_f64x8, covariance_yy_f64x8 = zeros_f64x8, covariance_yz_f64x8 = zeros_f64x8;
981
+ __m512d covariance_zx_f64x8 = zeros_f64x8, covariance_zy_f64x8 = zeros_f64x8, covariance_zz_f64x8 = zeros_f64x8;
982
+ __m512d norm_squared_a_f64x8 = zeros_f64x8, norm_squared_b_f64x8 = zeros_f64x8;
1432
983
 
1433
984
  nk_size_t i = 0;
1434
985
  __m512d a_x_f64x8, a_y_f64x8, a_z_f64x8, b_x_f64x8, b_y_f64x8, b_z_f64x8;
@@ -1444,18 +995,21 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1444
995
  sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y_f64x8);
1445
996
  sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
1446
997
 
1447
- cov_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, cov_xx_f64x8),
1448
- cov_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, cov_xy_f64x8);
1449
- cov_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, cov_xz_f64x8);
1450
- cov_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, cov_yx_f64x8),
1451
- cov_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, cov_yy_f64x8);
1452
- cov_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, cov_yz_f64x8);
1453
- cov_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, cov_zx_f64x8),
1454
- cov_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, cov_zy_f64x8);
1455
- cov_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, cov_zz_f64x8);
1456
- variance_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, variance_a_f64x8);
1457
- variance_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, variance_a_f64x8);
1458
- variance_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, variance_a_f64x8);
998
+ covariance_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, covariance_xx_f64x8),
999
+ covariance_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, covariance_xy_f64x8);
1000
+ covariance_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, covariance_xz_f64x8);
1001
+ covariance_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, covariance_yx_f64x8),
1002
+ covariance_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, covariance_yy_f64x8);
1003
+ covariance_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, covariance_yz_f64x8);
1004
+ covariance_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, covariance_zx_f64x8),
1005
+ covariance_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, covariance_zy_f64x8);
1006
+ covariance_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, covariance_zz_f64x8);
1007
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, norm_squared_a_f64x8);
1008
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, norm_squared_a_f64x8);
1009
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, norm_squared_a_f64x8);
1010
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_x_f64x8, b_x_f64x8, norm_squared_b_f64x8);
1011
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_y_f64x8, b_y_f64x8, norm_squared_b_f64x8);
1012
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_z_f64x8, b_z_f64x8, norm_squared_b_f64x8);
1459
1013
  }
1460
1014
 
1461
1015
  if (i < n) {
@@ -1478,18 +1032,21 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1478
1032
  sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y_f64x8);
1479
1033
  sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
1480
1034
 
1481
- cov_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, cov_xx_f64x8),
1482
- cov_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, cov_xy_f64x8);
1483
- cov_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, cov_xz_f64x8);
1484
- cov_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, cov_yx_f64x8),
1485
- cov_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, cov_yy_f64x8);
1486
- cov_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, cov_yz_f64x8);
1487
- cov_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, cov_zx_f64x8),
1488
- cov_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, cov_zy_f64x8);
1489
- cov_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, cov_zz_f64x8);
1490
- variance_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, variance_a_f64x8);
1491
- variance_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, variance_a_f64x8);
1492
- variance_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, variance_a_f64x8);
1035
+ covariance_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, covariance_xx_f64x8),
1036
+ covariance_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, covariance_xy_f64x8);
1037
+ covariance_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, covariance_xz_f64x8);
1038
+ covariance_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, covariance_yx_f64x8),
1039
+ covariance_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, covariance_yy_f64x8);
1040
+ covariance_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, covariance_yz_f64x8);
1041
+ covariance_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, covariance_zx_f64x8),
1042
+ covariance_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, covariance_zy_f64x8);
1043
+ covariance_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, covariance_zz_f64x8);
1044
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, norm_squared_a_f64x8);
1045
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, norm_squared_a_f64x8);
1046
+ norm_squared_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, norm_squared_a_f64x8);
1047
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_x_f64x8, b_x_f64x8, norm_squared_b_f64x8);
1048
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_y_f64x8, b_y_f64x8, norm_squared_b_f64x8);
1049
+ norm_squared_b_f64x8 = _mm512_fmadd_pd(b_z_f64x8, b_z_f64x8, norm_squared_b_f64x8);
1493
1050
  i = n;
1494
1051
  }
1495
1052
 
@@ -1501,16 +1058,19 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1501
1058
  nk_f64_t sum_b_x = nk_reduce_stable_f64x8_skylake_(sum_b_x_f64x8), sum_b_x_compensation = 0.0;
1502
1059
  nk_f64_t sum_b_y = nk_reduce_stable_f64x8_skylake_(sum_b_y_f64x8), sum_b_y_compensation = 0.0;
1503
1060
  nk_f64_t sum_b_z = nk_reduce_stable_f64x8_skylake_(sum_b_z_f64x8), sum_b_z_compensation = 0.0;
1504
- nk_f64_t covariance_x_x = nk_reduce_stable_f64x8_skylake_(cov_xx_f64x8), covariance_x_x_compensation = 0.0;
1505
- nk_f64_t covariance_x_y = nk_reduce_stable_f64x8_skylake_(cov_xy_f64x8), covariance_x_y_compensation = 0.0;
1506
- nk_f64_t covariance_x_z = nk_reduce_stable_f64x8_skylake_(cov_xz_f64x8), covariance_x_z_compensation = 0.0;
1507
- nk_f64_t covariance_y_x = nk_reduce_stable_f64x8_skylake_(cov_yx_f64x8), covariance_y_x_compensation = 0.0;
1508
- nk_f64_t covariance_y_y = nk_reduce_stable_f64x8_skylake_(cov_yy_f64x8), covariance_y_y_compensation = 0.0;
1509
- nk_f64_t covariance_y_z = nk_reduce_stable_f64x8_skylake_(cov_yz_f64x8), covariance_y_z_compensation = 0.0;
1510
- nk_f64_t covariance_z_x = nk_reduce_stable_f64x8_skylake_(cov_zx_f64x8), covariance_z_x_compensation = 0.0;
1511
- nk_f64_t covariance_z_y = nk_reduce_stable_f64x8_skylake_(cov_zy_f64x8), covariance_z_y_compensation = 0.0;
1512
- nk_f64_t covariance_z_z = nk_reduce_stable_f64x8_skylake_(cov_zz_f64x8), covariance_z_z_compensation = 0.0;
1513
- nk_f64_t variance_a_sum = nk_reduce_stable_f64x8_skylake_(variance_a_f64x8), variance_a_compensation = 0.0;
1061
+ nk_f64_t covariance_x_x = nk_reduce_stable_f64x8_skylake_(covariance_xx_f64x8), covariance_x_x_compensation = 0.0;
1062
+ nk_f64_t covariance_x_y = nk_reduce_stable_f64x8_skylake_(covariance_xy_f64x8), covariance_x_y_compensation = 0.0;
1063
+ nk_f64_t covariance_x_z = nk_reduce_stable_f64x8_skylake_(covariance_xz_f64x8), covariance_x_z_compensation = 0.0;
1064
+ nk_f64_t covariance_y_x = nk_reduce_stable_f64x8_skylake_(covariance_yx_f64x8), covariance_y_x_compensation = 0.0;
1065
+ nk_f64_t covariance_y_y = nk_reduce_stable_f64x8_skylake_(covariance_yy_f64x8), covariance_y_y_compensation = 0.0;
1066
+ nk_f64_t covariance_y_z = nk_reduce_stable_f64x8_skylake_(covariance_yz_f64x8), covariance_y_z_compensation = 0.0;
1067
+ nk_f64_t covariance_z_x = nk_reduce_stable_f64x8_skylake_(covariance_zx_f64x8), covariance_z_x_compensation = 0.0;
1068
+ nk_f64_t covariance_z_y = nk_reduce_stable_f64x8_skylake_(covariance_zy_f64x8), covariance_z_y_compensation = 0.0;
1069
+ nk_f64_t covariance_z_z = nk_reduce_stable_f64x8_skylake_(covariance_zz_f64x8), covariance_z_z_compensation = 0.0;
1070
+ nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x8_skylake_(norm_squared_a_f64x8),
1071
+ norm_squared_a_compensation = 0.0;
1072
+ nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x8_skylake_(norm_squared_b_f64x8),
1073
+ norm_squared_b_compensation = 0.0;
1514
1074
 
1515
1075
  for (; i < n; ++i) {
1516
1076
  nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
@@ -1530,9 +1090,12 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1530
1090
  nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
1531
1091
  nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
1532
1092
  nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
1533
- nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, ax);
1534
- nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, ay);
1535
- nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, az);
1093
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax);
1094
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay);
1095
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
1096
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx);
1097
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by);
1098
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
1536
1099
  }
1537
1100
 
1538
1101
  sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
@@ -1543,7 +1106,8 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1543
1106
  covariance_y_z += covariance_y_z_compensation;
1544
1107
  covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
1545
1108
  covariance_z_z += covariance_z_z_compensation;
1546
- variance_a_sum += variance_a_compensation;
1109
+ norm_squared_a_sum += norm_squared_a_compensation;
1110
+ norm_squared_b_sum += norm_squared_b_compensation;
1547
1111
 
1548
1112
  nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
1549
1113
  nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
@@ -1551,9 +1115,15 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1551
1115
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1552
1116
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1553
1117
 
1554
- // Compute centered covariance and variance.
1555
- nk_f64_t variance_a = variance_a_sum * inv_n -
1556
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
1118
+ // Centered norm squared via parallel-axis identity (clamped for numerical safety).
1119
+ nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
1120
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1121
+ centroid_a_z * centroid_a_z);
1122
+ nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
1123
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1124
+ centroid_b_z * centroid_b_z);
1125
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
1126
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
1557
1127
 
1558
1128
  // Compute centered covariance matrix: Hᵢⱼ = Σ(aᵢ×bⱼ) - Σaᵢ × Σbⱼ / n.
1559
1129
  nk_f64_t cross_covariance[9];
@@ -1568,32 +1138,58 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1568
1138
  cross_covariance[8] = covariance_z_z - sum_a_z * sum_b_z * inv_n;
1569
1139
 
1570
1140
  // SVD using f64 for full precision
1571
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
1572
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
1573
-
1574
- nk_f64_t r[9];
1575
- nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
1576
-
1577
- // Scale factor: c = trace(D × S) / (n × variance(a))
1578
- nk_f64_t det = nk_det3x3_f64_(r);
1579
- nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
1580
- nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
1581
- nk_f64_t c = trace_ds / ((nk_f64_t)n * variance_a);
1582
- if (scale) *scale = c;
1583
-
1584
- // Handle reflection
1585
- if (det < 0) {
1586
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1587
- nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
1141
+ // Identity-dominant short-circuit: when H_centered is near-diagonal positive-definite,
1142
+ // R = I, trace(R * H) = H[0]+H[4]+H[8] (also == trace_ds with d3=+1), and the scale
1143
+ // derivation collapses. Skips SVD + two rotation_from_svd calls.
1144
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
1145
+ cross_covariance[4] * cross_covariance[4] +
1146
+ cross_covariance[8] * cross_covariance[8];
1147
+ nk_f64_t covariance_offdiagonal_norm_squared =
1148
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
1149
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
1150
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
1151
+ nk_f64_t optimal_rotation[9];
1152
+ nk_f64_t c;
1153
+ nk_f64_t trace_rotation_covariance;
1154
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
1155
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
1156
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
1157
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
1158
+ optimal_rotation[8] = 1;
1159
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
1160
+ c = centered_norm_squared_a > 0.0 ? trace_rotation_covariance / centered_norm_squared_a : 0.0;
1588
1161
  }
1589
-
1590
- // Output rotation matrix.
1162
+ else {
1163
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
1164
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
1165
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
1166
+
1167
+ // Scale factor: c = trace(D · S) / (n * variance(a)), with reflection sign via d3.
1168
+ nk_f64_t det = nk_det3x3_f64_(optimal_rotation);
1169
+ nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
1170
+ nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_diagonal[0], 1.0, svd_diagonal[4], 1.0, svd_diagonal[8], d3);
1171
+ c = centered_norm_squared_a > 0.0 ? trace_ds / centered_norm_squared_a : 0.0;
1172
+
1173
+ if (det < 0) {
1174
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1175
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
1176
+ }
1177
+ trace_rotation_covariance =
1178
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1179
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1180
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1181
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1182
+ optimal_rotation[8] * cross_covariance[8];
1183
+ }
1184
+ if (scale) *scale = c;
1591
1185
  if (rotation)
1592
- for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)r[j];
1186
+ for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)optimal_rotation[j];
1593
1187
 
1594
- // Compute RMSD with scaling
1595
- nk_f64_t sum_squared = nk_transformed_ssd_f64_skylake_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
1596
- centroid_b_x, centroid_b_y, centroid_b_z);
1188
+ // Folded SSD with scale: Sum(|| c*R*(a-abar) - (b-bbar) ||^2)
1189
+ // = c²·‖a-ā‖² + ‖b-b̄‖² 2c·trace(R · H_centered).
1190
+ nk_f64_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
1191
+ 2.0 * c * trace_rotation_covariance;
1192
+ if (sum_squared < 0.0) sum_squared = 0.0;
1597
1193
  *result = nk_f64_sqrt_haswell(sum_squared * inv_n);
1598
1194
  }
1599
1195
 
@@ -1603,85 +1199,34 @@ NK_PUBLIC void nk_rmsd_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size
1603
1199
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1604
1200
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1605
1201
  if (scale) *scale = 1.0f;
1202
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
1203
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
1606
1204
 
1607
- __m512 const zeros_f32x16 = _mm512_setzero_ps();
1608
- __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
1609
- __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
1610
- __m512 sum_squared_x_f32x16 = zeros_f32x16, sum_squared_y_f32x16 = zeros_f32x16;
1611
- __m512 sum_squared_z_f32x16 = zeros_f32x16;
1612
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1613
- nk_size_t i = 0;
1205
+ // 15-lane stride-3 layout: mask at the f16 level so the last chunk stays in-bounds.
1206
+ __m512 sum_squared_f32x16 = _mm512_setzero_ps();
1207
+ nk_size_t index = 0;
1614
1208
 
1615
- for (; i + 16 <= n; i += 16) {
1616
- nk_deinterleave_f16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1617
- nk_deinterleave_f16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1618
-
1619
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1620
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1621
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1622
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1623
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1624
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1625
-
1626
- __m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
1627
- __m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
1628
- __m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
1629
-
1630
- sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
1631
- sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
1632
- sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
1209
+ for (; index + 5 <= n; index += 5) {
1210
+ __m256i a_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
1211
+ __m256i b_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
1212
+ __m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
1213
+ __m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
1214
+ __m512 delta_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
1215
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_f32x16, delta_f32x16, sum_squared_f32x16);
1633
1216
  }
1634
1217
 
1635
- // Tail: deinterleave remaining points into zero-initialized vectors
1636
- if (i < n) {
1637
- nk_size_t tail = n - i;
1638
- nk_deinterleave_f16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1639
- nk_deinterleave_f16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1640
-
1641
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1642
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1643
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1644
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1645
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1646
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1647
-
1648
- __m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
1649
- __m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
1650
- __m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
1651
-
1652
- sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
1653
- sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
1654
- sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
1218
+ if (index < n) {
1219
+ __mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
1220
+ __m256i a_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
1221
+ __m256i b_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
1222
+ __m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
1223
+ __m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
1224
+ __m512 delta_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
1225
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_f32x16, delta_f32x16, sum_squared_f32x16);
1655
1226
  }
1656
1227
 
1657
- nk_f32_t total_ax = _mm512_reduce_add_ps(sum_a_x_f32x16);
1658
- nk_f32_t total_ay = _mm512_reduce_add_ps(sum_a_y_f32x16);
1659
- nk_f32_t total_az = _mm512_reduce_add_ps(sum_a_z_f32x16);
1660
- nk_f32_t total_bx = _mm512_reduce_add_ps(sum_b_x_f32x16);
1661
- nk_f32_t total_by = _mm512_reduce_add_ps(sum_b_y_f32x16);
1662
- nk_f32_t total_bz = _mm512_reduce_add_ps(sum_b_z_f32x16);
1663
- nk_f32_t total_sq_x = _mm512_reduce_add_ps(sum_squared_x_f32x16);
1664
- nk_f32_t total_sq_y = _mm512_reduce_add_ps(sum_squared_y_f32x16);
1665
- nk_f32_t total_sq_z = _mm512_reduce_add_ps(sum_squared_z_f32x16);
1666
-
1667
- nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1668
- nk_f32_t centroid_a_x = total_ax * inv_n;
1669
- nk_f32_t centroid_a_y = total_ay * inv_n;
1670
- nk_f32_t centroid_a_z = total_az * inv_n;
1671
- nk_f32_t centroid_b_x = total_bx * inv_n;
1672
- nk_f32_t centroid_b_y = total_by * inv_n;
1673
- nk_f32_t centroid_b_z = total_bz * inv_n;
1674
-
1675
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1676
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1677
-
1678
- nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
1679
- nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
1680
- nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
1681
- nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
1682
- nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
1683
-
1684
- *result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
1228
+ nk_f32_t sum_squared = _mm512_reduce_add_ps(sum_squared_f32x16);
1229
+ *result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
1685
1230
  }
1686
1231
 
1687
1232
  NK_PUBLIC void nk_rmsd_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
@@ -1690,642 +1235,541 @@ NK_PUBLIC void nk_rmsd_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_s
1690
1235
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1691
1236
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1692
1237
  if (scale) *scale = 1.0f;
1238
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
1239
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
1693
1240
 
1694
- __m512 const zeros_f32x16 = _mm512_setzero_ps();
1695
- __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
1696
- __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
1697
- __m512 sum_squared_x_f32x16 = zeros_f32x16, sum_squared_y_f32x16 = zeros_f32x16;
1698
- __m512 sum_squared_z_f32x16 = zeros_f32x16;
1699
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1700
- nk_size_t i = 0;
1241
+ // 15-lane stride-3 layout: mask at the bf16 level so the last chunk stays in-bounds.
1242
+ __m512 sum_squared_f32x16 = _mm512_setzero_ps();
1243
+ nk_size_t index = 0;
1701
1244
 
1702
- for (; i + 16 <= n; i += 16) {
1703
- nk_deinterleave_bf16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1704
- nk_deinterleave_bf16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1705
-
1706
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1707
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1708
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1709
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1710
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1711
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1712
-
1713
- __m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
1714
- __m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
1715
- __m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
1716
-
1717
- sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
1718
- sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
1719
- sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
1245
+ for (; index + 5 <= n; index += 5) {
1246
+ __m256i a_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
1247
+ __m256i b_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
1248
+ __m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
1249
+ __m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
1250
+ __m512 delta_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
1251
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_f32x16, delta_f32x16, sum_squared_f32x16);
1720
1252
  }
1721
1253
 
1722
- // Tail: deinterleave remaining points into zero-initialized vectors
1723
- if (i < n) {
1724
- nk_size_t tail = n - i;
1725
- nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1726
- nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1727
-
1728
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1729
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1730
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1731
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1732
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1733
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1734
-
1735
- __m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
1736
- __m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
1737
- __m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
1738
-
1739
- sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
1740
- sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
1741
- sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
1254
+ if (index < n) {
1255
+ __mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
1256
+ __m256i a_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
1257
+ __m256i b_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
1258
+ __m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
1259
+ __m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
1260
+ __m512 delta_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
1261
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_f32x16, delta_f32x16, sum_squared_f32x16);
1742
1262
  }
1743
1263
 
1744
- nk_f32_t total_ax = _mm512_reduce_add_ps(sum_a_x_f32x16);
1745
- nk_f32_t total_ay = _mm512_reduce_add_ps(sum_a_y_f32x16);
1746
- nk_f32_t total_az = _mm512_reduce_add_ps(sum_a_z_f32x16);
1747
- nk_f32_t total_bx = _mm512_reduce_add_ps(sum_b_x_f32x16);
1748
- nk_f32_t total_by = _mm512_reduce_add_ps(sum_b_y_f32x16);
1749
- nk_f32_t total_bz = _mm512_reduce_add_ps(sum_b_z_f32x16);
1750
- nk_f32_t total_sq_x = _mm512_reduce_add_ps(sum_squared_x_f32x16);
1751
- nk_f32_t total_sq_y = _mm512_reduce_add_ps(sum_squared_y_f32x16);
1752
- nk_f32_t total_sq_z = _mm512_reduce_add_ps(sum_squared_z_f32x16);
1753
-
1754
- nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1755
- nk_f32_t centroid_a_x = total_ax * inv_n;
1756
- nk_f32_t centroid_a_y = total_ay * inv_n;
1757
- nk_f32_t centroid_a_z = total_az * inv_n;
1758
- nk_f32_t centroid_b_x = total_bx * inv_n;
1759
- nk_f32_t centroid_b_y = total_by * inv_n;
1760
- nk_f32_t centroid_b_z = total_bz * inv_n;
1761
-
1762
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1763
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1764
-
1765
- nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
1766
- nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
1767
- nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
1768
- nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
1769
- nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
1770
-
1771
- *result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
1264
+ nk_f32_t sum_squared = _mm512_reduce_add_ps(sum_squared_f32x16);
1265
+ *result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
1772
1266
  }
1773
1267
 
1774
1268
  NK_PUBLIC void nk_kabsch_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1775
1269
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1776
- __m512 const zeros_f32x16 = _mm512_setzero_ps();
1777
-
1778
- __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
1779
- __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
1780
- __m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
1781
- __m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
1782
- __m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
1270
+ // 15-lane stride-3 layout: one masked epi16 load + widen gives {a_f32x16, b_f32x16} with
1271
+ // channel phase [x,y,z, x,y,z, x,y,z, x,y,z, x,y,z, _] constant across all chunks. The 9
1272
+ // H-cells come from three product accumulators a*b, a*rot1(b), a*rot2(b) demuxed per channel.
1273
+ __m512i const idx_rotation_1_i32x16 = _mm512_setr_epi32(1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9, 13, 14, 12, 15);
1274
+ __m512i const idx_rotation_2_i32x16 = _mm512_setr_epi32(2, 0, 1, 5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 15);
1783
1275
 
1784
- nk_size_t i = 0;
1785
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1276
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
1277
+ __m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
1278
+ __m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
1279
+ __m512 product_diagonal_f32x16 = zeros_f32x16;
1280
+ __m512 product_rotation_1_f32x16 = zeros_f32x16;
1281
+ __m512 product_rotation_2_f32x16 = zeros_f32x16;
1786
1282
 
1787
- for (; i + 16 <= n; i += 16) {
1788
- nk_deinterleave_f16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1789
- nk_deinterleave_f16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1790
-
1791
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1792
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1793
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1794
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1795
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1796
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1797
-
1798
- cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
1799
- cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
1800
- cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
1801
- cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
1802
- cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
1803
- cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
1804
- cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
1805
- cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
1806
- cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
1283
+ nk_size_t index = 0;
1284
+ for (; index + 5 <= n; index += 5) {
1285
+ __m256i a_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
1286
+ __m256i b_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
1287
+ __m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
1288
+ __m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
1289
+ __m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
1290
+ __m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
1291
+ sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
1292
+ sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
1293
+ norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
1294
+ norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
1295
+ product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
1296
+ product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
1297
+ product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
1807
1298
  }
1808
1299
 
1809
- // Tail: deinterleave remaining points into zero-initialized vectors
1810
- if (i < n) {
1811
- nk_size_t tail = n - i;
1812
- nk_deinterleave_f16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1813
- nk_deinterleave_f16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1814
-
1815
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1816
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1817
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1818
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1819
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1820
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1821
-
1822
- cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
1823
- cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
1824
- cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
1825
- cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
1826
- cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
1827
- cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
1828
- cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
1829
- cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
1830
- cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
1300
+ if (index < n) {
1301
+ __mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
1302
+ __m256i a_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
1303
+ __m256i b_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
1304
+ __m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
1305
+ __m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
1306
+ __m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
1307
+ __m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
1308
+ sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
1309
+ sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
1310
+ norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
1311
+ norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
1312
+ product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
1313
+ product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
1314
+ product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
1831
1315
  }
1832
1316
 
1833
- nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
1834
- nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
1835
- nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
1836
- nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
1837
- nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
1838
- nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
1839
- nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
1840
- nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
1841
- nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
1842
- nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
1843
- nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
1844
- nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
1845
- nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
1846
- nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
1847
- nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
1317
+ // Per-channel demux via mask-reduce on the fp32 accumulators (lane i carries channel i%3).
1318
+ __mmask16 const mask_channel_x_f32 = 0x1249; // lanes {0, 3, 6, 9, 12}
1319
+ __mmask16 const mask_channel_y_f32 = 0x2492; // lanes {1, 4, 7, 10, 13}
1320
+ __mmask16 const mask_channel_z_f32 = 0x4924; // lanes {2, 5, 8, 11, 14}
1321
+
1322
+ nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
1323
+ nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
1324
+ nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
1325
+ nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
1326
+ nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
1327
+ nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
1328
+ nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
1329
+ nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
1330
+
1331
+ nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
1332
+ nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
1333
+ nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
1334
+ nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
1335
+ nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
1336
+ nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
1337
+ nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
1338
+ nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
1339
+ nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
1848
1340
 
1849
1341
  nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1850
1342
  nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
1851
1343
  nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
1852
-
1853
1344
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1854
1345
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1855
1346
 
1856
- covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
1857
- covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
1858
- covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
1859
- covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
1860
- covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
1861
- covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
1862
- covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
1863
- covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
1864
- covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
1865
-
1866
- nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1867
- covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1868
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1869
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
1870
-
1871
- nk_f32_t r[9];
1872
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1873
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1874
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1875
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1876
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1877
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1878
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1879
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1880
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1881
-
1882
- if (nk_det3x3_f32_(r) < 0) {
1883
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1884
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1885
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1886
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1887
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1888
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1889
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1890
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1891
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1892
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1347
+ // Parallel-axis correction.
1348
+ nk_f32_t cross_covariance[9];
1349
+ cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
1350
+ cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
1351
+ cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
1352
+ cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
1353
+ cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
1354
+ cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
1355
+ cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
1356
+ cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
1357
+ cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
1358
+
1359
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1360
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1361
+ nk_f32_t optimal_rotation[9];
1362
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
1363
+ if (nk_det3x3_f32_(optimal_rotation) < 0) {
1364
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1365
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
1893
1366
  }
1894
-
1895
1367
  if (rotation)
1896
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1368
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1897
1369
  if (scale) *scale = 1.0f;
1898
1370
 
1899
- nk_f32_t sum_squared = nk_transformed_ssd_f16_skylake_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
1900
- centroid_b_x, centroid_b_y, centroid_b_z);
1371
+ // Folded SSD via trace identity:
1372
+ // SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered)
1373
+ // trace(R · H_centered) = Σⱼₖ R[j,k] · H[k,j] (note transpose on H).
1374
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
1375
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1376
+ centroid_a_z * centroid_a_z);
1377
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
1378
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1379
+ centroid_b_z * centroid_b_z);
1380
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1381
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1382
+ nk_f32_t trace_rotation_covariance =
1383
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1384
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1385
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1386
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1387
+ optimal_rotation[8] * cross_covariance[8];
1388
+ nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
1389
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
1901
1390
  *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
1902
1391
  }
1903
1392
 
1904
1393
  NK_PUBLIC void nk_kabsch_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1905
1394
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1906
- __m512 const zeros_f32x16 = _mm512_setzero_ps();
1907
-
1908
- __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
1909
- __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
1910
- __m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
1911
- __m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
1912
- __m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
1395
+ // 15-lane stride-3 layout: one masked epi16 load + widen gives {a_f32x16, b_f32x16} with
1396
+ // channel phase [x,y,z, x,y,z, x,y,z, x,y,z, x,y,z, _] constant across all chunks. The 9
1397
+ // H-cells come from three product accumulators a*b, a*rot1(b), a*rot2(b) demuxed per channel.
1398
+ __m512i const idx_rotation_1_i32x16 = _mm512_setr_epi32(1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9, 13, 14, 12, 15);
1399
+ __m512i const idx_rotation_2_i32x16 = _mm512_setr_epi32(2, 0, 1, 5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 15);
1913
1400
 
1914
- nk_size_t i = 0;
1915
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1401
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
1402
+ __m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
1403
+ __m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
1404
+ __m512 product_diagonal_f32x16 = zeros_f32x16;
1405
+ __m512 product_rotation_1_f32x16 = zeros_f32x16;
1406
+ __m512 product_rotation_2_f32x16 = zeros_f32x16;
1916
1407
 
1917
- for (; i + 16 <= n; i += 16) {
1918
- nk_deinterleave_bf16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1919
- nk_deinterleave_bf16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1920
-
1921
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1922
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1923
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1924
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1925
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1926
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1927
-
1928
- cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
1929
- cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
1930
- cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
1931
- cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
1932
- cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
1933
- cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
1934
- cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
1935
- cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
1936
- cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
1408
+ nk_size_t index = 0;
1409
+ for (; index + 5 <= n; index += 5) {
1410
+ __m256i a_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
1411
+ __m256i b_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
1412
+ __m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
1413
+ __m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
1414
+ __m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
1415
+ __m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
1416
+ sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
1417
+ sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
1418
+ norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
1419
+ norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
1420
+ product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
1421
+ product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
1422
+ product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
1937
1423
  }
1938
1424
 
1939
- // Tail: deinterleave remaining points into zero-initialized vectors
1940
- if (i < n) {
1941
- nk_size_t tail = n - i;
1942
- nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1943
- nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1944
-
1945
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1946
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1947
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1948
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1949
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1950
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1951
-
1952
- cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
1953
- cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
1954
- cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
1955
- cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
1956
- cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
1957
- cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
1958
- cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
1959
- cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
1960
- cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
1425
+ if (index < n) {
1426
+ __mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
1427
+ __m256i a_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
1428
+ __m256i b_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
1429
+ __m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
1430
+ __m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
1431
+ __m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
1432
+ __m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
1433
+ sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
1434
+ sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
1435
+ norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
1436
+ norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
1437
+ product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
1438
+ product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
1439
+ product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
1961
1440
  }
1962
1441
 
1963
- nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
1964
- nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
1965
- nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
1966
- nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
1967
- nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
1968
- nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
1969
- nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
1970
- nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
1971
- nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
1972
- nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
1973
- nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
1974
- nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
1975
- nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
1976
- nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
1977
- nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
1442
+ // Per-channel demux via mask-reduce on the fp32 accumulators (lane i carries channel i%3).
1443
+ __mmask16 const mask_channel_x_f32 = 0x1249; // lanes {0, 3, 6, 9, 12}
1444
+ __mmask16 const mask_channel_y_f32 = 0x2492; // lanes {1, 4, 7, 10, 13}
1445
+ __mmask16 const mask_channel_z_f32 = 0x4924; // lanes {2, 5, 8, 11, 14}
1446
+
1447
+ nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
1448
+ nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
1449
+ nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
1450
+ nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
1451
+ nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
1452
+ nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
1453
+ nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
1454
+ nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
1455
+
1456
+ nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
1457
+ nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
1458
+ nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
1459
+ nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
1460
+ nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
1461
+ nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
1462
+ nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
1463
+ nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
1464
+ nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
1978
1465
 
1979
1466
  nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1980
1467
  nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
1981
1468
  nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
1982
-
1983
1469
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1984
1470
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1985
1471
 
1986
- covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
1987
- covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
1988
- covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
1989
- covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
1990
- covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
1991
- covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
1992
- covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
1993
- covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
1994
- covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
1995
-
1996
- nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1997
- covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1998
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1999
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
2000
-
2001
- nk_f32_t r[9];
2002
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2003
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2004
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2005
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2006
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2007
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2008
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2009
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2010
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2011
-
2012
- if (nk_det3x3_f32_(r) < 0) {
2013
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
2014
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2015
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2016
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2017
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2018
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2019
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2020
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2021
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2022
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1472
+ // Parallel-axis correction.
1473
+ nk_f32_t cross_covariance[9];
1474
+ cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
1475
+ cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
1476
+ cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
1477
+ cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
1478
+ cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
1479
+ cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
1480
+ cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
1481
+ cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
1482
+ cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
1483
+
1484
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1485
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1486
+ nk_f32_t optimal_rotation[9];
1487
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
1488
+ if (nk_det3x3_f32_(optimal_rotation) < 0) {
1489
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1490
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
2023
1491
  }
2024
-
2025
1492
  if (rotation)
2026
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1493
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
2027
1494
  if (scale) *scale = 1.0f;
2028
1495
 
2029
- nk_f32_t sum_squared = nk_transformed_ssd_bf16_skylake_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
2030
- centroid_b_x, centroid_b_y, centroid_b_z);
1496
+ // Folded SSD via trace identity:
1497
+ // SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered)
1498
+ // trace(R · H_centered) = Σⱼₖ R[j,k] · H[k,j] (note transpose on H).
1499
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
1500
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1501
+ centroid_a_z * centroid_a_z);
1502
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
1503
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1504
+ centroid_b_z * centroid_b_z);
1505
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1506
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1507
+ nk_f32_t trace_rotation_covariance =
1508
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1509
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1510
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1511
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1512
+ optimal_rotation[8] * cross_covariance[8];
1513
+ nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
1514
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
2031
1515
  *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
2032
1516
  }
2033
1517
 
2034
1518
  NK_PUBLIC void nk_umeyama_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
2035
1519
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
2036
- __m512 const zeros_f32x16 = _mm512_setzero_ps();
1520
+ // Same 15-lane streaming-stats pattern as kabsch_f16_skylake; adds the Umeyama scale.
1521
+ __m512i const idx_rotation_1_i32x16 = _mm512_setr_epi32(1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9, 13, 14, 12, 15);
1522
+ __m512i const idx_rotation_2_i32x16 = _mm512_setr_epi32(2, 0, 1, 5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 15);
2037
1523
 
2038
- __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
2039
- __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
2040
- __m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
2041
- __m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
2042
- __m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
2043
- __m512 variance_a_f32x16 = zeros_f32x16;
2044
-
2045
- nk_size_t i = 0;
2046
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1524
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
1525
+ __m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
1526
+ __m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
1527
+ __m512 product_diagonal_f32x16 = zeros_f32x16;
1528
+ __m512 product_rotation_1_f32x16 = zeros_f32x16;
1529
+ __m512 product_rotation_2_f32x16 = zeros_f32x16;
2047
1530
 
2048
- for (; i + 16 <= n; i += 16) {
2049
- nk_deinterleave_f16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2050
- nk_deinterleave_f16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2051
-
2052
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2053
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2054
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2055
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2056
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2057
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2058
-
2059
- cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2060
- cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2061
- cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2062
- cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2063
- cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2064
- cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2065
- cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2066
- cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2067
- cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2068
-
2069
- variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
2070
- variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
2071
- variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
1531
+ nk_size_t index = 0;
1532
+ for (; index + 5 <= n; index += 5) {
1533
+ __m256i a_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
1534
+ __m256i b_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
1535
+ __m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
1536
+ __m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
1537
+ __m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
1538
+ __m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
1539
+ sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
1540
+ sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
1541
+ norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
1542
+ norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
1543
+ product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
1544
+ product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
1545
+ product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
2072
1546
  }
2073
1547
 
2074
- // Tail: deinterleave remaining points into zero-initialized vectors
2075
- if (i < n) {
2076
- nk_size_t tail = n - i;
2077
- nk_deinterleave_f16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2078
- nk_deinterleave_f16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2079
-
2080
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2081
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2082
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2083
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2084
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2085
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2086
-
2087
- cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2088
- cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2089
- cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2090
- cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2091
- cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2092
- cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2093
- cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2094
- cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2095
- cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2096
-
2097
- variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
2098
- variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
2099
- variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
1548
+ if (index < n) {
1549
+ __mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
1550
+ __m256i a_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
1551
+ __m256i b_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
1552
+ __m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
1553
+ __m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
1554
+ __m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
1555
+ __m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
1556
+ sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
1557
+ sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
1558
+ norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
1559
+ norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
1560
+ product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
1561
+ product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
1562
+ product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
2100
1563
  }
2101
1564
 
2102
- nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
2103
- nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
2104
- nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
2105
- nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
2106
- nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
2107
- nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
2108
- nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
2109
- nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
2110
- nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
2111
- nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
2112
- nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
2113
- nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
2114
- nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
2115
- nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
2116
- nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
2117
- nk_f32_t variance_a_sum = _mm512_reduce_add_ps(variance_a_f32x16);
1565
+ __mmask16 const mask_channel_x_f32 = 0x1249;
1566
+ __mmask16 const mask_channel_y_f32 = 0x2492;
1567
+ __mmask16 const mask_channel_z_f32 = 0x4924;
1568
+
1569
+ nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
1570
+ nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
1571
+ nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
1572
+ nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
1573
+ nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
1574
+ nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
1575
+ nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
1576
+ nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
1577
+
1578
+ nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
1579
+ nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
1580
+ nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
1581
+ nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
1582
+ nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
1583
+ nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
1584
+ nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
1585
+ nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
1586
+ nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
2118
1587
 
2119
1588
  nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
2120
1589
  nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
2121
1590
  nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
2122
-
2123
1591
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
2124
1592
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
2125
1593
 
2126
- nk_f32_t variance_a = variance_a_sum * inv_n -
2127
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
2128
-
2129
- covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
2130
- covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
2131
- covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
2132
- covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
2133
- covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
2134
- covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
2135
- covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
2136
- covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
2137
- covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
2138
-
2139
- nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
2140
- covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
2141
-
2142
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
2143
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
2144
-
2145
- nk_f32_t r[9];
2146
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2147
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2148
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2149
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2150
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2151
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2152
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2153
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2154
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2155
-
2156
- nk_f32_t det = nk_det3x3_f32_(r);
2157
- nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
2158
- nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
2159
- nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
1594
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
1595
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1596
+ centroid_a_z * centroid_a_z);
1597
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
1598
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1599
+ centroid_b_z * centroid_b_z);
1600
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1601
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1602
+
1603
+ nk_f32_t cross_covariance[9];
1604
+ cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
1605
+ cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
1606
+ cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
1607
+ cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
1608
+ cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
1609
+ cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
1610
+ cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
1611
+ cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
1612
+ cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
1613
+
1614
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1615
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1616
+ nk_f32_t optimal_rotation[9];
1617
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
1618
+
1619
+ // Scale factor: c = trace(D · S) / ‖a-ā‖², with reflection sign via d3.
1620
+ nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
1621
+ nk_f32_t d3 = det < 0.0f ? -1.0f : 1.0f;
1622
+ nk_f32_t trace_ds = nk_sum_three_products_f32_(svd_diagonal[0], 1.0f, svd_diagonal[4], 1.0f, svd_diagonal[8], d3);
1623
+ nk_f32_t c = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
2160
1624
  if (scale) *scale = c;
2161
1625
 
2162
- if (det < 0) {
2163
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
2164
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2165
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2166
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2167
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2168
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2169
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2170
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2171
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2172
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1626
+ if (det < 0.0f) {
1627
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1628
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
2173
1629
  }
2174
-
2175
1630
  if (rotation)
2176
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2177
-
2178
- nk_f32_t sum_squared = nk_transformed_ssd_f16_skylake_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
2179
- centroid_b_x, centroid_b_y, centroid_b_z);
1631
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1632
+
1633
+ // Folded SSD with scale:
1634
+ // SSD = c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
1635
+ nk_f32_t trace_rotation_covariance =
1636
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1637
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1638
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1639
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1640
+ optimal_rotation[8] * cross_covariance[8];
1641
+ nk_f32_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
1642
+ 2.0f * c * trace_rotation_covariance;
1643
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
2180
1644
  *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
2181
1645
  }
2182
1646
 
2183
1647
  NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
2184
1648
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
2185
- __m512 const zeros_f32x16 = _mm512_setzero_ps();
1649
+ // Same 15-lane streaming-stats pattern as kabsch_bf16_skylake; adds the Umeyama scale.
1650
+ __m512i const idx_rotation_1_i32x16 = _mm512_setr_epi32(1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9, 13, 14, 12, 15);
1651
+ __m512i const idx_rotation_2_i32x16 = _mm512_setr_epi32(2, 0, 1, 5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 15);
2186
1652
 
2187
- __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
2188
- __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
2189
- __m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
2190
- __m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
2191
- __m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
2192
- __m512 variance_a_f32x16 = zeros_f32x16;
2193
-
2194
- nk_size_t i = 0;
2195
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1653
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
1654
+ __m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
1655
+ __m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
1656
+ __m512 product_diagonal_f32x16 = zeros_f32x16;
1657
+ __m512 product_rotation_1_f32x16 = zeros_f32x16;
1658
+ __m512 product_rotation_2_f32x16 = zeros_f32x16;
2196
1659
 
2197
- for (; i + 16 <= n; i += 16) {
2198
- nk_deinterleave_bf16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2199
- nk_deinterleave_bf16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2200
-
2201
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2202
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2203
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2204
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2205
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2206
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2207
-
2208
- cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2209
- cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2210
- cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2211
- cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2212
- cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2213
- cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2214
- cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2215
- cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2216
- cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2217
-
2218
- variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
2219
- variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
2220
- variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
1660
+ nk_size_t index = 0;
1661
+ for (; index + 5 <= n; index += 5) {
1662
+ __m256i a_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
1663
+ __m256i b_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
1664
+ __m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
1665
+ __m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
1666
+ __m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
1667
+ __m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
1668
+ sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
1669
+ sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
1670
+ norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
1671
+ norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
1672
+ product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
1673
+ product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
1674
+ product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
2221
1675
  }
2222
1676
 
2223
- // Tail: deinterleave remaining points into zero-initialized vectors
2224
- if (i < n) {
2225
- nk_size_t tail = n - i;
2226
- nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2227
- nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2228
-
2229
- sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2230
- sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2231
- sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2232
- sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2233
- sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2234
- sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2235
-
2236
- cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2237
- cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2238
- cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2239
- cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2240
- cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2241
- cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2242
- cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2243
- cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2244
- cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2245
-
2246
- variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
2247
- variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
2248
- variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
1677
+ if (index < n) {
1678
+ __mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
1679
+ __m256i a_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
1680
+ __m256i b_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
1681
+ __m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
1682
+ __m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
1683
+ __m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
1684
+ __m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
1685
+ sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
1686
+ sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
1687
+ norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
1688
+ norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
1689
+ product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
1690
+ product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
1691
+ product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
2249
1692
  }
2250
1693
 
2251
- nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
2252
- nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
2253
- nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
2254
- nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
2255
- nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
2256
- nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
2257
- nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
2258
- nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
2259
- nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
2260
- nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
2261
- nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
2262
- nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
2263
- nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
2264
- nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
2265
- nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
2266
- nk_f32_t variance_a_sum = _mm512_reduce_add_ps(variance_a_f32x16);
1694
+ __mmask16 const mask_channel_x_f32 = 0x1249;
1695
+ __mmask16 const mask_channel_y_f32 = 0x2492;
1696
+ __mmask16 const mask_channel_z_f32 = 0x4924;
1697
+
1698
+ nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
1699
+ nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
1700
+ nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
1701
+ nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
1702
+ nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
1703
+ nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
1704
+ nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
1705
+ nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
1706
+
1707
+ nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
1708
+ nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
1709
+ nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
1710
+ nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
1711
+ nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
1712
+ nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
1713
+ nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
1714
+ nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
1715
+ nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
2267
1716
 
2268
1717
  nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
2269
1718
  nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
2270
1719
  nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
2271
-
2272
1720
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
2273
1721
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
2274
1722
 
2275
- nk_f32_t variance_a = variance_a_sum * inv_n -
2276
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
2277
-
2278
- covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
2279
- covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
2280
- covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
2281
- covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
2282
- covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
2283
- covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
2284
- covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
2285
- covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
2286
- covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
2287
-
2288
- nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
2289
- covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
2290
-
2291
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
2292
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
2293
-
2294
- nk_f32_t r[9];
2295
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2296
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2297
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2298
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2299
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2300
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2301
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2302
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2303
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2304
-
2305
- nk_f32_t det = nk_det3x3_f32_(r);
2306
- nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
2307
- nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
2308
- nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
1723
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
1724
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1725
+ centroid_a_z * centroid_a_z);
1726
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
1727
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1728
+ centroid_b_z * centroid_b_z);
1729
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1730
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1731
+
1732
+ nk_f32_t cross_covariance[9];
1733
+ cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
1734
+ cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
1735
+ cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
1736
+ cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
1737
+ cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
1738
+ cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
1739
+ cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
1740
+ cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
1741
+ cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
1742
+
1743
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1744
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1745
+ nk_f32_t optimal_rotation[9];
1746
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
1747
+
1748
+ // Scale factor: c = trace(D · S) / ‖a-ā‖², with reflection sign via d3.
1749
+ nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
1750
+ nk_f32_t d3 = det < 0.0f ? -1.0f : 1.0f;
1751
+ nk_f32_t trace_ds = nk_sum_three_products_f32_(svd_diagonal[0], 1.0f, svd_diagonal[4], 1.0f, svd_diagonal[8], d3);
1752
+ nk_f32_t c = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
2309
1753
  if (scale) *scale = c;
2310
1754
 
2311
- if (det < 0) {
2312
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
2313
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2314
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2315
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2316
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2317
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2318
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2319
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2320
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2321
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1755
+ if (det < 0.0f) {
1756
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1757
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
2322
1758
  }
2323
-
2324
1759
  if (rotation)
2325
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2326
-
2327
- nk_f32_t sum_squared = nk_transformed_ssd_bf16_skylake_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
2328
- centroid_b_x, centroid_b_y, centroid_b_z);
1760
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1761
+
1762
+ // Folded SSD with scale:
1763
+ // SSD = c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
1764
+ nk_f32_t trace_rotation_covariance =
1765
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1766
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1767
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1768
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1769
+ optimal_rotation[8] * cross_covariance[8];
1770
+ nk_f32_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
1771
+ 2.0f * c * trace_rotation_covariance;
1772
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
2329
1773
  *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
2330
1774
  }
2331
1775