numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,12 +8,12 @@
8
8
  *
9
9
  * @section dot_haswell_instructions Key AVX2/FMA Dot Product Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm256_fmadd_ps/pd VFMADD (YMM, YMM, YMM) 5cy 0.5/cy p01
13
- * _mm256_mul_ps/pd VMULPS/PD (YMM, YMM, YMM) 5cy 0.5/cy p01
14
- * _mm256_add_ps/pd VADDPS/PD (YMM, YMM, YMM) 3cy 1/cy p01
15
- * _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy 1/cy p01
16
- * _mm256_cvtps_pd VCVTPS2PD (YMM, XMM) 2cy 1/cy p01
11
+ * Intrinsic Instruction Haswell Genoa
12
+ * _mm256_fmadd_ps/pd VFMADD (YMM, YMM, YMM) 5cy @ p01 4cy @ p01
13
+ * _mm256_mul_ps/pd VMULPS/PD (YMM, YMM, YMM) 5cy @ p01 3cy @ p01
14
+ * _mm256_add_ps/pd VADDPS/PD (YMM, YMM, YMM) 3cy @ p01 3cy @ p23
15
+ * _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy @ p01 4cy @ p12+p23
16
+ * _mm256_cvtps_pd VCVTPS2PD (YMM, XMM) 2cy @ p01 4cy @ p12+p23
17
17
  *
18
18
  * For small numeric types (F16, BF16, E4M3, E5M2) we use F32 accumulators. For F32 dot products,
19
19
  * upcasting to F64 and downcasting back is faster than stable summation algorithms. For F64 we
@@ -141,7 +141,7 @@ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x4_haswell_(__m256d sum_f64x4, __m256d
141
141
  return tentative_sum + (lower_error + upper_error + rounding_error);
142
142
  }
143
143
 
144
- #pragma region - Traditional Floats
144
+ #pragma region F32 and F64 Floats
145
145
 
146
146
  NK_PUBLIC void nk_dot_f32_haswell(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
147
147
  nk_f64_t *result) {
@@ -479,30 +479,35 @@ NK_INTERNAL void nk_dot_f32x4_finalize_haswell(
479
479
  result->ymm_pd = sum_abcd_f64x4;
480
480
  }
481
481
 
482
- #pragma endregion - Traditional Floats
482
+ #pragma endregion F32 and F64 Floats
483
483
 
484
- #pragma region - Smaller Floats
484
+ #pragma region F16 and BF16 Floats
485
485
 
486
486
  NK_PUBLIC void nk_dot_bf16_haswell(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
487
487
  nk_f32_t *result) {
488
- __m128i a_bf16x8, b_bf16x8;
488
+ __m256i a_bf16_i16x16, b_bf16_i16x16;
489
489
  __m256 sum_f32x8 = _mm256_setzero_ps();
490
+ __m256i mask_high_u32x8 = _mm256_set1_epi32((int)0xFFFF0000);
490
491
  nk_dot_bf16_haswell_cycle:
491
- if (count_scalars < 8) {
492
+ if (count_scalars < 16) {
492
493
  nk_b256_vec_t a_vec, b_vec;
493
494
  nk_partial_load_b16x16_serial_(a_scalars, &a_vec, count_scalars);
494
495
  nk_partial_load_b16x16_serial_(b_scalars, &b_vec, count_scalars);
495
- a_bf16x8 = a_vec.xmms[0];
496
- b_bf16x8 = b_vec.xmms[0];
496
+ a_bf16_i16x16 = a_vec.ymm;
497
+ b_bf16_i16x16 = b_vec.ymm;
497
498
  count_scalars = 0;
498
499
  }
499
500
  else {
500
- a_bf16x8 = _mm_loadu_si128((__m128i const *)a_scalars);
501
- b_bf16x8 = _mm_loadu_si128((__m128i const *)b_scalars);
502
- a_scalars += 8, b_scalars += 8, count_scalars -= 8;
501
+ a_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)a_scalars);
502
+ b_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)b_scalars);
503
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
503
504
  }
504
- sum_f32x8 = _mm256_fmadd_ps(nk_bf16x8_to_f32x8_haswell_(a_bf16x8), nk_bf16x8_to_f32x8_haswell_(b_bf16x8),
505
- sum_f32x8);
505
+ __m256 a_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(a_bf16_i16x16, 16));
506
+ __m256 b_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(b_bf16_i16x16, 16));
507
+ sum_f32x8 = _mm256_fmadd_ps(a_even_f32x8, b_even_f32x8, sum_f32x8);
508
+ __m256 a_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(a_bf16_i16x16, mask_high_u32x8));
509
+ __m256 b_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(b_bf16_i16x16, mask_high_u32x8));
510
+ sum_f32x8 = _mm256_fmadd_ps(a_odd_f32x8, b_odd_f32x8, sum_f32x8);
506
511
  if (count_scalars) goto nk_dot_bf16_haswell_cycle;
507
512
  *result = (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_f32x8);
508
513
  }
