numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1688 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for Haswell.
3
+ * @file include/numkong/dot/haswell.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_haswell_instructions Key AVX2/FMA Dot Product Instructions
10
+ *
11
+ * Intrinsic Instruction Latency Throughput Ports
12
+ * _mm256_fmadd_ps/pd VFMADD (YMM, YMM, YMM) 5cy 0.5/cy p01
13
+ * _mm256_mul_ps/pd VMULPS/PD (YMM, YMM, YMM) 5cy 0.5/cy p01
14
+ * _mm256_add_ps/pd VADDPS/PD (YMM, YMM, YMM) 3cy 1/cy p01
15
+ * _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy 1/cy p01
16
+ * _mm256_cvtps_pd VCVTPS2PD (YMM, XMM) 2cy 1/cy p01
17
+ *
18
+ * For small numeric types (F16, BF16, E4M3, E5M2) we use F32 accumulators. For F32 dot products,
19
+ * upcasting to F64 and downcasting back is faster than stable summation algorithms. For F64 we
20
+ * use the Dot2 algorithm (Ogita-Rump-Oishi, 2005) for compensated accumulation via TwoSum/TwoProd.
21
+ * For F32 complex dot products, upcasting to F64 absorbs the deferred sign-flip error.
22
+ * BF16c and F16c use the same deferred sign-flip with F32 accumulators.
23
+ *
24
+ * @section dot_haswell_stateful Stateful Streaming Logic
25
+ *
26
+ * To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
27
+ * `NK_INTERNAL` functions:
28
+ *
29
+ * - nk_dot_f64x4 state with Dot2 stable dot-products,
30
+ * - nk_dot_f32x4 state with double-precision numerics,
31
+ * - nk_dot_through_f32 state for 16-, 8-, and 6-bit float inputs with single-precision numerics,
32
+ * - nk_dot_through_i32 state for 8-bit signed and unsigned integer inputs,
33
+ * - nk_dot_i4x32 for 4-bit signed integer products with 2 correction terms,
34
+ * - nk_dot_u4x32 for 4-bit unsigned integer products.
35
+ *
36
+ * @code{c}
37
+ * nk_dot_through_i32_state_haswell_t_ state_first, state_second, state_third, state_fourth;
38
+ * nk_b128_vec_t query_i8x16, target_first_i8x16, target_second_i8x16, target_third_i8x16, target_fourth_i8x16,
39
+ * nk_dot_through_i32_init_haswell_(&state_first);
40
+ * nk_dot_through_i32_init_haswell_(&state_second);
41
+ * nk_dot_through_i32_init_haswell_(&state_third);
42
+ * nk_dot_through_i32_init_haswell_(&state_fourth);
43
+ * for (nk_size_t idx = 0; idx + 16 <= depth; idx += 16) {
44
+ * query_i8x16.xmm = _mm_loadu_si128(query_ptr + idx);
45
+ * target_first_i8x16.xmm = _mm_loadu_si128(target_first_ptr + idx);
46
+ * target_second_i8x16.xmm = _mm_loadu_si128(target_second_ptr + idx);
47
+ * target_third_i8x16.xmm = _mm_loadu_si128(target_third_ptr + idx);
48
+ * target_fourth_i8x16.xmm = _mm_loadu_si128(target_fourth_ptr + idx);
49
+ * nk_dot_i8x16_update_haswell(&state_first, query_i8x16, target_first_i8x16, idx, 16);
50
+ * nk_dot_i8x16_update_haswell(&state_second, query_i8x16, target_second_i8x16, idx, 16);
51
+ * nk_dot_i8x16_update_haswell(&state_third, query_i8x16, target_third_i8x16, idx, 16);
52
+ * nk_dot_i8x16_update_haswell(&state_fourth, query_i8x16, target_fourth_i8x16, idx, 16);
53
+ * }
54
+ * nk_b128_vec_t results_i32x4;
55
+ * nk_dot_through_i32_finalize_haswell_(&state_first, &state_second, &state_third, &state_fourth,
56
+ * depth, &results_i32x4);
57
+ * @endcode
58
+ *
59
+ * Not every numeric type has dedicated dot-product SIMD circuitry on each ISA. Smaller float types
60
+ * like f16, bf16, e4m3, e5m2, e2m3, and e3m2 on Haswell use ISA-specific upcasting to f32 combined
61
+ * with native FMA instructions, sharing the `nk_dot_through_f32` accumulation logic:
62
+ *
63
+ * @code{c}
64
+ * nk_dot_e4m3x16_state_haswell_t state_first, state_second, state_third, state_fourth;
65
+ * nk_b256_vec_t query_f32x8, target_first_f32x8, target_second_f32x8, target_third_f32x8, target_fourth_f32x8;
66
+ * nk_dot_through_f32_init_haswell_(&state_first);
67
+ * nk_dot_through_f32_init_haswell_(&state_second);
68
+ * nk_dot_through_f32_init_haswell_(&state_third);
69
+ * nk_dot_through_f32_init_haswell_(&state_fourth);
70
+ * for (nk_size_t idx = 0; idx + 8 <= depth; idx += 8) {
71
+ * query_f32x8.ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64(query_ptr + idx));
72
+ * target_first_f32x8.ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64(target_first_ptr + idx));
73
+ * target_second_f32x8.ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64(target_second_ptr + idx));
74
+ * target_third_f32x8.ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64(target_third_ptr + idx));
75
+ * target_fourth_f32x8.ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64(target_fourth_ptr + idx));
76
+ * nk_dot_through_f32_update_haswell_(&state_first, query_f32x8, target_first_f32x8, idx, 8);
77
+ * nk_dot_through_f32_update_haswell_(&state_second, query_f32x8, target_second_f32x8, idx, 8);
78
+ * nk_dot_through_f32_update_haswell_(&state_third, query_f32x8, target_third_f32x8, idx, 8);
79
+ * nk_dot_through_f32_update_haswell_(&state_fourth, query_f32x8, target_fourth_f32x8, idx, 8);
80
+ * }
81
+ * nk_b128_vec_t results_f32x4;
82
+ * nk_dot_through_f32_finalize_haswell_(&state_first, &state_second, &state_third, &state_fourth,
83
+ * depth, &results_f32x4);
84
+ * @endcode
85
+ */
86
+ #ifndef NK_DOT_HASWELL_H
87
+ #define NK_DOT_HASWELL_H
88
+
89
+ #if NK_TARGET_X86_
90
+ #if NK_TARGET_HASWELL
91
+
92
+ #include "numkong/types.h"
93
+ #include "numkong/dot/serial.h"
94
+ #include "numkong/reduce/haswell.h"
95
+ #include "numkong/cast/haswell.h" // `nk_f32x8_to_bf16x8_haswell_`
96
+
97
+ #if defined(__cplusplus)
98
+ extern "C" {
99
+ #endif
100
+
101
+ #if defined(__clang__)
102
+ #pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
103
+ #elif defined(__GNUC__)
104
+ #pragma GCC push_options
105
+ #pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
106
+ #endif
107
+
108
+ /** @brief Compensated horizontal sum of 4 f64 lanes via TwoSum tree reduction.
109
+ * @sa nk_reduce_sum_f64_serial_ for the serial equivalent
110
+ */
111
+ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x4_haswell_(__m256d sum_f64x4, __m256d compensation_f64x4) {
112
+ // Stage 0: TwoSum merge of sum + compensation (4-wide, parallel)
113
+ __m256d tentative_sum_f64x4 = _mm256_add_pd(sum_f64x4, compensation_f64x4);
114
+ __m256d virtual_addend_f64x4 = _mm256_sub_pd(tentative_sum_f64x4, sum_f64x4);
115
+ __m256d rounding_error_f64x4 = _mm256_add_pd(
116
+ _mm256_sub_pd(sum_f64x4, _mm256_sub_pd(tentative_sum_f64x4, virtual_addend_f64x4)),
117
+ _mm256_sub_pd(compensation_f64x4, virtual_addend_f64x4));
118
+
119
+ // Stage 1: TwoSum halving 4→2
120
+ __m128d lower_sum_f64x2 = _mm256_castpd256_pd128(tentative_sum_f64x4);
121
+ __m128d upper_sum_f64x2 = _mm256_extractf128_pd(tentative_sum_f64x4, 1);
122
+ __m128d tentative_sum_f64x2 = _mm_add_pd(lower_sum_f64x2, upper_sum_f64x2);
123
+ __m128d virtual_addend_f64x2 = _mm_sub_pd(tentative_sum_f64x2, lower_sum_f64x2);
124
+ __m128d rounding_error_f64x2 = _mm_add_pd(
125
+ _mm_sub_pd(lower_sum_f64x2, _mm_sub_pd(tentative_sum_f64x2, virtual_addend_f64x2)),
126
+ _mm_sub_pd(upper_sum_f64x2, virtual_addend_f64x2));
127
+ // Accumulate errors: stage 0 errors (halved) + stage 1 rounding error
128
+ __m128d lower_error_f64x2 = _mm256_castpd256_pd128(rounding_error_f64x4);
129
+ __m128d upper_error_f64x2 = _mm256_extractf128_pd(rounding_error_f64x4, 1);
130
+ __m128d accumulated_error_f64x2 = _mm_add_pd(_mm_add_pd(lower_error_f64x2, upper_error_f64x2),
131
+ rounding_error_f64x2);
132
+
133
+ // Stage 2: Scalar TwoSum 2→1
134
+ nk_f64_t lower_sum = _mm_cvtsd_f64(tentative_sum_f64x2);
135
+ nk_f64_t upper_sum = _mm_cvtsd_f64(_mm_unpackhi_pd(tentative_sum_f64x2, tentative_sum_f64x2));
136
+ nk_f64_t lower_error = _mm_cvtsd_f64(accumulated_error_f64x2);
137
+ nk_f64_t upper_error = _mm_cvtsd_f64(_mm_unpackhi_pd(accumulated_error_f64x2, accumulated_error_f64x2));
138
+ nk_f64_t tentative_sum = lower_sum + upper_sum;
139
+ nk_f64_t virtual_addend = tentative_sum - lower_sum;
140
+ nk_f64_t rounding_error = (lower_sum - (tentative_sum - virtual_addend)) + (upper_sum - virtual_addend);
141
+ return tentative_sum + (lower_error + upper_error + rounding_error);
142
+ }
143
+
144
+ #pragma region - Traditional Floats
145
+
146
+ NK_PUBLIC void nk_dot_f32_haswell(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
147
+ nk_f64_t *result) {
148
+ __m256d sum_f64x4 = _mm256_setzero_pd();
149
+ nk_size_t idx_scalars = 0;
150
+ for (; idx_scalars + 4 <= count_scalars; idx_scalars += 4) {
151
+ __m128 a_f32x4 = _mm_loadu_ps(a_scalars + idx_scalars);
152
+ __m128 b_f32x4 = _mm_loadu_ps(b_scalars + idx_scalars);
153
+ __m256d a_f64x4 = _mm256_cvtps_pd(a_f32x4);
154
+ __m256d b_f64x4 = _mm256_cvtps_pd(b_f32x4);
155
+ sum_f64x4 = _mm256_fmadd_pd(a_f64x4, b_f64x4, sum_f64x4);
156
+ }
157
+ nk_f64_t sum = nk_reduce_add_f64x4_haswell_(sum_f64x4);
158
+ for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_f64_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
159
+ *result = sum;
160
+ }
161
+
162
+ NK_PUBLIC void nk_dot_f32c_haswell(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
163
+ nk_f64c_t *result) {
164
+ // Using XOR to flip sign bits is cheaper than separate FMA/FMS. Throughput doubles from 2.5 GB/s to 5 GB/s.
165
+ __m256d sum_real_f64x4 = _mm256_setzero_pd();
166
+ __m256d sum_imag_f64x4 = _mm256_setzero_pd();
167
+ __m256i sign_flip_i64x4 = _mm256_set_epi64x(0x8000000000000000, 0, 0x8000000000000000, 0);
168
+ nk_size_t idx_pairs = 0;
169
+ for (; idx_pairs + 2 <= count_pairs; idx_pairs += 2) {
170
+ __m128 a_f32x4 = _mm_loadu_ps((nk_f32_t const *)(a_pairs + idx_pairs));
171
+ __m128 b_f32x4 = _mm_loadu_ps((nk_f32_t const *)(b_pairs + idx_pairs));
172
+ __m256d a_f64x4 = _mm256_cvtps_pd(a_f32x4);
173
+ __m256d b_f64x4 = _mm256_cvtps_pd(b_f32x4);
174
+ __m256d b_swapped_f64x4 = _mm256_permute_pd(b_f64x4, 0x5); // 0b0101: swap adjacent pairs
175
+ sum_real_f64x4 = _mm256_fmadd_pd(a_f64x4, b_f64x4, sum_real_f64x4);
176
+ sum_imag_f64x4 = _mm256_fmadd_pd(a_f64x4, b_swapped_f64x4, sum_imag_f64x4);
177
+ }
178
+ // Flip the sign bit in every second f64 before accumulation:
179
+ sum_real_f64x4 = _mm256_castsi256_pd(_mm256_xor_si256(_mm256_castpd_si256(sum_real_f64x4), sign_flip_i64x4));
180
+ nk_f64_t sum_real = nk_reduce_add_f64x4_haswell_(sum_real_f64x4);
181
+ nk_f64_t sum_imag = nk_reduce_add_f64x4_haswell_(sum_imag_f64x4);
182
+ for (; idx_pairs != count_pairs; ++idx_pairs) {
183
+ nk_f32c_t a_pair = a_pairs[idx_pairs], b_pair = b_pairs[idx_pairs];
184
+ sum_real += (nk_f64_t)a_pair.real * b_pair.real - (nk_f64_t)a_pair.imag * b_pair.imag;
185
+ sum_imag += (nk_f64_t)a_pair.real * b_pair.imag + (nk_f64_t)a_pair.imag * b_pair.real;
186
+ }
187
+ result->real = sum_real;
188
+ result->imag = sum_imag;
189
+ }
190
+
191
+ NK_PUBLIC void nk_vdot_f32c_haswell(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
192
+ nk_f64c_t *result) {
193
+ __m256d sum_real_f64x4 = _mm256_setzero_pd();
194
+ __m256d sum_imag_f64x4 = _mm256_setzero_pd();
195
+ __m256i sign_flip_i64x4 = _mm256_set_epi64x(0x8000000000000000, 0, 0x8000000000000000, 0);
196
+ nk_size_t idx_pairs = 0;
197
+ for (; idx_pairs + 2 <= count_pairs; idx_pairs += 2) {
198
+ __m128 a_f32x4 = _mm_loadu_ps((nk_f32_t const *)(a_pairs + idx_pairs));
199
+ __m128 b_f32x4 = _mm_loadu_ps((nk_f32_t const *)(b_pairs + idx_pairs));
200
+ __m256d a_f64x4 = _mm256_cvtps_pd(a_f32x4);
201
+ __m256d b_f64x4 = _mm256_cvtps_pd(b_f32x4);
202
+ sum_real_f64x4 = _mm256_fmadd_pd(a_f64x4, b_f64x4, sum_real_f64x4);
203
+ __m256d b_swapped_f64x4 = _mm256_permute_pd(b_f64x4, 0x5); // 0b0101: swap adjacent pairs
204
+ sum_imag_f64x4 = _mm256_fmadd_pd(a_f64x4, b_swapped_f64x4, sum_imag_f64x4);
205
+ }
206
+ // Flip the sign bit in every second f64 before accumulation:
207
+ sum_imag_f64x4 = _mm256_castsi256_pd(_mm256_xor_si256(_mm256_castpd_si256(sum_imag_f64x4), sign_flip_i64x4));
208
+ nk_f64_t sum_real = nk_reduce_add_f64x4_haswell_(sum_real_f64x4);
209
+ nk_f64_t sum_imag = nk_reduce_add_f64x4_haswell_(sum_imag_f64x4);
210
+ for (; idx_pairs != count_pairs; ++idx_pairs) {
211
+ nk_f32c_t a_pair = a_pairs[idx_pairs], b_pair = b_pairs[idx_pairs];
212
+ sum_real += (nk_f64_t)a_pair.real * b_pair.real + (nk_f64_t)a_pair.imag * b_pair.imag;
213
+ sum_imag += (nk_f64_t)a_pair.real * b_pair.imag - (nk_f64_t)a_pair.imag * b_pair.real;
214
+ }
215
+ result->real = sum_real;
216
+ result->imag = sum_imag;
217
+ }
218
+
219
+ NK_PUBLIC void nk_dot_f64_haswell(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
220
+ nk_f64_t *result) {
221
+ // Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product
222
+ __m256d sum_f64x4 = _mm256_setzero_pd();
223
+ __m256d compensation_f64x4 = _mm256_setzero_pd();
224
+ __m256d a_f64x4, b_f64x4;
225
+
226
+ nk_dot_f64_haswell_cycle:
227
+ if (count_scalars < 4) {
228
+ nk_b256_vec_t a_tail, b_tail;
229
+ nk_partial_load_b64x4_haswell_(a_scalars, &a_tail, count_scalars);
230
+ nk_partial_load_b64x4_haswell_(b_scalars, &b_tail, count_scalars);
231
+ a_f64x4 = a_tail.ymm_pd;
232
+ b_f64x4 = b_tail.ymm_pd;
233
+ count_scalars = 0;
234
+ }
235
+ else {
236
+ a_f64x4 = _mm256_loadu_pd(a_scalars);
237
+ b_f64x4 = _mm256_loadu_pd(b_scalars);
238
+ a_scalars += 4, b_scalars += 4, count_scalars -= 4;
239
+ }
240
+
241
+ // TwoProd: h = a * b, r = fma(a, b, -h) captures the rounding error
242
+ __m256d product_f64x4 = _mm256_mul_pd(a_f64x4, b_f64x4);
243
+ __m256d product_error_f64x4 = _mm256_fmsub_pd(a_f64x4, b_f64x4, product_f64x4);
244
+ // TwoSum: (t, q) = TwoSum(sum, h) where t = sum + h rounded, q = error
245
+ __m256d tentative_sum_f64x4 = _mm256_add_pd(sum_f64x4, product_f64x4);
246
+ __m256d virtual_addend_f64x4 = _mm256_sub_pd(tentative_sum_f64x4, sum_f64x4);
247
+ __m256d sum_error_f64x4 = _mm256_add_pd(
248
+ _mm256_sub_pd(sum_f64x4, _mm256_sub_pd(tentative_sum_f64x4, virtual_addend_f64x4)),
249
+ _mm256_sub_pd(product_f64x4, virtual_addend_f64x4));
250
+ // Update: sum = t, compensation += q + r
251
+ sum_f64x4 = tentative_sum_f64x4;
252
+ compensation_f64x4 = _mm256_add_pd(compensation_f64x4, _mm256_add_pd(sum_error_f64x4, product_error_f64x4));
253
+
254
+ if (count_scalars) goto nk_dot_f64_haswell_cycle;
255
+ // Compensated horizontal reduction preserving Dot2 error tracking
256
+ *result = nk_dot_stable_sum_f64x4_haswell_(sum_f64x4, compensation_f64x4);
257
+ }
258
+
259
+ NK_PUBLIC void nk_dot_f64c_haswell(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
260
+ nk_f64c_t *result) {
261
+ // Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated complex dot product
262
+ __m256d sum_real_f64x4 = _mm256_setzero_pd();
263
+ __m256d sum_imag_f64x4 = _mm256_setzero_pd();
264
+ __m256d compensation_real_f64x4 = _mm256_setzero_pd();
265
+ __m256d compensation_imag_f64x4 = _mm256_setzero_pd();
266
+ __m256i sign_flip_i64x4 = _mm256_set_epi64x(0x8000000000000000, 0, 0x8000000000000000, 0);
267
+ __m256d a_f64x4, b_f64x4;
268
+
269
+ nk_dot_f64c_haswell_cycle:
270
+ if (count_pairs < 2) {
271
+ nk_b256_vec_t a_tail, b_tail;
272
+ nk_partial_load_b64x4_haswell_(a_pairs, &a_tail, count_pairs * 2);
273
+ nk_partial_load_b64x4_haswell_(b_pairs, &b_tail, count_pairs * 2);
274
+ a_f64x4 = a_tail.ymm_pd;
275
+ b_f64x4 = b_tail.ymm_pd;
276
+ count_pairs = 0;
277
+ }
278
+ else {
279
+ a_f64x4 = _mm256_loadu_pd((nk_f64_t const *)a_pairs);
280
+ b_f64x4 = _mm256_loadu_pd((nk_f64_t const *)b_pairs);
281
+ a_pairs += 2, b_pairs += 2, count_pairs -= 2;
282
+ }
283
+
284
+ __m256d b_swapped_f64x4 = _mm256_permute_pd(b_f64x4, 0x5); // 0b0101: swap adjacent pairs
285
+
286
+ // TwoProd for real part: a * b
287
+ __m256d product_real_f64x4 = _mm256_mul_pd(a_f64x4, b_f64x4);
288
+ __m256d product_real_error_f64x4 = _mm256_fmsub_pd(a_f64x4, b_f64x4, product_real_f64x4);
289
+ // TwoSum for real part
290
+ __m256d tentative_sum_real_f64x4 = _mm256_add_pd(sum_real_f64x4, product_real_f64x4);
291
+ __m256d virtual_addend_real_f64x4 = _mm256_sub_pd(tentative_sum_real_f64x4, sum_real_f64x4);
292
+ __m256d sum_real_error_f64x4 = _mm256_add_pd(
293
+ _mm256_sub_pd(sum_real_f64x4, _mm256_sub_pd(tentative_sum_real_f64x4, virtual_addend_real_f64x4)),
294
+ _mm256_sub_pd(product_real_f64x4, virtual_addend_real_f64x4));
295
+ sum_real_f64x4 = tentative_sum_real_f64x4;
296
+ compensation_real_f64x4 = _mm256_add_pd(compensation_real_f64x4,
297
+ _mm256_add_pd(sum_real_error_f64x4, product_real_error_f64x4));
298
+
299
+ // TwoProd for imag part: a * b_swapped
300
+ __m256d product_imag_f64x4 = _mm256_mul_pd(a_f64x4, b_swapped_f64x4);
301
+ __m256d product_imag_error_f64x4 = _mm256_fmsub_pd(a_f64x4, b_swapped_f64x4, product_imag_f64x4);
302
+ // TwoSum for imag part
303
+ __m256d tentative_sum_imag_f64x4 = _mm256_add_pd(sum_imag_f64x4, product_imag_f64x4);
304
+ __m256d virtual_addend_imag_f64x4 = _mm256_sub_pd(tentative_sum_imag_f64x4, sum_imag_f64x4);
305
+ __m256d sum_imag_error_f64x4 = _mm256_add_pd(
306
+ _mm256_sub_pd(sum_imag_f64x4, _mm256_sub_pd(tentative_sum_imag_f64x4, virtual_addend_imag_f64x4)),
307
+ _mm256_sub_pd(product_imag_f64x4, virtual_addend_imag_f64x4));
308
+ sum_imag_f64x4 = tentative_sum_imag_f64x4;
309
+ compensation_imag_f64x4 = _mm256_add_pd(compensation_imag_f64x4,
310
+ _mm256_add_pd(sum_imag_error_f64x4, product_imag_error_f64x4));
311
+
312
+ if (count_pairs) goto nk_dot_f64c_haswell_cycle;
313
+ // Flip sign in every second f64 for real part (to get a_r*b_r - a_i*b_i)
314
+ sum_real_f64x4 = _mm256_castsi256_pd(_mm256_xor_si256(_mm256_castpd_si256(sum_real_f64x4), sign_flip_i64x4));
315
+ compensation_real_f64x4 = _mm256_castsi256_pd(
316
+ _mm256_xor_si256(_mm256_castpd_si256(compensation_real_f64x4), sign_flip_i64x4));
317
+ // Compensated horizontal reduction preserving Dot2 error tracking
318
+ result->real = nk_dot_stable_sum_f64x4_haswell_(sum_real_f64x4, compensation_real_f64x4);
319
+ result->imag = nk_dot_stable_sum_f64x4_haswell_(sum_imag_f64x4, compensation_imag_f64x4);
320
+ }
321
+
322
+ NK_PUBLIC void nk_vdot_f64c_haswell(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
323
+ nk_f64c_t *result) {
324
+ // Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated conjugate dot product
325
+ __m256d sum_real_f64x4 = _mm256_setzero_pd();
326
+ __m256d sum_imag_f64x4 = _mm256_setzero_pd();
327
+ __m256d compensation_real_f64x4 = _mm256_setzero_pd();
328
+ __m256d compensation_imag_f64x4 = _mm256_setzero_pd();
329
+ __m256i sign_flip_i64x4 = _mm256_set_epi64x(0x8000000000000000, 0, 0x8000000000000000, 0);
330
+ __m256d a_f64x4, b_f64x4;
331
+
332
+ nk_vdot_f64c_haswell_cycle:
333
+ if (count_pairs < 2) {
334
+ nk_b256_vec_t a_tail, b_tail;
335
+ nk_partial_load_b64x4_haswell_(a_pairs, &a_tail, count_pairs * 2);
336
+ nk_partial_load_b64x4_haswell_(b_pairs, &b_tail, count_pairs * 2);
337
+ a_f64x4 = a_tail.ymm_pd;
338
+ b_f64x4 = b_tail.ymm_pd;
339
+ count_pairs = 0;
340
+ }
341
+ else {
342
+ a_f64x4 = _mm256_loadu_pd((nk_f64_t const *)a_pairs);
343
+ b_f64x4 = _mm256_loadu_pd((nk_f64_t const *)b_pairs);
344
+ a_pairs += 2, b_pairs += 2, count_pairs -= 2;
345
+ }
346
+
347
+ __m256d b_swapped_f64x4 = _mm256_permute_pd(b_f64x4, 0x5); // 0b0101: swap adjacent pairs
348
+
349
+ // TwoProd for real part: a * b
350
+ __m256d product_real_f64x4 = _mm256_mul_pd(a_f64x4, b_f64x4);
351
+ __m256d product_real_error_f64x4 = _mm256_fmsub_pd(a_f64x4, b_f64x4, product_real_f64x4);
352
+ // TwoSum for real part
353
+ __m256d tentative_sum_real_f64x4 = _mm256_add_pd(sum_real_f64x4, product_real_f64x4);
354
+ __m256d virtual_addend_real_f64x4 = _mm256_sub_pd(tentative_sum_real_f64x4, sum_real_f64x4);
355
+ __m256d sum_real_error_f64x4 = _mm256_add_pd(
356
+ _mm256_sub_pd(sum_real_f64x4, _mm256_sub_pd(tentative_sum_real_f64x4, virtual_addend_real_f64x4)),
357
+ _mm256_sub_pd(product_real_f64x4, virtual_addend_real_f64x4));
358
+ sum_real_f64x4 = tentative_sum_real_f64x4;
359
+ compensation_real_f64x4 = _mm256_add_pd(compensation_real_f64x4,
360
+ _mm256_add_pd(sum_real_error_f64x4, product_real_error_f64x4));
361
+
362
+ // TwoProd for imag part: a * b_swapped
363
+ __m256d product_imag_f64x4 = _mm256_mul_pd(a_f64x4, b_swapped_f64x4);
364
+ __m256d product_imag_error_f64x4 = _mm256_fmsub_pd(a_f64x4, b_swapped_f64x4, product_imag_f64x4);
365
+ // TwoSum for imag part
366
+ __m256d tentative_sum_imag_f64x4 = _mm256_add_pd(sum_imag_f64x4, product_imag_f64x4);
367
+ __m256d virtual_addend_imag_f64x4 = _mm256_sub_pd(tentative_sum_imag_f64x4, sum_imag_f64x4);
368
+ __m256d sum_imag_error_f64x4 = _mm256_add_pd(
369
+ _mm256_sub_pd(sum_imag_f64x4, _mm256_sub_pd(tentative_sum_imag_f64x4, virtual_addend_imag_f64x4)),
370
+ _mm256_sub_pd(product_imag_f64x4, virtual_addend_imag_f64x4));
371
+ sum_imag_f64x4 = tentative_sum_imag_f64x4;
372
+ compensation_imag_f64x4 = _mm256_add_pd(compensation_imag_f64x4,
373
+ _mm256_add_pd(sum_imag_error_f64x4, product_imag_error_f64x4));
374
+
375
+ if (count_pairs) goto nk_vdot_f64c_haswell_cycle;
376
+ // Flip sign in every second f64 for imag part (to get a_r*b_i - a_i*b_r)
377
+ sum_imag_f64x4 = _mm256_castsi256_pd(_mm256_xor_si256(_mm256_castpd_si256(sum_imag_f64x4), sign_flip_i64x4));
378
+ compensation_imag_f64x4 = _mm256_castsi256_pd(
379
+ _mm256_xor_si256(_mm256_castpd_si256(compensation_imag_f64x4), sign_flip_i64x4));
380
+ // Compensated horizontal reduction preserving Dot2 error tracking
381
+ result->real = nk_dot_stable_sum_f64x4_haswell_(sum_real_f64x4, compensation_real_f64x4);
382
+ result->imag = nk_dot_stable_sum_f64x4_haswell_(sum_imag_f64x4, compensation_imag_f64x4);
383
+ }
384
+
385
+ /**
386
+ * @brief Running state for 256-bit dot accumulation over f64 scalars on Haswell.
387
+ *
388
+ * Uses the Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product.
389
+ */
390
+ typedef struct nk_dot_f64x4_state_haswell_t {
391
+ __m256d sum_f64x4;
392
+ __m256d compensation_f64x4; // Error accumulator for Dot2
393
+ } nk_dot_f64x4_state_haswell_t;
394
+
395
+ NK_INTERNAL void nk_dot_f64x4_init_haswell(nk_dot_f64x4_state_haswell_t *state) {
396
+ state->sum_f64x4 = _mm256_setzero_pd();
397
+ state->compensation_f64x4 = _mm256_setzero_pd();
398
+ }
399
+
400
+ NK_INTERNAL void nk_dot_f64x4_update_haswell(nk_dot_f64x4_state_haswell_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
401
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
402
+ nk_unused_(depth_offset);
403
+ nk_unused_(active_dimensions);
404
+ __m256d sum_f64x4 = state->sum_f64x4;
405
+ __m256d compensation_f64x4 = state->compensation_f64x4;
406
+ __m256d a_f64x4 = a.ymm_pd;
407
+ __m256d b_f64x4 = b.ymm_pd;
408
+
409
+ // TwoProd: h = a * b, r = fma(a, b, -h) captures the rounding error
410
+ __m256d product_f64x4 = _mm256_mul_pd(a_f64x4, b_f64x4);
411
+ __m256d product_error_f64x4 = _mm256_fmsub_pd(a_f64x4, b_f64x4, product_f64x4);
412
+
413
+ // TwoSum: (t, q) = TwoSum(sum, h) where t = sum + h rounded, q = error
414
+ __m256d tentative_sum_f64x4 = _mm256_add_pd(sum_f64x4, product_f64x4);
415
+ __m256d virtual_addend_f64x4 = _mm256_sub_pd(tentative_sum_f64x4, sum_f64x4);
416
+ __m256d sum_error_f64x4 = _mm256_add_pd(
417
+ _mm256_sub_pd(sum_f64x4, _mm256_sub_pd(tentative_sum_f64x4, virtual_addend_f64x4)),
418
+ _mm256_sub_pd(product_f64x4, virtual_addend_f64x4));
419
+
420
+ // Update: sum = t, compensation += q + r
421
+ state->sum_f64x4 = tentative_sum_f64x4;
422
+ state->compensation_f64x4 = _mm256_add_pd(compensation_f64x4, _mm256_add_pd(sum_error_f64x4, product_error_f64x4));
423
+ }
424
+
425
+ NK_INTERNAL void nk_dot_f64x4_finalize_haswell( //
426
+ nk_dot_f64x4_state_haswell_t const *state_a, nk_dot_f64x4_state_haswell_t const *state_b, //
427
+ nk_dot_f64x4_state_haswell_t const *state_c, nk_dot_f64x4_state_haswell_t const *state_d, //
428
+ nk_size_t total_dimensions, nk_b256_vec_t *result) {
429
+ nk_unused_(total_dimensions);
430
+ // Compensated horizontal reduction preserving Dot2 error tracking per state
431
+ result->f64s[0] = nk_dot_stable_sum_f64x4_haswell_(state_a->sum_f64x4, state_a->compensation_f64x4);
432
+ result->f64s[1] = nk_dot_stable_sum_f64x4_haswell_(state_b->sum_f64x4, state_b->compensation_f64x4);
433
+ result->f64s[2] = nk_dot_stable_sum_f64x4_haswell_(state_c->sum_f64x4, state_c->compensation_f64x4);
434
+ result->f64s[3] = nk_dot_stable_sum_f64x4_haswell_(state_d->sum_f64x4, state_d->compensation_f64x4);
435
+ }
436
+
437
+ typedef struct nk_dot_f32x4_state_haswell_t {
438
+ __m256d sum_f64x4;
439
+ } nk_dot_f32x4_state_haswell_t;
440
+
441
+ NK_INTERNAL void nk_dot_f32x4_init_haswell(nk_dot_f32x4_state_haswell_t *state) {
442
+ state->sum_f64x4 = _mm256_setzero_pd();
443
+ }
444
+
445
+ NK_INTERNAL void nk_dot_f32x4_update_haswell(nk_dot_f32x4_state_haswell_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
446
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
447
+ nk_unused_(depth_offset);
448
+ nk_unused_(active_dimensions);
449
+ // Upcast 4 f32s to f64s for high-precision accumulation
450
+ __m256d a_f64x4 = _mm256_cvtps_pd(_mm_castsi128_ps(a.xmm));
451
+ __m256d b_f64x4 = _mm256_cvtps_pd(_mm_castsi128_ps(b.xmm));
452
+ // FMA accumulation in f64
453
+ state->sum_f64x4 = _mm256_fmadd_pd(a_f64x4, b_f64x4, state->sum_f64x4);
454
+ }
455
+
456
+ NK_INTERNAL void nk_dot_f32x4_finalize_haswell( //
457
+ nk_dot_f32x4_state_haswell_t const *state_a, nk_dot_f32x4_state_haswell_t const *state_b, //
458
+ nk_dot_f32x4_state_haswell_t const *state_c, nk_dot_f32x4_state_haswell_t const *state_d, //
459
+ nk_size_t total_dimensions, nk_b256_vec_t *result) {
460
+ nk_unused_(total_dimensions);
461
+ // Horizontal reduction: 4 f64s → 1 f64 for each state
462
+ __m256d sum_a_f64x4 = state_a->sum_f64x4;
463
+ __m256d sum_b_f64x4 = state_b->sum_f64x4;
464
+ __m256d sum_c_f64x4 = state_c->sum_f64x4;
465
+ __m256d sum_d_f64x4 = state_d->sum_f64x4;
466
+
467
+ // 4 → 2: add high 128-bit lane to low lane
468
+ __m128d sum_a_f64x2 = _mm_add_pd(_mm256_castpd256_pd128(sum_a_f64x4), _mm256_extractf128_pd(sum_a_f64x4, 1));
469
+ __m128d sum_b_f64x2 = _mm_add_pd(_mm256_castpd256_pd128(sum_b_f64x4), _mm256_extractf128_pd(sum_b_f64x4, 1));
470
+ __m128d sum_c_f64x2 = _mm_add_pd(_mm256_castpd256_pd128(sum_c_f64x4), _mm256_extractf128_pd(sum_c_f64x4, 1));
471
+ __m128d sum_d_f64x2 = _mm_add_pd(_mm256_castpd256_pd128(sum_d_f64x4), _mm256_extractf128_pd(sum_d_f64x4, 1));
472
+
473
+ // 2 → 1: horizontal add
474
+ __m128d sum_ab_f64x2 = _mm_hadd_pd(sum_a_f64x2, sum_b_f64x2); // [sum_a, sum_b]
475
+ __m128d sum_cd_f64x2 = _mm_hadd_pd(sum_c_f64x2, sum_d_f64x2); // [sum_c, sum_d]
476
+
477
+ // Combine into __m256d and convert to f32
478
+ __m256d sum_abcd_f64x4 = _mm256_set_m128d(sum_cd_f64x2, sum_ab_f64x2);
479
+ result->ymm_pd = sum_abcd_f64x4;
480
+ }
481
+
482
+ #pragma endregion - Traditional Floats
483
+
484
+ #pragma region - Smaller Floats
485
+
486
+ NK_PUBLIC void nk_dot_bf16_haswell(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
487
+ nk_f32_t *result) {
488
+ __m128i a_bf16x8, b_bf16x8;
489
+ __m256 sum_f32x8 = _mm256_setzero_ps();
490
+ nk_dot_bf16_haswell_cycle:
491
+ if (count_scalars < 8) {
492
+ nk_b256_vec_t a_vec, b_vec;
493
+ nk_partial_load_b16x16_serial_(a_scalars, &a_vec, count_scalars);
494
+ nk_partial_load_b16x16_serial_(b_scalars, &b_vec, count_scalars);
495
+ a_bf16x8 = a_vec.xmms[0];
496
+ b_bf16x8 = b_vec.xmms[0];
497
+ count_scalars = 0;
498
+ }
499
+ else {
500
+ a_bf16x8 = _mm_loadu_si128((__m128i const *)a_scalars);
501
+ b_bf16x8 = _mm_loadu_si128((__m128i const *)b_scalars);
502
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
503
+ }
504
+ sum_f32x8 = _mm256_fmadd_ps(nk_bf16x8_to_f32x8_haswell_(a_bf16x8), nk_bf16x8_to_f32x8_haswell_(b_bf16x8),
505
+ sum_f32x8);
506
+ if (count_scalars) goto nk_dot_bf16_haswell_cycle;
507
+ *result = (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_f32x8);
508
+ }
509
+
510
+ NK_PUBLIC void nk_dot_f16_haswell(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
511
+ nk_f32_t *result) {
512
+ __m256 a_f32x8, b_f32x8;
513
+ __m256 sum_f32x8 = _mm256_setzero_ps();
514
+ nk_dot_f16_haswell_cycle:
515
+ if (count_scalars < 8) {
516
+ nk_b256_vec_t a_vec, b_vec;
517
+ nk_partial_load_f16x8_to_f32x8_haswell_(a_scalars, &a_vec, count_scalars);
518
+ nk_partial_load_f16x8_to_f32x8_haswell_(b_scalars, &b_vec, count_scalars);
519
+ a_f32x8 = a_vec.ymm_ps;
520
+ b_f32x8 = b_vec.ymm_ps;
521
+ count_scalars = 0;
522
+ }
523
+ else {
524
+ a_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)a_scalars));
525
+ b_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)b_scalars));
526
+ count_scalars -= 8, a_scalars += 8, b_scalars += 8;
527
+ }
528
+ sum_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_f32x8);
529
+ if (count_scalars) goto nk_dot_f16_haswell_cycle;
530
+ *result = (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_f32x8);
531
+ }
532
+
533
+ NK_PUBLIC void nk_dot_bf16c_haswell(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
534
+ nk_f32c_t *result) {
535
+ // Convert BF16 to F32, then use F32 complex dot product with sign-flipping optimization.
536
+ // Uses same XOR trick as f32c to double throughput by deferring sign flips until after loop.
537
+ __m128i a_bf16x8, b_bf16x8;
538
+ __m256 sum_real_f32x8 = _mm256_setzero_ps();
539
+ __m256 sum_imag_f32x8 = _mm256_setzero_ps();
540
+ __m256i const sign_flip_i64x4 = _mm256_set1_epi64x(0x8000000000000000);
541
+ __m256i const swap_adjacent_i8x32 = _mm256_set_epi8( //
542
+ 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4);
543
+
544
+ nk_dot_bf16c_haswell_cycle:
545
+ if (count_pairs < 4) {
546
+ // Partial load using serial helper
547
+ nk_b256_vec_t a_vec, b_vec;
548
+ nk_partial_load_b16x16_serial_(a_pairs, &a_vec, count_pairs * 2);
549
+ nk_partial_load_b16x16_serial_(b_pairs, &b_vec, count_pairs * 2);
550
+ a_bf16x8 = a_vec.xmms[0];
551
+ b_bf16x8 = b_vec.xmms[0];
552
+ count_pairs = 0;
553
+ }
554
+ else {
555
+ a_bf16x8 = _mm_loadu_si128((__m128i const *)a_pairs);
556
+ b_bf16x8 = _mm_loadu_si128((__m128i const *)b_pairs);
557
+ a_pairs += 4, b_pairs += 4, count_pairs -= 4;
558
+ }
559
+
560
+ // Convert BF16 to F32
561
+ __m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16x8);
562
+ __m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16x8);
563
+
564
+ // Complex multiply-accumulate: swap b for imaginary part
565
+ __m256 b_swapped_f32x8 = _mm256_castsi256_ps(
566
+ _mm256_shuffle_epi8(_mm256_castps_si256(b_f32x8), swap_adjacent_i8x32));
567
+ sum_real_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_real_f32x8);
568
+ sum_imag_f32x8 = _mm256_fmadd_ps(a_f32x8, b_swapped_f32x8, sum_imag_f32x8);
569
+
570
+ if (count_pairs) goto nk_dot_bf16c_haswell_cycle;
571
+
572
+ // Flip the sign bit in every second scalar (real part: a_r*b_r - a_i*b_i)
573
+ sum_real_f32x8 = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(sum_real_f32x8), sign_flip_i64x4));
574
+
575
+ result->real = nk_reduce_add_f32x8_haswell_(sum_real_f32x8);
576
+ result->imag = nk_reduce_add_f32x8_haswell_(sum_imag_f32x8);
577
+ }
578
+
579
+ NK_PUBLIC void nk_vdot_bf16c_haswell(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
580
+ nk_f32c_t *result) {
581
+ // Conjugate complex dot product: conj(a) * b
582
+ __m128i a_bf16x8, b_bf16x8;
583
+ __m256 sum_real_f32x8 = _mm256_setzero_ps();
584
+ __m256 sum_imag_f32x8 = _mm256_setzero_ps();
585
+ __m256i const sign_flip_i64x4 = _mm256_set1_epi64x(0x8000000000000000);
586
+ __m256i const swap_adjacent_i8x32 = _mm256_set_epi8( //
587
+ 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4);
588
+
589
+ nk_vdot_bf16c_haswell_cycle:
590
+ if (count_pairs < 4) {
591
+ // Partial load using serial helper
592
+ nk_b256_vec_t a_vec, b_vec;
593
+ nk_partial_load_b16x16_serial_(a_pairs, &a_vec, count_pairs * 2);
594
+ nk_partial_load_b16x16_serial_(b_pairs, &b_vec, count_pairs * 2);
595
+ a_bf16x8 = a_vec.xmms[0];
596
+ b_bf16x8 = b_vec.xmms[0];
597
+ count_pairs = 0;
598
+ }
599
+ else {
600
+ a_bf16x8 = _mm_loadu_si128((__m128i const *)a_pairs);
601
+ b_bf16x8 = _mm_loadu_si128((__m128i const *)b_pairs);
602
+ a_pairs += 4, b_pairs += 4, count_pairs -= 4;
603
+ }
604
+
605
+ // Convert BF16 to F32
606
+ __m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16x8);
607
+ __m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16x8);
608
+
609
+ // Conjugate complex multiply-accumulate
610
+ sum_real_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_real_f32x8);
611
+ __m256 b_swapped_f32x8 = _mm256_castsi256_ps(
612
+ _mm256_shuffle_epi8(_mm256_castps_si256(b_f32x8), swap_adjacent_i8x32));
613
+ sum_imag_f32x8 = _mm256_fmadd_ps(a_f32x8, b_swapped_f32x8, sum_imag_f32x8);
614
+
615
+ if (count_pairs) goto nk_vdot_bf16c_haswell_cycle;
616
+
617
+ // Flip the sign bit in every second scalar (imag part: a_r*b_i - a_i*b_r)
618
+ sum_imag_f32x8 = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(sum_imag_f32x8), sign_flip_i64x4));
619
+
620
+ result->real = nk_reduce_add_f32x8_haswell_(sum_real_f32x8);
621
+ result->imag = nk_reduce_add_f32x8_haswell_(sum_imag_f32x8);
622
+ }
623
+
624
+ NK_PUBLIC void nk_dot_f16c_haswell(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
625
+ nk_f32c_t *result) {
626
+ __m256 sum_real_f32x8 = _mm256_setzero_ps();
627
+ __m256 sum_imag_f32x8 = _mm256_setzero_ps();
628
+ __m256i sign_flip_i64x4 = _mm256_set1_epi64x(0x8000000000000000);
629
+ __m256i swap_adjacent_i8x32 = _mm256_set_epi8( //
630
+ 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4);
631
+ while (count_pairs >= 4) {
632
+ __m256 a_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)a_pairs));
633
+ __m256 b_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)b_pairs));
634
+ __m256 b_swapped_f32x8 = _mm256_castsi256_ps(
635
+ _mm256_shuffle_epi8(_mm256_castps_si256(b_f32x8), swap_adjacent_i8x32));
636
+ sum_real_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_real_f32x8);
637
+ sum_imag_f32x8 = _mm256_fmadd_ps(a_f32x8, b_swapped_f32x8, sum_imag_f32x8);
638
+ count_pairs -= 4, a_pairs += 4, b_pairs += 4;
639
+ }
640
+ // Flip the sign bit in every second scalar before accumulation:
641
+ sum_real_f32x8 = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(sum_real_f32x8), sign_flip_i64x4));
642
+ nk_f32c_t tail_result;
643
+ nk_dot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
644
+ result->real = tail_result.real + (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_real_f32x8);
645
+ result->imag = tail_result.imag + (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_imag_f32x8);
646
+ }
647
+
648
+ NK_PUBLIC void nk_vdot_f16c_haswell(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
649
+ nk_f32c_t *result) {
650
+ __m256 sum_real_f32x8 = _mm256_setzero_ps();
651
+ __m256 sum_imag_f32x8 = _mm256_setzero_ps();
652
+ __m256i sign_flip_i64x4 = _mm256_set1_epi64x(0x8000000000000000);
653
+ __m256i swap_adjacent_i8x32 = _mm256_set_epi8( //
654
+ 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4);
655
+ while (count_pairs >= 4) {
656
+ __m256 a_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)a_pairs));
657
+ __m256 b_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)b_pairs));
658
+ sum_real_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_real_f32x8);
659
+ b_f32x8 = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_f32x8), swap_adjacent_i8x32));
660
+ sum_imag_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_imag_f32x8);
661
+ count_pairs -= 4, a_pairs += 4, b_pairs += 4;
662
+ }
663
+ // Flip the sign bit in every second scalar before accumulation:
664
+ sum_imag_f32x8 = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(sum_imag_f32x8), sign_flip_i64x4));
665
+ nk_f32c_t tail_result;
666
+ nk_vdot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
667
+ result->real = tail_result.real + (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_real_f32x8);
668
+ result->imag = tail_result.imag + (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_imag_f32x8);
669
+ }
670
+
671
+ NK_PUBLIC void nk_dot_e4m3_haswell(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
672
+ nk_f32_t *result) {
673
+ __m256 a_f32x8, b_f32x8;
674
+ __m256 sum_f32x8 = _mm256_setzero_ps();
675
+ nk_dot_e4m3_haswell_cycle:
676
+ if (count_scalars < 8) {
677
+ nk_b256_vec_t a_vec, b_vec;
678
+ nk_partial_load_e4m3x8_to_f32x8_haswell_(a_scalars, &a_vec, count_scalars);
679
+ nk_partial_load_e4m3x8_to_f32x8_haswell_(b_scalars, &b_vec, count_scalars);
680
+ a_f32x8 = a_vec.ymm_ps;
681
+ b_f32x8 = b_vec.ymm_ps;
682
+ count_scalars = 0;
683
+ }
684
+ else {
685
+ a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)a_scalars));
686
+ b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)b_scalars));
687
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
688
+ }
689
+ sum_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_f32x8);
690
+ if (count_scalars) goto nk_dot_e4m3_haswell_cycle;
691
+ *result = (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_f32x8);
692
+ }
693
+
694
+ NK_PUBLIC void nk_dot_e5m2_haswell(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
695
+ nk_f32_t *result) {
696
+ __m256 a_f32x8, b_f32x8;
697
+ __m256 sum_f32x8 = _mm256_setzero_ps();
698
+ nk_dot_e5m2_haswell_cycle:
699
+ if (count_scalars < 8) {
700
+ nk_b256_vec_t a_vec, b_vec;
701
+ nk_partial_load_e5m2x8_to_f32x8_haswell_(a_scalars, &a_vec, count_scalars);
702
+ nk_partial_load_e5m2x8_to_f32x8_haswell_(b_scalars, &b_vec, count_scalars);
703
+ a_f32x8 = a_vec.ymm_ps;
704
+ b_f32x8 = b_vec.ymm_ps;
705
+ count_scalars = 0;
706
+ }
707
+ else {
708
+ a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)a_scalars));
709
+ b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)b_scalars));
710
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
711
+ }
712
+ sum_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_f32x8);
713
+ if (count_scalars) goto nk_dot_e5m2_haswell_cycle;
714
+ *result = (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_f32x8);
715
+ }
716
+
717
+ NK_PUBLIC void nk_dot_e2m3_haswell(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
718
+ nk_f32_t *result) {
719
+ // Integer dot product for e2m3 using dual-VPSHUFB (LUT) + VPMADDUBSW (unsigned×signed).
720
+ // Every e2m3 value × 16 is an exact integer in [-120, +120].
721
+ // Result = i32_dot / 256.0f (exact, no rounding error).
722
+ //
723
+ // 32-entry LUT split into two 16-entry halves for VPSHUFB (which indexes 0-15):
724
+ // lut_lower[0..15]: {0,2,4,6,8,10,12,14, 16,18,20,22,24,26,28,30}
725
+ // lut_upper[0..15]: {32,36,40,44,48,52,56,60, 64,72,80,88,96,104,112,120}
726
+ //
727
+ __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
728
+ 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
729
+ __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
730
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
731
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
732
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
733
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
734
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
735
+ __m256i const ones_i16x16 = _mm256_set1_epi16(1);
736
+ __m256i sum_i32x8 = _mm256_setzero_si256();
737
+ __m256i a_e2m3_u8x32, b_e2m3_u8x32;
738
+
739
+ nk_dot_e2m3_haswell_cycle:
740
+ if (count_scalars < 32) {
741
+ nk_b256_vec_t a_vec, b_vec;
742
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
743
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
744
+ a_e2m3_u8x32 = a_vec.ymm;
745
+ b_e2m3_u8x32 = b_vec.ymm;
746
+ count_scalars = 0;
747
+ }
748
+ else {
749
+ a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
750
+ b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
751
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
752
+ }
753
+
754
+ // Extract 5-bit magnitude, then split into low 4 bits (VPSHUFB index) and bit 4 (hi/lo select)
755
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
756
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
757
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
758
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
759
+ __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
760
+ half_select_u8x32);
761
+ __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
762
+ half_select_u8x32);
763
+
764
+ // Dual VPSHUFB: lookup in both halves, blend based on bit 4
765
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
766
+ _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
767
+ a_upper_select_u8x32);
768
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
769
+ _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
770
+ b_upper_select_u8x32);
771
+
772
+ // Combined sign: (a ^ b) & 0x20, negate b where signs differ
773
+ __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
774
+ __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
775
+ __m256i b_negated_u8x32 = _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32);
776
+ __m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32, b_negated_u8x32, negate_mask_u8x32);
777
+
778
+ // VPMADDUBSW: a_unsigned[unsigned] × b_signed[signed] → i16 pairs (max |120×120| = 14400 < 32767, safe)
779
+ __m256i products_i16x16 = _mm256_maddubs_epi16(a_unsigned_u8x32, b_signed_i8x32);
780
+ // VPMADDWD with ones: i16 pairs → i32
781
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(products_i16x16, ones_i16x16));
782
+
783
+ if (count_scalars) goto nk_dot_e2m3_haswell_cycle;
784
+ *result = (nk_f32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8) / 256.0f;
785
+ }
786
+
787
+ NK_PUBLIC void nk_dot_e3m2_haswell(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
788
+ nk_f32_t *result) {
789
+ // Integer dot product for e3m2 using dual-VPSHUFB (low-byte LUT) + VPMADDWD (i16×i16→i32).
790
+ // Every e3m2 value × 16 is an exact integer, but magnitudes reach 448, requiring i16.
791
+ // Result = i32_dot / 256.0f (exact, no rounding error).
792
+ //
793
+ // 32-entry magnitude LUT split into low bytes for dual VPSHUFB:
794
+ // lut_lower[0..15]: low bytes of {0,1,2,3,4,5,6,7,8,10,12,14,16,20,24,28}
795
+ // lut_upper[0..15]: low bytes of {32,40,48,56,64,80,96,112,128,160,192,224,256,320,384,448}
796
+ // High byte is 1 iff magnitude index >= 28 (values 256-448), else 0.
797
+ //
798
+ __m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
799
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
800
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
801
+ __m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
802
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
803
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
804
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
805
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
806
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
807
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
808
+ __m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
809
+ __m256i const ones_u8x32 = _mm256_set1_epi8(1);
810
+ __m256i const ones_i16x16 = _mm256_set1_epi16(1);
811
+ __m256i sum_i32x8 = _mm256_setzero_si256();
812
+ __m256i a_e3m2_u8x32, b_e3m2_u8x32;
813
+
814
+ nk_dot_e3m2_haswell_cycle:
815
+ if (count_scalars < 32) {
816
+ nk_b256_vec_t a_vec, b_vec;
817
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
818
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
819
+ a_e3m2_u8x32 = a_vec.ymm;
820
+ b_e3m2_u8x32 = b_vec.ymm;
821
+ count_scalars = 0;
822
+ }
823
+ else {
824
+ a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
825
+ b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
826
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
827
+ }
828
+
829
+ // Extract 5-bit magnitude, split into low 4 bits and bit 4
830
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
831
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
832
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
833
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
834
+ __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
835
+ half_select_u8x32);
836
+ __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
837
+ half_select_u8x32);
838
+
839
+ // Dual VPSHUFB: lookup low bytes in both halves, blend based on bit 4
840
+ __m256i a_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, a_shuffle_index_u8x32),
841
+ _mm256_shuffle_epi8(lut_lo_upper_u8x32, a_shuffle_index_u8x32),
842
+ a_upper_select_u8x32);
843
+ __m256i b_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, b_shuffle_index_u8x32),
844
+ _mm256_shuffle_epi8(lut_lo_upper_u8x32, b_shuffle_index_u8x32),
845
+ b_upper_select_u8x32);
846
+
847
+ // High byte: 1 iff magnitude >= 28 (signed compare safe: 27 < 128)
848
+ __m256i a_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
849
+ __m256i b_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
850
+
851
+ // Interleave low and high bytes into i16 (little-endian: low byte first)
852
+ __m256i a_lo_i16x16 = _mm256_unpacklo_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
853
+ __m256i a_hi_i16x16 = _mm256_unpackhi_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
854
+ __m256i b_lo_i16x16 = _mm256_unpacklo_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
855
+ __m256i b_hi_i16x16 = _mm256_unpackhi_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
856
+
857
+ // Combined sign: (a ^ b) & 0x20, widen to i16 via unpack, create +1/-1 sign vector
858
+ __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
859
+ __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
860
+ __m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
861
+ __m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
862
+ __m256i sign_lo_i16x16 = _mm256_or_si256(negate_lo_i16x16, ones_i16x16);
863
+ __m256i sign_hi_i16x16 = _mm256_or_si256(negate_hi_i16x16, ones_i16x16);
864
+ __m256i b_signed_lo_i16x16 = _mm256_sign_epi16(b_lo_i16x16, sign_lo_i16x16);
865
+ __m256i b_signed_hi_i16x16 = _mm256_sign_epi16(b_hi_i16x16, sign_hi_i16x16);
866
+
867
+ // VPMADDWD: a_unsigned_i16 × b_signed_i16 → i32 accumulator
868
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_lo_i16x16, b_signed_lo_i16x16));
869
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_hi_i16x16, b_signed_hi_i16x16));
870
+
871
+ if (count_scalars) goto nk_dot_e3m2_haswell_cycle;
872
+ *result = (nk_f32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8) / 256.0f;
873
+ }
874
+
875
+ /**
876
+ * @brief Internal helper state for dot-products of low-precision types, where 32-bit accumulation is enough.
877
+ * @sa nk_dot_f16x8_state_haswell_t, nk_dot_bf16x8_state_haswell_t
878
+ * @sa nk_dot_e4m3x16_state_haswell_t, nk_dot_e5m2x16_state_haswell_t
879
+ */
880
+ typedef struct nk_dot_through_f32_state_haswell_t_ {
881
+ __m256 sum_f32x8;
882
+ } nk_dot_through_f32_state_haswell_t_;
883
+
884
+ /**
885
+ * @brief Initializes 32-bit accumulators for low-precision dot-products.
886
+ * @sa nk_dot_f16x8_init_haswell, nk_dot_bf16x8_init_haswell
887
+ * @sa nk_dot_e4m3x16_init_haswell, nk_dot_e5m2x16_init_haswell
888
+ */
889
+ NK_INTERNAL void nk_dot_through_f32_init_haswell_(nk_dot_through_f32_state_haswell_t_ *state) {
890
+ state->sum_f32x8 = _mm256_setzero_ps();
891
+ }
892
+
893
+ /**
894
+ * @brief Fuses 32-bit multiplication and accumulation for low-precision dot-products.
895
+ * @sa nk_dot_f16x8_update_haswell, nk_dot_bf16x8_update_haswell
896
+ * @sa nk_dot_e4m3x16_update_haswell, nk_dot_e5m2x16_update_haswell
897
+ */
898
+ NK_INTERNAL void nk_dot_through_f32_update_haswell_(nk_dot_through_f32_state_haswell_t_ *state, nk_b256_vec_t a,
899
+ nk_b256_vec_t b, nk_size_t depth_offset,
900
+ nk_size_t active_dimensions) {
901
+ nk_unused_(depth_offset);
902
+ nk_unused_(active_dimensions);
903
+ state->sum_f32x8 = _mm256_fmadd_ps(a.ymm_ps, b.ymm_ps, state->sum_f32x8);
904
+ }
905
+
906
+ /**
907
+ * @brief Finalizes 4x low-precision dot-products placing them into 4x consecutive 32-bit slots.
908
+ * @sa nk_dot_f16x8_finalize_haswell, nk_dot_bf16x8_finalize_haswell
909
+ * @sa nk_dot_e4m3x16_finalize_haswell, nk_dot_e5m2x16_finalize_haswell
910
+ *
911
+ * The goal of this kernel is simple - compute 4x horizontal reductions, each involving 8x floats.
912
+ * The lack of vectorized horizontal instruction implies many consecutive shuffles producing a tree-like
913
+ * reduction. This kernel allows combining some of those operations between different dot products.
914
+ */
915
+ NK_INTERNAL void nk_dot_through_f32_finalize_haswell_( //
916
+ nk_dot_through_f32_state_haswell_t_ const *state_a, nk_dot_through_f32_state_haswell_t_ const *state_b, //
917
+ nk_dot_through_f32_state_haswell_t_ const *state_c, nk_dot_through_f32_state_haswell_t_ const *state_d, //
918
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
919
+ nk_unused_(total_dimensions);
920
+
921
+ __m256 const sum_a_f32x8 = state_a->sum_f32x8, sum_b_f32x8 = state_b->sum_f32x8, sum_c_f32x8 = state_c->sum_f32x8,
922
+ sum_d_f32x8 = state_d->sum_f32x8;
923
+
924
+ // ILP-optimized 4-way horizontal reduction for f32 in AVX2
925
+ __m128 sum_a_f32x4 = _mm_add_ps(_mm256_castps256_ps128(sum_a_f32x8), _mm256_extractf128_ps(sum_a_f32x8, 1));
926
+ __m128 sum_b_f32x4 = _mm_add_ps(_mm256_castps256_ps128(sum_b_f32x8), _mm256_extractf128_ps(sum_b_f32x8, 1));
927
+ __m128 sum_c_f32x4 = _mm_add_ps(_mm256_castps256_ps128(sum_c_f32x8), _mm256_extractf128_ps(sum_c_f32x8, 1));
928
+ __m128 sum_d_f32x4 = _mm_add_ps(_mm256_castps256_ps128(sum_d_f32x8), _mm256_extractf128_ps(sum_d_f32x8, 1));
929
+ __m128 transpose_ab_low_f32x4 = _mm_unpacklo_ps(sum_a_f32x4, sum_b_f32x4);
930
+ __m128 transpose_cd_low_f32x4 = _mm_unpacklo_ps(sum_c_f32x4, sum_d_f32x4);
931
+ __m128 transpose_ab_high_f32x4 = _mm_unpackhi_ps(sum_a_f32x4, sum_b_f32x4);
932
+ __m128 transpose_cd_high_f32x4 = _mm_unpackhi_ps(sum_c_f32x4, sum_d_f32x4);
933
+ __m128 sum_lane0_f32x4 = _mm_movelh_ps(transpose_ab_low_f32x4, transpose_cd_low_f32x4);
934
+ __m128 sum_lane1_f32x4 = _mm_movehl_ps(transpose_cd_low_f32x4, transpose_ab_low_f32x4);
935
+ __m128 sum_lane2_f32x4 = _mm_movelh_ps(transpose_ab_high_f32x4, transpose_cd_high_f32x4);
936
+ __m128 sum_lane3_f32x4 = _mm_movehl_ps(transpose_cd_high_f32x4, transpose_ab_high_f32x4);
937
+ __m128 final_sum_f32x4 = _mm_add_ps(_mm_add_ps(sum_lane0_f32x4, sum_lane1_f32x4),
938
+ _mm_add_ps(sum_lane2_f32x4, sum_lane3_f32x4));
939
+ result->xmm = _mm_castps_si128(final_sum_f32x4);
940
+ }
941
+
942
+ /**
943
+ * @brief Running state for 128-bit dot accumulation over f16 scalars on Haswell.
944
+ * @note Alias of nk_dot_through_f32_state_haswell_t_
945
+ */
946
+ typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_f16x8_state_haswell_t;
947
+
948
+ /**
949
+ * @brief Running state for 128-bit dot accumulation over bf16 scalars on Haswell.
950
+ * @note Alias of nk_dot_through_f32_state_haswell_t_
951
+ */
952
+ typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_bf16x8_state_haswell_t;
953
+
954
+ /**
955
+ * @brief Running state for 128-bit dot accumulation over e4m3 scalars on Haswell.
956
+ * @note Alias of nk_dot_through_f32_state_haswell_t_
957
+ */
958
+ typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_e4m3x16_state_haswell_t;
959
+
960
+ /**
961
+ * @brief Running state for 128-bit dot accumulation over e5m2 scalars on Haswell.
962
+ * @note Alias of nk_dot_through_f32_state_haswell_t_
963
+ */
964
+ typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_e5m2x16_state_haswell_t;
965
+
966
+ /**
967
+ * @brief Running state for 128-bit dot accumulation over e2m3 scalars on Haswell.
968
+ * @note Alias of nk_dot_through_f32_state_haswell_t_
969
+ */
970
+ typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_e2m3x16_state_haswell_t;
971
+
972
+ /**
973
+ * @brief Running state for 128-bit dot accumulation over e3m2 scalars on Haswell.
974
+ * @note Alias of nk_dot_through_f32_state_haswell_t_
975
+ */
976
+ typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_e3m2x16_state_haswell_t;
977
+
978
+ /**
979
+ * @brief Integer LUT batch state for e2m3 dot-products on Haswell (AVX2).
980
+ * Uses VPMADDUBSW (u8×i8→i16) + VPMADDWD (i16→i32) instead of Sierra's VPDPBUSD.
981
+ */
982
+ typedef struct nk_dot_e2m3x32_state_haswell_t {
983
+ __m256i sum_i32x8;
984
+ } nk_dot_e2m3x32_state_haswell_t;
985
+
986
+ NK_INTERNAL void nk_dot_e2m3x32_init_haswell(nk_dot_e2m3x32_state_haswell_t *state) {
987
+ state->sum_i32x8 = _mm256_setzero_si256();
988
+ }
989
+
990
+ NK_INTERNAL void nk_dot_e2m3x32_update_haswell(nk_dot_e2m3x32_state_haswell_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
991
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
992
+ nk_unused_(depth_offset);
993
+ nk_unused_(active_dimensions);
994
+ __m256i const lut_lower_u8x32 = _mm256_set_epi8( //
995
+ 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
996
+ 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
997
+ __m256i const lut_upper_u8x32 = _mm256_set_epi8( //
998
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
999
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
1000
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
1001
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
1002
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
1003
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
1004
+ __m256i const ones_i16x16 = _mm256_set1_epi16(1);
1005
+
1006
+ __m256i a_e2m3_u8x32 = a.ymm;
1007
+ __m256i b_e2m3_u8x32 = b.ymm;
1008
+
1009
+ // Extract 5-bit magnitude, split into low 4 bits and bit 4
1010
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
1011
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
1012
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
1013
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
1014
+ __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
1015
+ half_select_u8x32);
1016
+ __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
1017
+ half_select_u8x32);
1018
+
1019
+ // Dual VPSHUFB + blend
1020
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
1021
+ _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
1022
+ a_upper_select_u8x32);
1023
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
1024
+ _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
1025
+ b_upper_select_u8x32);
1026
+
1027
+ // Combined sign + conditional negate
1028
+ __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
1029
+ __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
1030
+ __m256i b_negated_u8x32 = _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32);
1031
+ __m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32, b_negated_u8x32, negate_mask_u8x32);
1032
+
1033
+ // VPMADDUBSW + VPMADDWD: u8×i8→i16→i32
1034
+ __m256i products_i16x16 = _mm256_maddubs_epi16(a_unsigned_u8x32, b_signed_i8x32);
1035
+ __m256i products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
1036
+ state->sum_i32x8 = _mm256_add_epi32(state->sum_i32x8, products_i32x8);
1037
+ }
1038
+
1039
+ NK_INTERNAL void nk_dot_e2m3x32_finalize_haswell( //
1040
+ nk_dot_e2m3x32_state_haswell_t const *state_a, nk_dot_e2m3x32_state_haswell_t const *state_b, //
1041
+ nk_dot_e2m3x32_state_haswell_t const *state_c, nk_dot_e2m3x32_state_haswell_t const *state_d, //
1042
+ nk_size_t total_dimensions, nk_b128_vec_t *results) {
1043
+ nk_unused_(total_dimensions);
1044
+
1045
+ // ILP-optimized 4-way horizontal reduction: i32x8 → scalar i32, then → f32 with ÷256
1046
+ __m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->sum_i32x8),
1047
+ _mm256_extracti128_si256(state_a->sum_i32x8, 1));
1048
+ __m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->sum_i32x8),
1049
+ _mm256_extracti128_si256(state_b->sum_i32x8, 1));
1050
+ __m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->sum_i32x8),
1051
+ _mm256_extracti128_si256(state_c->sum_i32x8, 1));
1052
+ __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->sum_i32x8),
1053
+ _mm256_extracti128_si256(state_d->sum_i32x8, 1));
1054
+
1055
+ // Transpose for SIMD reduction
1056
+ __m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
1057
+ __m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
1058
+ __m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
1059
+ __m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
1060
+ __m128i lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
1061
+ __m128i lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
1062
+ __m128i lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
1063
+ __m128i lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
1064
+ __m128i sum_i32x4 = _mm_add_epi32(_mm_add_epi32(lane0_i32x4, lane1_i32x4), _mm_add_epi32(lane2_i32x4, lane3_i32x4));
1065
+
1066
+ // Convert i32 → f32 and scale by 1/256
1067
+ __m128 sum_f32x4 = _mm_mul_ps(_mm_cvtepi32_ps(sum_i32x4), _mm_set1_ps(1.0f / 256.0f));
1068
+ results->xmm = _mm_castps_si128(sum_f32x4);
1069
+ }
1070
+
1071
+ /**
1072
+ * @brief Integer LUT batch state for e3m2 dot-products on Haswell (AVX2).
1073
+ * Uses i16 widening via VPMADDWD (i16×i16→i32) with two accumulators for lo/hi halves.
1074
+ */
1075
+ typedef struct nk_dot_e3m2x32_state_haswell_t {
1076
+ __m256i sum_a_i32x8;
1077
+ __m256i sum_b_i32x8;
1078
+ } nk_dot_e3m2x32_state_haswell_t;
1079
+
1080
+ NK_INTERNAL void nk_dot_e3m2x32_init_haswell(nk_dot_e3m2x32_state_haswell_t *state) {
1081
+ state->sum_a_i32x8 = _mm256_setzero_si256();
1082
+ state->sum_b_i32x8 = _mm256_setzero_si256();
1083
+ }
1084
+
1085
+ NK_INTERNAL void nk_dot_e3m2x32_update_haswell(nk_dot_e3m2x32_state_haswell_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
1086
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
1087
+ nk_unused_(depth_offset);
1088
+ nk_unused_(active_dimensions);
1089
+ __m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
1090
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1091
+ __m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
1092
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
1093
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
1094
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
1095
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
1096
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
1097
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
1098
+ __m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
1099
+ __m256i const ones_u8x32 = _mm256_set1_epi8(1);
1100
+ __m256i const ones_i16x16 = _mm256_set1_epi16(1);
1101
+
1102
+ __m256i a_e3m2_u8x32 = a.ymm;
1103
+ __m256i b_e3m2_u8x32 = b.ymm;
1104
+
1105
+ // Extract 5-bit magnitude, split into low 4 bits and bit 4
1106
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
1107
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
1108
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
1109
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
1110
+ __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
1111
+ half_select_u8x32);
1112
+ __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
1113
+ half_select_u8x32);
1114
+
1115
+ // Dual VPSHUFB for low bytes
1116
+ __m256i a_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, a_shuffle_index_u8x32),
1117
+ _mm256_shuffle_epi8(lut_lo_upper_u8x32, a_shuffle_index_u8x32),
1118
+ a_upper_select_u8x32);
1119
+ __m256i b_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, b_shuffle_index_u8x32),
1120
+ _mm256_shuffle_epi8(lut_lo_upper_u8x32, b_shuffle_index_u8x32),
1121
+ b_upper_select_u8x32);
1122
+
1123
+ // High byte: 1 iff magnitude >= 28
1124
+ __m256i a_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
1125
+ __m256i b_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
1126
+
1127
+ // Interleave low and high bytes into i16
1128
+ __m256i a_lo_i16x16 = _mm256_unpacklo_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
1129
+ __m256i a_hi_i16x16 = _mm256_unpackhi_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
1130
+ __m256i b_lo_i16x16 = _mm256_unpacklo_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
1131
+ __m256i b_hi_i16x16 = _mm256_unpackhi_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
1132
+
1133
+ // Combined sign: (a ^ b) & 0x20, widen to i16, create +1/-1 sign vector via VPSIGNW
1134
+ __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
1135
+ __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
1136
+ __m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
1137
+ __m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
1138
+ __m256i sign_lo_i16x16 = _mm256_or_si256(negate_lo_i16x16, ones_i16x16);
1139
+ __m256i sign_hi_i16x16 = _mm256_or_si256(negate_hi_i16x16, ones_i16x16);
1140
+ __m256i b_signed_lo_i16x16 = _mm256_sign_epi16(b_lo_i16x16, sign_lo_i16x16);
1141
+ __m256i b_signed_hi_i16x16 = _mm256_sign_epi16(b_hi_i16x16, sign_hi_i16x16);
1142
+
1143
+ // VPMADDWD: a_unsigned_i16 × b_signed_i16 → i32 (two halves → two accumulators)
1144
+ state->sum_a_i32x8 = _mm256_add_epi32(state->sum_a_i32x8, _mm256_madd_epi16(a_lo_i16x16, b_signed_lo_i16x16));
1145
+ state->sum_b_i32x8 = _mm256_add_epi32(state->sum_b_i32x8, _mm256_madd_epi16(a_hi_i16x16, b_signed_hi_i16x16));
1146
+ }
1147
+
1148
+ NK_INTERNAL void nk_dot_e3m2x32_finalize_haswell( //
1149
+ nk_dot_e3m2x32_state_haswell_t const *state_a, nk_dot_e3m2x32_state_haswell_t const *state_b, //
1150
+ nk_dot_e3m2x32_state_haswell_t const *state_c, nk_dot_e3m2x32_state_haswell_t const *state_d, //
1151
+ nk_size_t total_dimensions, nk_b128_vec_t *results) {
1152
+ nk_unused_(total_dimensions);
1153
+
1154
+ // Merge two accumulators per state, then same 4-way transpose-reduce as Sierra
1155
+ __m256i merged_a = _mm256_add_epi32(state_a->sum_a_i32x8, state_a->sum_b_i32x8);
1156
+ __m256i merged_b = _mm256_add_epi32(state_b->sum_a_i32x8, state_b->sum_b_i32x8);
1157
+ __m256i merged_c = _mm256_add_epi32(state_c->sum_a_i32x8, state_c->sum_b_i32x8);
1158
+ __m256i merged_d = _mm256_add_epi32(state_d->sum_a_i32x8, state_d->sum_b_i32x8);
1159
+
1160
+ __m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(merged_a), _mm256_extracti128_si256(merged_a, 1));
1161
+ __m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(merged_b), _mm256_extracti128_si256(merged_b, 1));
1162
+ __m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(merged_c), _mm256_extracti128_si256(merged_c, 1));
1163
+ __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(merged_d), _mm256_extracti128_si256(merged_d, 1));
1164
+
1165
+ __m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
1166
+ __m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
1167
+ __m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
1168
+ __m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
1169
+ __m128i lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
1170
+ __m128i lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
1171
+ __m128i lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
1172
+ __m128i lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
1173
+ __m128i sum_i32x4 = _mm_add_epi32(_mm_add_epi32(lane0_i32x4, lane1_i32x4), _mm_add_epi32(lane2_i32x4, lane3_i32x4));
1174
+
1175
+ __m128 sum_f32x4 = _mm_mul_ps(_mm_cvtepi32_ps(sum_i32x4), _mm_set1_ps(1.0f / 256.0f));
1176
+ results->xmm = _mm_castps_si128(sum_f32x4);
1177
+ }
1178
+
1179
+ #pragma endregion - Smaller Floats
1180
+
1181
+ #pragma region - Small Integers
1182
+
1183
+ NK_PUBLIC void nk_dot_i8_haswell(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
1184
+ nk_i32_t *result) {
1185
+ __m256i sum_low_i32x8 = _mm256_setzero_si256();
1186
+ __m256i sum_high_i32x8 = _mm256_setzero_si256();
1187
+ nk_size_t idx_scalars = 0;
1188
+ // Use two 128-bit loads instead of 256-bit load + extract to avoid Port 5 contention.
1189
+ // VEXTRACTI128 uses Port 5; two smaller loads use Port 2/3 (2 ports available).
1190
+ for (; idx_scalars + 32 <= count_scalars; idx_scalars += 32) {
1191
+ __m128i a_low_i8x16 = _mm_loadu_si128((__m128i const *)(a_scalars + idx_scalars));
1192
+ __m128i a_high_i8x16 = _mm_loadu_si128((__m128i const *)(a_scalars + idx_scalars + 16));
1193
+ __m128i b_low_i8x16 = _mm_loadu_si128((__m128i const *)(b_scalars + idx_scalars));
1194
+ __m128i b_high_i8x16 = _mm_loadu_si128((__m128i const *)(b_scalars + idx_scalars + 16));
1195
+ // Upcast `int8` to `int16` - no extracts needed
1196
+ __m256i a_low_i16x16 = _mm256_cvtepi8_epi16(a_low_i8x16);
1197
+ __m256i a_high_i16x16 = _mm256_cvtepi8_epi16(a_high_i8x16);
1198
+ __m256i b_low_i16x16 = _mm256_cvtepi8_epi16(b_low_i8x16);
1199
+ __m256i b_high_i16x16 = _mm256_cvtepi8_epi16(b_high_i8x16);
1200
+ // Multiply and accumulate at `int16` level, accumulate at `int32` level
1201
+ sum_low_i32x8 = _mm256_add_epi32(sum_low_i32x8, _mm256_madd_epi16(a_low_i16x16, b_low_i16x16));
1202
+ sum_high_i32x8 = _mm256_add_epi32(sum_high_i32x8, _mm256_madd_epi16(a_high_i16x16, b_high_i16x16));
1203
+ }
1204
+ nk_i32_t sum = nk_reduce_add_i32x8_haswell_(_mm256_add_epi32(sum_low_i32x8, sum_high_i32x8));
1205
+ for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_i32_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
1206
+ *result = sum;
1207
+ }
1208
+
1209
+ NK_PUBLIC void nk_dot_u8_haswell(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
1210
+ nk_u32_t *result) {
1211
+ __m256i sum_low_i32x8 = _mm256_setzero_si256();
1212
+ __m256i sum_high_i32x8 = _mm256_setzero_si256();
1213
+ __m256i const zeros_i8x32 = _mm256_setzero_si256();
1214
+ nk_size_t idx_scalars = 0;
1215
+ for (; idx_scalars + 32 <= count_scalars; idx_scalars += 32) {
1216
+ __m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a_scalars + idx_scalars));
1217
+ __m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b_scalars + idx_scalars));
1218
+ // Upcast `uint8` to `int16`. Unpacking is faster than extracts.
1219
+ __m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_u8x32, zeros_i8x32);
1220
+ __m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_u8x32, zeros_i8x32);
1221
+ __m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_u8x32, zeros_i8x32);
1222
+ __m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_u8x32, zeros_i8x32);
1223
+ // Multiply and accumulate at `int16` level, accumulate at `int32` level
1224
+ sum_low_i32x8 = _mm256_add_epi32(sum_low_i32x8, _mm256_madd_epi16(a_low_i16x16, b_low_i16x16));
1225
+ sum_high_i32x8 = _mm256_add_epi32(sum_high_i32x8, _mm256_madd_epi16(a_high_i16x16, b_high_i16x16));
1226
+ }
1227
+ nk_u32_t sum = (nk_u32_t)nk_reduce_add_i32x8_haswell_(_mm256_add_epi32(sum_low_i32x8, sum_high_i32x8));
1228
+ for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_u32_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
1229
+ *result = sum;
1230
+ }
1231
+
1232
+ NK_PUBLIC void nk_dot_i4_haswell(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result) {
1233
+ // i4 values are packed as nibbles: two 4-bit signed values per byte.
1234
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
1235
+ //
1236
+ // Algorithm: For signed i4, we use an algebraic transformation (similar to Ice Lake).
1237
+ // Let ax, bx be the unsigned [0,15] representation of signed values a, b in [-8,7].
1238
+ // Then: a = ax - 8, b = bx - 8 (the XOR trick gives signed = (unsigned ^ 8) - 8)
1239
+ // So: a * b = (ax - 8)(bx - 8) = ax * bx - 8 * ax - 8 * bx + 64
1240
+ //
1241
+ // We compute ax * bx using widening multiply, then apply the correction:
1242
+ // signed_dot = unsigned_dot - 8 * (sum_ax + sum_bx) + 64 * n
1243
+ //
1244
+ // Optimization: Process 16 bytes (32 nibbles) per iteration and use SAD for correction sums.
1245
+ // Benchmark shows 16-byte approach is 2× faster than 8-byte (10.7 GB/s vs 5.3 GB/s).
1246
+ // Better ILP and amortized loop overhead with wider operations.
1247
+ //
1248
+ n = nk_size_round_up_to_multiple_(n, 2);
1249
+ nk_size_t n_bytes = n / 2;
1250
+ __m128i const nibble_mask_u8x16 = _mm_set1_epi8(0x0F);
1251
+ __m128i const xor_mask_u8x16 = _mm_set1_epi8(0x08);
1252
+ __m128i const zeros_u8x16 = _mm_setzero_si128();
1253
+ __m256i sum_cd_i32x8 = _mm256_setzero_si256();
1254
+ __m128i sum_cx_i64x2 = _mm_setzero_si128(); // Use i64 for SAD results
1255
+ __m128i sum_dx_i64x2 = _mm_setzero_si128();
1256
+ __m128i a_i4x32, b_i4x32;
1257
+
1258
+ nk_dot_i4_haswell_cycle:
1259
+ // Process 16 bytes (32 nibbles) per iteration
1260
+ if (n_bytes < 16) {
1261
+ // Zero-padded partial load: zero nibbles XOR 8 = 8, which contributes to cd_dot
1262
+ // and correction sums. Extend `n` to cover padding so `+64*n` cancels it out.
1263
+ nk_b128_vec_t a_vec, b_vec;
1264
+ nk_partial_load_b8x16_serial_(a, &a_vec, n_bytes);
1265
+ nk_partial_load_b8x16_serial_(b, &b_vec, n_bytes);
1266
+ a_i4x32 = a_vec.xmm;
1267
+ b_i4x32 = b_vec.xmm;
1268
+ n += (16 - n_bytes) * 2;
1269
+ n_bytes = 0;
1270
+ }
1271
+ else {
1272
+ a_i4x32 = _mm_loadu_si128((__m128i const *)a); // Load full 16 bytes
1273
+ b_i4x32 = _mm_loadu_si128((__m128i const *)b);
1274
+ a += 16, b += 16, n_bytes -= 16;
1275
+ }
1276
+
1277
+ // Extract low and high nibbles
1278
+ __m128i a_lo_u8x16 = _mm_and_si128(a_i4x32, nibble_mask_u8x16);
1279
+ __m128i a_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(a_i4x32, 4), nibble_mask_u8x16);
1280
+ __m128i b_lo_u8x16 = _mm_and_si128(b_i4x32, nibble_mask_u8x16);
1281
+ __m128i b_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(b_i4x32, 4), nibble_mask_u8x16);
1282
+
1283
+ // XOR with 8 to get cx, dx values for the algebraic transformation
1284
+ __m128i c_lo_u8x16 = _mm_xor_si128(a_lo_u8x16, xor_mask_u8x16);
1285
+ __m128i c_hi_u8x16 = _mm_xor_si128(a_hi_u8x16, xor_mask_u8x16);
1286
+ __m128i d_lo_u8x16 = _mm_xor_si128(b_lo_u8x16, xor_mask_u8x16);
1287
+ __m128i d_hi_u8x16 = _mm_xor_si128(b_hi_u8x16, xor_mask_u8x16);
1288
+
1289
+ // Widen u8 to i16 and multiply using MADD (2× instead of 4×)
1290
+ __m256i c_lo_i16x16 = _mm256_cvtepu8_epi16(c_lo_u8x16);
1291
+ __m256i c_hi_i16x16 = _mm256_cvtepu8_epi16(c_hi_u8x16);
1292
+ __m256i d_lo_i16x16 = _mm256_cvtepu8_epi16(d_lo_u8x16);
1293
+ __m256i d_hi_i16x16 = _mm256_cvtepu8_epi16(d_hi_u8x16);
1294
+
1295
+ // Multiply i16×i16 and accumulate to i32 using MADD
1296
+ sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(c_lo_i16x16, d_lo_i16x16));
1297
+ sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(c_hi_i16x16, d_hi_i16x16));
1298
+
1299
+ // Optimization: Use SAD for correction sums (5cy vs 24cy for 8× widenings)
1300
+ // PSADBW sums 8× u8 values to a single i64 in each 64-bit lane
1301
+ sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(c_lo_u8x16, zeros_u8x16));
1302
+ sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(c_hi_u8x16, zeros_u8x16));
1303
+ sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(d_lo_u8x16, zeros_u8x16));
1304
+ sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(d_hi_u8x16, zeros_u8x16));
1305
+
1306
+ if (n_bytes) goto nk_dot_i4_haswell_cycle;
1307
+
1308
+ // Reduce and apply algebraic correction
1309
+ nk_i32_t cd_dot = nk_reduce_add_i32x8_haswell_(sum_cd_i32x8);
1310
+
1311
+ // Extract SAD results (already summed across 8 bytes per lane)
1312
+ nk_i64_t cx_sum = (nk_i64_t)_mm_extract_epi64(sum_cx_i64x2, 0) + (nk_i64_t)_mm_extract_epi64(sum_cx_i64x2, 1);
1313
+ nk_i64_t dx_sum = (nk_i64_t)_mm_extract_epi64(sum_dx_i64x2, 0) + (nk_i64_t)_mm_extract_epi64(sum_dx_i64x2, 1);
1314
+
1315
+ *result = (nk_i32_t)(cd_dot - 8 * (cx_sum + dx_sum) + 64 * (nk_i64_t)n);
1316
+ }
1317
+
1318
+ NK_PUBLIC void nk_dot_u4_haswell(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
1319
+ // u4 values are packed as nibbles: two 4-bit unsigned values per byte.
1320
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
1321
+ // Values are ∈ [0,15], so we can use direct unpacking and multiplication.
1322
+ //
1323
+ // Optimization: Process 16 bytes (32 nibbles) per iteration for better ILP.
1324
+ // Benchmark shows 16-byte approach provides best performance.
1325
+ //
1326
+ n = nk_size_round_up_to_multiple_(n, 2);
1327
+ nk_size_t n_bytes = n / 2;
1328
+ __m128i const nibble_mask_u8x16 = _mm_set1_epi8(0x0F);
1329
+ __m256i sum_i32x8 = _mm256_setzero_si256();
1330
+ __m128i a_u4x32, b_u4x32;
1331
+
1332
+ nk_dot_u4_haswell_cycle:
1333
+ // Process 16 bytes (32 nibbles) per iteration
1334
+ if (n_bytes < 16) {
1335
+ // Partial load using serial helper
1336
+ nk_b128_vec_t a_vec, b_vec;
1337
+ nk_partial_load_b8x16_serial_(a, &a_vec, n_bytes);
1338
+ nk_partial_load_b8x16_serial_(b, &b_vec, n_bytes);
1339
+ a_u4x32 = a_vec.xmm;
1340
+ b_u4x32 = b_vec.xmm;
1341
+ n_bytes = 0;
1342
+ }
1343
+ else {
1344
+ a_u4x32 = _mm_loadu_si128((__m128i const *)a); // Load full 16 bytes
1345
+ b_u4x32 = _mm_loadu_si128((__m128i const *)b);
1346
+ a += 16, b += 16, n_bytes -= 16;
1347
+ }
1348
+
1349
+ // Extract low and high nibbles
1350
+ __m128i a_lo_u8x16 = _mm_and_si128(a_u4x32, nibble_mask_u8x16);
1351
+ __m128i a_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(a_u4x32, 4), nibble_mask_u8x16);
1352
+ __m128i b_lo_u8x16 = _mm_and_si128(b_u4x32, nibble_mask_u8x16);
1353
+ __m128i b_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(b_u4x32, 4), nibble_mask_u8x16);
1354
+
1355
+ // Widen u8 to i16
1356
+ __m256i a_lo_i16x16 = _mm256_cvtepu8_epi16(a_lo_u8x16);
1357
+ __m256i a_hi_i16x16 = _mm256_cvtepu8_epi16(a_hi_u8x16);
1358
+ __m256i b_lo_i16x16 = _mm256_cvtepu8_epi16(b_lo_u8x16);
1359
+ __m256i b_hi_i16x16 = _mm256_cvtepu8_epi16(b_hi_u8x16);
1360
+
1361
+ // Multiply i16×i16 and accumulate to i32 using MADD
1362
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_lo_i16x16, b_lo_i16x16));
1363
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_hi_i16x16, b_hi_i16x16));
1364
+
1365
+ if (n_bytes) goto nk_dot_u4_haswell_cycle;
1366
+
1367
+ *result = (nk_u32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8);
1368
+ }
1369
+
1370
+ /**
1371
+ * @brief Internal helper state for dot-products of integer types, where 32-bit accumulation is enough.
1372
+ * @sa nk_dot_i8x16_state_haswell_t, nk_dot_u8x16_state_haswell_t
1373
+ */
1374
+ typedef struct nk_dot_through_i32_state_haswell_t_ {
1375
+ __m256i sum_i32x8;
1376
+ } nk_dot_through_i32_state_haswell_t_;
1377
+
1378
+ /**
1379
+ * @brief Initializes 32-bit accumulators for integer dot-products.
1380
+ * @sa nk_dot_i8x16_update_haswell, nk_dot_u8x16_update_haswell
1381
+ */
1382
+ NK_INTERNAL void nk_dot_through_i32_init_haswell_(nk_dot_through_i32_state_haswell_t_ *state) {
1383
+ state->sum_i32x8 = _mm256_setzero_si256();
1384
+ }
1385
+
1386
+ /**
1387
+ * @brief Finalizes 4x integer dot-products placing them into 4x consecutive 32-bit slots.
1388
+ * @sa nk_dot_i8x16_update_haswell, nk_dot_u8x16_update_haswell
1389
+ */
1390
+ NK_INTERNAL void nk_dot_through_i32_finalize_haswell_( //
1391
+ nk_dot_through_i32_state_haswell_t_ const *state_a, nk_dot_through_i32_state_haswell_t_ const *state_b, //
1392
+ nk_dot_through_i32_state_haswell_t_ const *state_c, nk_dot_through_i32_state_haswell_t_ const *state_d, //
1393
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
1394
+ nk_unused_(total_dimensions);
1395
+ // ILP-optimized 4-way horizontal reduction for i32 in AVX2
1396
+ // Step 1: 8->4 for all 4 states
1397
+ __m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->sum_i32x8),
1398
+ _mm256_extracti128_si256(state_a->sum_i32x8, 1));
1399
+ __m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->sum_i32x8),
1400
+ _mm256_extracti128_si256(state_b->sum_i32x8, 1));
1401
+ __m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->sum_i32x8),
1402
+ _mm256_extracti128_si256(state_c->sum_i32x8, 1));
1403
+ __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->sum_i32x8),
1404
+ _mm256_extracti128_si256(state_d->sum_i32x8, 1));
1405
+ // Step 2: Transpose 4×4 matrix
1406
+ __m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
1407
+ __m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
1408
+ __m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
1409
+ __m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
1410
+ __m128i sum_lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
1411
+ __m128i sum_lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
1412
+ __m128i sum_lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
1413
+ __m128i sum_lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
1414
+ // Step 3: Vertical sum and store as i32
1415
+ __m128i sum_i32x4 = _mm_add_epi32(_mm_add_epi32(sum_lane0_i32x4, sum_lane1_i32x4),
1416
+ _mm_add_epi32(sum_lane2_i32x4, sum_lane3_i32x4));
1417
+ result->xmm = sum_i32x4;
1418
+ }
1419
+
1420
+ /**
1421
+ * @brief Running state for 128-bit dot accumulation over i8 scalars on Haswell.
1422
+ * @note Alias of nk_dot_through_i32_state_haswell_t_
1423
+ */
1424
+ typedef struct nk_dot_through_i32_state_haswell_t_ nk_dot_i8x16_state_haswell_t;
1425
+
1426
+ NK_INTERNAL void nk_dot_i8x16_init_haswell(nk_dot_i8x16_state_haswell_t *state) {
1427
+ nk_dot_through_i32_init_haswell_(state);
1428
+ }
1429
+
1430
+ NK_INTERNAL void nk_dot_i8x16_update_haswell(nk_dot_i8x16_state_haswell_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
1431
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
1432
+ nk_unused_(depth_offset);
1433
+ nk_unused_(active_dimensions);
1434
+ __m256i a_i16x16 = _mm256_cvtepi8_epi16(a.xmm);
1435
+ __m256i b_i16x16 = _mm256_cvtepi8_epi16(b.xmm);
1436
+ state->sum_i32x8 = _mm256_add_epi32(state->sum_i32x8, _mm256_madd_epi16(a_i16x16, b_i16x16));
1437
+ }
1438
+
1439
+ NK_INTERNAL void nk_dot_i8x16_finalize_haswell( //
1440
+ nk_dot_i8x16_state_haswell_t const *state_a, nk_dot_i8x16_state_haswell_t const *state_b, //
1441
+ nk_dot_i8x16_state_haswell_t const *state_c, nk_dot_i8x16_state_haswell_t const *state_d, //
1442
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
1443
+ nk_dot_through_i32_finalize_haswell_(state_a, state_b, state_c, state_d, total_dimensions, result);
1444
+ }
1445
+
1446
+ /**
1447
+ * @brief Running state for 128-bit dot accumulation over u8 scalars on Haswell.
1448
+ * @note Alias of nk_dot_through_i32_state_haswell_t_
1449
+ */
1450
+ typedef struct nk_dot_through_i32_state_haswell_t_ nk_dot_u8x16_state_haswell_t;
1451
+
1452
+ NK_INTERNAL void nk_dot_u8x16_init_haswell(nk_dot_u8x16_state_haswell_t *state) {
1453
+ nk_dot_through_i32_init_haswell_(state);
1454
+ }
1455
+
1456
+ NK_INTERNAL void nk_dot_u8x16_update_haswell(nk_dot_u8x16_state_haswell_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
1457
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
1458
+ nk_unused_(depth_offset);
1459
+ nk_unused_(active_dimensions);
1460
+ __m256i a_i16x16 = _mm256_cvtepu8_epi16(a.xmm);
1461
+ __m256i b_i16x16 = _mm256_cvtepu8_epi16(b.xmm);
1462
+ state->sum_i32x8 = _mm256_add_epi32(state->sum_i32x8, _mm256_madd_epi16(a_i16x16, b_i16x16));
1463
+ }
1464
+
1465
+ NK_INTERNAL void nk_dot_u8x16_finalize_haswell( //
1466
+ nk_dot_u8x16_state_haswell_t const *state_a, nk_dot_u8x16_state_haswell_t const *state_b, //
1467
+ nk_dot_u8x16_state_haswell_t const *state_c, nk_dot_u8x16_state_haswell_t const *state_d, //
1468
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
1469
+ nk_dot_through_i32_finalize_haswell_(state_a, state_b, state_c, state_d, total_dimensions, result);
1470
+ }
1471
+
1472
+ /**
1473
+ * @brief State for batched i4 dot products on Haswell.
1474
+ * Processes 32 nibbles (16 bytes) per update iteration for optimal ILP.
1475
+ */
1476
+ typedef struct nk_dot_i4x32_state_haswell_t {
1477
+ __m256i biased_product_sum_i32x8; // Single accumulator: (a^8)×(b^8) products
1478
+ } nk_dot_i4x32_state_haswell_t;
1479
+
1480
+ NK_INTERNAL void nk_dot_i4x32_init_haswell(nk_dot_i4x32_state_haswell_t *state) {
1481
+ state->biased_product_sum_i32x8 = _mm256_setzero_si256();
1482
+ }
1483
+
1484
+ NK_INTERNAL void nk_dot_i4x32_update_haswell(nk_dot_i4x32_state_haswell_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
1485
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
1486
+ // Process 32 nibbles (16 bytes) from the full 128-bit vector
1487
+ // Algebraic transformation: a×b = (a^8)×(b^8) − 8×(Σa + Σb) − 64×n
1488
+ // Correction applied at finalize time using precomputed sums.
1489
+ nk_unused_(depth_offset);
1490
+ nk_unused_(active_dimensions);
1491
+
1492
+ __m128i const nibble_mask_u8x16 = _mm_set1_epi8(0x0F);
1493
+ __m128i const xor_mask_u8x16 = _mm_set1_epi8(0x08);
1494
+
1495
+ __m128i a_i4x32 = a.xmm;
1496
+ __m128i b_i4x32 = b.xmm;
1497
+
1498
+ // Extract low and high nibbles
1499
+ __m128i a_lo_u8x16 = _mm_and_si128(a_i4x32, nibble_mask_u8x16);
1500
+ __m128i a_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(a_i4x32, 4), nibble_mask_u8x16);
1501
+ __m128i b_lo_u8x16 = _mm_and_si128(b_i4x32, nibble_mask_u8x16);
1502
+ __m128i b_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(b_i4x32, 4), nibble_mask_u8x16);
1503
+
1504
+ // XOR with 8 for algebraic transformation
1505
+ __m128i c_lo_u8x16 = _mm_xor_si128(a_lo_u8x16, xor_mask_u8x16);
1506
+ __m128i c_hi_u8x16 = _mm_xor_si128(a_hi_u8x16, xor_mask_u8x16);
1507
+ __m128i d_lo_u8x16 = _mm_xor_si128(b_lo_u8x16, xor_mask_u8x16);
1508
+ __m128i d_hi_u8x16 = _mm_xor_si128(b_hi_u8x16, xor_mask_u8x16);
1509
+
1510
+ // Widen u8 to i16 and multiply using MADD
1511
+ __m256i c_lo_i16x16 = _mm256_cvtepu8_epi16(c_lo_u8x16);
1512
+ __m256i c_hi_i16x16 = _mm256_cvtepu8_epi16(c_hi_u8x16);
1513
+ __m256i d_lo_i16x16 = _mm256_cvtepu8_epi16(d_lo_u8x16);
1514
+ __m256i d_hi_i16x16 = _mm256_cvtepu8_epi16(d_hi_u8x16);
1515
+
1516
+ // Multiply and accumulate (no SAD — correction deferred to finalize)
1517
+ state->biased_product_sum_i32x8 = _mm256_add_epi32(state->biased_product_sum_i32x8,
1518
+ _mm256_madd_epi16(c_lo_i16x16, d_lo_i16x16));
1519
+ state->biased_product_sum_i32x8 = _mm256_add_epi32(state->biased_product_sum_i32x8,
1520
+ _mm256_madd_epi16(c_hi_i16x16, d_hi_i16x16));
1521
+ }
1522
+
1523
+ NK_INTERNAL void nk_dot_i4x32_finalize_haswell( //
1524
+ nk_dot_i4x32_state_haswell_t const *state_a, nk_dot_i4x32_state_haswell_t const *state_b, //
1525
+ nk_dot_i4x32_state_haswell_t const *state_c, nk_dot_i4x32_state_haswell_t const *state_d, //
1526
+ nk_size_t total_dimensions, //
1527
+ nk_i32_t a_sum, /* A row sum (signed sum of i4 values) */ //
1528
+ nk_b128_vec_t b_sums, /* 4 × i32 B column sums */ //
1529
+ nk_b128_vec_t *result) {
1530
+
1531
+ // Compensated 4-way reduction with external correction sums.
1532
+ // Formula: result = biased_product − 8×(Σa + Σb) − 64×depth_padded
1533
+ nk_size_t depth_nibbles = nk_size_round_up_to_multiple_(total_dimensions, 32);
1534
+
1535
+ // Reduce main products from ymm (i32x8) to xmm (i32x4)
1536
+ __m128i product_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->biased_product_sum_i32x8),
1537
+ _mm256_extracti128_si256(state_a->biased_product_sum_i32x8, 1));
1538
+ __m128i product_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->biased_product_sum_i32x8),
1539
+ _mm256_extracti128_si256(state_b->biased_product_sum_i32x8, 1));
1540
+ __m128i product_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->biased_product_sum_i32x8),
1541
+ _mm256_extracti128_si256(state_c->biased_product_sum_i32x8, 1));
1542
+ __m128i product_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->biased_product_sum_i32x8),
1543
+ _mm256_extracti128_si256(state_d->biased_product_sum_i32x8, 1));
1544
+
1545
+ // 4-way transpose to get [a,b,c,d] in lanes
1546
+ __m128i transpose_ab_low = _mm_unpacklo_epi32(product_a_i32x4, product_b_i32x4);
1547
+ __m128i transpose_cd_low = _mm_unpacklo_epi32(product_c_i32x4, product_d_i32x4);
1548
+ __m128i transpose_ab_high = _mm_unpackhi_epi32(product_a_i32x4, product_b_i32x4);
1549
+ __m128i transpose_cd_high = _mm_unpackhi_epi32(product_c_i32x4, product_d_i32x4);
1550
+ __m128i biased_i32x4 = _mm_add_epi32(_mm_add_epi32(_mm_unpacklo_epi64(transpose_ab_low, transpose_cd_low),
1551
+ _mm_unpackhi_epi64(transpose_ab_low, transpose_cd_low)),
1552
+ _mm_add_epi32(_mm_unpacklo_epi64(transpose_ab_high, transpose_cd_high),
1553
+ _mm_unpackhi_epi64(transpose_ab_high, transpose_cd_high)));
1554
+
1555
+ // Apply compensation: result = biased − 8×(Σa + Σb) − 64×depth_padded
1556
+ __m128i a_sum_broadcast_i32x4 = _mm_set1_epi32(a_sum);
1557
+ __m128i ab_sums_i32x4 = _mm_add_epi32(a_sum_broadcast_i32x4, b_sums.xmm);
1558
+ __m128i correction_i32x4 = _mm_slli_epi32(ab_sums_i32x4, 3); // × 8
1559
+ __m128i offset_i32x4 = _mm_set1_epi32((nk_i32_t)(-64LL * (nk_i64_t)depth_nibbles));
1560
+ result->xmm = _mm_add_epi32(_mm_sub_epi32(biased_i32x4, correction_i32x4), offset_i32x4);
1561
+ }
1562
+
1563
+ /**
1564
+ * @brief State for batched u4 dot products on Haswell.
1565
+ * Processes 32 nibbles (16 bytes) per update iteration for optimal ILP.
1566
+ */
1567
+ typedef struct nk_dot_u4x32_state_haswell_t {
1568
+ __m256i product_sum_i32x8; // Main product accumulator
1569
+ } nk_dot_u4x32_state_haswell_t;
1570
+
1571
+ NK_INTERNAL void nk_dot_u4x32_init_haswell(nk_dot_u4x32_state_haswell_t *state) {
1572
+ state->product_sum_i32x8 = _mm256_setzero_si256();
1573
+ }
1574
+
1575
+ NK_INTERNAL void nk_dot_u4x32_update_haswell(nk_dot_u4x32_state_haswell_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
1576
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
1577
+ // Process 32 nibbles (16 bytes) from the full 128-bit vector
1578
+ // No algebraic transformation needed for unsigned values
1579
+ nk_unused_(depth_offset);
1580
+ nk_unused_(active_dimensions);
1581
+
1582
+ __m128i const nibble_mask_u8x16 = _mm_set1_epi8(0x0F);
1583
+
1584
+ __m128i a_u4x32 = a.xmm;
1585
+ __m128i b_u4x32 = b.xmm;
1586
+
1587
+ // Extract low and high nibbles
1588
+ __m128i a_lo_u8x16 = _mm_and_si128(a_u4x32, nibble_mask_u8x16);
1589
+ __m128i a_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(a_u4x32, 4), nibble_mask_u8x16);
1590
+ __m128i b_lo_u8x16 = _mm_and_si128(b_u4x32, nibble_mask_u8x16);
1591
+ __m128i b_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(b_u4x32, 4), nibble_mask_u8x16);
1592
+
1593
+ // Widen u8 to i16
1594
+ __m256i a_lo_i16x16 = _mm256_cvtepu8_epi16(a_lo_u8x16);
1595
+ __m256i a_hi_i16x16 = _mm256_cvtepu8_epi16(a_hi_u8x16);
1596
+ __m256i b_lo_i16x16 = _mm256_cvtepu8_epi16(b_lo_u8x16);
1597
+ __m256i b_hi_i16x16 = _mm256_cvtepu8_epi16(b_hi_u8x16);
1598
+
1599
+ // Multiply and accumulate
1600
+ state->product_sum_i32x8 = _mm256_add_epi32(state->product_sum_i32x8, _mm256_madd_epi16(a_lo_i16x16, b_lo_i16x16));
1601
+ state->product_sum_i32x8 = _mm256_add_epi32(state->product_sum_i32x8, _mm256_madd_epi16(a_hi_i16x16, b_hi_i16x16));
1602
+ }
1603
+
1604
+ NK_INTERNAL void nk_dot_u4x32_finalize_haswell( //
1605
+ nk_dot_u4x32_state_haswell_t const *state_a, nk_dot_u4x32_state_haswell_t const *state_b, //
1606
+ nk_dot_u4x32_state_haswell_t const *state_c, nk_dot_u4x32_state_haswell_t const *state_d, //
1607
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
1608
+ nk_unused_(total_dimensions);
1609
+
1610
+ // 4-way ILP-optimized reduction (no algebraic correction needed for unsigned)
1611
+ // Reduce main products from ymm (i32x8) to xmm (i32x4)
1612
+ __m128i product_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_a->product_sum_i32x8),
1613
+ _mm256_extracti128_si256(state_a->product_sum_i32x8, 1));
1614
+ __m128i product_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_b->product_sum_i32x8),
1615
+ _mm256_extracti128_si256(state_b->product_sum_i32x8, 1));
1616
+ __m128i product_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_c->product_sum_i32x8),
1617
+ _mm256_extracti128_si256(state_c->product_sum_i32x8, 1));
1618
+ __m128i product_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(state_d->product_sum_i32x8),
1619
+ _mm256_extracti128_si256(state_d->product_sum_i32x8, 1));
1620
+
1621
+ // 4-way transpose to get [a,b,c,d] in lanes
1622
+ __m128i transpose_ab_low = _mm_unpacklo_epi32(product_a_i32x4, product_b_i32x4);
1623
+ __m128i transpose_cd_low = _mm_unpacklo_epi32(product_c_i32x4, product_d_i32x4);
1624
+ __m128i transpose_ab_high = _mm_unpackhi_epi32(product_a_i32x4, product_b_i32x4);
1625
+ __m128i transpose_cd_high = _mm_unpackhi_epi32(product_c_i32x4, product_d_i32x4);
1626
+ __m128i product_lane0 = _mm_unpacklo_epi64(transpose_ab_low, transpose_cd_low);
1627
+ __m128i product_lane1 = _mm_unpackhi_epi64(transpose_ab_low, transpose_cd_low);
1628
+ __m128i product_lane2 = _mm_unpacklo_epi64(transpose_ab_high, transpose_cd_high);
1629
+ __m128i product_lane3 = _mm_unpackhi_epi64(transpose_ab_high, transpose_cd_high);
1630
+
1631
+ // Sum product lanes
1632
+ result->xmm = _mm_add_epi32(_mm_add_epi32(product_lane0, product_lane1),
1633
+ _mm_add_epi32(product_lane2, product_lane3));
1634
+ }
1635
+
1636
+ #pragma endregion - Small Integers
1637
+
1638
+ #pragma region - Binary
1639
+
1640
+ NK_PUBLIC void nk_dot_u1_haswell(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
1641
+ nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
1642
+ nk_u32_t dot = 0;
1643
+ for (; n_bytes >= 8; n_bytes -= 8, a += 8, b += 8)
1644
+ dot += (nk_u32_t)_mm_popcnt_u64(*(nk_u64_t const *)a & *(nk_u64_t const *)b);
1645
+ for (; n_bytes; --n_bytes, ++a, ++b) dot += (nk_u32_t)_mm_popcnt_u32(*a & *b);
1646
+ *result = dot;
1647
+ }
1648
+
1649
+ typedef struct nk_dot_u1x128_state_haswell_t {
1650
+ nk_u32_t dot_count;
1651
+ } nk_dot_u1x128_state_haswell_t;
1652
+
1653
+ NK_INTERNAL void nk_dot_u1x128_init_haswell(nk_dot_u1x128_state_haswell_t *state) { state->dot_count = 0; }
1654
+
1655
+ NK_INTERNAL void nk_dot_u1x128_update_haswell(nk_dot_u1x128_state_haswell_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
1656
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
1657
+ nk_unused_(depth_offset);
1658
+ nk_unused_(active_dimensions);
1659
+ state->dot_count += (nk_u32_t)_mm_popcnt_u64(a.u64s[0] & b.u64s[0]);
1660
+ state->dot_count += (nk_u32_t)_mm_popcnt_u64(a.u64s[1] & b.u64s[1]);
1661
+ }
1662
+
1663
+ NK_INTERNAL void nk_dot_u1x128_finalize_haswell( //
1664
+ nk_dot_u1x128_state_haswell_t const *state_a, nk_dot_u1x128_state_haswell_t const *state_b,
1665
+ nk_dot_u1x128_state_haswell_t const *state_c, nk_dot_u1x128_state_haswell_t const *state_d,
1666
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
1667
+ nk_unused_(total_dimensions);
1668
+ result->u32s[0] = state_a->dot_count;
1669
+ result->u32s[1] = state_b->dot_count;
1670
+ result->u32s[2] = state_c->dot_count;
1671
+ result->u32s[3] = state_d->dot_count;
1672
+ }
1673
+
1674
+ #pragma endregion - Binary
1675
+
1676
+ #if defined(__clang__)
1677
+ #pragma clang attribute pop
1678
+ #elif defined(__GNUC__)
1679
+ #pragma GCC pop_options
1680
+ #endif
1681
+
1682
+ #if defined(__cplusplus)
1683
+ } // extern "C"
1684
+ #endif
1685
+
1686
+ #endif // NK_TARGET_HASWELL
1687
+ #endif // NK_TARGET_X86_
1688
+ #endif // NK_DOT_HASWELL_H