numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,15 +8,13 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_icelake_instructions Key AVX-512 VNNI Spatial Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_dpwssd_epi32
|
|
13
|
-
* _mm512_cvtepi8_epi16
|
|
14
|
-
* _mm512_sub_epi16
|
|
15
|
-
* _mm512_reduce_add_epi32 (pseudo: shuffle chain) ~8cy ~8cy
|
|
11
|
+
* Intrinsic Instruction Icelake Genoa
|
|
12
|
+
* _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
|
|
13
|
+
* _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 3cy @ p12
|
|
14
|
+
* _mm512_sub_epi16 VPSUBW (ZMM, ZMM, ZMM) 1cy @ p05 1cy @ p0123
|
|
16
15
|
*
|
|
17
16
|
* Ice Lake's VNNI enables efficient i8 distance computations via VPDPWSSD for squared differences.
|
|
18
17
|
* After widening i8 to i16, the same instruction computes both multiply and horizontal pair addition.
|
|
19
|
-
* This approach avoids the asymmetric VPDPBUSD issues with signed values like -128.
|
|
20
18
|
*/
|
|
21
19
|
#ifndef NK_SPATIAL_ICELAKE_H
|
|
22
20
|
#define NK_SPATIAL_ICELAKE_H
|
|
@@ -25,18 +23,21 @@
|
|
|
25
23
|
#if NK_TARGET_ICELAKE
|
|
26
24
|
|
|
27
25
|
#include "numkong/types.h"
|
|
26
|
+
#include "numkong/spatial/haswell.h" // `nk_angular_normalize_f32_haswell_`, `nk_f32_sqrt_haswell`
|
|
27
|
+
#include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
|
|
28
28
|
|
|
29
29
|
#if defined(__cplusplus)
|
|
30
30
|
extern "C" {
|
|
31
31
|
#endif
|
|
32
32
|
|
|
33
33
|
#if defined(__clang__)
|
|
34
|
-
#pragma clang attribute push(
|
|
35
|
-
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,f16c,fma,bmi,bmi2"))), \
|
|
34
|
+
#pragma clang attribute push( \
|
|
35
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,avx512vbmi,f16c,fma,bmi,bmi2"))), \
|
|
36
36
|
apply_to = function)
|
|
37
37
|
#elif defined(__GNUC__)
|
|
38
38
|
#pragma GCC push_options
|
|
39
|
-
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "
|
|
39
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "avx512vbmi", "f16c", "fma", \
|
|
40
|
+
"bmi", "bmi2")
|
|
40
41
|
#endif
|
|
41
42
|
|
|
42
43
|
NK_PUBLIC void nk_sqeuclidean_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
@@ -142,7 +143,7 @@ nk_angular_i8_icelake_cycle:
|
|
|
142
143
|
//
|
|
143
144
|
// VNNI instruction performance (Ice Lake vs Zen4 Genoa):
|
|
144
145
|
//
|
|
145
|
-
// Instruction
|
|
146
|
+
// Instruction Icelake Genoa
|
|
146
147
|
// VPDPBUSDS (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
|
|
147
148
|
// VPDPWSSDS (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
|
|
148
149
|
// VPMADDWD (ZMM, ZMM, ZMM) 5cy @ p05 3cy @ p01
|
|
@@ -173,7 +174,8 @@ nk_angular_i8_icelake_cycle:
|
|
|
173
174
|
nk_i32_t dot_product_i32 = _mm512_reduce_add_epi32(dot_product_i32x16);
|
|
174
175
|
nk_i32_t a_norm_sq_i32 = _mm512_reduce_add_epi32(a_norm_sq_i32x16);
|
|
175
176
|
nk_i32_t b_norm_sq_i32 = _mm512_reduce_add_epi32(b_norm_sq_i32x16);
|
|
176
|
-
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32,
|
|
177
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
|
|
178
|
+
(nk_f32_t)b_norm_sq_i32);
|
|
177
179
|
}
|
|
178
180
|
NK_PUBLIC void nk_sqeuclidean_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
179
181
|
__m512i distance_sq_low_i32x16 = _mm512_setzero_si512();
|
|
@@ -258,7 +260,8 @@ nk_angular_u8_icelake_cycle:
|
|
|
258
260
|
_mm512_add_epi32(dot_product_low_i32x16, dot_product_high_i32x16));
|
|
259
261
|
nk_i32_t a_norm_sq_i32 = _mm512_reduce_add_epi32(_mm512_add_epi32(a_norm_sq_low_i32x16, a_norm_sq_high_i32x16));
|
|
260
262
|
nk_i32_t b_norm_sq_i32 = _mm512_reduce_add_epi32(_mm512_add_epi32(b_norm_sq_low_i32x16, b_norm_sq_high_i32x16));
|
|
261
|
-
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32,
|
|
263
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
|
|
264
|
+
(nk_f32_t)b_norm_sq_i32);
|
|
262
265
|
}
|
|
263
266
|
|
|
264
267
|
NK_PUBLIC void nk_sqeuclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
@@ -285,7 +288,7 @@ NK_PUBLIC void nk_sqeuclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b,
|
|
|
285
288
|
__m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
|
|
286
289
|
__m512i const eight_i8x64 = _mm512_set1_epi8(8);
|
|
287
290
|
|
|
288
|
-
__m512i
|
|
291
|
+
__m512i a_i4_u8x64, b_i4_u8x64;
|
|
289
292
|
__m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
|
|
290
293
|
__m512i a_low_i8x64, a_high_i8x64, b_low_i8x64, b_high_i8x64;
|
|
291
294
|
__m512i diff_low_u8x64, diff_high_u8x64;
|
|
@@ -294,22 +297,22 @@ NK_PUBLIC void nk_sqeuclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b,
|
|
|
294
297
|
nk_sqeuclidean_i4_icelake_cycle:
|
|
295
298
|
if (n_bytes < 64) {
|
|
296
299
|
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
|
|
297
|
-
|
|
298
|
-
|
|
300
|
+
a_i4_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
301
|
+
b_i4_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
299
302
|
n_bytes = 0;
|
|
300
303
|
}
|
|
301
304
|
else {
|
|
302
|
-
|
|
303
|
-
|
|
305
|
+
a_i4_u8x64 = _mm512_loadu_epi8(a);
|
|
306
|
+
b_i4_u8x64 = _mm512_loadu_epi8(b);
|
|
304
307
|
a += 64, b += 64, n_bytes -= 64;
|
|
305
308
|
}
|
|
306
309
|
|
|
307
310
|
// Extract nibbles as unsigned [0,15]. VPSHUFB ignores high 4 bits of index,
|
|
308
311
|
// so no AND needed for low nibbles when used with lookup, but we need it here.
|
|
309
|
-
a_low_u8x64 = _mm512_and_si512(
|
|
310
|
-
a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(
|
|
311
|
-
b_low_u8x64 = _mm512_and_si512(
|
|
312
|
-
b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(
|
|
312
|
+
a_low_u8x64 = _mm512_and_si512(a_i4_u8x64, nibble_mask_u8x64);
|
|
313
|
+
a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4_u8x64, 4), nibble_mask_u8x64);
|
|
314
|
+
b_low_u8x64 = _mm512_and_si512(b_i4_u8x64, nibble_mask_u8x64);
|
|
315
|
+
b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4_u8x64, 4), nibble_mask_u8x64);
|
|
313
316
|
|
|
314
317
|
// Sign extend using XOR trick: signed = (nibble ^ 8) - 8
|
|
315
318
|
a_low_i8x64 = _mm512_sub_epi8(_mm512_xor_si512(a_low_u8x64, eight_i8x64), eight_i8x64);
|
|
@@ -363,7 +366,7 @@ NK_PUBLIC void nk_angular_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_
|
|
|
363
366
|
__m512i const eight_i8x64 = _mm512_set1_epi8(8);
|
|
364
367
|
__m512i const zeros_i8x64 = _mm512_setzero_si512();
|
|
365
368
|
|
|
366
|
-
__m512i
|
|
369
|
+
__m512i a_i4_u8x64, b_i4_u8x64;
|
|
367
370
|
__m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
|
|
368
371
|
__m512i ax_low_u8x64, ax_high_u8x64, bx_low_u8x64, bx_high_u8x64;
|
|
369
372
|
__m512i a_low_i8x64, a_high_i8x64, b_low_i8x64, b_high_i8x64;
|
|
@@ -379,21 +382,21 @@ NK_PUBLIC void nk_angular_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_
|
|
|
379
382
|
nk_angular_i4_icelake_cycle:
|
|
380
383
|
if (n_bytes < 64) {
|
|
381
384
|
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
|
|
382
|
-
|
|
383
|
-
|
|
385
|
+
a_i4_u8x64 = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, a);
|
|
386
|
+
b_i4_u8x64 = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, b);
|
|
384
387
|
n_bytes = 0;
|
|
385
388
|
}
|
|
386
389
|
else {
|
|
387
|
-
|
|
388
|
-
|
|
390
|
+
a_i4_u8x64 = _mm512_loadu_epi8(a);
|
|
391
|
+
b_i4_u8x64 = _mm512_loadu_epi8(b);
|
|
389
392
|
a += 64, b += 64, n_bytes -= 64;
|
|
390
393
|
}
|
|
391
394
|
|
|
392
395
|
// Extract nibbles as unsigned [0,15]
|
|
393
|
-
a_low_u8x64 = _mm512_and_si512(
|
|
394
|
-
a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(
|
|
395
|
-
b_low_u8x64 = _mm512_and_si512(
|
|
396
|
-
b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(
|
|
396
|
+
a_low_u8x64 = _mm512_and_si512(a_i4_u8x64, nibble_mask_u8x64);
|
|
397
|
+
a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4_u8x64, 4), nibble_mask_u8x64);
|
|
398
|
+
b_low_u8x64 = _mm512_and_si512(b_i4_u8x64, nibble_mask_u8x64);
|
|
399
|
+
b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4_u8x64, 4), nibble_mask_u8x64);
|
|
397
400
|
|
|
398
401
|
// Compute biased values: ax = a ^ 8 (still ∈ [0,15], just reordered)
|
|
399
402
|
ax_low_u8x64 = _mm512_xor_si512(a_low_u8x64, eight_i8x64);
|
|
@@ -440,7 +443,7 @@ nk_angular_i4_icelake_cycle:
|
|
|
440
443
|
nk_i32_t norm_excess = 128 * (nk_i32_t)(nk_size_round_up_to_multiple_(n_bytes_total, 64) - n_bytes_total);
|
|
441
444
|
nk_i32_t a2 = _mm512_reduce_add_epi32(a2_i32x16) - norm_excess;
|
|
442
445
|
nk_i32_t b2 = _mm512_reduce_add_epi32(b2_i32x16) - norm_excess;
|
|
443
|
-
*result = nk_angular_normalize_f32_haswell_(ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
446
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
444
447
|
}
|
|
445
448
|
|
|
446
449
|
NK_PUBLIC void nk_sqeuclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
@@ -457,7 +460,7 @@ NK_PUBLIC void nk_sqeuclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b,
|
|
|
457
460
|
// No sign extension needed since values are unsigned.
|
|
458
461
|
__m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
|
|
459
462
|
|
|
460
|
-
__m512i
|
|
463
|
+
__m512i a_u4_u8x64, b_u4_u8x64;
|
|
461
464
|
__m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
|
|
462
465
|
__m512i diff_low_u8x64, diff_high_u8x64;
|
|
463
466
|
__m512i d2_i32x16 = _mm512_setzero_si512();
|
|
@@ -465,21 +468,21 @@ NK_PUBLIC void nk_sqeuclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b,
|
|
|
465
468
|
nk_sqeuclidean_u4_icelake_cycle:
|
|
466
469
|
if (n_bytes < 64) {
|
|
467
470
|
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
|
|
468
|
-
|
|
469
|
-
|
|
471
|
+
a_u4_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
472
|
+
b_u4_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
470
473
|
n_bytes = 0;
|
|
471
474
|
}
|
|
472
475
|
else {
|
|
473
|
-
|
|
474
|
-
|
|
476
|
+
a_u4_u8x64 = _mm512_loadu_epi8(a);
|
|
477
|
+
b_u4_u8x64 = _mm512_loadu_epi8(b);
|
|
475
478
|
a += 64, b += 64, n_bytes -= 64;
|
|
476
479
|
}
|
|
477
480
|
|
|
478
481
|
// Extract nibbles as unsigned [0,15]
|
|
479
|
-
a_low_u8x64 = _mm512_and_si512(
|
|
480
|
-
a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(
|
|
481
|
-
b_low_u8x64 = _mm512_and_si512(
|
|
482
|
-
b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(
|
|
482
|
+
a_low_u8x64 = _mm512_and_si512(a_u4_u8x64, nibble_mask_u8x64);
|
|
483
|
+
a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4_u8x64, 4), nibble_mask_u8x64);
|
|
484
|
+
b_low_u8x64 = _mm512_and_si512(b_u4_u8x64, nibble_mask_u8x64);
|
|
485
|
+
b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4_u8x64, 4), nibble_mask_u8x64);
|
|
483
486
|
|
|
484
487
|
// Absolute difference for unsigned: |a-b| = (a ⊖ b) | (b ⊖ a) where ⊖ is saturating sub
|
|
485
488
|
diff_low_u8x64 = _mm512_or_si512(_mm512_subs_epu8(a_low_u8x64, b_low_u8x64),
|
|
@@ -515,7 +518,7 @@ NK_PUBLIC void nk_angular_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_
|
|
|
515
518
|
__m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
|
|
516
519
|
__m512i const zeros_i8x64 = _mm512_setzero_si512();
|
|
517
520
|
|
|
518
|
-
__m512i
|
|
521
|
+
__m512i a_u4_u8x64, b_u4_u8x64;
|
|
519
522
|
__m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
|
|
520
523
|
|
|
521
524
|
__m512i ab_i32x16 = zeros_i8x64;
|
|
@@ -525,21 +528,21 @@ NK_PUBLIC void nk_angular_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_
|
|
|
525
528
|
nk_angular_u4_icelake_cycle:
|
|
526
529
|
if (n_bytes < 64) {
|
|
527
530
|
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
|
|
528
|
-
|
|
529
|
-
|
|
531
|
+
a_u4_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
532
|
+
b_u4_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
530
533
|
n_bytes = 0;
|
|
531
534
|
}
|
|
532
535
|
else {
|
|
533
|
-
|
|
534
|
-
|
|
536
|
+
a_u4_u8x64 = _mm512_loadu_epi8(a);
|
|
537
|
+
b_u4_u8x64 = _mm512_loadu_epi8(b);
|
|
535
538
|
a += 64, b += 64, n_bytes -= 64;
|
|
536
539
|
}
|
|
537
540
|
|
|
538
541
|
// Extract nibbles as unsigned [0,15]
|
|
539
|
-
a_low_u8x64 = _mm512_and_si512(
|
|
540
|
-
a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(
|
|
541
|
-
b_low_u8x64 = _mm512_and_si512(
|
|
542
|
-
b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(
|
|
542
|
+
a_low_u8x64 = _mm512_and_si512(a_u4_u8x64, nibble_mask_u8x64);
|
|
543
|
+
a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4_u8x64, 4), nibble_mask_u8x64);
|
|
544
|
+
b_low_u8x64 = _mm512_and_si512(b_u4_u8x64, nibble_mask_u8x64);
|
|
545
|
+
b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4_u8x64, 4), nibble_mask_u8x64);
|
|
543
546
|
|
|
544
547
|
// Dot product with DPBUSD (safe for unsigned [0,15])
|
|
545
548
|
ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, a_low_u8x64, b_low_u8x64);
|
|
@@ -553,22 +556,500 @@ nk_angular_u4_icelake_cycle:
|
|
|
553
556
|
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
|
|
554
557
|
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0);
|
|
555
558
|
|
|
556
|
-
__m512i
|
|
557
|
-
__m512i
|
|
558
|
-
__m512i
|
|
559
|
-
__m512i
|
|
559
|
+
__m512i a2_low_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, a_low_u8x64);
|
|
560
|
+
__m512i a2_high_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, a_high_u8x64);
|
|
561
|
+
__m512i b2_low_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, b_low_u8x64);
|
|
562
|
+
__m512i b2_high_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, b_high_u8x64);
|
|
560
563
|
|
|
561
564
|
// Accumulate low and high squares separately using SAD to avoid u8 overflow
|
|
562
|
-
a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(
|
|
563
|
-
a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(
|
|
564
|
-
b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(
|
|
565
|
-
b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(
|
|
565
|
+
a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(a2_low_u8x64, zeros_i8x64));
|
|
566
|
+
a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(a2_high_u8x64, zeros_i8x64));
|
|
567
|
+
b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(b2_low_u8x64, zeros_i8x64));
|
|
568
|
+
b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(b2_high_u8x64, zeros_i8x64));
|
|
566
569
|
if (n_bytes) goto nk_angular_u4_icelake_cycle;
|
|
567
570
|
|
|
568
571
|
nk_i32_t ab = _mm512_reduce_add_epi32(ab_i32x16);
|
|
569
572
|
nk_i64_t a2 = _mm512_reduce_add_epi64(a2_i64x8);
|
|
570
573
|
nk_i64_t b2 = _mm512_reduce_add_epi64(b2_i64x8);
|
|
571
|
-
*result = nk_angular_normalize_f32_haswell_(ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
574
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_icelake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
578
|
+
// E4M3 squared Euclidean distance via octave VNNI.
|
|
579
|
+
|
|
580
|
+
__m512i const lut_normal_u8x64 = _mm512_set_epi8( //
|
|
581
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
582
|
+
30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8, //
|
|
583
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
584
|
+
30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8); //
|
|
585
|
+
__m512i const lut_subnorm_u8x64 = _mm512_set_epi8( //
|
|
586
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
|
|
587
|
+
0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
588
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
|
|
589
|
+
0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0); //
|
|
590
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x7F);
|
|
591
|
+
__m512i const subnorm_threshold_u8x64 = _mm512_set1_epi8(0x08);
|
|
592
|
+
__m512i const oct_threshold_20_u8x64 = _mm512_set1_epi8(0x20);
|
|
593
|
+
__m512i const oct_threshold_40_u8x64 = _mm512_set1_epi8(0x40);
|
|
594
|
+
__m512i const oct_threshold_60_u8x64 = _mm512_set1_epi8(0x60);
|
|
595
|
+
|
|
596
|
+
__m512i ab0_i32x16 = _mm512_setzero_si512(), ab1_i32x16 = _mm512_setzero_si512();
|
|
597
|
+
__m512i ab2_i32x16 = _mm512_setzero_si512(), ab3_i32x16 = _mm512_setzero_si512();
|
|
598
|
+
__m512i ab4_i32x16 = _mm512_setzero_si512(), ab5_i32x16 = _mm512_setzero_si512();
|
|
599
|
+
__m512i ab6_i32x16 = _mm512_setzero_si512();
|
|
600
|
+
__m512i a2_0_i32x16 = _mm512_setzero_si512(), a2_2_i32x16 = _mm512_setzero_si512();
|
|
601
|
+
__m512i a2_4_i32x16 = _mm512_setzero_si512(), a2_6_i32x16 = _mm512_setzero_si512();
|
|
602
|
+
__m512i b2_0_i32x16 = _mm512_setzero_si512(), b2_2_i32x16 = _mm512_setzero_si512();
|
|
603
|
+
__m512i b2_4_i32x16 = _mm512_setzero_si512(), b2_6_i32x16 = _mm512_setzero_si512();
|
|
604
|
+
__m512i a_e4m3_u8x64, b_e4m3_u8x64;
|
|
605
|
+
|
|
606
|
+
nk_sqeuclidean_e4m3_icelake_cycle:
|
|
607
|
+
if (n < 64) {
|
|
608
|
+
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
|
|
609
|
+
a_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
610
|
+
b_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
611
|
+
n = 0;
|
|
612
|
+
}
|
|
613
|
+
else {
|
|
614
|
+
a_e4m3_u8x64 = _mm512_loadu_si512(a);
|
|
615
|
+
b_e4m3_u8x64 = _mm512_loadu_si512(b);
|
|
616
|
+
a += 64, b += 64, n -= 64;
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
__m512i a_magnitude_u8x64 = _mm512_and_si512(a_e4m3_u8x64, magnitude_mask_u8x64);
|
|
620
|
+
__m512i b_magnitude_u8x64 = _mm512_and_si512(b_e4m3_u8x64, magnitude_mask_u8x64);
|
|
621
|
+
__m512i a_base_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_normal_u8x64);
|
|
622
|
+
__m512i b_base_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_normal_u8x64);
|
|
623
|
+
a_base_u8x64 = _mm512_mask_permutexvar_epi8(a_base_u8x64,
|
|
624
|
+
_mm512_cmplt_epu8_mask(a_magnitude_u8x64, subnorm_threshold_u8x64),
|
|
625
|
+
a_magnitude_u8x64, lut_subnorm_u8x64);
|
|
626
|
+
b_base_u8x64 = _mm512_mask_permutexvar_epi8(b_base_u8x64,
|
|
627
|
+
_mm512_cmplt_epu8_mask(b_magnitude_u8x64, subnorm_threshold_u8x64),
|
|
628
|
+
b_magnitude_u8x64, lut_subnorm_u8x64);
|
|
629
|
+
|
|
630
|
+
__m512i sign_diff_u8x64 = _mm512_ternarylogic_epi64(a_e4m3_u8x64, b_e4m3_u8x64, magnitude_mask_u8x64, 0x14);
|
|
631
|
+
__m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_base_u8x64, _mm512_test_epi8_mask(sign_diff_u8x64, sign_diff_u8x64),
|
|
632
|
+
_mm512_setzero_si512(), b_base_u8x64);
|
|
633
|
+
|
|
634
|
+
__mmask64 ka_lt20 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_20_u8x64);
|
|
635
|
+
__mmask64 ka_lt40 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_40_u8x64);
|
|
636
|
+
__mmask64 ka_lt60 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_60_u8x64);
|
|
637
|
+
__mmask64 kb_lt20 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_20_u8x64);
|
|
638
|
+
__mmask64 kb_lt40 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_40_u8x64);
|
|
639
|
+
__mmask64 kb_lt60 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_60_u8x64);
|
|
640
|
+
|
|
641
|
+
__m512i a0_u8x64 = _mm512_maskz_mov_epi8(ka_lt20, a_base_u8x64);
|
|
642
|
+
__m512i a1_u8x64 = _mm512_maskz_mov_epi8(ka_lt40 & ~ka_lt20, a_base_u8x64);
|
|
643
|
+
__m512i a2_u8x64 = _mm512_maskz_mov_epi8(ka_lt60 & ~ka_lt40, a_base_u8x64);
|
|
644
|
+
__m512i a3_u8x64 = _mm512_maskz_mov_epi8(~ka_lt60, a_base_u8x64);
|
|
645
|
+
|
|
646
|
+
__m512i b0_i8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_signed_i8x64);
|
|
647
|
+
__m512i b1_i8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_signed_i8x64);
|
|
648
|
+
__m512i b2_i8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_signed_i8x64);
|
|
649
|
+
__m512i b3_i8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_signed_i8x64);
|
|
650
|
+
|
|
651
|
+
// dot(a,b): 16 VPDPBUSD
|
|
652
|
+
ab0_i32x16 = _mm512_dpbusd_epi32(ab0_i32x16, a0_u8x64, b0_i8x64);
|
|
653
|
+
ab1_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab1_i32x16, a0_u8x64, b1_i8x64), a1_u8x64, b0_i8x64);
|
|
654
|
+
ab2_i32x16 = _mm512_dpbusd_epi32(
|
|
655
|
+
_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab2_i32x16, a0_u8x64, b2_i8x64), a1_u8x64, b1_i8x64), a2_u8x64,
|
|
656
|
+
b0_i8x64);
|
|
657
|
+
ab3_i32x16 = _mm512_dpbusd_epi32(
|
|
658
|
+
_mm512_dpbusd_epi32(
|
|
659
|
+
_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab3_i32x16, a0_u8x64, b3_i8x64), a1_u8x64, b2_i8x64), a2_u8x64,
|
|
660
|
+
b1_i8x64),
|
|
661
|
+
a3_u8x64, b0_i8x64);
|
|
662
|
+
ab4_i32x16 = _mm512_dpbusd_epi32(
|
|
663
|
+
_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab4_i32x16, a1_u8x64, b3_i8x64), a2_u8x64, b2_i8x64), a3_u8x64,
|
|
664
|
+
b1_i8x64);
|
|
665
|
+
ab5_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab5_i32x16, a2_u8x64, b3_i8x64), a3_u8x64, b2_i8x64);
|
|
666
|
+
ab6_i32x16 = _mm512_dpbusd_epi32(ab6_i32x16, a3_u8x64, b3_i8x64);
|
|
667
|
+
|
|
668
|
+
// ||a||²: 4 VPDPBUSD (self-dot, same-octave only)
|
|
669
|
+
a2_0_i32x16 = _mm512_dpbusd_epi32(a2_0_i32x16, a0_u8x64, a0_u8x64);
|
|
670
|
+
a2_2_i32x16 = _mm512_dpbusd_epi32(a2_2_i32x16, a1_u8x64, a1_u8x64);
|
|
671
|
+
a2_4_i32x16 = _mm512_dpbusd_epi32(a2_4_i32x16, a2_u8x64, a2_u8x64);
|
|
672
|
+
a2_6_i32x16 = _mm512_dpbusd_epi32(a2_6_i32x16, a3_u8x64, a3_u8x64);
|
|
673
|
+
|
|
674
|
+
// ||b||²: 4 VPDPBUSD (unsigned b, not signed)
|
|
675
|
+
__m512i b0_u8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_base_u8x64);
|
|
676
|
+
__m512i b1_u8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_base_u8x64);
|
|
677
|
+
__m512i b2_u8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_base_u8x64);
|
|
678
|
+
__m512i b3_u8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_base_u8x64);
|
|
679
|
+
b2_0_i32x16 = _mm512_dpbusd_epi32(b2_0_i32x16, b0_u8x64, b0_u8x64);
|
|
680
|
+
b2_2_i32x16 = _mm512_dpbusd_epi32(b2_2_i32x16, b1_u8x64, b1_u8x64);
|
|
681
|
+
b2_4_i32x16 = _mm512_dpbusd_epi32(b2_4_i32x16, b2_u8x64, b2_u8x64);
|
|
682
|
+
b2_6_i32x16 = _mm512_dpbusd_epi32(b2_6_i32x16, b3_u8x64, b3_u8x64);
|
|
683
|
+
|
|
684
|
+
if (n) goto nk_sqeuclidean_e4m3_icelake_cycle;
|
|
685
|
+
|
|
686
|
+
// Reduce dot(a,b)
|
|
687
|
+
__m512 ab_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(ab0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
|
|
688
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab1_i32x16), _mm512_set1_ps(1.52587890625e-05f), ab_f32x16);
|
|
689
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab2_i32x16), _mm512_set1_ps(2.44140625e-04f), ab_f32x16);
|
|
690
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab3_i32x16), _mm512_set1_ps(3.90625e-03f), ab_f32x16);
|
|
691
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab4_i32x16), _mm512_set1_ps(6.25e-02f), ab_f32x16);
|
|
692
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab5_i32x16), _mm512_set1_ps(1.0f), ab_f32x16);
|
|
693
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab6_i32x16), _mm512_set1_ps(16.0f), ab_f32x16);
|
|
694
|
+
|
|
695
|
+
// Reduce ||a||² and ||b||² (even-k only: scale = 2^(8·oct − 20))
|
|
696
|
+
__m512 a2_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(a2_0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
|
|
697
|
+
a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_2_i32x16), _mm512_set1_ps(2.44140625e-04f), a2_f32x16);
|
|
698
|
+
a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_4_i32x16), _mm512_set1_ps(6.25e-02f), a2_f32x16);
|
|
699
|
+
a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_6_i32x16), _mm512_set1_ps(16.0f), a2_f32x16);
|
|
700
|
+
|
|
701
|
+
__m512 b2_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(b2_0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
|
|
702
|
+
b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_2_i32x16), _mm512_set1_ps(2.44140625e-04f), b2_f32x16);
|
|
703
|
+
b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_4_i32x16), _mm512_set1_ps(6.25e-02f), b2_f32x16);
|
|
704
|
+
b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_6_i32x16), _mm512_set1_ps(16.0f), b2_f32x16);
|
|
705
|
+
|
|
706
|
+
// (a-b)² = ||a||² + ||b||² - 2·dot(a,b)
|
|
707
|
+
__m512 sum_sq_f32x16 = _mm512_add_ps(a2_f32x16, b2_f32x16);
|
|
708
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
NK_PUBLIC void nk_euclidean_e4m3_icelake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
712
|
+
nk_sqeuclidean_e4m3_icelake(a, b, n, result);
|
|
713
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
NK_PUBLIC void nk_angular_e4m3_icelake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
717
|
+
// E4M3 angular distance via octave VNNI.
|
|
718
|
+
|
|
719
|
+
__m512i const lut_normal_u8x64 = _mm512_set_epi8( //
|
|
720
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
721
|
+
30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8, //
|
|
722
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
723
|
+
30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8); //
|
|
724
|
+
__m512i const lut_subnorm_u8x64 = _mm512_set_epi8( //
|
|
725
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
|
|
726
|
+
0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
727
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
|
|
728
|
+
0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0); //
|
|
729
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x7F);
|
|
730
|
+
__m512i const subnorm_threshold_u8x64 = _mm512_set1_epi8(0x08);
|
|
731
|
+
__m512i const oct_threshold_20_u8x64 = _mm512_set1_epi8(0x20);
|
|
732
|
+
__m512i const oct_threshold_40_u8x64 = _mm512_set1_epi8(0x40);
|
|
733
|
+
__m512i const oct_threshold_60_u8x64 = _mm512_set1_epi8(0x60);
|
|
734
|
+
|
|
735
|
+
__m512i ab0_i32x16 = _mm512_setzero_si512(), ab1_i32x16 = _mm512_setzero_si512();
|
|
736
|
+
__m512i ab2_i32x16 = _mm512_setzero_si512(), ab3_i32x16 = _mm512_setzero_si512();
|
|
737
|
+
__m512i ab4_i32x16 = _mm512_setzero_si512(), ab5_i32x16 = _mm512_setzero_si512();
|
|
738
|
+
__m512i ab6_i32x16 = _mm512_setzero_si512();
|
|
739
|
+
__m512i a2_0_i32x16 = _mm512_setzero_si512(), a2_2_i32x16 = _mm512_setzero_si512();
|
|
740
|
+
__m512i a2_4_i32x16 = _mm512_setzero_si512(), a2_6_i32x16 = _mm512_setzero_si512();
|
|
741
|
+
__m512i b2_0_i32x16 = _mm512_setzero_si512(), b2_2_i32x16 = _mm512_setzero_si512();
|
|
742
|
+
__m512i b2_4_i32x16 = _mm512_setzero_si512(), b2_6_i32x16 = _mm512_setzero_si512();
|
|
743
|
+
__m512i a_e4m3_u8x64, b_e4m3_u8x64;
|
|
744
|
+
|
|
745
|
+
nk_angular_e4m3_icelake_cycle:
|
|
746
|
+
if (n < 64) {
|
|
747
|
+
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
|
|
748
|
+
a_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
749
|
+
b_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
750
|
+
n = 0;
|
|
751
|
+
}
|
|
752
|
+
else {
|
|
753
|
+
a_e4m3_u8x64 = _mm512_loadu_si512(a);
|
|
754
|
+
b_e4m3_u8x64 = _mm512_loadu_si512(b);
|
|
755
|
+
a += 64, b += 64, n -= 64;
|
|
756
|
+
}
|
|
757
|
+
|
|
758
|
+
__m512i a_magnitude_u8x64 = _mm512_and_si512(a_e4m3_u8x64, magnitude_mask_u8x64);
|
|
759
|
+
__m512i b_magnitude_u8x64 = _mm512_and_si512(b_e4m3_u8x64, magnitude_mask_u8x64);
|
|
760
|
+
__m512i a_base_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_normal_u8x64);
|
|
761
|
+
__m512i b_base_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_normal_u8x64);
|
|
762
|
+
a_base_u8x64 = _mm512_mask_permutexvar_epi8(a_base_u8x64,
|
|
763
|
+
_mm512_cmplt_epu8_mask(a_magnitude_u8x64, subnorm_threshold_u8x64),
|
|
764
|
+
a_magnitude_u8x64, lut_subnorm_u8x64);
|
|
765
|
+
b_base_u8x64 = _mm512_mask_permutexvar_epi8(b_base_u8x64,
|
|
766
|
+
_mm512_cmplt_epu8_mask(b_magnitude_u8x64, subnorm_threshold_u8x64),
|
|
767
|
+
b_magnitude_u8x64, lut_subnorm_u8x64);
|
|
768
|
+
|
|
769
|
+
__m512i sign_diff_u8x64 = _mm512_ternarylogic_epi64(a_e4m3_u8x64, b_e4m3_u8x64, magnitude_mask_u8x64, 0x14);
|
|
770
|
+
__m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_base_u8x64, _mm512_test_epi8_mask(sign_diff_u8x64, sign_diff_u8x64),
|
|
771
|
+
_mm512_setzero_si512(), b_base_u8x64);
|
|
772
|
+
|
|
773
|
+
__mmask64 ka_lt20 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_20_u8x64);
|
|
774
|
+
__mmask64 ka_lt40 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_40_u8x64);
|
|
775
|
+
__mmask64 ka_lt60 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_60_u8x64);
|
|
776
|
+
__mmask64 kb_lt20 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_20_u8x64);
|
|
777
|
+
__mmask64 kb_lt40 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_40_u8x64);
|
|
778
|
+
__mmask64 kb_lt60 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_60_u8x64);
|
|
779
|
+
|
|
780
|
+
__m512i a0_u8x64 = _mm512_maskz_mov_epi8(ka_lt20, a_base_u8x64);
|
|
781
|
+
__m512i a1_u8x64 = _mm512_maskz_mov_epi8(ka_lt40 & ~ka_lt20, a_base_u8x64);
|
|
782
|
+
__m512i a2_u8x64 = _mm512_maskz_mov_epi8(ka_lt60 & ~ka_lt40, a_base_u8x64);
|
|
783
|
+
__m512i a3_u8x64 = _mm512_maskz_mov_epi8(~ka_lt60, a_base_u8x64);
|
|
784
|
+
|
|
785
|
+
__m512i b0_i8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_signed_i8x64);
|
|
786
|
+
__m512i b1_i8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_signed_i8x64);
|
|
787
|
+
__m512i b2_i8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_signed_i8x64);
|
|
788
|
+
__m512i b3_i8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_signed_i8x64);
|
|
789
|
+
|
|
790
|
+
// dot(a,b): 16 VPDPBUSD
|
|
791
|
+
ab0_i32x16 = _mm512_dpbusd_epi32(ab0_i32x16, a0_u8x64, b0_i8x64);
|
|
792
|
+
ab1_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab1_i32x16, a0_u8x64, b1_i8x64), a1_u8x64, b0_i8x64);
|
|
793
|
+
ab2_i32x16 = _mm512_dpbusd_epi32(
|
|
794
|
+
_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab2_i32x16, a0_u8x64, b2_i8x64), a1_u8x64, b1_i8x64), a2_u8x64,
|
|
795
|
+
b0_i8x64);
|
|
796
|
+
ab3_i32x16 = _mm512_dpbusd_epi32(
|
|
797
|
+
_mm512_dpbusd_epi32(
|
|
798
|
+
_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab3_i32x16, a0_u8x64, b3_i8x64), a1_u8x64, b2_i8x64), a2_u8x64,
|
|
799
|
+
b1_i8x64),
|
|
800
|
+
a3_u8x64, b0_i8x64);
|
|
801
|
+
ab4_i32x16 = _mm512_dpbusd_epi32(
|
|
802
|
+
_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab4_i32x16, a1_u8x64, b3_i8x64), a2_u8x64, b2_i8x64), a3_u8x64,
|
|
803
|
+
b1_i8x64);
|
|
804
|
+
ab5_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab5_i32x16, a2_u8x64, b3_i8x64), a3_u8x64, b2_i8x64);
|
|
805
|
+
ab6_i32x16 = _mm512_dpbusd_epi32(ab6_i32x16, a3_u8x64, b3_i8x64);
|
|
806
|
+
|
|
807
|
+
// ||a||²: 4 VPDPBUSD
|
|
808
|
+
a2_0_i32x16 = _mm512_dpbusd_epi32(a2_0_i32x16, a0_u8x64, a0_u8x64);
|
|
809
|
+
a2_2_i32x16 = _mm512_dpbusd_epi32(a2_2_i32x16, a1_u8x64, a1_u8x64);
|
|
810
|
+
a2_4_i32x16 = _mm512_dpbusd_epi32(a2_4_i32x16, a2_u8x64, a2_u8x64);
|
|
811
|
+
a2_6_i32x16 = _mm512_dpbusd_epi32(a2_6_i32x16, a3_u8x64, a3_u8x64);
|
|
812
|
+
|
|
813
|
+
// ||b||²: 4 VPDPBUSD (unsigned b)
|
|
814
|
+
__m512i b0_u8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_base_u8x64);
|
|
815
|
+
__m512i b1_u8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_base_u8x64);
|
|
816
|
+
__m512i b2_u8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_base_u8x64);
|
|
817
|
+
__m512i b3_u8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_base_u8x64);
|
|
818
|
+
b2_0_i32x16 = _mm512_dpbusd_epi32(b2_0_i32x16, b0_u8x64, b0_u8x64);
|
|
819
|
+
b2_2_i32x16 = _mm512_dpbusd_epi32(b2_2_i32x16, b1_u8x64, b1_u8x64);
|
|
820
|
+
b2_4_i32x16 = _mm512_dpbusd_epi32(b2_4_i32x16, b2_u8x64, b2_u8x64);
|
|
821
|
+
b2_6_i32x16 = _mm512_dpbusd_epi32(b2_6_i32x16, b3_u8x64, b3_u8x64);
|
|
822
|
+
|
|
823
|
+
if (n) goto nk_angular_e4m3_icelake_cycle;
|
|
824
|
+
|
|
825
|
+
// Reduce dot(a,b)
|
|
826
|
+
__m512 ab_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(ab0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
|
|
827
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab1_i32x16), _mm512_set1_ps(1.52587890625e-05f), ab_f32x16);
|
|
828
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab2_i32x16), _mm512_set1_ps(2.44140625e-04f), ab_f32x16);
|
|
829
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab3_i32x16), _mm512_set1_ps(3.90625e-03f), ab_f32x16);
|
|
830
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab4_i32x16), _mm512_set1_ps(6.25e-02f), ab_f32x16);
|
|
831
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab5_i32x16), _mm512_set1_ps(1.0f), ab_f32x16);
|
|
832
|
+
ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab6_i32x16), _mm512_set1_ps(16.0f), ab_f32x16);
|
|
833
|
+
|
|
834
|
+
__m512 a2_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(a2_0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
|
|
835
|
+
a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_2_i32x16), _mm512_set1_ps(2.44140625e-04f), a2_f32x16);
|
|
836
|
+
a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_4_i32x16), _mm512_set1_ps(6.25e-02f), a2_f32x16);
|
|
837
|
+
a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_6_i32x16), _mm512_set1_ps(16.0f), a2_f32x16);
|
|
838
|
+
|
|
839
|
+
__m512 b2_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(b2_0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
|
|
840
|
+
b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_2_i32x16), _mm512_set1_ps(2.44140625e-04f), b2_f32x16);
|
|
841
|
+
b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_4_i32x16), _mm512_set1_ps(6.25e-02f), b2_f32x16);
|
|
842
|
+
b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_6_i32x16), _mm512_set1_ps(16.0f), b2_f32x16);
|
|
843
|
+
|
|
844
|
+
nk_f32_t ab_f32 = nk_reduce_add_f32x16_skylake_(ab_f32x16);
|
|
845
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a2_f32x16);
|
|
846
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b2_f32x16);
|
|
847
|
+
*result = nk_angular_normalize_f32_haswell_(ab_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_icelake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
851
|
+
// E2M3 squared Euclidean distance via VPDPBUSD integer MAC.
|
|
852
|
+
__m512i const lut_magnitude_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
|
|
853
|
+
32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0,
|
|
854
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
|
|
855
|
+
32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
856
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
857
|
+
__m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
|
|
858
|
+
__m512i ab_i32x16 = _mm512_setzero_si512();
|
|
859
|
+
__m512i a2_i32x16 = _mm512_setzero_si512();
|
|
860
|
+
__m512i b2_i32x16 = _mm512_setzero_si512();
|
|
861
|
+
__m512i a_e2m3_u8x64, b_e2m3_u8x64;
|
|
862
|
+
|
|
863
|
+
nk_sqeuclidean_e2m3_icelake_cycle:
|
|
864
|
+
if (n < 64) {
|
|
865
|
+
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
|
|
866
|
+
a_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
867
|
+
b_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
868
|
+
n = 0;
|
|
869
|
+
}
|
|
870
|
+
else {
|
|
871
|
+
a_e2m3_u8x64 = _mm512_loadu_si512(a);
|
|
872
|
+
b_e2m3_u8x64 = _mm512_loadu_si512(b);
|
|
873
|
+
a += 64, b += 64, n -= 64;
|
|
874
|
+
}
|
|
875
|
+
|
|
876
|
+
__m512i a_magnitude_u8x64 = _mm512_and_si512(a_e2m3_u8x64, magnitude_mask_u8x64);
|
|
877
|
+
__m512i b_magnitude_u8x64 = _mm512_and_si512(b_e2m3_u8x64, magnitude_mask_u8x64);
|
|
878
|
+
__m512i a_unsigned_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_magnitude_u8x64);
|
|
879
|
+
__m512i b_unsigned_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_magnitude_u8x64);
|
|
880
|
+
|
|
881
|
+
__m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e2m3_u8x64, b_e2m3_u8x64), sign_mask_u8x64);
|
|
882
|
+
__mmask64 negate_mask = _mm512_test_epi8_mask(sign_combined_u8x64, sign_combined_u8x64);
|
|
883
|
+
__m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_unsigned_u8x64, negate_mask, _mm512_setzero_si512(),
|
|
884
|
+
b_unsigned_u8x64);
|
|
885
|
+
|
|
886
|
+
ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, a_unsigned_u8x64, b_signed_i8x64);
|
|
887
|
+
a2_i32x16 = _mm512_dpbusd_epi32(a2_i32x16, a_unsigned_u8x64, a_unsigned_u8x64);
|
|
888
|
+
b2_i32x16 = _mm512_dpbusd_epi32(b2_i32x16, b_unsigned_u8x64, b_unsigned_u8x64);
|
|
889
|
+
|
|
890
|
+
if (n) goto nk_sqeuclidean_e2m3_icelake_cycle;
|
|
891
|
+
|
|
892
|
+
// (a-b)² = a² + b² − 2·ab, scaled by 256 (16² from LUT)
|
|
893
|
+
__m512 a2_f32x16 = _mm512_cvtepi32_ps(a2_i32x16);
|
|
894
|
+
__m512 b2_f32x16 = _mm512_cvtepi32_ps(b2_i32x16);
|
|
895
|
+
__m512 ab_f32x16 = _mm512_cvtepi32_ps(ab_i32x16);
|
|
896
|
+
__m512 sum_sq_f32x16 = _mm512_add_ps(a2_f32x16, b2_f32x16);
|
|
897
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16)) / 256.0f;
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
NK_PUBLIC void nk_euclidean_e2m3_icelake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
901
|
+
nk_sqeuclidean_e2m3_icelake(a, b, n, result);
|
|
902
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
903
|
+
}
|
|
904
|
+
|
|
905
|
+
NK_PUBLIC void nk_angular_e2m3_icelake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
906
|
+
// E2M3 angular distance via VPDPBUSD integer MAC.
|
|
907
|
+
__m512i const lut_magnitude_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
|
|
908
|
+
32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0,
|
|
909
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
|
|
910
|
+
32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
911
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
912
|
+
__m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
|
|
913
|
+
__m512i ab_i32x16 = _mm512_setzero_si512();
|
|
914
|
+
__m512i a2_i32x16 = _mm512_setzero_si512();
|
|
915
|
+
__m512i b2_i32x16 = _mm512_setzero_si512();
|
|
916
|
+
__m512i a_e2m3_u8x64, b_e2m3_u8x64;
|
|
917
|
+
|
|
918
|
+
nk_angular_e2m3_icelake_cycle:
|
|
919
|
+
if (n < 64) {
|
|
920
|
+
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
|
|
921
|
+
a_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
922
|
+
b_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
923
|
+
n = 0;
|
|
924
|
+
}
|
|
925
|
+
else {
|
|
926
|
+
a_e2m3_u8x64 = _mm512_loadu_si512(a);
|
|
927
|
+
b_e2m3_u8x64 = _mm512_loadu_si512(b);
|
|
928
|
+
a += 64, b += 64, n -= 64;
|
|
929
|
+
}
|
|
930
|
+
|
|
931
|
+
__m512i a_magnitude_u8x64 = _mm512_and_si512(a_e2m3_u8x64, magnitude_mask_u8x64);
|
|
932
|
+
__m512i b_magnitude_u8x64 = _mm512_and_si512(b_e2m3_u8x64, magnitude_mask_u8x64);
|
|
933
|
+
__m512i a_unsigned_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_magnitude_u8x64);
|
|
934
|
+
__m512i b_unsigned_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_magnitude_u8x64);
|
|
935
|
+
|
|
936
|
+
__m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e2m3_u8x64, b_e2m3_u8x64), sign_mask_u8x64);
|
|
937
|
+
__mmask64 negate_mask = _mm512_test_epi8_mask(sign_combined_u8x64, sign_combined_u8x64);
|
|
938
|
+
__m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_unsigned_u8x64, negate_mask, _mm512_setzero_si512(),
|
|
939
|
+
b_unsigned_u8x64);
|
|
940
|
+
|
|
941
|
+
ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, a_unsigned_u8x64, b_signed_i8x64);
|
|
942
|
+
a2_i32x16 = _mm512_dpbusd_epi32(a2_i32x16, a_unsigned_u8x64, a_unsigned_u8x64);
|
|
943
|
+
b2_i32x16 = _mm512_dpbusd_epi32(b2_i32x16, b_unsigned_u8x64, b_unsigned_u8x64);
|
|
944
|
+
|
|
945
|
+
if (n) goto nk_angular_e2m3_icelake_cycle;
|
|
946
|
+
|
|
947
|
+
nk_f32_t ab_f32 = (nk_f32_t)_mm512_reduce_add_epi32(ab_i32x16) / 256.0f;
|
|
948
|
+
nk_f32_t a_norm_sq_f32 = (nk_f32_t)_mm512_reduce_add_epi32(a2_i32x16) / 256.0f;
|
|
949
|
+
nk_f32_t b_norm_sq_f32 = (nk_f32_t)_mm512_reduce_add_epi32(b2_i32x16) / 256.0f;
|
|
950
|
+
*result = nk_angular_normalize_f32_haswell_(ab_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_icelake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
954
|
+
// E3M2 squared Euclidean distance via direct difference squaring.
|
|
955
|
+
__m512i const lut_magnitude_i16x32 = _mm512_set_epi16( //
|
|
956
|
+
448, 384, 320, 256, 224, 192, 160, 128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
957
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
958
|
+
__m512i const magnitude_mask_i16x32 = _mm512_set1_epi16(0x1F);
|
|
959
|
+
__m512i const sign_mask_i16x32 = _mm512_set1_epi16(0x20);
|
|
960
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
961
|
+
__m256i a_e3m2_u8x32, b_e3m2_u8x32;
|
|
962
|
+
|
|
963
|
+
nk_sqeuclidean_e3m2_icelake_cycle:
|
|
964
|
+
if (n < 32) {
|
|
965
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
966
|
+
a_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
967
|
+
b_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
968
|
+
n = 0;
|
|
969
|
+
}
|
|
970
|
+
else {
|
|
971
|
+
a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a);
|
|
972
|
+
b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b);
|
|
973
|
+
a += 32, b += 32, n -= 32;
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
__m512i a_u16x32 = _mm512_cvtepu8_epi16(a_e3m2_u8x32);
|
|
977
|
+
__m512i b_u16x32 = _mm512_cvtepu8_epi16(b_e3m2_u8x32);
|
|
978
|
+
__m512i a_unsigned_i16x32 = _mm512_permutexvar_epi16(_mm512_and_si512(a_u16x32, magnitude_mask_i16x32),
|
|
979
|
+
lut_magnitude_i16x32);
|
|
980
|
+
__m512i b_unsigned_i16x32 = _mm512_permutexvar_epi16(_mm512_and_si512(b_u16x32, magnitude_mask_i16x32),
|
|
981
|
+
lut_magnitude_i16x32);
|
|
982
|
+
|
|
983
|
+
// Apply signs individually
|
|
984
|
+
__mmask32 a_negative_mask = _mm512_test_epi16_mask(a_u16x32, sign_mask_i16x32);
|
|
985
|
+
__mmask32 b_negative_mask = _mm512_test_epi16_mask(b_u16x32, sign_mask_i16x32);
|
|
986
|
+
__m512i a_signed_i16x32 = _mm512_mask_sub_epi16(a_unsigned_i16x32, a_negative_mask, _mm512_setzero_si512(),
|
|
987
|
+
a_unsigned_i16x32);
|
|
988
|
+
__m512i b_signed_i16x32 = _mm512_mask_sub_epi16(b_unsigned_i16x32, b_negative_mask, _mm512_setzero_si512(),
|
|
989
|
+
b_unsigned_i16x32);
|
|
990
|
+
|
|
991
|
+
// Direct difference squaring: (a-b)² via VPMADDWD
|
|
992
|
+
__m512i diff_i16x32 = _mm512_sub_epi16(a_signed_i16x32, b_signed_i16x32);
|
|
993
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(diff_i16x32, diff_i16x32));
|
|
994
|
+
|
|
995
|
+
if (n) goto nk_sqeuclidean_e3m2_icelake_cycle;
|
|
996
|
+
*result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
|
|
997
|
+
}
|
|
998
|
+
|
|
999
|
+
NK_PUBLIC void nk_euclidean_e3m2_icelake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1000
|
+
nk_sqeuclidean_e3m2_icelake(a, b, n, result);
|
|
1001
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
1002
|
+
}
|
|
1003
|
+
|
|
1004
|
+
NK_PUBLIC void nk_angular_e3m2_icelake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1005
|
+
// E3M2 angular distance via VPMADDWD integer MAC.
|
|
1006
|
+
__m512i const lut_magnitude_i16x32 = _mm512_set_epi16( //
|
|
1007
|
+
448, 384, 320, 256, 224, 192, 160, 128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
1008
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
1009
|
+
__m512i const magnitude_mask_i16x32 = _mm512_set1_epi16(0x1F);
|
|
1010
|
+
__m512i const sign_mask_i16x32 = _mm512_set1_epi16(0x20);
|
|
1011
|
+
__m512i ab_i32x16 = _mm512_setzero_si512();
|
|
1012
|
+
__m512i a2_i32x16 = _mm512_setzero_si512();
|
|
1013
|
+
__m512i b2_i32x16 = _mm512_setzero_si512();
|
|
1014
|
+
__m256i a_e3m2_u8x32, b_e3m2_u8x32;
|
|
1015
|
+
|
|
1016
|
+
nk_angular_e3m2_icelake_cycle:
|
|
1017
|
+
if (n < 32) {
|
|
1018
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
1019
|
+
a_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
1020
|
+
b_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
1021
|
+
n = 0;
|
|
1022
|
+
}
|
|
1023
|
+
else {
|
|
1024
|
+
a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a);
|
|
1025
|
+
b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b);
|
|
1026
|
+
a += 32, b += 32, n -= 32;
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
__m512i a_u16x32 = _mm512_cvtepu8_epi16(a_e3m2_u8x32);
|
|
1030
|
+
__m512i b_u16x32 = _mm512_cvtepu8_epi16(b_e3m2_u8x32);
|
|
1031
|
+
__m512i a_unsigned_i16x32 = _mm512_permutexvar_epi16(_mm512_and_si512(a_u16x32, magnitude_mask_i16x32),
|
|
1032
|
+
lut_magnitude_i16x32);
|
|
1033
|
+
__m512i b_unsigned_i16x32 = _mm512_permutexvar_epi16(_mm512_and_si512(b_u16x32, magnitude_mask_i16x32),
|
|
1034
|
+
lut_magnitude_i16x32);
|
|
1035
|
+
|
|
1036
|
+
__mmask32 a_negative_mask = _mm512_test_epi16_mask(a_u16x32, sign_mask_i16x32);
|
|
1037
|
+
__mmask32 b_negative_mask = _mm512_test_epi16_mask(b_u16x32, sign_mask_i16x32);
|
|
1038
|
+
__m512i a_signed_i16x32 = _mm512_mask_sub_epi16(a_unsigned_i16x32, a_negative_mask, _mm512_setzero_si512(),
|
|
1039
|
+
a_unsigned_i16x32);
|
|
1040
|
+
__m512i b_signed_i16x32 = _mm512_mask_sub_epi16(b_unsigned_i16x32, b_negative_mask, _mm512_setzero_si512(),
|
|
1041
|
+
b_unsigned_i16x32);
|
|
1042
|
+
|
|
1043
|
+
ab_i32x16 = _mm512_add_epi32(ab_i32x16, _mm512_madd_epi16(a_signed_i16x32, b_signed_i16x32));
|
|
1044
|
+
a2_i32x16 = _mm512_add_epi32(a2_i32x16, _mm512_madd_epi16(a_unsigned_i16x32, a_unsigned_i16x32));
|
|
1045
|
+
b2_i32x16 = _mm512_add_epi32(b2_i32x16, _mm512_madd_epi16(b_unsigned_i16x32, b_unsigned_i16x32));
|
|
1046
|
+
|
|
1047
|
+
if (n) goto nk_angular_e3m2_icelake_cycle;
|
|
1048
|
+
|
|
1049
|
+
nk_f32_t ab_f32 = (nk_f32_t)_mm512_reduce_add_epi32(ab_i32x16) / 256.0f;
|
|
1050
|
+
nk_f32_t a_norm_sq_f32 = (nk_f32_t)_mm512_reduce_add_epi32(a2_i32x16) / 256.0f;
|
|
1051
|
+
nk_f32_t b_norm_sq_f32 = (nk_f32_t)_mm512_reduce_add_epi32(b2_i32x16) / 256.0f;
|
|
1052
|
+
*result = nk_angular_normalize_f32_haswell_(ab_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
572
1053
|
}
|
|
573
1054
|
|
|
574
1055
|
#if defined(__clang__)
|