@@ -534,7 +539,7 @@ NK_PUBLIC void nk_dot_bf16c_haswell(nk_bf16c_t const *a_pairs, nk_bf16c_t const
534
539
  nk_f32c_t *result) {
535
540
  // Convert BF16 to F32, then use F32 complex dot product with sign-flipping optimization.
536
541
  // Uses same XOR trick as f32c to double throughput by deferring sign flips until after loop.
537
- __m128i a_bf16x8, b_bf16x8;
542
+ __m128i a_bf16_i16x8, b_bf16_i16x8;
538
543
  __m256 sum_real_f32x8 = _mm256_setzero_ps();
539
544
  __m256 sum_imag_f32x8 = _mm256_setzero_ps();
540
545
  __m256i const sign_flip_i64x4 = _mm256_set1_epi64x(0x8000000000000000);
@@ -547,19 +552,19 @@ nk_dot_bf16c_haswell_cycle:
547
552
  nk_b256_vec_t a_vec, b_vec;
548
553
  nk_partial_load_b16x16_serial_(a_pairs, &a_vec, count_pairs * 2);
549
554
  nk_partial_load_b16x16_serial_(b_pairs, &b_vec, count_pairs * 2);
550
- a_bf16x8 = a_vec.xmms[0];
551
- b_bf16x8 = b_vec.xmms[0];
555
+ a_bf16_i16x8 = a_vec.xmms[0];
556
+ b_bf16_i16x8 = b_vec.xmms[0];
552
557
  count_pairs = 0;
553
558
  }
554
559
  else {
555
- a_bf16x8 = _mm_loadu_si128((__m128i const *)a_pairs);
556
- b_bf16x8 = _mm_loadu_si128((__m128i const *)b_pairs);
560
+ a_bf16_i16x8 = _mm_loadu_si128((__m128i const *)a_pairs);
561
+ b_bf16_i16x8 = _mm_loadu_si128((__m128i const *)b_pairs);
557
562
  a_pairs += 4, b_pairs += 4, count_pairs -= 4;
558
563
  }
559
564
 
560
565
  // Convert BF16 to F32
561
- __m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16x8);
562
- __m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16x8);
566
+ __m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16_i16x8);
567
+ __m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16_i16x8);
563
568
 
564
569
  // Complex multiply-accumulate: swap b for imaginary part
565
570
  __m256 b_swapped_f32x8 = _mm256_castsi256_ps(
@@ -579,7 +584,7 @@ nk_dot_bf16c_haswell_cycle:
579
584
  NK_PUBLIC void nk_vdot_bf16c_haswell(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
580
585
  nk_f32c_t *result) {
581
586
  // Conjugate complex dot product: conj(a) * b
582
- __m128i a_bf16x8, b_bf16x8;
587
+ __m128i a_bf16_i16x8, b_bf16_i16x8;
583
588
  __m256 sum_real_f32x8 = _mm256_setzero_ps();
584
589
  __m256 sum_imag_f32x8 = _mm256_setzero_ps();
585
590
  __m256i const sign_flip_i64x4 = _mm256_set1_epi64x(0x8000000000000000);
@@ -592,19 +597,19 @@ nk_vdot_bf16c_haswell_cycle:
592
597
  nk_b256_vec_t a_vec, b_vec;
593
598
  nk_partial_load_b16x16_serial_(a_pairs, &a_vec, count_pairs * 2);
594
599
  nk_partial_load_b16x16_serial_(b_pairs, &b_vec, count_pairs * 2);
595
- a_bf16x8 = a_vec.xmms[0];
596
- b_bf16x8 = b_vec.xmms[0];
600
+ a_bf16_i16x8 = a_vec.xmms[0];
601
+ b_bf16_i16x8 = b_vec.xmms[0];
597
602
  count_pairs = 0;
598
603
  }
599
604
  else {
600
- a_bf16x8 = _mm_loadu_si128((__m128i const *)a_pairs);
601
- b_bf16x8 = _mm_loadu_si128((__m128i const *)b_pairs);
605
+ a_bf16_i16x8 = _mm_loadu_si128((__m128i const *)a_pairs);
606
+ b_bf16_i16x8 = _mm_loadu_si128((__m128i const *)b_pairs);
602
607
  a_pairs += 4, b_pairs += 4, count_pairs -= 4;
603
608
  }
604
609
 
605
610
  // Convert BF16 to F32
606
- __m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16x8);
607
- __m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16x8);
611
+ __m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16_i16x8);
612
+ __m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16_i16x8);
608
613
 
609
614
  // Conjugate complex multiply-accumulate
610
615
  sum_real_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_real_f32x8);
@@ -724,10 +729,10 @@ NK_PUBLIC void nk_dot_e2m3_haswell(nk_e2m3_t const *a_scalars, nk_e2m3_t const *
724
729
  // lut_lower[0..15]: {0,2,4,6,8,10,12,14, 16,18,20,22,24,26,28,30}
725
730
  // lut_upper[0..15]: {32,36,40,44,48,52,56,60, 64,72,80,88,96,104,112,120}
726
731
  //
727
- __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
728
- 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
729
- __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
730
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
732
+ __m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
733
+ 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
734
+ __m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
735
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
731
736
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
732
737
  __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
733
738
  __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
@@ -756,18 +761,18 @@ nk_dot_e2m3_haswell_cycle:
756
761
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
757
762
  __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
758
763
  __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
759
- __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
760
- half_select_u8x32);
761
- __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
762
- half_select_u8x32);
764
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
765
+ half_select_u8x32);
766
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
767
+ half_select_u8x32);
763
768
 
764
769
  // Dual VPSHUFB: lookup in both halves, blend based on bit 4
