numkong 7.4.4 → 7.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +81 -5
  3. package/c/dispatch_f16.c +23 -0
  4. package/c/numkong.c +0 -13
  5. package/include/numkong/attention/sme.h +34 -31
  6. package/include/numkong/capabilities.h +2 -15
  7. package/include/numkong/cast/neon.h +15 -0
  8. package/include/numkong/curved/smef64.h +82 -62
  9. package/include/numkong/dot/rvvbf16.h +1 -1
  10. package/include/numkong/dot/rvvhalf.h +1 -1
  11. package/include/numkong/dot/sve.h +6 -5
  12. package/include/numkong/dot/svebfdot.h +2 -1
  13. package/include/numkong/dot/svehalf.h +6 -5
  14. package/include/numkong/dot/svesdot.h +3 -2
  15. package/include/numkong/dots/graniteamx.h +733 -0
  16. package/include/numkong/dots/serial.h +11 -4
  17. package/include/numkong/dots/sme.h +172 -140
  18. package/include/numkong/dots/smebi32.h +14 -11
  19. package/include/numkong/dots/smef64.h +31 -26
  20. package/include/numkong/dots.h +29 -3
  21. package/include/numkong/each/serial.h +22 -0
  22. package/include/numkong/geospatial/haswell.h +1 -1
  23. package/include/numkong/geospatial/neon.h +1 -1
  24. package/include/numkong/geospatial/serial.h +1 -1
  25. package/include/numkong/geospatial/skylake.h +1 -1
  26. package/include/numkong/maxsim/sme.h +94 -55
  27. package/include/numkong/mesh/README.md +13 -27
  28. package/include/numkong/mesh/haswell.h +25 -122
  29. package/include/numkong/mesh/neon.h +21 -110
  30. package/include/numkong/mesh/neonbfdot.h +4 -43
  31. package/include/numkong/mesh/rvv.h +7 -82
  32. package/include/numkong/mesh/serial.h +48 -53
  33. package/include/numkong/mesh/skylake.h +7 -123
  34. package/include/numkong/mesh/v128relaxed.h +9 -93
  35. package/include/numkong/mesh.h +2 -2
  36. package/include/numkong/mesh.hpp +35 -96
  37. package/include/numkong/reduce/neon.h +29 -0
  38. package/include/numkong/reduce/neonbfdot.h +2 -2
  39. package/include/numkong/reduce/neonfhm.h +4 -4
  40. package/include/numkong/reduce/sve.h +52 -0
  41. package/include/numkong/reduce.h +4 -0
  42. package/include/numkong/set/sve.h +6 -5
  43. package/include/numkong/sets/smebi32.h +35 -30
  44. package/include/numkong/sparse/sve2.h +3 -2
  45. package/include/numkong/spatial/sve.h +7 -6
  46. package/include/numkong/spatial/svebfdot.h +7 -4
  47. package/include/numkong/spatial/svehalf.h +5 -4
  48. package/include/numkong/spatial/svesdot.h +9 -8
  49. package/include/numkong/spatials/graniteamx.h +173 -0
  50. package/include/numkong/spatials/serial.h +22 -0
  51. package/include/numkong/spatials/sme.h +391 -350
  52. package/include/numkong/spatials/smef64.h +79 -70
  53. package/include/numkong/spatials.h +37 -4
  54. package/include/numkong/types.h +59 -0
  55. package/javascript/dist/cjs/numkong.js +13 -0
  56. package/javascript/dist/esm/numkong.js +13 -0
  57. package/javascript/numkong.c +56 -12
  58. package/javascript/numkong.ts +13 -0
  59. package/package.json +7 -7
  60. package/probes/probe.js +2 -2
  61. package/wasm/numkong.wasm +0 -0
@@ -570,16 +570,10 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
570
570
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
571
571
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
572
572
  if (scale) *scale = 1.0f;
573
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
574
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
573
575
 
574
- // Fused single-pass: accumulate centroids and squared differences simultaneously.
575
- // RMSD = √(E[(a−b)²] − (ā − b̄)²)
576
576
  v128_t zero_f64x2 = wasm_f64x2_splat(0.0);
577
- v128_t sum_a_x_low_f64x2 = zero_f64x2, sum_a_x_high_f64x2 = zero_f64x2;
578
- v128_t sum_a_y_low_f64x2 = zero_f64x2, sum_a_y_high_f64x2 = zero_f64x2;
579
- v128_t sum_a_z_low_f64x2 = zero_f64x2, sum_a_z_high_f64x2 = zero_f64x2;
580
- v128_t sum_b_x_low_f64x2 = zero_f64x2, sum_b_x_high_f64x2 = zero_f64x2;
581
- v128_t sum_b_y_low_f64x2 = zero_f64x2, sum_b_y_high_f64x2 = zero_f64x2;
582
- v128_t sum_b_z_low_f64x2 = zero_f64x2, sum_b_z_high_f64x2 = zero_f64x2;
583
577
  v128_t sum_sq_x_low_f64x2 = zero_f64x2, sum_sq_x_high_f64x2 = zero_f64x2;
584
578
  v128_t sum_sq_y_low_f64x2 = zero_f64x2, sum_sq_y_high_f64x2 = zero_f64x2;
585
579
  v128_t sum_sq_z_low_f64x2 = zero_f64x2, sum_sq_z_high_f64x2 = zero_f64x2;
