numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,563 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for Alder Lake.
|
|
3
|
+
* @file include/numkong/dot/alder.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 4, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_alder_instructions AVX-VNNI Instructions Performance
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Alder Lake Raptor Lake
|
|
12
|
+
* _mm256_dpbusd_epi32 VPDPBUSD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
|
|
13
|
+
* _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
|
|
14
|
+
* _mm256_sad_epu8 VPSADBW (YMM, YMM, YMM) 3cy @ p5 3cy @ p5
|
|
15
|
+
*
|
|
16
|
+
* Alder Lake and Raptor Lake support AVX-VNNI (256-bit VNNI)
|
|
17
|
+
* for accelerated integer dot products. This is the 256-bit variant of AVX-512 VNNI found on Ice Lake.
|
|
18
|
+
* We use VPDPBUSD for asymmetric unsigned×signed multiplication with algebraic transformations to
|
|
19
|
+
* handle signed×signed (i8) and unsigned×unsigned (u8) cases.
|
|
20
|
+
*
|
|
21
|
+
* Performance improvements over previous approaches:
|
|
22
|
+
* - i8×i8: 1.3-1.4× speedup using dpbusd with XOR transformation (a+128)×b - 128×sum(b)
|
|
23
|
+
* - u8×u8: 1.8-2.0× speedup using dpbusd with XOR transformation a×(b-128) + 128×sum(a)
|
|
24
|
+
* These match the speedups achieved on Ice Lake (AVX-512 VNNI) but with 256-bit vectors.
|
|
25
|
+
*
|
|
26
|
+
* @section dot_alder_stateful Stateful Streaming Logic
|
|
27
|
+
*
|
|
28
|
+
* To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
|
|
29
|
+
* `NK_INTERNAL` functions:
|
|
30
|
+
*
|
|
31
|
+
* - nk_dot_i8x32 for 8-bit signed integer inputs using DPBUSD with algebraic transformation,
|
|
32
|
+
* - nk_dot_u8x32 for 8-bit unsigned integer inputs using DPBUSD with algebraic transformation.
|
|
33
|
+
*
|
|
34
|
+
* @code{c}
|
|
35
|
+
* nk_dot_i8x32_state_alder_t state_first, state_second, state_third, state_fourth;
|
|
36
|
+
* nk_b256_vec_t query_i8x32, target_first_i8x32, target_second_i8x32, target_third_i8x32, target_fourth_i8x32;
|
|
37
|
+
* nk_dot_i8x32_init_alder(&state_first);
|
|
38
|
+
* nk_dot_i8x32_init_alder(&state_second);
|
|
39
|
+
* nk_dot_i8x32_init_alder(&state_third);
|
|
40
|
+
* nk_dot_i8x32_init_alder(&state_fourth);
|
|
41
|
+
* for (nk_size_t idx = 0; idx + 32 <= depth; idx += 32) {
|
|
42
|
+
* query_i8x32.ymm = _mm256_loadu_si256(query_ptr + idx);
|
|
43
|
+
* target_first_i8x32.ymm = _mm256_loadu_si256(target_first_ptr + idx);
|
|
44
|
+
* target_second_i8x32.ymm = _mm256_loadu_si256(target_second_ptr + idx);
|
|
45
|
+
* target_third_i8x32.ymm = _mm256_loadu_si256(target_third_ptr + idx);
|
|
46
|
+
* target_fourth_i8x32.ymm = _mm256_loadu_si256(target_fourth_ptr + idx);
|
|
47
|
+
* nk_dot_i8x32_update_alder(&state_first, query_i8x32, target_first_i8x32, idx, 32);
|
|
48
|
+
* nk_dot_i8x32_update_alder(&state_second, query_i8x32, target_second_i8x32, idx, 32);
|
|
49
|
+
* nk_dot_i8x32_update_alder(&state_third, query_i8x32, target_third_i8x32, idx, 32);
|
|
50
|
+
* nk_dot_i8x32_update_alder(&state_fourth, query_i8x32, target_fourth_i8x32, idx, 32);
|
|
51
|
+
* }
|
|
52
|
+
* nk_b128_vec_t results_i32x4;
|
|
53
|
+
* nk_dot_i8x32_finalize_alder(&state_first, &state_second, &state_third, &state_fourth, depth, &results_i32x4);
|
|
54
|
+
* @endcode
|
|
55
|
+
*
|
|
56
|
+
* The unsigned variant follows the same pattern with appropriate type changes:
|
|
57
|
+
*
|
|
58
|
+
* @code{c}
|
|
59
|
+
* nk_dot_u8x32_state_alder_t state_first, state_second, state_third, state_fourth;
|
|
60
|
+
* nk_b256_vec_t query_u8x32, target_first_u8x32, target_second_u8x32, target_third_u8x32, target_fourth_u8x32;
|
|
61
|
+
* nk_dot_u8x32_init_alder(&state_first);
|
|
62
|
+
* nk_dot_u8x32_init_alder(&state_second);
|
|
63
|
+
* nk_dot_u8x32_init_alder(&state_third);
|
|
64
|
+
* nk_dot_u8x32_init_alder(&state_fourth);
|
|
65
|
+
* for (nk_size_t idx = 0; idx + 32 <= depth; idx += 32) {
|
|
66
|
+
* query_u8x32.ymm = _mm256_loadu_si256(query_ptr + idx);
|
|
67
|
+
* target_first_u8x32.ymm = _mm256_loadu_si256(target_first_ptr + idx);
|
|
68
|
+
* target_second_u8x32.ymm = _mm256_loadu_si256(target_second_ptr + idx);
|
|
69
|
+
* target_third_u8x32.ymm = _mm256_loadu_si256(target_third_ptr + idx);
|
|
70
|
+
* target_fourth_u8x32.ymm = _mm256_loadu_si256(target_fourth_ptr + idx);
|
|
71
|
+
* nk_dot_u8x32_update_alder(&state_first, query_u8x32, target_first_u8x32, idx, 32);
|
|
72
|
+
* nk_dot_u8x32_update_alder(&state_second, query_u8x32, target_second_u8x32, idx, 32);
|
|
73
|
+
* nk_dot_u8x32_update_alder(&state_third, query_u8x32, target_third_u8x32, idx, 32);
|
|
74
|
+
* nk_dot_u8x32_update_alder(&state_fourth, query_u8x32, target_fourth_u8x32, idx, 32);
|
|
75
|
+
* }
|
|
76
|
+
* nk_b128_vec_t results_u32x4;
|
|
77
|
+
* nk_dot_u8x32_finalize_alder(&state_first, &state_second, &state_third, &state_fourth, depth, &results_u32x4);
|
|
78
|
+
* @endcode
|
|
79
|
+
*/
|
|
80
|
+
#ifndef NK_DOT_ALDER_H
|
|
81
|
+
#define NK_DOT_ALDER_H
|
|
82
|
+
|
|
83
|
+
#if NK_TARGET_X86_
|
|
84
|
+
#if NK_TARGET_ALDER
|
|
85
|
+
|
|
86
|
+
#include "numkong/types.h"
|
|
87
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b8x32_serial_`
|
|
88
|
+
#include "numkong/reduce/haswell.h" // `nk_reduce_add_i32x8_haswell_`
|
|
89
|
+
|
|
90
|
+
#if defined(__cplusplus)
|
|
91
|
+
extern "C" {
|
|
92
|
+
#endif
|
|
93
|
+
|
|
94
|
+
// On GCC/Clang, VEX encoding is handled by target attributes.
|
|
95
|
+
// Alias the MSVC-specific _avx intrinsic names to standard names.
|
|
96
|
+
#if !defined(_MSC_VER)
|
|
97
|
+
#define _mm256_dpbusd_avx_epi32 _mm256_dpbusd_epi32
|
|
98
|
+
#define _mm256_dpwssd_avx_epi32 _mm256_dpwssd_epi32
|
|
99
|
+
#endif
|
|
100
|
+
|
|
101
|
+
#if defined(__clang__)
|
|
102
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2,avxvnni"))), apply_to = function)
|
|
103
|
+
#elif defined(__GNUC__)
|
|
104
|
+
#pragma GCC push_options
|
|
105
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2", "avxvnni")
|
|
106
|
+
#endif
|
|
107
|
+
|
|
108
|
+
NK_PUBLIC void nk_dot_i8_alder(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
109
|
+
nk_i32_t *result) {
|
|
110
|
+
// Optimized i8×i8 dot product using algebraic transformation with DPBUSD
|
|
111
|
+
//
|
|
112
|
+
// Algebraic transformation:
|
|
113
|
+
// Let a' = a XOR 0x80 (interpreted as unsigned, gives a+128 mod 256)
|
|
114
|
+
// dpbusd(a', b) computes: (a+128) × b [unsigned × signed]
|
|
115
|
+
// Therefore: a×b = (a+128)×b - 128×sum(b)
|
|
116
|
+
//
|
|
117
|
+
// Where:
|
|
118
|
+
// - XOR with 0x80 converts signed i8 [-128,127] to unsigned [0,255]
|
|
119
|
+
// - dpbusd performs unsigned×signed multiply-accumulate
|
|
120
|
+
// - Correction term 128×sum(b) is computed and subtracted at the end
|
|
121
|
+
//
|
|
122
|
+
// Performance: ~1.3-1.4× speedup expected over cvtepi8_epi16 + dpwssd approach
|
|
123
|
+
// - Processes 32 elements/iteration (AVX2 width)
|
|
124
|
+
// - Lower latency per iteration: 4 cy (VPDPBUSD @ p05) vs 3+4 = 7 cy (VPMOVSXBW @ p5 + VPMADDWD @ p05)
|
|
125
|
+
// - Better port utilization: VPDPBUSD (p05) runs in parallel with VPMOVSXBW (p5) + VPMADDWD (p05) for
|
|
126
|
+
// correction term, enabling dual-issue execution on p0 and p5 simultaneously. Old approach bottlenecked
|
|
127
|
+
// on p5 for sign extension.
|
|
128
|
+
//
|
|
129
|
+
__m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
|
|
130
|
+
__m256i const zeros_u8x32 = _mm256_setzero_si256();
|
|
131
|
+
__m256i sum_ab_i32x8 = _mm256_setzero_si256();
|
|
132
|
+
__m256i sum_b_biased_i64x4 = _mm256_setzero_si256();
|
|
133
|
+
__m256i a_i8x32, b_i8x32;
|
|
134
|
+
nk_size_t total_elements = count_scalars;
|
|
135
|
+
|
|
136
|
+
nk_dot_i8_alder_cycle:
|
|
137
|
+
if (count_scalars < 32) {
|
|
138
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
139
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
140
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
141
|
+
a_i8x32 = _mm256_load_si256(&a_vec.ymm);
|
|
142
|
+
b_i8x32 = _mm256_load_si256(&b_vec.ymm);
|
|
143
|
+
count_scalars = 0;
|
|
144
|
+
}
|
|
145
|
+
else {
|
|
146
|
+
a_i8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
147
|
+
b_i8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
148
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// Convert a to unsigned [0,255] by XOR with 0x80: a_unsigned = a + 128
|
|
152
|
+
__m256i a_unsigned_u8x32 = _mm256_xor_si256(a_i8x32, xor_mask_u8x32);
|
|
153
|
+
|
|
154
|
+
// Compute (a+128) × b using dpbusd: unsigned × signed
|
|
155
|
+
sum_ab_i32x8 = _mm256_dpbusd_avx_epi32(sum_ab_i32x8, a_unsigned_u8x32, b_i8x32);
|
|
156
|
+
|
|
157
|
+
// Accumulate sum(b+128) using SAD (replaces cvtepi8_epi16 + madd)
|
|
158
|
+
__m256i b_biased_u8x32 = _mm256_xor_si256(b_i8x32, xor_mask_u8x32);
|
|
159
|
+
sum_b_biased_i64x4 = _mm256_add_epi64(sum_b_biased_i64x4, _mm256_sad_epu8(b_biased_u8x32, zeros_u8x32));
|
|
160
|
+
|
|
161
|
+
if (count_scalars) goto nk_dot_i8_alder_cycle;
|
|
162
|
+
|
|
163
|
+
// Apply algebraic correction: a×b = (a+128)×b - 128×sum(b)
|
|
164
|
+
// With biased accumulator: sum(b) = sum_b_biased - 128×count
|
|
165
|
+
// So: correction = 128×sum(b) = 128×sum_b_biased - 16384×count
|
|
166
|
+
nk_i32_t ab_sum = nk_reduce_add_i32x8_haswell_(sum_ab_i32x8);
|
|
167
|
+
nk_i64_t sum_b_biased = nk_reduce_add_i64x4_haswell_(sum_b_biased_i64x4);
|
|
168
|
+
nk_size_t elements_rounded = nk_size_round_up_to_multiple_(total_elements, 32);
|
|
169
|
+
nk_i64_t correction = 128LL * sum_b_biased - 16384LL * (nk_i64_t)elements_rounded;
|
|
170
|
+
|
|
171
|
+
*result = (nk_i32_t)(ab_sum - correction);
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
typedef struct nk_dot_i8x32_state_alder_t {
|
|
175
|
+
__m256i biased_product_sum_i32x8; // Single accumulator: (a+128)×b, correction applied at finalize
|
|
176
|
+
} nk_dot_i8x32_state_alder_t;
|
|
177
|
+
|
|
178
|
+
NK_INTERNAL void nk_dot_i8x32_init_alder(nk_dot_i8x32_state_alder_t *state) {
|
|
179
|
+
state->biased_product_sum_i32x8 = _mm256_setzero_si256();
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
NK_INTERNAL void nk_dot_i8x32_update_alder(nk_dot_i8x32_state_alder_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
|
|
183
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
184
|
+
nk_unused_(depth_offset);
|
|
185
|
+
nk_unused_(active_dimensions);
|
|
186
|
+
__m256i a_unsigned_u8x32 = _mm256_xor_si256(a.ymm, _mm256_set1_epi8((char)0x80));
|
|
187
|
+
state->biased_product_sum_i32x8 = _mm256_dpbusd_avx_epi32(state->biased_product_sum_i32x8, a_unsigned_u8x32, b.ymm);
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
NK_INTERNAL void nk_dot_i8x32_finalize_alder( //
|
|
191
|
+
nk_dot_i8x32_state_alder_t const *state_a, nk_dot_i8x32_state_alder_t const *state_b, //
|
|
192
|
+
nk_dot_i8x32_state_alder_t const *state_c, nk_dot_i8x32_state_alder_t const *state_d, //
|
|
193
|
+
nk_size_t total_dimensions, //
|
|
194
|
+
nk_i32_t a_sum, /* A row sum (unused for i8) */ //
|
|
195
|
+
nk_b128_vec_t b_sums, /* 4 × i32 B column sums */ //
|
|
196
|
+
nk_b128_vec_t *results) {
|
|
197
|
+
nk_unused_(total_dimensions);
|
|
198
|
+
nk_unused_(a_sum);
|
|
199
|
+
|
|
200
|
+
// Reduce biased products: ymm (i32x8) → xmm (i32x4)
|
|
201
|
+
__m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->biased_product_sum_i32x8),
|
|
202
|
+
_mm256_extracti128_si256(state_a->biased_product_sum_i32x8, 1));
|
|
203
|
+
__m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->biased_product_sum_i32x8),
|
|
204
|
+
_mm256_extracti128_si256(state_b->biased_product_sum_i32x8, 1));
|
|
205
|
+
__m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->biased_product_sum_i32x8),
|
|
206
|
+
_mm256_extracti128_si256(state_c->biased_product_sum_i32x8, 1));
|
|
207
|
+
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->biased_product_sum_i32x8),
|
|
208
|
+
_mm256_extracti128_si256(state_d->biased_product_sum_i32x8, 1));
|
|
209
|
+
|
|
210
|
+
// 4-way transpose reduce
|
|
211
|
+
__m128i t_ab_lo = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
212
|
+
__m128i t_cd_lo = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
213
|
+
__m128i t_ab_hi = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
214
|
+
__m128i t_cd_hi = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
215
|
+
__m128i biased_i32x4 = _mm_add_epi32(
|
|
216
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
|
|
217
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
|
|
218
|
+
|
|
219
|
+
// Apply compensation: result = biased − 128 × Σb
|
|
220
|
+
__m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
|
|
221
|
+
results->xmm = _mm_sub_epi32(biased_i32x4, correction_i32x4);
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
NK_PUBLIC void nk_dot_u8_alder(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
|
|
225
|
+
nk_u32_t *result) {
|
|
226
|
+
// Optimized u8×u8 dot product using algebraic transformation with DPBUSD
|
|
227
|
+
//
|
|
228
|
+
// Algebraic transformation:
|
|
229
|
+
// Let b' = b XOR 0x80 (converts unsigned to signed: b' = b - 128)
|
|
230
|
+
// dpbusd(a, b') computes: a × (b-128) [unsigned × signed]
|
|
231
|
+
// Therefore: a×b = a×(b-128) + 128×sum(a)
|
|
232
|
+
//
|
|
233
|
+
// Where:
|
|
234
|
+
// - XOR with 0x80 converts unsigned u8 [0,255] to signed [-128,127]
|
|
235
|
+
// - dpbusd performs unsigned×signed multiply-accumulate
|
|
236
|
+
// - sad_epu8 computes sum(a) as correction term
|
|
237
|
+
// - Correction term 128×sum(a) is added at the end
|
|
238
|
+
//
|
|
239
|
+
// Performance: ~1.8-2.0× speedup expected over unpack + dpwssd approach
|
|
240
|
+
// - Processes 32 elements/iteration
|
|
241
|
+
// - Lower latency per iteration
|
|
242
|
+
// - Eliminates unpack operations
|
|
243
|
+
// - dpbusd runs in parallel with sad
|
|
244
|
+
//
|
|
245
|
+
__m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
|
|
246
|
+
__m256i const zeros_u8x32 = _mm256_setzero_si256();
|
|
247
|
+
__m256i sum_ab_i32x8 = _mm256_setzero_si256();
|
|
248
|
+
__m256i sum_a_i64x4 = _mm256_setzero_si256();
|
|
249
|
+
__m256i a_u8x32, b_u8x32;
|
|
250
|
+
|
|
251
|
+
nk_dot_u8_alder_cycle:
|
|
252
|
+
if (count_scalars < 32) {
|
|
253
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
254
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
255
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
256
|
+
a_u8x32 = _mm256_load_si256(&a_vec.ymm);
|
|
257
|
+
b_u8x32 = _mm256_load_si256(&b_vec.ymm);
|
|
258
|
+
count_scalars = 0;
|
|
259
|
+
}
|
|
260
|
+
else {
|
|
261
|
+
a_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
262
|
+
b_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
263
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
// Convert b to signed [-128,127] by XOR with 0x80: b_signed = b - 128
|
|
267
|
+
__m256i b_signed_i8x32 = _mm256_xor_si256(b_u8x32, xor_mask_u8x32);
|
|
268
|
+
|
|
269
|
+
// Compute a × (b-128) using dpbusd: unsigned × signed
|
|
270
|
+
sum_ab_i32x8 = _mm256_dpbusd_avx_epi32(sum_ab_i32x8, a_u8x32, b_signed_i8x32);
|
|
271
|
+
|
|
272
|
+
// Accumulate sum(a) for correction term using sad_epu8 (1cy @ p5)
|
|
273
|
+
sum_a_i64x4 = _mm256_add_epi64(sum_a_i64x4, _mm256_sad_epu8(a_u8x32, zeros_u8x32));
|
|
274
|
+
|
|
275
|
+
if (count_scalars) goto nk_dot_u8_alder_cycle;
|
|
276
|
+
|
|
277
|
+
// Apply algebraic correction: a×b = a×(b-128) + 128×sum(a)
|
|
278
|
+
nk_i32_t ab_dot_signed = nk_reduce_add_i32x8_haswell_(sum_ab_i32x8);
|
|
279
|
+
|
|
280
|
+
// Reduce sum_a from 4 i64 values to scalar
|
|
281
|
+
__m128i sum_a_low_i64x2 = _mm256_castsi256_si128(sum_a_i64x4);
|
|
282
|
+
__m128i sum_a_high_i64x2 = _mm256_extracti128_si256(sum_a_i64x4, 1);
|
|
283
|
+
__m128i sum_a_i64x2 = _mm_add_epi64(sum_a_low_i64x2, sum_a_high_i64x2);
|
|
284
|
+
__m128i sum_a_shuffled = _mm_shuffle_epi32(sum_a_i64x2, _MM_SHUFFLE(1, 0, 3, 2));
|
|
285
|
+
__m128i sum_a_final = _mm_add_epi64(sum_a_i64x2, sum_a_shuffled);
|
|
286
|
+
nk_i64_t sum_a = _mm_cvtsi128_si64(sum_a_final);
|
|
287
|
+
|
|
288
|
+
nk_i64_t correction = 128LL * sum_a;
|
|
289
|
+
|
|
290
|
+
*result = (nk_u32_t)(ab_dot_signed + correction);
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
typedef struct nk_dot_u8x32_state_alder_t {
|
|
294
|
+
__m256i biased_product_sum_i32x8; // Single accumulator: DPBUSD(b, a^0x80), correction applied at finalize
|
|
295
|
+
} nk_dot_u8x32_state_alder_t;
|
|
296
|
+
|
|
297
|
+
NK_INTERNAL void nk_dot_u8x32_init_alder(nk_dot_u8x32_state_alder_t *state) {
|
|
298
|
+
state->biased_product_sum_i32x8 = _mm256_setzero_si256();
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
NK_INTERNAL void nk_dot_u8x32_update_alder(nk_dot_u8x32_state_alder_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
|
|
302
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
303
|
+
nk_unused_(depth_offset);
|
|
304
|
+
nk_unused_(active_dimensions);
|
|
305
|
+
// Operand swap: DPBUSD(b, a^0x80) = b·(a−128) → result = biased + 128·Σb
|
|
306
|
+
__m256i a_signed_i8x32 = _mm256_xor_si256(a.ymm, _mm256_set1_epi8((char)0x80));
|
|
307
|
+
state->biased_product_sum_i32x8 = _mm256_dpbusd_avx_epi32(state->biased_product_sum_i32x8, b.ymm, a_signed_i8x32);
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
NK_INTERNAL void nk_dot_u8x32_finalize_alder( //
|
|
311
|
+
nk_dot_u8x32_state_alder_t const *state_a, nk_dot_u8x32_state_alder_t const *state_b, //
|
|
312
|
+
nk_dot_u8x32_state_alder_t const *state_c, nk_dot_u8x32_state_alder_t const *state_d, //
|
|
313
|
+
nk_size_t total_dimensions, //
|
|
314
|
+
nk_i32_t a_sum, /* A row sum (unused for u8) */ //
|
|
315
|
+
nk_b128_vec_t b_sums, /* 4 × u32 B column sums */ //
|
|
316
|
+
nk_b128_vec_t *result) {
|
|
317
|
+
nk_unused_(total_dimensions);
|
|
318
|
+
nk_unused_(a_sum);
|
|
319
|
+
|
|
320
|
+
// Reduce biased products: ymm (i32x8) → xmm (i32x4)
|
|
321
|
+
__m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->biased_product_sum_i32x8),
|
|
322
|
+
_mm256_extracti128_si256(state_a->biased_product_sum_i32x8, 1));
|
|
323
|
+
__m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->biased_product_sum_i32x8),
|
|
324
|
+
_mm256_extracti128_si256(state_b->biased_product_sum_i32x8, 1));
|
|
325
|
+
__m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->biased_product_sum_i32x8),
|
|
326
|
+
_mm256_extracti128_si256(state_c->biased_product_sum_i32x8, 1));
|
|
327
|
+
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->biased_product_sum_i32x8),
|
|
328
|
+
_mm256_extracti128_si256(state_d->biased_product_sum_i32x8, 1));
|
|
329
|
+
|
|
330
|
+
// 4-way transpose reduce
|
|
331
|
+
__m128i t_ab_lo = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
332
|
+
__m128i t_cd_lo = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
333
|
+
__m128i t_ab_hi = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
334
|
+
__m128i t_cd_hi = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
335
|
+
__m128i biased_i32x4 = _mm_add_epi32(
|
|
336
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
|
|
337
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
|
|
338
|
+
|
|
339
|
+
// Apply compensation: result = biased + 128 × Σb
|
|
340
|
+
__m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
|
|
341
|
+
result->xmm = _mm_add_epi32(biased_i32x4, correction_i32x4);
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
/**
|
|
345
|
+
* Stateful element-sum helpers for compensated symmetric GEMM.
|
|
346
|
+
* SAD runs on port 5 while DPBUSD runs on ports 0+1 — zero throughput cost when inlined.
|
|
347
|
+
*/
|
|
348
|
+
|
|
349
|
+
/* i8x32: signed i8 sum via XOR→unsigned + SAD, bias-corrected at finalize */
|
|
350
|
+
typedef struct nk_sum_i8x32_state_alder_t {
|
|
351
|
+
__m256i biased_sum_u64x4; /* Accumulates SAD of (v ^ 0x80), needs bias correction at finalize */
|
|
352
|
+
} nk_sum_i8x32_state_alder_t;
|
|
353
|
+
|
|
354
|
+
NK_INTERNAL void nk_sum_i8x32_init_alder(nk_sum_i8x32_state_alder_t *state) {
|
|
355
|
+
state->biased_sum_u64x4 = _mm256_setzero_si256();
|
|
356
|
+
}
|
|
357
|
+
NK_INTERNAL void nk_sum_i8x32_update_alder(nk_sum_i8x32_state_alder_t *state, nk_b256_vec_t vector) {
|
|
358
|
+
/* Convert signed→unsigned via XOR 0x80, then SAD against zero gives sum of unsigned values */
|
|
359
|
+
__m256i vector_unsigned_u8x32 = _mm256_xor_si256(vector.ymm, _mm256_set1_epi8((char)0x80));
|
|
360
|
+
__m256i sad_result_u64x4 = _mm256_sad_epu8(vector_unsigned_u8x32, _mm256_setzero_si256());
|
|
361
|
+
state->biased_sum_u64x4 = _mm256_add_epi64(state->biased_sum_u64x4, sad_result_u64x4);
|
|
362
|
+
}
|
|
363
|
+
NK_INTERNAL nk_i32_t nk_sum_i8x32_finalize_alder(nk_sum_i8x32_state_alder_t const *state, nk_size_t count) {
|
|
364
|
+
/* Horizontal reduce u64x4 → scalar */
|
|
365
|
+
__m128i low_u64x2 = _mm256_castsi256_si128(state->biased_sum_u64x4);
|
|
366
|
+
__m128i high_u64x2 = _mm256_extracti128_si256(state->biased_sum_u64x4, 1);
|
|
367
|
+
__m128i paired_u64x2 = _mm_add_epi64(low_u64x2, high_u64x2);
|
|
368
|
+
__m128i shuffled_u64x2 = _mm_shuffle_epi32(paired_u64x2, _MM_SHUFFLE(1, 0, 3, 2));
|
|
369
|
+
__m128i total_u64x2 = _mm_add_epi64(paired_u64x2, shuffled_u64x2);
|
|
370
|
+
nk_u64_t unsigned_sum = (nk_u64_t)_mm_cvtsi128_si64(total_u64x2);
|
|
371
|
+
/* Undo XOR bias: signed_sum = unsigned_sum - 128 * count */
|
|
372
|
+
return (nk_i32_t)((nk_i64_t)unsigned_sum - 128 * (nk_i64_t)count);
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
/* u8x32: unsigned u8 sum via plain SAD — no bias correction needed */
|
|
376
|
+
typedef struct nk_sum_u8x32_state_alder_t {
|
|
377
|
+
__m256i sum_u64x4; /* Direct SAD accumulator */
|
|
378
|
+
} nk_sum_u8x32_state_alder_t;
|
|
379
|
+
|
|
380
|
+
NK_INTERNAL void nk_sum_u8x32_init_alder(nk_sum_u8x32_state_alder_t *state) {
|
|
381
|
+
state->sum_u64x4 = _mm256_setzero_si256();
|
|
382
|
+
}
|
|
383
|
+
NK_INTERNAL void nk_sum_u8x32_update_alder(nk_sum_u8x32_state_alder_t *state, nk_b256_vec_t vector) {
|
|
384
|
+
__m256i sad_result_u64x4 = _mm256_sad_epu8(vector.ymm, _mm256_setzero_si256());
|
|
385
|
+
state->sum_u64x4 = _mm256_add_epi64(state->sum_u64x4, sad_result_u64x4);
|
|
386
|
+
}
|
|
387
|
+
NK_INTERNAL nk_u32_t nk_sum_u8x32_finalize_alder(nk_sum_u8x32_state_alder_t const *state, nk_size_t count) {
|
|
388
|
+
nk_unused_(count);
|
|
389
|
+
__m128i low_u64x2 = _mm256_castsi256_si128(state->sum_u64x4);
|
|
390
|
+
__m128i high_u64x2 = _mm256_extracti128_si256(state->sum_u64x4, 1);
|
|
391
|
+
__m128i paired_u64x2 = _mm_add_epi64(low_u64x2, high_u64x2);
|
|
392
|
+
__m128i shuffled_u64x2 = _mm_shuffle_epi32(paired_u64x2, _MM_SHUFFLE(1, 0, 3, 2));
|
|
393
|
+
__m128i total_u64x2 = _mm_add_epi64(paired_u64x2, shuffled_u64x2);
|
|
394
|
+
return (nk_u32_t)_mm_cvtsi128_si64(total_u64x2);
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
NK_PUBLIC void nk_dot_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
398
|
+
nk_f32_t *result) {
|
|
399
|
+
// Integer dot product for e2m3 using dual-VPSHUFB (LUT) + VPDPBUSD (unsigned×signed).
|
|
400
|
+
// Every e2m3 value × 16 is an exact integer in [-120, +120].
|
|
401
|
+
// Result = i32_dot / 256.0f (exact, no rounding error).
|
|
402
|
+
//
|
|
403
|
+
// This is the Alder Lake (256-bit AVX-VNNI) variant of the Ice Lake kernel.
|
|
404
|
+
// DPBUSD replaces MADDUBS+MADD (2 instructions → 1), accumulating u8×i8→i32 directly.
|
|
405
|
+
//
|
|
406
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8( //
|
|
407
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
408
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
409
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8( //
|
|
410
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
411
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
412
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
413
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
414
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
415
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
416
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
417
|
+
__m256i a_e2m3_u8x32, b_e2m3_u8x32;
|
|
418
|
+
|
|
419
|
+
nk_dot_e2m3_alder_cycle:
|
|
420
|
+
if (count_scalars < 32) {
|
|
421
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
422
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
423
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
424
|
+
a_e2m3_u8x32 = a_vec.ymm;
|
|
425
|
+
b_e2m3_u8x32 = b_vec.ymm;
|
|
426
|
+
count_scalars = 0;
|
|
427
|
+
}
|
|
428
|
+
else {
|
|
429
|
+
a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
430
|
+
b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
431
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
// Extract 5-bit magnitude, then split into low 4 bits (VPSHUFB index) and bit 4 (hi/lo select)
|
|
435
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
436
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
437
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
438
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
439
|
+
__m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
440
|
+
half_select_u8x32);
|
|
441
|
+
__m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
442
|
+
half_select_u8x32);
|
|
443
|
+
|
|
444
|
+
// Dual VPSHUFB: lookup in both halves, blend based on bit 4
|
|
445
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
|
|
446
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
|
|
447
|
+
a_upper_select_u8x32);
|
|
448
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
|
|
449
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
|
|
450
|
+
b_upper_select_u8x32);
|
|
451
|
+
|
|
452
|
+
// Combined sign: (a ^ b) & 0x20, negate b where signs differ
|
|
453
|
+
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
454
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
|
|
455
|
+
__m256i b_negated_u8x32 = _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32);
|
|
456
|
+
__m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32, b_negated_u8x32, negate_mask_u8x32);
|
|
457
|
+
|
|
458
|
+
// VPDPBUSD: a_unsigned[u8] × b_signed[i8] → i32 (replaces VPMADDUBSW + VPMADDWD)
|
|
459
|
+
sum_i32x8 = _mm256_dpbusd_avx_epi32(sum_i32x8, a_unsigned_u8x32, b_signed_i8x32);
|
|
460
|
+
|
|
461
|
+
if (count_scalars) goto nk_dot_e2m3_alder_cycle;
|
|
462
|
+
*result = (nk_f32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8) / 256.0f;
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
typedef struct nk_dot_e2m3x32_state_alder_t {
|
|
466
|
+
__m256i sum_i32x8; // DPBUSD accumulator: u8_magnitude × i8_signed → i32
|
|
467
|
+
} nk_dot_e2m3x32_state_alder_t;
|
|
468
|
+
|
|
469
|
+
NK_INTERNAL void nk_dot_e2m3x32_init_alder(nk_dot_e2m3x32_state_alder_t *state) {
|
|
470
|
+
state->sum_i32x8 = _mm256_setzero_si256();
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
NK_INTERNAL void nk_dot_e2m3x32_update_alder(nk_dot_e2m3x32_state_alder_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
|
|
474
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
475
|
+
nk_unused_(depth_offset);
|
|
476
|
+
nk_unused_(active_dimensions);
|
|
477
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8( //
|
|
478
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
479
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
480
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8( //
|
|
481
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
482
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
483
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
484
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
485
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
486
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
487
|
+
|
|
488
|
+
__m256i a_e2m3_u8x32 = a.ymm;
|
|
489
|
+
__m256i b_e2m3_u8x32 = b.ymm;
|
|
490
|
+
|
|
491
|
+
// Extract 5-bit magnitude, split into low 4 bits and bit 4
|
|
492
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
493
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
494
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
495
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
496
|
+
__m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
497
|
+
half_select_u8x32);
|
|
498
|
+
__m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
499
|
+
half_select_u8x32);
|
|
500
|
+
|
|
501
|
+
// Dual VPSHUFB + blend
|
|
502
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
|
|
503
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
|
|
504
|
+
a_upper_select_u8x32);
|
|
505
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
|
|
506
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
|
|
507
|
+
b_upper_select_u8x32);
|
|
508
|
+
|
|
509
|
+
// Combined sign + conditional negate
|
|
510
|
+
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
511
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
|
|
512
|
+
__m256i b_negated_u8x32 = _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32);
|
|
513
|
+
__m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32, b_negated_u8x32, negate_mask_u8x32);
|
|
514
|
+
|
|
515
|
+
// VPDPBUSD: u8 × i8 → i32
|
|
516
|
+
state->sum_i32x8 = _mm256_dpbusd_avx_epi32(state->sum_i32x8, a_unsigned_u8x32, b_signed_i8x32);
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
NK_INTERNAL void nk_dot_e2m3x32_finalize_alder( //
|
|
520
|
+
nk_dot_e2m3x32_state_alder_t const *state_a, nk_dot_e2m3x32_state_alder_t const *state_b, //
|
|
521
|
+
nk_dot_e2m3x32_state_alder_t const *state_c, nk_dot_e2m3x32_state_alder_t const *state_d, //
|
|
522
|
+
nk_size_t total_dimensions, nk_b128_vec_t *results) {
|
|
523
|
+
nk_unused_(total_dimensions);
|
|
524
|
+
|
|
525
|
+
// ILP-optimized 4-way horizontal reduction: i32x8 → scalar i32, then → f32 with ÷256
|
|
526
|
+
__m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->sum_i32x8),
|
|
527
|
+
_mm256_extracti128_si256(state_a->sum_i32x8, 1));
|
|
528
|
+
__m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->sum_i32x8),
|
|
529
|
+
_mm256_extracti128_si256(state_b->sum_i32x8, 1));
|
|
530
|
+
__m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->sum_i32x8),
|
|
531
|
+
_mm256_extracti128_si256(state_c->sum_i32x8, 1));
|
|
532
|
+
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->sum_i32x8),
|
|
533
|
+
_mm256_extracti128_si256(state_d->sum_i32x8, 1));
|
|
534
|
+
|
|
535
|
+
// Transpose for SIMD reduction
|
|
536
|
+
__m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
537
|
+
__m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
538
|
+
__m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
539
|
+
__m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
540
|
+
__m128i lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
541
|
+
__m128i lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
542
|
+
__m128i lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
543
|
+
__m128i lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
544
|
+
__m128i sum_i32x4 = _mm_add_epi32(_mm_add_epi32(lane0_i32x4, lane1_i32x4), _mm_add_epi32(lane2_i32x4, lane3_i32x4));
|
|
545
|
+
|
|
546
|
+
// Convert i32 → f32 and scale by 1/256
|
|
547
|
+
__m128 sum_f32x4 = _mm_mul_ps(_mm_cvtepi32_ps(sum_i32x4), _mm_set1_ps(1.0f / 256.0f));
|
|
548
|
+
results->xmm = _mm_castps_si128(sum_f32x4);
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
#if defined(__clang__)
|
|
552
|
+
#pragma clang attribute pop
|
|
553
|
+
#elif defined(__GNUC__)
|
|
554
|
+
#pragma GCC pop_options
|
|
555
|
+
#endif
|
|
556
|
+
|
|
557
|
+
#if defined(__cplusplus)
|
|
558
|
+
} // extern "C"
|
|
559
|
+
#endif
|
|
560
|
+
|
|
561
|
+
#endif // NK_TARGET_ALDER
|
|
562
|
+
#endif // NK_TARGET_X86_
|
|
563
|
+
#endif // NK_DOT_ALDER_H
|