numkong 7.4.5 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +99 -5
- package/c/dispatch_e5m2.c +23 -3
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +1167 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +33 -11
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +41 -3
- package/include/numkong/each/serial.h +39 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +15 -4
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/maxsim/sme.h +34 -33
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +225 -161
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +301 -0
- package/include/numkong/spatials/serial.h +39 -0
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +54 -4
- package/include/numkong/tensor.hpp +107 -23
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +59 -14
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -840,28 +840,37 @@ nk_angular_e3m2_haswell_cycle:
|
|
|
840
840
|
}
|
|
841
841
|
|
|
842
842
|
NK_PUBLIC void nk_sqeuclidean_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
843
|
-
|
|
843
|
+
// E4M3 has no free widen shift, so we call the Giesen-based 8-lane cast helper
|
|
844
|
+
// twice per 16-lane iter and run with two F32 accumulators to break the FMA chain.
|
|
845
|
+
__m256 first_acc_f32x8 = _mm256_setzero_ps();
|
|
846
|
+
__m256 second_acc_f32x8 = _mm256_setzero_ps();
|
|
847
|
+
__m128i a_u8x16, b_u8x16;
|
|
844
848
|
|
|
845
849
|
nk_sqeuclidean_e4m3_haswell_cycle:
|
|
846
|
-
if (n <
|
|
850
|
+
if (n < 16) {
|
|
847
851
|
nk_b128_vec_t a_vec, b_vec;
|
|
848
852
|
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
849
853
|
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
854
|
+
a_u8x16 = a_vec.xmm;
|
|
855
|
+
b_u8x16 = b_vec.xmm;
|
|
856
|
+
n = 0;
|
|
854
857
|
}
|
|
855
858
|
else {
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
860
|
-
n -= 8, a += 8, b += 8;
|
|
861
|
-
goto nk_sqeuclidean_e4m3_haswell_cycle;
|
|
859
|
+
a_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
860
|
+
b_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
861
|
+
a += 16, b += 16, n -= 16;
|
|
862
862
|
}
|
|
863
|
+
__m256 a_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_u8x16);
|
|
864
|
+
__m256 a_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_u8x16, a_u8x16));
|
|
865
|
+
__m256 b_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_u8x16);
|
|
866
|
+
__m256 b_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_u8x16, b_u8x16));
|
|
867
|
+
__m256 diff_low_f32x8 = _mm256_sub_ps(a_low_f32x8, b_low_f32x8);
|
|
868
|
+
__m256 diff_high_f32x8 = _mm256_sub_ps(a_high_f32x8, b_high_f32x8);
|
|
869
|
+
first_acc_f32x8 = _mm256_fmadd_ps(diff_low_f32x8, diff_low_f32x8, first_acc_f32x8);
|
|
870
|
+
second_acc_f32x8 = _mm256_fmadd_ps(diff_high_f32x8, diff_high_f32x8, second_acc_f32x8);
|
|
871
|
+
if (n) goto nk_sqeuclidean_e4m3_haswell_cycle;
|
|
863
872
|
|
|
864
|
-
*result = nk_reduce_add_f32x8_haswell_(
|
|
873
|
+
*result = nk_reduce_add_f32x8_haswell_(_mm256_add_ps(first_acc_f32x8, second_acc_f32x8));
|
|
865
874
|
}
|
|
866
875
|
|
|
867
876
|
NK_PUBLIC void nk_euclidean_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -873,27 +882,33 @@ NK_PUBLIC void nk_angular_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, n
|
|
|
873
882
|
__m256 dot_product_f32x8 = _mm256_setzero_ps();
|
|
874
883
|
__m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
875
884
|
__m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
885
|
+
__m128i a_u8x16, b_u8x16;
|
|
876
886
|
|
|
877
887
|
nk_angular_e4m3_haswell_cycle:
|
|
878
|
-
if (n <
|
|
888
|
+
if (n < 16) {
|
|
879
889
|
nk_b128_vec_t a_vec, b_vec;
|
|
880
890
|
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
881
891
|
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
886
|
-
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
892
|
+
a_u8x16 = a_vec.xmm;
|
|
893
|
+
b_u8x16 = b_vec.xmm;
|
|
894
|
+
n = 0;
|
|
887
895
|
}
|
|
888
896
|
else {
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
+
a_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
898
|
+
b_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
899
|
+
a += 16, b += 16, n -= 16;
|
|
900
|
+
}
|
|
901
|
+
__m256 a_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_u8x16);
|
|
902
|
+
__m256 a_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_u8x16, a_u8x16));
|
|
903
|
+
__m256 b_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_u8x16);
|
|
904
|
+
__m256 b_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_u8x16, b_u8x16));
|
|
905
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_low_f32x8, b_low_f32x8, dot_product_f32x8);
|
|
906
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_high_f32x8, b_high_f32x8, dot_product_f32x8);
|
|
907
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_low_f32x8, a_low_f32x8, a_norm_sq_f32x8);
|
|
908
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_high_f32x8, a_high_f32x8, a_norm_sq_f32x8);
|
|
909
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_low_f32x8, b_low_f32x8, b_norm_sq_f32x8);
|
|
910
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_high_f32x8, b_high_f32x8, b_norm_sq_f32x8);
|
|
911
|
+
if (n) goto nk_angular_e4m3_haswell_cycle;
|
|
897
912
|
|
|
898
913
|
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
899
914
|
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
@@ -902,28 +917,44 @@ nk_angular_e4m3_haswell_cycle:
|
|
|
902
917
|
}
|
|
903
918
|
|
|
904
919
|
NK_PUBLIC void nk_sqeuclidean_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
905
|
-
|
|
920
|
+
// E5M2 shares F16's exponent bias (15): `byte << 8` equals the matching F16 encoding.
|
|
921
|
+
// `vpunpck*bw` against zero is the free widen+shift: zero byte in low half of each
|
|
922
|
+
// 16-bit lane, E5M2 byte in high half. Per-128-bit-lane scrambled; commutative sum
|
|
923
|
+
// reduction is invariant under that.
|
|
924
|
+
__m256 first_acc_f32x8 = _mm256_setzero_ps();
|
|
925
|
+
__m256 second_acc_f32x8 = _mm256_setzero_ps();
|
|
926
|
+
__m128i const zero_u8x16 = _mm_setzero_si128();
|
|
927
|
+
__m128i a_u8x16, b_u8x16;
|
|
906
928
|
|
|
907
929
|
nk_sqeuclidean_e5m2_haswell_cycle:
|
|
908
|
-
if (n <
|
|
930
|
+
if (n < 16) {
|
|
909
931
|
nk_b128_vec_t a_vec, b_vec;
|
|
910
932
|
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
911
933
|
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
934
|
+
a_u8x16 = a_vec.xmm;
|
|
935
|
+
b_u8x16 = b_vec.xmm;
|
|
936
|
+
n = 0;
|
|
916
937
|
}
|
|
917
938
|
else {
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
939
|
+
a_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
940
|
+
b_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
941
|
+
a += 16, b += 16, n -= 16;
|
|
942
|
+
}
|
|
943
|
+
__m128i a_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_u8x16);
|
|
944
|
+
__m128i a_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_u8x16);
|
|
945
|
+
__m128i b_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_u8x16);
|
|
946
|
+
__m128i b_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_u8x16);
|
|
947
|
+
__m256 a_first_f32x8 = _mm256_cvtph_ps(a_even_f16x8);
|
|
948
|
+
__m256 a_second_f32x8 = _mm256_cvtph_ps(a_odd_f16x8);
|
|
949
|
+
__m256 b_first_f32x8 = _mm256_cvtph_ps(b_even_f16x8);
|
|
950
|
+
__m256 b_second_f32x8 = _mm256_cvtph_ps(b_odd_f16x8);
|
|
951
|
+
__m256 diff_first_f32x8 = _mm256_sub_ps(a_first_f32x8, b_first_f32x8);
|
|
952
|
+
__m256 diff_second_f32x8 = _mm256_sub_ps(a_second_f32x8, b_second_f32x8);
|
|
953
|
+
first_acc_f32x8 = _mm256_fmadd_ps(diff_first_f32x8, diff_first_f32x8, first_acc_f32x8);
|
|
954
|
+
second_acc_f32x8 = _mm256_fmadd_ps(diff_second_f32x8, diff_second_f32x8, second_acc_f32x8);
|
|
955
|
+
if (n) goto nk_sqeuclidean_e5m2_haswell_cycle;
|
|
956
|
+
|
|
957
|
+
*result = nk_reduce_add_f32x8_haswell_(_mm256_add_ps(first_acc_f32x8, second_acc_f32x8));
|
|
927
958
|
}
|
|
928
959
|
|
|
929
960
|
NK_PUBLIC void nk_euclidean_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -935,27 +966,38 @@ NK_PUBLIC void nk_angular_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, n
|
|
|
935
966
|
__m256 dot_product_f32x8 = _mm256_setzero_ps();
|
|
936
967
|
__m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
937
968
|
__m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
969
|
+
__m128i const zero_u8x16 = _mm_setzero_si128();
|
|
970
|
+
__m128i a_u8x16, b_u8x16;
|
|
938
971
|
|
|
939
972
|
nk_angular_e5m2_haswell_cycle:
|
|
940
|
-
if (n <
|
|
973
|
+
if (n < 16) {
|
|
941
974
|
nk_b128_vec_t a_vec, b_vec;
|
|
942
975
|
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
943
976
|
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
948
|
-
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
977
|
+
a_u8x16 = a_vec.xmm;
|
|
978
|
+
b_u8x16 = b_vec.xmm;
|
|
979
|
+
n = 0;
|
|
949
980
|
}
|
|
950
981
|
else {
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
982
|
+
a_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
983
|
+
b_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
984
|
+
a += 16, b += 16, n -= 16;
|
|
985
|
+
}
|
|
986
|
+
__m128i a_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_u8x16);
|
|
987
|
+
__m128i a_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_u8x16);
|
|
988
|
+
__m128i b_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_u8x16);
|
|
989
|
+
__m128i b_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_u8x16);
|
|
990
|
+
__m256 a_first_f32x8 = _mm256_cvtph_ps(a_even_f16x8);
|
|
991
|
+
__m256 a_second_f32x8 = _mm256_cvtph_ps(a_odd_f16x8);
|
|
992
|
+
__m256 b_first_f32x8 = _mm256_cvtph_ps(b_even_f16x8);
|
|
993
|
+
__m256 b_second_f32x8 = _mm256_cvtph_ps(b_odd_f16x8);
|
|
994
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_first_f32x8, b_first_f32x8, dot_product_f32x8);
|
|
995
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_second_f32x8, b_second_f32x8, dot_product_f32x8);
|
|
996
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_first_f32x8, a_first_f32x8, a_norm_sq_f32x8);
|
|
997
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_second_f32x8, a_second_f32x8, a_norm_sq_f32x8);
|
|
998
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_first_f32x8, b_first_f32x8, b_norm_sq_f32x8);
|
|
999
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_second_f32x8, b_second_f32x8, b_norm_sq_f32x8);
|
|
1000
|
+
if (n) goto nk_angular_e5m2_haswell_cycle;
|
|
959
1001
|
|
|
960
1002
|
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
961
1003
|
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
@@ -108,6 +108,15 @@ extern "C" {
|
|
|
108
108
|
} \
|
|
109
109
|
}
|
|
110
110
|
|
|
111
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
112
|
+
* See dots/serial.h for rationale. */
|
|
113
|
+
#if defined(__clang__)
|
|
114
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
115
|
+
#elif defined(__GNUC__)
|
|
116
|
+
#pragma GCC push_options
|
|
117
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
118
|
+
#endif
|
|
119
|
+
|
|
111
120
|
nk_define_angular_(f64, f64, f64, nk_assign_from_to_, nk_f64_rsqrt_serial) // nk_angular_f64_serial
|
|
112
121
|
nk_define_sqeuclidean_(f64, f64, f64, nk_assign_from_to_) // nk_sqeuclidean_f64_serial
|
|
113
122
|
nk_define_euclidean_(f64, f64, f64, f64, nk_assign_from_to_, nk_f64_sqrt_serial) // nk_euclidean_f64_serial
|
|
@@ -340,6 +349,12 @@ NK_INTERNAL void nk_euclidean_through_u32_from_dot_serial_(nk_b128_vec_t dots, n
|
|
|
340
349
|
}
|
|
341
350
|
}
|
|
342
351
|
|
|
352
|
+
#if defined(__clang__)
|
|
353
|
+
#pragma clang attribute pop
|
|
354
|
+
#elif defined(__GNUC__)
|
|
355
|
+
#pragma GCC pop_options
|
|
356
|
+
#endif
|
|
357
|
+
|
|
343
358
|
#if defined(__cplusplus)
|
|
344
359
|
} // extern "C"
|
|
345
360
|
#endif
|
|
@@ -346,28 +346,36 @@ nk_angular_f16_skylake_cycle:
|
|
|
346
346
|
}
|
|
347
347
|
|
|
348
348
|
NK_PUBLIC void nk_sqeuclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
349
|
-
|
|
350
|
-
|
|
349
|
+
// E4M3 has no free widen shift (its 4-bit exponent doesn't line up with F16's 5-bit
|
|
350
|
+
// at bit 10), so we call the Giesen-based 16-lane cast helper twice per iter and
|
|
351
|
+
// run with two F32 accumulators to break the FMA dependency chain.
|
|
352
|
+
__m512 first_acc_f32x16 = _mm512_setzero_ps();
|
|
353
|
+
__m512 second_acc_f32x16 = _mm512_setzero_ps();
|
|
354
|
+
__m256i a_u8x32, b_u8x32;
|
|
351
355
|
|
|
352
356
|
nk_sqeuclidean_e4m3_skylake_cycle:
|
|
353
|
-
if (n <
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
+
if (n < 32) {
|
|
358
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
359
|
+
a_u8x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
360
|
+
b_u8x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
357
361
|
n = 0;
|
|
358
362
|
}
|
|
359
363
|
else {
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
a +=
|
|
364
|
+
a_u8x32 = _mm256_loadu_si256((__m256i const *)a);
|
|
365
|
+
b_u8x32 = _mm256_loadu_si256((__m256i const *)b);
|
|
366
|
+
a += 32, b += 32, n -= 32;
|
|
363
367
|
}
|
|
364
|
-
__m512
|
|
365
|
-
__m512
|
|
366
|
-
__m512
|
|
367
|
-
|
|
368
|
+
__m512 a_low_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_castsi256_si128(a_u8x32));
|
|
369
|
+
__m512 a_high_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_extracti128_si256(a_u8x32, 1));
|
|
370
|
+
__m512 b_low_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_castsi256_si128(b_u8x32));
|
|
371
|
+
__m512 b_high_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_extracti128_si256(b_u8x32, 1));
|
|
372
|
+
__m512 diff_low_f32x16 = _mm512_sub_ps(a_low_f32x16, b_low_f32x16);
|
|
373
|
+
__m512 diff_high_f32x16 = _mm512_sub_ps(a_high_f32x16, b_high_f32x16);
|
|
374
|
+
first_acc_f32x16 = _mm512_fmadd_ps(diff_low_f32x16, diff_low_f32x16, first_acc_f32x16);
|
|
375
|
+
second_acc_f32x16 = _mm512_fmadd_ps(diff_high_f32x16, diff_high_f32x16, second_acc_f32x16);
|
|
368
376
|
if (n) goto nk_sqeuclidean_e4m3_skylake_cycle;
|
|
369
377
|
|
|
370
|
-
*result = nk_reduce_add_f32x16_skylake_(
|
|
378
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_add_ps(first_acc_f32x16, second_acc_f32x16));
|
|
371
379
|
}
|
|
372
380
|
|
|
373
381
|
NK_PUBLIC void nk_euclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -379,25 +387,30 @@ NK_PUBLIC void nk_angular_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, n
|
|
|
379
387
|
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
380
388
|
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
381
389
|
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
382
|
-
|
|
390
|
+
__m256i a_u8x32, b_u8x32;
|
|
383
391
|
|
|
384
392
|
nk_angular_e4m3_skylake_cycle:
|
|
385
|
-
if (n <
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
393
|
+
if (n < 32) {
|
|
394
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
395
|
+
a_u8x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
396
|
+
b_u8x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
389
397
|
n = 0;
|
|
390
398
|
}
|
|
391
399
|
else {
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
a +=
|
|
400
|
+
a_u8x32 = _mm256_loadu_si256((__m256i const *)a);
|
|
401
|
+
b_u8x32 = _mm256_loadu_si256((__m256i const *)b);
|
|
402
|
+
a += 32, b += 32, n -= 32;
|
|
395
403
|
}
|
|
396
|
-
__m512
|
|
397
|
-
__m512
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
404
|
+
__m512 a_low_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_castsi256_si128(a_u8x32));
|
|
405
|
+
__m512 a_high_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_extracti128_si256(a_u8x32, 1));
|
|
406
|
+
__m512 b_low_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_castsi256_si128(b_u8x32));
|
|
407
|
+
__m512 b_high_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_extracti128_si256(b_u8x32, 1));
|
|
408
|
+
dot_f32x16 = _mm512_fmadd_ps(a_low_f32x16, b_low_f32x16, dot_f32x16);
|
|
409
|
+
dot_f32x16 = _mm512_fmadd_ps(a_high_f32x16, b_high_f32x16, dot_f32x16);
|
|
410
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_low_f32x16, a_low_f32x16, a_norm_sq_f32x16);
|
|
411
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_high_f32x16, a_high_f32x16, a_norm_sq_f32x16);
|
|
412
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_low_f32x16, b_low_f32x16, b_norm_sq_f32x16);
|
|
413
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_high_f32x16, b_high_f32x16, b_norm_sq_f32x16);
|
|
401
414
|
if (n) goto nk_angular_e4m3_skylake_cycle;
|
|
402
415
|
|
|
403
416
|
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
@@ -407,28 +420,53 @@ nk_angular_e4m3_skylake_cycle:
|
|
|
407
420
|
}
|
|
408
421
|
|
|
409
422
|
NK_PUBLIC void nk_sqeuclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
410
|
-
|
|
411
|
-
|
|
423
|
+
// E5M2 shares F16's exponent bias (15): `byte << 8` equals the matching F16 bit-pattern
|
|
424
|
+
// for normals, subnormals, zero, Inf, and NaN. We expose that shift for free by unpacking
|
|
425
|
+
// against zero — the zero byte lands in the low half of each 16-bit lane, the E5M2 byte
|
|
426
|
+
// in the high half. `vpunpck*bw` is per-128-bit-lane so the F32 outputs are lane-scrambled
|
|
427
|
+
// across 512 bits, but the commutative sum reduction is invariant under that.
|
|
428
|
+
__m512 first_acc_f32x16 = _mm512_setzero_ps();
|
|
429
|
+
__m512 second_acc_f32x16 = _mm512_setzero_ps();
|
|
430
|
+
__m512i const zero_u8x64 = _mm512_setzero_si512();
|
|
431
|
+
__m512i a_u8x64, b_u8x64;
|
|
412
432
|
|
|
413
433
|
nk_sqeuclidean_e5m2_skylake_cycle:
|
|
414
|
-
if (n <
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
434
|
+
if (n < 64) {
|
|
435
|
+
__mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n);
|
|
436
|
+
a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
437
|
+
b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
418
438
|
n = 0;
|
|
419
439
|
}
|
|
420
440
|
else {
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
a +=
|
|
441
|
+
a_u8x64 = _mm512_loadu_si512((__m512i const *)a);
|
|
442
|
+
b_u8x64 = _mm512_loadu_si512((__m512i const *)b);
|
|
443
|
+
a += 64, b += 64, n -= 64;
|
|
424
444
|
}
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
445
|
+
__m512i a_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, a_u8x64);
|
|
446
|
+
__m512i a_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, a_u8x64);
|
|
447
|
+
__m512i b_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, b_u8x64);
|
|
448
|
+
__m512i b_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, b_u8x64);
|
|
449
|
+
|
|
450
|
+
__m512 a_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_even_f16x32));
|
|
451
|
+
__m512 a_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_even_f16x32, 1));
|
|
452
|
+
__m512 a_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_odd_f16x32));
|
|
453
|
+
__m512 a_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_odd_f16x32, 1));
|
|
454
|
+
__m512 b_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_even_f16x32));
|
|
455
|
+
__m512 b_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_even_f16x32, 1));
|
|
456
|
+
__m512 b_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_odd_f16x32));
|
|
457
|
+
__m512 b_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_odd_f16x32, 1));
|
|
458
|
+
|
|
459
|
+
__m512 diff_first_f32x16 = _mm512_sub_ps(a_first_f32x16, b_first_f32x16);
|
|
460
|
+
__m512 diff_second_f32x16 = _mm512_sub_ps(a_second_f32x16, b_second_f32x16);
|
|
461
|
+
__m512 diff_third_f32x16 = _mm512_sub_ps(a_third_f32x16, b_third_f32x16);
|
|
462
|
+
__m512 diff_fourth_f32x16 = _mm512_sub_ps(a_fourth_f32x16, b_fourth_f32x16);
|
|
463
|
+
first_acc_f32x16 = _mm512_fmadd_ps(diff_first_f32x16, diff_first_f32x16, first_acc_f32x16);
|
|
464
|
+
second_acc_f32x16 = _mm512_fmadd_ps(diff_second_f32x16, diff_second_f32x16, second_acc_f32x16);
|
|
465
|
+
first_acc_f32x16 = _mm512_fmadd_ps(diff_third_f32x16, diff_third_f32x16, first_acc_f32x16);
|
|
466
|
+
second_acc_f32x16 = _mm512_fmadd_ps(diff_fourth_f32x16, diff_fourth_f32x16, second_acc_f32x16);
|
|
429
467
|
if (n) goto nk_sqeuclidean_e5m2_skylake_cycle;
|
|
430
468
|
|
|
431
|
-
*result = nk_reduce_add_f32x16_skylake_(
|
|
469
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_add_ps(first_acc_f32x16, second_acc_f32x16));
|
|
432
470
|
}
|
|
433
471
|
|
|
434
472
|
NK_PUBLIC void nk_euclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -440,25 +478,47 @@ NK_PUBLIC void nk_angular_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, n
|
|
|
440
478
|
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
441
479
|
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
442
480
|
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
443
|
-
|
|
481
|
+
__m512i const zero_u8x64 = _mm512_setzero_si512();
|
|
482
|
+
__m512i a_u8x64, b_u8x64;
|
|
444
483
|
|
|
445
484
|
nk_angular_e5m2_skylake_cycle:
|
|
446
|
-
if (n <
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
485
|
+
if (n < 64) {
|
|
486
|
+
__mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n);
|
|
487
|
+
a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
488
|
+
b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
450
489
|
n = 0;
|
|
451
490
|
}
|
|
452
491
|
else {
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
a +=
|
|
492
|
+
a_u8x64 = _mm512_loadu_si512((__m512i const *)a);
|
|
493
|
+
b_u8x64 = _mm512_loadu_si512((__m512i const *)b);
|
|
494
|
+
a += 64, b += 64, n -= 64;
|
|
456
495
|
}
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
496
|
+
__m512i a_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, a_u8x64);
|
|
497
|
+
__m512i a_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, a_u8x64);
|
|
498
|
+
__m512i b_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, b_u8x64);
|
|
499
|
+
__m512i b_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, b_u8x64);
|
|
500
|
+
|
|
501
|
+
__m512 a_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_even_f16x32));
|
|
502
|
+
__m512 a_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_even_f16x32, 1));
|
|
503
|
+
__m512 a_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_odd_f16x32));
|
|
504
|
+
__m512 a_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_odd_f16x32, 1));
|
|
505
|
+
__m512 b_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_even_f16x32));
|
|
506
|
+
__m512 b_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_even_f16x32, 1));
|
|
507
|
+
__m512 b_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_odd_f16x32));
|
|
508
|
+
__m512 b_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_odd_f16x32, 1));
|
|
509
|
+
|
|
510
|
+
dot_f32x16 = _mm512_fmadd_ps(a_first_f32x16, b_first_f32x16, dot_f32x16);
|
|
511
|
+
dot_f32x16 = _mm512_fmadd_ps(a_second_f32x16, b_second_f32x16, dot_f32x16);
|
|
512
|
+
dot_f32x16 = _mm512_fmadd_ps(a_third_f32x16, b_third_f32x16, dot_f32x16);
|
|
513
|
+
dot_f32x16 = _mm512_fmadd_ps(a_fourth_f32x16, b_fourth_f32x16, dot_f32x16);
|
|
514
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_first_f32x16, a_first_f32x16, a_norm_sq_f32x16);
|
|
515
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_second_f32x16, a_second_f32x16, a_norm_sq_f32x16);
|
|
516
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_third_f32x16, a_third_f32x16, a_norm_sq_f32x16);
|
|
517
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_fourth_f32x16, a_fourth_f32x16, a_norm_sq_f32x16);
|
|
518
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_first_f32x16, b_first_f32x16, b_norm_sq_f32x16);
|
|
519
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_second_f32x16, b_second_f32x16, b_norm_sq_f32x16);
|
|
520
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_third_f32x16, b_third_f32x16, b_norm_sq_f32x16);
|
|
521
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_fourth_f32x16, b_fourth_f32x16, b_norm_sq_f32x16);
|
|
462
522
|
if (n) goto nk_angular_e5m2_skylake_cycle;
|
|
463
523
|
|
|
464
524
|
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
@@ -36,6 +36,7 @@
|
|
|
36
36
|
#if NK_TARGET_SVE
|
|
37
37
|
|
|
38
38
|
#include "numkong/types.h"
|
|
39
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
39
40
|
#include "numkong/spatial/neon.h" // `nk_f64_sqrt_neon`
|
|
40
41
|
#include "numkong/dot/sve.h" // `nk_dot_stable_sum_f64_sve_`
|
|
41
42
|
|
|
@@ -113,7 +114,7 @@ NK_PUBLIC void nk_sqeuclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_s
|
|
|
113
114
|
svfloat64_t diff_odd_f64x = svsub_f64_x(pred_odd_b64x, a_odd_f64x, b_odd_f64x);
|
|
114
115
|
dist_sq_f64x = svmla_f64_m(pred_odd_b64x, dist_sq_f64x, diff_odd_f64x, diff_odd_f64x);
|
|
115
116
|
}
|
|
116
|
-
nk_f64_t dist_sq_f64 =
|
|
117
|
+
nk_f64_t dist_sq_f64 = nk_svaddv_f64_(svptrue_b64(), dist_sq_f64x);
|
|
117
118
|
*result = dist_sq_f64;
|
|
118
119
|
}
|
|
119
120
|
|
|
@@ -149,9 +150,9 @@ NK_PUBLIC void nk_angular_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_
|
|
|
149
150
|
b2_f64x = svmla_f64_m(pred_odd_b64x, b2_f64x, b_odd_f64x, b_odd_f64x);
|
|
150
151
|
}
|
|
151
152
|
|
|
152
|
-
nk_f64_t ab_f64 =
|
|
153
|
-
nk_f64_t a2_f64 =
|
|
154
|
-
nk_f64_t b2_f64 =
|
|
153
|
+
nk_f64_t ab_f64 = nk_svaddv_f64_(svptrue_b64(), ab_f64x);
|
|
154
|
+
nk_f64_t a2_f64 = nk_svaddv_f64_(svptrue_b64(), a2_f64x);
|
|
155
|
+
nk_f64_t b2_f64 = nk_svaddv_f64_(svptrue_b64(), b2_f64x);
|
|
155
156
|
*result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
|
|
156
157
|
}
|
|
157
158
|
|
|
@@ -225,8 +226,8 @@ NK_PUBLIC void nk_angular_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_
|
|
|
225
226
|
} while (i < n);
|
|
226
227
|
|
|
227
228
|
nk_f64_t ab_f64 = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, ab_sum_f64x, ab_compensation_f64x);
|
|
228
|
-
nk_f64_t a2_f64 =
|
|
229
|
-
nk_f64_t b2_f64 =
|
|
229
|
+
nk_f64_t a2_f64 = nk_svaddv_f64_(predicate_all_b64x, a2_f64x);
|
|
230
|
+
nk_f64_t b2_f64 = nk_svaddv_f64_(predicate_all_b64x, b2_f64x);
|
|
230
231
|
*result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
|
|
231
232
|
}
|
|
232
233
|
|
|
@@ -36,6 +36,7 @@
|
|
|
36
36
|
#if NK_TARGET_SVEBFDOT
|
|
37
37
|
|
|
38
38
|
#include "numkong/types.h"
|
|
39
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
39
40
|
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
40
41
|
|
|
41
42
|
#if defined(__cplusplus)
|
|
@@ -75,7 +76,9 @@ NK_PUBLIC void nk_sqeuclidean_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t c
|
|
|
75
76
|
d2_high_f32x = svmla_f32_m(predicate_high_b32x, d2_high_f32x, a_minus_b_high_f32x, a_minus_b_high_f32x);
|
|
76
77
|
i += svcnth();
|
|
77
78
|
} while (i < n);
|
|
78
|
-
nk_f32_t
|
|
79
|
+
nk_f32_t d2_low = nk_svaddv_f32_(svptrue_b32(), d2_low_f32x);
|
|
80
|
+
nk_f32_t d2_high = nk_svaddv_f32_(svptrue_b32(), d2_high_f32x);
|
|
81
|
+
nk_f32_t d2 = d2_low + d2_high;
|
|
79
82
|
*result = d2;
|
|
80
83
|
}
|
|
81
84
|
NK_PUBLIC void nk_euclidean_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -101,9 +104,9 @@ NK_PUBLIC void nk_angular_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t const
|
|
|
101
104
|
i += svcnth();
|
|
102
105
|
} while (i < n);
|
|
103
106
|
|
|
104
|
-
nk_f32_t ab =
|
|
105
|
-
nk_f32_t a2 =
|
|
106
|
-
nk_f32_t b2 =
|
|
107
|
+
nk_f32_t ab = nk_svaddv_f32_(svptrue_b32(), ab_f32x);
|
|
108
|
+
nk_f32_t a2 = nk_svaddv_f32_(svptrue_b32(), a2_f32x);
|
|
109
|
+
nk_f32_t b2 = nk_svaddv_f32_(svptrue_b32(), b2_f32x);
|
|
107
110
|
*result = nk_angular_normalize_f32_neon_(ab, a2, b2);
|
|
108
111
|
}
|
|
109
112
|
|
|
@@ -32,6 +32,7 @@
|
|
|
32
32
|
#if NK_TARGET_SVEHALF
|
|
33
33
|
|
|
34
34
|
#include "numkong/types.h"
|
|
35
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
35
36
|
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
36
37
|
|
|
37
38
|
#if defined(__cplusplus)
|
|
@@ -74,7 +75,7 @@ NK_PUBLIC void nk_sqeuclidean_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const
|
|
|
74
75
|
|
|
75
76
|
i += svcnth();
|
|
76
77
|
} while (i < n);
|
|
77
|
-
*result =
|
|
78
|
+
*result = nk_svaddv_f32_(svptrue_b32(), d2_f32x);
|
|
78
79
|
}
|
|
79
80
|
|
|
80
81
|
NK_PUBLIC void nk_euclidean_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -114,9 +115,9 @@ NK_PUBLIC void nk_angular_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const *b_
|
|
|
114
115
|
i += svcnth();
|
|
115
116
|
} while (i < n);
|
|
116
117
|
|
|
117
|
-
nk_f32_t ab_f32 =
|
|
118
|
-
nk_f32_t a2_f32 =
|
|
119
|
-
nk_f32_t b2_f32 =
|
|
118
|
+
nk_f32_t ab_f32 = nk_svaddv_f32_(svptrue_b32(), ab_f32x);
|
|
119
|
+
nk_f32_t a2_f32 = nk_svaddv_f32_(svptrue_b32(), a2_f32x);
|
|
120
|
+
nk_f32_t b2_f32 = nk_svaddv_f32_(svptrue_b32(), b2_f32x);
|
|
120
121
|
*result = nk_angular_normalize_f32_neon_(ab_f32, a2_f32, b2_f32);
|
|
121
122
|
}
|
|
122
123
|
|
|
@@ -34,6 +34,7 @@
|
|
|
34
34
|
#if NK_TARGET_SVESDOT
|
|
35
35
|
|
|
36
36
|
#include "numkong/types.h"
|
|
37
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
37
38
|
#include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
|
|
38
39
|
|
|
39
40
|
#if defined(__cplusplus)
|
|
@@ -58,7 +59,7 @@ NK_PUBLIC void nk_sqeuclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_
|
|
|
58
59
|
distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
|
|
59
60
|
i += svcntb();
|
|
60
61
|
} while (i < n);
|
|
61
|
-
*result = (nk_u32_t)
|
|
62
|
+
*result = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), distance_sq_u32x);
|
|
62
63
|
}
|
|
63
64
|
NK_PUBLIC void nk_euclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
64
65
|
nk_u32_t distance_sq_u32;
|
|
@@ -81,9 +82,9 @@ NK_PUBLIC void nk_angular_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size
|
|
|
81
82
|
i += svcntb();
|
|
82
83
|
} while (i < n);
|
|
83
84
|
|
|
84
|
-
nk_i32_t ab = (nk_i32_t)
|
|
85
|
-
nk_i32_t a2 = (nk_i32_t)
|
|
86
|
-
nk_i32_t b2 = (nk_i32_t)
|
|
85
|
+
nk_i32_t ab = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), ab_i32x);
|
|
86
|
+
nk_i32_t a2 = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), a2_i32x);
|
|
87
|
+
nk_i32_t b2 = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), b2_i32x);
|
|
87
88
|
*result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
88
89
|
}
|
|
89
90
|
|
|
@@ -98,7 +99,7 @@ NK_PUBLIC void nk_sqeuclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_
|
|
|
98
99
|
distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
|
|
99
100
|
i += svcntb();
|
|
100
101
|
} while (i < n);
|
|
101
|
-
*result = (nk_u32_t)
|
|
102
|
+
*result = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), distance_sq_u32x);
|
|
102
103
|
}
|
|
103
104
|
NK_PUBLIC void nk_euclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
104
105
|
nk_u32_t distance_sq_u32;
|
|
@@ -121,9 +122,9 @@ NK_PUBLIC void nk_angular_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size
|
|
|
121
122
|
i += svcntb();
|
|
122
123
|
} while (i < n);
|
|
123
124
|
|
|
124
|
-
nk_u32_t ab = (nk_u32_t)
|
|
125
|
-
nk_u32_t a2 = (nk_u32_t)
|
|
126
|
-
nk_u32_t b2 = (nk_u32_t)
|
|
125
|
+
nk_u32_t ab = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), ab_u32x);
|
|
126
|
+
nk_u32_t a2 = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), a2_u32x);
|
|
127
|
+
nk_u32_t b2 = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), b2_u32x);
|
|
127
128
|
*result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
128
129
|
}
|
|
129
130
|
|