@@ -590,8 +584,7 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
590
584
  nk_deinterleave_f32x4_v128relaxed_(a + index * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
591
585
  nk_deinterleave_f32x4_v128relaxed_(b + index * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
592
586
 
593
- // Promote lower and upper halves to f64. Deltas computed in f64 to avoid
594
- // f32 cancellation in the single-pass formula RMSD = √(E[(a−b)²] − (ā − b̄)²).
587
+ // Promote lower and upper halves to f64 for precision.
595
588
  v128_t a_x_low_f64x2 = wasm_f64x2_promote_low_f32x4(a_x_f32x4);
596
589
  v128_t a_x_high_f64x2 = wasm_f64x2_promote_low_f32x4(wasm_i32x4_shuffle(a_x_f32x4, a_x_f32x4, 2, 3, 0, 1));
597
590
  v128_t a_y_low_f64x2 = wasm_f64x2_promote_low_f32x4(a_y_f32x4);
@@ -605,21 +598,7 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
605
598
  v128_t b_z_low_f64x2 = wasm_f64x2_promote_low_f32x4(b_z_f32x4);
606
599
  v128_t b_z_high_f64x2 = wasm_f64x2_promote_low_f32x4(wasm_i32x4_shuffle(b_z_f32x4, b_z_f32x4, 2, 3, 0, 1));
607
600
 
608
- // Accumulate centroids.
609
- sum_a_x_low_f64x2 = wasm_f64x2_add(sum_a_x_low_f64x2, a_x_low_f64x2);
610
- sum_a_x_high_f64x2 = wasm_f64x2_add(sum_a_x_high_f64x2, a_x_high_f64x2);
611
- sum_a_y_low_f64x2 = wasm_f64x2_add(sum_a_y_low_f64x2, a_y_low_f64x2);
612
- sum_a_y_high_f64x2 = wasm_f64x2_add(sum_a_y_high_f64x2, a_y_high_f64x2);
613
- sum_a_z_low_f64x2 = wasm_f64x2_add(sum_a_z_low_f64x2, a_z_low_f64x2);
614
- sum_a_z_high_f64x2 = wasm_f64x2_add(sum_a_z_high_f64x2, a_z_high_f64x2);
615
- sum_b_x_low_f64x2 = wasm_f64x2_add(sum_b_x_low_f64x2, b_x_low_f64x2);
616
- sum_b_x_high_f64x2 = wasm_f64x2_add(sum_b_x_high_f64x2, b_x_high_f64x2);
617
- sum_b_y_low_f64x2 = wasm_f64x2_add(sum_b_y_low_f64x2, b_y_low_f64x2);
618
- sum_b_y_high_f64x2 = wasm_f64x2_add(sum_b_y_high_f64x2, b_y_high_f64x2);
619
- sum_b_z_low_f64x2 = wasm_f64x2_add(sum_b_z_low_f64x2, b_z_low_f64x2);
620
- sum_b_z_high_f64x2 = wasm_f64x2_add(sum_b_z_high_f64x2, b_z_high_f64x2);
621
-
622
- // Accumulate squared differences in f64 — deltas computed in f64 for precision.
601
+ // Accumulate squared differences in f64.
623
602
  v128_t dx_low_f64x2 = wasm_f64x2_sub(a_x_low_f64x2, b_x_low_f64x2);
624
603
  v128_t dx_high_f64x2 = wasm_f64x2_sub(a_x_high_f64x2, b_x_high_f64x2);
625
604
  v128_t dy_low_f64x2 = wasm_f64x2_sub(a_y_low_f64x2, b_y_low_f64x2);
@@ -635,12 +614,6 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
635
614
  sum_sq_z_high_f64x2 = wasm_f64x2_relaxed_madd(dz_high_f64x2, dz_high_f64x2, sum_sq_z_high_f64x2);
636
615
  }
637
616
 
638
- nk_f64_t sum_a_x = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_a_x_low_f64x2, sum_a_x_high_f64x2));
639
- nk_f64_t sum_a_y = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_a_y_low_f64x2, sum_a_y_high_f64x2));
640
- nk_f64_t sum_a_z = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_a_z_low_f64x2, sum_a_z_high_f64x2));
641
- nk_f64_t sum_b_x = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_b_x_low_f64x2, sum_b_x_high_f64x2));
642
- nk_f64_t sum_b_y = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_b_y_low_f64x2, sum_b_y_high_f64x2));
643
- nk_f64_t sum_b_z = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_b_z_low_f64x2, sum_b_z_high_f64x2));
644
617
  nk_f64_t sum_sq_x = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_sq_x_low_f64x2, sum_sq_x_high_f64x2));
645
618
  nk_f64_t sum_sq_y = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_sq_y_low_f64x2, sum_sq_y_high_f64x2));
646
619
  nk_f64_t sum_sq_z = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_sq_z_low_f64x2, sum_sq_z_high_f64x2));
@@ -649,45 +622,25 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
649
622
  for (; index < n; ++index) {
650
623
  nk_f64_t ax = a[index * 3 + 0], ay = a[index * 3 + 1], az = a[index * 3 + 2];
651
624
  nk_f64_t bx = b[index * 3 + 0], by = b[index * 3 + 1], bz = b[index * 3 + 2];
652
- sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
653
- sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
654
625
  nk_f64_t dx = ax - bx, dy = ay - by, dz = az - bz;
655
626
  sum_sq_x += dx * dx, sum_sq_y += dy * dy, sum_sq_z += dz * dz;
656
627
  }
657
628
 
