numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1084 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for Skylake.
|
|
3
|
+
* @file include/numkong/dot/skylake.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_skylake_instructions Key AVX-512 Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput Ports
|
|
12
|
+
* _mm512_madd_epi16 VPMADDWD (ZMM, ZMM, ZMM) 5cy 0.5/cy p05
|
|
13
|
+
* _mm512_add_epi32 VPADDD (ZMM, ZMM, ZMM) 1cy 0.5/cy p05
|
|
14
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p05
|
|
15
|
+
* _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy 1/cy p5
|
|
16
|
+
*
|
|
17
|
+
* Skylake-X server chips feature dual 512-bit FMA units on ports 0 and 5, enabling 0.5cy throughput for
|
|
18
|
+
* VFMADD and arithmetic operations. Client Skylake variants have only one FMA unit with 1cy throughput.
|
|
19
|
+
* Without VNNI support, integer dot products use VPMADDWD for i16 pair multiplication with i32 accumulation.
|
|
20
|
+
*
|
|
21
|
+
* @section dot_skylake_stateful Stateful Streaming Logic
|
|
22
|
+
*
|
|
23
|
+
* To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
|
|
24
|
+
* `NK_INTERNAL` functions:
|
|
25
|
+
*
|
|
26
|
+
* - nk_dot_f64x8 state with Dot2 stable dot-products,
|
|
27
|
+
* - nk_dot_f32x8 state with double-precision numerics,
|
|
28
|
+
* - nk_dot_through_f32 state for 16-bit float inputs with single-precision numerics.
|
|
29
|
+
*
|
|
30
|
+
* @code{c}
|
|
31
|
+
* nk_dot_f64x8_state_skylake_t state_first, state_second, state_third, state_fourth;
|
|
32
|
+
* nk_b512_vec_t query_f64x8, target_first_f64x8, target_second_f64x8, target_third_f64x8, target_fourth_f64x8;
|
|
33
|
+
* nk_dot_f64x8_init_skylake(&state_first);
|
|
34
|
+
* nk_dot_f64x8_init_skylake(&state_second);
|
|
35
|
+
* nk_dot_f64x8_init_skylake(&state_third);
|
|
36
|
+
* nk_dot_f64x8_init_skylake(&state_fourth);
|
|
37
|
+
* for (nk_size_t idx = 0; idx + 8 <= depth; idx += 8) {
|
|
38
|
+
* query_f64x8.zmm_pd = _mm512_loadu_pd(query_ptr + idx);
|
|
39
|
+
* target_first_f64x8.zmm_pd = _mm512_loadu_pd(target_first_ptr + idx);
|
|
40
|
+
* target_second_f64x8.zmm_pd = _mm512_loadu_pd(target_second_ptr + idx);
|
|
41
|
+
* target_third_f64x8.zmm_pd = _mm512_loadu_pd(target_third_ptr + idx);
|
|
42
|
+
* target_fourth_f64x8.zmm_pd = _mm512_loadu_pd(target_fourth_ptr + idx);
|
|
43
|
+
* nk_dot_f64x8_update_skylake(&state_first, query_f64x8, target_first_f64x8, idx, 8);
|
|
44
|
+
* nk_dot_f64x8_update_skylake(&state_second, query_f64x8, target_second_f64x8, idx, 8);
|
|
45
|
+
* nk_dot_f64x8_update_skylake(&state_third, query_f64x8, target_third_f64x8, idx, 8);
|
|
46
|
+
* nk_dot_f64x8_update_skylake(&state_fourth, query_f64x8, target_fourth_f64x8, idx, 8);
|
|
47
|
+
* }
|
|
48
|
+
* nk_b256_vec_t results_f64x4;
|
|
49
|
+
* nk_dot_f64x8_finalize_skylake(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f64x4);
|
|
50
|
+
* @endcode
|
|
51
|
+
*
|
|
52
|
+
* Smaller float types like f16 and bf16 on Skylake use ISA-specific upcasting to f32 combined with native
|
|
53
|
+
* FMA instructions, sharing the `nk_dot_through_f32` accumulation logic:
|
|
54
|
+
*
|
|
55
|
+
* @code{c}
|
|
56
|
+
* nk_dot_f16x16_state_skylake_t state_first, state_second, state_third, state_fourth;
|
|
57
|
+
* nk_b512_vec_t query_f32x16, target_first_f32x16, target_second_f32x16, target_third_f32x16, target_fourth_f32x16;
|
|
58
|
+
* nk_dot_through_f32_init_skylake_(&state_first);
|
|
59
|
+
* nk_dot_through_f32_init_skylake_(&state_second);
|
|
60
|
+
* nk_dot_through_f32_init_skylake_(&state_third);
|
|
61
|
+
* nk_dot_through_f32_init_skylake_(&state_fourth);
|
|
62
|
+
* for (nk_size_t idx = 0; idx + 16 <= depth; idx += 16) {
|
|
63
|
+
* nk_load_f16x16_to_f32x16_skylake_(query_ptr + idx, &query_f32x16);
|
|
64
|
+
* nk_load_f16x16_to_f32x16_skylake_(target_first_ptr + idx, &target_first_f32x16);
|
|
65
|
+
* nk_load_f16x16_to_f32x16_skylake_(target_second_ptr + idx, &target_second_f32x16);
|
|
66
|
+
* nk_load_f16x16_to_f32x16_skylake_(target_third_ptr + idx, &target_third_f32x16);
|
|
67
|
+
* nk_load_f16x16_to_f32x16_skylake_(target_fourth_ptr + idx, &target_fourth_f32x16);
|
|
68
|
+
* nk_dot_through_f32_update_skylake_(&state_first, query_f32x16, target_first_f32x16, idx, 16);
|
|
69
|
+
* nk_dot_through_f32_update_skylake_(&state_second, query_f32x16, target_second_f32x16, idx, 16);
|
|
70
|
+
* nk_dot_through_f32_update_skylake_(&state_third, query_f32x16, target_third_f32x16, idx, 16);
|
|
71
|
+
* nk_dot_through_f32_update_skylake_(&state_fourth, query_f32x16, target_fourth_f32x16, idx, 16);
|
|
72
|
+
* }
|
|
73
|
+
* nk_b128_vec_t results_f32x4;
|
|
74
|
+
* nk_dot_through_f32_finalize_skylake_(&state_first, &state_second, &state_third, &state_fourth,
|
|
75
|
+
* depth, &results_f32x4);
|
|
76
|
+
* @endcode
|
|
77
|
+
*/
|
|
78
|
+
#ifndef NK_DOT_SKYLAKE_H
|
|
79
|
+
#define NK_DOT_SKYLAKE_H
|
|
80
|
+
|
|
81
|
+
#if NK_TARGET_X86_
|
|
82
|
+
#if NK_TARGET_SKYLAKE
|
|
83
|
+
|
|
84
|
+
#include "numkong/cast/skylake.h" // `nk_bf16x16_to_f32x16_skylake_`
|
|
85
|
+
#include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
|
|
86
|
+
#include "numkong/dot/haswell.h" // `nk_dot_stable_sum_f64x4_haswell_`
|
|
87
|
+
|
|
88
|
+
#if defined(__cplusplus)
|
|
89
|
+
extern "C" {
|
|
90
|
+
#endif
|
|
91
|
+
|
|
92
|
+
#if defined(__clang__)
|
|
93
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
|
|
94
|
+
apply_to = function)
|
|
95
|
+
#elif defined(__GNUC__)
|
|
96
|
+
#pragma GCC push_options
|
|
97
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
98
|
+
#endif
|
|
99
|
+
|
|
100
|
+
/** @brief Compensated horizontal sum of 8 f64 lanes via TwoSum tree reduction. */
|
|
101
|
+
NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x8_skylake_(__m512d sum_f64x8, __m512d compensation_f64x8) {
|
|
102
|
+
// Stage 0: TwoSum merge of sum + compensation (8-wide)
|
|
103
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(sum_f64x8, compensation_f64x8);
|
|
104
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, sum_f64x8);
|
|
105
|
+
__m512d rounding_error_f64x8 = _mm512_add_pd(
|
|
106
|
+
_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
|
|
107
|
+
_mm512_sub_pd(compensation_f64x8, virtual_addend_f64x8));
|
|
108
|
+
|
|
109
|
+
// Stage 1: TwoSum halving 8→4
|
|
110
|
+
__m256d lower_sum_f64x4 = _mm512_castpd512_pd256(tentative_sum_f64x8);
|
|
111
|
+
__m256d upper_sum_f64x4 = _mm512_extractf64x4_pd(tentative_sum_f64x8, 1);
|
|
112
|
+
__m256d tentative_sum_f64x4 = _mm256_add_pd(lower_sum_f64x4, upper_sum_f64x4);
|
|
113
|
+
__m256d virtual_addend_f64x4 = _mm256_sub_pd(tentative_sum_f64x4, lower_sum_f64x4);
|
|
114
|
+
__m256d rounding_error_f64x4 = _mm256_add_pd(
|
|
115
|
+
_mm256_sub_pd(lower_sum_f64x4, _mm256_sub_pd(tentative_sum_f64x4, virtual_addend_f64x4)),
|
|
116
|
+
_mm256_sub_pd(upper_sum_f64x4, virtual_addend_f64x4));
|
|
117
|
+
__m256d lower_error_f64x4 = _mm512_castpd512_pd256(rounding_error_f64x8);
|
|
118
|
+
__m256d upper_error_f64x4 = _mm512_extractf64x4_pd(rounding_error_f64x8, 1);
|
|
119
|
+
__m256d accumulated_error_f64x4 = _mm256_add_pd(_mm256_add_pd(lower_error_f64x4, upper_error_f64x4),
|
|
120
|
+
rounding_error_f64x4);
|
|
121
|
+
|
|
122
|
+
// Stages 2-3: Delegate to Haswell for 4→2→1 reduction
|
|
123
|
+
return nk_dot_stable_sum_f64x4_haswell_(tentative_sum_f64x4, accumulated_error_f64x4);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
#pragma region - Traditional Floats
|
|
127
|
+
|
|
128
|
+
/**
|
|
129
|
+
* @brief Internal helper state for dot-products of low-precision types, where 32-bit accumulation is enough.
|
|
130
|
+
* @sa nk_dot_f16x16_state_skylake_t, nk_dot_bf16x16_state_skylake_t
|
|
131
|
+
* @sa nk_dot_e4m3x16_state_skylake_t, nk_dot_e5m2x16_state_skylake_t
|
|
132
|
+
*/
|
|
133
|
+
typedef struct nk_dot_through_f32_state_skylake_t_ {
|
|
134
|
+
__m512 sum_f32x16;
|
|
135
|
+
} nk_dot_through_f32_state_skylake_t_;
|
|
136
|
+
|
|
137
|
+
/**
|
|
138
|
+
* @brief Initializes 32-bit accumulators for low-precision dot-products.
|
|
139
|
+
* @sa nk_dot_f16x16_init_skylake, nk_dot_bf16x16_init_skylake
|
|
140
|
+
* @sa nk_dot_e4m3x16_init_skylake, nk_dot_e5m2x16_init_skylake
|
|
141
|
+
*/
|
|
142
|
+
NK_INTERNAL void nk_dot_through_f32_init_skylake_(nk_dot_through_f32_state_skylake_t_ *state) {
|
|
143
|
+
state->sum_f32x16 = _mm512_setzero_ps();
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
/**
|
|
147
|
+
* @brief Fuses 32-bit multiplication and accumulation for low-precision dot-products.
|
|
148
|
+
* @sa nk_dot_f16x16_udpate_skylake, nk_dot_bf16x16_udpate_skylake
|
|
149
|
+
* @sa nk_dot_e4m3x16_udpate_skylake, nk_dot_e5m2x16_udpate_skylake
|
|
150
|
+
*/
|
|
151
|
+
NK_INTERNAL void nk_dot_through_f32_update_skylake_(nk_dot_through_f32_state_skylake_t_ *state, nk_b512_vec_t a,
|
|
152
|
+
nk_b512_vec_t b, nk_size_t depth_offset,
|
|
153
|
+
nk_size_t active_dimensions) {
|
|
154
|
+
nk_unused_(depth_offset);
|
|
155
|
+
nk_unused_(active_dimensions);
|
|
156
|
+
state->sum_f32x16 = _mm512_fmadd_ps(a.zmm_ps, b.zmm_ps, state->sum_f32x16);
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
/**
|
|
160
|
+
* @brief Finalizes 4x low-precision dot-products placing them into 4x consecutive 32-bit slots.
|
|
161
|
+
* @sa nk_dot_f16x16_udpate_skylake, nk_dot_bf16x16_udpate_skylake
|
|
162
|
+
* @sa nk_dot_e4m3x16_udpate_skylake, nk_dot_e5m2x16_udpate_skylake
|
|
163
|
+
*
|
|
164
|
+
* The goal of this kernel is simple - compute 4x horizontal reductions, each involing 16x floats.
|
|
165
|
+
* The lack of vectorized horizontal instruction implies many consecutive shuffles producing a tree-like
|
|
166
|
+
* reduction. This kernel allow combinding some of those operations between different dot products.
|
|
167
|
+
*/
|
|
168
|
+
NK_INTERNAL void nk_dot_through_f32_finalize_skylake_( //
|
|
169
|
+
nk_dot_through_f32_state_skylake_t_ const *state_a, nk_dot_through_f32_state_skylake_t_ const *state_b, //
|
|
170
|
+
nk_dot_through_f32_state_skylake_t_ const *state_c, nk_dot_through_f32_state_skylake_t_ const *state_d, //
|
|
171
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
172
|
+
nk_unused_(total_dimensions);
|
|
173
|
+
|
|
174
|
+
__m512 const sum_a_f32x16 = state_a->sum_f32x16, sum_b_f32x16 = state_b->sum_f32x16,
|
|
175
|
+
sum_c_f32x16 = state_c->sum_f32x16, sum_d_f32x16 = state_d->sum_f32x16;
|
|
176
|
+
|
|
177
|
+
// ILP-optimized 4-way horizontal reduction for f32x16 in AVX-512
|
|
178
|
+
// Step 1: 16 → 8 for all 4 states (extract high 256-bit half and add to low half)
|
|
179
|
+
__m256 sum_a_f32x8 = _mm256_add_ps(_mm512_castps512_ps256(sum_a_f32x16), _mm512_extractf32x8_ps(sum_a_f32x16, 1));
|
|
180
|
+
__m256 sum_b_f32x8 = _mm256_add_ps(_mm512_castps512_ps256(sum_b_f32x16), _mm512_extractf32x8_ps(sum_b_f32x16, 1));
|
|
181
|
+
__m256 sum_c_f32x8 = _mm256_add_ps(_mm512_castps512_ps256(sum_c_f32x16), _mm512_extractf32x8_ps(sum_c_f32x16, 1));
|
|
182
|
+
__m256 sum_d_f32x8 = _mm256_add_ps(_mm512_castps512_ps256(sum_d_f32x16), _mm512_extractf32x8_ps(sum_d_f32x16, 1));
|
|
183
|
+
// Step 2: 8 → 4 for all 4 states (extract high 128-bit half and add to low half)
|
|
184
|
+
__m128 sum_a_f32x4 = _mm_add_ps(_mm256_castps256_ps128(sum_a_f32x8), _mm256_extractf128_ps(sum_a_f32x8, 1));
|
|
185
|
+
__m128 sum_b_f32x4 = _mm_add_ps(_mm256_castps256_ps128(sum_b_f32x8), _mm256_extractf128_ps(sum_b_f32x8, 1));
|
|
186
|
+
__m128 sum_c_f32x4 = _mm_add_ps(_mm256_castps256_ps128(sum_c_f32x8), _mm256_extractf128_ps(sum_c_f32x8, 1));
|
|
187
|
+
__m128 sum_d_f32x4 = _mm_add_ps(_mm256_castps256_ps128(sum_d_f32x8), _mm256_extractf128_ps(sum_d_f32x8, 1));
|
|
188
|
+
// Step 3: Transpose 4x4 and reduce to get final 4 scalars
|
|
189
|
+
__m128 transpose_ab_low_f32x4 = _mm_unpacklo_ps(sum_a_f32x4, sum_b_f32x4);
|
|
190
|
+
__m128 transpose_cd_low_f32x4 = _mm_unpacklo_ps(sum_c_f32x4, sum_d_f32x4);
|
|
191
|
+
__m128 transpose_ab_high_f32x4 = _mm_unpackhi_ps(sum_a_f32x4, sum_b_f32x4);
|
|
192
|
+
__m128 transpose_cd_high_f32x4 = _mm_unpackhi_ps(sum_c_f32x4, sum_d_f32x4);
|
|
193
|
+
__m128 sum_lane0_f32x4 = _mm_movelh_ps(transpose_ab_low_f32x4, transpose_cd_low_f32x4);
|
|
194
|
+
__m128 sum_lane1_f32x4 = _mm_movehl_ps(transpose_cd_low_f32x4, transpose_ab_low_f32x4);
|
|
195
|
+
__m128 sum_lane2_f32x4 = _mm_movelh_ps(transpose_ab_high_f32x4, transpose_cd_high_f32x4);
|
|
196
|
+
__m128 sum_lane3_f32x4 = _mm_movehl_ps(transpose_cd_high_f32x4, transpose_ab_high_f32x4);
|
|
197
|
+
__m128 final_sum_f32x4 = _mm_add_ps(_mm_add_ps(sum_lane0_f32x4, sum_lane1_f32x4),
|
|
198
|
+
_mm_add_ps(sum_lane2_f32x4, sum_lane3_f32x4));
|
|
199
|
+
result->xmm = _mm_castps_si128(final_sum_f32x4);
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
NK_PUBLIC void nk_dot_f32_skylake(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
203
|
+
nk_f64_t *result) {
|
|
204
|
+
__m256 a_f32x8, b_f32x8;
|
|
205
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
206
|
+
|
|
207
|
+
nk_dot_f32_skylake_cycle:
|
|
208
|
+
if (count_scalars < 8) {
|
|
209
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, count_scalars);
|
|
210
|
+
a_f32x8 = _mm256_maskz_loadu_ps(mask, a_scalars);
|
|
211
|
+
b_f32x8 = _mm256_maskz_loadu_ps(mask, b_scalars);
|
|
212
|
+
count_scalars = 0;
|
|
213
|
+
}
|
|
214
|
+
else {
|
|
215
|
+
a_f32x8 = _mm256_loadu_ps(a_scalars);
|
|
216
|
+
b_f32x8 = _mm256_loadu_ps(b_scalars);
|
|
217
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
218
|
+
}
|
|
219
|
+
sum_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_f32x8), _mm512_cvtps_pd(b_f32x8), sum_f64x8);
|
|
220
|
+
if (count_scalars) goto nk_dot_f32_skylake_cycle;
|
|
221
|
+
|
|
222
|
+
*result = _mm512_reduce_add_pd(sum_f64x8);
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
NK_PUBLIC void nk_dot_f64_skylake(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
|
|
226
|
+
nk_f64_t *result) {
|
|
227
|
+
// Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product
|
|
228
|
+
__m512d a_f64x8, b_f64x8;
|
|
229
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
230
|
+
__m512d compensation_f64x8 = _mm512_setzero_pd();
|
|
231
|
+
|
|
232
|
+
nk_dot_f64_skylake_cycle:
|
|
233
|
+
if (count_scalars < 8) {
|
|
234
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, count_scalars);
|
|
235
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, a_scalars);
|
|
236
|
+
b_f64x8 = _mm512_maskz_loadu_pd(mask, b_scalars);
|
|
237
|
+
count_scalars = 0;
|
|
238
|
+
}
|
|
239
|
+
else {
|
|
240
|
+
a_f64x8 = _mm512_loadu_pd(a_scalars);
|
|
241
|
+
b_f64x8 = _mm512_loadu_pd(b_scalars);
|
|
242
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
243
|
+
}
|
|
244
|
+
// TwoProd: h = a * b, r = fma(a, b, -h) captures the rounding error
|
|
245
|
+
__m512d product_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
246
|
+
__m512d product_error_f64x8 = _mm512_fmsub_pd(a_f64x8, b_f64x8, product_f64x8);
|
|
247
|
+
// TwoSum: (t, q) = TwoSum(sum, h) where t = sum + h rounded, q = error
|
|
248
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(sum_f64x8, product_f64x8);
|
|
249
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, sum_f64x8);
|
|
250
|
+
__m512d sum_error_f64x8 = _mm512_add_pd(
|
|
251
|
+
_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
|
|
252
|
+
_mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
|
|
253
|
+
// Update: sum = t, compensation += q + r
|
|
254
|
+
sum_f64x8 = tentative_sum_f64x8;
|
|
255
|
+
compensation_f64x8 = _mm512_add_pd(compensation_f64x8, _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
256
|
+
if (count_scalars) goto nk_dot_f64_skylake_cycle;
|
|
257
|
+
|
|
258
|
+
// Compensated horizontal reduction preserving Dot2 error tracking
|
|
259
|
+
*result = nk_dot_stable_sum_f64x8_skylake_(sum_f64x8, compensation_f64x8);
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
NK_PUBLIC void nk_dot_f32c_skylake(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
263
|
+
nk_f64c_t *result) {
|
|
264
|
+
__m256 a_f32x8, b_f32x8;
|
|
265
|
+
__m512d sum_real_f64x8 = _mm512_setzero_pd();
|
|
266
|
+
__m512d sum_imag_f64x8 = _mm512_setzero_pd();
|
|
267
|
+
|
|
268
|
+
// We take into account, that FMS is the same as FMA with a negative multiplier.
|
|
269
|
+
// To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
|
|
270
|
+
// This way we can avoid the shuffling and the need for separate real and imaginary parts.
|
|
271
|
+
// For the imaginary part of the product, we would need to swap the real and imaginary parts of
|
|
272
|
+
// one of the vectors.
|
|
273
|
+
__m512i const sign_flip_f64x8 = _mm512_set_epi64(0x8000000000000000, 0, 0x8000000000000000, 0, 0x8000000000000000,
|
|
274
|
+
0, 0x8000000000000000, 0);
|
|
275
|
+
nk_dot_f32c_skylake_cycle:
|
|
276
|
+
if (count_pairs < 4) {
|
|
277
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, count_pairs * 2);
|
|
278
|
+
a_f32x8 = _mm256_maskz_loadu_ps(mask, (nk_f32_t const *)a_pairs);
|
|
279
|
+
b_f32x8 = _mm256_maskz_loadu_ps(mask, (nk_f32_t const *)b_pairs);
|
|
280
|
+
count_pairs = 0;
|
|
281
|
+
}
|
|
282
|
+
else {
|
|
283
|
+
a_f32x8 = _mm256_loadu_ps((nk_f32_t const *)a_pairs);
|
|
284
|
+
b_f32x8 = _mm256_loadu_ps((nk_f32_t const *)b_pairs);
|
|
285
|
+
a_pairs += 4, b_pairs += 4, count_pairs -= 4;
|
|
286
|
+
}
|
|
287
|
+
__m512d a_f64x8 = _mm512_cvtps_pd(a_f32x8);
|
|
288
|
+
__m512d b_f64x8 = _mm512_cvtps_pd(b_f32x8);
|
|
289
|
+
__m512d b_swapped_f64x8 = _mm512_permute_pd(b_f64x8, 0x55);
|
|
290
|
+
sum_real_f64x8 = _mm512_fmadd_pd(a_f64x8, b_f64x8, sum_real_f64x8);
|
|
291
|
+
sum_imag_f64x8 = _mm512_fmadd_pd(a_f64x8, b_swapped_f64x8, sum_imag_f64x8);
|
|
292
|
+
if (count_pairs) goto nk_dot_f32c_skylake_cycle;
|
|
293
|
+
|
|
294
|
+
// Flip the sign bit in every second f64 before accumulation:
|
|
295
|
+
sum_real_f64x8 = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(sum_real_f64x8), sign_flip_f64x8));
|
|
296
|
+
|
|
297
|
+
// Reduce horizontal sums:
|
|
298
|
+
result->real = _mm512_reduce_add_pd(sum_real_f64x8);
|
|
299
|
+
result->imag = _mm512_reduce_add_pd(sum_imag_f64x8);
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
NK_PUBLIC void nk_vdot_f32c_skylake(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
303
|
+
nk_f64c_t *result) {
|
|
304
|
+
__m256 a_f32x8, b_f32x8;
|
|
305
|
+
__m512d sum_real_f64x8 = _mm512_setzero_pd();
|
|
306
|
+
__m512d sum_imag_f64x8 = _mm512_setzero_pd();
|
|
307
|
+
|
|
308
|
+
// We take into account, that FMS is the same as FMA with a negative multiplier.
|
|
309
|
+
// To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
|
|
310
|
+
// This way we can avoid the shuffling and the need for separate real and imaginary parts.
|
|
311
|
+
// For the imaginary part of the product, we would need to swap the real and imaginary parts of
|
|
312
|
+
// one of the vectors.
|
|
313
|
+
__m512i const sign_flip_f64x8 = _mm512_set_epi64(0x8000000000000000, 0, 0x8000000000000000, 0, 0x8000000000000000,
|
|
314
|
+
0, 0x8000000000000000, 0);
|
|
315
|
+
nk_vdot_f32c_skylake_cycle:
|
|
316
|
+
if (count_pairs < 4) {
|
|
317
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, count_pairs * 2);
|
|
318
|
+
a_f32x8 = _mm256_maskz_loadu_ps(mask, (nk_f32_t const *)a_pairs);
|
|
319
|
+
b_f32x8 = _mm256_maskz_loadu_ps(mask, (nk_f32_t const *)b_pairs);
|
|
320
|
+
count_pairs = 0;
|
|
321
|
+
}
|
|
322
|
+
else {
|
|
323
|
+
a_f32x8 = _mm256_loadu_ps((nk_f32_t const *)a_pairs);
|
|
324
|
+
b_f32x8 = _mm256_loadu_ps((nk_f32_t const *)b_pairs);
|
|
325
|
+
a_pairs += 4, b_pairs += 4, count_pairs -= 4;
|
|
326
|
+
}
|
|
327
|
+
__m512d a_f64x8 = _mm512_cvtps_pd(a_f32x8);
|
|
328
|
+
__m512d b_f64x8 = _mm512_cvtps_pd(b_f32x8);
|
|
329
|
+
sum_real_f64x8 = _mm512_fmadd_pd(a_f64x8, b_f64x8, sum_real_f64x8);
|
|
330
|
+
__m512d b_swapped_f64x8 = _mm512_permute_pd(b_f64x8, 0x55);
|
|
331
|
+
sum_imag_f64x8 = _mm512_fmadd_pd(a_f64x8, b_swapped_f64x8, sum_imag_f64x8);
|
|
332
|
+
if (count_pairs) goto nk_vdot_f32c_skylake_cycle;
|
|
333
|
+
|
|
334
|
+
// Flip the sign bit in every second f64 before accumulation:
|
|
335
|
+
sum_imag_f64x8 = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(sum_imag_f64x8), sign_flip_f64x8));
|
|
336
|
+
|
|
337
|
+
// Reduce horizontal sums:
|
|
338
|
+
result->real = _mm512_reduce_add_pd(sum_real_f64x8);
|
|
339
|
+
result->imag = _mm512_reduce_add_pd(sum_imag_f64x8);
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
NK_PUBLIC void nk_dot_f64c_skylake(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
343
|
+
nk_f64c_t *result) {
|
|
344
|
+
// Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated complex dot product
|
|
345
|
+
__m512d a_f64x8, b_f64x8;
|
|
346
|
+
__m512d sum_real_f64x8 = _mm512_setzero_pd();
|
|
347
|
+
__m512d sum_imag_f64x8 = _mm512_setzero_pd();
|
|
348
|
+
__m512d compensation_real_f64x8 = _mm512_setzero_pd();
|
|
349
|
+
__m512d compensation_imag_f64x8 = _mm512_setzero_pd();
|
|
350
|
+
|
|
351
|
+
// We take into account, that FMS is the same as FMA with a negative multiplier.
|
|
352
|
+
// To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
|
|
353
|
+
// This way we can avoid the shuffling and the need for separate real and imaginary parts.
|
|
354
|
+
// For the imaginary part of the product, we would need to swap the real and imaginary parts of
|
|
355
|
+
// one of the vectors.
|
|
356
|
+
__m512i const sign_flip_f64x8 = _mm512_set_epi64( //
|
|
357
|
+
0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000, //
|
|
358
|
+
0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000 //
|
|
359
|
+
);
|
|
360
|
+
nk_dot_f64c_skylake_cycle:
|
|
361
|
+
if (count_pairs < 4) {
|
|
362
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, count_pairs * 2);
|
|
363
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, a_pairs);
|
|
364
|
+
b_f64x8 = _mm512_maskz_loadu_pd(mask, b_pairs);
|
|
365
|
+
count_pairs = 0;
|
|
366
|
+
}
|
|
367
|
+
else {
|
|
368
|
+
a_f64x8 = _mm512_loadu_pd(a_pairs);
|
|
369
|
+
b_f64x8 = _mm512_loadu_pd(b_pairs);
|
|
370
|
+
a_pairs += 4, b_pairs += 4, count_pairs -= 4;
|
|
371
|
+
}
|
|
372
|
+
__m512d b_swapped_f64x8 = _mm512_permute_pd(b_f64x8, 0x55); //? Same as 0b01010101.
|
|
373
|
+
|
|
374
|
+
// TwoProd for real part: a * b
|
|
375
|
+
__m512d product_real_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
376
|
+
__m512d product_real_error_f64x8 = _mm512_fmsub_pd(a_f64x8, b_f64x8, product_real_f64x8);
|
|
377
|
+
// TwoSum for real part
|
|
378
|
+
__m512d tentative_sum_real_f64x8 = _mm512_add_pd(sum_real_f64x8, product_real_f64x8);
|
|
379
|
+
__m512d virtual_addend_real_f64x8 = _mm512_sub_pd(tentative_sum_real_f64x8, sum_real_f64x8);
|
|
380
|
+
__m512d sum_real_error_f64x8 = _mm512_add_pd(
|
|
381
|
+
_mm512_sub_pd(sum_real_f64x8, _mm512_sub_pd(tentative_sum_real_f64x8, virtual_addend_real_f64x8)),
|
|
382
|
+
_mm512_sub_pd(product_real_f64x8, virtual_addend_real_f64x8));
|
|
383
|
+
sum_real_f64x8 = tentative_sum_real_f64x8;
|
|
384
|
+
compensation_real_f64x8 = _mm512_add_pd(compensation_real_f64x8,
|
|
385
|
+
_mm512_add_pd(sum_real_error_f64x8, product_real_error_f64x8));
|
|
386
|
+
|
|
387
|
+
// TwoProd for imag part: a * b_swapped
|
|
388
|
+
__m512d product_imag_f64x8 = _mm512_mul_pd(a_f64x8, b_swapped_f64x8);
|
|
389
|
+
__m512d product_imag_error_f64x8 = _mm512_fmsub_pd(a_f64x8, b_swapped_f64x8, product_imag_f64x8);
|
|
390
|
+
// TwoSum for imag part
|
|
391
|
+
__m512d tentative_sum_imag_f64x8 = _mm512_add_pd(sum_imag_f64x8, product_imag_f64x8);
|
|
392
|
+
__m512d virtual_addend_imag_f64x8 = _mm512_sub_pd(tentative_sum_imag_f64x8, sum_imag_f64x8);
|
|
393
|
+
__m512d sum_imag_error_f64x8 = _mm512_add_pd(
|
|
394
|
+
_mm512_sub_pd(sum_imag_f64x8, _mm512_sub_pd(tentative_sum_imag_f64x8, virtual_addend_imag_f64x8)),
|
|
395
|
+
_mm512_sub_pd(product_imag_f64x8, virtual_addend_imag_f64x8));
|
|
396
|
+
sum_imag_f64x8 = tentative_sum_imag_f64x8;
|
|
397
|
+
compensation_imag_f64x8 = _mm512_add_pd(compensation_imag_f64x8,
|
|
398
|
+
_mm512_add_pd(sum_imag_error_f64x8, product_imag_error_f64x8));
|
|
399
|
+
|
|
400
|
+
if (count_pairs) goto nk_dot_f64c_skylake_cycle;
|
|
401
|
+
|
|
402
|
+
// Flip the sign bit in every second scalar before accumulation (to get a_r*b_r - a_i*b_i):
|
|
403
|
+
sum_real_f64x8 = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(sum_real_f64x8), sign_flip_f64x8));
|
|
404
|
+
compensation_real_f64x8 = _mm512_castsi512_pd(
|
|
405
|
+
_mm512_xor_si512(_mm512_castpd_si512(compensation_real_f64x8), sign_flip_f64x8));
|
|
406
|
+
|
|
407
|
+
// Compensated horizontal reduction preserving Dot2 error tracking
|
|
408
|
+
result->real = nk_dot_stable_sum_f64x8_skylake_(sum_real_f64x8, compensation_real_f64x8);
|
|
409
|
+
result->imag = nk_dot_stable_sum_f64x8_skylake_(sum_imag_f64x8, compensation_imag_f64x8);
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
NK_PUBLIC void nk_vdot_f64c_skylake(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
413
|
+
nk_f64c_t *result) {
|
|
414
|
+
// Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated conjugate dot product
|
|
415
|
+
__m512d a_f64x8, b_f64x8;
|
|
416
|
+
__m512d sum_real_f64x8 = _mm512_setzero_pd();
|
|
417
|
+
__m512d sum_imag_f64x8 = _mm512_setzero_pd();
|
|
418
|
+
__m512d compensation_real_f64x8 = _mm512_setzero_pd();
|
|
419
|
+
__m512d compensation_imag_f64x8 = _mm512_setzero_pd();
|
|
420
|
+
|
|
421
|
+
// We take into account, that FMS is the same as FMA with a negative multiplier.
|
|
422
|
+
// To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
|
|
423
|
+
// This way we can avoid the shuffling and the need for separate real and imaginary parts.
|
|
424
|
+
// For the imaginary part of the product, we would need to swap the real and imaginary parts of
|
|
425
|
+
// one of the vectors.
|
|
426
|
+
__m512i const sign_flip_f64x8 = _mm512_set_epi64( //
|
|
427
|
+
0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000, //
|
|
428
|
+
0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000 //
|
|
429
|
+
);
|
|
430
|
+
nk_vdot_f64c_skylake_cycle:
|
|
431
|
+
if (count_pairs < 4) {
|
|
432
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, count_pairs * 2);
|
|
433
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, (nk_f64_t const *)a_pairs);
|
|
434
|
+
b_f64x8 = _mm512_maskz_loadu_pd(mask, (nk_f64_t const *)b_pairs);
|
|
435
|
+
count_pairs = 0;
|
|
436
|
+
}
|
|
437
|
+
else {
|
|
438
|
+
a_f64x8 = _mm512_loadu_pd((nk_f64_t const *)a_pairs);
|
|
439
|
+
b_f64x8 = _mm512_loadu_pd((nk_f64_t const *)b_pairs);
|
|
440
|
+
a_pairs += 4, b_pairs += 4, count_pairs -= 4;
|
|
441
|
+
}
|
|
442
|
+
__m512d b_swapped_f64x8 = _mm512_permute_pd(b_f64x8, 0x55); //? Same as 0b01010101.
|
|
443
|
+
|
|
444
|
+
// TwoProd for real part: a * b
|
|
445
|
+
__m512d product_real_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
446
|
+
__m512d product_real_error_f64x8 = _mm512_fmsub_pd(a_f64x8, b_f64x8, product_real_f64x8);
|
|
447
|
+
// TwoSum for real part
|
|
448
|
+
__m512d tentative_sum_real_f64x8 = _mm512_add_pd(sum_real_f64x8, product_real_f64x8);
|
|
449
|
+
__m512d virtual_addend_real_f64x8 = _mm512_sub_pd(tentative_sum_real_f64x8, sum_real_f64x8);
|
|
450
|
+
__m512d sum_real_error_f64x8 = _mm512_add_pd(
|
|
451
|
+
_mm512_sub_pd(sum_real_f64x8, _mm512_sub_pd(tentative_sum_real_f64x8, virtual_addend_real_f64x8)),
|
|
452
|
+
_mm512_sub_pd(product_real_f64x8, virtual_addend_real_f64x8));
|
|
453
|
+
sum_real_f64x8 = tentative_sum_real_f64x8;
|
|
454
|
+
compensation_real_f64x8 = _mm512_add_pd(compensation_real_f64x8,
|
|
455
|
+
_mm512_add_pd(sum_real_error_f64x8, product_real_error_f64x8));
|
|
456
|
+
|
|
457
|
+
// TwoProd for imag part: a * b_swapped
|
|
458
|
+
__m512d product_imag_f64x8 = _mm512_mul_pd(a_f64x8, b_swapped_f64x8);
|
|
459
|
+
__m512d product_imag_error_f64x8 = _mm512_fmsub_pd(a_f64x8, b_swapped_f64x8, product_imag_f64x8);
|
|
460
|
+
// TwoSum for imag part
|
|
461
|
+
__m512d tentative_sum_imag_f64x8 = _mm512_add_pd(sum_imag_f64x8, product_imag_f64x8);
|
|
462
|
+
__m512d virtual_addend_imag_f64x8 = _mm512_sub_pd(tentative_sum_imag_f64x8, sum_imag_f64x8);
|
|
463
|
+
__m512d sum_imag_error_f64x8 = _mm512_add_pd(
|
|
464
|
+
_mm512_sub_pd(sum_imag_f64x8, _mm512_sub_pd(tentative_sum_imag_f64x8, virtual_addend_imag_f64x8)),
|
|
465
|
+
_mm512_sub_pd(product_imag_f64x8, virtual_addend_imag_f64x8));
|
|
466
|
+
sum_imag_f64x8 = tentative_sum_imag_f64x8;
|
|
467
|
+
compensation_imag_f64x8 = _mm512_add_pd(compensation_imag_f64x8,
|
|
468
|
+
_mm512_add_pd(sum_imag_error_f64x8, product_imag_error_f64x8));
|
|
469
|
+
|
|
470
|
+
if (count_pairs) goto nk_vdot_f64c_skylake_cycle;
|
|
471
|
+
|
|
472
|
+
// Flip the sign bit in every second scalar before accumulation (to get a_r*b_i - a_i*b_r):
|
|
473
|
+
sum_imag_f64x8 = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(sum_imag_f64x8), sign_flip_f64x8));
|
|
474
|
+
compensation_imag_f64x8 = _mm512_castsi512_pd(
|
|
475
|
+
_mm512_xor_si512(_mm512_castpd_si512(compensation_imag_f64x8), sign_flip_f64x8));
|
|
476
|
+
|
|
477
|
+
// Compensated horizontal reduction preserving Dot2 error tracking
|
|
478
|
+
result->real = nk_dot_stable_sum_f64x8_skylake_(sum_real_f64x8, compensation_real_f64x8);
|
|
479
|
+
result->imag = nk_dot_stable_sum_f64x8_skylake_(sum_imag_f64x8, compensation_imag_f64x8);
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
#pragma region - Smaller Floats
|
|
483
|
+
|
|
484
|
+
NK_PUBLIC void nk_dot_f16_skylake(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
485
|
+
nk_f32_t *result) {
|
|
486
|
+
__m256i a_f16x16, b_f16x16;
|
|
487
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
488
|
+
|
|
489
|
+
nk_dot_f16_skylake_cycle:
|
|
490
|
+
if (count_scalars < 16) {
|
|
491
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
|
|
492
|
+
a_f16x16 = _mm256_maskz_loadu_epi16(mask, a_scalars);
|
|
493
|
+
b_f16x16 = _mm256_maskz_loadu_epi16(mask, b_scalars);
|
|
494
|
+
count_scalars = 0;
|
|
495
|
+
}
|
|
496
|
+
else {
|
|
497
|
+
a_f16x16 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
498
|
+
b_f16x16 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
499
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
500
|
+
}
|
|
501
|
+
__m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
|
|
502
|
+
__m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
|
|
503
|
+
sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
|
|
504
|
+
if (count_scalars) goto nk_dot_f16_skylake_cycle;
|
|
505
|
+
|
|
506
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
NK_PUBLIC void nk_dot_bf16_skylake(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
510
|
+
nk_f32_t *result) {
|
|
511
|
+
__m256i a_bf16x16, b_bf16x16;
|
|
512
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
513
|
+
|
|
514
|
+
nk_dot_bf16_skylake_cycle:
|
|
515
|
+
if (count_scalars < 16) {
|
|
516
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
|
|
517
|
+
a_bf16x16 = _mm256_maskz_loadu_epi16(mask, a_scalars);
|
|
518
|
+
b_bf16x16 = _mm256_maskz_loadu_epi16(mask, b_scalars);
|
|
519
|
+
count_scalars = 0;
|
|
520
|
+
}
|
|
521
|
+
else {
|
|
522
|
+
a_bf16x16 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
523
|
+
b_bf16x16 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
524
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
525
|
+
}
|
|
526
|
+
__m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
527
|
+
__m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
|
|
528
|
+
sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
|
|
529
|
+
if (count_scalars) goto nk_dot_bf16_skylake_cycle;
|
|
530
|
+
|
|
531
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
NK_PUBLIC void nk_dot_e4m3_skylake(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
535
|
+
nk_f32_t *result) {
|
|
536
|
+
__m128i a_e4m3x16, b_e4m3x16;
|
|
537
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
538
|
+
|
|
539
|
+
nk_dot_e4m3_skylake_cycle:
|
|
540
|
+
if (count_scalars < 16) {
|
|
541
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
|
|
542
|
+
a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
|
|
543
|
+
b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
|
|
544
|
+
count_scalars = 0;
|
|
545
|
+
}
|
|
546
|
+
else {
|
|
547
|
+
a_e4m3x16 = _mm_loadu_si128((__m128i const *)a_scalars);
|
|
548
|
+
b_e4m3x16 = _mm_loadu_si128((__m128i const *)b_scalars);
|
|
549
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
550
|
+
}
|
|
551
|
+
__m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
|
|
552
|
+
__m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3x16);
|
|
553
|
+
sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
|
|
554
|
+
if (count_scalars) goto nk_dot_e4m3_skylake_cycle;
|
|
555
|
+
|
|
556
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
NK_PUBLIC void nk_dot_e5m2_skylake(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
560
|
+
nk_f32_t *result) {
|
|
561
|
+
__m128i a_e5m2x16, b_e5m2x16;
|
|
562
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
563
|
+
|
|
564
|
+
nk_dot_e5m2_skylake_cycle:
|
|
565
|
+
if (count_scalars < 16) {
|
|
566
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
|
|
567
|
+
a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
|
|
568
|
+
b_e5m2x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
|
|
569
|
+
count_scalars = 0;
|
|
570
|
+
}
|
|
571
|
+
else {
|
|
572
|
+
a_e5m2x16 = _mm_loadu_si128((__m128i const *)a_scalars);
|
|
573
|
+
b_e5m2x16 = _mm_loadu_si128((__m128i const *)b_scalars);
|
|
574
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
575
|
+
}
|
|
576
|
+
__m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
|
|
577
|
+
__m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2x16);
|
|
578
|
+
sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
|
|
579
|
+
if (count_scalars) goto nk_dot_e5m2_skylake_cycle;
|
|
580
|
+
|
|
581
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
NK_PUBLIC void nk_dot_e2m3_skylake(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
585
|
+
nk_f32_t *result) {
|
|
586
|
+
// Integer dot product for e2m3 using dual-VPSHUFB (LUT) + VPMADDUBSW (unsigned×signed).
|
|
587
|
+
// 64 elements per iteration using AVX-512BW. Result = i32_dot / 256.0f (exact).
|
|
588
|
+
//
|
|
589
|
+
// LUTs replicated 4× for 512-bit VPSHUFB (operates per 128-bit lane):
|
|
590
|
+
__m512i const lut_lower_u8x64 = _mm512_set_epi8( //
|
|
591
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
592
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
593
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
594
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
595
|
+
__m512i const lut_upper_u8x64 = _mm512_set_epi8( //
|
|
596
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
597
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
598
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
599
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
600
|
+
__m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
|
|
601
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
602
|
+
__m512i const half_select_u8x64 = _mm512_set1_epi8(0x10);
|
|
603
|
+
__m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
|
|
604
|
+
__m512i const ones_i16x32 = _mm512_set1_epi16(1);
|
|
605
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
606
|
+
__m512i a_e2m3_u8x64, b_e2m3_u8x64;
|
|
607
|
+
|
|
608
|
+
nk_dot_e2m3_skylake_cycle:
|
|
609
|
+
if (count_scalars < 64) {
|
|
610
|
+
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, count_scalars);
|
|
611
|
+
a_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a_scalars);
|
|
612
|
+
b_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b_scalars);
|
|
613
|
+
count_scalars = 0;
|
|
614
|
+
}
|
|
615
|
+
else {
|
|
616
|
+
a_e2m3_u8x64 = _mm512_loadu_si512((__m512i const *)a_scalars);
|
|
617
|
+
b_e2m3_u8x64 = _mm512_loadu_si512((__m512i const *)b_scalars);
|
|
618
|
+
a_scalars += 64, b_scalars += 64, count_scalars -= 64;
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
// Extract 5-bit magnitude, then split into low 4 bits (VPSHUFB index) and bit 4 (hi/lo select)
|
|
622
|
+
__m512i a_magnitude_u8x64 = _mm512_and_si512(a_e2m3_u8x64, magnitude_mask_u8x64);
|
|
623
|
+
__m512i b_magnitude_u8x64 = _mm512_and_si512(b_e2m3_u8x64, magnitude_mask_u8x64);
|
|
624
|
+
__m512i a_shuffle_index_u8x64 = _mm512_and_si512(a_magnitude_u8x64, nibble_mask_u8x64);
|
|
625
|
+
__m512i b_shuffle_index_u8x64 = _mm512_and_si512(b_magnitude_u8x64, nibble_mask_u8x64);
|
|
626
|
+
|
|
627
|
+
// Bit-4 select via kmask (cleaner than Haswell's vector compare)
|
|
628
|
+
__mmask64 a_upper_select = _mm512_test_epi8_mask(a_magnitude_u8x64, half_select_u8x64);
|
|
629
|
+
__mmask64 b_upper_select = _mm512_test_epi8_mask(b_magnitude_u8x64, half_select_u8x64);
|
|
630
|
+
|
|
631
|
+
// Dual VPSHUFB + mask-blend for 32-entry LUT
|
|
632
|
+
__m512i a_unsigned_u8x64 = _mm512_mask_blend_epi8(a_upper_select,
|
|
633
|
+
_mm512_shuffle_epi8(lut_lower_u8x64, a_shuffle_index_u8x64),
|
|
634
|
+
_mm512_shuffle_epi8(lut_upper_u8x64, a_shuffle_index_u8x64));
|
|
635
|
+
__m512i b_unsigned_u8x64 = _mm512_mask_blend_epi8(b_upper_select,
|
|
636
|
+
_mm512_shuffle_epi8(lut_lower_u8x64, b_shuffle_index_u8x64),
|
|
637
|
+
_mm512_shuffle_epi8(lut_upper_u8x64, b_shuffle_index_u8x64));
|
|
638
|
+
|
|
639
|
+
// Combined sign: (a ^ b) & 0x20, negate b where signs differ using kmask
|
|
640
|
+
__m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e2m3_u8x64, b_e2m3_u8x64), sign_mask_u8x64);
|
|
641
|
+
__mmask64 negate_mask = _mm512_test_epi8_mask(sign_combined_u8x64, sign_combined_u8x64);
|
|
642
|
+
__m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_unsigned_u8x64, negate_mask, _mm512_setzero_si512(),
|
|
643
|
+
b_unsigned_u8x64);
|
|
644
|
+
|
|
645
|
+
// VPMADDUBSW: a_unsigned[u8] × b_signed[i8] → i16 pairs
|
|
646
|
+
__m512i products_i16x32 = _mm512_maddubs_epi16(a_unsigned_u8x64, b_signed_i8x64);
|
|
647
|
+
// VPMADDWD with ones: i16 pairs → i32
|
|
648
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(products_i16x32, ones_i16x32));
|
|
649
|
+
|
|
650
|
+
if (count_scalars) goto nk_dot_e2m3_skylake_cycle;
|
|
651
|
+
*result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
NK_PUBLIC void nk_dot_e3m2_skylake(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
|
|
655
|
+
nk_f32_t *result) {
|
|
656
|
+
// Integer dot product for e3m2 using dual-VPSHUFB (low-byte LUT) + VPMADDWD (i16×i16→i32).
|
|
657
|
+
// 64 elements per iteration using AVX-512BW. Magnitudes reach 448, requiring i16.
|
|
658
|
+
// Result = i32_dot / 256.0f (exact, no rounding error).
|
|
659
|
+
//
|
|
660
|
+
__m512i const lut_lo_lower_u8x64 = _mm512_set_epi8( //
|
|
661
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
662
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
663
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
664
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
665
|
+
__m512i const lut_lo_upper_u8x64 = _mm512_set_epi8( //
|
|
666
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
667
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
668
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
669
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
670
|
+
__m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
|
|
671
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
672
|
+
__m512i const half_select_u8x64 = _mm512_set1_epi8(0x10);
|
|
673
|
+
__m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
|
|
674
|
+
__m512i const ones_u8x64 = _mm512_set1_epi8(1);
|
|
675
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
676
|
+
__m512i a_e3m2_u8x64, b_e3m2_u8x64;
|
|
677
|
+
|
|
678
|
+
nk_dot_e3m2_skylake_cycle:
|
|
679
|
+
if (count_scalars < 64) {
|
|
680
|
+
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, count_scalars);
|
|
681
|
+
a_e3m2_u8x64 = _mm512_maskz_loadu_epi8(mask, a_scalars);
|
|
682
|
+
b_e3m2_u8x64 = _mm512_maskz_loadu_epi8(mask, b_scalars);
|
|
683
|
+
count_scalars = 0;
|
|
684
|
+
}
|
|
685
|
+
else {
|
|
686
|
+
a_e3m2_u8x64 = _mm512_loadu_si512((__m512i const *)a_scalars);
|
|
687
|
+
b_e3m2_u8x64 = _mm512_loadu_si512((__m512i const *)b_scalars);
|
|
688
|
+
a_scalars += 64, b_scalars += 64, count_scalars -= 64;
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
// Extract 5-bit magnitude, split into low 4 bits and bit 4
|
|
692
|
+
__m512i a_magnitude_u8x64 = _mm512_and_si512(a_e3m2_u8x64, magnitude_mask_u8x64);
|
|
693
|
+
__m512i b_magnitude_u8x64 = _mm512_and_si512(b_e3m2_u8x64, magnitude_mask_u8x64);
|
|
694
|
+
__m512i a_shuffle_index_u8x64 = _mm512_and_si512(a_magnitude_u8x64, nibble_mask_u8x64);
|
|
695
|
+
__m512i b_shuffle_index_u8x64 = _mm512_and_si512(b_magnitude_u8x64, nibble_mask_u8x64);
|
|
696
|
+
|
|
697
|
+
// Bit-4 select via kmask
|
|
698
|
+
__mmask64 a_upper_select = _mm512_test_epi8_mask(a_magnitude_u8x64, half_select_u8x64);
|
|
699
|
+
__mmask64 b_upper_select = _mm512_test_epi8_mask(b_magnitude_u8x64, half_select_u8x64);
|
|
700
|
+
|
|
701
|
+
// Dual VPSHUFB + mask-blend for low bytes
|
|
702
|
+
__m512i a_lo_bytes_u8x64 = _mm512_mask_blend_epi8(a_upper_select,
|
|
703
|
+
_mm512_shuffle_epi8(lut_lo_lower_u8x64, a_shuffle_index_u8x64),
|
|
704
|
+
_mm512_shuffle_epi8(lut_lo_upper_u8x64, a_shuffle_index_u8x64));
|
|
705
|
+
__m512i b_lo_bytes_u8x64 = _mm512_mask_blend_epi8(b_upper_select,
|
|
706
|
+
_mm512_shuffle_epi8(lut_lo_lower_u8x64, b_shuffle_index_u8x64),
|
|
707
|
+
_mm512_shuffle_epi8(lut_lo_upper_u8x64, b_shuffle_index_u8x64));
|
|
708
|
+
|
|
709
|
+
// High byte: 1 iff magnitude >= 28 (unsigned compare via _mm512_cmpge_epu8_mask)
|
|
710
|
+
__mmask64 a_hi_mask = _mm512_cmpge_epu8_mask(a_magnitude_u8x64, _mm512_set1_epi8(28));
|
|
711
|
+
__mmask64 b_hi_mask = _mm512_cmpge_epu8_mask(b_magnitude_u8x64, _mm512_set1_epi8(28));
|
|
712
|
+
__m512i a_hi_bytes_u8x64 = _mm512_maskz_mov_epi8(a_hi_mask, ones_u8x64);
|
|
713
|
+
__m512i b_hi_bytes_u8x64 = _mm512_maskz_mov_epi8(b_hi_mask, ones_u8x64);
|
|
714
|
+
|
|
715
|
+
// Interleave low and high bytes into i16
|
|
716
|
+
__m512i a_lo_i16x32 = _mm512_unpacklo_epi8(a_lo_bytes_u8x64, a_hi_bytes_u8x64);
|
|
717
|
+
__m512i a_hi_i16x32 = _mm512_unpackhi_epi8(a_lo_bytes_u8x64, a_hi_bytes_u8x64);
|
|
718
|
+
__m512i b_lo_i16x32 = _mm512_unpacklo_epi8(b_lo_bytes_u8x64, b_hi_bytes_u8x64);
|
|
719
|
+
__m512i b_hi_i16x32 = _mm512_unpackhi_epi8(b_lo_bytes_u8x64, b_hi_bytes_u8x64);
|
|
720
|
+
|
|
721
|
+
// Combined sign: (a ^ b) & 0x20, need to apply at i16 level
|
|
722
|
+
// Compute sign mask at u8 level, widen to match unpacklo/unpackhi ordering via PEXT
|
|
723
|
+
__m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e3m2_u8x64, b_e3m2_u8x64), sign_mask_u8x64);
|
|
724
|
+
__mmask64 negate_u8_mask = _mm512_test_epi8_mask(sign_combined_u8x64, sign_combined_u8x64);
|
|
725
|
+
// Extract bits matching unpacklo element ordering (bytes 0-7,16-23,32-39,48-55 per 64-byte vector)
|
|
726
|
+
__mmask32 negate_lo_i16 = (__mmask32)_pext_u64(negate_u8_mask, 0x00FF00FF00FF00FFULL);
|
|
727
|
+
__mmask32 negate_hi_i16 = (__mmask32)_pext_u64(negate_u8_mask, 0xFF00FF00FF00FF00ULL);
|
|
728
|
+
// Negate b at i16 level using mask_sub
|
|
729
|
+
__m512i b_signed_lo_i16x32 = _mm512_mask_sub_epi16(b_lo_i16x32, negate_lo_i16, _mm512_setzero_si512(), b_lo_i16x32);
|
|
730
|
+
__m512i b_signed_hi_i16x32 = _mm512_mask_sub_epi16(b_hi_i16x32, negate_hi_i16, _mm512_setzero_si512(), b_hi_i16x32);
|
|
731
|
+
|
|
732
|
+
// VPMADDWD: a_i16 × b_signed_i16 → i32 accumulator
|
|
733
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_lo_i16x32, b_signed_lo_i16x32));
|
|
734
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_hi_i16x32, b_signed_hi_i16x32));
|
|
735
|
+
|
|
736
|
+
if (count_scalars) goto nk_dot_e3m2_skylake_cycle;
|
|
737
|
+
*result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
#pragma endregion - Smaller Floats
|
|
741
|
+
|
|
742
|
+
#pragma region - Small Integers
|
|
743
|
+
|
|
744
|
+
NK_PUBLIC void nk_dot_i8_skylake(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
745
|
+
nk_i32_t *result) {
|
|
746
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
747
|
+
nk_size_t idx_scalars = 0;
|
|
748
|
+
for (; idx_scalars + 32 <= count_scalars; idx_scalars += 32) {
|
|
749
|
+
// Load 32 bytes at a time and widen to i16
|
|
750
|
+
__m256i a_i8x32 = _mm256_loadu_si256((__m256i const *)(a_scalars + idx_scalars));
|
|
751
|
+
__m256i b_i8x32 = _mm256_loadu_si256((__m256i const *)(b_scalars + idx_scalars));
|
|
752
|
+
__m512i a_i16x32 = _mm512_cvtepi8_epi16(a_i8x32);
|
|
753
|
+
__m512i b_i16x32 = _mm512_cvtepi8_epi16(b_i8x32);
|
|
754
|
+
// VPMADDWD: 5cy (0.5/cy) @ p05 - multiply adjacent i16 pairs, add to i32
|
|
755
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_i16x32, b_i16x32));
|
|
756
|
+
}
|
|
757
|
+
nk_i32_t sum = _mm512_reduce_add_epi32(sum_i32x16);
|
|
758
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_i32_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
|
|
759
|
+
*result = sum;
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
NK_PUBLIC void nk_dot_u8_skylake(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
|
|
763
|
+
nk_u32_t *result) {
|
|
764
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
765
|
+
nk_size_t idx_scalars = 0;
|
|
766
|
+
for (; idx_scalars + 32 <= count_scalars; idx_scalars += 32) {
|
|
767
|
+
// Load 32 bytes and zero-extend to i16 (u8 → u16 via zero-extension)
|
|
768
|
+
__m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a_scalars + idx_scalars));
|
|
769
|
+
__m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b_scalars + idx_scalars));
|
|
770
|
+
__m512i a_u16x32 = _mm512_cvtepu8_epi16(a_u8x32);
|
|
771
|
+
__m512i b_u16x32 = _mm512_cvtepu8_epi16(b_u8x32);
|
|
772
|
+
// VPMADDWD: 5cy (0.5/cy) @ p05 - multiply adjacent i16 pairs, add to i32
|
|
773
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_u16x32, b_u16x32));
|
|
774
|
+
}
|
|
775
|
+
nk_u32_t sum = (nk_u32_t)_mm512_reduce_add_epi32(sum_i32x16);
|
|
776
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_u32_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
|
|
777
|
+
*result = sum;
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
typedef struct nk_dot_f64x8_state_skylake_t {
|
|
781
|
+
__m512d sum_f64x8;
|
|
782
|
+
__m512d compensation_f64x8;
|
|
783
|
+
} nk_dot_f64x8_state_skylake_t;
|
|
784
|
+
|
|
785
|
+
NK_INTERNAL void nk_dot_f64x8_init_skylake(nk_dot_f64x8_state_skylake_t *state) {
|
|
786
|
+
state->sum_f64x8 = _mm512_setzero_pd();
|
|
787
|
+
state->compensation_f64x8 = _mm512_setzero_pd();
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
NK_INTERNAL void nk_dot_f64x8_update_skylake(nk_dot_f64x8_state_skylake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
|
|
791
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
792
|
+
nk_unused_(depth_offset);
|
|
793
|
+
nk_unused_(active_dimensions);
|
|
794
|
+
__m512d sum_f64x8 = state->sum_f64x8;
|
|
795
|
+
__m512d compensation_f64x8 = state->compensation_f64x8;
|
|
796
|
+
__m512d a_f64x8 = a.zmm_pd;
|
|
797
|
+
__m512d b_f64x8 = b.zmm_pd;
|
|
798
|
+
|
|
799
|
+
// TwoProd: h = a * b, r = fma(a, b, -h) captures the rounding error
|
|
800
|
+
__m512d product_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
801
|
+
__m512d product_error_f64x8 = _mm512_fmsub_pd(a_f64x8, b_f64x8, product_f64x8);
|
|
802
|
+
|
|
803
|
+
// TwoSum: (t, q) = TwoSum(sum, h) where t = sum + h rounded, q = error
|
|
804
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(sum_f64x8, product_f64x8);
|
|
805
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, sum_f64x8);
|
|
806
|
+
__m512d sum_error_f64x8 = _mm512_add_pd(
|
|
807
|
+
_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
|
|
808
|
+
_mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
|
|
809
|
+
|
|
810
|
+
// Update: sum = t, compensation += q + r
|
|
811
|
+
state->sum_f64x8 = tentative_sum_f64x8;
|
|
812
|
+
state->compensation_f64x8 = _mm512_add_pd(compensation_f64x8, _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
NK_INTERNAL void nk_dot_f64x8_finalize_skylake( //
|
|
816
|
+
nk_dot_f64x8_state_skylake_t const *state_a, nk_dot_f64x8_state_skylake_t const *state_b, //
|
|
817
|
+
nk_dot_f64x8_state_skylake_t const *state_c, nk_dot_f64x8_state_skylake_t const *state_d, //
|
|
818
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
819
|
+
nk_unused_(total_dimensions);
|
|
820
|
+
// Compensated horizontal reduction preserving Dot2 error tracking per state
|
|
821
|
+
result->f64s[0] = nk_dot_stable_sum_f64x8_skylake_(state_a->sum_f64x8, state_a->compensation_f64x8);
|
|
822
|
+
result->f64s[1] = nk_dot_stable_sum_f64x8_skylake_(state_b->sum_f64x8, state_b->compensation_f64x8);
|
|
823
|
+
result->f64s[2] = nk_dot_stable_sum_f64x8_skylake_(state_c->sum_f64x8, state_c->compensation_f64x8);
|
|
824
|
+
result->f64s[3] = nk_dot_stable_sum_f64x8_skylake_(state_d->sum_f64x8, state_d->compensation_f64x8);
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
typedef struct nk_dot_f32x8_state_skylake_t {
|
|
828
|
+
__m512d sum_f64x8;
|
|
829
|
+
} nk_dot_f32x8_state_skylake_t;
|
|
830
|
+
|
|
831
|
+
NK_INTERNAL void nk_dot_f32x8_init_skylake(nk_dot_f32x8_state_skylake_t *state) {
|
|
832
|
+
state->sum_f64x8 = _mm512_setzero_pd();
|
|
833
|
+
}
|
|
834
|
+
|
|
835
|
+
NK_INTERNAL void nk_dot_f32x8_update_skylake(nk_dot_f32x8_state_skylake_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
|
|
836
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
837
|
+
nk_unused_(depth_offset);
|
|
838
|
+
nk_unused_(active_dimensions);
|
|
839
|
+
// Upcast 8 f32s to f64 for high-precision accumulation
|
|
840
|
+
__m512d a_f64x8 = _mm512_cvtps_pd(a.ymm_ps);
|
|
841
|
+
__m512d b_f64x8 = _mm512_cvtps_pd(b.ymm_ps);
|
|
842
|
+
// Simple FMA accumulation in f64
|
|
843
|
+
state->sum_f64x8 = _mm512_fmadd_pd(a_f64x8, b_f64x8, state->sum_f64x8);
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
NK_INTERNAL void nk_dot_f32x8_finalize_skylake( //
|
|
847
|
+
nk_dot_f32x8_state_skylake_t const *state_a, nk_dot_f32x8_state_skylake_t const *state_b, //
|
|
848
|
+
nk_dot_f32x8_state_skylake_t const *state_c, nk_dot_f32x8_state_skylake_t const *state_d, //
|
|
849
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
850
|
+
nk_unused_(total_dimensions);
|
|
851
|
+
// ILP-optimized 4-way horizontal reduction for f64
|
|
852
|
+
// Step 1: 8->4 for all 4 states (extract high 256-bit half and add to low half)
|
|
853
|
+
__m256d sum_a_f64x4 = _mm256_add_pd(_mm512_castpd512_pd256(state_a->sum_f64x8),
|
|
854
|
+
_mm512_extractf64x4_pd(state_a->sum_f64x8, 1));
|
|
855
|
+
__m256d sum_b_f64x4 = _mm256_add_pd(_mm512_castpd512_pd256(state_b->sum_f64x8),
|
|
856
|
+
_mm512_extractf64x4_pd(state_b->sum_f64x8, 1));
|
|
857
|
+
__m256d sum_c_f64x4 = _mm256_add_pd(_mm512_castpd512_pd256(state_c->sum_f64x8),
|
|
858
|
+
_mm512_extractf64x4_pd(state_c->sum_f64x8, 1));
|
|
859
|
+
__m256d sum_d_f64x4 = _mm256_add_pd(_mm512_castpd512_pd256(state_d->sum_f64x8),
|
|
860
|
+
_mm512_extractf64x4_pd(state_d->sum_f64x8, 1));
|
|
861
|
+
// Step 2: 4->2 for all 4 states (extract high 128-bit half and add to low half)
|
|
862
|
+
__m128d sum_a_f64x2 = _mm_add_pd(_mm256_castpd256_pd128(sum_a_f64x4), _mm256_extractf128_pd(sum_a_f64x4, 1));
|
|
863
|
+
__m128d sum_b_f64x2 = _mm_add_pd(_mm256_castpd256_pd128(sum_b_f64x4), _mm256_extractf128_pd(sum_b_f64x4, 1));
|
|
864
|
+
__m128d sum_c_f64x2 = _mm_add_pd(_mm256_castpd256_pd128(sum_c_f64x4), _mm256_extractf128_pd(sum_c_f64x4, 1));
|
|
865
|
+
__m128d sum_d_f64x2 = _mm_add_pd(_mm256_castpd256_pd128(sum_d_f64x4), _mm256_extractf128_pd(sum_d_f64x4, 1));
|
|
866
|
+
// Step 3: Horizontal add pairs: [a0+a1, b0+b1] and [c0+c1, d0+d1]
|
|
867
|
+
__m128d sum_ab_f64x2 = _mm_hadd_pd(sum_a_f64x2, sum_b_f64x2);
|
|
868
|
+
__m128d sum_cd_f64x2 = _mm_hadd_pd(sum_c_f64x2, sum_d_f64x2);
|
|
869
|
+
result->ymm_pd = _mm256_set_m128d(sum_cd_f64x2, sum_ab_f64x2);
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
#pragma endregion - Traditional Floats
|
|
873
|
+
|
|
874
|
+
typedef nk_dot_through_f32_state_skylake_t_ nk_dot_bf16x16_state_skylake_t;
|
|
875
|
+
|
|
876
|
+
typedef nk_dot_through_f32_state_skylake_t_ nk_dot_f16x16_state_skylake_t;
|
|
877
|
+
|
|
878
|
+
typedef struct nk_dot_e2m3x64_state_skylake_t {
|
|
879
|
+
__m512i sum_i32x16;
|
|
880
|
+
} nk_dot_e2m3x64_state_skylake_t;
|
|
881
|
+
|
|
882
|
+
NK_INTERNAL void nk_dot_e2m3x64_init_skylake(nk_dot_e2m3x64_state_skylake_t *state) {
|
|
883
|
+
state->sum_i32x16 = _mm512_setzero_si512();
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
NK_INTERNAL void nk_dot_e2m3x64_update_skylake(nk_dot_e2m3x64_state_skylake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
|
|
887
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
888
|
+
nk_unused_(depth_offset);
|
|
889
|
+
nk_unused_(active_dimensions);
|
|
890
|
+
__m512i const lut_lower_u8x64 = _mm512_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
|
|
891
|
+
26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26, 24,
|
|
892
|
+
22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26, 24, 22, 20,
|
|
893
|
+
18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
894
|
+
__m512i const lut_upper_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
895
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
896
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
897
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
898
|
+
__m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
|
|
899
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
900
|
+
__m512i const half_select_u8x64 = _mm512_set1_epi8(0x10);
|
|
901
|
+
__m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
|
|
902
|
+
__m512i const ones_i16x32 = _mm512_set1_epi16(1);
|
|
903
|
+
|
|
904
|
+
__m512i a_u8x64 = a.zmm;
|
|
905
|
+
__m512i b_u8x64 = b.zmm;
|
|
906
|
+
|
|
907
|
+
__m512i a_magnitude = _mm512_and_si512(a_u8x64, magnitude_mask_u8x64);
|
|
908
|
+
__m512i b_magnitude = _mm512_and_si512(b_u8x64, magnitude_mask_u8x64);
|
|
909
|
+
__m512i a_shuffle_idx = _mm512_and_si512(a_magnitude, nibble_mask_u8x64);
|
|
910
|
+
__m512i b_shuffle_idx = _mm512_and_si512(b_magnitude, nibble_mask_u8x64);
|
|
911
|
+
|
|
912
|
+
__mmask64 a_upper = _mm512_test_epi8_mask(a_magnitude, half_select_u8x64);
|
|
913
|
+
__mmask64 b_upper = _mm512_test_epi8_mask(b_magnitude, half_select_u8x64);
|
|
914
|
+
|
|
915
|
+
__m512i a_unsigned = _mm512_mask_blend_epi8(a_upper, _mm512_shuffle_epi8(lut_lower_u8x64, a_shuffle_idx),
|
|
916
|
+
_mm512_shuffle_epi8(lut_upper_u8x64, a_shuffle_idx));
|
|
917
|
+
__m512i b_unsigned = _mm512_mask_blend_epi8(b_upper, _mm512_shuffle_epi8(lut_lower_u8x64, b_shuffle_idx),
|
|
918
|
+
_mm512_shuffle_epi8(lut_upper_u8x64, b_shuffle_idx));
|
|
919
|
+
|
|
920
|
+
__m512i sign_combined = _mm512_and_si512(_mm512_xor_si512(a_u8x64, b_u8x64), sign_mask_u8x64);
|
|
921
|
+
__mmask64 negate_mask = _mm512_test_epi8_mask(sign_combined, sign_combined);
|
|
922
|
+
__m512i b_signed = _mm512_mask_sub_epi8(b_unsigned, negate_mask, _mm512_setzero_si512(), b_unsigned);
|
|
923
|
+
|
|
924
|
+
__m512i products_i16x32 = _mm512_maddubs_epi16(a_unsigned, b_signed);
|
|
925
|
+
state->sum_i32x16 = _mm512_add_epi32(state->sum_i32x16, _mm512_madd_epi16(products_i16x32, ones_i16x32));
|
|
926
|
+
}
|
|
927
|
+
|
|
928
|
+
NK_INTERNAL void nk_dot_e2m3x64_finalize_skylake( //
|
|
929
|
+
nk_dot_e2m3x64_state_skylake_t const *state_a, nk_dot_e2m3x64_state_skylake_t const *state_b, //
|
|
930
|
+
nk_dot_e2m3x64_state_skylake_t const *state_c, nk_dot_e2m3x64_state_skylake_t const *state_d, //
|
|
931
|
+
nk_size_t total_dimensions, nk_b128_vec_t *results) {
|
|
932
|
+
nk_unused_(total_dimensions);
|
|
933
|
+
|
|
934
|
+
// 16→8 for all 4 states (extract high 256-bit half and add to low half)
|
|
935
|
+
__m256i sum_a_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_a->sum_i32x16),
|
|
936
|
+
_mm512_extracti32x8_epi32(state_a->sum_i32x16, 1));
|
|
937
|
+
__m256i sum_b_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_b->sum_i32x16),
|
|
938
|
+
_mm512_extracti32x8_epi32(state_b->sum_i32x16, 1));
|
|
939
|
+
__m256i sum_c_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_c->sum_i32x16),
|
|
940
|
+
_mm512_extracti32x8_epi32(state_c->sum_i32x16, 1));
|
|
941
|
+
__m256i sum_d_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_d->sum_i32x16),
|
|
942
|
+
_mm512_extracti32x8_epi32(state_d->sum_i32x16, 1));
|
|
943
|
+
|
|
944
|
+
// 8→4: extract high 128-bit half and add to low half
|
|
945
|
+
__m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_a_i32x8), _mm256_extracti128_si256(sum_a_i32x8, 1));
|
|
946
|
+
__m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_b_i32x8), _mm256_extracti128_si256(sum_b_i32x8, 1));
|
|
947
|
+
__m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_c_i32x8), _mm256_extracti128_si256(sum_c_i32x8, 1));
|
|
948
|
+
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
|
|
949
|
+
|
|
950
|
+
// 4×4 transpose and reduce (same as Sierra/Haswell integer finalize)
|
|
951
|
+
__m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
952
|
+
__m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
953
|
+
__m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
954
|
+
__m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
955
|
+
__m128i lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
956
|
+
__m128i lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
957
|
+
__m128i lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
958
|
+
__m128i lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
959
|
+
__m128i sum_i32x4 = _mm_add_epi32(_mm_add_epi32(lane0_i32x4, lane1_i32x4), _mm_add_epi32(lane2_i32x4, lane3_i32x4));
|
|
960
|
+
|
|
961
|
+
__m128 sum_f32x4 = _mm_mul_ps(_mm_cvtepi32_ps(sum_i32x4), _mm_set1_ps(1.0f / 256.0f));
|
|
962
|
+
results->xmm = _mm_castps_si128(sum_f32x4);
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
typedef struct nk_dot_e3m2x64_state_skylake_t {
|
|
966
|
+
__m512i sum_a_i32x16;
|
|
967
|
+
__m512i sum_b_i32x16;
|
|
968
|
+
} nk_dot_e3m2x64_state_skylake_t;
|
|
969
|
+
|
|
970
|
+
NK_INTERNAL void nk_dot_e3m2x64_init_skylake(nk_dot_e3m2x64_state_skylake_t *state) {
|
|
971
|
+
state->sum_a_i32x16 = _mm512_setzero_si512();
|
|
972
|
+
state->sum_b_i32x16 = _mm512_setzero_si512();
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
NK_INTERNAL void nk_dot_e3m2x64_update_skylake(nk_dot_e3m2x64_state_skylake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
|
|
976
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
977
|
+
nk_unused_(depth_offset);
|
|
978
|
+
nk_unused_(active_dimensions);
|
|
979
|
+
__m512i const lut_lo_lower_u8x64 = _mm512_set_epi8( //
|
|
980
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
981
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
982
|
+
__m512i const lut_lo_upper_u8x64 = _mm512_set_epi8( //
|
|
983
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
984
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
985
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
986
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
987
|
+
__m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
|
|
988
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
989
|
+
__m512i const half_select_u8x64 = _mm512_set1_epi8(0x10);
|
|
990
|
+
__m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
|
|
991
|
+
__m512i const ones_u8x64 = _mm512_set1_epi8(1);
|
|
992
|
+
|
|
993
|
+
__m512i a_u8x64 = a.zmm;
|
|
994
|
+
__m512i b_u8x64 = b.zmm;
|
|
995
|
+
|
|
996
|
+
__m512i a_magnitude = _mm512_and_si512(a_u8x64, magnitude_mask_u8x64);
|
|
997
|
+
__m512i b_magnitude = _mm512_and_si512(b_u8x64, magnitude_mask_u8x64);
|
|
998
|
+
__m512i a_shuffle_idx = _mm512_and_si512(a_magnitude, nibble_mask_u8x64);
|
|
999
|
+
__m512i b_shuffle_idx = _mm512_and_si512(b_magnitude, nibble_mask_u8x64);
|
|
1000
|
+
|
|
1001
|
+
__mmask64 a_upper = _mm512_test_epi8_mask(a_magnitude, half_select_u8x64);
|
|
1002
|
+
__mmask64 b_upper = _mm512_test_epi8_mask(b_magnitude, half_select_u8x64);
|
|
1003
|
+
|
|
1004
|
+
__m512i a_lo_bytes = _mm512_mask_blend_epi8(a_upper, _mm512_shuffle_epi8(lut_lo_lower_u8x64, a_shuffle_idx),
|
|
1005
|
+
_mm512_shuffle_epi8(lut_lo_upper_u8x64, a_shuffle_idx));
|
|
1006
|
+
__m512i b_lo_bytes = _mm512_mask_blend_epi8(b_upper, _mm512_shuffle_epi8(lut_lo_lower_u8x64, b_shuffle_idx),
|
|
1007
|
+
_mm512_shuffle_epi8(lut_lo_upper_u8x64, b_shuffle_idx));
|
|
1008
|
+
|
|
1009
|
+
__mmask64 a_hi_mask = _mm512_cmpge_epu8_mask(a_magnitude, _mm512_set1_epi8(28));
|
|
1010
|
+
__mmask64 b_hi_mask = _mm512_cmpge_epu8_mask(b_magnitude, _mm512_set1_epi8(28));
|
|
1011
|
+
__m512i a_hi_bytes = _mm512_maskz_mov_epi8(a_hi_mask, ones_u8x64);
|
|
1012
|
+
__m512i b_hi_bytes = _mm512_maskz_mov_epi8(b_hi_mask, ones_u8x64);
|
|
1013
|
+
|
|
1014
|
+
__m512i a_lo_i16 = _mm512_unpacklo_epi8(a_lo_bytes, a_hi_bytes);
|
|
1015
|
+
__m512i a_hi_i16 = _mm512_unpackhi_epi8(a_lo_bytes, a_hi_bytes);
|
|
1016
|
+
__m512i b_lo_i16 = _mm512_unpacklo_epi8(b_lo_bytes, b_hi_bytes);
|
|
1017
|
+
__m512i b_hi_i16 = _mm512_unpackhi_epi8(b_lo_bytes, b_hi_bytes);
|
|
1018
|
+
|
|
1019
|
+
// Combined sign: negate b at i16 level via PEXT + mask_sub
|
|
1020
|
+
__m512i sign_combined = _mm512_and_si512(_mm512_xor_si512(a_u8x64, b_u8x64), sign_mask_u8x64);
|
|
1021
|
+
__mmask64 negate_u8 = _mm512_test_epi8_mask(sign_combined, sign_combined);
|
|
1022
|
+
__mmask32 negate_lo = (__mmask32)_pext_u64(negate_u8, 0x00FF00FF00FF00FFULL);
|
|
1023
|
+
__mmask32 negate_hi = (__mmask32)_pext_u64(negate_u8, 0xFF00FF00FF00FF00ULL);
|
|
1024
|
+
__m512i b_signed_lo = _mm512_mask_sub_epi16(b_lo_i16, negate_lo, _mm512_setzero_si512(), b_lo_i16);
|
|
1025
|
+
__m512i b_signed_hi = _mm512_mask_sub_epi16(b_hi_i16, negate_hi, _mm512_setzero_si512(), b_hi_i16);
|
|
1026
|
+
|
|
1027
|
+
state->sum_a_i32x16 = _mm512_add_epi32(state->sum_a_i32x16, _mm512_madd_epi16(a_lo_i16, b_signed_lo));
|
|
1028
|
+
state->sum_b_i32x16 = _mm512_add_epi32(state->sum_b_i32x16, _mm512_madd_epi16(a_hi_i16, b_signed_hi));
|
|
1029
|
+
}
|
|
1030
|
+
|
|
1031
|
+
NK_INTERNAL void nk_dot_e3m2x64_finalize_skylake( //
|
|
1032
|
+
nk_dot_e3m2x64_state_skylake_t const *state_a, nk_dot_e3m2x64_state_skylake_t const *state_b, //
|
|
1033
|
+
nk_dot_e3m2x64_state_skylake_t const *state_c, nk_dot_e3m2x64_state_skylake_t const *state_d, //
|
|
1034
|
+
nk_size_t total_dimensions, nk_b128_vec_t *results) {
|
|
1035
|
+
nk_unused_(total_dimensions);
|
|
1036
|
+
|
|
1037
|
+
// Merge two accumulators per state
|
|
1038
|
+
__m512i merged_a = _mm512_add_epi32(state_a->sum_a_i32x16, state_a->sum_b_i32x16);
|
|
1039
|
+
__m512i merged_b = _mm512_add_epi32(state_b->sum_a_i32x16, state_b->sum_b_i32x16);
|
|
1040
|
+
__m512i merged_c = _mm512_add_epi32(state_c->sum_a_i32x16, state_c->sum_b_i32x16);
|
|
1041
|
+
__m512i merged_d = _mm512_add_epi32(state_d->sum_a_i32x16, state_d->sum_b_i32x16);
|
|
1042
|
+
|
|
1043
|
+
// 16→8
|
|
1044
|
+
__m256i sum_a_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(merged_a), _mm512_extracti32x8_epi32(merged_a, 1));
|
|
1045
|
+
__m256i sum_b_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(merged_b), _mm512_extracti32x8_epi32(merged_b, 1));
|
|
1046
|
+
__m256i sum_c_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(merged_c), _mm512_extracti32x8_epi32(merged_c, 1));
|
|
1047
|
+
__m256i sum_d_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(merged_d), _mm512_extracti32x8_epi32(merged_d, 1));
|
|
1048
|
+
|
|
1049
|
+
// 8→4
|
|
1050
|
+
__m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_a_i32x8), _mm256_extracti128_si256(sum_a_i32x8, 1));
|
|
1051
|
+
__m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_b_i32x8), _mm256_extracti128_si256(sum_b_i32x8, 1));
|
|
1052
|
+
__m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_c_i32x8), _mm256_extracti128_si256(sum_c_i32x8, 1));
|
|
1053
|
+
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
|
|
1054
|
+
|
|
1055
|
+
// 4×4 transpose and reduce
|
|
1056
|
+
__m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
1057
|
+
__m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
1058
|
+
__m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
1059
|
+
__m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
1060
|
+
__m128i lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
1061
|
+
__m128i lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
1062
|
+
__m128i lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
1063
|
+
__m128i lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
1064
|
+
__m128i sum_i32x4 = _mm_add_epi32(_mm_add_epi32(lane0_i32x4, lane1_i32x4), _mm_add_epi32(lane2_i32x4, lane3_i32x4));
|
|
1065
|
+
|
|
1066
|
+
__m128 sum_f32x4 = _mm_mul_ps(_mm_cvtepi32_ps(sum_i32x4), _mm_set1_ps(1.0f / 256.0f));
|
|
1067
|
+
results->xmm = _mm_castps_si128(sum_f32x4);
|
|
1068
|
+
}
|
|
1069
|
+
|
|
1070
|
+
#pragma endregion - Small Integers
|
|
1071
|
+
|
|
1072
|
+
#if defined(__clang__)
|
|
1073
|
+
#pragma clang attribute pop
|
|
1074
|
+
#elif defined(__GNUC__)
|
|
1075
|
+
#pragma GCC pop_options
|
|
1076
|
+
#endif
|
|
1077
|
+
|
|
1078
|
+
#if defined(__cplusplus)
|
|
1079
|
+
} // extern "C"
|
|
1080
|
+
#endif
|
|
1081
|
+
|
|
1082
|
+
#endif // NK_TARGET_SKYLAKE
|
|
1083
|
+
#endif // NK_TARGET_X86_
|
|
1084
|
+
#endif // NK_DOT_SKYLAKE_H
|