numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,883 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for Ice Lake.
3
+ * @file include/numkong/dot/icelake.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_icelake_instructions VNNI Instructions Performance
10
+ *
11
+ * Intrinsic Instruction Ice Genoa
12
+ * _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
13
+ * _mm512_dpbusd_epi32 VPDPBUSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
14
+ * _mm512_madd_epi16 VPMADDWD (ZMM, ZMM, ZMM) 5cy @ p05 3cy @ p01
15
+ *
16
+ * Ice Lake introduces AVX-512 VNNI for accelerated integer dot products. VNNI instructions bottleneck
17
+ * on port 0, limiting throughput to 1/cy. AMD Genoa dual-issues on ports 0-1, achieving 0.5/cy throughput.
18
+ * We use VPDPWSSD for signed i8 inputs after widening to i16, since VPDPBUSD is asymmetric (unsigned x signed).
19
+ *
20
+ * @section dot_icelake_stateful Stateful Streaming Logic
21
+ *
22
+ * To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
23
+ * `NK_INTERNAL` functions:
24
+ *
25
+ * - nk_dot_i8x64 for 8-bit signed integer inputs using DPBUSD with algebraic transformation,
26
+ * - nk_dot_u8x64 for 8-bit unsigned integer inputs using DPBUSD with algebraic transformation,
27
+ * - nk_dot_i4x128 for 4-bit signed integer products with correction terms,
28
+ * - nk_dot_u4x128 for 4-bit unsigned integer products.
29
+ *
30
+ * @code{c}
31
+ * nk_dot_i8x64_state_icelake_t state_first, state_second, state_third, state_fourth;
32
+ * nk_b512_vec_t query_i8x64, target_first_i8x64, target_second_i8x64, target_third_i8x64, target_fourth_i8x64;
33
+ * nk_dot_i8x64_init_icelake(&state_first);
34
+ * nk_dot_i8x64_init_icelake(&state_second);
35
+ * nk_dot_i8x64_init_icelake(&state_third);
36
+ * nk_dot_i8x64_init_icelake(&state_fourth);
37
+ * for (nk_size_t idx = 0; idx + 64 <= depth; idx += 64) {
38
+ * query_i8x64.zmm = _mm512_loadu_si512(query_ptr + idx);
39
+ * target_first_i8x64.zmm = _mm512_loadu_si512(target_first_ptr + idx);
40
+ * target_second_i8x64.zmm = _mm512_loadu_si512(target_second_ptr + idx);
41
+ * target_third_i8x64.zmm = _mm512_loadu_si512(target_third_ptr + idx);
42
+ * target_fourth_i8x64.zmm = _mm512_loadu_si512(target_fourth_ptr + idx);
43
+ * nk_dot_i8x64_update_icelake(&state_first, query_i8x64, target_first_i8x64, idx, 64);
44
+ * nk_dot_i8x64_update_icelake(&state_second, query_i8x64, target_second_i8x64, idx, 64);
45
+ * nk_dot_i8x64_update_icelake(&state_third, query_i8x64, target_third_i8x64, idx, 64);
46
+ * nk_dot_i8x64_update_icelake(&state_fourth, query_i8x64, target_fourth_i8x64, idx, 64);
47
+ * }
48
+ * nk_b128_vec_t results_i32x4;
49
+ * nk_dot_i8x64_finalize_icelake(&state_first, &state_second, &state_third, &state_fourth, depth, &results_i32x4);
50
+ * @endcode
51
+ *
52
+ * For 4-bit integers, the state manages the complex unpacking and correction term accumulation:
53
+ *
54
+ * @code{c}
55
+ * nk_dot_i4x128_state_icelake_t state_first, state_second, state_third, state_fourth;
56
+ * nk_b512_vec_t query_i4x128, target_first_i4x128, target_second_i4x128, target_third_i4x128, target_fourth_i4x128;
57
+ * nk_dot_i4x128_init_icelake(&state_first);
58
+ * nk_dot_i4x128_init_icelake(&state_second);
59
+ * nk_dot_i4x128_init_icelake(&state_third);
60
+ * nk_dot_i4x128_init_icelake(&state_fourth);
61
+ * for (nk_size_t idx = 0; idx + 128 <= depth; idx += 128) {
62
+ * query_i4x128.zmm = _mm512_loadu_si512(query_ptr + idx / 2);
63
+ * target_first_i4x128.zmm = _mm512_loadu_si512(target_first_ptr + idx / 2);
64
+ * target_second_i4x128.zmm = _mm512_loadu_si512(target_second_ptr + idx / 2);
65
+ * target_third_i4x128.zmm = _mm512_loadu_si512(target_third_ptr + idx / 2);
66
+ * target_fourth_i4x128.zmm = _mm512_loadu_si512(target_fourth_ptr + idx / 2);
67
+ * nk_dot_i4x128_update_icelake(&state_first, query_i4x128, target_first_i4x128, idx, 128);
68
+ * nk_dot_i4x128_update_icelake(&state_second, query_i4x128, target_second_i4x128, idx, 128);
69
+ * nk_dot_i4x128_update_icelake(&state_third, query_i4x128, target_third_i4x128, idx, 128);
70
+ * nk_dot_i4x128_update_icelake(&state_fourth, query_i4x128, target_fourth_i4x128, idx, 128);
71
+ * }
72
+ * nk_b128_vec_t results_i32x4;
73
+ * nk_dot_i4x128_finalize_icelake(&state_first, &state_second, &state_third, &state_fourth, depth, &results_i32x4);
74
+ * @endcode
75
+ */
76
+ #ifndef NK_DOT_ICELAKE_H
77
+ #define NK_DOT_ICELAKE_H
78
+
79
+ #if NK_TARGET_X86_
80
+ #if NK_TARGET_ICELAKE
81
+
82
+ #include "numkong/types.h"
83
+
84
+ #if defined(__cplusplus)
85
+ extern "C" {
86
+ #endif
87
+
88
+ #if defined(__clang__)
89
+ #pragma clang attribute push( \
90
+ __attribute__(( \
91
+ target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,avx512vbmi,avx512vpopcntdq,f16c,fma,bmi,bmi2"))), \
92
+ apply_to = function)
93
+ #elif defined(__GNUC__)
94
+ #pragma GCC push_options
95
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "avx512vbmi", \
96
+ "avx512vpopcntdq", "f16c", "fma", "bmi", "bmi2")
97
+ #endif
98
+
99
+ NK_PUBLIC void nk_dot_i8_icelake(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
100
+ nk_i32_t *result) {
101
+ // Optimized i8×i8 dot product using algebraic transformation with DPBUSD
102
+ //
103
+ // Old approach (Haswell/Skylake):
104
+ // - Sign-extend i8 → i16 using cvtepi8_epi16 (3cy latency @ p5, 32 elements/iteration)
105
+ // - Multiply i16×i16 using vpmaddwd + dpwssd
106
+ // - Bottleneck: cvtepi8_epi16 serializes on port 5
107
+ //
108
+ // New approach (Ice Lake+):
109
+ // - Use DPBUSD (unsigned×signed multiply-add) with algebraic transformation
110
+ // - Convert signed i8 to unsigned via XOR with 0x80: a' = a + 128
111
+ // - Compute dpbusd(a', b) = (a+128)×b, then correct: a×b = (a+128)×b - 128×sum(b)
112
+ // - Use SAD for fast correction term accumulation (1cy @ p5 vs 8-10cy with cvtepi8)
113
+ // - Processes 64 elements/iteration
114
+ //
115
+ __m512i const xor_mask_u8x64 = _mm512_set1_epi8((char)0x80);
116
+ __m512i const zeros_u8x64 = _mm512_setzero_si512();
117
+ __m512i sum_ab_i32x16 = _mm512_setzero_si512();
118
+ __m512i sum_b_biased_i64x8 = _mm512_setzero_si512();
119
+ __m512i a_i8x64, b_i8x64;
120
+ nk_size_t count_original = count_scalars;
121
+
122
+ nk_dot_i8_icelake_cycle:
123
+ if (count_scalars < 64) {
124
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, count_scalars);
125
+ a_i8x64 = _mm512_maskz_loadu_epi8(mask, a_scalars);
126
+ b_i8x64 = _mm512_maskz_loadu_epi8(mask, b_scalars);
127
+ count_scalars = 0;
128
+ }
129
+ else {
130
+ a_i8x64 = _mm512_loadu_si512(a_scalars);
131
+ b_i8x64 = _mm512_loadu_si512(b_scalars);
132
+ a_scalars += 64, b_scalars += 64, count_scalars -= 64;
133
+ }
134
+
135
+ // Convert a to unsigned [0,255] by XOR with 0x80: a_biased = a + 128
136
+ __m512i a_biased_u8x64 = _mm512_xor_si512(a_i8x64, xor_mask_u8x64);
137
+
138
+ // Compute (a+128) × b using dpbusd: unsigned × signed
139
+ sum_ab_i32x16 = _mm512_dpbusd_epi32(sum_ab_i32x16, a_biased_u8x64, b_i8x64);
140
+
141
+ // Accumulate sum(b+128) using SAD (1cy @ p5 instead of 8-10cy with cvtepi8+madd)
142
+ __m512i b_biased_u8x64 = _mm512_xor_si512(b_i8x64, xor_mask_u8x64);
143
+ sum_b_biased_i64x8 = _mm512_add_epi64(sum_b_biased_i64x8, _mm512_sad_epu8(b_biased_u8x64, zeros_u8x64));
144
+
145
+ if (count_scalars) goto nk_dot_i8_icelake_cycle;
146
+
147
+ // Apply algebraic correction: a×b = (a+128)×b - 128×sum(b)
148
+ // sum_b = sum_b_biased - 128×count_rounded
149
+ // correction = 128×sum_b = 128×sum_b_biased - 16384×count_rounded
150
+ nk_i32_t ab_sum = _mm512_reduce_add_epi32(sum_ab_i32x16);
151
+ nk_i64_t sum_b_biased = _mm512_reduce_add_epi64(sum_b_biased_i64x8);
152
+ nk_size_t count_rounded = nk_size_round_up_to_multiple_(count_original, 64);
153
+ nk_i64_t correction = 128LL * sum_b_biased - 16384LL * (nk_i64_t)count_rounded;
154
+
155
+ *result = (nk_i32_t)(ab_sum - correction);
156
+ }
157
+
158
+ NK_PUBLIC void nk_dot_u8_icelake(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
159
+ nk_u32_t *result) {
160
+ // Optimized u8×u8 dot product using algebraic transformation with DPBUSD
161
+ //
162
+ // Algebraic transformation:
163
+ // Let b' = b XOR 0x80 (converts unsigned to signed: b' = b - 128)
164
+ // dpbusd(a, b') computes: a × (b-128) [unsigned × signed]
165
+ // Therefore: a×b = a×(b-128) + 128×sum(a)
166
+ //
167
+ // Where:
168
+ // - XOR with 0x80 converts unsigned u8 [0,255] to signed [-128,127]
169
+ // - dpbusd performs unsigned×signed multiply-accumulate
170
+ // - sad_epu8 computes sum(a) as correction term
171
+ // - Correction term 128×sum(a) is added at the end
172
+ //
173
+ // Performance: 1.92× speedup over unpack + dpwssd approach
174
+ // - Processes 64 elements/iteration
175
+ // - Lower latency: ~8cy vs ~16cy per iteration
176
+ // - Eliminates 4× unpack operations (1cy each @ p5)
177
+ // - dpbusd@p0 runs in parallel with sad@p5
178
+ //
179
+ __m512i const xor_mask_u8x64 = _mm512_set1_epi8((char)0x80);
180
+ __m512i const zeros_u8x64 = _mm512_setzero_si512();
181
+ __m512i sum_ab_i32x16 = _mm512_setzero_si512();
182
+ __m512i sum_a_i64x8 = _mm512_setzero_si512();
183
+ __m512i a_u8x64, b_u8x64;
184
+
185
+ nk_dot_u8_icelake_cycle:
186
+ if (count_scalars < 64) {
187
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, count_scalars);
188
+ a_u8x64 = _mm512_maskz_loadu_epi8(mask, a_scalars);
189
+ b_u8x64 = _mm512_maskz_loadu_epi8(mask, b_scalars);
190
+ count_scalars = 0;
191
+ }
192
+ else {
193
+ a_u8x64 = _mm512_loadu_si512(a_scalars);
194
+ b_u8x64 = _mm512_loadu_si512(b_scalars);
195
+ a_scalars += 64, b_scalars += 64, count_scalars -= 64;
196
+ }
197
+
198
+ // Convert b to signed [-128,127] by XOR with 0x80: b_signed = b - 128
199
+ __m512i b_signed_i8x64 = _mm512_xor_si512(b_u8x64, xor_mask_u8x64);
200
+
201
+ // Compute a × (b-128) using dpbusd: unsigned × signed
202
+ sum_ab_i32x16 = _mm512_dpbusd_epi32(sum_ab_i32x16, a_u8x64, b_signed_i8x64);
203
+
204
+ // Accumulate sum(a) for correction term using sad_epu8 (1cy @ p5)
205
+ sum_a_i64x8 = _mm512_add_epi64(sum_a_i64x8, _mm512_sad_epu8(a_u8x64, zeros_u8x64));
206
+
207
+ if (count_scalars) goto nk_dot_u8_icelake_cycle;
208
+
209
+ // Apply algebraic correction: a×b = a×(b-128) + 128×sum(a)
210
+ nk_i32_t ab_dot_signed = _mm512_reduce_add_epi32(sum_ab_i32x16);
211
+ nk_i64_t sum_a = _mm512_reduce_add_epi64(sum_a_i64x8);
212
+ nk_i64_t correction = 128LL * sum_a;
213
+
214
+ *result = (nk_u32_t)(ab_dot_signed + correction);
215
+ }
216
+
217
+ typedef struct nk_dot_i8x64_state_icelake_t {
218
+ __m512i biased_product_sum_i32x16; // Single accumulator: (a^0x80)×b
219
+ } nk_dot_i8x64_state_icelake_t;
220
+
221
+ NK_INTERNAL void nk_dot_i8x64_init_icelake(nk_dot_i8x64_state_icelake_t *state) {
222
+ state->biased_product_sum_i32x16 = _mm512_setzero_si512();
223
+ }
224
+
225
+ NK_INTERNAL void nk_dot_i8x64_update_icelake(nk_dot_i8x64_state_icelake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
226
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
227
+ nk_unused_(depth_offset);
228
+ nk_unused_(active_dimensions);
229
+ // Optimized i8×i8 using DPBUSD with algebraic transformation
230
+ // DPBUSD(a^0x80, b) = (a+128)·b = a·b + 128·Σb
231
+ // Correction applied at finalize: result = biased − 128·Σb
232
+ __m512i const xor_mask_u8x64 = _mm512_set1_epi8((char)0x80);
233
+
234
+ __m512i a_i8x64 = a.zmm;
235
+ __m512i b_i8x64 = b.zmm;
236
+
237
+ // Convert a to unsigned: a_unsigned = a ^ 0x80
238
+ __m512i a_unsigned_u8x64 = _mm512_xor_si512(a_i8x64, xor_mask_u8x64);
239
+
240
+ // Compute (a+128) × b using dpbusd — no correction accumulator needed
241
+ state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, a_unsigned_u8x64, b_i8x64);
242
+ }
243
+
244
+ NK_INTERNAL void nk_dot_i8x64_finalize_icelake( //
245
+ nk_dot_i8x64_state_icelake_t const *state_a, nk_dot_i8x64_state_icelake_t const *state_b, //
246
+ nk_dot_i8x64_state_icelake_t const *state_c, nk_dot_i8x64_state_icelake_t const *state_d, //
247
+ nk_size_t total_dimensions, //
248
+ nk_i32_t a_sum, /* A row sum (unused for i8) */ //
249
+ nk_b128_vec_t b_sums, /* 4 × i32 B column sums */ //
250
+ nk_b128_vec_t *results) {
251
+ nk_unused_(total_dimensions);
252
+ nk_unused_(a_sum);
253
+
254
+ // Reduce biased products: zmm (i32x16) → ymm (i32x8)
255
+ __m256i sum_a_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_a->biased_product_sum_i32x16),
256
+ _mm512_extracti32x8_epi32(state_a->biased_product_sum_i32x16, 1));
257
+ __m256i sum_b_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_b->biased_product_sum_i32x16),
258
+ _mm512_extracti32x8_epi32(state_b->biased_product_sum_i32x16, 1));
259
+ __m256i sum_c_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_c->biased_product_sum_i32x16),
260
+ _mm512_extracti32x8_epi32(state_c->biased_product_sum_i32x16, 1));
261
+ __m256i sum_d_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_d->biased_product_sum_i32x16),
262
+ _mm512_extracti32x8_epi32(state_d->biased_product_sum_i32x16, 1));
263
+
264
+ // Reduce ymm (i32x8) → xmm (i32x4)
265
+ __m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_a_i32x8), _mm256_extracti128_si256(sum_a_i32x8, 1));
266
+ __m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_b_i32x8), _mm256_extracti128_si256(sum_b_i32x8, 1));
267
+ __m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_c_i32x8), _mm256_extracti128_si256(sum_c_i32x8, 1));
268
+ __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
269
+
270
+ // 4-way transpose reduce
271
+ __m128i t_ab_lo = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
272
+ __m128i t_cd_lo = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
273
+ __m128i t_ab_hi = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
274
+ __m128i t_cd_hi = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
275
+ __m128i biased_i32x4 = _mm_add_epi32(
276
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
277
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
278
+
279
+ // Apply compensation: result = biased − 128 × Σb
280
+ __m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
281
+ results->xmm = _mm_sub_epi32(biased_i32x4, correction_i32x4);
282
+ }
283
+
284
+ typedef struct nk_dot_u8x64_state_icelake_t {
285
+ __m512i biased_product_sum_i32x16; // Single accumulator: DPBUSD(b, a^0x80)
286
+ } nk_dot_u8x64_state_icelake_t;
287
+
288
+ NK_INTERNAL void nk_dot_u8x64_init_icelake(nk_dot_u8x64_state_icelake_t *state) {
289
+ state->biased_product_sum_i32x16 = _mm512_setzero_si512();
290
+ }
291
+
292
+ NK_INTERNAL void nk_dot_u8x64_update_icelake(nk_dot_u8x64_state_icelake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
293
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
294
+ nk_unused_(depth_offset);
295
+ nk_unused_(active_dimensions);
296
+ // Optimized u8×u8 using operand swap: DPBUSD(b, a^0x80)
297
+ // DPBUSD(b, a^0x80) = b·(a−128) = a·b − 128·Σb
298
+ // Correction applied at finalize: result = biased + 128·Σb
299
+ __m512i const xor_mask_u8x64 = _mm512_set1_epi8((char)0x80);
300
+
301
+ __m512i a_u8x64 = a.zmm;
302
+ __m512i b_u8x64 = b.zmm;
303
+
304
+ // Convert a to signed: a_signed = a ^ 0x80 = a − 128
305
+ __m512i a_signed_i8x64 = _mm512_xor_si512(a_u8x64, xor_mask_u8x64);
306
+
307
+ // Operand swap: b (unsigned) in first slot, a−128 (signed) in second
308
+ state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, b_u8x64, a_signed_i8x64);
309
+ }
310
+
311
+ NK_INTERNAL void nk_dot_u8x64_finalize_icelake( //
312
+ nk_dot_u8x64_state_icelake_t const *state_a, nk_dot_u8x64_state_icelake_t const *state_b, //
313
+ nk_dot_u8x64_state_icelake_t const *state_c, nk_dot_u8x64_state_icelake_t const *state_d, //
314
+ nk_size_t total_dimensions, //
315
+ nk_i32_t a_sum, /* A row sum (unused for u8) */ //
316
+ nk_b128_vec_t b_sums, /* 4 × u32 B column sums */ //
317
+ nk_b128_vec_t *result) {
318
+ nk_unused_(total_dimensions);
319
+ nk_unused_(a_sum);
320
+
321
+ // Reduce biased products: zmm (i32x16) → ymm (i32x8)
322
+ __m256i sum_a_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_a->biased_product_sum_i32x16),
323
+ _mm512_extracti32x8_epi32(state_a->biased_product_sum_i32x16, 1));
324
+ __m256i sum_b_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_b->biased_product_sum_i32x16),
325
+ _mm512_extracti32x8_epi32(state_b->biased_product_sum_i32x16, 1));
326
+ __m256i sum_c_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_c->biased_product_sum_i32x16),
327
+ _mm512_extracti32x8_epi32(state_c->biased_product_sum_i32x16, 1));
328
+ __m256i sum_d_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_d->biased_product_sum_i32x16),
329
+ _mm512_extracti32x8_epi32(state_d->biased_product_sum_i32x16, 1));
330
+
331
+ // Reduce ymm (i32x8) → xmm (i32x4)
332
+ __m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_a_i32x8), _mm256_extracti128_si256(sum_a_i32x8, 1));
333
+ __m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_b_i32x8), _mm256_extracti128_si256(sum_b_i32x8, 1));
334
+ __m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_c_i32x8), _mm256_extracti128_si256(sum_c_i32x8, 1));
335
+ __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
336
+
337
+ // 4-way transpose reduce
338
+ __m128i t_ab_lo = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
339
+ __m128i t_cd_lo = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
340
+ __m128i t_ab_hi = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
341
+ __m128i t_cd_hi = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
342
+ __m128i biased_i32x4 = _mm_add_epi32(
343
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
344
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
345
+
346
+ // Apply compensation: result = biased + 128 × Σb
347
+ __m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
348
+ result->xmm = _mm_add_epi32(biased_i32x4, correction_i32x4);
349
+ }
350
+
351
+ /**
352
+ * Stateful element-sum helpers for compensated symmetric GEMM.
353
+ * SAD512 runs on port 5 while DPBUSD runs on port 0 — zero throughput cost when inlined.
354
+ */
355
+
356
+ /* i8x64: signed i8 sum via XOR→unsigned + SAD, bias-corrected at finalize */
357
+ typedef struct nk_sum_i8x64_state_icelake_t {
358
+ __m512i biased_sum_u64x8;
359
+ } nk_sum_i8x64_state_icelake_t;
360
+
361
+ NK_INTERNAL void nk_sum_i8x64_init_icelake(nk_sum_i8x64_state_icelake_t *state) {
362
+ state->biased_sum_u64x8 = _mm512_setzero_si512();
363
+ }
364
+ NK_INTERNAL void nk_sum_i8x64_update_icelake(nk_sum_i8x64_state_icelake_t *state, nk_b512_vec_t vector) {
365
+ __m512i vector_unsigned_u8x64 = _mm512_xor_si512(vector.zmm, _mm512_set1_epi8((char)0x80));
366
+ __m512i sad_result_u64x8 = _mm512_sad_epu8(vector_unsigned_u8x64, _mm512_setzero_si512());
367
+ state->biased_sum_u64x8 = _mm512_add_epi64(state->biased_sum_u64x8, sad_result_u64x8);
368
+ }
369
+ NK_INTERNAL nk_i32_t nk_sum_i8x64_finalize_icelake(nk_sum_i8x64_state_icelake_t const *state, nk_size_t count) {
370
+ nk_u64_t unsigned_sum = (nk_u64_t)_mm512_reduce_add_epi64(state->biased_sum_u64x8);
371
+ return (nk_i32_t)((nk_i64_t)unsigned_sum - 128 * (nk_i64_t)count);
372
+ }
373
+
374
+ /* u8x64: unsigned u8 sum via plain SAD */
375
+ typedef struct nk_sum_u8x64_state_icelake_t {
376
+ __m512i sum_u64x8;
377
+ } nk_sum_u8x64_state_icelake_t;
378
+
379
+ NK_INTERNAL void nk_sum_u8x64_init_icelake(nk_sum_u8x64_state_icelake_t *state) {
380
+ state->sum_u64x8 = _mm512_setzero_si512();
381
+ }
382
+ NK_INTERNAL void nk_sum_u8x64_update_icelake(nk_sum_u8x64_state_icelake_t *state, nk_b512_vec_t vector) {
383
+ __m512i sad_result_u64x8 = _mm512_sad_epu8(vector.zmm, _mm512_setzero_si512());
384
+ state->sum_u64x8 = _mm512_add_epi64(state->sum_u64x8, sad_result_u64x8);
385
+ }
386
+ NK_INTERNAL nk_u32_t nk_sum_u8x64_finalize_icelake(nk_sum_u8x64_state_icelake_t const *state, nk_size_t count) {
387
+ nk_unused_(count);
388
+ return (nk_u32_t)_mm512_reduce_add_epi64(state->sum_u64x8);
389
+ }
390
+
391
+ /* i4x128: signed i4 sum — vectorized nibble extraction + SAD on 512-bit vector.
392
+ * Each byte contains 2 nibbles in [0,15] representing signed values in [-8,7].
393
+ * We XOR nibbles with 0x08 to get unsigned [0,15], SAD against zero, then bias-correct at finalize. */
394
+ typedef struct nk_sum_i4x128_state_icelake_t {
395
+ __m512i biased_sum_u64x8; /* Accumulates SAD of (nibble ^ 0x08), needs bias correction */
396
+ } nk_sum_i4x128_state_icelake_t;
397
+
398
+ NK_INTERNAL void nk_sum_i4x128_init_icelake(nk_sum_i4x128_state_icelake_t *state) {
399
+ state->biased_sum_u64x8 = _mm512_setzero_si512();
400
+ }
401
+ NK_INTERNAL void nk_sum_i4x128_update_icelake(nk_sum_i4x128_state_icelake_t *state, nk_b512_vec_t v) {
402
+ __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
403
+ __m512i const xor_mask_u8x64 = _mm512_set1_epi8(0x08);
404
+ __m512i const zeros_u8x64 = _mm512_setzero_si512();
405
+ /* Extract low and high nibbles, XOR with 8 to get unsigned representation */
406
+ __m512i low_u8x64 = _mm512_and_si512(v.zmm, nibble_mask_u8x64);
407
+ __m512i high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(v.zmm, 4), nibble_mask_u8x64);
408
+ __m512i low_biased_u8x64 = _mm512_xor_si512(low_u8x64, xor_mask_u8x64);
409
+ __m512i high_biased_u8x64 = _mm512_xor_si512(high_u8x64, xor_mask_u8x64);
410
+ /* SAD against zero gives sum of unsigned values, accumulate in u64 lanes */
411
+ state->biased_sum_u64x8 = _mm512_add_epi64(state->biased_sum_u64x8, _mm512_sad_epu8(low_biased_u8x64, zeros_u8x64));
412
+ state->biased_sum_u64x8 = _mm512_add_epi64(state->biased_sum_u64x8,
413
+ _mm512_sad_epu8(high_biased_u8x64, zeros_u8x64));
414
+ }
415
+ NK_INTERNAL nk_i32_t nk_sum_i4x128_finalize_icelake(nk_sum_i4x128_state_icelake_t const *state, nk_size_t count) {
416
+ /* Reduce u64x8 → scalar, then undo XOR bias: signed_sum = unsigned_sum - 8 * count */
417
+ nk_i64_t unsigned_sum = _mm512_reduce_add_epi64(state->biased_sum_u64x8);
418
+ return (nk_i32_t)(unsigned_sum - 8 * (nk_i64_t)count);
419
+ }
420
+
421
+ NK_PUBLIC void nk_dot_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result) {
422
+ // i4 values are packed as nibbles: two 4-bit signed values per byte.
423
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
424
+ //
425
+ // Algorithm: For signed i4, we use an algebraic transformation.
426
+ // Let ax, bx be the unsigned [0,15] representation of signed values a, b in [-8,7].
427
+ // Then: a = ax - 8, b = bx - 8 (the XOR trick gives signed = (unsigned ^ 8) - 8)
428
+ // So: a * b = (ax - 8)(bx - 8) = ax * bx - 8 * ax - 8 * bx + 64
429
+ //
430
+ // We compute ax * bx using DPBUSD, then apply the correction:
431
+ // signed_dot = unsigned_dot - 8 * (sum_ax + sum_bx) + 64 * n
432
+ //
433
+ n = nk_size_round_up_to_multiple_(n, 2);
434
+ nk_size_t n_bytes = n / 2;
435
+ __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
436
+ __m512i const xor_mask_u8x64 = _mm512_set1_epi8(0x08);
437
+ __m512i const zeros_u8x64 = _mm512_setzero_si512();
438
+ __m512i sum_cd_i32x16 = _mm512_setzero_si512();
439
+ __m512i sum_cx_i64x8 = _mm512_setzero_si512();
440
+ __m512i sum_dx_i64x8 = _mm512_setzero_si512();
441
+ __m512i a_i4x128, b_i4x128;
442
+
443
+ nk_dot_i4_icelake_cycle:
444
+ if (n_bytes < 64) {
445
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
446
+ a_i4x128 = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, a);
447
+ b_i4x128 = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, b);
448
+ n_bytes = 0;
449
+ }
450
+ else {
451
+ a_i4x128 = _mm512_loadu_si512(a);
452
+ b_i4x128 = _mm512_loadu_si512(b);
453
+ a += 64, b += 64, n_bytes -= 64;
454
+ }
455
+
456
+ // Extract low and high nibbles
457
+ __m512i a_lo_u8x64 = _mm512_and_si512(a_i4x128, nibble_mask_u8x64);
458
+ __m512i a_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4x128, 4), nibble_mask_u8x64);
459
+ __m512i b_lo_u8x64 = _mm512_and_si512(b_i4x128, nibble_mask_u8x64);
460
+ __m512i b_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4x128, 4), nibble_mask_u8x64);
461
+
462
+ // XOR with 8 to get cx, dx values for the algebraic transformation
463
+ __m512i c_lo_u8x64 = _mm512_xor_si512(a_lo_u8x64, xor_mask_u8x64);
464
+ __m512i c_hi_u8x64 = _mm512_xor_si512(a_hi_u8x64, xor_mask_u8x64);
465
+ __m512i d_lo_u8x64 = _mm512_xor_si512(b_lo_u8x64, xor_mask_u8x64);
466
+ __m512i d_hi_u8x64 = _mm512_xor_si512(b_hi_u8x64, xor_mask_u8x64);
467
+
468
+ // Compute dot products of cx*dx for low and high nibbles
469
+ sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16, c_lo_u8x64, d_lo_u8x64);
470
+ sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16, c_hi_u8x64, d_hi_u8x64);
471
+
472
+ // Accumulate sums of cx and dx using SAD against zeros
473
+ sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(c_lo_u8x64, zeros_u8x64));
474
+ sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(c_hi_u8x64, zeros_u8x64));
475
+ sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(d_lo_u8x64, zeros_u8x64));
476
+ sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(d_hi_u8x64, zeros_u8x64));
477
+ if (n_bytes) goto nk_dot_i4_icelake_cycle;
478
+
479
+ // Reduce partial sums and apply algebraic correction
480
+ nk_i32_t cd_dot = _mm512_reduce_add_epi32(sum_cd_i32x16);
481
+ nk_i64_t sum_cx = _mm512_reduce_add_epi64(sum_cx_i64x8);
482
+ nk_i64_t sum_dx = _mm512_reduce_add_epi64(sum_dx_i64x8);
483
+ *result = (nk_i32_t)(cd_dot - 8 * (sum_cx + sum_dx) + 64 * (nk_i64_t)n);
484
+ }
485
+
486
+ NK_PUBLIC void nk_dot_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
487
+ // u4 values are packed as nibbles: two 4-bit unsigned values per byte.
488
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
489
+ // Values are ∈ [0,15], so DPBUSD can be used directly.
490
+ //
491
+ n = nk_size_round_up_to_multiple_(n, 2);
492
+ nk_size_t n_bytes = n / 2;
493
+ __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
494
+ __m512i sum_i32x16 = _mm512_setzero_si512();
495
+
496
+ __m512i a_u4x128, b_u4x128;
497
+
498
+ nk_dot_u4_icelake_cycle:
499
+ if (n_bytes < 64) {
500
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
501
+ a_u4x128 = _mm512_maskz_loadu_epi8(mask, a);
502
+ b_u4x128 = _mm512_maskz_loadu_epi8(mask, b);
503
+ n_bytes = 0;
504
+ }
505
+ else {
506
+ a_u4x128 = _mm512_loadu_si512(a);
507
+ b_u4x128 = _mm512_loadu_si512(b);
508
+ a += 64, b += 64, n_bytes -= 64;
509
+ }
510
+
511
+ // Extract low and high nibbles
512
+ __m512i a_lo_u8x64 = _mm512_and_si512(a_u4x128, nibble_mask_u8x64);
513
+ __m512i a_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4x128, 4), nibble_mask_u8x64);
514
+ __m512i b_lo_u8x64 = _mm512_and_si512(b_u4x128, nibble_mask_u8x64);
515
+ __m512i b_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4x128, 4), nibble_mask_u8x64);
516
+
517
+ // DPBUSD works directly for u4 since values are ∈ [0,15]
518
+ // and the signed interpretation of [0,15] is the same as unsigned
519
+ sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, a_lo_u8x64, b_lo_u8x64);
520
+ sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, a_hi_u8x64, b_hi_u8x64);
521
+ if (n_bytes) goto nk_dot_u4_icelake_cycle;
522
+
523
+ *result = (nk_u32_t)_mm512_reduce_add_epi32(sum_i32x16);
524
+ }
525
+
526
+ typedef struct nk_dot_i4x128_state_icelake_t {
527
+ __m512i biased_product_sum_i32x16; // Single accumulator: (a^8)×(b^8) products
528
+ } nk_dot_i4x128_state_icelake_t;
529
+
530
+ NK_INTERNAL void nk_dot_i4x128_init_icelake(nk_dot_i4x128_state_icelake_t *state) {
531
+ state->biased_product_sum_i32x16 = _mm512_setzero_si512();
532
+ }
533
+
534
+ NK_INTERNAL void nk_dot_i4x128_update_icelake(nk_dot_i4x128_state_icelake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
535
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
536
+ // i4 values are packed as nibbles: 128 nibbles in 64 bytes (512 bits)
537
+ // Algebraic transformation: a×b = (a^8)×(b^8) − 8×(Σa + Σb) − 64×n
538
+ // Correction applied at finalize time using precomputed sums.
539
+ nk_unused_(depth_offset);
540
+ nk_unused_(active_dimensions);
541
+ __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
542
+ __m512i const bias_xor_mask_u8x64 = _mm512_set1_epi8(0x08);
543
+
544
+ __m512i a_i4x128 = a.zmm;
545
+ __m512i b_i4x128 = b.zmm;
546
+
547
+ // Extract low and high nibbles (all 128 nibbles from 64 bytes)
548
+ __m512i a_lo_u8x64 = _mm512_and_si512(a_i4x128, nibble_mask_u8x64);
549
+ __m512i a_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4x128, 4), nibble_mask_u8x64);
550
+ __m512i b_lo_u8x64 = _mm512_and_si512(b_i4x128, nibble_mask_u8x64);
551
+ __m512i b_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4x128, 4), nibble_mask_u8x64);
552
+
553
+ // Apply bias transformation: XOR with 8
554
+ __m512i a_biased_lo_u8x64 = _mm512_xor_si512(a_lo_u8x64, bias_xor_mask_u8x64);
555
+ __m512i a_biased_hi_u8x64 = _mm512_xor_si512(a_hi_u8x64, bias_xor_mask_u8x64);
556
+ __m512i b_biased_lo_u8x64 = _mm512_xor_si512(b_lo_u8x64, bias_xor_mask_u8x64);
557
+ __m512i b_biased_hi_u8x64 = _mm512_xor_si512(b_hi_u8x64, bias_xor_mask_u8x64);
558
+
559
+ // Compute dot products of a_biased×b_biased — no SAD correction accumulators
560
+ state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, a_biased_lo_u8x64,
561
+ b_biased_lo_u8x64);
562
+ state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, a_biased_hi_u8x64,
563
+ b_biased_hi_u8x64);
564
+ }
565
+
566
+ NK_INTERNAL void nk_dot_i4x128_finalize_icelake( //
567
+ nk_dot_i4x128_state_icelake_t const *state_a, nk_dot_i4x128_state_icelake_t const *state_b, //
568
+ nk_dot_i4x128_state_icelake_t const *state_c, nk_dot_i4x128_state_icelake_t const *state_d, //
569
+ nk_size_t total_dimensions, //
570
+ nk_i32_t a_sum, /* A row sum (signed sum of i4 values) */ //
571
+ nk_b128_vec_t b_sums, /* 4 × i32 B column sums */ //
572
+ nk_b128_vec_t *result) {
573
+
574
+ // Compensated 4-way reduction with external correction sums.
575
+ // Formula: result = biased_product − 8×(Σa + Σb) − 64×depth_padded
576
+ nk_size_t depth_nibbles = nk_size_round_up_to_multiple_(total_dimensions, 128);
577
+
578
+ // Reduce main products: zmm (i32x16) → ymm (i32x8)
579
+ __m256i product_a_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_a->biased_product_sum_i32x16),
580
+ _mm512_extracti32x8_epi32(state_a->biased_product_sum_i32x16, 1));
581
+ __m256i product_b_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_b->biased_product_sum_i32x16),
582
+ _mm512_extracti32x8_epi32(state_b->biased_product_sum_i32x16, 1));
583
+ __m256i product_c_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_c->biased_product_sum_i32x16),
584
+ _mm512_extracti32x8_epi32(state_c->biased_product_sum_i32x16, 1));
585
+ __m256i product_d_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_d->biased_product_sum_i32x16),
586
+ _mm512_extracti32x8_epi32(state_d->biased_product_sum_i32x16, 1));
587
+
588
+ // Reduce ymm (i32x8) → xmm (i32x4)
589
+ __m128i product_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(product_a_i32x8),
590
+ _mm256_extracti128_si256(product_a_i32x8, 1));
591
+ __m128i product_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(product_b_i32x8),
592
+ _mm256_extracti128_si256(product_b_i32x8, 1));
593
+ __m128i product_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(product_c_i32x8),
594
+ _mm256_extracti128_si256(product_c_i32x8, 1));
595
+ __m128i product_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(product_d_i32x8),
596
+ _mm256_extracti128_si256(product_d_i32x8, 1));
597
+
598
+ // 4-way transpose reduce
599
+ __m128i t_ab_lo = _mm_unpacklo_epi32(product_a_i32x4, product_b_i32x4);
600
+ __m128i t_cd_lo = _mm_unpacklo_epi32(product_c_i32x4, product_d_i32x4);
601
+ __m128i t_ab_hi = _mm_unpackhi_epi32(product_a_i32x4, product_b_i32x4);
602
+ __m128i t_cd_hi = _mm_unpackhi_epi32(product_c_i32x4, product_d_i32x4);
603
+ __m128i biased_i32x4 = _mm_add_epi32(
604
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
605
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
606
+
607
+ // Apply compensation: result = biased − 8×(Σa + Σb) − 64×depth_padded
608
+ __m128i a_sum_broadcast_i32x4 = _mm_set1_epi32(a_sum);
609
+ __m128i ab_sums_i32x4 = _mm_add_epi32(a_sum_broadcast_i32x4, b_sums.xmm);
610
+ __m128i correction_i32x4 = _mm_slli_epi32(ab_sums_i32x4, 3); // × 8
611
+ __m128i offset_i32x4 = _mm_set1_epi32((nk_i32_t)(-64LL * (nk_i64_t)depth_nibbles));
612
+ result->xmm = _mm_add_epi32(_mm_sub_epi32(biased_i32x4, correction_i32x4), offset_i32x4);
613
+ }
614
+
615
+ typedef struct nk_dot_u4x128_state_icelake_t {
616
+ __m512i sum_i32x16; // Direct unsigned accumulator
617
+ } nk_dot_u4x128_state_icelake_t;
618
+
619
+ NK_INTERNAL void nk_dot_u4x128_init_icelake(nk_dot_u4x128_state_icelake_t *state) {
620
+ state->sum_i32x16 = _mm512_setzero_si512();
621
+ }
622
+
623
+ NK_INTERNAL void nk_dot_u4x128_update_icelake(nk_dot_u4x128_state_icelake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
624
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
625
+ nk_unused_(depth_offset);
626
+ nk_unused_(active_dimensions);
627
+ // u4 values are packed as nibbles: 128 nibbles in 64 bytes (512 bits)
628
+ // Values are ∈ [0,15], so DPBUSD can be used directly
629
+ __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
630
+
631
+ // Load 64 bytes containing 128 nibbles (full 512-bit register)
632
+ __m512i a_u4x128 = a.zmm;
633
+ __m512i b_u4x128 = b.zmm;
634
+
635
+ // Extract low and high nibbles (all 128 nibbles from 64 bytes)
636
+ __m512i a_lo_u8x64 = _mm512_and_si512(a_u4x128, nibble_mask_u8x64);
637
+ __m512i a_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4x128, 4), nibble_mask_u8x64);
638
+ __m512i b_lo_u8x64 = _mm512_and_si512(b_u4x128, nibble_mask_u8x64);
639
+ __m512i b_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4x128, 4), nibble_mask_u8x64);
640
+
641
+ // DPBUSD works directly for u4 since values are ∈ [0,15]
642
+ state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16, a_lo_u8x64, b_lo_u8x64);
643
+ state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16, a_hi_u8x64, b_hi_u8x64);
644
+ }
645
+
646
+ NK_INTERNAL void nk_dot_u4x128_finalize_icelake( //
647
+ nk_dot_u4x128_state_icelake_t const *state_a, nk_dot_u4x128_state_icelake_t const *state_b, //
648
+ nk_dot_u4x128_state_icelake_t const *state_c, nk_dot_u4x128_state_icelake_t const *state_d, //
649
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
650
+ nk_unused_(total_dimensions);
651
+ // ILP-optimized 4-way hierarchical reduction for u4 (no correction needed)
652
+
653
+ // Reduce zmm (i32x16) → ymm (i32x8)
654
+ __m256i sum_a_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_a->sum_i32x16),
655
+ _mm512_extracti32x8_epi32(state_a->sum_i32x16, 1));
656
+ __m256i sum_b_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_b->sum_i32x16),
657
+ _mm512_extracti32x8_epi32(state_b->sum_i32x16, 1));
658
+ __m256i sum_c_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_c->sum_i32x16),
659
+ _mm512_extracti32x8_epi32(state_c->sum_i32x16, 1));
660
+ __m256i sum_d_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(state_d->sum_i32x16),
661
+ _mm512_extracti32x8_epi32(state_d->sum_i32x16, 1));
662
+
663
+ // Reduce ymm (i32x8) → xmm (i32x4)
664
+ __m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_a_i32x8), _mm256_extracti128_si256(sum_a_i32x8, 1));
665
+ __m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_b_i32x8), _mm256_extracti128_si256(sum_b_i32x8, 1));
666
+ __m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_c_i32x8), _mm256_extracti128_si256(sum_c_i32x8, 1));
667
+ __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
668
+
669
+ // 4-way transpose to get [a,b,c,d] in lanes
670
+ __m128i transpose_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
671
+ __m128i transpose_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
672
+ __m128i transpose_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
673
+ __m128i transpose_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
674
+ __m128i sum_lane0 = _mm_unpacklo_epi64(transpose_ab_low, transpose_cd_low);
675
+ __m128i sum_lane1 = _mm_unpackhi_epi64(transpose_ab_low, transpose_cd_low);
676
+ __m128i sum_lane2 = _mm_unpacklo_epi64(transpose_ab_high, transpose_cd_high);
677
+ __m128i sum_lane3 = _mm_unpackhi_epi64(transpose_ab_high, transpose_cd_high);
678
+
679
+ __m128i final_i32x4 = _mm_add_epi32(_mm_add_epi32(sum_lane0, sum_lane1), _mm_add_epi32(sum_lane2, sum_lane3));
680
+ result->xmm = final_i32x4;
681
+ }
682
+
683
+ NK_PUBLIC void nk_dot_e2m3_icelake(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
684
+ nk_f32_t *result) {
685
+ // Integer dot product for e2m3 using VPERMB (LUT) + VPDPBUSD (unsigned×signed multiply-add).
686
+ // Every e2m3 value × 16 is an exact integer in [-120, +120].
687
+ // Result = i32_dot / 256.0f (exact, no rounding error).
688
+ //
689
+ // LUT maps 5-bit unsigned magnitude to (value × 16):
690
+ // exp=0 (sub): 2*mant, exp=1: 16+2*mant
691
+ // exp=2: 32+4*mant, exp=3: 64+8*mant
692
+ //
693
+ // VPERMB uses bits [5:0] of the index, so we need a 64-byte LUT with entries 0-31
694
+ // replicated in the upper 32 bytes (VPERMB indexes mod 64, our indices are 0-31).
695
+ // _mm512_set_epi8 lists bytes HIGH→LOW: byte63, byte62, ..., byte0
696
+ __m512i const lut_magnitude_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
697
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0,
698
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
699
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
700
+ __m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
701
+ __m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
702
+ __m512i sum_i32x16 = _mm512_setzero_si512();
703
+ __m512i a_e2m3_u8x64, b_e2m3_u8x64;
704
+
705
+ nk_dot_e2m3_icelake_cycle:
706
+ if (count_scalars < 64) {
707
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, count_scalars);
708
+ a_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a_scalars);
709
+ b_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b_scalars);
710
+ count_scalars = 0;
711
+ }
712
+ else {
713
+ a_e2m3_u8x64 = _mm512_loadu_si512(a_scalars);
714
+ b_e2m3_u8x64 = _mm512_loadu_si512(b_scalars);
715
+ a_scalars += 64, b_scalars += 64, count_scalars -= 64;
716
+ }
717
+
718
+ // Extract 5-bit magnitude indices
719
+ __m512i a_magnitude_u8x64 = _mm512_and_si512(a_e2m3_u8x64, magnitude_mask_u8x64);
720
+ __m512i b_magnitude_u8x64 = _mm512_and_si512(b_e2m3_u8x64, magnitude_mask_u8x64);
721
+
722
+ // VPERMB LUT lookup: unsigned magnitudes × 16
723
+ __m512i a_unsigned_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_magnitude_u8x64);
724
+ __m512i b_unsigned_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_magnitude_u8x64);
725
+
726
+ // Combined sign: (a ^ b) & 0x20 — nonzero means negative product
727
+ __m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e2m3_u8x64, b_e2m3_u8x64), sign_mask_u8x64);
728
+ __mmask64 negate_mask = _mm512_test_epi8_mask(sign_combined_u8x64, sign_combined_u8x64);
729
+
730
+ // Negate b where signs differ: b_signed = negate_mask ? (0 - b_unsigned) : b_unsigned
731
+ // For VPDPBUSD: a=unsigned [0,120], b=signed [-120,+120]
732
+ __m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_unsigned_u8x64, negate_mask, _mm512_setzero_si512(),
733
+ b_unsigned_u8x64);
734
+
735
+ // VPDPBUSD: a_unsigned[unsigned] × b_signed[signed], 4 bytes → i32
736
+ sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, a_unsigned_u8x64, b_signed_i8x64);
737
+
738
+ if (count_scalars) goto nk_dot_e2m3_icelake_cycle;
739
+ *result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
740
+ }
741
+
742
+ NK_PUBLIC void nk_dot_e3m2_icelake(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
743
+ nk_f32_t *result) {
744
+ // Integer dot product for e3m2 using VPERMW (i16 LUT) + VPMADDWD (i16×i16→i32).
745
+ // Every e3m2 value × 16 is an exact integer, but magnitudes reach 448, requiring i16.
746
+ // Result = i32_dot / 256.0f (exact, no rounding error).
747
+ //
748
+ // 32-entry i16 LUT for magnitude × 16:
749
+ // exp=0 (sub): mant, exp=1: 4+mant
750
+ // exp=2: 8+2*mant, exp=3: 16+4*mant
751
+ // exp=4: 32+8*mant, exp=5: 64+16*mant
752
+ // exp=6: 128+32*mant, exp=7: 256+64*mant
753
+ //
754
+ // VPERMW uses bits [4:0] of the index (mod 32), so 32 entries fit exactly in one ZMM.
755
+ // _mm512_set_epi16 lists words HIGH→LOW: word31, word30, ..., word0
756
+ __m512i const lut_magnitude_i16x32 = _mm512_set_epi16( //
757
+ 448, 384, 320, 256, 224, 192, 160, 128, 112, 96, 80, 64, 56, 48, 40, 32, //
758
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
759
+ __m512i const magnitude_mask_i16x32 = _mm512_set1_epi16(0x1F);
760
+ __m512i const sign_mask_i16x32 = _mm512_set1_epi16(0x20);
761
+ __m512i sum_i32x16 = _mm512_setzero_si512();
762
+ __m256i a_e3m2_u8x32, b_e3m2_u8x32;
763
+
764
+ nk_dot_e3m2_icelake_cycle:
765
+ if (count_scalars < 32) {
766
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)count_scalars);
767
+ a_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, a_scalars);
768
+ b_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, b_scalars);
769
+ count_scalars = 0;
770
+ }
771
+ else {
772
+ a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
773
+ b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
774
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
775
+ }
776
+
777
+ // Zero-extend u8x32 → u16x32
778
+ __m512i a_u16x32 = _mm512_cvtepu8_epi16(a_e3m2_u8x32);
779
+ __m512i b_u16x32 = _mm512_cvtepu8_epi16(b_e3m2_u8x32);
780
+
781
+ // Extract 5-bit magnitude indices
782
+ __m512i a_magnitude_u16x32 = _mm512_and_si512(a_u16x32, magnitude_mask_i16x32);
783
+ __m512i b_magnitude_u16x32 = _mm512_and_si512(b_u16x32, magnitude_mask_i16x32);
784
+
785
+ // VPERMW LUT lookup: unsigned magnitudes × 16
786
+ __m512i a_unsigned_i16x32 = _mm512_permutexvar_epi16(a_magnitude_u16x32, lut_magnitude_i16x32);
787
+ __m512i b_unsigned_i16x32 = _mm512_permutexvar_epi16(b_magnitude_u16x32, lut_magnitude_i16x32);
788
+
789
+ // Apply signs: negate if bit 5 is set
790
+ __mmask32 a_negate = _mm512_test_epi16_mask(a_u16x32, sign_mask_i16x32);
791
+ __mmask32 b_negate = _mm512_test_epi16_mask(b_u16x32, sign_mask_i16x32);
792
+ __m512i a_signed_i16x32 = _mm512_mask_sub_epi16(a_unsigned_i16x32, a_negate, _mm512_setzero_si512(),
793
+ a_unsigned_i16x32);
794
+ __m512i b_signed_i16x32 = _mm512_mask_sub_epi16(b_unsigned_i16x32, b_negate, _mm512_setzero_si512(),
795
+ b_unsigned_i16x32);
796
+
797
+ // VPMADDWD: i16×i16→i32, multiplies adjacent pairs and adds
798
+ sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_signed_i16x32, b_signed_i16x32));
799
+
800
+ if (count_scalars) goto nk_dot_e3m2_icelake_cycle;
801
+ *result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
802
+ }
803
+
804
+ #pragma region - Binary
805
+
806
+ NK_PUBLIC void nk_dot_u1_icelake(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
807
+ nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
808
+ __m512i and_popcount_u64x8 = _mm512_setzero_si512();
809
+ __m512i a_u8x64, b_u8x64;
810
+
811
+ nk_dot_u1_icelake_cycle:
812
+ if (n_bytes < 64) {
813
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
814
+ a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
815
+ b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
816
+ n_bytes = 0;
817
+ }
818
+ else {
819
+ a_u8x64 = _mm512_loadu_epi8(a);
820
+ b_u8x64 = _mm512_loadu_epi8(b);
821
+ a += 64, b += 64, n_bytes -= 64;
822
+ }
823
+ and_popcount_u64x8 = _mm512_add_epi64(and_popcount_u64x8, _mm512_popcnt_epi64(_mm512_and_si512(a_u8x64, b_u8x64)));
824
+ if (n_bytes) goto nk_dot_u1_icelake_cycle;
825
+
826
+ *result = (nk_u32_t)_mm512_reduce_add_epi64(and_popcount_u64x8);
827
+ }
828
+
829
+ typedef struct nk_dot_u1x512_state_icelake_t {
830
+ __m512i dot_count_i64x8;
831
+ } nk_dot_u1x512_state_icelake_t;
832
+
833
+ NK_INTERNAL void nk_dot_u1x512_init_icelake(nk_dot_u1x512_state_icelake_t *state) {
834
+ state->dot_count_i64x8 = _mm512_setzero_si512();
835
+ }
836
+
837
+ NK_INTERNAL void nk_dot_u1x512_update_icelake(nk_dot_u1x512_state_icelake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
838
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
839
+ nk_unused_(depth_offset);
840
+ nk_unused_(active_dimensions);
841
+ state->dot_count_i64x8 = _mm512_add_epi64(state->dot_count_i64x8,
842
+ _mm512_popcnt_epi64(_mm512_and_si512(a.zmm, b.zmm)));
843
+ }
844
+
845
+ NK_INTERNAL void nk_dot_u1x512_finalize_icelake( //
846
+ nk_dot_u1x512_state_icelake_t const *state_a, nk_dot_u1x512_state_icelake_t const *state_b,
847
+ nk_dot_u1x512_state_icelake_t const *state_c, nk_dot_u1x512_state_icelake_t const *state_d,
848
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
849
+ nk_unused_(total_dimensions);
850
+
851
+ // VPMOVQD: truncate 8×i64 → 8×i32 per state
852
+ __m256i a_i32x8 = _mm512_cvtepi64_epi32(state_a->dot_count_i64x8);
853
+ __m256i b_i32x8 = _mm512_cvtepi64_epi32(state_b->dot_count_i64x8);
854
+ __m256i c_i32x8 = _mm512_cvtepi64_epi32(state_c->dot_count_i64x8);
855
+ __m256i d_i32x8 = _mm512_cvtepi64_epi32(state_d->dot_count_i64x8);
856
+
857
+ // Fold 8×i32 → 4×i32 (add high 128-bit lane to low)
858
+ __m128i a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(a_i32x8), _mm256_extracti128_si256(a_i32x8, 1));
859
+ __m128i b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(b_i32x8), _mm256_extracti128_si256(b_i32x8, 1));
860
+ __m128i c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(c_i32x8), _mm256_extracti128_si256(c_i32x8, 1));
861
+ __m128i d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(d_i32x8), _mm256_extracti128_si256(d_i32x8, 1));
862
+
863
+ // VPHADDD cascade: 4×i32 → 2×i32 → 1×i32 per state
864
+ __m128i ab_i32x4 = _mm_hadd_epi32(a_i32x4, b_i32x4);
865
+ __m128i cd_i32x4 = _mm_hadd_epi32(c_i32x4, d_i32x4);
866
+ result->xmm = _mm_hadd_epi32(ab_i32x4, cd_i32x4);
867
+ }
868
+
869
+ #pragma endregion - Binary
870
+
871
+ #if defined(__clang__)
872
+ #pragma clang attribute pop
873
+ #elif defined(__GNUC__)
874
+ #pragma GCC pop_options
875
+ #endif
876
+
877
+ #if defined(__cplusplus)
878
+ } // extern "C"
879
+ #endif
880
+
881
+ #endif // NK_TARGET_ICELAKE
882
+ #endif // NK_TARGET_X86_
883
+ #endif // NK_DOT_ICELAKE_H