658
- nk_f64_t inv_points_count = 1.0 / (nk_f64_t)n;
659
- nk_f64_t centroid_a_x = sum_a_x * inv_points_count, centroid_a_y = sum_a_y * inv_points_count,
660
- centroid_a_z = sum_a_z * inv_points_count;
661
- nk_f64_t centroid_b_x = sum_b_x * inv_points_count, centroid_b_y = sum_b_y * inv_points_count,
662
- centroid_b_z = sum_b_z * inv_points_count;
663
- if (a_centroid)
664
- a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
665
- a_centroid[2] = (nk_f32_t)centroid_a_z;
666
- if (b_centroid)
667
- b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
668
- b_centroid[2] = (nk_f32_t)centroid_b_z;
669
-
670
- nk_f64_t sum_squared = sum_sq_x + sum_sq_y + sum_sq_z;
671
- nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
672
- nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
673
- nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
674
- nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
675
- *result = nk_f64_sqrt_v128relaxed(sum_squared * inv_points_count - mean_diff_sq);
629
+ *result = nk_f64_sqrt_v128relaxed((sum_sq_x + sum_sq_y + sum_sq_z) / (nk_f64_t)n);
676
630
  }
677
631
 
678
632
  NK_PUBLIC void nk_rmsd_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
679
633
  nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
680
- // RMSD uses identity rotation and scale=1.0
681
634
  if (rotation)
682
635
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
683
636
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
684
637
  if (scale) *scale = 1.0;
638
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
639
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
685
640
 
686
641
  v128_t const zeros_f64x2 = wasm_f64x2_splat(0);
687
642
 
688
- // Accumulators for centroids and squared differences
689
- v128_t sum_a_x_f64x2 = zeros_f64x2, sum_a_y_f64x2 = zeros_f64x2, sum_a_z_f64x2 = zeros_f64x2;
690
- v128_t sum_b_x_f64x2 = zeros_f64x2, sum_b_y_f64x2 = zeros_f64x2, sum_b_z_f64x2 = zeros_f64x2;
643
+ // Accumulators for squared differences
691
644
  v128_t sum_squared_x_f64x2 = zeros_f64x2, sum_squared_y_f64x2 = zeros_f64x2, sum_squared_z_f64x2 = zeros_f64x2;
692
645
 
693
646
  v128_t a_x_f64x2, a_y_f64x2, a_z_f64x2, b_x_f64x2, b_y_f64x2, b_z_f64x2;
@@ -698,13 +651,6 @@ NK_PUBLIC void nk_rmsd_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_
698
651
  nk_deinterleave_f64x2_v128relaxed_(a + i * 3, &a_x_f64x2, &a_y_f64x2, &a_z_f64x2);
699
652
  nk_deinterleave_f64x2_v128relaxed_(b + i * 3, &b_x_f64x2, &b_y_f64x2, &b_z_f64x2);
700
653
 
701
- sum_a_x_f64x2 = wasm_f64x2_add(sum_a_x_f64x2, a_x_f64x2);
702
- sum_a_y_f64x2 = wasm_f64x2_add(sum_a_y_f64x2, a_y_f64x2);
703
- sum_a_z_f64x2 = wasm_f64x2_add(sum_a_z_f64x2, a_z_f64x2);
704
- sum_b_x_f64x2 = wasm_f64x2_add(sum_b_x_f64x2, b_x_f64x2);
705
- sum_b_y_f64x2 = wasm_f64x2_add(sum_b_y_f64x2, b_y_f64x2);
706
- sum_b_z_f64x2 = wasm_f64x2_add(sum_b_z_f64x2, b_z_f64x2);
707
-
708
654
  v128_t delta_x_f64x2 = wasm_f64x2_sub(a_x_f64x2, b_x_f64x2);
709
655
  v128_t delta_y_f64x2 = wasm_f64x2_sub(a_y_f64x2, b_y_f64x2);
710
656
  v128_t delta_z_f64x2 = wasm_f64x2_sub(a_z_f64x2, b_z_f64x2);
@@ -715,12 +661,6 @@ NK_PUBLIC void nk_rmsd_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_
715
661
  }
716
662
 
717
663
  // Reduce vectors to scalars.
718
- nk_f64_t total_ax = nk_reduce_stable_f64x2_v128relaxed_(sum_a_x_f64x2), total_ax_compensation = 0.0;
719
- nk_f64_t total_ay = nk_reduce_stable_f64x2_v128relaxed_(sum_a_y_f64x2), total_ay_compensation = 0.0;
720
- nk_f64_t total_az = nk_reduce_stable_f64x2_v128relaxed_(sum_a_z_f64x2), total_az_compensation = 0.0;
721
- nk_f64_t total_bx = nk_reduce_stable_f64x2_v128relaxed_(sum_b_x_f64x2), total_bx_compensation = 0.0;
722
- nk_f64_t total_by = nk_reduce_stable_f64x2_v128relaxed_(sum_b_y_f64x2), total_by_compensation = 0.0;
723
- nk_f64_t total_bz = nk_reduce_stable_f64x2_v128relaxed_(sum_b_z_f64x2), total_bz_compensation = 0.0;
724
664
  nk_f64_t total_squared_x = nk_reduce_stable_f64x2_v128relaxed_(sum_squared_x_f64x2),
725
665
  total_squared_x_compensation = 0.0;
726
666
  nk_f64_t total_squared_y = nk_reduce_stable_f64x2_v128relaxed_(sum_squared_y_f64x2),
