numkong 7.4.4 → 7.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +81 -5
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/graniteamx.h +733 -0
- package/include/numkong/dots/serial.h +11 -4
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +29 -3
- package/include/numkong/each/serial.h +22 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +1 -1
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/sme.h +94 -55
- package/include/numkong/mesh/README.md +13 -27
- package/include/numkong/mesh/haswell.h +25 -122
- package/include/numkong/mesh/neon.h +21 -110
- package/include/numkong/mesh/neonbfdot.h +4 -43
- package/include/numkong/mesh/rvv.h +7 -82
- package/include/numkong/mesh/serial.h +48 -53
- package/include/numkong/mesh/skylake.h +7 -123
- package/include/numkong/mesh/v128relaxed.h +9 -93
- package/include/numkong/mesh.h +2 -2
- package/include/numkong/mesh.hpp +35 -96
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatials/graniteamx.h +173 -0
- package/include/numkong/spatials/serial.h +22 -0
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +37 -4
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +56 -12
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -570,16 +570,10 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
|
|
|
570
570
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
571
571
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
572
572
|
if (scale) *scale = 1.0f;
|
|
573
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
574
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
573
575
|
|
|
574
|
-
// Fused single-pass: accumulate centroids and squared differences simultaneously.
|
|
575
|
-
// RMSD = √(E[(a−b)²] − (ā − b̄)²)
|
|
576
576
|
v128_t zero_f64x2 = wasm_f64x2_splat(0.0);
|
|
577
|
-
v128_t sum_a_x_low_f64x2 = zero_f64x2, sum_a_x_high_f64x2 = zero_f64x2;
|
|
578
|
-
v128_t sum_a_y_low_f64x2 = zero_f64x2, sum_a_y_high_f64x2 = zero_f64x2;
|
|
579
|
-
v128_t sum_a_z_low_f64x2 = zero_f64x2, sum_a_z_high_f64x2 = zero_f64x2;
|
|
580
|
-
v128_t sum_b_x_low_f64x2 = zero_f64x2, sum_b_x_high_f64x2 = zero_f64x2;
|
|
581
|
-
v128_t sum_b_y_low_f64x2 = zero_f64x2, sum_b_y_high_f64x2 = zero_f64x2;
|
|
582
|
-
v128_t sum_b_z_low_f64x2 = zero_f64x2, sum_b_z_high_f64x2 = zero_f64x2;
|
|
583
577
|
v128_t sum_sq_x_low_f64x2 = zero_f64x2, sum_sq_x_high_f64x2 = zero_f64x2;
|
|
584
578
|
v128_t sum_sq_y_low_f64x2 = zero_f64x2, sum_sq_y_high_f64x2 = zero_f64x2;
|
|
585
579
|
v128_t sum_sq_z_low_f64x2 = zero_f64x2, sum_sq_z_high_f64x2 = zero_f64x2;
|
|
@@ -590,8 +584,7 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
|
|
|
590
584
|
nk_deinterleave_f32x4_v128relaxed_(a + index * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
591
585
|
nk_deinterleave_f32x4_v128relaxed_(b + index * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
592
586
|
|
|
593
|
-
// Promote lower and upper halves to f64
|
|
594
|
-
// f32 cancellation in the single-pass formula RMSD = √(E[(a−b)²] − (ā − b̄)²).
|
|
587
|
+
// Promote lower and upper halves to f64 for precision.
|
|
595
588
|
v128_t a_x_low_f64x2 = wasm_f64x2_promote_low_f32x4(a_x_f32x4);
|
|
596
589
|
v128_t a_x_high_f64x2 = wasm_f64x2_promote_low_f32x4(wasm_i32x4_shuffle(a_x_f32x4, a_x_f32x4, 2, 3, 0, 1));
|
|
597
590
|
v128_t a_y_low_f64x2 = wasm_f64x2_promote_low_f32x4(a_y_f32x4);
|
|
@@ -605,21 +598,7 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
|
|
|
605
598
|
v128_t b_z_low_f64x2 = wasm_f64x2_promote_low_f32x4(b_z_f32x4);
|
|
606
599
|
v128_t b_z_high_f64x2 = wasm_f64x2_promote_low_f32x4(wasm_i32x4_shuffle(b_z_f32x4, b_z_f32x4, 2, 3, 0, 1));
|
|
607
600
|
|
|
608
|
-
// Accumulate
|
|
609
|
-
sum_a_x_low_f64x2 = wasm_f64x2_add(sum_a_x_low_f64x2, a_x_low_f64x2);
|
|
610
|
-
sum_a_x_high_f64x2 = wasm_f64x2_add(sum_a_x_high_f64x2, a_x_high_f64x2);
|
|
611
|
-
sum_a_y_low_f64x2 = wasm_f64x2_add(sum_a_y_low_f64x2, a_y_low_f64x2);
|
|
612
|
-
sum_a_y_high_f64x2 = wasm_f64x2_add(sum_a_y_high_f64x2, a_y_high_f64x2);
|
|
613
|
-
sum_a_z_low_f64x2 = wasm_f64x2_add(sum_a_z_low_f64x2, a_z_low_f64x2);
|
|
614
|
-
sum_a_z_high_f64x2 = wasm_f64x2_add(sum_a_z_high_f64x2, a_z_high_f64x2);
|
|
615
|
-
sum_b_x_low_f64x2 = wasm_f64x2_add(sum_b_x_low_f64x2, b_x_low_f64x2);
|
|
616
|
-
sum_b_x_high_f64x2 = wasm_f64x2_add(sum_b_x_high_f64x2, b_x_high_f64x2);
|
|
617
|
-
sum_b_y_low_f64x2 = wasm_f64x2_add(sum_b_y_low_f64x2, b_y_low_f64x2);
|
|
618
|
-
sum_b_y_high_f64x2 = wasm_f64x2_add(sum_b_y_high_f64x2, b_y_high_f64x2);
|
|
619
|
-
sum_b_z_low_f64x2 = wasm_f64x2_add(sum_b_z_low_f64x2, b_z_low_f64x2);
|
|
620
|
-
sum_b_z_high_f64x2 = wasm_f64x2_add(sum_b_z_high_f64x2, b_z_high_f64x2);
|
|
621
|
-
|
|
622
|
-
// Accumulate squared differences in f64 — deltas computed in f64 for precision.
|
|
601
|
+
// Accumulate squared differences in f64.
|
|
623
602
|
v128_t dx_low_f64x2 = wasm_f64x2_sub(a_x_low_f64x2, b_x_low_f64x2);
|
|
624
603
|
v128_t dx_high_f64x2 = wasm_f64x2_sub(a_x_high_f64x2, b_x_high_f64x2);
|
|
625
604
|
v128_t dy_low_f64x2 = wasm_f64x2_sub(a_y_low_f64x2, b_y_low_f64x2);
|
|
@@ -635,12 +614,6 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
|
|
|
635
614
|
sum_sq_z_high_f64x2 = wasm_f64x2_relaxed_madd(dz_high_f64x2, dz_high_f64x2, sum_sq_z_high_f64x2);
|
|
636
615
|
}
|
|
637
616
|
|
|
638
|
-
nk_f64_t sum_a_x = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_a_x_low_f64x2, sum_a_x_high_f64x2));
|
|
639
|
-
nk_f64_t sum_a_y = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_a_y_low_f64x2, sum_a_y_high_f64x2));
|
|
640
|
-
nk_f64_t sum_a_z = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_a_z_low_f64x2, sum_a_z_high_f64x2));
|
|
641
|
-
nk_f64_t sum_b_x = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_b_x_low_f64x2, sum_b_x_high_f64x2));
|
|
642
|
-
nk_f64_t sum_b_y = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_b_y_low_f64x2, sum_b_y_high_f64x2));
|
|
643
|
-
nk_f64_t sum_b_z = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_b_z_low_f64x2, sum_b_z_high_f64x2));
|
|
644
617
|
nk_f64_t sum_sq_x = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_sq_x_low_f64x2, sum_sq_x_high_f64x2));
|
|
645
618
|
nk_f64_t sum_sq_y = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_sq_y_low_f64x2, sum_sq_y_high_f64x2));
|
|
646
619
|
nk_f64_t sum_sq_z = nk_hsum_f64x2_v128relaxed_(wasm_f64x2_add(sum_sq_z_low_f64x2, sum_sq_z_high_f64x2));
|
|
@@ -649,45 +622,25 @@ NK_PUBLIC void nk_rmsd_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_
|
|
|
649
622
|
for (; index < n; ++index) {
|
|
650
623
|
nk_f64_t ax = a[index * 3 + 0], ay = a[index * 3 + 1], az = a[index * 3 + 2];
|
|
651
624
|
nk_f64_t bx = b[index * 3 + 0], by = b[index * 3 + 1], bz = b[index * 3 + 2];
|
|
652
|
-
sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
|
|
653
|
-
sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
|
|
654
625
|
nk_f64_t dx = ax - bx, dy = ay - by, dz = az - bz;
|
|
655
626
|
sum_sq_x += dx * dx, sum_sq_y += dy * dy, sum_sq_z += dz * dz;
|
|
656
627
|
}
|
|
657
628
|
|
|
658
|
-
|
|
659
|
-
nk_f64_t centroid_a_x = sum_a_x * inv_points_count, centroid_a_y = sum_a_y * inv_points_count,
|
|
660
|
-
centroid_a_z = sum_a_z * inv_points_count;
|
|
661
|
-
nk_f64_t centroid_b_x = sum_b_x * inv_points_count, centroid_b_y = sum_b_y * inv_points_count,
|
|
662
|
-
centroid_b_z = sum_b_z * inv_points_count;
|
|
663
|
-
if (a_centroid)
|
|
664
|
-
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
665
|
-
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
666
|
-
if (b_centroid)
|
|
667
|
-
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
668
|
-
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
669
|
-
|
|
670
|
-
nk_f64_t sum_squared = sum_sq_x + sum_sq_y + sum_sq_z;
|
|
671
|
-
nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
672
|
-
nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
673
|
-
nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
674
|
-
nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
675
|
-
*result = nk_f64_sqrt_v128relaxed(sum_squared * inv_points_count - mean_diff_sq);
|
|
629
|
+
*result = nk_f64_sqrt_v128relaxed((sum_sq_x + sum_sq_y + sum_sq_z) / (nk_f64_t)n);
|
|
676
630
|
}
|
|
677
631
|
|
|
678
632
|
NK_PUBLIC void nk_rmsd_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
679
633
|
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
|
|
680
|
-
// RMSD uses identity rotation and scale=1.0
|
|
681
634
|
if (rotation)
|
|
682
635
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
683
636
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
684
637
|
if (scale) *scale = 1.0;
|
|
638
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
639
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
685
640
|
|
|
686
641
|
v128_t const zeros_f64x2 = wasm_f64x2_splat(0);
|
|
687
642
|
|
|
688
|
-
// Accumulators for
|
|
689
|
-
v128_t sum_a_x_f64x2 = zeros_f64x2, sum_a_y_f64x2 = zeros_f64x2, sum_a_z_f64x2 = zeros_f64x2;
|
|
690
|
-
v128_t sum_b_x_f64x2 = zeros_f64x2, sum_b_y_f64x2 = zeros_f64x2, sum_b_z_f64x2 = zeros_f64x2;
|
|
643
|
+
// Accumulators for squared differences
|
|
691
644
|
v128_t sum_squared_x_f64x2 = zeros_f64x2, sum_squared_y_f64x2 = zeros_f64x2, sum_squared_z_f64x2 = zeros_f64x2;
|
|
692
645
|
|
|
693
646
|
v128_t a_x_f64x2, a_y_f64x2, a_z_f64x2, b_x_f64x2, b_y_f64x2, b_z_f64x2;
|
|
@@ -698,13 +651,6 @@ NK_PUBLIC void nk_rmsd_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_
|
|
|
698
651
|
nk_deinterleave_f64x2_v128relaxed_(a + i * 3, &a_x_f64x2, &a_y_f64x2, &a_z_f64x2);
|
|
699
652
|
nk_deinterleave_f64x2_v128relaxed_(b + i * 3, &b_x_f64x2, &b_y_f64x2, &b_z_f64x2);
|
|
700
653
|
|
|
701
|
-
sum_a_x_f64x2 = wasm_f64x2_add(sum_a_x_f64x2, a_x_f64x2);
|
|
702
|
-
sum_a_y_f64x2 = wasm_f64x2_add(sum_a_y_f64x2, a_y_f64x2);
|
|
703
|
-
sum_a_z_f64x2 = wasm_f64x2_add(sum_a_z_f64x2, a_z_f64x2);
|
|
704
|
-
sum_b_x_f64x2 = wasm_f64x2_add(sum_b_x_f64x2, b_x_f64x2);
|
|
705
|
-
sum_b_y_f64x2 = wasm_f64x2_add(sum_b_y_f64x2, b_y_f64x2);
|
|
706
|
-
sum_b_z_f64x2 = wasm_f64x2_add(sum_b_z_f64x2, b_z_f64x2);
|
|
707
|
-
|
|
708
654
|
v128_t delta_x_f64x2 = wasm_f64x2_sub(a_x_f64x2, b_x_f64x2);
|
|
709
655
|
v128_t delta_y_f64x2 = wasm_f64x2_sub(a_y_f64x2, b_y_f64x2);
|
|
710
656
|
v128_t delta_z_f64x2 = wasm_f64x2_sub(a_z_f64x2, b_z_f64x2);
|
|
@@ -715,12 +661,6 @@ NK_PUBLIC void nk_rmsd_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_
|
|
|
715
661
|
}
|
|
716
662
|
|
|
717
663
|
// Reduce vectors to scalars.
|
|
718
|
-
nk_f64_t total_ax = nk_reduce_stable_f64x2_v128relaxed_(sum_a_x_f64x2), total_ax_compensation = 0.0;
|
|
719
|
-
nk_f64_t total_ay = nk_reduce_stable_f64x2_v128relaxed_(sum_a_y_f64x2), total_ay_compensation = 0.0;
|
|
720
|
-
nk_f64_t total_az = nk_reduce_stable_f64x2_v128relaxed_(sum_a_z_f64x2), total_az_compensation = 0.0;
|
|
721
|
-
nk_f64_t total_bx = nk_reduce_stable_f64x2_v128relaxed_(sum_b_x_f64x2), total_bx_compensation = 0.0;
|
|
722
|
-
nk_f64_t total_by = nk_reduce_stable_f64x2_v128relaxed_(sum_b_y_f64x2), total_by_compensation = 0.0;
|
|
723
|
-
nk_f64_t total_bz = nk_reduce_stable_f64x2_v128relaxed_(sum_b_z_f64x2), total_bz_compensation = 0.0;
|
|
724
664
|
nk_f64_t total_squared_x = nk_reduce_stable_f64x2_v128relaxed_(sum_squared_x_f64x2),
|
|
725
665
|
total_squared_x_compensation = 0.0;
|
|
726
666
|
nk_f64_t total_squared_y = nk_reduce_stable_f64x2_v128relaxed_(sum_squared_y_f64x2),
|
|
@@ -732,40 +672,16 @@ NK_PUBLIC void nk_rmsd_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_
|
|
|
732
672
|
for (; i < n; ++i) {
|
|
733
673
|
nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
|
|
734
674
|
nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
|
|
735
|
-
nk_accumulate_sum_f64_(&total_ax, &total_ax_compensation, ax);
|
|
736
|
-
nk_accumulate_sum_f64_(&total_ay, &total_ay_compensation, ay);
|
|
737
|
-
nk_accumulate_sum_f64_(&total_az, &total_az_compensation, az);
|
|
738
|
-
nk_accumulate_sum_f64_(&total_bx, &total_bx_compensation, bx);
|
|
739
|
-
nk_accumulate_sum_f64_(&total_by, &total_by_compensation, by);
|
|
740
|
-
nk_accumulate_sum_f64_(&total_bz, &total_bz_compensation, bz);
|
|
741
675
|
nk_f64_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
|
|
742
676
|
nk_accumulate_square_f64_(&total_squared_x, &total_squared_x_compensation, delta_x);
|
|
743
677
|
nk_accumulate_square_f64_(&total_squared_y, &total_squared_y_compensation, delta_y);
|
|
744
678
|
nk_accumulate_square_f64_(&total_squared_z, &total_squared_z_compensation, delta_z);
|
|
745
679
|
}
|
|
746
680
|
|
|
747
|
-
total_ax += total_ax_compensation, total_ay += total_ay_compensation, total_az += total_az_compensation;
|
|
748
|
-
total_bx += total_bx_compensation, total_by += total_by_compensation, total_bz += total_bz_compensation;
|
|
749
681
|
total_squared_x += total_squared_x_compensation, total_squared_y += total_squared_y_compensation,
|
|
750
682
|
total_squared_z += total_squared_z_compensation;
|
|
751
683
|
|
|
752
|
-
|
|
753
|
-
nk_f64_t inv_points_count = 1.0 / (nk_f64_t)n;
|
|
754
|
-
nk_f64_t centroid_a_x = total_ax * inv_points_count, centroid_a_y = total_ay * inv_points_count,
|
|
755
|
-
centroid_a_z = total_az * inv_points_count;
|
|
756
|
-
nk_f64_t centroid_b_x = total_bx * inv_points_count, centroid_b_y = total_by * inv_points_count,
|
|
757
|
-
centroid_b_z = total_bz * inv_points_count;
|
|
758
|
-
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
759
|
-
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
760
|
-
|
|
761
|
-
// Compute RMSD
|
|
762
|
-
nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
763
|
-
nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
764
|
-
nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
765
|
-
nk_f64_t sum_squared = total_squared_x + total_squared_y + total_squared_z;
|
|
766
|
-
nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
767
|
-
|
|
768
|
-
*result = nk_f64_sqrt_v128relaxed(sum_squared * inv_points_count - mean_diff_sq);
|
|
684
|
+
*result = nk_f64_sqrt_v128relaxed((total_squared_x + total_squared_y + total_squared_z) / (nk_f64_t)n);
|
|
769
685
|
}
|
|
770
686
|
|
|
771
687
|
NK_PUBLIC void nk_kabsch_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
package/include/numkong/mesh.h
CHANGED
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
*
|
|
7
7
|
* Contains:
|
|
8
8
|
*
|
|
9
|
-
* - Root Mean Square Deviation (RMSD)
|
|
9
|
+
* - Root Mean Square Deviation (RMSD) of raw point differences
|
|
10
10
|
* - Kabsch algorithm for optimal rigid body alignment (rotation only)
|
|
11
11
|
* - Umeyama algorithm for similarity transform (rotation + uniform scaling)
|
|
12
12
|
*
|
|
@@ -48,7 +48,7 @@
|
|
|
48
48
|
*
|
|
49
49
|
* @section algorithm_overview Algorithm Overview
|
|
50
50
|
*
|
|
51
|
-
* - RMSD:
|
|
51
|
+
* - RMSD: Raw √(Σ‖aᵢ − bᵢ‖² / n) without centering or alignment. R = identity, scale = 1.0, centroids zeroed
|
|
52
52
|
* - Kabsch: Finds optimal rotation R minimizing ‖R × (a - ā) - (b - b̄)‖. scale = 1.0
|
|
53
53
|
* - Umeyama: Finds optimal rotation R and scale c minimizing ‖c × R × (a - ā) - (b - b̄)‖
|
|
54
54
|
*
|
package/include/numkong/mesh.hpp
CHANGED
|
@@ -354,74 +354,30 @@ void rmsd( //
|
|
|
354
354
|
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
355
355
|
nk_rmsd_bf16(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_,
|
|
356
356
|
scale ? &scale->raw_ : nullptr, &metric->raw_);
|
|
357
|
-
// Scalar fallback
|
|
357
|
+
// Scalar fallback: raw √(Σ‖aᵢ − bᵢ‖² / n), no centering
|
|
358
358
|
else {
|
|
359
|
-
// Step 1: Compute centroids
|
|
360
|
-
metric_type_ sum_a_x {}, sum_a_y {}, sum_a_z {};
|
|
361
|
-
metric_type_ sum_b_x {}, sum_b_y {}, sum_b_z {};
|
|
362
|
-
metric_type_ val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z;
|
|
363
|
-
|
|
364
|
-
for (std::size_t i = 0; i < n; i++) {
|
|
365
|
-
val_a_x = metric_type_(a[i * 3 + 0]);
|
|
366
|
-
val_a_y = metric_type_(a[i * 3 + 1]);
|
|
367
|
-
val_a_z = metric_type_(a[i * 3 + 2]);
|
|
368
|
-
val_b_x = metric_type_(b[i * 3 + 0]);
|
|
369
|
-
val_b_y = metric_type_(b[i * 3 + 1]);
|
|
370
|
-
val_b_z = metric_type_(b[i * 3 + 2]);
|
|
371
|
-
sum_a_x = sum_a_x + val_a_x;
|
|
372
|
-
sum_a_y = sum_a_y + val_a_y;
|
|
373
|
-
sum_a_z = sum_a_z + val_a_z;
|
|
374
|
-
sum_b_x = sum_b_x + val_b_x;
|
|
375
|
-
sum_b_y = sum_b_y + val_b_y;
|
|
376
|
-
sum_b_z = sum_b_z + val_b_z;
|
|
377
|
-
}
|
|
378
|
-
|
|
379
|
-
metric_type_ inv_n = metric_type_(1.0) / metric_type_(static_cast<double>(n));
|
|
380
|
-
metric_type_ centroid_a_x = sum_a_x * inv_n;
|
|
381
|
-
metric_type_ centroid_a_y = sum_a_y * inv_n;
|
|
382
|
-
metric_type_ centroid_a_z = sum_a_z * inv_n;
|
|
383
|
-
metric_type_ centroid_b_x = sum_b_x * inv_n;
|
|
384
|
-
metric_type_ centroid_b_y = sum_b_y * inv_n;
|
|
385
|
-
metric_type_ centroid_b_z = sum_b_z * inv_n;
|
|
386
|
-
|
|
387
|
-
// Step 2: Store centroids if requested
|
|
388
359
|
if (a_centroid)
|
|
389
|
-
a_centroid[0] = transform_type_(
|
|
390
|
-
a_centroid[2] = transform_type_(
|
|
360
|
+
a_centroid[0] = transform_type_(0.0), a_centroid[1] = transform_type_(0.0),
|
|
361
|
+
a_centroid[2] = transform_type_(0.0);
|
|
391
362
|
if (b_centroid)
|
|
392
|
-
b_centroid[0] = transform_type_(
|
|
393
|
-
b_centroid[2] = transform_type_(
|
|
394
|
-
|
|
395
|
-
// Step 3: RMSD uses identity rotation and scale=1.0
|
|
363
|
+
b_centroid[0] = transform_type_(0.0), b_centroid[1] = transform_type_(0.0),
|
|
364
|
+
b_centroid[2] = transform_type_(0.0);
|
|
396
365
|
if (rotation) {
|
|
397
|
-
rotation[0] = transform_type_(1.0);
|
|
398
|
-
rotation[1] = transform_type_(0.0);
|
|
399
|
-
rotation[
|
|
400
|
-
rotation[3] = transform_type_(0.0);
|
|
401
|
-
rotation[4] = transform_type_(1.0);
|
|
402
|
-
rotation[5] = transform_type_(0.0);
|
|
403
|
-
rotation[6] = transform_type_(0.0);
|
|
404
|
-
rotation[7] = transform_type_(0.0);
|
|
405
|
-
rotation[8] = transform_type_(1.0);
|
|
366
|
+
rotation[0] = transform_type_(1.0), rotation[1] = transform_type_(0.0), rotation[2] = transform_type_(0.0);
|
|
367
|
+
rotation[3] = transform_type_(0.0), rotation[4] = transform_type_(1.0), rotation[5] = transform_type_(0.0);
|
|
368
|
+
rotation[6] = transform_type_(0.0), rotation[7] = transform_type_(0.0), rotation[8] = transform_type_(1.0);
|
|
406
369
|
}
|
|
407
370
|
if (scale) *scale = transform_type_(1.0);
|
|
408
371
|
|
|
409
|
-
// Step 4: Compute RMSD between centered point clouds
|
|
410
372
|
metric_type_ sum_squared {};
|
|
411
373
|
for (std::size_t i = 0; i < n; i++) {
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
val_b_x = metric_type_(b[i * 3 + 0]);
|
|
416
|
-
val_b_y = metric_type_(b[i * 3 + 1]);
|
|
417
|
-
val_b_z = metric_type_(b[i * 3 + 2]);
|
|
418
|
-
metric_type_ dx = (val_a_x - centroid_a_x) - (val_b_x - centroid_b_x);
|
|
419
|
-
metric_type_ dy = (val_a_y - centroid_a_y) - (val_b_y - centroid_b_y);
|
|
420
|
-
metric_type_ dz = (val_a_z - centroid_a_z) - (val_b_z - centroid_b_z);
|
|
374
|
+
metric_type_ dx = metric_type_(a[i * 3 + 0]) - metric_type_(b[i * 3 + 0]);
|
|
375
|
+
metric_type_ dy = metric_type_(a[i * 3 + 1]) - metric_type_(b[i * 3 + 1]);
|
|
376
|
+
metric_type_ dz = metric_type_(a[i * 3 + 2]) - metric_type_(b[i * 3 + 2]);
|
|
421
377
|
sum_squared = sum_squared + dx * dx + dy * dy + dz * dz;
|
|
422
378
|
}
|
|
423
379
|
|
|
424
|
-
*metric = (sum_squared
|
|
380
|
+
*metric = (sum_squared / metric_type_(static_cast<double>(n))).sqrt();
|
|
425
381
|
}
|
|
426
382
|
}
|
|
427
383
|
|
|
@@ -470,18 +426,12 @@ void kabsch( //
|
|
|
470
426
|
metric_type_ val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z;
|
|
471
427
|
|
|
472
428
|
for (std::size_t i = 0; i < n; i++) {
|
|
473
|
-
val_a_x = metric_type_(a[i * 3 + 0])
|
|
474
|
-
val_a_y = metric_type_(a[i * 3 + 1]);
|
|
429
|
+
val_a_x = metric_type_(a[i * 3 + 0]), val_a_y = metric_type_(a[i * 3 + 1]),
|
|
475
430
|
val_a_z = metric_type_(a[i * 3 + 2]);
|
|
476
|
-
val_b_x = metric_type_(b[i * 3 + 0])
|
|
477
|
-
val_b_y = metric_type_(b[i * 3 + 1]);
|
|
431
|
+
val_b_x = metric_type_(b[i * 3 + 0]), val_b_y = metric_type_(b[i * 3 + 1]),
|
|
478
432
|
val_b_z = metric_type_(b[i * 3 + 2]);
|
|
479
|
-
sum_a_x = sum_a_x + val_a_x;
|
|
480
|
-
|
|
481
|
-
sum_a_z = sum_a_z + val_a_z;
|
|
482
|
-
sum_b_x = sum_b_x + val_b_x;
|
|
483
|
-
sum_b_y = sum_b_y + val_b_y;
|
|
484
|
-
sum_b_z = sum_b_z + val_b_z;
|
|
433
|
+
sum_a_x = sum_a_x + val_a_x, sum_a_y = sum_a_y + val_a_y, sum_a_z = sum_a_z + val_a_z;
|
|
434
|
+
sum_b_x = sum_b_x + val_b_x, sum_b_y = sum_b_y + val_b_y, sum_b_z = sum_b_z + val_b_z;
|
|
485
435
|
}
|
|
486
436
|
|
|
487
437
|
metric_type_ inv_n = metric_type_(1.0) / metric_type_(static_cast<double>(n));
|
|
@@ -503,11 +453,9 @@ void kabsch( //
|
|
|
503
453
|
// Step 2: Build 3x3 covariance matrix H = (A - A_bar)^T x (B - B_bar)
|
|
504
454
|
metric_type_ cross_covariance[9] = {};
|
|
505
455
|
for (std::size_t i = 0; i < n; i++) {
|
|
506
|
-
val_a_x = metric_type_(a[i * 3 + 0]) - centroid_a_x
|
|
507
|
-
val_a_y = metric_type_(a[i * 3 + 1]) - centroid_a_y;
|
|
456
|
+
val_a_x = metric_type_(a[i * 3 + 0]) - centroid_a_x, val_a_y = metric_type_(a[i * 3 + 1]) - centroid_a_y,
|
|
508
457
|
val_a_z = metric_type_(a[i * 3 + 2]) - centroid_a_z;
|
|
509
|
-
val_b_x = metric_type_(b[i * 3 + 0]) - centroid_b_x
|
|
510
|
-
val_b_y = metric_type_(b[i * 3 + 1]) - centroid_b_y;
|
|
458
|
+
val_b_x = metric_type_(b[i * 3 + 0]) - centroid_b_x, val_b_y = metric_type_(b[i * 3 + 1]) - centroid_b_y,
|
|
511
459
|
val_b_z = metric_type_(b[i * 3 + 2]) - centroid_b_z;
|
|
512
460
|
cross_covariance[0] = cross_covariance[0] + val_a_x * val_b_x;
|
|
513
461
|
cross_covariance[1] = cross_covariance[1] + val_a_x * val_b_y;
|
|
@@ -563,11 +511,11 @@ void kabsch( //
|
|
|
563
511
|
metric_type_ sum_squared {};
|
|
564
512
|
for (std::size_t i = 0; i < n; i++) {
|
|
565
513
|
metric_type_ point_a[3], point_b[3], rotated_point_a[3];
|
|
566
|
-
point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x
|
|
567
|
-
point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y
|
|
514
|
+
point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x,
|
|
515
|
+
point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y,
|
|
568
516
|
point_a[2] = metric_type_(a[i * 3 + 2]) - centroid_a_z;
|
|
569
|
-
point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x
|
|
570
|
-
point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y
|
|
517
|
+
point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x,
|
|
518
|
+
point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y,
|
|
571
519
|
point_b[2] = metric_type_(b[i * 3 + 2]) - centroid_b_z;
|
|
572
520
|
rotated_point_a[0] = rotation_matrix[0] * point_a[0] + rotation_matrix[1] * point_a[1] +
|
|
573
521
|
rotation_matrix[2] * point_a[2];
|
|
@@ -628,18 +576,12 @@ void umeyama(in_type_ const *a, in_type_ const *b, std::size_t n, transform_type
|
|
|
628
576
|
metric_type_ val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z;
|
|
629
577
|
|
|
630
578
|
for (std::size_t i = 0; i < n; i++) {
|
|
631
|
-
val_a_x = metric_type_(a[i * 3 + 0])
|
|
632
|
-
val_a_y = metric_type_(a[i * 3 + 1]);
|
|
579
|
+
val_a_x = metric_type_(a[i * 3 + 0]), val_a_y = metric_type_(a[i * 3 + 1]),
|
|
633
580
|
val_a_z = metric_type_(a[i * 3 + 2]);
|
|
634
|
-
val_b_x = metric_type_(b[i * 3 + 0])
|
|
635
|
-
val_b_y = metric_type_(b[i * 3 + 1]);
|
|
581
|
+
val_b_x = metric_type_(b[i * 3 + 0]), val_b_y = metric_type_(b[i * 3 + 1]),
|
|
636
582
|
val_b_z = metric_type_(b[i * 3 + 2]);
|
|
637
|
-
sum_a_x = sum_a_x + val_a_x;
|
|
638
|
-
|
|
639
|
-
sum_a_z = sum_a_z + val_a_z;
|
|
640
|
-
sum_b_x = sum_b_x + val_b_x;
|
|
641
|
-
sum_b_y = sum_b_y + val_b_y;
|
|
642
|
-
sum_b_z = sum_b_z + val_b_z;
|
|
583
|
+
sum_a_x = sum_a_x + val_a_x, sum_a_y = sum_a_y + val_a_y, sum_a_z = sum_a_z + val_a_z;
|
|
584
|
+
sum_b_x = sum_b_x + val_b_x, sum_b_y = sum_b_y + val_b_y, sum_b_z = sum_b_z + val_b_z;
|
|
643
585
|
}
|
|
644
586
|
|
|
645
587
|
metric_type_ inv_n = metric_type_(1.0) / metric_type_(static_cast<double>(n));
|
|
@@ -650,16 +592,13 @@ void umeyama(in_type_ const *a, in_type_ const *b, std::size_t n, transform_type
|
|
|
650
592
|
metric_type_ centroid_b_y = sum_b_y * inv_n;
|
|
651
593
|
metric_type_ centroid_b_z = sum_b_z * inv_n;
|
|
652
594
|
|
|
653
|
-
if (a_centroid)
|
|
654
|
-
a_centroid[0] = transform_type_(centroid_a_x)
|
|
655
|
-
a_centroid[1] = transform_type_(centroid_a_y);
|
|
595
|
+
if (a_centroid)
|
|
596
|
+
a_centroid[0] = transform_type_(centroid_a_x), a_centroid[1] = transform_type_(centroid_a_y),
|
|
656
597
|
a_centroid[2] = transform_type_(centroid_a_z);
|
|
657
|
-
|
|
658
|
-
if (b_centroid)
|
|
659
|
-
b_centroid[0] = transform_type_(centroid_b_x)
|
|
660
|
-
b_centroid[1] = transform_type_(centroid_b_y);
|
|
598
|
+
|
|
599
|
+
if (b_centroid)
|
|
600
|
+
b_centroid[0] = transform_type_(centroid_b_x), b_centroid[1] = transform_type_(centroid_b_y),
|
|
661
601
|
b_centroid[2] = transform_type_(centroid_b_z);
|
|
662
|
-
}
|
|
663
602
|
|
|
664
603
|
// Step 2: Build covariance matrix H and compute variance of A
|
|
665
604
|
metric_type_ cross_covariance[9] = {};
|
|
@@ -733,11 +672,11 @@ void umeyama(in_type_ const *a, in_type_ const *b, std::size_t n, transform_type
|
|
|
733
672
|
metric_type_ sum_squared {};
|
|
734
673
|
for (std::size_t i = 0; i < n; i++) {
|
|
735
674
|
metric_type_ point_a[3], point_b[3], rotated_point_a[3];
|
|
736
|
-
point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x
|
|
737
|
-
point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y
|
|
675
|
+
point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x,
|
|
676
|
+
point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y,
|
|
738
677
|
point_a[2] = metric_type_(a[i * 3 + 2]) - centroid_a_z;
|
|
739
|
-
point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x
|
|
740
|
-
point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y
|
|
678
|
+
point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x,
|
|
679
|
+
point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y,
|
|
741
680
|
point_b[2] = metric_type_(b[i * 3 + 2]) - centroid_b_z;
|
|
742
681
|
rotated_point_a[0] = scale_factor * (rotation_matrix[0] * point_a[0] + rotation_matrix[1] * point_a[1] +
|
|
743
682
|
rotation_matrix[2] * point_a[2]);
|
|
@@ -3936,6 +3936,35 @@ NK_PUBLIC void nk_reduce_moments_f16_neon( //
|
|
|
3936
3936
|
else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3937
3937
|
}
|
|
3938
3938
|
|
|
3939
|
+
NK_INTERNAL void nk_reduce_moments_u1_neon_contiguous_( //
|
|
3940
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, //
|
|
3941
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3942
|
+
nk_size_t byte_count = nk_size_divide_round_up_(count, NK_BITS_PER_BYTE);
|
|
3943
|
+
nk_u64_t sum = 0;
|
|
3944
|
+
nk_size_t idx = 0;
|
|
3945
|
+
// Each vcntq_u8 produces values 0-8 per lane; accumulate at u8 level
|
|
3946
|
+
// for up to 31 iterations (31 × 8 = 248, fits in u8) before widening.
|
|
3947
|
+
while (idx + 16 <= byte_count) {
|
|
3948
|
+
uint8x16_t popcount_u8x16 = vdupq_n_u8(0);
|
|
3949
|
+
for (nk_size_t cycle = 0; cycle < 31 && idx + 16 <= byte_count; ++cycle, idx += 16) {
|
|
3950
|
+
uint8x16_t data_u8x16 = vld1q_u8((nk_u8_t const *)data_ptr + idx);
|
|
3951
|
+
popcount_u8x16 = vaddq_u8(popcount_u8x16, vcntq_u8(data_u8x16));
|
|
3952
|
+
}
|
|
3953
|
+
sum += (nk_u64_t)vaddlvq_u8(popcount_u8x16);
|
|
3954
|
+
}
|
|
3955
|
+
for (; idx < byte_count; ++idx) sum += nk_u1x8_popcount_(((nk_u8_t const *)data_ptr)[idx]);
|
|
3956
|
+
*sum_ptr = sum, *sumsq_ptr = sum;
|
|
3957
|
+
}
|
|
3958
|
+
|
|
3959
|
+
NK_PUBLIC void nk_reduce_moments_u1_neon( //
|
|
3960
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3961
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3962
|
+
count = nk_size_round_up_to_multiple_(count, 8);
|
|
3963
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3964
|
+
else if (stride_bytes == 1) nk_reduce_moments_u1_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3965
|
+
else nk_reduce_moments_u1_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3966
|
+
}
|
|
3967
|
+
|
|
3939
3968
|
#if defined(__clang__)
|
|
3940
3969
|
#pragma clang attribute pop
|
|
3941
3970
|
#elif defined(__GNUC__)
|
|
@@ -33,7 +33,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_contiguous_( //
|
|
|
33
33
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
34
34
|
|
|
35
35
|
// bf16 representation of 1.0 is 0x3F80 (same as upper 16 bits of f32 1.0)
|
|
36
|
-
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(
|
|
36
|
+
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
|
|
37
37
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
38
38
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
39
39
|
nk_size_t idx = 0;
|
|
@@ -61,7 +61,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_strided_( //
|
|
|
61
61
|
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
62
62
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
63
63
|
|
|
64
|
-
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(
|
|
64
|
+
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
|
|
65
65
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
66
66
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
67
67
|
nk_size_t idx = 0;
|
|
@@ -34,7 +34,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_contiguous_( //
|
|
|
34
34
|
|
|
35
35
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
36
36
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
37
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
37
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
38
38
|
nk_size_t idx = 0;
|
|
39
39
|
|
|
40
40
|
for (; idx + 8 <= count; idx += 8) {
|
|
@@ -67,7 +67,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
|
|
|
67
67
|
|
|
68
68
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
69
69
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
70
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
70
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
71
71
|
nk_size_t idx = 0;
|
|
72
72
|
|
|
73
73
|
if (stride_elements == 2) {
|
|
@@ -159,7 +159,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_contiguous_( //
|
|
|
159
159
|
|
|
160
160
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
161
161
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
162
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
162
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
163
163
|
nk_size_t idx = 0;
|
|
164
164
|
|
|
165
165
|
for (; idx + 8 <= count; idx += 8) {
|
|
@@ -192,7 +192,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
|
|
|
192
192
|
|
|
193
193
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
194
194
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
195
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
195
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
196
196
|
nk_size_t idx = 0;
|
|
197
197
|
|
|
198
198
|
if (stride_elements == 2) {
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SVE horizontal reduction helpers with MSan unpoisoning.
|
|
3
|
+
* @file include/numkong/reduce/sve.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date April 12, 2026
|
|
6
|
+
*
|
|
7
|
+
* LLVM's MSan does not instrument ARM SVE intrinsics — `svaddv` moves data
|
|
8
|
+
* from vector to scalar registers via architecture-specific paths invisible
|
|
9
|
+
* to the compiler, causing false-positive uninitialized-value reports.
|
|
10
|
+
* These macros wrap the reduction and unpoison the scalar result.
|
|
11
|
+
*
|
|
12
|
+
* The `svaddv` intrinsic stays inside a macro so it expands in the caller's
|
|
13
|
+
* target context — SVE and SME streaming translation units carry incompatible
|
|
14
|
+
* target attributes. The unpoisoning runs on the already-reduced scalar, so it
|
|
15
|
+
* lives in a target-agnostic `NK_INTERNAL` helper called from the macro.
|
|
16
|
+
*
|
|
17
|
+
* @sa include/numkong/reduce.h
|
|
18
|
+
*/
|
|
19
|
+
#ifndef NK_REDUCE_SVE_H
|
|
20
|
+
#define NK_REDUCE_SVE_H
|
|
21
|
+
|
|
22
|
+
#if NK_TARGET_ARM64_
|
|
23
|
+
#if NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
|
|
24
|
+
|
|
25
|
+
#include "numkong/types.h"
|
|
26
|
+
|
|
27
|
+
NK_INTERNAL nk_f64_t nk_unpoison_f64_(nk_f64_t v) NK_STREAMING_COMPATIBLE_ {
|
|
28
|
+
nk_unpoison_(&v, sizeof(v));
|
|
29
|
+
return v;
|
|
30
|
+
}
|
|
31
|
+
NK_INTERNAL nk_f32_t nk_unpoison_f32_(nk_f32_t v) NK_STREAMING_COMPATIBLE_ {
|
|
32
|
+
nk_unpoison_(&v, sizeof(v));
|
|
33
|
+
return v;
|
|
34
|
+
}
|
|
35
|
+
NK_INTERNAL nk_u64_t nk_unpoison_u64_(nk_u64_t v) NK_STREAMING_COMPATIBLE_ {
|
|
36
|
+
nk_unpoison_(&v, sizeof(v));
|
|
37
|
+
return v;
|
|
38
|
+
}
|
|
39
|
+
NK_INTERNAL nk_i64_t nk_unpoison_i64_(nk_i64_t v) NK_STREAMING_COMPATIBLE_ {
|
|
40
|
+
nk_unpoison_(&v, sizeof(v));
|
|
41
|
+
return v;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
#define nk_svaddv_f64_(predicate, vector) nk_unpoison_f64_(svaddv_f64((predicate), (vector)))
|
|
45
|
+
#define nk_svaddv_f32_(predicate, vector) nk_unpoison_f32_(svaddv_f32((predicate), (vector)))
|
|
46
|
+
#define nk_svaddv_u32_(predicate, vector) nk_unpoison_u64_(svaddv_u32((predicate), (vector)))
|
|
47
|
+
#define nk_svaddv_s32_(predicate, vector) nk_unpoison_i64_(svaddv_s32((predicate), (vector)))
|
|
48
|
+
#define nk_svaddv_u8_(predicate, vector) nk_unpoison_u64_(svaddv_u8((predicate), (vector)))
|
|
49
|
+
|
|
50
|
+
#endif // NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
|
|
51
|
+
#endif // NK_TARGET_ARM64_
|
|
52
|
+
#endif // NK_REDUCE_SVE_H
|
package/include/numkong/reduce.h
CHANGED
|
@@ -389,6 +389,8 @@ NK_PUBLIC void nk_reduce_moments_i16_neon(nk_i16_t const *, nk_size_t, nk_size_t
|
|
|
389
389
|
/** @copydoc nk_reduce_moments_f64 */
|
|
390
390
|
NK_PUBLIC void nk_reduce_moments_u16_neon(nk_u16_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
|
|
391
391
|
/** @copydoc nk_reduce_moments_f64 */
|
|
392
|
+
NK_PUBLIC void nk_reduce_moments_u1_neon(nk_u1x8_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
|
|
393
|
+
/** @copydoc nk_reduce_moments_f64 */
|
|
392
394
|
NK_PUBLIC void nk_reduce_moments_i32_neon(nk_i32_t const *, nk_size_t, nk_size_t, nk_i64_t *, nk_u64_t *);
|
|
393
395
|
/** @copydoc nk_reduce_moments_f64 */
|
|
394
396
|
NK_PUBLIC void nk_reduce_moments_u32_neon(nk_u32_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
|
|
@@ -1559,6 +1561,8 @@ NK_PUBLIC void nk_reduce_moments_u1(nk_u1x8_t const *d, nk_size_t n, nk_size_t s
|
|
|
1559
1561
|
nk_reduce_moments_u1_skylake(d, n, s, sum, sumsq);
|
|
1560
1562
|
#elif NK_TARGET_HASWELL
|
|
1561
1563
|
nk_reduce_moments_u1_haswell(d, n, s, sum, sumsq);
|
|
1564
|
+
#elif NK_TARGET_NEON
|
|
1565
|
+
nk_reduce_moments_u1_neon(d, n, s, sum, sumsq);
|
|
1562
1566
|
#else
|
|
1563
1567
|
nk_reduce_moments_u1_serial(d, n, s, sum, sumsq);
|
|
1564
1568
|
#endif
|
|
@@ -32,8 +32,9 @@
|
|
|
32
32
|
#if NK_TARGET_ARM64_
|
|
33
33
|
#if NK_TARGET_SVE
|
|
34
34
|
|
|
35
|
-
#include "numkong/types.h"
|
|
36
|
-
#include "numkong/
|
|
35
|
+
#include "numkong/types.h" // `nk_u1x8_t`
|
|
36
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
37
|
+
#include "numkong/set/neon.h" // `nk_hamming_u1_neon`
|
|
37
38
|
|
|
38
39
|
#if defined(__cplusplus)
|
|
39
40
|
extern "C" {
|
|
@@ -73,7 +74,7 @@ NK_PUBLIC void nk_hamming_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
|
|
|
73
74
|
i += words_per_register;
|
|
74
75
|
++cycle;
|
|
75
76
|
} while (i < n_bytes && cycle < 31);
|
|
76
|
-
differences +=
|
|
77
|
+
differences += nk_svaddv_u8_(all_predicate_b8x, popcount_u8x);
|
|
77
78
|
popcount_u8x = svdup_n_u8(0);
|
|
78
79
|
cycle = 0; // Reset the cycle counter.
|
|
79
80
|
}
|
|
@@ -110,9 +111,9 @@ NK_PUBLIC void nk_jaccard_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
|
|
|
110
111
|
i += words_per_register;
|
|
111
112
|
++cycle;
|
|
112
113
|
} while (i < n_bytes && cycle < 31);
|
|
113
|
-
intersection_count +=
|
|
114
|
+
intersection_count += nk_svaddv_u8_(all_predicate_b8x, intersection_popcount_u8x);
|
|
114
115
|
intersection_popcount_u8x = svdup_n_u8(0);
|
|
115
|
-
union_count +=
|
|
116
|
+
union_count += nk_svaddv_u8_(all_predicate_b8x, union_popcount_u8x);
|
|
116
117
|
union_popcount_u8x = svdup_n_u8(0);
|
|
117
118
|
cycle = 0; // Reset the cycle counter.
|
|
118
119
|
}
|