numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for Sierra Forest.
|
|
3
|
+
* @file include/numkong/dot/sierra.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_sierra_instructions AVX-VNNI-INT8 Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction
|
|
12
|
+
* _mm256_dpbssd_epi32 VPDPBSSD (YMM, YMM, YMM) i8 x i8 -> i32
|
|
13
|
+
* _mm256_dpbuud_epi32 VPDPBUUD (YMM, YMM, YMM) u8 x u8 -> u32
|
|
14
|
+
*
|
|
15
|
+
* Sierra Forest CPUs support AVX-VNNI-INT8, adding native signed*signed and
|
|
16
|
+
* unsigned*unsigned 8-bit dot products. This eliminates the algebraic sign
|
|
17
|
+
* transformations required on Alder Lake (AVX-VNNI only).
|
|
18
|
+
*
|
|
19
|
+
* @section dot_sierra_stateful Stateful Streaming Logic
|
|
20
|
+
*
|
|
21
|
+
* To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
|
|
22
|
+
* `NK_INTERNAL` functions:
|
|
23
|
+
*
|
|
24
|
+
* - nk_dot_i8x32 for 8-bit signed integer inputs using native DPBSSD (no algebraic transform),
|
|
25
|
+
* - nk_dot_u8x32 for 8-bit unsigned integer inputs using native DPBUUD (no algebraic transform).
|
|
26
|
+
*
|
|
27
|
+
* Each state struct contains only a single accumulator field (no correction terms needed).
|
|
28
|
+
*
|
|
29
|
+
* @code{c}
|
|
30
|
+
* nk_dot_i8x32_state_sierra_t state_first, state_second, state_third, state_fourth;
|
|
31
|
+
* nk_b256_vec_t query_i8x32, target_first_i8x32, target_second_i8x32, target_third_i8x32, target_fourth_i8x32;
|
|
32
|
+
* nk_dot_i8x32_init_sierra(&state_first);
|
|
33
|
+
* nk_dot_i8x32_init_sierra(&state_second);
|
|
34
|
+
* nk_dot_i8x32_init_sierra(&state_third);
|
|
35
|
+
* nk_dot_i8x32_init_sierra(&state_fourth);
|
|
36
|
+
* for (nk_size_t idx = 0; idx + 32 <= depth; idx += 32) {
|
|
37
|
+
* query_i8x32.ymm = _mm256_loadu_si256(query_ptr + idx);
|
|
38
|
+
* target_first_i8x32.ymm = _mm256_loadu_si256(target_first_ptr + idx);
|
|
39
|
+
* target_second_i8x32.ymm = _mm256_loadu_si256(target_second_ptr + idx);
|
|
40
|
+
* target_third_i8x32.ymm = _mm256_loadu_si256(target_third_ptr + idx);
|
|
41
|
+
* target_fourth_i8x32.ymm = _mm256_loadu_si256(target_fourth_ptr + idx);
|
|
42
|
+
* nk_dot_i8x32_update_sierra(&state_first, query_i8x32, target_first_i8x32, idx, 32);
|
|
43
|
+
* nk_dot_i8x32_update_sierra(&state_second, query_i8x32, target_second_i8x32, idx, 32);
|
|
44
|
+
* nk_dot_i8x32_update_sierra(&state_third, query_i8x32, target_third_i8x32, idx, 32);
|
|
45
|
+
* nk_dot_i8x32_update_sierra(&state_fourth, query_i8x32, target_fourth_i8x32, idx, 32);
|
|
46
|
+
* }
|
|
47
|
+
* nk_b128_vec_t results_i32x4;
|
|
48
|
+
* nk_dot_i8x32_finalize_sierra(&state_first, &state_second, &state_third, &state_fourth, depth, &results_i32x4);
|
|
49
|
+
* @endcode
|
|
50
|
+
*
|
|
51
|
+
* The unsigned variant follows the same pattern with appropriate type changes:
|
|
52
|
+
*
|
|
53
|
+
* @code{c}
|
|
54
|
+
* nk_dot_u8x32_state_sierra_t state_first, state_second, state_third, state_fourth;
|
|
55
|
+
* nk_b256_vec_t query_u8x32, target_first_u8x32, target_second_u8x32, target_third_u8x32, target_fourth_u8x32;
|
|
56
|
+
* nk_dot_u8x32_init_sierra(&state_first);
|
|
57
|
+
* nk_dot_u8x32_init_sierra(&state_second);
|
|
58
|
+
* nk_dot_u8x32_init_sierra(&state_third);
|
|
59
|
+
* nk_dot_u8x32_init_sierra(&state_fourth);
|
|
60
|
+
* for (nk_size_t idx = 0; idx + 32 <= depth; idx += 32) {
|
|
61
|
+
* query_u8x32.ymm = _mm256_loadu_si256(query_ptr + idx);
|
|
62
|
+
* target_first_u8x32.ymm = _mm256_loadu_si256(target_first_ptr + idx);
|
|
63
|
+
* target_second_u8x32.ymm = _mm256_loadu_si256(target_second_ptr + idx);
|
|
64
|
+
* target_third_u8x32.ymm = _mm256_loadu_si256(target_third_ptr + idx);
|
|
65
|
+
* target_fourth_u8x32.ymm = _mm256_loadu_si256(target_fourth_ptr + idx);
|
|
66
|
+
* nk_dot_u8x32_update_sierra(&state_first, query_u8x32, target_first_u8x32, idx, 32);
|
|
67
|
+
* nk_dot_u8x32_update_sierra(&state_second, query_u8x32, target_second_u8x32, idx, 32);
|
|
68
|
+
* nk_dot_u8x32_update_sierra(&state_third, query_u8x32, target_third_u8x32, idx, 32);
|
|
69
|
+
* nk_dot_u8x32_update_sierra(&state_fourth, query_u8x32, target_fourth_u8x32, idx, 32);
|
|
70
|
+
* }
|
|
71
|
+
* nk_b128_vec_t results_u32x4;
|
|
72
|
+
* nk_dot_u8x32_finalize_sierra(&state_first, &state_second, &state_third, &state_fourth, depth, &results_u32x4);
|
|
73
|
+
* @endcode
|
|
74
|
+
*/
|
|
75
|
+
#ifndef NK_DOT_SIERRA_H
|
|
76
|
+
#define NK_DOT_SIERRA_H
|
|
77
|
+
|
|
78
|
+
#if NK_TARGET_X86_
|
|
79
|
+
#if NK_TARGET_SIERRA
|
|
80
|
+
|
|
81
|
+
#include "numkong/types.h"
|
|
82
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b8x32_serial_`
|
|
83
|
+
#include "numkong/reduce/haswell.h" // `nk_reduce_add_i32x8_haswell_`
|
|
84
|
+
|
|
85
|
+
#if defined(__cplusplus)
|
|
86
|
+
extern "C" {
|
|
87
|
+
#endif
|
|
88
|
+
|
|
89
|
+
#if defined(__clang__)
|
|
90
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2,avxvnni,avxvnniint8"))), apply_to = function)
|
|
91
|
+
#elif defined(__GNUC__)
|
|
92
|
+
#pragma GCC push_options
|
|
93
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2", "avxvnni", "avxvnniint8")
|
|
94
|
+
#endif
|
|
95
|
+
|
|
96
|
+
NK_PUBLIC void nk_dot_i8_sierra(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
97
|
+
nk_i32_t *result) {
|
|
98
|
+
// Native i8*i8 dot product using DPBSSD (signed * signed -> i32)
|
|
99
|
+
// No algebraic transformation needed - dpbssd handles signed*signed directly.
|
|
100
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
101
|
+
__m256i a_i8x32, b_i8x32;
|
|
102
|
+
|
|
103
|
+
nk_dot_i8_sierra_cycle:
|
|
104
|
+
if (count_scalars < 32) {
|
|
105
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
106
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
107
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
108
|
+
a_i8x32 = _mm256_load_si256(&a_vec.ymm);
|
|
109
|
+
b_i8x32 = _mm256_load_si256(&b_vec.ymm);
|
|
110
|
+
count_scalars = 0;
|
|
111
|
+
}
|
|
112
|
+
else {
|
|
113
|
+
a_i8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
114
|
+
b_i8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
115
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
// VPDPBSSD: signed i8 * signed i8 -> i32 accumulation
|
|
119
|
+
sum_i32x8 = _mm256_dpbssd_epi32(sum_i32x8, a_i8x32, b_i8x32);
|
|
120
|
+
|
|
121
|
+
if (count_scalars) goto nk_dot_i8_sierra_cycle;
|
|
122
|
+
|
|
123
|
+
*result = nk_reduce_add_i32x8_haswell_(sum_i32x8);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
typedef struct nk_dot_i8x32_state_sierra_t {
|
|
127
|
+
__m256i sum_i32x8; // DPBSSD accumulator: i8 * i8 -> i32
|
|
128
|
+
} nk_dot_i8x32_state_sierra_t;
|
|
129
|
+
|
|
130
|
+
NK_INTERNAL void nk_dot_i8x32_init_sierra(nk_dot_i8x32_state_sierra_t *state) {
|
|
131
|
+
state->sum_i32x8 = _mm256_setzero_si256();
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
NK_INTERNAL void nk_dot_i8x32_update_sierra(nk_dot_i8x32_state_sierra_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
|
|
135
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
136
|
+
nk_unused_(depth_offset);
|
|
137
|
+
nk_unused_(active_dimensions);
|
|
138
|
+
state->sum_i32x8 = _mm256_dpbssd_epi32(state->sum_i32x8, a.ymm, b.ymm);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
NK_INTERNAL void nk_dot_i8x32_finalize_sierra( //
|
|
142
|
+
nk_dot_i8x32_state_sierra_t const *state_a, nk_dot_i8x32_state_sierra_t const *state_b, //
|
|
143
|
+
nk_dot_i8x32_state_sierra_t const *state_c, nk_dot_i8x32_state_sierra_t const *state_d, //
|
|
144
|
+
nk_size_t total_dimensions, nk_b128_vec_t *results) {
|
|
145
|
+
nk_unused_(total_dimensions);
|
|
146
|
+
|
|
147
|
+
// ILP-optimized 4-way horizontal reduction: i32x8 -> scalar i32
|
|
148
|
+
__m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->sum_i32x8),
|
|
149
|
+
_mm256_extracti128_si256(state_a->sum_i32x8, 1));
|
|
150
|
+
__m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->sum_i32x8),
|
|
151
|
+
_mm256_extracti128_si256(state_b->sum_i32x8, 1));
|
|
152
|
+
__m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->sum_i32x8),
|
|
153
|
+
_mm256_extracti128_si256(state_c->sum_i32x8, 1));
|
|
154
|
+
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->sum_i32x8),
|
|
155
|
+
_mm256_extracti128_si256(state_d->sum_i32x8, 1));
|
|
156
|
+
|
|
157
|
+
// Transpose and reduce
|
|
158
|
+
__m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
159
|
+
__m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
160
|
+
__m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
161
|
+
__m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
162
|
+
__m128i lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
163
|
+
__m128i lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
164
|
+
__m128i lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
165
|
+
__m128i lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
166
|
+
results->xmm = _mm_add_epi32(_mm_add_epi32(lane0_i32x4, lane1_i32x4), _mm_add_epi32(lane2_i32x4, lane3_i32x4));
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
NK_PUBLIC void nk_dot_u8_sierra(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
|
|
170
|
+
nk_u32_t *result) {
|
|
171
|
+
// Native u8*u8 dot product using DPBUUD (unsigned * unsigned -> u32)
|
|
172
|
+
// No algebraic transformation needed - dpbuud handles unsigned*unsigned directly.
|
|
173
|
+
__m256i sum_u32x8 = _mm256_setzero_si256();
|
|
174
|
+
__m256i a_u8x32, b_u8x32;
|
|
175
|
+
|
|
176
|
+
nk_dot_u8_sierra_cycle:
|
|
177
|
+
if (count_scalars < 32) {
|
|
178
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
179
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
180
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
181
|
+
a_u8x32 = _mm256_load_si256(&a_vec.ymm);
|
|
182
|
+
b_u8x32 = _mm256_load_si256(&b_vec.ymm);
|
|
183
|
+
count_scalars = 0;
|
|
184
|
+
}
|
|
185
|
+
else {
|
|
186
|
+
a_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
187
|
+
b_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
188
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
// VPDPBUUD: unsigned u8 * unsigned u8 -> u32 accumulation
|
|
192
|
+
sum_u32x8 = _mm256_dpbuud_epi32(sum_u32x8, a_u8x32, b_u8x32);
|
|
193
|
+
|
|
194
|
+
if (count_scalars) goto nk_dot_u8_sierra_cycle;
|
|
195
|
+
|
|
196
|
+
// Reduce u32x8 to scalar - reinterpret as i32 for reduction, cast back
|
|
197
|
+
*result = (nk_u32_t)(nk_i32_t)nk_reduce_add_i32x8_haswell_(sum_u32x8);
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
typedef struct nk_dot_u8x32_state_sierra_t {
|
|
201
|
+
__m256i sum_u32x8; // DPBUUD accumulator: u8 * u8 -> u32
|
|
202
|
+
} nk_dot_u8x32_state_sierra_t;
|
|
203
|
+
|
|
204
|
+
NK_INTERNAL void nk_dot_u8x32_init_sierra(nk_dot_u8x32_state_sierra_t *state) {
|
|
205
|
+
state->sum_u32x8 = _mm256_setzero_si256();
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
NK_INTERNAL void nk_dot_u8x32_update_sierra(nk_dot_u8x32_state_sierra_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
|
|
209
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
210
|
+
nk_unused_(depth_offset);
|
|
211
|
+
nk_unused_(active_dimensions);
|
|
212
|
+
state->sum_u32x8 = _mm256_dpbuud_epi32(state->sum_u32x8, a.ymm, b.ymm);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
NK_INTERNAL void nk_dot_u8x32_finalize_sierra( //
|
|
216
|
+
nk_dot_u8x32_state_sierra_t const *state_a, nk_dot_u8x32_state_sierra_t const *state_b, //
|
|
217
|
+
nk_dot_u8x32_state_sierra_t const *state_c, nk_dot_u8x32_state_sierra_t const *state_d, //
|
|
218
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
219
|
+
nk_unused_(total_dimensions);
|
|
220
|
+
|
|
221
|
+
// Same transpose+reduce pattern but simpler - no correction term
|
|
222
|
+
__m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->sum_u32x8),
|
|
223
|
+
_mm256_extracti128_si256(state_a->sum_u32x8, 1));
|
|
224
|
+
__m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->sum_u32x8),
|
|
225
|
+
_mm256_extracti128_si256(state_b->sum_u32x8, 1));
|
|
226
|
+
__m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->sum_u32x8),
|
|
227
|
+
_mm256_extracti128_si256(state_c->sum_u32x8, 1));
|
|
228
|
+
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->sum_u32x8),
|
|
229
|
+
_mm256_extracti128_si256(state_d->sum_u32x8, 1));
|
|
230
|
+
|
|
231
|
+
__m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
232
|
+
__m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
233
|
+
__m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
234
|
+
__m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
235
|
+
__m128i lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
236
|
+
__m128i lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
237
|
+
__m128i lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
238
|
+
__m128i lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
239
|
+
result->xmm = _mm_add_epi32(_mm_add_epi32(lane0_i32x4, lane1_i32x4), _mm_add_epi32(lane2_i32x4, lane3_i32x4));
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
NK_PUBLIC void nk_dot_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
243
|
+
nk_f32_t *result) {
|
|
244
|
+
// Integer dot product for e2m3 using dual-VPSHUFB (LUT) + VPDPBSSD (signed*signed).
|
|
245
|
+
// Every e2m3 value * 16 is an exact integer in [-120, +120].
|
|
246
|
+
// Result = i32_dot / 256.0f (exact, no rounding error).
|
|
247
|
+
//
|
|
248
|
+
// Uses dpbssd instead of dpbusd — both operands are already signed i8 after
|
|
249
|
+
// LUT + sign application, so no unsigned conversion is needed.
|
|
250
|
+
//
|
|
251
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
|
|
252
|
+
26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
253
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
254
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
255
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
256
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
257
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
258
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
259
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
260
|
+
__m256i a_e2m3_u8x32, b_e2m3_u8x32;
|
|
261
|
+
|
|
262
|
+
nk_dot_e2m3_sierra_cycle:
|
|
263
|
+
if (count_scalars < 32) {
|
|
264
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
265
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
266
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
267
|
+
a_e2m3_u8x32 = a_vec.ymm;
|
|
268
|
+
b_e2m3_u8x32 = b_vec.ymm;
|
|
269
|
+
count_scalars = 0;
|
|
270
|
+
}
|
|
271
|
+
else {
|
|
272
|
+
a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
273
|
+
b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
274
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
// Decode a: extract magnitude, dual-VPSHUFB LUT, apply sign
|
|
278
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
279
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
280
|
+
__m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
281
|
+
half_select_u8x32);
|
|
282
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
|
|
283
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
|
|
284
|
+
a_upper_select_u8x32);
|
|
285
|
+
__m256i a_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
286
|
+
__m256i a_signed_i8x32 = _mm256_blendv_epi8(
|
|
287
|
+
a_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate_mask_u8x32);
|
|
288
|
+
|
|
289
|
+
// Decode b: same LUT decode + sign
|
|
290
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
291
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
292
|
+
__m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
293
|
+
half_select_u8x32);
|
|
294
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
|
|
295
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
|
|
296
|
+
b_upper_select_u8x32);
|
|
297
|
+
__m256i b_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
298
|
+
__m256i b_signed_i8x32 = _mm256_blendv_epi8(
|
|
299
|
+
b_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate_mask_u8x32);
|
|
300
|
+
|
|
301
|
+
// VPDPBSSD: signed i8 * signed i8 -> i32
|
|
302
|
+
sum_i32x8 = _mm256_dpbssd_epi32(sum_i32x8, a_signed_i8x32, b_signed_i8x32);
|
|
303
|
+
|
|
304
|
+
if (count_scalars) goto nk_dot_e2m3_sierra_cycle;
|
|
305
|
+
*result = (nk_f32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8) / 256.0f;
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
typedef struct nk_dot_e2m3x32_state_sierra_t {
|
|
309
|
+
__m256i sum_i32x8; // DPBSSD accumulator: i8_signed * i8_signed -> i32
|
|
310
|
+
} nk_dot_e2m3x32_state_sierra_t;
|
|
311
|
+
|
|
312
|
+
NK_INTERNAL void nk_dot_e2m3x32_init_sierra(nk_dot_e2m3x32_state_sierra_t *state) {
|
|
313
|
+
state->sum_i32x8 = _mm256_setzero_si256();
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
NK_INTERNAL void nk_dot_e2m3x32_update_sierra(nk_dot_e2m3x32_state_sierra_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
|
|
317
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
318
|
+
nk_unused_(depth_offset);
|
|
319
|
+
nk_unused_(active_dimensions);
|
|
320
|
+
// Same LUT constants...
|
|
321
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
|
|
322
|
+
26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
323
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
324
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
325
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
326
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
327
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
328
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
329
|
+
|
|
330
|
+
__m256i a_e2m3_u8x32 = a.ymm;
|
|
331
|
+
__m256i b_e2m3_u8x32 = b.ymm;
|
|
332
|
+
|
|
333
|
+
// Decode a
|
|
334
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
335
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
336
|
+
__m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
337
|
+
half_select_u8x32);
|
|
338
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
|
|
339
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
|
|
340
|
+
a_upper_select_u8x32);
|
|
341
|
+
__m256i a_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
342
|
+
__m256i a_signed_i8x32 = _mm256_blendv_epi8(
|
|
343
|
+
a_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate_mask_u8x32);
|
|
344
|
+
|
|
345
|
+
// Decode b
|
|
346
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
347
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
348
|
+
__m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
349
|
+
half_select_u8x32);
|
|
350
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
|
|
351
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
|
|
352
|
+
b_upper_select_u8x32);
|
|
353
|
+
__m256i b_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
354
|
+
__m256i b_signed_i8x32 = _mm256_blendv_epi8(
|
|
355
|
+
b_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate_mask_u8x32);
|
|
356
|
+
|
|
357
|
+
// VPDPBSSD: signed * signed -> i32
|
|
358
|
+
state->sum_i32x8 = _mm256_dpbssd_epi32(state->sum_i32x8, a_signed_i8x32, b_signed_i8x32);
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
NK_INTERNAL void nk_dot_e2m3x32_finalize_sierra( //
|
|
362
|
+
nk_dot_e2m3x32_state_sierra_t const *state_a, nk_dot_e2m3x32_state_sierra_t const *state_b, //
|
|
363
|
+
nk_dot_e2m3x32_state_sierra_t const *state_c, nk_dot_e2m3x32_state_sierra_t const *state_d, //
|
|
364
|
+
nk_size_t total_dimensions, nk_b128_vec_t *results) {
|
|
365
|
+
nk_unused_(total_dimensions);
|
|
366
|
+
|
|
367
|
+
// ILP-optimized 4-way horizontal reduction: i32x8 -> scalar i32, then -> f32 with /256
|
|
368
|
+
__m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->sum_i32x8),
|
|
369
|
+
_mm256_extracti128_si256(state_a->sum_i32x8, 1));
|
|
370
|
+
__m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->sum_i32x8),
|
|
371
|
+
_mm256_extracti128_si256(state_b->sum_i32x8, 1));
|
|
372
|
+
__m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->sum_i32x8),
|
|
373
|
+
_mm256_extracti128_si256(state_c->sum_i32x8, 1));
|
|
374
|
+
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->sum_i32x8),
|
|
375
|
+
_mm256_extracti128_si256(state_d->sum_i32x8, 1));
|
|
376
|
+
|
|
377
|
+
// Transpose for SIMD reduction
|
|
378
|
+
__m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
379
|
+
__m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
380
|
+
__m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
381
|
+
__m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
382
|
+
__m128i lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
383
|
+
__m128i lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
384
|
+
__m128i lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
385
|
+
__m128i lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
386
|
+
__m128i sum_i32x4 = _mm_add_epi32(_mm_add_epi32(lane0_i32x4, lane1_i32x4), _mm_add_epi32(lane2_i32x4, lane3_i32x4));
|
|
387
|
+
|
|
388
|
+
// Convert i32 -> f32 and scale by 1/256
|
|
389
|
+
__m128 sum_f32x4 = _mm_mul_ps(_mm_cvtepi32_ps(sum_i32x4), _mm_set1_ps(1.0f / 256.0f));
|
|
390
|
+
results->xmm = _mm_castps_si128(sum_f32x4);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
#if defined(__clang__)
|
|
394
|
+
#pragma clang attribute pop
|
|
395
|
+
#elif defined(__GNUC__)
|
|
396
|
+
#pragma GCC pop_options
|
|
397
|
+
#endif
|
|
398
|
+
|
|
399
|
+
#if defined(__cplusplus)
|
|
400
|
+
} // extern "C"
|
|
401
|
+
#endif
|
|
402
|
+
|
|
403
|
+
#endif // NK_TARGET_SIERRA
|
|
404
|
+
#endif // NK_TARGET_X86_
|
|
405
|
+
#endif // NK_DOT_SIERRA_H
|