@@ -732,40 +672,16 @@ NK_PUBLIC void nk_rmsd_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_
732
672
  for (; i < n; ++i) {
733
673
  nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
734
674
  nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
735
- nk_accumulate_sum_f64_(&total_ax, &total_ax_compensation, ax);
736
- nk_accumulate_sum_f64_(&total_ay, &total_ay_compensation, ay);
737
- nk_accumulate_sum_f64_(&total_az, &total_az_compensation, az);
738
- nk_accumulate_sum_f64_(&total_bx, &total_bx_compensation, bx);
739
- nk_accumulate_sum_f64_(&total_by, &total_by_compensation, by);
740
- nk_accumulate_sum_f64_(&total_bz, &total_bz_compensation, bz);
741
675
  nk_f64_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
742
676
  nk_accumulate_square_f64_(&total_squared_x, &total_squared_x_compensation, delta_x);
743
677
  nk_accumulate_square_f64_(&total_squared_y, &total_squared_y_compensation, delta_y);
744
678
  nk_accumulate_square_f64_(&total_squared_z, &total_squared_z_compensation, delta_z);
745
679
  }
746
680
 
747
- total_ax += total_ax_compensation, total_ay += total_ay_compensation, total_az += total_az_compensation;
748
- total_bx += total_bx_compensation, total_by += total_by_compensation, total_bz += total_bz_compensation;
749
681
  total_squared_x += total_squared_x_compensation, total_squared_y += total_squared_y_compensation,
750
682
  total_squared_z += total_squared_z_compensation;
751
683
 
752
- // Compute centroids
753
- nk_f64_t inv_points_count = 1.0 / (nk_f64_t)n;
754
- nk_f64_t centroid_a_x = total_ax * inv_points_count, centroid_a_y = total_ay * inv_points_count,
755
- centroid_a_z = total_az * inv_points_count;
756
- nk_f64_t centroid_b_x = total_bx * inv_points_count, centroid_b_y = total_by * inv_points_count,
757
- centroid_b_z = total_bz * inv_points_count;
758
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
759
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
760
-
761
- // Compute RMSD
762
- nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
763
- nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
764
- nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
765
- nk_f64_t sum_squared = total_squared_x + total_squared_y + total_squared_z;
766
- nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
767
-
768
- *result = nk_f64_sqrt_v128relaxed(sum_squared * inv_points_count - mean_diff_sq);
684
+ *result = nk_f64_sqrt_v128relaxed((total_squared_x + total_squared_y + total_squared_z) / (nk_f64_t)n);
769
685
  }
770
686
 
771
687
  NK_PUBLIC void nk_kabsch_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
@@ -6,7 +6,7 @@
6
6
  *
7
7
  * Contains:
8
8
  *
9
- * - Root Mean Square Deviation (RMSD) for rigid body superposition
9
+ * - Root Mean Square Deviation (RMSD) of raw point differences
10
10
  * - Kabsch algorithm for optimal rigid body alignment (rotation only)
11
11
  * - Umeyama algorithm for similarity transform (rotation + uniform scaling)
12
12
  *
@@ -48,7 +48,7 @@
48
48
  *
49
49
  * @section algorithm_overview Algorithm Overview
50
50
  *
51
- * - RMSD: Simple root mean square deviation without alignment. R = identity, scale = 1.0
51
+ * - RMSD: Raw √(Σ‖aᵢ bᵢ‖² / n) without centering or alignment. R = identity, scale = 1.0, centroids zeroed
52
52
  * - Kabsch: Finds optimal rotation R minimizing ‖R × (a - ā) - (b - b̄)‖. scale = 1.0
53
53
  * - Umeyama: Finds optimal rotation R and scale c minimizing ‖c × R × (a - ā) - (b - b̄)‖
54
54
  *
