numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -34,8 +34,8 @@ extern "C" {
|
|
|
34
34
|
NK_PUBLIC void nk_sqeuclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
35
35
|
nk_f32_t *result) {
|
|
36
36
|
// Per-lane accumulator — deferred horizontal reduction
|
|
37
|
-
nk_size_t
|
|
38
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
37
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
38
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
39
39
|
|
|
40
40
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
41
41
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -54,7 +54,7 @@ NK_PUBLIC void nk_sqeuclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t co
|
|
|
54
54
|
|
|
55
55
|
// Single horizontal reduction after the loop
|
|
56
56
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
57
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1,
|
|
57
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
|
|
58
58
|
}
|
|
59
59
|
|
|
60
60
|
NK_PUBLIC void nk_euclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -66,10 +66,10 @@ NK_PUBLIC void nk_euclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t cons
|
|
|
66
66
|
NK_PUBLIC void nk_angular_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
67
67
|
nk_f32_t *result) {
|
|
68
68
|
// Per-lane accumulators — deferred horizontal reduction
|
|
69
|
-
nk_size_t
|
|
70
|
-
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
71
|
-
vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
72
|
-
vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
69
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
70
|
+
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
71
|
+
vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
72
|
+
vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
73
73
|
|
|
74
74
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
75
75
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -89,9 +89,12 @@ NK_PUBLIC void nk_angular_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const
|
|
|
89
89
|
|
|
90
90
|
// Single horizontal reduction after the loop
|
|
91
91
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
92
|
-
nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(
|
|
93
|
-
|
|
94
|
-
nk_f32_t
|
|
92
|
+
nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(
|
|
93
|
+
__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, max_vector_length));
|
|
94
|
+
nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(
|
|
95
|
+
__riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, max_vector_length));
|
|
96
|
+
nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(
|
|
97
|
+
__riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, max_vector_length));
|
|
95
98
|
|
|
96
99
|
// Normalize: 1 − dot / sqrt(‖a‖² × ‖b‖²)
|
|
97
100
|
if (a_sq == 0.0f && b_sq == 0.0f) { *result = 0.0f; }
|
|
@@ -40,11 +40,11 @@ extern "C" {
|
|
|
40
40
|
#define nk_define_sqeuclidean_(input_type, accumulator_type, output_type, load_and_convert) \
|
|
41
41
|
NK_PUBLIC void nk_sqeuclidean_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
42
42
|
nk_size_t n, nk_##output_type##_t *result) { \
|
|
43
|
-
nk_##accumulator_type##_t sum = 0, compensation = 0,
|
|
43
|
+
nk_##accumulator_type##_t sum = 0, compensation = 0, a_value, b_value; \
|
|
44
44
|
for (nk_size_t i = 0; i != n; ++i) { \
|
|
45
|
-
load_and_convert(a + i, &
|
|
46
|
-
load_and_convert(b + i, &
|
|
47
|
-
nk_##accumulator_type##_t diff =
|
|
45
|
+
load_and_convert(a + i, &a_value); \
|
|
46
|
+
load_and_convert(b + i, &b_value); \
|
|
47
|
+
nk_##accumulator_type##_t diff = a_value - b_value; \
|
|
48
48
|
nk_##accumulator_type##_t term = diff * diff, t = sum + term; \
|
|
49
49
|
compensation += (nk_##accumulator_type##_abs_(sum) >= nk_##accumulator_type##_abs_(term)) \
|
|
50
50
|
? ((sum - t) + term) \
|
|
@@ -74,14 +74,14 @@ extern "C" {
|
|
|
74
74
|
#define nk_define_angular_(input_type, accumulator_type, output_type, load_and_convert, compute_rsqrt) \
|
|
75
75
|
NK_PUBLIC void nk_angular_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
76
76
|
nk_size_t n, nk_##output_type##_t *result) { \
|
|
77
|
-
nk_##accumulator_type##_t dot_sum = 0, a_sum = 0, b_sum = 0,
|
|
77
|
+
nk_##accumulator_type##_t dot_sum = 0, a_sum = 0, b_sum = 0, a_value, b_value; \
|
|
78
78
|
nk_##accumulator_type##_t compensation_dot = 0, compensation_a = 0, compensation_b = 0; \
|
|
79
79
|
for (nk_size_t i = 0; i != n; ++i) { \
|
|
80
|
-
load_and_convert(a + i, &
|
|
81
|
-
load_and_convert(b + i, &
|
|
82
|
-
nk_##accumulator_type##_t term_dot =
|
|
83
|
-
nk_##accumulator_type##_t term_a =
|
|
84
|
-
nk_##accumulator_type##_t term_b =
|
|
80
|
+
load_and_convert(a + i, &a_value); \
|
|
81
|
+
load_and_convert(b + i, &b_value); \
|
|
82
|
+
nk_##accumulator_type##_t term_dot = a_value * b_value, t_dot = dot_sum + term_dot; \
|
|
83
|
+
nk_##accumulator_type##_t term_a = a_value * a_value, t_a = a_sum + term_a; \
|
|
84
|
+
nk_##accumulator_type##_t term_b = b_value * b_value, t_b = b_sum + term_b; \
|
|
85
85
|
compensation_dot += (nk_##accumulator_type##_abs_(dot_sum) >= nk_##accumulator_type##_abs_(term_dot)) \
|
|
86
86
|
? ((dot_sum - t_dot) + term_dot) \
|
|
87
87
|
: ((term_dot - t_dot) + dot_sum); \
|
|
@@ -101,8 +101,9 @@ extern "C" {
|
|
|
101
101
|
if (a_norm_sq == 0 && b_norm_sq == 0) { *result = 0; } \
|
|
102
102
|
else if (dot_product == 0) { *result = 1; } \
|
|
103
103
|
else { \
|
|
104
|
-
nk_##output_type##_t unclipped_distance =
|
|
105
|
-
|
|
104
|
+
nk_##output_type##_t unclipped_distance = (nk_##output_type##_t)( \
|
|
105
|
+
1 - (nk_##output_type##_t)dot_product * compute_rsqrt((nk_##output_type##_t)a_norm_sq) * \
|
|
106
|
+
compute_rsqrt((nk_##output_type##_t)b_norm_sq)); \
|
|
106
107
|
*result = unclipped_distance > 0 ? unclipped_distance : 0; \
|
|
107
108
|
} \
|
|
108
109
|
}
|
|
@@ -8,12 +8,12 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_sierra_instructions AVXVNNIINT8 Instructions Performance
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm256_dpbssds_epi32
|
|
13
|
-
* _mm256_dpbssd_epi32
|
|
14
|
-
* _mm256_dpbuud_epi32
|
|
15
|
-
* _mm_rsqrt_ps
|
|
16
|
-
* _mm_sqrt_ss
|
|
11
|
+
* Intrinsic Instruction Sierra Forest
|
|
12
|
+
* _mm256_dpbssds_epi32 VPDPBSSDS (YMM, YMM, YMM) 4cy @ p05
|
|
13
|
+
* _mm256_dpbssd_epi32 VPDPBSSD (YMM, YMM, YMM) 4cy @ p05
|
|
14
|
+
* _mm256_dpbuud_epi32 VPDPBUUD (YMM, YMM, YMM) 4cy @ p05
|
|
15
|
+
* _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0
|
|
16
|
+
* _mm_sqrt_ss VSQRTSS (XMM, XMM, XMM) 12cy @ p0
|
|
17
17
|
*
|
|
18
18
|
* Sierra Forest (AVXVNNIINT8) provides native signed x signed and unsigned x unsigned
|
|
19
19
|
* dot products, eliminating the need for algebraic corrections required on Alder Lake.
|
|
@@ -67,7 +67,8 @@ NK_PUBLIC void nk_angular_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_
|
|
|
67
67
|
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
68
68
|
}
|
|
69
69
|
|
|
70
|
-
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32,
|
|
70
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
|
|
71
|
+
(nk_f32_t)b_norm_sq_i32);
|
|
71
72
|
}
|
|
72
73
|
|
|
73
74
|
NK_PUBLIC void nk_sqeuclidean_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
@@ -132,7 +133,8 @@ NK_PUBLIC void nk_angular_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_
|
|
|
132
133
|
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
133
134
|
}
|
|
134
135
|
|
|
135
|
-
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32,
|
|
136
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
|
|
137
|
+
(nk_f32_t)b_norm_sq_i32);
|
|
136
138
|
}
|
|
137
139
|
|
|
138
140
|
NK_PUBLIC void nk_sqeuclidean_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
@@ -177,15 +179,15 @@ NK_PUBLIC void nk_angular_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t cons
|
|
|
177
179
|
// Every e2m3 value × 16 is an exact integer in [-120, +120].
|
|
178
180
|
// DPBSSD(signed, signed) eliminates the need for unsigned conversion tricks.
|
|
179
181
|
//
|
|
180
|
-
__m256i const
|
|
181
|
-
|
|
182
|
-
__m256i const
|
|
183
|
-
|
|
182
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
|
|
183
|
+
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
184
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
185
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
184
186
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
185
187
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
186
188
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
187
189
|
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
188
|
-
__m256i
|
|
190
|
+
__m256i ab_i32x8 = _mm256_setzero_si256();
|
|
189
191
|
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
190
192
|
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
191
193
|
__m256i a_e2m3_u8x32, b_e2m3_u8x32;
|
|
@@ -207,35 +209,39 @@ nk_angular_e2m3_sierra_cycle:
|
|
|
207
209
|
|
|
208
210
|
// Decode a: extract magnitude, dual-VPSHUFB LUT, apply sign
|
|
209
211
|
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
210
|
-
__m256i
|
|
211
|
-
__m256i
|
|
212
|
-
|
|
213
|
-
|
|
212
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
213
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
214
|
+
half_select_u8x32);
|
|
215
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
|
|
216
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
|
|
217
|
+
a_high_select_u8x32);
|
|
214
218
|
__m256i a_negate = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
215
219
|
__m256i a_signed_i8x32 = _mm256_blendv_epi8(a_unsigned_u8x32,
|
|
216
220
|
_mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate);
|
|
217
221
|
|
|
218
222
|
// Decode b: same LUT decode + sign
|
|
219
223
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
220
|
-
__m256i
|
|
221
|
-
__m256i
|
|
222
|
-
|
|
223
|
-
|
|
224
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
225
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
226
|
+
half_select_u8x32);
|
|
227
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
|
|
228
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
|
|
229
|
+
b_high_select_u8x32);
|
|
224
230
|
__m256i b_negate = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
225
231
|
__m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32,
|
|
226
232
|
_mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate);
|
|
227
233
|
|
|
228
234
|
// VPDPBSSD: signed × signed → i32
|
|
229
|
-
|
|
235
|
+
ab_i32x8 = _mm256_dpbssd_epi32(ab_i32x8, a_signed_i8x32, b_signed_i8x32);
|
|
230
236
|
a_norm_i32x8 = _mm256_dpbssd_epi32(a_norm_i32x8, a_signed_i8x32, a_signed_i8x32);
|
|
231
237
|
b_norm_i32x8 = _mm256_dpbssd_epi32(b_norm_i32x8, b_signed_i8x32, b_signed_i8x32);
|
|
232
238
|
|
|
233
239
|
if (count_scalars) goto nk_angular_e2m3_sierra_cycle;
|
|
234
240
|
|
|
235
|
-
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(
|
|
241
|
+
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(ab_i32x8);
|
|
236
242
|
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
237
243
|
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
238
|
-
*result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
|
|
244
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_i32, (nk_f32_t)a_norm_i32, (nk_f32_t)b_norm_i32);
|
|
239
245
|
}
|
|
240
246
|
|
|
241
247
|
NK_PUBLIC void nk_sqeuclidean_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
|
|
@@ -243,15 +249,15 @@ NK_PUBLIC void nk_sqeuclidean_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t
|
|
|
243
249
|
// Squared Euclidean distance for e2m3 using norm decomposition + VPDPBSSD.
|
|
244
250
|
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
245
251
|
//
|
|
246
|
-
__m256i const
|
|
247
|
-
|
|
248
|
-
__m256i const
|
|
249
|
-
|
|
252
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
|
|
253
|
+
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
254
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
255
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
250
256
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
251
257
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
252
258
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
253
259
|
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
254
|
-
__m256i
|
|
260
|
+
__m256i ab_i32x8 = _mm256_setzero_si256();
|
|
255
261
|
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
256
262
|
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
257
263
|
__m256i a_e2m3_u8x32, b_e2m3_u8x32;
|
|
@@ -273,31 +279,35 @@ nk_sqeuclidean_e2m3_sierra_cycle:
|
|
|
273
279
|
|
|
274
280
|
// Decode a
|
|
275
281
|
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
276
|
-
__m256i
|
|
277
|
-
__m256i
|
|
278
|
-
|
|
279
|
-
|
|
282
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
283
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
284
|
+
half_select_u8x32);
|
|
285
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
|
|
286
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
|
|
287
|
+
a_high_select_u8x32);
|
|
280
288
|
__m256i a_negate = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
281
289
|
__m256i a_signed_i8x32 = _mm256_blendv_epi8(a_unsigned_u8x32,
|
|
282
290
|
_mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate);
|
|
283
291
|
|
|
284
292
|
// Decode b
|
|
285
293
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
286
|
-
__m256i
|
|
287
|
-
__m256i
|
|
288
|
-
|
|
289
|
-
|
|
294
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
295
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
296
|
+
half_select_u8x32);
|
|
297
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
|
|
298
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
|
|
299
|
+
b_high_select_u8x32);
|
|
290
300
|
__m256i b_negate = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
291
301
|
__m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32,
|
|
292
302
|
_mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate);
|
|
293
303
|
|
|
294
|
-
|
|
304
|
+
ab_i32x8 = _mm256_dpbssd_epi32(ab_i32x8, a_signed_i8x32, b_signed_i8x32);
|
|
295
305
|
a_norm_i32x8 = _mm256_dpbssd_epi32(a_norm_i32x8, a_signed_i8x32, a_signed_i8x32);
|
|
296
306
|
b_norm_i32x8 = _mm256_dpbssd_epi32(b_norm_i32x8, b_signed_i8x32, b_signed_i8x32);
|
|
297
307
|
|
|
298
308
|
if (count_scalars) goto nk_sqeuclidean_e2m3_sierra_cycle;
|
|
299
309
|
|
|
300
|
-
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(
|
|
310
|
+
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(ab_i32x8);
|
|
301
311
|
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
302
312
|
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
303
313
|
*result = (nk_f32_t)(a_norm_i32 + b_norm_i32 - 2 * dot_i32) / 256.0f;
|
|
@@ -308,6 +318,189 @@ NK_PUBLIC void nk_euclidean_e2m3_sierra(nk_e2m3_t const *a, nk_e2m3_t const *b,
|
|
|
308
318
|
*result = nk_f32_sqrt_haswell(*result);
|
|
309
319
|
}
|
|
310
320
|
|
|
321
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_sierra(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
|
|
322
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
323
|
+
// E3M2 squared Euclidean distance via direct difference squaring.
|
|
324
|
+
__m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
|
|
325
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
326
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
327
|
+
__m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
|
|
328
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
329
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
330
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
331
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
332
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
333
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
334
|
+
__m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
|
|
335
|
+
__m256i const ones_u8x32 = _mm256_set1_epi8(1);
|
|
336
|
+
__m256i const ones_i16x16 = _mm256_set1_epi16(1);
|
|
337
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
338
|
+
__m256i a_e3m2_u8x32, b_e3m2_u8x32;
|
|
339
|
+
|
|
340
|
+
nk_sqeuclidean_e3m2_sierra_cycle:
|
|
341
|
+
if (count_scalars < 32) {
|
|
342
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
343
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
344
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
345
|
+
a_e3m2_u8x32 = a_vec.ymm;
|
|
346
|
+
b_e3m2_u8x32 = b_vec.ymm;
|
|
347
|
+
count_scalars = 0;
|
|
348
|
+
}
|
|
349
|
+
else {
|
|
350
|
+
a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
351
|
+
b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
352
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
// Decode both to unsigned i16 via dual-VPSHUFB + interleave
|
|
356
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
|
|
357
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
|
|
358
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
359
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
360
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
361
|
+
half_select_u8x32);
|
|
362
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
363
|
+
half_select_u8x32);
|
|
364
|
+
__m256i a_low_bytes_u8x32 = _mm256_blendv_epi8(
|
|
365
|
+
_mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
|
|
366
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32), a_high_select_u8x32);
|
|
367
|
+
__m256i b_low_bytes_u8x32 = _mm256_blendv_epi8(
|
|
368
|
+
_mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
|
|
369
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32), b_high_select_u8x32);
|
|
370
|
+
__m256i a_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
|
|
371
|
+
ones_u8x32);
|
|
372
|
+
__m256i b_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
|
|
373
|
+
ones_u8x32);
|
|
374
|
+
|
|
375
|
+
// Interleave to i16 and apply signs
|
|
376
|
+
__m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
|
|
377
|
+
__m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
|
|
378
|
+
__m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
|
|
379
|
+
__m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
|
|
380
|
+
|
|
381
|
+
__m256i a_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
382
|
+
__m256i b_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
383
|
+
__m256i a_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
|
|
384
|
+
ones_i16x16);
|
|
385
|
+
__m256i a_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
|
|
386
|
+
ones_i16x16);
|
|
387
|
+
__m256i b_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
|
|
388
|
+
ones_i16x16);
|
|
389
|
+
__m256i b_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
|
|
390
|
+
ones_i16x16);
|
|
391
|
+
__m256i a_signed_low_i16x16 = _mm256_sign_epi16(a_low_i16x16, a_sign_low_i16x16);
|
|
392
|
+
__m256i a_signed_high_i16x16 = _mm256_sign_epi16(a_high_i16x16, a_sign_high_i16x16);
|
|
393
|
+
__m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, b_sign_low_i16x16);
|
|
394
|
+
__m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, b_sign_high_i16x16);
|
|
395
|
+
|
|
396
|
+
// Direct difference squaring: (a-b)² via VPMADDWD
|
|
397
|
+
__m256i diff_low_i16x16 = _mm256_sub_epi16(a_signed_low_i16x16, b_signed_low_i16x16);
|
|
398
|
+
__m256i diff_high_i16x16 = _mm256_sub_epi16(a_signed_high_i16x16, b_signed_high_i16x16);
|
|
399
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(diff_low_i16x16, diff_low_i16x16));
|
|
400
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(diff_high_i16x16, diff_high_i16x16));
|
|
401
|
+
|
|
402
|
+
if (count_scalars) goto nk_sqeuclidean_e3m2_sierra_cycle;
|
|
403
|
+
*result = (nk_f32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8) / 256.0f;
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
NK_PUBLIC void nk_euclidean_e3m2_sierra(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
407
|
+
nk_sqeuclidean_e3m2_sierra(a, b, n, result);
|
|
408
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
NK_PUBLIC void nk_angular_e3m2_sierra(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
|
|
412
|
+
nk_f32_t *result) {
|
|
413
|
+
// E3M2 angular distance via VPMADDWD integer MAC.
|
|
414
|
+
__m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
|
|
415
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
416
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
417
|
+
__m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
|
|
418
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
419
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
420
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
421
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
422
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
423
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
424
|
+
__m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
|
|
425
|
+
__m256i const ones_u8x32 = _mm256_set1_epi8(1);
|
|
426
|
+
__m256i const ones_i16x16 = _mm256_set1_epi16(1);
|
|
427
|
+
__m256i ab_i32x8 = _mm256_setzero_si256();
|
|
428
|
+
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
429
|
+
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
430
|
+
__m256i a_e3m2_u8x32, b_e3m2_u8x32;
|
|
431
|
+
|
|
432
|
+
nk_angular_e3m2_sierra_cycle:
|
|
433
|
+
if (count_scalars < 32) {
|
|
434
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
435
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
436
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
437
|
+
a_e3m2_u8x32 = a_vec.ymm;
|
|
438
|
+
b_e3m2_u8x32 = b_vec.ymm;
|
|
439
|
+
count_scalars = 0;
|
|
440
|
+
}
|
|
441
|
+
else {
|
|
442
|
+
a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
443
|
+
b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
444
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
// Decode both to unsigned i16
|
|
448
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
|
|
449
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
|
|
450
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
451
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
452
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
453
|
+
half_select_u8x32);
|
|
454
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
455
|
+
half_select_u8x32);
|
|
456
|
+
__m256i a_low_bytes_u8x32 = _mm256_blendv_epi8(
|
|
457
|
+
_mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
|
|
458
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32), a_high_select_u8x32);
|
|
459
|
+
__m256i b_low_bytes_u8x32 = _mm256_blendv_epi8(
|
|
460
|
+
_mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
|
|
461
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32), b_high_select_u8x32);
|
|
462
|
+
__m256i a_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
|
|
463
|
+
ones_u8x32);
|
|
464
|
+
__m256i b_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
|
|
465
|
+
ones_u8x32);
|
|
466
|
+
__m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
|
|
467
|
+
__m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
|
|
468
|
+
__m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
|
|
469
|
+
__m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
|
|
470
|
+
|
|
471
|
+
// Apply signs individually
|
|
472
|
+
__m256i a_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
473
|
+
__m256i b_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
474
|
+
__m256i a_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
|
|
475
|
+
ones_i16x16);
|
|
476
|
+
__m256i a_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
|
|
477
|
+
ones_i16x16);
|
|
478
|
+
__m256i b_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
|
|
479
|
+
ones_i16x16);
|
|
480
|
+
__m256i b_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
|
|
481
|
+
ones_i16x16);
|
|
482
|
+
__m256i a_signed_low_i16x16 = _mm256_sign_epi16(a_low_i16x16, a_sign_low_i16x16);
|
|
483
|
+
__m256i a_signed_high_i16x16 = _mm256_sign_epi16(a_high_i16x16, a_sign_high_i16x16);
|
|
484
|
+
__m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, b_sign_low_i16x16);
|
|
485
|
+
__m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, b_sign_high_i16x16);
|
|
486
|
+
|
|
487
|
+
// dot(a,b) + a² + b² via VPMADDWD
|
|
488
|
+
ab_i32x8 = _mm256_add_epi32(ab_i32x8, _mm256_madd_epi16(a_signed_low_i16x16, b_signed_low_i16x16));
|
|
489
|
+
ab_i32x8 = _mm256_add_epi32(ab_i32x8, _mm256_madd_epi16(a_signed_high_i16x16, b_signed_high_i16x16));
|
|
490
|
+
a_norm_i32x8 = _mm256_add_epi32(a_norm_i32x8, _mm256_madd_epi16(a_low_i16x16, a_low_i16x16));
|
|
491
|
+
a_norm_i32x8 = _mm256_add_epi32(a_norm_i32x8, _mm256_madd_epi16(a_high_i16x16, a_high_i16x16));
|
|
492
|
+
b_norm_i32x8 = _mm256_add_epi32(b_norm_i32x8, _mm256_madd_epi16(b_low_i16x16, b_low_i16x16));
|
|
493
|
+
b_norm_i32x8 = _mm256_add_epi32(b_norm_i32x8, _mm256_madd_epi16(b_high_i16x16, b_high_i16x16));
|
|
494
|
+
|
|
495
|
+
if (count_scalars) goto nk_angular_e3m2_sierra_cycle;
|
|
496
|
+
|
|
497
|
+
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(ab_i32x8);
|
|
498
|
+
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
499
|
+
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
500
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_i32 / 256.0f, (nk_f32_t)a_norm_i32 / 256.0f,
|
|
501
|
+
(nk_f32_t)b_norm_i32 / 256.0f);
|
|
502
|
+
}
|
|
503
|
+
|
|
311
504
|
#if defined(__clang__)
|
|
312
505
|
#pragma clang attribute pop
|
|
313
506
|
#elif defined(__GNUC__)
|