765
- __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
766
- _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
767
- a_upper_select_u8x32);
768
- __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
769
- _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
770
- b_upper_select_u8x32);
770
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
771
+ _mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
772
+ a_high_select_u8x32);
773
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
774
+ _mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
775
+ b_high_select_u8x32);
771
776
 
772
777
  // Combined sign: (a ^ b) & 0x20, negate b where signs differ
773
778
  __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
@@ -795,10 +800,10 @@ NK_PUBLIC void nk_dot_e3m2_haswell(nk_e3m2_t const *a_scalars, nk_e3m2_t const *
795
800
  // lut_upper[0..15]: low bytes of {32,40,48,56,64,80,96,112,128,160,192,224,256,320,384,448}
796
801
  // High byte is 1 iff magnitude index >= 28 (values 256-448), else 0.
797
802
  //
798
- __m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
803
+ __m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
799
804
  28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
800
805
  28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
801
- __m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
806
+ __m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
802
807
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
803
808
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
804
809
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
@@ -831,42 +836,44 @@ nk_dot_e3m2_haswell_cycle:
831
836
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
832
837
  __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
833
838
  __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
834
- __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
835
- half_select_u8x32);
836
- __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
837
- half_select_u8x32);
839
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
840
+ half_select_u8x32);
841
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
842
+ half_select_u8x32);
838
843
 
839
844
  // Dual VPSHUFB: lookup low bytes in both halves, blend based on bit 4
840
- __m256i a_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, a_shuffle_index_u8x32),
841
- _mm256_shuffle_epi8(lut_lo_upper_u8x32, a_shuffle_index_u8x32),
842
- a_upper_select_u8x32);
843
- __m256i b_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, b_shuffle_index_u8x32),
844
- _mm256_shuffle_epi8(lut_lo_upper_u8x32, b_shuffle_index_u8x32),
845
- b_upper_select_u8x32);
845
+ __m256i a_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
846
+ _mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32),
847
+ a_high_select_u8x32);
848
+ __m256i b_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
849
+ _mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32),
850
+ b_high_select_u8x32);
846
851
 
847
852
  // High byte: 1 iff magnitude >= 28 (signed compare safe: 27 < 128)
848
- __m256i a_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
849
- __m256i b_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
853
+ __m256i a_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
854
+ ones_u8x32);
855
+ __m256i b_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
856
+ ones_u8x32);
850
857
 
851
858
  // Interleave low and high bytes into i16 (little-endian: low byte first)
852
- __m256i a_lo_i16x16 = _mm256_unpacklo_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
853
- __m256i a_hi_i16x16 = _mm256_unpackhi_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
854
- __m256i b_lo_i16x16 = _mm256_unpacklo_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
855
- __m256i b_hi_i16x16 = _mm256_unpackhi_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
859
+ __m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
860
+ __m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
861
+ __m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
862
+ __m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
856
863
 
857
864
  // Combined sign: (a ^ b) & 0x20, widen to i16 via unpack, create +1/-1 sign vector
858
865
  __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
859
866
  __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
860
- __m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
861
- __m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
862
- __m256i sign_lo_i16x16 = _mm256_or_si256(negate_lo_i16x16, ones_i16x16);
863
- __m256i sign_hi_i16x16 = _mm256_or_si256(negate_hi_i16x16, ones_i16x16);
864
- __m256i b_signed_lo_i16x16 = _mm256_sign_epi16(b_lo_i16x16, sign_lo_i16x16);
865
- __m256i b_signed_hi_i16x16 = _mm256_sign_epi16(b_hi_i16x16, sign_hi_i16x16);
867
+ __m256i negate_low_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
868
+ __m256i negate_high_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
869
+ __m256i sign_low_i16x16 = _mm256_or_si256(negate_low_i16x16, ones_i16x16);
870
+ __m256i sign_high_i16x16 = _mm256_or_si256(negate_high_i16x16, ones_i16x16);
871
+ __m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, sign_low_i16x16);
872
+ __m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, sign_high_i16x16);
866
873
 
867
874
  // VPMADDWD: a_unsigned_i16 × b_signed_i16 → i32 accumulator
868
- sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_lo_i16x16, b_signed_lo_i16x16));
869
- sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_hi_i16x16, b_signed_hi_i16x16));
875
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_low_i16x16, b_signed_low_i16x16));
876
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_high_i16x16, b_signed_high_i16x16));
870
877
 
871
878
  if (count_scalars) goto nk_dot_e3m2_haswell_cycle;
872
879
  *result = (nk_f32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8) / 256.0f;