@@ -354,74 +354,30 @@ void rmsd( //
354
354
  else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
355
355
  nk_rmsd_bf16(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_,
356
356
  scale ? &scale->raw_ : nullptr, &metric->raw_);
357
- // Scalar fallback
357
+ // Scalar fallback: raw √(Σ‖aᵢ − bᵢ‖² / n), no centering
358
358
  else {
359
- // Step 1: Compute centroids
360
- metric_type_ sum_a_x {}, sum_a_y {}, sum_a_z {};
361
- metric_type_ sum_b_x {}, sum_b_y {}, sum_b_z {};
362
- metric_type_ val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z;
363
-
364
- for (std::size_t i = 0; i < n; i++) {
365
- val_a_x = metric_type_(a[i * 3 + 0]);
366
- val_a_y = metric_type_(a[i * 3 + 1]);
367
- val_a_z = metric_type_(a[i * 3 + 2]);
368
- val_b_x = metric_type_(b[i * 3 + 0]);
369
- val_b_y = metric_type_(b[i * 3 + 1]);
370
- val_b_z = metric_type_(b[i * 3 + 2]);
371
- sum_a_x = sum_a_x + val_a_x;
372
- sum_a_y = sum_a_y + val_a_y;
373
- sum_a_z = sum_a_z + val_a_z;
374
- sum_b_x = sum_b_x + val_b_x;
375
- sum_b_y = sum_b_y + val_b_y;
376
- sum_b_z = sum_b_z + val_b_z;
377
- }
378
-
379
- metric_type_ inv_n = metric_type_(1.0) / metric_type_(static_cast<double>(n));
380
- metric_type_ centroid_a_x = sum_a_x * inv_n;
381
- metric_type_ centroid_a_y = sum_a_y * inv_n;
382
- metric_type_ centroid_a_z = sum_a_z * inv_n;
383
- metric_type_ centroid_b_x = sum_b_x * inv_n;
384
- metric_type_ centroid_b_y = sum_b_y * inv_n;
385
- metric_type_ centroid_b_z = sum_b_z * inv_n;
386
-
387
- // Step 2: Store centroids if requested
388
359
  if (a_centroid)
389
- a_centroid[0] = transform_type_(centroid_a_x), a_centroid[1] = transform_type_(centroid_a_y),
390
- a_centroid[2] = transform_type_(centroid_a_z);
360
+ a_centroid[0] = transform_type_(0.0), a_centroid[1] = transform_type_(0.0),
361
+ a_centroid[2] = transform_type_(0.0);
391
362
  if (b_centroid)
392
- b_centroid[0] = transform_type_(centroid_b_x), b_centroid[1] = transform_type_(centroid_b_y),
393
- b_centroid[2] = transform_type_(centroid_b_z);
394
-
395
- // Step 3: RMSD uses identity rotation and scale=1.0
363
+ b_centroid[0] = transform_type_(0.0), b_centroid[1] = transform_type_(0.0),
364
+ b_centroid[2] = transform_type_(0.0);
396
365
  if (rotation) {
397
- rotation[0] = transform_type_(1.0);
398
- rotation[1] = transform_type_(0.0);
399
- rotation[2] = transform_type_(0.0);
400
- rotation[3] = transform_type_(0.0);
401
- rotation[4] = transform_type_(1.0);
402
- rotation[5] = transform_type_(0.0);
403
- rotation[6] = transform_type_(0.0);
404
- rotation[7] = transform_type_(0.0);
405
- rotation[8] = transform_type_(1.0);
366
+ rotation[0] = transform_type_(1.0), rotation[1] = transform_type_(0.0), rotation[2] = transform_type_(0.0);
367
+ rotation[3] = transform_type_(0.0), rotation[4] = transform_type_(1.0), rotation[5] = transform_type_(0.0);
368
+ rotation[6] = transform_type_(0.0), rotation[7] = transform_type_(0.0), rotation[8] = transform_type_(1.0);
406
369
  }
407
370
  if (scale) *scale = transform_type_(1.0);
408
371
 
409
- // Step 4: Compute RMSD between centered point clouds
410
372
  metric_type_ sum_squared {};
411
373
  for (std::size_t i = 0; i < n; i++) {
412
- val_a_x = metric_type_(a[i * 3 + 0]);
413
- val_a_y = metric_type_(a[i * 3 + 1]);
414
- val_a_z = metric_type_(a[i * 3 + 2]);
415
- val_b_x = metric_type_(b[i * 3 + 0]);
416
- val_b_y = metric_type_(b[i * 3 + 1]);
417
- val_b_z = metric_type_(b[i * 3 + 2]);
418
- metric_type_ dx = (val_a_x - centroid_a_x) - (val_b_x - centroid_b_x);
419
- metric_type_ dy = (val_a_y - centroid_a_y) - (val_b_y - centroid_b_y);
420
- metric_type_ dz = (val_a_z - centroid_a_z) - (val_b_z - centroid_b_z);
374
+ metric_type_ dx = metric_type_(a[i * 3 + 0]) - metric_type_(b[i * 3 + 0]);
375
+ metric_type_ dy = metric_type_(a[i * 3 + 1]) - metric_type_(b[i * 3 + 1]);
376
+ metric_type_ dz = metric_type_(a[i * 3 + 2]) - metric_type_(b[i * 3 + 2]);
421
377
  sum_squared = sum_squared + dx * dx + dy * dy + dz * dz;
422
378
  }
423
379
 
424
- *metric = (sum_squared * inv_n).sqrt();
380
+ *metric = (sum_squared / metric_type_(static_cast<double>(n))).sqrt();
425
381
  }
426
382
  }
427
383
 
