numkong 7.4.4 → 7.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +81 -5
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/graniteamx.h +733 -0
- package/include/numkong/dots/serial.h +11 -4
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +29 -3
- package/include/numkong/each/serial.h +22 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +1 -1
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/sme.h +94 -55
- package/include/numkong/mesh/README.md +13 -27
- package/include/numkong/mesh/haswell.h +25 -122
- package/include/numkong/mesh/neon.h +21 -110
- package/include/numkong/mesh/neonbfdot.h +4 -43
- package/include/numkong/mesh/rvv.h +7 -82
- package/include/numkong/mesh/serial.h +48 -53
- package/include/numkong/mesh/skylake.h +7 -123
- package/include/numkong/mesh/v128relaxed.h +9 -93
- package/include/numkong/mesh.h +2 -2
- package/include/numkong/mesh.hpp +35 -96
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatials/graniteamx.h +173 -0
- package/include/numkong/spatials/serial.h +22 -0
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +37 -4
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +56 -12
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -701,16 +701,10 @@ NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t p
|
|
|
701
701
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
702
702
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
703
703
|
if (scale) *scale = 1.0f;
|
|
704
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
705
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
704
706
|
|
|
705
|
-
// Fused single-pass: accumulate centroids and squared differences simultaneously.
|
|
706
|
-
// RMSD = √(E[(a−b)²] − (ā − b̄)²)
|
|
707
707
|
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
708
|
-
vfloat64m2_t sum_a_x_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
709
|
-
vfloat64m2_t sum_a_y_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
710
|
-
vfloat64m2_t sum_a_z_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
711
|
-
vfloat64m2_t sum_b_x_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
712
|
-
vfloat64m2_t sum_b_y_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
713
|
-
vfloat64m2_t sum_b_z_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
714
708
|
vfloat64m2_t sum_squared_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
715
709
|
nk_f32_t const *a_ptr = a, *b_ptr = b;
|
|
716
710
|
nk_size_t remaining = points_count;
|
|
@@ -725,15 +719,7 @@ NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t p
|
|
|
725
719
|
vfloat32m1_t b_x_f32m1 = __riscv_vget_v_f32m1x3_f32m1(b_f32m1x3, 0);
|
|
726
720
|
vfloat32m1_t b_y_f32m1 = __riscv_vget_v_f32m1x3_f32m1(b_f32m1x3, 1);
|
|
727
721
|
vfloat32m1_t b_z_f32m1 = __riscv_vget_v_f32m1x3_f32m1(b_f32m1x3, 2);
|
|
728
|
-
// Accumulate
|
|
729
|
-
sum_a_x_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_a_x_f64m2, sum_a_x_f64m2, a_x_f32m1, vector_length);
|
|
730
|
-
sum_a_y_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_a_y_f64m2, sum_a_y_f64m2, a_y_f32m1, vector_length);
|
|
731
|
-
sum_a_z_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_a_z_f64m2, sum_a_z_f64m2, a_z_f32m1, vector_length);
|
|
732
|
-
sum_b_x_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_b_x_f64m2, sum_b_x_f64m2, b_x_f32m1, vector_length);
|
|
733
|
-
sum_b_y_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_b_y_f64m2, sum_b_y_f64m2, b_y_f32m1, vector_length);
|
|
734
|
-
sum_b_z_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_b_z_f64m2, sum_b_z_f64m2, b_z_f32m1, vector_length);
|
|
735
|
-
// Accumulate (a−b)² per component. Widen a,b to f64 before subtracting to avoid f32
|
|
736
|
-
// cancellation in the single-pass formula RMSD = √(E[(a−b)²] − (ā − b̄)²).
|
|
722
|
+
// Accumulate (a−b)² per component, widening to f64.
|
|
737
723
|
vfloat64m2_t a_x_f64m2 = __riscv_vfwcvt_f_f_v_f64m2(a_x_f32m1, vector_length);
|
|
738
724
|
vfloat64m2_t b_x_f64m2 = __riscv_vfwcvt_f_f_v_f64m2(b_x_f32m1, vector_length);
|
|
739
725
|
vfloat64m2_t a_y_f64m2 = __riscv_vfwcvt_f_f_v_f64m2(a_y_f32m1, vector_length);
|
|
@@ -748,38 +734,9 @@ NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t p
|
|
|
748
734
|
sum_squared_f64m2 = __riscv_vfmacc_vv_f64m2_tu(sum_squared_f64m2, delta_z_f64m2, delta_z_f64m2, vector_length);
|
|
749
735
|
}
|
|
750
736
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
751
|
-
nk_f64_t inv_points_count = 1.0 / (nk_f64_t)points_count;
|
|
752
|
-
nk_f64_t centroid_a_x = __riscv_vfmv_f_s_f64m1_f64(
|
|
753
|
-
__riscv_vfredusum_vs_f64m2_f64m1(sum_a_x_f64m2, zero_f64m1, max_vector_length)) *
|
|
754
|
-
inv_points_count;
|
|
755
|
-
nk_f64_t centroid_a_y = __riscv_vfmv_f_s_f64m1_f64(
|
|
756
|
-
__riscv_vfredusum_vs_f64m2_f64m1(sum_a_y_f64m2, zero_f64m1, max_vector_length)) *
|
|
757
|
-
inv_points_count;
|
|
758
|
-
nk_f64_t centroid_a_z = __riscv_vfmv_f_s_f64m1_f64(
|
|
759
|
-
__riscv_vfredusum_vs_f64m2_f64m1(sum_a_z_f64m2, zero_f64m1, max_vector_length)) *
|
|
760
|
-
inv_points_count;
|
|
761
|
-
nk_f64_t centroid_b_x = __riscv_vfmv_f_s_f64m1_f64(
|
|
762
|
-
__riscv_vfredusum_vs_f64m2_f64m1(sum_b_x_f64m2, zero_f64m1, max_vector_length)) *
|
|
763
|
-
inv_points_count;
|
|
764
|
-
nk_f64_t centroid_b_y = __riscv_vfmv_f_s_f64m1_f64(
|
|
765
|
-
__riscv_vfredusum_vs_f64m2_f64m1(sum_b_y_f64m2, zero_f64m1, max_vector_length)) *
|
|
766
|
-
inv_points_count;
|
|
767
|
-
nk_f64_t centroid_b_z = __riscv_vfmv_f_s_f64m1_f64(
|
|
768
|
-
__riscv_vfredusum_vs_f64m2_f64m1(sum_b_z_f64m2, zero_f64m1, max_vector_length)) *
|
|
769
|
-
inv_points_count;
|
|
770
|
-
if (a_centroid)
|
|
771
|
-
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
772
|
-
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
773
|
-
if (b_centroid)
|
|
774
|
-
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
775
|
-
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
776
|
-
|
|
777
737
|
nk_f64_t sum_squared = __riscv_vfmv_f_s_f64m1_f64(
|
|
778
738
|
__riscv_vfredusum_vs_f64m2_f64m1(sum_squared_f64m2, zero_f64m1, max_vector_length));
|
|
779
|
-
|
|
780
|
-
mean_diff_z = centroid_a_z - centroid_b_z;
|
|
781
|
-
nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
782
|
-
*result = nk_f64_sqrt_rvv(sum_squared * inv_points_count - mean_diff_sq);
|
|
739
|
+
*result = nk_f64_sqrt_rvv(sum_squared / (nk_f64_t)points_count);
|
|
783
740
|
}
|
|
784
741
|
|
|
785
742
|
NK_PUBLIC void nk_rmsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t points_count, nk_f64_t *a_centroid,
|
|
@@ -788,22 +745,10 @@ NK_PUBLIC void nk_rmsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t p
|
|
|
788
745
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
789
746
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
790
747
|
if (scale) *scale = 1.0;
|
|
748
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
749
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
791
750
|
|
|
792
|
-
// Fused single-pass: accumulate centroids and squared differences simultaneously.
|
|
793
|
-
// RMSD = √(E[(a−b)²] − (ā − b̄)²)
|
|
794
751
|
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
795
|
-
vfloat64m1_t sum_a_x_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
796
|
-
vfloat64m1_t sum_a_y_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
797
|
-
vfloat64m1_t sum_a_z_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
798
|
-
vfloat64m1_t sum_b_x_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
799
|
-
vfloat64m1_t sum_b_y_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
800
|
-
vfloat64m1_t sum_b_z_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
801
|
-
vfloat64m1_t compensation_a_x_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
802
|
-
vfloat64m1_t compensation_a_y_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
803
|
-
vfloat64m1_t compensation_a_z_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
804
|
-
vfloat64m1_t compensation_b_x_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
805
|
-
vfloat64m1_t compensation_b_y_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
806
|
-
vfloat64m1_t compensation_b_z_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
807
752
|
vfloat64m1_t sum_squared_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
808
753
|
vfloat64m1_t compensation_squared_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
809
754
|
nk_f64_t const *a_ptr = a, *b_ptr = b;
|
|
@@ -819,13 +764,6 @@ NK_PUBLIC void nk_rmsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t p
|
|
|
819
764
|
vfloat64m1_t b_x_f64m1 = __riscv_vget_v_f64m1x3_f64m1(b_f64m1x3, 0);
|
|
820
765
|
vfloat64m1_t b_y_f64m1 = __riscv_vget_v_f64m1x3_f64m1(b_f64m1x3, 1);
|
|
821
766
|
vfloat64m1_t b_z_f64m1 = __riscv_vget_v_f64m1x3_f64m1(b_f64m1x3, 2);
|
|
822
|
-
// Accumulate centroids with Kahan compensation.
|
|
823
|
-
nk_accumulate_sum_f64m1_rvv_(&sum_a_x_f64m1, &compensation_a_x_f64m1, a_x_f64m1, vector_length);
|
|
824
|
-
nk_accumulate_sum_f64m1_rvv_(&sum_a_y_f64m1, &compensation_a_y_f64m1, a_y_f64m1, vector_length);
|
|
825
|
-
nk_accumulate_sum_f64m1_rvv_(&sum_a_z_f64m1, &compensation_a_z_f64m1, a_z_f64m1, vector_length);
|
|
826
|
-
nk_accumulate_sum_f64m1_rvv_(&sum_b_x_f64m1, &compensation_b_x_f64m1, b_x_f64m1, vector_length);
|
|
827
|
-
nk_accumulate_sum_f64m1_rvv_(&sum_b_y_f64m1, &compensation_b_y_f64m1, b_y_f64m1, vector_length);
|
|
828
|
-
nk_accumulate_sum_f64m1_rvv_(&sum_b_z_f64m1, &compensation_b_z_f64m1, b_z_f64m1, vector_length);
|
|
829
767
|
// Accumulate (a-b)^2 per component.
|
|
830
768
|
vfloat64m1_t delta_x_f64m1 = __riscv_vfsub_vv_f64m1(a_x_f64m1, b_x_f64m1, vector_length);
|
|
831
769
|
vfloat64m1_t delta_y_f64m1 = __riscv_vfsub_vv_f64m1(a_y_f64m1, b_y_f64m1, vector_length);
|
|
@@ -835,21 +773,8 @@ NK_PUBLIC void nk_rmsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t p
|
|
|
835
773
|
dist_sq_f64m1 = __riscv_vfmacc_vv_f64m1(dist_sq_f64m1, delta_z_f64m1, delta_z_f64m1, vector_length);
|
|
836
774
|
nk_accumulate_sum_f64m1_rvv_(&sum_squared_f64m1, &compensation_squared_f64m1, dist_sq_f64m1, vector_length);
|
|
837
775
|
}
|
|
838
|
-
nk_f64_t inv_points_count = 1.0 / (nk_f64_t)points_count;
|
|
839
|
-
nk_f64_t centroid_a_x = nk_dot_stable_sum_f64m1_rvv_(sum_a_x_f64m1, compensation_a_x_f64m1) * inv_points_count;
|
|
840
|
-
nk_f64_t centroid_a_y = nk_dot_stable_sum_f64m1_rvv_(sum_a_y_f64m1, compensation_a_y_f64m1) * inv_points_count;
|
|
841
|
-
nk_f64_t centroid_a_z = nk_dot_stable_sum_f64m1_rvv_(sum_a_z_f64m1, compensation_a_z_f64m1) * inv_points_count;
|
|
842
|
-
nk_f64_t centroid_b_x = nk_dot_stable_sum_f64m1_rvv_(sum_b_x_f64m1, compensation_b_x_f64m1) * inv_points_count;
|
|
843
|
-
nk_f64_t centroid_b_y = nk_dot_stable_sum_f64m1_rvv_(sum_b_y_f64m1, compensation_b_y_f64m1) * inv_points_count;
|
|
844
|
-
nk_f64_t centroid_b_z = nk_dot_stable_sum_f64m1_rvv_(sum_b_z_f64m1, compensation_b_z_f64m1) * inv_points_count;
|
|
845
|
-
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
846
|
-
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
847
|
-
|
|
848
776
|
nk_f64_t sum_squared = nk_dot_stable_sum_f64m1_rvv_(sum_squared_f64m1, compensation_squared_f64m1);
|
|
849
|
-
|
|
850
|
-
mean_diff_z = centroid_a_z - centroid_b_z;
|
|
851
|
-
nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
852
|
-
*result = nk_f64_sqrt_rvv(sum_squared * inv_points_count - mean_diff_sq);
|
|
777
|
+
*result = nk_f64_sqrt_rvv(sum_squared / (nk_f64_t)points_count);
|
|
853
778
|
}
|
|
854
779
|
|
|
855
780
|
NK_PUBLIC void nk_kabsch_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t points_count, nk_f32_t *a_centroid,
|
|
@@ -289,6 +289,18 @@ extern "C" {
|
|
|
289
289
|
m[2] * (m[3] * m[7] - m[4] * m[6]); \
|
|
290
290
|
}
|
|
291
291
|
|
|
292
|
+
/* Optimize serial fallbacks for size — see dots/serial.h for rationale. */
|
|
293
|
+
#if defined(NDEBUG)
|
|
294
|
+
#if defined(_MSC_VER)
|
|
295
|
+
#pragma optimize("s", on)
|
|
296
|
+
#elif defined(__clang__)
|
|
297
|
+
#pragma clang attribute push(__attribute__((minsize)), apply_to = function)
|
|
298
|
+
#elif defined(__GNUC__)
|
|
299
|
+
#pragma GCC push_options
|
|
300
|
+
#pragma GCC optimize("Os")
|
|
301
|
+
#endif
|
|
302
|
+
#endif
|
|
303
|
+
|
|
292
304
|
NK_INTERNAL nk_f32_t nk_sum_three_products_f32_(nk_f32_t left_0, nk_f32_t right_0, nk_f32_t left_1, nk_f32_t right_1,
|
|
293
305
|
nk_f32_t left_2, nk_f32_t right_2) {
|
|
294
306
|
return left_0 * right_0 + left_1 * right_1 + left_2 * right_2;
|
|
@@ -392,59 +404,32 @@ nk_define_det3x3_(f64)
|
|
|
392
404
|
/* RMSD (Root Mean Square Deviation) without optimal superposition.
|
|
393
405
|
* Simply computes the RMS of distances between corresponding points.
|
|
394
406
|
*/
|
|
395
|
-
#define nk_define_rmsd_(input_type, accumulator_type, output_type, result_type, load_and_convert, compute_sqrt)
|
|
396
|
-
NK_PUBLIC void nk_rmsd_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b,
|
|
397
|
-
nk_size_t n, nk_##output_type##_t *a_centroid,
|
|
398
|
-
nk_##output_type##_t *b_centroid, nk_##output_type##_t *rotation,
|
|
399
|
-
nk_##output_type##_t *scale, nk_##result_type##_t *result) {
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
nk_##accumulator_type##_t
|
|
420
|
-
|
|
421
|
-
nk_##accumulator_type##_t centroid_b_y = (sum_b_y + sum_b_y_compensation) * inv_n; \
|
|
422
|
-
nk_##accumulator_type##_t centroid_b_z = (sum_b_z + sum_b_z_compensation) * inv_n; \
|
|
423
|
-
if (a_centroid) \
|
|
424
|
-
a_centroid[0] = (nk_##output_type##_t)centroid_a_x, a_centroid[1] = (nk_##output_type##_t)centroid_a_y, \
|
|
425
|
-
a_centroid[2] = (nk_##output_type##_t)centroid_a_z; \
|
|
426
|
-
if (b_centroid) \
|
|
427
|
-
b_centroid[0] = (nk_##output_type##_t)centroid_b_x, b_centroid[1] = (nk_##output_type##_t)centroid_b_y, \
|
|
428
|
-
b_centroid[2] = (nk_##output_type##_t)centroid_b_z; \
|
|
429
|
-
/* RMSD uses identity rotation and scale=1.0 */ \
|
|
430
|
-
if (rotation) \
|
|
431
|
-
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0, \
|
|
432
|
-
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1; \
|
|
433
|
-
if (scale) *scale = (nk_##output_type##_t)1; \
|
|
434
|
-
nk_##accumulator_type##_t sum_squared = 0, sum_squared_compensation = 0; \
|
|
435
|
-
for (nk_size_t i = 0; i < n; ++i) { \
|
|
436
|
-
load_and_convert(a + i * 3 + 0, &val_a_x), load_and_convert(b + i * 3 + 0, &val_b_x); \
|
|
437
|
-
load_and_convert(a + i * 3 + 1, &val_a_y), load_and_convert(b + i * 3 + 1, &val_b_y); \
|
|
438
|
-
load_and_convert(a + i * 3 + 2, &val_a_z), load_and_convert(b + i * 3 + 2, &val_b_z); \
|
|
439
|
-
nk_##accumulator_type##_t dx = (val_a_x - centroid_a_x) - (val_b_x - centroid_b_x); \
|
|
440
|
-
nk_##accumulator_type##_t dy = (val_a_y - centroid_a_y) - (val_b_y - centroid_b_y); \
|
|
441
|
-
nk_##accumulator_type##_t dz = (val_a_z - centroid_a_z) - (val_b_z - centroid_b_z); \
|
|
442
|
-
nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dx); \
|
|
443
|
-
nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dy); \
|
|
444
|
-
nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dz); \
|
|
445
|
-
} \
|
|
446
|
-
nk_##accumulator_type##_t msd = (sum_squared + sum_squared_compensation) * inv_n; \
|
|
447
|
-
*result = msd > 0 ? (nk_##result_type##_t)compute_sqrt(msd) : 0; \
|
|
407
|
+
#define nk_define_rmsd_(input_type, accumulator_type, output_type, result_type, load_and_convert, compute_sqrt) \
|
|
408
|
+
NK_PUBLIC void nk_rmsd_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
409
|
+
nk_size_t n, nk_##output_type##_t *a_centroid, \
|
|
410
|
+
nk_##output_type##_t *b_centroid, nk_##output_type##_t *rotation, \
|
|
411
|
+
nk_##output_type##_t *scale, nk_##result_type##_t *result) { \
|
|
412
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0; \
|
|
413
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0; \
|
|
414
|
+
if (rotation) \
|
|
415
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0, \
|
|
416
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1; \
|
|
417
|
+
if (scale) *scale = (nk_##output_type##_t)1; \
|
|
418
|
+
nk_##accumulator_type##_t sum_squared = 0, sum_squared_compensation = 0; \
|
|
419
|
+
nk_##accumulator_type##_t val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z; \
|
|
420
|
+
for (nk_size_t i = 0; i < n; ++i) { \
|
|
421
|
+
load_and_convert(a + i * 3 + 0, &val_a_x), load_and_convert(b + i * 3 + 0, &val_b_x); \
|
|
422
|
+
load_and_convert(a + i * 3 + 1, &val_a_y), load_and_convert(b + i * 3 + 1, &val_b_y); \
|
|
423
|
+
load_and_convert(a + i * 3 + 2, &val_a_z), load_and_convert(b + i * 3 + 2, &val_b_z); \
|
|
424
|
+
nk_##accumulator_type##_t dx = val_a_x - val_b_x; \
|
|
425
|
+
nk_##accumulator_type##_t dy = val_a_y - val_b_y; \
|
|
426
|
+
nk_##accumulator_type##_t dz = val_a_z - val_b_z; \
|
|
427
|
+
nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dx); \
|
|
428
|
+
nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dy); \
|
|
429
|
+
nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dz); \
|
|
430
|
+
} \
|
|
431
|
+
nk_##accumulator_type##_t msd = (sum_squared + sum_squared_compensation) / (nk_##accumulator_type##_t)n; \
|
|
432
|
+
*result = msd > 0 ? (nk_##result_type##_t)compute_sqrt(msd) : 0; \
|
|
448
433
|
}
|
|
449
434
|
|
|
450
435
|
/* Kabsch algorithm for optimal rigid body superposition.
|
|
@@ -719,6 +704,16 @@ nk_define_umeyama_(bf16, f32, f32, f32, f32, nk_bf16_to_f32_serial, nk_f32_sqrt_
|
|
|
719
704
|
#undef nk_define_kabsch_
|
|
720
705
|
#undef nk_define_umeyama_
|
|
721
706
|
|
|
707
|
+
#if defined(NDEBUG)
|
|
708
|
+
#if defined(_MSC_VER)
|
|
709
|
+
#pragma optimize("", on)
|
|
710
|
+
#elif defined(__clang__)
|
|
711
|
+
#pragma clang attribute pop
|
|
712
|
+
#elif defined(__GNUC__)
|
|
713
|
+
#pragma GCC pop_options
|
|
714
|
+
#endif
|
|
715
|
+
#endif
|
|
716
|
+
|
|
722
717
|
#if defined(__cplusplus)
|
|
723
718
|
} // extern "C"
|
|
724
719
|
#endif
|
|
@@ -644,12 +644,10 @@ NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
644
644
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
645
645
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
646
646
|
if (scale) *scale = 1.0f;
|
|
647
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
648
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
647
649
|
|
|
648
|
-
// Fused single-pass: centroids + squared differences in f64, using the identity:
|
|
649
|
-
// RMSD = √(E[(a-b)²] - (ā - b̄)²)
|
|
650
650
|
__m512d const zeros_f64x8 = _mm512_setzero_pd();
|
|
651
|
-
__m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
|
|
652
|
-
__m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
|
|
653
651
|
__m512d sum_squared_x_f64x8 = zeros_f64x8, sum_squared_y_f64x8 = zeros_f64x8, sum_squared_z_f64x8 = zeros_f64x8;
|
|
654
652
|
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
655
653
|
nk_size_t i = 0;
|
|
@@ -672,13 +670,6 @@ NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
672
670
|
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
673
671
|
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
674
672
|
|
|
675
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
676
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
677
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
678
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
679
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
680
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
681
|
-
|
|
682
673
|
__m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
683
674
|
__m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
684
675
|
__m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
@@ -708,13 +699,6 @@ NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
708
699
|
b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
709
700
|
b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
710
701
|
|
|
711
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
712
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
713
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
714
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
715
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
716
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
717
|
-
|
|
718
702
|
delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
719
703
|
delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
720
704
|
delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
@@ -746,13 +730,6 @@ NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
746
730
|
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
747
731
|
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
748
732
|
|
|
749
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
750
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
751
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
752
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
753
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
754
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
755
|
-
|
|
756
733
|
__m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
757
734
|
__m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
758
735
|
__m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
@@ -796,13 +773,6 @@ NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
796
773
|
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
797
774
|
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
798
775
|
|
|
799
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
800
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
801
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
802
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
803
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
804
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
805
|
-
|
|
806
776
|
__m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
807
777
|
__m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
808
778
|
__m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
@@ -817,32 +787,10 @@ NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
817
787
|
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
|
|
818
788
|
}
|
|
819
789
|
|
|
820
|
-
// Reduce and compute centroids
|
|
821
|
-
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
822
|
-
nk_f64_t total_ax = _mm512_reduce_add_pd(sum_a_x_f64x8);
|
|
823
|
-
nk_f64_t total_ay = _mm512_reduce_add_pd(sum_a_y_f64x8);
|
|
824
|
-
nk_f64_t total_az = _mm512_reduce_add_pd(sum_a_z_f64x8);
|
|
825
|
-
nk_f64_t total_bx = _mm512_reduce_add_pd(sum_b_x_f64x8);
|
|
826
|
-
nk_f64_t total_by = _mm512_reduce_add_pd(sum_b_y_f64x8);
|
|
827
|
-
nk_f64_t total_bz = _mm512_reduce_add_pd(sum_b_z_f64x8);
|
|
828
790
|
nk_f64_t total_sq_x = _mm512_reduce_add_pd(sum_squared_x_f64x8);
|
|
829
791
|
nk_f64_t total_sq_y = _mm512_reduce_add_pd(sum_squared_y_f64x8);
|
|
830
792
|
nk_f64_t total_sq_z = _mm512_reduce_add_pd(sum_squared_z_f64x8);
|
|
831
|
-
|
|
832
|
-
nk_f64_t centroid_a_x = total_ax * inv_n, centroid_a_y = total_ay * inv_n, centroid_a_z = total_az * inv_n;
|
|
833
|
-
nk_f64_t centroid_b_x = total_bx * inv_n, centroid_b_y = total_by * inv_n, centroid_b_z = total_bz * inv_n;
|
|
834
|
-
if (a_centroid)
|
|
835
|
-
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
836
|
-
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
837
|
-
if (b_centroid)
|
|
838
|
-
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
839
|
-
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
840
|
-
|
|
841
|
-
nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x, mean_diff_y = centroid_a_y - centroid_b_y,
|
|
842
|
-
mean_diff_z = centroid_a_z - centroid_b_z;
|
|
843
|
-
nk_f64_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
844
|
-
nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
845
|
-
*result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
793
|
+
*result = nk_f64_sqrt_haswell((total_sq_x + total_sq_y + total_sq_z) / (nk_f64_t)n);
|
|
846
794
|
}
|
|
847
795
|
|
|
848
796
|
NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
@@ -1008,21 +956,15 @@ NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_si
|
|
|
1008
956
|
|
|
1009
957
|
NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
1010
958
|
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
|
|
1011
|
-
// RMSD uses identity rotation and scale=1.0.
|
|
1012
959
|
if (rotation)
|
|
1013
960
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1014
961
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1015
962
|
if (scale) *scale = 1.0;
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
// = √(E[(a-b)²] - (ā - b̄)²)
|
|
963
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
964
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
965
|
+
|
|
1020
966
|
__m512i const gather_idx_i64x8 = _mm512_setr_epi64(0, 3, 6, 9, 12, 15, 18, 21);
|
|
1021
967
|
__m512d const zeros_f64x8 = _mm512_setzero_pd();
|
|
1022
|
-
|
|
1023
|
-
// Accumulators for centroids and squared differences
|
|
1024
|
-
__m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
|
|
1025
|
-
__m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
|
|
1026
968
|
__m512d sum_squared_x_f64x8 = zeros_f64x8, sum_squared_y_f64x8 = zeros_f64x8, sum_squared_z_f64x8 = zeros_f64x8;
|
|
1027
969
|
|
|
1028
970
|
__m512d a_x_f64x8, a_y_f64x8, a_z_f64x8, b_x_f64x8, b_y_f64x8, b_z_f64x8;
|
|
@@ -1034,13 +976,6 @@ NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
1034
976
|
nk_deinterleave_f64x8_skylake_(a + i * 3, &a_x_f64x8, &a_y_f64x8, &a_z_f64x8);
|
|
1035
977
|
nk_deinterleave_f64x8_skylake_(b + i * 3, &b_x_f64x8, &b_y_f64x8, &b_z_f64x8);
|
|
1036
978
|
|
|
1037
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, a_x_f64x8),
|
|
1038
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, a_y_f64x8),
|
|
1039
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, a_z_f64x8);
|
|
1040
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, b_x_f64x8),
|
|
1041
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y_f64x8),
|
|
1042
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
|
|
1043
|
-
|
|
1044
979
|
__m512d delta_x_f64x8 = _mm512_sub_pd(a_x_f64x8, b_x_f64x8),
|
|
1045
980
|
delta_y_f64x8 = _mm512_sub_pd(a_y_f64x8, b_y_f64x8),
|
|
1046
981
|
delta_z_f64x8 = _mm512_sub_pd(a_z_f64x8, b_z_f64x8);
|
|
@@ -1053,13 +988,6 @@ NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
1053
988
|
nk_deinterleave_f64x8_skylake_(a + (i + 8) * 3, &a_x1_f64x8, &a_y1_f64x8, &a_z1_f64x8);
|
|
1054
989
|
nk_deinterleave_f64x8_skylake_(b + (i + 8) * 3, &b_x1_f64x8, &b_y1_f64x8, &b_z1_f64x8);
|
|
1055
990
|
|
|
1056
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, a_x1_f64x8),
|
|
1057
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, a_y1_f64x8),
|
|
1058
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, a_z1_f64x8);
|
|
1059
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, b_x1_f64x8),
|
|
1060
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y1_f64x8),
|
|
1061
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z1_f64x8);
|
|
1062
|
-
|
|
1063
991
|
__m512d delta_x1_f64x8 = _mm512_sub_pd(a_x1_f64x8, b_x1_f64x8),
|
|
1064
992
|
delta_y1_f64x8 = _mm512_sub_pd(a_y1_f64x8, b_y1_f64x8),
|
|
1065
993
|
delta_z1_f64x8 = _mm512_sub_pd(a_z1_f64x8, b_z1_f64x8);
|
|
@@ -1073,13 +1001,6 @@ NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
1073
1001
|
nk_deinterleave_f64x8_skylake_(a + i * 3, &a_x_f64x8, &a_y_f64x8, &a_z_f64x8);
|
|
1074
1002
|
nk_deinterleave_f64x8_skylake_(b + i * 3, &b_x_f64x8, &b_y_f64x8, &b_z_f64x8);
|
|
1075
1003
|
|
|
1076
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, a_x_f64x8),
|
|
1077
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, a_y_f64x8),
|
|
1078
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, a_z_f64x8);
|
|
1079
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, b_x_f64x8),
|
|
1080
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y_f64x8),
|
|
1081
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
|
|
1082
|
-
|
|
1083
1004
|
__m512d delta_x_f64x8 = _mm512_sub_pd(a_x_f64x8, b_x_f64x8),
|
|
1084
1005
|
delta_y_f64x8 = _mm512_sub_pd(a_y_f64x8, b_y_f64x8),
|
|
1085
1006
|
delta_z_f64x8 = _mm512_sub_pd(a_z_f64x8, b_z_f64x8);
|
|
@@ -1102,13 +1023,6 @@ NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
1102
1023
|
b_y_f64x8 = _mm512_mask_i64gather_pd(zeros_f64x8, mask, gather_idx_i64x8, b_tail + 1, 8);
|
|
1103
1024
|
b_z_f64x8 = _mm512_mask_i64gather_pd(zeros_f64x8, mask, gather_idx_i64x8, b_tail + 2, 8);
|
|
1104
1025
|
|
|
1105
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, a_x_f64x8),
|
|
1106
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, a_y_f64x8),
|
|
1107
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, a_z_f64x8);
|
|
1108
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, b_x_f64x8),
|
|
1109
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y_f64x8),
|
|
1110
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
|
|
1111
|
-
|
|
1112
1026
|
__m512d delta_x_f64x8 = _mm512_sub_pd(a_x_f64x8, b_x_f64x8),
|
|
1113
1027
|
delta_y_f64x8 = _mm512_sub_pd(a_y_f64x8, b_y_f64x8),
|
|
1114
1028
|
delta_z_f64x8 = _mm512_sub_pd(a_z_f64x8, b_z_f64x8);
|
|
@@ -1118,14 +1032,6 @@ NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
1118
1032
|
i = n;
|
|
1119
1033
|
}
|
|
1120
1034
|
|
|
1121
|
-
// Reduce and compute centroids.
|
|
1122
|
-
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
1123
|
-
nk_f64_t total_ax = nk_reduce_stable_f64x8_skylake_(sum_a_x_f64x8), total_ax_compensation = 0.0;
|
|
1124
|
-
nk_f64_t total_ay = nk_reduce_stable_f64x8_skylake_(sum_a_y_f64x8), total_ay_compensation = 0.0;
|
|
1125
|
-
nk_f64_t total_az = nk_reduce_stable_f64x8_skylake_(sum_a_z_f64x8), total_az_compensation = 0.0;
|
|
1126
|
-
nk_f64_t total_bx = nk_reduce_stable_f64x8_skylake_(sum_b_x_f64x8), total_bx_compensation = 0.0;
|
|
1127
|
-
nk_f64_t total_by = nk_reduce_stable_f64x8_skylake_(sum_b_y_f64x8), total_by_compensation = 0.0;
|
|
1128
|
-
nk_f64_t total_bz = nk_reduce_stable_f64x8_skylake_(sum_b_z_f64x8), total_bz_compensation = 0.0;
|
|
1129
1035
|
nk_f64_t total_squared_x = nk_reduce_stable_f64x8_skylake_(sum_squared_x_f64x8), total_squared_x_compensation = 0.0;
|
|
1130
1036
|
nk_f64_t total_squared_y = nk_reduce_stable_f64x8_skylake_(sum_squared_y_f64x8), total_squared_y_compensation = 0.0;
|
|
1131
1037
|
nk_f64_t total_squared_z = nk_reduce_stable_f64x8_skylake_(sum_squared_z_f64x8), total_squared_z_compensation = 0.0;
|
|
@@ -1133,37 +1039,15 @@ NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
1133
1039
|
for (; i < n; ++i) {
|
|
1134
1040
|
nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
|
|
1135
1041
|
nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
|
|
1136
|
-
nk_accumulate_sum_f64_(&total_ax, &total_ax_compensation, ax);
|
|
1137
|
-
nk_accumulate_sum_f64_(&total_ay, &total_ay_compensation, ay);
|
|
1138
|
-
nk_accumulate_sum_f64_(&total_az, &total_az_compensation, az);
|
|
1139
|
-
nk_accumulate_sum_f64_(&total_bx, &total_bx_compensation, bx);
|
|
1140
|
-
nk_accumulate_sum_f64_(&total_by, &total_by_compensation, by);
|
|
1141
|
-
nk_accumulate_sum_f64_(&total_bz, &total_bz_compensation, bz);
|
|
1142
1042
|
nk_f64_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
|
|
1143
1043
|
nk_accumulate_square_f64_(&total_squared_x, &total_squared_x_compensation, delta_x);
|
|
1144
1044
|
nk_accumulate_square_f64_(&total_squared_y, &total_squared_y_compensation, delta_y);
|
|
1145
1045
|
nk_accumulate_square_f64_(&total_squared_z, &total_squared_z_compensation, delta_z);
|
|
1146
1046
|
}
|
|
1147
1047
|
|
|
1148
|
-
total_ax += total_ax_compensation, total_ay += total_ay_compensation, total_az += total_az_compensation;
|
|
1149
|
-
total_bx += total_bx_compensation, total_by += total_by_compensation, total_bz += total_bz_compensation;
|
|
1150
1048
|
total_squared_x += total_squared_x_compensation, total_squared_y += total_squared_y_compensation,
|
|
1151
1049
|
total_squared_z += total_squared_z_compensation;
|
|
1152
|
-
|
|
1153
|
-
nk_f64_t centroid_a_x = total_ax * inv_n, centroid_a_y = total_ay * inv_n, centroid_a_z = total_az * inv_n;
|
|
1154
|
-
nk_f64_t centroid_b_x = total_bx * inv_n, centroid_b_y = total_by * inv_n, centroid_b_z = total_bz * inv_n;
|
|
1155
|
-
|
|
1156
|
-
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1157
|
-
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1158
|
-
|
|
1159
|
-
// Compute RMSD using the formula:
|
|
1160
|
-
// RMSD = √(E[(a-b)²] - (ā - b̄)²).
|
|
1161
|
-
nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x, mean_diff_y = centroid_a_y - centroid_b_y,
|
|
1162
|
-
mean_diff_z = centroid_a_z - centroid_b_z;
|
|
1163
|
-
nk_f64_t sum_squared = total_squared_x + total_squared_y + total_squared_z;
|
|
1164
|
-
nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
1165
|
-
|
|
1166
|
-
*result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
1050
|
+
*result = nk_f64_sqrt_haswell((total_squared_x + total_squared_y + total_squared_z) / (nk_f64_t)n);
|
|
1167
1051
|
}
|
|
1168
1052
|
|
|
1169
1053
|
NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|