@@ -946,10 +953,34 @@ NK_INTERNAL void nk_dot_through_f32_finalize_haswell_(
946
953
  typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_f16x8_state_haswell_t;
947
954
 
948
955
  /**
949
- * @brief Running state for 128-bit dot accumulation over bf16 scalars on Haswell.
950
- * @note Alias of nk_dot_through_f32_state_haswell_t_
956
+ * @brief Running state for 256-bit dot accumulation over bf16 scalars on Haswell.
957
+ * @note Processes 16 bf16 per tile step via unpack(zero, bf16) → 2×8 f32 FMA.
951
958
  */
952
- typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_bf16x8_state_haswell_t;
959
+ typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_bf16x16_state_haswell_t;
960
+
961
+ NK_INTERNAL void nk_dot_bf16x16_init_haswell(nk_dot_bf16x16_state_haswell_t *state) {
962
+ nk_dot_through_f32_init_haswell_(state);
963
+ }
964
+
965
+ NK_INTERNAL void nk_dot_bf16x16_update_haswell(nk_dot_bf16x16_state_haswell_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
966
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
967
+ nk_unused_(depth_offset);
968
+ nk_unused_(active_dimensions);
969
+ __m256i mask_high_u32x8 = _mm256_set1_epi32((int)0xFFFF0000);
970
+ __m256 a_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(a.ymm, 16));
971
+ __m256 b_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(b.ymm, 16));
972
+ state->sum_f32x8 = _mm256_fmadd_ps(a_even_f32x8, b_even_f32x8, state->sum_f32x8);
973
+ __m256 a_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(a.ymm, mask_high_u32x8));
974
+ __m256 b_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(b.ymm, mask_high_u32x8));
975
+ state->sum_f32x8 = _mm256_fmadd_ps(a_odd_f32x8, b_odd_f32x8, state->sum_f32x8);
976
+ }
977
+
978
+ NK_INTERNAL void nk_dot_bf16x16_finalize_haswell( //
979
+ nk_dot_bf16x16_state_haswell_t const *state_a, nk_dot_bf16x16_state_haswell_t const *state_b, //
980
+ nk_dot_bf16x16_state_haswell_t const *state_c, nk_dot_bf16x16_state_haswell_t const *state_d, //
981
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
982
+ nk_dot_through_f32_finalize_haswell_(state_a, state_b, state_c, state_d, total_dimensions, result);
983
+ }
953
984
 