@@ -470,18 +426,12 @@ void kabsch( //
470
426
  metric_type_ val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z;
471
427
 
472
428
  for (std::size_t i = 0; i < n; i++) {
473
- val_a_x = metric_type_(a[i * 3 + 0]);
474
- val_a_y = metric_type_(a[i * 3 + 1]);
429
+ val_a_x = metric_type_(a[i * 3 + 0]), val_a_y = metric_type_(a[i * 3 + 1]),
475
430
  val_a_z = metric_type_(a[i * 3 + 2]);
476
- val_b_x = metric_type_(b[i * 3 + 0]);
477
- val_b_y = metric_type_(b[i * 3 + 1]);
431
+ val_b_x = metric_type_(b[i * 3 + 0]), val_b_y = metric_type_(b[i * 3 + 1]),
478
432
  val_b_z = metric_type_(b[i * 3 + 2]);
479
- sum_a_x = sum_a_x + val_a_x;
480
- sum_a_y = sum_a_y + val_a_y;
481
- sum_a_z = sum_a_z + val_a_z;
482
- sum_b_x = sum_b_x + val_b_x;
483
- sum_b_y = sum_b_y + val_b_y;
484
- sum_b_z = sum_b_z + val_b_z;
433
+ sum_a_x = sum_a_x + val_a_x, sum_a_y = sum_a_y + val_a_y, sum_a_z = sum_a_z + val_a_z;
434
+ sum_b_x = sum_b_x + val_b_x, sum_b_y = sum_b_y + val_b_y, sum_b_z = sum_b_z + val_b_z;
485
435
  }
486
436
 
487
437
  metric_type_ inv_n = metric_type_(1.0) / metric_type_(static_cast<double>(n));
@@ -503,11 +453,9 @@ void kabsch( //
503
453
  // Step 2: Build 3x3 covariance matrix H = (A - A_bar)^T x (B - B_bar)
504
454
  metric_type_ cross_covariance[9] = {};
505
455
  for (std::size_t i = 0; i < n; i++) {
506
- val_a_x = metric_type_(a[i * 3 + 0]) - centroid_a_x;
507
- val_a_y = metric_type_(a[i * 3 + 1]) - centroid_a_y;
456
+ val_a_x = metric_type_(a[i * 3 + 0]) - centroid_a_x, val_a_y = metric_type_(a[i * 3 + 1]) - centroid_a_y,
508
457
  val_a_z = metric_type_(a[i * 3 + 2]) - centroid_a_z;
509
- val_b_x = metric_type_(b[i * 3 + 0]) - centroid_b_x;
510
- val_b_y = metric_type_(b[i * 3 + 1]) - centroid_b_y;
458
+ val_b_x = metric_type_(b[i * 3 + 0]) - centroid_b_x, val_b_y = metric_type_(b[i * 3 + 1]) - centroid_b_y,
511
459
  val_b_z = metric_type_(b[i * 3 + 2]) - centroid_b_z;
512
460
  cross_covariance[0] = cross_covariance[0] + val_a_x * val_b_x;
513
461
  cross_covariance[1] = cross_covariance[1] + val_a_x * val_b_y;
@@ -563,11 +511,11 @@ void kabsch( //
563
511
  metric_type_ sum_squared {};
564
512
  for (std::size_t i = 0; i < n; i++) {
565
513
  metric_type_ point_a[3], point_b[3], rotated_point_a[3];
566
- point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x;
567
- point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y;
514
+ point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x,
515
+ point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y,
568
516
  point_a[2] = metric_type_(a[i * 3 + 2]) - centroid_a_z;
569
- point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x;
570
- point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y;
517
+ point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x,
518
+ point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y,
571
519
  point_b[2] = metric_type_(b[i * 3 + 2]) - centroid_b_z;
572
520
  rotated_point_a[0] = rotation_matrix[0] * point_a[0] + rotation_matrix[1] * point_a[1] +
573
521
  rotation_matrix[2] * point_a[2];
@@ -628,18 +576,12 @@ void umeyama(in_type_ const *a, in_type_ const *b, std::size_t n, transform_type
628
576
  metric_type_ val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z;
629
577
 
630
578
  for (std::size_t i = 0; i < n; i++) {
631
- val_a_x = metric_type_(a[i * 3 + 0]);
632
- val_a_y = metric_type_(a[i * 3 + 1]);
579
+ val_a_x = metric_type_(a[i * 3 + 0]), val_a_y = metric_type_(a[i * 3 + 1]),
633
580
  val_a_z = metric_type_(a[i * 3 + 2]);
634
- val_b_x = metric_type_(b[i * 3 + 0]);
635
- val_b_y = metric_type_(b[i * 3 + 1]);
581
+ val_b_x = metric_type_(b[i * 3 + 0]), val_b_y = metric_type_(b[i * 3 + 1]),
636
582
  val_b_z = metric_type_(b[i * 3 + 2]);
637
- sum_a_x = sum_a_x + val_a_x;
638
- sum_a_y = sum_a_y + val_a_y;
639
- sum_a_z = sum_a_z + val_a_z;
640
- sum_b_x = sum_b_x + val_b_x;
641
- sum_b_y = sum_b_y + val_b_y;
642
- sum_b_z = sum_b_z + val_b_z;
583
+ sum_a_x = sum_a_x + val_a_x, sum_a_y = sum_a_y + val_a_y, sum_a_z = sum_a_z + val_a_z;
584
+ sum_b_x = sum_b_x + val_b_x, sum_b_y = sum_b_y + val_b_y, sum_b_z = sum_b_z + val_b_z;
643
585
  }
644
586
 
645
587
  metric_type_ inv_n = metric_type_(1.0) / metric_type_(static_cast<double>(n));
@@ -650,16 +592,13 @@ void umeyama(in_type_ const *a, in_type_ const *b, std::size_t n, transform_type
650
592
  metric_type_ centroid_b_y = sum_b_y * inv_n;
651
593
  metric_type_ centroid_b_z = sum_b_z * inv_n;
652
594
 
653
- if (a_centroid) {
654
- a_centroid[0] = transform_type_(centroid_a_x);
655
- a_centroid[1] = transform_type_(centroid_a_y);
595
+ if (a_centroid)
596
+ a_centroid[0] = transform_type_(centroid_a_x), a_centroid[1] = transform_type_(centroid_a_y),
656
597
  a_centroid[2] = transform_type_(centroid_a_z);
657
- }
658
- if (b_centroid) {
659
- b_centroid[0] = transform_type_(centroid_b_x);
660
- b_centroid[1] = transform_type_(centroid_b_y);
598
+
599
+ if (b_centroid)
600
+ b_centroid[0] = transform_type_(centroid_b_x), b_centroid[1] = transform_type_(centroid_b_y),
661
601
  b_centroid[2] = transform_type_(centroid_b_z);
662
- }
663
602
 
664
603
  // Step 2: Build covariance matrix H and compute variance of A
665
604
  metric_type_ cross_covariance[9] = {};
@@ -733,11 +672,11 @@ void umeyama(in_type_ const *a, in_type_ const *b, std::size_t n, transform_type
733
672
  metric_type_ sum_squared {};
734
673
  for (std::size_t i = 0; i < n; i++) {
735
674
  metric_type_ point_a[3], point_b[3], rotated_point_a[3];
736
- point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x;
737
- point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y;
675
+ point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x,
676
+ point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y,
738
677
  point_a[2] = metric_type_(a[i * 3 + 2]) - centroid_a_z;
739
- point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x;
740
- point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y;
678
+ point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x,
679
+ point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y,
741
680
  point_b[2] = metric_type_(b[i * 3 + 2]) - centroid_b_z;
742
681
  rotated_point_a[0] = scale_factor * (rotation_matrix[0] * point_a[0] + rotation_matrix[1] * point_a[1] +
743
682
  rotation_matrix[2] * point_a[2]);
@@ -3936,6 +3936,35 @@ NK_PUBLIC void nk_reduce_moments_f16_neon( //
3936
3936
  else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3937
3937
  }
3938
3938
 
3939
+ NK_INTERNAL void nk_reduce_moments_u1_neon_contiguous_( //
3940
+ nk_u1x8_t const *data_ptr, nk_size_t count, //
3941
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
3942
+ nk_size_t byte_count = nk_size_divide_round_up_(count, NK_BITS_PER_BYTE);
3943
+ nk_u64_t sum = 0;
3944
+ nk_size_t idx = 0;
3945
+ // Each vcntq_u8 produces values 0-8 per lane; accumulate at u8 level
3946
+ // for up to 31 iterations (31 × 8 = 248, fits in u8) before widening.
3947
+ while (idx + 16 <= byte_count) {
3948
+ uint8x16_t popcount_u8x16 = vdupq_n_u8(0);
3949
+ for (nk_size_t cycle = 0; cycle < 31 && idx + 16 <= byte_count; ++cycle, idx += 16) {
3950
+ uint8x16_t data_u8x16 = vld1q_u8((nk_u8_t const *)data_ptr + idx);
3951
+ popcount_u8x16 = vaddq_u8(popcount_u8x16, vcntq_u8(data_u8x16));
3952
+ }
3953
+ sum += (nk_u64_t)vaddlvq_u8(popcount_u8x16);
3954
+ }
3955
+ for (; idx < byte_count; ++idx) sum += nk_u1x8_popcount_(((nk_u8_t const *)data_ptr)[idx]);
3956
+ *sum_ptr = sum, *sumsq_ptr = sum;
3957
+ }
3958
+
3959
+ NK_PUBLIC void nk_reduce_moments_u1_neon( //
3960
+ nk_u1x8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3961
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
3962
+ count = nk_size_round_up_to_multiple_(count, 8);
3963
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
3964
+ else if (stride_bytes == 1) nk_reduce_moments_u1_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
3965
+ else nk_reduce_moments_u1_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3966
+ }
3967
+
3939
3968
  #if defined(__clang__)
3940
3969
  #pragma clang attribute pop
3941
3970
  #elif defined(__GNUC__)
@@ -33,7 +33,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_contiguous_( //
33
33
  nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
34
34
 
35
35
  // bf16 representation of 1.0 is 0x3F80 (same as upper 16 bits of f32 1.0)
36
- bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(vdupq_n_u16(0x3F80));
36
+ bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
37
37
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
38
38
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
39
39
  nk_size_t idx = 0;
@@ -61,7 +61,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_strided_( //
61
61
  nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
62
62
  nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
63
63
 
64
- bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(vdupq_n_u16(0x3F80));
64
+ bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
65
65
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
66
66
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
67
67
  nk_size_t idx = 0;
@@ -34,7 +34,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_contiguous_( //
34
34
 
35
35
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
36
36
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
37
- float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
37
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
38
38
  nk_size_t idx = 0;
39
39
 
40
40
  for (; idx + 8 <= count; idx += 8) {
@@ -67,7 +67,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
67
67
 
68
68
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
69
69
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
70
- float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
70
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
71
71
  nk_size_t idx = 0;
72
72
 
73
73
  if (stride_elements == 2) {
@@ -159,7 +159,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_contiguous_( //
159
159
 
160
160
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
161
161
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
162
- float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
162
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
163
163
  nk_size_t idx = 0;
164
164
 
165
165
  for (; idx + 8 <= count; idx += 8) {
@@ -192,7 +192,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
192
192
 
193
193
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
194
194
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
195
- float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
195
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
196
196
  nk_size_t idx = 0;
197
197
 
198
198
  if (stride_elements == 2) {
@@ -0,0 +1,52 @@
1
+ /**
2
+ * @brief SVE horizontal reduction helpers with MSan unpoisoning.
3
+ * @file include/numkong/reduce/sve.h
4
+ * @author Ash Vardanian
5
+ * @date April 12, 2026
6
+ *
7
+ * LLVM's MSan does not instrument ARM SVE intrinsics — `svaddv` moves data
8
+ * from vector to scalar registers via architecture-specific paths invisible
9
+ * to the compiler, causing false-positive uninitialized-value reports.
10
+ * These macros wrap the reduction and unpoison the scalar result.
11
+ *
12
+ * The `svaddv` intrinsic stays inside a macro so it expands in the caller's
13
+ * target context — SVE and SME streaming translation units carry incompatible
14
+ * target attributes. The unpoisoning runs on the already-reduced scalar, so it
15
+ * lives in a target-agnostic `NK_INTERNAL` helper called from the macro.
16
+ *
17
+ * @sa include/numkong/reduce.h
18
+ */
19
+ #ifndef NK_REDUCE_SVE_H
20
+ #define NK_REDUCE_SVE_H
21
+
22
+ #if NK_TARGET_ARM64_
23
+ #if NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
24
+
25
+ #include "numkong/types.h"
26
+
27
+ NK_INTERNAL nk_f64_t nk_unpoison_f64_(nk_f64_t v) NK_STREAMING_COMPATIBLE_ {
28
+ nk_unpoison_(&v, sizeof(v));
29
+ return v;
30
+ }
31
+ NK_INTERNAL nk_f32_t nk_unpoison_f32_(nk_f32_t v) NK_STREAMING_COMPATIBLE_ {
32
+ nk_unpoison_(&v, sizeof(v));
33
+ return v;
34
+ }
35
+ NK_INTERNAL nk_u64_t nk_unpoison_u64_(nk_u64_t v) NK_STREAMING_COMPATIBLE_ {
36
+ nk_unpoison_(&v, sizeof(v));
37
+ return v;
38
+ }
39
+ NK_INTERNAL nk_i64_t nk_unpoison_i64_(nk_i64_t v) NK_STREAMING_COMPATIBLE_ {
40
+ nk_unpoison_(&v, sizeof(v));
41
+ return v;
42
+ }
43
+
44
+ #define nk_svaddv_f64_(predicate, vector) nk_unpoison_f64_(svaddv_f64((predicate), (vector)))
45
+ #define nk_svaddv_f32_(predicate, vector) nk_unpoison_f32_(svaddv_f32((predicate), (vector)))
46
+ #define nk_svaddv_u32_(predicate, vector) nk_unpoison_u64_(svaddv_u32((predicate), (vector)))
47
+ #define nk_svaddv_s32_(predicate, vector) nk_unpoison_i64_(svaddv_s32((predicate), (vector)))
48
+ #define nk_svaddv_u8_(predicate, vector) nk_unpoison_u64_(svaddv_u8((predicate), (vector)))
49
+
50
+ #endif // NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
51
+ #endif // NK_TARGET_ARM64_
52
+ #endif // NK_REDUCE_SVE_H
@@ -389,6 +389,8 @@ NK_PUBLIC void nk_reduce_moments_i16_neon(nk_i16_t const *, nk_size_t, nk_size_t
389
389
  /** @copydoc nk_reduce_moments_f64 */
390
390
  NK_PUBLIC void nk_reduce_moments_u16_neon(nk_u16_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
391
391
  /** @copydoc nk_reduce_moments_f64 */
392
+ NK_PUBLIC void nk_reduce_moments_u1_neon(nk_u1x8_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
393
+ /** @copydoc nk_reduce_moments_f64 */
392
394
  NK_PUBLIC void nk_reduce_moments_i32_neon(nk_i32_t const *, nk_size_t, nk_size_t, nk_i64_t *, nk_u64_t *);
393
395
  /** @copydoc nk_reduce_moments_f64 */
394
396
  NK_PUBLIC void nk_reduce_moments_u32_neon(nk_u32_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
@@ -1559,6 +1561,8 @@ NK_PUBLIC void nk_reduce_moments_u1(nk_u1x8_t const *d, nk_size_t n, nk_size_t s
1559
1561
  nk_reduce_moments_u1_skylake(d, n, s, sum, sumsq);
1560
1562
  #elif NK_TARGET_HASWELL
1561
1563
  nk_reduce_moments_u1_haswell(d, n, s, sum, sumsq);
1564
+ #elif NK_TARGET_NEON
1565
+ nk_reduce_moments_u1_neon(d, n, s, sum, sumsq);
1562
1566
  #else
1563
1567
  nk_reduce_moments_u1_serial(d, n, s, sum, sumsq);
1564
1568
  #endif
@@ -32,8 +32,9 @@
32
32
  #if NK_TARGET_ARM64_
33
33
  #if NK_TARGET_SVE
34
34
 
35
- #include "numkong/types.h" // `nk_u1x8_t`
36
- #include "numkong/set/neon.h" // `nk_hamming_u1_neon`
35
+ #include "numkong/types.h" // `nk_u1x8_t`
36
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
37
+ #include "numkong/set/neon.h" // `nk_hamming_u1_neon`
37
38
 
38
39
  #if defined(__cplusplus)
39
40
  extern "C" {
@@ -73,7 +74,7 @@ NK_PUBLIC void nk_hamming_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
73
74
  i += words_per_register;
74
75
  ++cycle;
75
76
  } while (i < n_bytes && cycle < 31);
76
- differences += svaddv_u8(all_predicate_b8x, popcount_u8x);
77
+ differences += nk_svaddv_u8_(all_predicate_b8x, popcount_u8x);
77
78
  popcount_u8x = svdup_n_u8(0);
78
79
  cycle = 0; // Reset the cycle counter.
79
80
  }
@@ -110,9 +111,9 @@ NK_PUBLIC void nk_jaccard_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
110
111
  i += words_per_register;
111
112
  ++cycle;
112
113
  } while (i < n_bytes && cycle < 31);
113
- intersection_count += svaddv_u8(all_predicate_b8x, intersection_popcount_u8x);
114
+ intersection_count += nk_svaddv_u8_(all_predicate_b8x, intersection_popcount_u8x);
114
115
  intersection_popcount_u8x = svdup_n_u8(0);
115
- union_count += svaddv_u8(all_predicate_b8x, union_popcount_u8x);
116
+ union_count += nk_svaddv_u8_(all_predicate_b8x, union_popcount_u8x);
116
117
  union_popcount_u8x = svdup_n_u8(0);
117
118
  cycle = 0; // Reset the cycle counter.
118
119
  }