numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -139,74 +139,6 @@ nk_angular_bf16_genoa_cycle:
|
|
|
139
139
|
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
140
140
|
}
|
|
141
141
|
|
|
142
|
-
NK_PUBLIC void nk_sqeuclidean_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
143
|
-
__m512 a_sq_f32x16 = _mm512_setzero_ps();
|
|
144
|
-
__m512 b_sq_f32x16 = _mm512_setzero_ps();
|
|
145
|
-
__m512 ab_f32x16 = _mm512_setzero_ps();
|
|
146
|
-
__m256i a_e4m3x32, b_e4m3x32;
|
|
147
|
-
|
|
148
|
-
nk_sqeuclidean_e4m3_genoa_cycle:
|
|
149
|
-
if (n < 32) {
|
|
150
|
-
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
151
|
-
a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
152
|
-
b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
153
|
-
n = 0;
|
|
154
|
-
}
|
|
155
|
-
else {
|
|
156
|
-
a_e4m3x32 = _mm256_loadu_epi8(a);
|
|
157
|
-
b_e4m3x32 = _mm256_loadu_epi8(b);
|
|
158
|
-
a += 32, b += 32, n -= 32;
|
|
159
|
-
}
|
|
160
|
-
__m512i a_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(a_e4m3x32);
|
|
161
|
-
__m512i b_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(b_e4m3x32);
|
|
162
|
-
a_sq_f32x16 = _mm512_dpbf16_ps(a_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(a_bf16x32));
|
|
163
|
-
b_sq_f32x16 = _mm512_dpbf16_ps(b_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
164
|
-
ab_f32x16 = _mm512_dpbf16_ps(ab_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
165
|
-
if (n) goto nk_sqeuclidean_e4m3_genoa_cycle;
|
|
166
|
-
|
|
167
|
-
// (a-b)² = a² + b² - 2ab
|
|
168
|
-
__m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
|
|
169
|
-
*result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
NK_PUBLIC void nk_euclidean_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
173
|
-
nk_sqeuclidean_e4m3_genoa(a, b, n, result);
|
|
174
|
-
*result = nk_f32_sqrt_haswell(*result);
|
|
175
|
-
}
|
|
176
|
-
|
|
177
|
-
NK_PUBLIC void nk_angular_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
178
|
-
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
179
|
-
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
180
|
-
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
181
|
-
__m256i a_e4m3x32, b_e4m3x32;
|
|
182
|
-
|
|
183
|
-
nk_angular_e4m3_genoa_cycle:
|
|
184
|
-
if (n < 32) {
|
|
185
|
-
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
186
|
-
a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
187
|
-
b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
188
|
-
n = 0;
|
|
189
|
-
}
|
|
190
|
-
else {
|
|
191
|
-
a_e4m3x32 = _mm256_loadu_epi8(a);
|
|
192
|
-
b_e4m3x32 = _mm256_loadu_epi8(b);
|
|
193
|
-
a += 32, b += 32, n -= 32;
|
|
194
|
-
}
|
|
195
|
-
__m512i a_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(a_e4m3x32);
|
|
196
|
-
__m512i b_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(b_e4m3x32);
|
|
197
|
-
dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
198
|
-
a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
|
|
199
|
-
nk_m512bh_from_m512i_(a_bf16x32));
|
|
200
|
-
b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
|
|
201
|
-
nk_m512bh_from_m512i_(b_bf16x32));
|
|
202
|
-
if (n) goto nk_angular_e4m3_genoa_cycle;
|
|
203
|
-
|
|
204
|
-
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
205
|
-
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
206
|
-
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
207
|
-
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
208
|
-
}
|
|
209
|
-
|
|
210
142
|
NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
211
143
|
__m512 a_sq_f32x16 = _mm512_setzero_ps();
|
|
212
144
|
__m512 b_sq_f32x16 = _mm512_setzero_ps();
|
|
@@ -8,14 +8,14 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_haswell_instructions Key AVX2 Spatial Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm256_fmadd_ps
|
|
13
|
-
* _mm256_mul_ps
|
|
14
|
-
* _mm256_add_ps
|
|
15
|
-
* _mm256_sub_ps
|
|
16
|
-
* _mm_rsqrt_ps
|
|
17
|
-
* _mm_sqrt_ps
|
|
18
|
-
* _mm256_sqrt_ps
|
|
11
|
+
* Intrinsic Instruction Haswell Genoa
|
|
12
|
+
* _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy @ p01 4cy @ p01
|
|
13
|
+
* _mm256_mul_ps VMULPS (YMM, YMM, YMM) 5cy @ p01 3cy @ p01
|
|
14
|
+
* _mm256_add_ps VADDPS (YMM, YMM, YMM) 3cy @ p01 3cy @ p23
|
|
15
|
+
* _mm256_sub_ps VSUBPS (YMM, YMM, YMM) 3cy @ p01 3cy @ p23
|
|
16
|
+
* _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0 4cy @ p01
|
|
17
|
+
* _mm_sqrt_ps VSQRTPS (XMM, XMM) 11cy @ p0 15cy @ p01
|
|
18
|
+
* _mm256_sqrt_ps VSQRTPS (YMM, YMM) 19cy @ p0 15cy @ p01
|
|
19
19
|
*
|
|
20
20
|
* For angular distance normalization, `_mm_rsqrt_ps` provides ~12-bit precision (1.5 x 2⁻¹² error).
|
|
21
21
|
* Newton-Raphson refinement doubles precision to ~22-24 bits, sufficient for f32. For f64 we use
|
|
@@ -52,7 +52,7 @@ NK_INTERNAL __m128 nk_rsqrt_f32x4_haswell_(__m128 x) {
|
|
|
52
52
|
}
|
|
53
53
|
|
|
54
54
|
/** @brief Safe square root of 4 floats with zero-clamping for numerical stability. */
|
|
55
|
-
NK_INTERNAL __m128
|
|
55
|
+
NK_INTERNAL __m128 nk_sqrt_f32x4_haswell_(__m128 x) { return _mm_sqrt_ps(_mm_max_ps(x, _mm_setzero_ps())); }
|
|
56
56
|
|
|
57
57
|
/** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs. */
|
|
58
58
|
NK_INTERNAL void nk_angular_through_f32_from_dot_haswell_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
@@ -73,7 +73,7 @@ NK_INTERNAL void nk_euclidean_through_f32_from_dot_haswell_(nk_b128_vec_t dots,
|
|
|
73
73
|
__m128 query_sumsq_f32x4 = _mm_set1_ps(query_sumsq);
|
|
74
74
|
__m128 sum_sq_f32x4 = _mm_add_ps(query_sumsq_f32x4, target_sumsqs.xmm_ps);
|
|
75
75
|
__m128 dist_sq_f32x4 = _mm_fnmadd_ps(_mm_set1_ps(2.0f), dots_f32x4, sum_sq_f32x4);
|
|
76
|
-
results->xmm_ps =
|
|
76
|
+
results->xmm_ps = nk_sqrt_f32x4_haswell_(dist_sq_f32x4);
|
|
77
77
|
}
|
|
78
78
|
|
|
79
79
|
/** @brief Angular from_dot for native f64: 1 − dot / √(query_sumsq × target_sumsq) for 4 pairs. */
|
|
@@ -117,7 +117,7 @@ NK_INTERNAL void nk_euclidean_through_i32_from_dot_haswell_(nk_b128_vec_t dots,
|
|
|
117
117
|
__m128 query_sumsq_f32x4 = _mm_set1_ps((nk_f32_t)query_sumsq);
|
|
118
118
|
__m128 sum_sq_f32x4 = _mm_add_ps(query_sumsq_f32x4, _mm_cvtepi32_ps(target_sumsqs.xmm));
|
|
119
119
|
__m128 dist_sq_f32x4 = _mm_fnmadd_ps(_mm_set1_ps(2.0f), dots_f32x4, sum_sq_f32x4);
|
|
120
|
-
results->xmm_ps =
|
|
120
|
+
results->xmm_ps = nk_sqrt_f32x4_haswell_(dist_sq_f32x4);
|
|
121
121
|
}
|
|
122
122
|
|
|
123
123
|
/** @brief Angular from_dot for u32 accumulators: cast to f32, rsqrt+NR, clamp. 4 pairs. */
|
|
@@ -139,7 +139,7 @@ NK_INTERNAL void nk_euclidean_through_u32_from_dot_haswell_(nk_b128_vec_t dots,
|
|
|
139
139
|
__m128 query_sumsq_f32x4 = _mm_set1_ps((nk_f32_t)query_sumsq);
|
|
140
140
|
__m128 sum_sq_f32x4 = _mm_add_ps(query_sumsq_f32x4, _mm_cvtepi32_ps(target_sumsqs.xmm));
|
|
141
141
|
__m128 dist_sq_f32x4 = _mm_fnmadd_ps(_mm_set1_ps(2.0f), dots_f32x4, sum_sq_f32x4);
|
|
142
|
-
results->xmm_ps =
|
|
142
|
+
results->xmm_ps = nk_sqrt_f32x4_haswell_(dist_sq_f32x4);
|
|
143
143
|
}
|
|
144
144
|
|
|
145
145
|
NK_INTERNAL nk_f64_t nk_angular_normalize_f64_haswell_(nk_f64_t ab, nk_f64_t a2, nk_f64_t b2) {
|
|
@@ -173,28 +173,30 @@ NK_INTERNAL nk_f32_t nk_angular_normalize_f32_haswell_(nk_f32_t ab, nk_f32_t a2,
|
|
|
173
173
|
else if (ab == 0.0f) return 1.0f;
|
|
174
174
|
|
|
175
175
|
// Load the squares into an __m128 register for single-precision floating-point operations
|
|
176
|
-
__m128
|
|
176
|
+
__m128 squares_f32x4 = _mm_set_ps(a2, b2, a2, b2); // We replicate to make use of full register
|
|
177
177
|
|
|
178
178
|
// Compute the reciprocal square root of the squares using `_mm_rsqrt_ps` (single-precision)
|
|
179
|
-
__m128
|
|
179
|
+
__m128 rsqrts_f32x4 = _mm_rsqrt_ps(squares_f32x4);
|
|
180
180
|
|
|
181
181
|
// Perform one iteration of Newton-Raphson refinement to improve the precision of rsqrt:
|
|
182
182
|
// Formula: y' = y × (1.5 - 0.5 × x × y × y)
|
|
183
|
-
__m128
|
|
184
|
-
__m128
|
|
185
|
-
|
|
186
|
-
|
|
183
|
+
__m128 half_f32x4 = _mm_set1_ps(0.5f);
|
|
184
|
+
__m128 three_halves_f32x4 = _mm_set1_ps(1.5f);
|
|
185
|
+
rsqrts_f32x4 = _mm_mul_ps(
|
|
186
|
+
rsqrts_f32x4,
|
|
187
|
+
_mm_sub_ps(three_halves_f32x4,
|
|
188
|
+
_mm_mul_ps(half_f32x4, _mm_mul_ps(squares_f32x4, _mm_mul_ps(rsqrts_f32x4, rsqrts_f32x4)))));
|
|
187
189
|
|
|
188
190
|
// Extract the reciprocal square roots of a2 and b2 from the __m128 register
|
|
189
|
-
nk_f32_t a2_reciprocal = _mm_cvtss_f32(_mm_shuffle_ps(
|
|
190
|
-
nk_f32_t b2_reciprocal = _mm_cvtss_f32(
|
|
191
|
+
nk_f32_t a2_reciprocal = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts_f32x4, rsqrts_f32x4, _MM_SHUFFLE(0, 0, 0, 1)));
|
|
192
|
+
nk_f32_t b2_reciprocal = _mm_cvtss_f32(rsqrts_f32x4);
|
|
191
193
|
|
|
192
194
|
// Calculate the angular distance: 1 - dot_product × a2_reciprocal × b2_reciprocal
|
|
193
195
|
nk_f32_t result = 1.0f - ab * a2_reciprocal * b2_reciprocal;
|
|
194
196
|
return result > 0 ? result : 0;
|
|
195
197
|
}
|
|
196
198
|
|
|
197
|
-
#pragma region
|
|
199
|
+
#pragma region F16 and BF16 Floats
|
|
198
200
|
|
|
199
201
|
NK_PUBLIC void nk_sqeuclidean_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
200
202
|
__m256 a_f32x8, b_f32x8;
|
|
@@ -257,25 +259,32 @@ nk_angular_f16_haswell_cycle:
|
|
|
257
259
|
}
|
|
258
260
|
|
|
259
261
|
NK_PUBLIC void nk_sqeuclidean_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
260
|
-
|
|
262
|
+
__m256i a_bf16_i16x16, b_bf16_i16x16;
|
|
261
263
|
__m256 distance_sq_f32x8 = _mm256_setzero_ps();
|
|
264
|
+
__m256i mask_high_u32x8 = _mm256_set1_epi32((int)0xFFFF0000);
|
|
262
265
|
|
|
263
266
|
nk_sqeuclidean_bf16_haswell_cycle:
|
|
264
|
-
if (n <
|
|
267
|
+
if (n < 16) {
|
|
265
268
|
nk_b256_vec_t a_vec, b_vec;
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
269
|
+
nk_partial_load_b16x16_serial_(a, &a_vec, n);
|
|
270
|
+
nk_partial_load_b16x16_serial_(b, &b_vec, n);
|
|
271
|
+
a_bf16_i16x16 = a_vec.ymm;
|
|
272
|
+
b_bf16_i16x16 = b_vec.ymm;
|
|
270
273
|
n = 0;
|
|
271
274
|
}
|
|
272
275
|
else {
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
n -=
|
|
276
|
-
}
|
|
277
|
-
__m256
|
|
278
|
-
|
|
276
|
+
a_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)a);
|
|
277
|
+
b_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)b);
|
|
278
|
+
n -= 16, a += 16, b += 16;
|
|
279
|
+
}
|
|
280
|
+
__m256 a_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(a_bf16_i16x16, 16));
|
|
281
|
+
__m256 b_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(b_bf16_i16x16, 16));
|
|
282
|
+
__m256 diff_even_f32x8 = _mm256_sub_ps(a_even_f32x8, b_even_f32x8);
|
|
283
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_even_f32x8, diff_even_f32x8, distance_sq_f32x8);
|
|
284
|
+
__m256 a_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(a_bf16_i16x16, mask_high_u32x8));
|
|
285
|
+
__m256 b_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(b_bf16_i16x16, mask_high_u32x8));
|
|
286
|
+
__m256 diff_odd_f32x8 = _mm256_sub_ps(a_odd_f32x8, b_odd_f32x8);
|
|
287
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_odd_f32x8, diff_odd_f32x8, distance_sq_f32x8);
|
|
279
288
|
if (n) goto nk_sqeuclidean_bf16_haswell_cycle;
|
|
280
289
|
|
|
281
290
|
*result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
|
|
@@ -287,27 +296,35 @@ NK_PUBLIC void nk_euclidean_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
287
296
|
}
|
|
288
297
|
|
|
289
298
|
NK_PUBLIC void nk_angular_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
290
|
-
|
|
299
|
+
__m256i a_bf16_i16x16, b_bf16_i16x16;
|
|
291
300
|
__m256 dot_product_f32x8 = _mm256_setzero_ps(), a_norm_sq_f32x8 = _mm256_setzero_ps(),
|
|
292
301
|
b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
302
|
+
__m256i mask_high_u32x8 = _mm256_set1_epi32((int)0xFFFF0000);
|
|
293
303
|
|
|
294
304
|
nk_angular_bf16_haswell_cycle:
|
|
295
|
-
if (n <
|
|
305
|
+
if (n < 16) {
|
|
296
306
|
nk_b256_vec_t a_vec, b_vec;
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
307
|
+
nk_partial_load_b16x16_serial_(a, &a_vec, n);
|
|
308
|
+
nk_partial_load_b16x16_serial_(b, &b_vec, n);
|
|
309
|
+
a_bf16_i16x16 = a_vec.ymm;
|
|
310
|
+
b_bf16_i16x16 = b_vec.ymm;
|
|
301
311
|
n = 0;
|
|
302
312
|
}
|
|
303
313
|
else {
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
n -=
|
|
307
|
-
}
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
314
|
+
a_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)a);
|
|
315
|
+
b_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)b);
|
|
316
|
+
n -= 16, a += 16, b += 16;
|
|
317
|
+
}
|
|
318
|
+
__m256 a_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(a_bf16_i16x16, 16));
|
|
319
|
+
__m256 b_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(b_bf16_i16x16, 16));
|
|
320
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_even_f32x8, b_even_f32x8, dot_product_f32x8);
|
|
321
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_even_f32x8, a_even_f32x8, a_norm_sq_f32x8);
|
|
322
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_even_f32x8, b_even_f32x8, b_norm_sq_f32x8);
|
|
323
|
+
__m256 a_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(a_bf16_i16x16, mask_high_u32x8));
|
|
324
|
+
__m256 b_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(b_bf16_i16x16, mask_high_u32x8));
|
|
325
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_odd_f32x8, b_odd_f32x8, dot_product_f32x8);
|
|
326
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_odd_f32x8, a_odd_f32x8, a_norm_sq_f32x8);
|
|
327
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_odd_f32x8, b_odd_f32x8, b_norm_sq_f32x8);
|
|
311
328
|
if (n) goto nk_angular_bf16_haswell_cycle;
|
|
312
329
|
|
|
313
330
|
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
@@ -316,8 +333,8 @@ nk_angular_bf16_haswell_cycle:
|
|
|
316
333
|
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
317
334
|
}
|
|
318
335
|
|
|
319
|
-
#pragma endregion
|
|
320
|
-
#pragma region
|
|
336
|
+
#pragma endregion F16 and BF16 Floats
|
|
337
|
+
#pragma region I8 and U8 Integers
|
|
321
338
|
|
|
322
339
|
NK_PUBLIC void nk_sqeuclidean_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
323
340
|
// Optimized i8 L2-squared using saturating subtract + VPMADDWD
|
|
@@ -433,7 +450,8 @@ NK_PUBLIC void nk_angular_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size
|
|
|
433
450
|
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
434
451
|
}
|
|
435
452
|
|
|
436
|
-
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32,
|
|
453
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
|
|
454
|
+
(nk_f32_t)b_norm_sq_i32);
|
|
437
455
|
}
|
|
438
456
|
|
|
439
457
|
NK_PUBLIC void nk_sqeuclidean_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
@@ -539,11 +557,12 @@ NK_PUBLIC void nk_angular_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size
|
|
|
539
557
|
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
540
558
|
}
|
|
541
559
|
|
|
542
|
-
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32,
|
|
560
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
|
|
561
|
+
(nk_f32_t)b_norm_sq_i32);
|
|
543
562
|
}
|
|
544
563
|
|
|
545
|
-
#pragma endregion
|
|
546
|
-
#pragma region
|
|
564
|
+
#pragma endregion I8 and U8 Integers
|
|
565
|
+
#pragma region F32 and F64 Floats
|
|
547
566
|
|
|
548
567
|
NK_PUBLIC void nk_sqeuclidean_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
549
568
|
// Upcast to f64 for higher precision accumulation
|
|
@@ -693,8 +712,8 @@ nk_angular_f64_haswell_cycle:
|
|
|
693
712
|
nk_reduce_add_f64x4_haswell_(a_norm_sq_f64x4), nk_reduce_add_f64x4_haswell_(b_norm_sq_f64x4));
|
|
694
713
|
}
|
|
695
714
|
|
|
696
|
-
#pragma endregion
|
|
697
|
-
#pragma region
|
|
715
|
+
#pragma endregion F32 and F64 Floats
|
|
716
|
+
#pragma region FP8 Floats
|
|
698
717
|
|
|
699
718
|
NK_PUBLIC void nk_sqeuclidean_e2m3_haswell(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
700
719
|
__m256 distance_sq_f32x8 = _mm256_setzero_ps();
|
|
@@ -954,7 +973,7 @@ nk_angular_e5m2_haswell_cycle:
|
|
|
954
973
|
} // extern "C"
|
|
955
974
|
#endif
|
|
956
975
|
|
|
957
|
-
#pragma endregion
|
|
976
|
+
#pragma endregion FP8 Floats
|
|
958
977
|
#endif // NK_TARGET_HASWELL
|
|
959
978
|
#endif // NK_TARGET_X86_
|
|
960
979
|
#endif // NK_SPATIAL_HASWELL_H
|