954
985
  /**
955
986
  * @brief Running state for 128-bit dot accumulation over e4m3 scalars on Haswell.
@@ -991,10 +1022,10 @@ NK_INTERNAL void nk_dot_e2m3x32_update_haswell(nk_dot_e2m3x32_state_haswell_t *s
991
1022
  nk_size_t depth_offset, nk_size_t active_dimensions) {
992
1023
  nk_unused_(depth_offset);
993
1024
  nk_unused_(active_dimensions);
994
- __m256i const lut_lower_u8x32 = _mm256_set_epi8( //
1025
+ __m256i const lut_low_u8x32 = _mm256_set_epi8( //
995
1026
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
996
1027
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
997
- __m256i const lut_upper_u8x32 = _mm256_set_epi8( //
1028
+ __m256i const lut_high_u8x32 = _mm256_set_epi8( //
998
1029
  120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
999
1030
  120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
1000
1031
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
@@ -1011,18 +1042,18 @@ NK_INTERNAL void nk_dot_e2m3x32_update_haswell(nk_dot_e2m3x32_state_haswell_t *s
1011
1042
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
1012
1043
  __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
1013
1044
  __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
1014
- __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
1015
- half_select_u8x32);
1016
- __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
1017
- half_select_u8x32);
1045
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
1046
+ half_select_u8x32);
1047
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
1048
+ half_select_u8x32);
1018
1049
 
1019
1050
  // Dual VPSHUFB + blend
1020
- __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
1021
- _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
1022
- a_upper_select_u8x32);
1023
- __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
1024
- _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
1025
- b_upper_select_u8x32);
1051
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
1052
+ _mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
1053
+ a_high_select_u8x32);
1054
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
1055
+ _mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
1056
+ b_high_select_u8x32);
1026
1057
 
1027
1058
  // Combined sign + conditional negate
1028
1059
  __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
@@ -1086,9 +1117,9 @@ NK_INTERNAL void nk_dot_e3m2x32_update_haswell(nk_dot_e3m2x32_state_haswell_t *s
1086
1117
  nk_size_t depth_offset, nk_size_t active_dimensions) {
1087
1118
  nk_unused_(depth_offset);
1088
1119
  nk_unused_(active_dimensions);
1089
- __m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
1120
+ __m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
1090
1121
  28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1091
- __m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
1122
+ __m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
1092
1123
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
1093
1124
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
1094
1125
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
@@ -1107,42 +1138,44 @@ NK_INTERNAL void nk_dot_e3m2x32_update_haswell(nk_dot_e3m2x32_state_haswell_t *s
1107
1138
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
1108
1139
  __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
1109
1140
  __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
1110
- __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
1111
- half_select_u8x32);
1112
- __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
1113
- half_select_u8x32);
1141
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
1142
+ half_select_u8x32);
1143
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
1144
+ half_select_u8x32);
1114
1145
 
1115
1146
  // Dual VPSHUFB for low bytes
1116
- __m256i a_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, a_shuffle_index_u8x32),
1117
- _mm256_shuffle_epi8(lut_lo_upper_u8x32, a_shuffle_index_u8x32),
1118
- a_upper_select_u8x32);
1119
- __m256i b_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, b_shuffle_index_u8x32),
1120
- _mm256_shuffle_epi8(lut_lo_upper_u8x32, b_shuffle_index_u8x32),
1121
- b_upper_select_u8x32);
1147
+ __m256i a_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
1148
+ _mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32),
1149
+ a_high_select_u8x32);
1150
+ __m256i b_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
1151
+ _mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32),
1152
+ b_high_select_u8x32);
1122
1153
 
1123
1154
  // High byte: 1 iff magnitude >= 28
1124
- __m256i a_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
1125
- __m256i b_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
1155
+ __m256i a_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
1156
+ ones_u8x32);
1157
+ __m256i b_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
1158
+ ones_u8x32);
1126
1159
 
1127
1160
  // Interleave low and high bytes into i16
1128
- __m256i a_lo_i16x16 = _mm256_unpacklo_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
1129
- __m256i a_hi_i16x16 = _mm256_unpackhi_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
1130
- __m256i b_lo_i16x16 = _mm256_unpacklo_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
1131
- __m256i b_hi_i16x16 = _mm256_unpackhi_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
1161
+ __m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
1162
+ __m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
1163
+ __m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
1164
+ __m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
1132
1165
 
1133
1166
  // Combined sign: (a ^ b) & 0x20, widen to i16, create +1/-1 sign vector via VPSIGNW
1134
1167
  __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
1135
1168
  __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
1136
- __m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
1137
- __m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
1138
- __m256i sign_lo_i16x16 = _mm256_or_si256(negate_lo_i16x16, ones_i16x16);
1139
- __m256i sign_hi_i16x16 = _mm256_or_si256(negate_hi_i16x16, ones_i16x16);
1140
- __m256i b_signed_lo_i16x16 = _mm256_sign_epi16(b_lo_i16x16, sign_lo_i16x16);
1141
- __m256i b_signed_hi_i16x16 = _mm256_sign_epi16(b_hi_i16x16, sign_hi_i16x16);
1169
+ __m256i negate_low_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
1170
+ __m256i negate_high_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
1171
+ __m256i sign_low_i16x16 = _mm256_or_si256(negate_low_i16x16, ones_i16x16);
1172
+ __m256i sign_high_i16x16 = _mm256_or_si256(negate_high_i16x16, ones_i16x16);
1173
+ __m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, sign_low_i16x16);
1174
+ __m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, sign_high_i16x16);
1142
1175
 
1143
1176
  // VPMADDWD: a_unsigned_i16 × b_signed_i16 → i32 (two halves → two accumulators)
1144
- state->sum_a_i32x8 = _mm256_add_epi32(state->sum_a_i32x8, _mm256_madd_epi16(a_lo_i16x16, b_signed_lo_i16x16));
1145
- state->sum_b_i32x8 = _mm256_add_epi32(state->sum_b_i32x8, _mm256_madd_epi16(a_hi_i16x16, b_signed_hi_i16x16));
1177
+ state->sum_a_i32x8 = _mm256_add_epi32(state->sum_a_i32x8, _mm256_madd_epi16(a_low_i16x16, b_signed_low_i16x16));
1178
+ state->sum_b_i32x8 = _mm256_add_epi32(state->sum_b_i32x8, _mm256_madd_epi16(a_high_i16x16, b_signed_high_i16x16));
1146
1179
  }
1147
1180
 
1148
1181
  NK_INTERNAL void nk_dot_e3m2x32_finalize_haswell( //
@@ -1176,9 +1209,9 @@ NK_INTERNAL void nk_dot_e3m2x32_finalize_haswell(
1176
1209
  results->xmm = _mm_castps_si128(sum_f32x4);
1177
1210
  }
1178
1211
 
1179
- #pragma endregion - Smaller Floats
1212
+ #pragma endregion F16 and BF16 Floats
1180
1213
 
1181
- #pragma region - Small Integers
1214
+ #pragma region I8 and U8 Integers
1182
1215
 
1183
1216
  NK_PUBLIC void nk_dot_i8_haswell(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
1184
1217
  nk_i32_t *result) {
@@ -1275,33 +1308,33 @@ nk_dot_i4_haswell_cycle:
1275
1308
  }
1276
1309
 
1277
1310
  // Extract low and high nibbles
1278
- __m128i a_lo_u8x16 = _mm_and_si128(a_i4x32, nibble_mask_u8x16);
1279
- __m128i a_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(a_i4x32, 4), nibble_mask_u8x16);
1280
- __m128i b_lo_u8x16 = _mm_and_si128(b_i4x32, nibble_mask_u8x16);
1281
- __m128i b_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(b_i4x32, 4), nibble_mask_u8x16);
1311
+ __m128i a_low_u8x16 = _mm_and_si128(a_i4x32, nibble_mask_u8x16);
1312
+ __m128i a_high_u8x16 = _mm_and_si128(_mm_srli_epi16(a_i4x32, 4), nibble_mask_u8x16);
1313
+ __m128i b_low_u8x16 = _mm_and_si128(b_i4x32, nibble_mask_u8x16);
1314
+ __m128i b_high_u8x16 = _mm_and_si128(_mm_srli_epi16(b_i4x32, 4), nibble_mask_u8x16);
1282
1315
 
1283
1316
  // XOR with 8 to get cx, dx values for the algebraic transformation
1284
- __m128i c_lo_u8x16 = _mm_xor_si128(a_lo_u8x16, xor_mask_u8x16);
1285
- __m128i c_hi_u8x16 = _mm_xor_si128(a_hi_u8x16, xor_mask_u8x16);
1286
- __m128i d_lo_u8x16 = _mm_xor_si128(b_lo_u8x16, xor_mask_u8x16);
1287
- __m128i d_hi_u8x16 = _mm_xor_si128(b_hi_u8x16, xor_mask_u8x16);
1317
+ __m128i c_low_u8x16 = _mm_xor_si128(a_low_u8x16, xor_mask_u8x16);
1318
+ __m128i c_high_u8x16 = _mm_xor_si128(a_high_u8x16, xor_mask_u8x16);
1319
+ __m128i d_low_u8x16 = _mm_xor_si128(b_low_u8x16, xor_mask_u8x16);
1320
+ __m128i d_high_u8x16 = _mm_xor_si128(b_high_u8x16, xor_mask_u8x16);
1288
1321
 
1289
1322
  // Widen u8 to i16 and multiply using MADD (2× instead of 4×)
1290
- __m256i c_lo_i16x16 = _mm256_cvtepu8_epi16(c_lo_u8x16);
1291
- __m256i c_hi_i16x16 = _mm256_cvtepu8_epi16(c_hi_u8x16);
1292
- __m256i d_lo_i16x16 = _mm256_cvtepu8_epi16(d_lo_u8x16);
1293
- __m256i d_hi_i16x16 = _mm256_cvtepu8_epi16(d_hi_u8x16);
1323
+ __m256i c_low_i16x16 = _mm256_cvtepu8_epi16(c_low_u8x16);
1324
+ __m256i c_high_i16x16 = _mm256_cvtepu8_epi16(c_high_u8x16);
1325
+ __m256i d_low_i16x16 = _mm256_cvtepu8_epi16(d_low_u8x16);
1326
+ __m256i d_high_i16x16 = _mm256_cvtepu8_epi16(d_high_u8x16);
1294
1327
 
1295
1328
  // Multiply i16×i16 and accumulate to i32 using MADD
1296
- sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(c_lo_i16x16, d_lo_i16x16));
1297
- sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(c_hi_i16x16, d_hi_i16x16));
1329
+ sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(c_low_i16x16, d_low_i16x16));
1330
+ sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(c_high_i16x16, d_high_i16x16));
1298
1331
 
1299
1332
  // Optimization: Use SAD for correction sums (5cy vs 24cy for 8× widenings)
1300
1333
  // PSADBW sums 8× u8 values to a single i64 in each 64-bit lane
1301
- sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(c_lo_u8x16, zeros_u8x16));
1302
- sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(c_hi_u8x16, zeros_u8x16));
1303
- sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(d_lo_u8x16, zeros_u8x16));
1304
- sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(d_hi_u8x16, zeros_u8x16));
1334
+ sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(c_low_u8x16, zeros_u8x16));
1335
+ sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(c_high_u8x16, zeros_u8x16));
1336
+ sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(d_low_u8x16, zeros_u8x16));
1337
+ sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(d_high_u8x16, zeros_u8x16));
1305
1338
 
1306
1339
  if (n_bytes) goto nk_dot_i4_haswell_cycle;
1307
1340
 
@@ -1347,20 +1380,20 @@ nk_dot_u4_haswell_cycle:
1347
1380
  }
1348
1381
 
1349
1382
  // Extract low and high nibbles
1350
- __m128i a_lo_u8x16 = _mm_and_si128(a_u4x32, nibble_mask_u8x16);
1351
- __m128i a_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(a_u4x32, 4), nibble_mask_u8x16);
1352
- __m128i b_lo_u8x16 = _mm_and_si128(b_u4x32, nibble_mask_u8x16);
1353
- __m128i b_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(b_u4x32, 4), nibble_mask_u8x16);
1383
+ __m128i a_low_u8x16 = _mm_and_si128(a_u4x32, nibble_mask_u8x16);
1384
+ __m128i a_high_u8x16 = _mm_and_si128(_mm_srli_epi16(a_u4x32, 4), nibble_mask_u8x16);
1385
+ __m128i b_low_u8x16 = _mm_and_si128(b_u4x32, nibble_mask_u8x16);
1386
+ __m128i b_high_u8x16 = _mm_and_si128(_mm_srli_epi16(b_u4x32, 4), nibble_mask_u8x16);
1354
1387
 
1355
1388
  // Widen u8 to i16
1356
- __m256i a_lo_i16x16 = _mm256_cvtepu8_epi16(a_lo_u8x16);
1357
- __m256i a_hi_i16x16 = _mm256_cvtepu8_epi16(a_hi_u8x16);
1358
- __m256i b_lo_i16x16 = _mm256_cvtepu8_epi16(b_lo_u8x16);
1359
- __m256i b_hi_i16x16 = _mm256_cvtepu8_epi16(b_hi_u8x16);
1389
+ __m256i a_low_i16x16 = _mm256_cvtepu8_epi16(a_low_u8x16);
1390
+ __m256i a_high_i16x16 = _mm256_cvtepu8_epi16(a_high_u8x16);
1391
+ __m256i b_low_i16x16 = _mm256_cvtepu8_epi16(b_low_u8x16);
1392
+ __m256i b_high_i16x16 = _mm256_cvtepu8_epi16(b_high_u8x16);
1360
1393
 
1361
1394
  // Multiply i16×i16 and accumulate to i32 using MADD
1362
- sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_lo_i16x16, b_lo_i16x16));
1363
- sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_hi_i16x16, b_hi_i16x16));
1395
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_low_i16x16, b_low_i16x16));
1396
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_high_i16x16, b_high_i16x16));
1364
1397
 
1365
1398
  if (n_bytes) goto nk_dot_u4_haswell_cycle;
1366
1399
 
@@ -1496,28 +1529,28 @@ NK_INTERNAL void nk_dot_i4x32_update_haswell(nk_dot_i4x32_state_haswell_t *state
1496
1529
  __m128i b_i4x32 = b.xmm;
1497
1530
 
1498
1531
  // Extract low and high nibbles
1499
- __m128i a_lo_u8x16 = _mm_and_si128(a_i4x32, nibble_mask_u8x16);
1500
- __m128i a_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(a_i4x32, 4), nibble_mask_u8x16);
1501
- __m128i b_lo_u8x16 = _mm_and_si128(b_i4x32, nibble_mask_u8x16);
1502
- __m128i b_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(b_i4x32, 4), nibble_mask_u8x16);
1532
+ __m128i a_low_u8x16 = _mm_and_si128(a_i4x32, nibble_mask_u8x16);
1533
+ __m128i a_high_u8x16 = _mm_and_si128(_mm_srli_epi16(a_i4x32, 4), nibble_mask_u8x16);
1534
+ __m128i b_low_u8x16 = _mm_and_si128(b_i4x32, nibble_mask_u8x16);
1535
+ __m128i b_high_u8x16 = _mm_and_si128(_mm_srli_epi16(b_i4x32, 4), nibble_mask_u8x16);
1503
1536
 
1504
1537
  // XOR with 8 for algebraic transformation
1505
- __m128i c_lo_u8x16 = _mm_xor_si128(a_lo_u8x16, xor_mask_u8x16);
1506
- __m128i c_hi_u8x16 = _mm_xor_si128(a_hi_u8x16, xor_mask_u8x16);
1507
- __m128i d_lo_u8x16 = _mm_xor_si128(b_lo_u8x16, xor_mask_u8x16);
1508
- __m128i d_hi_u8x16 = _mm_xor_si128(b_hi_u8x16, xor_mask_u8x16);
1538
+ __m128i c_low_u8x16 = _mm_xor_si128(a_low_u8x16, xor_mask_u8x16);
1539
+ __m128i c_high_u8x16 = _mm_xor_si128(a_high_u8x16, xor_mask_u8x16);
1540
+ __m128i d_low_u8x16 = _mm_xor_si128(b_low_u8x16, xor_mask_u8x16);
1541
+ __m128i d_high_u8x16 = _mm_xor_si128(b_high_u8x16, xor_mask_u8x16);
1509
1542
 
1510
1543
  // Widen u8 to i16 and multiply using MADD
1511
- __m256i c_lo_i16x16 = _mm256_cvtepu8_epi16(c_lo_u8x16);
1512
- __m256i c_hi_i16x16 = _mm256_cvtepu8_epi16(c_hi_u8x16);
1513
- __m256i d_lo_i16x16 = _mm256_cvtepu8_epi16(d_lo_u8x16);
1514
- __m256i d_hi_i16x16 = _mm256_cvtepu8_epi16(d_hi_u8x16);
1544
+ __m256i c_low_i16x16 = _mm256_cvtepu8_epi16(c_low_u8x16);
1545
+ __m256i c_high_i16x16 = _mm256_cvtepu8_epi16(c_high_u8x16);
1546
+ __m256i d_low_i16x16 = _mm256_cvtepu8_epi16(d_low_u8x16);
1547
+ __m256i d_high_i16x16 = _mm256_cvtepu8_epi16(d_high_u8x16);
1515
1548
 
1516
1549
  // Multiply and accumulate (no SAD — correction deferred to finalize)
1517
1550
  state->biased_product_sum_i32x8 = _mm256_add_epi32(state->biased_product_sum_i32x8,
1518
- _mm256_madd_epi16(c_lo_i16x16, d_lo_i16x16));
1551
+ _mm256_madd_epi16(c_low_i16x16, d_low_i16x16));
1519
1552
  state->biased_product_sum_i32x8 = _mm256_add_epi32(state->biased_product_sum_i32x8,
1520
- _mm256_madd_epi16(c_hi_i16x16, d_hi_i16x16));
1553
+ _mm256_madd_epi16(c_high_i16x16, d_high_i16x16));
1521
1554
  }
1522
1555
 
1523
1556
  NK_INTERNAL void nk_dot_i4x32_finalize_haswell( //
@@ -1585,20 +1618,22 @@ NK_INTERNAL void nk_dot_u4x32_update_haswell(nk_dot_u4x32_state_haswell_t *state
1585
1618
  __m128i b_u4x32 = b.xmm;
1586
1619
 
1587
1620
  // Extract low and high nibbles
1588
- __m128i a_lo_u8x16 = _mm_and_si128(a_u4x32, nibble_mask_u8x16);
1589
- __m128i a_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(a_u4x32, 4), nibble_mask_u8x16);
1590
- __m128i b_lo_u8x16 = _mm_and_si128(b_u4x32, nibble_mask_u8x16);
1591
- __m128i b_hi_u8x16 = _mm_and_si128(_mm_srli_epi16(b_u4x32, 4), nibble_mask_u8x16);
1621
+ __m128i a_low_u8x16 = _mm_and_si128(a_u4x32, nibble_mask_u8x16);
1622
+ __m128i a_high_u8x16 = _mm_and_si128(_mm_srli_epi16(a_u4x32, 4), nibble_mask_u8x16);
1623
+ __m128i b_low_u8x16 = _mm_and_si128(b_u4x32, nibble_mask_u8x16);
1624
+ __m128i b_high_u8x16 = _mm_and_si128(_mm_srli_epi16(b_u4x32, 4), nibble_mask_u8x16);
1592
1625
 
1593
1626
  // Widen u8 to i16
1594
- __m256i a_lo_i16x16 = _mm256_cvtepu8_epi16(a_lo_u8x16);
1595
- __m256i a_hi_i16x16 = _mm256_cvtepu8_epi16(a_hi_u8x16);
1596
- __m256i b_lo_i16x16 = _mm256_cvtepu8_epi16(b_lo_u8x16);
1597
- __m256i b_hi_i16x16 = _mm256_cvtepu8_epi16(b_hi_u8x16);
1627
+ __m256i a_low_i16x16 = _mm256_cvtepu8_epi16(a_low_u8x16);
1628
+ __m256i a_high_i16x16 = _mm256_cvtepu8_epi16(a_high_u8x16);
1629
+ __m256i b_low_i16x16 = _mm256_cvtepu8_epi16(b_low_u8x16);
1630
+ __m256i b_high_i16x16 = _mm256_cvtepu8_epi16(b_high_u8x16);
1598
1631
 
1599
1632
  // Multiply and accumulate
1600
- state->product_sum_i32x8 = _mm256_add_epi32(state->product_sum_i32x8, _mm256_madd_epi16(a_lo_i16x16, b_lo_i16x16));
1601
- state->product_sum_i32x8 = _mm256_add_epi32(state->product_sum_i32x8, _mm256_madd_epi16(a_hi_i16x16, b_hi_i16x16));
1633
+ state->product_sum_i32x8 = _mm256_add_epi32(state->product_sum_i32x8,
1634
+ _mm256_madd_epi16(a_low_i16x16, b_low_i16x16));
1635
+ state->product_sum_i32x8 = _mm256_add_epi32(state->product_sum_i32x8,
1636
+ _mm256_madd_epi16(a_high_i16x16, b_high_i16x16));
1602
1637
  }
1603
1638
 
1604
1639
  NK_INTERNAL void nk_dot_u4x32_finalize_haswell( //
@@ -1619,23 +1654,23 @@ NK_INTERNAL void nk_dot_u4x32_finalize_haswell(
1619
1654
  _mm256_extracti128_si256(state_d->product_sum_i32x8, 1));
1620
1655
 
1621
1656
  // 4-way transpose to get [a,b,c,d] in lanes
1622
- __m128i transpose_ab_low = _mm_unpacklo_epi32(product_a_i32x4, product_b_i32x4);
1623
- __m128i transpose_cd_low = _mm_unpacklo_epi32(product_c_i32x4, product_d_i32x4);
1624
- __m128i transpose_ab_high = _mm_unpackhi_epi32(product_a_i32x4, product_b_i32x4);
1625
- __m128i transpose_cd_high = _mm_unpackhi_epi32(product_c_i32x4, product_d_i32x4);
1626
- __m128i product_lane0 = _mm_unpacklo_epi64(transpose_ab_low, transpose_cd_low);
1627
- __m128i product_lane1 = _mm_unpackhi_epi64(transpose_ab_low, transpose_cd_low);
1628
- __m128i product_lane2 = _mm_unpacklo_epi64(transpose_ab_high, transpose_cd_high);
1629
- __m128i product_lane3 = _mm_unpackhi_epi64(transpose_ab_high, transpose_cd_high);
1657
+ __m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(product_a_i32x4, product_b_i32x4);
1658
+ __m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(product_c_i32x4, product_d_i32x4);
1659
+ __m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(product_a_i32x4, product_b_i32x4);
1660
+ __m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(product_c_i32x4, product_d_i32x4);
1661
+ __m128i product_lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
1662
+ __m128i product_lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
1663
+ __m128i product_lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
1664
+ __m128i product_lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
1630
1665
 
1631
1666
  // Sum product lanes
1632
- result->xmm = _mm_add_epi32(_mm_add_epi32(product_lane0, product_lane1),
1633
- _mm_add_epi32(product_lane2, product_lane3));
1667
+ result->xmm = _mm_add_epi32(_mm_add_epi32(product_lane0_i32x4, product_lane1_i32x4),
1668
+ _mm_add_epi32(product_lane2_i32x4, product_lane3_i32x4));
1634
1669
  }
1635
1670
 
1636
- #pragma endregion - Small Integers
1671
+ #pragma endregion I8 and U8 Integers
1637
1672
 
1638
- #pragma region - Binary
1673
+ #pragma region Binary
1639
1674
 
1640
1675
  NK_PUBLIC void nk_dot_u1_haswell(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
1641
1676
  nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
@@ -1671,7 +1706,7 @@ NK_INTERNAL void nk_dot_u1x128_finalize_haswell( //
1671
1706
  result->u32s[3] = state_d->dot_count;
1672
1707
  }
1673
1708
 
1674
- #pragma endregion - Binary
1709
+ #pragma endregion Binary
1675
1710
 
1676
1711
  #if defined(__clang__)
1677
1712
  #pragma clang attribute pop