numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,11 +8,11 @@
8
8
  *
9
9
  * @section dot_skylake_instructions Key AVX-512 Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm512_madd_epi16 VPMADDWD (ZMM, ZMM, ZMM) 5cy 0.5/cy p05
13
- * _mm512_add_epi32 VPADDD (ZMM, ZMM, ZMM) 1cy 0.5/cy p05
14
- * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p05
15
- * _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy 1/cy p5
11
+ * Intrinsic Instruction Skylake-X Genoa
12
+ * _mm512_madd_epi16 VPMADDWD (ZMM, ZMM, ZMM) 5cy @ p05 3cy @ p01
13
+ * _mm512_add_epi32 VPADDD (ZMM, ZMM, ZMM) 1cy @ p05 1cy @ p0123
14
+ * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
15
+ * _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 4cy @ p12
16
16
  *
17
17
  * Skylake-X server chips feature dual 512-bit FMA units on ports 0 and 5, enabling 0.5cy throughput for
18
18
  * VFMADD and arithmetic operations. Client Skylake variants have only one FMA unit with 1cy throughput.
@@ -123,7 +123,7 @@ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x8_skylake_(__m512d sum_f64x8, __m512d
123
123
  return nk_dot_stable_sum_f64x4_haswell_(tentative_sum_f64x4, accumulated_error_f64x4);
124
124
  }
125
125
 
126
- #pragma region - Traditional Floats
126
+ #pragma region F32 and F64 Floats
127
127
 
128
128
  /**
129
129
  * @brief Internal helper state for dot-products of low-precision types, where 32-bit accumulation is enough.
@@ -479,7 +479,8 @@ nk_vdot_f64c_skylake_cycle:
479
479
  result->imag = nk_dot_stable_sum_f64x8_skylake_(sum_imag_f64x8, compensation_imag_f64x8);
480
480
  }
481
481
 
482
- #pragma region - Smaller Floats
482
+ #pragma endregion F32 and F64 Floats
483
+ #pragma region F16 and BF16 Floats
483
484
 
484
485
  NK_PUBLIC void nk_dot_f16_skylake(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
485
486
  nk_f32_t *result) {
@@ -508,24 +509,28 @@ nk_dot_f16_skylake_cycle:
508
509
 
509
510
  NK_PUBLIC void nk_dot_bf16_skylake(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
510
511
  nk_f32_t *result) {
511
- __m256i a_bf16x16, b_bf16x16;
512
+ __m512i a_bf16_i16x32, b_bf16_i16x32;
512
513
  __m512 sum_f32x16 = _mm512_setzero_ps();
514
+ __m512i mask_high_u32x16 = _mm512_set1_epi32((int)0xFFFF0000);
513
515
 
514
516
  nk_dot_bf16_skylake_cycle:
515
- if (count_scalars < 16) {
516
- __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
517
- a_bf16x16 = _mm256_maskz_loadu_epi16(mask, a_scalars);
518
- b_bf16x16 = _mm256_maskz_loadu_epi16(mask, b_scalars);
517
+ if (count_scalars < 32) {
518
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
519
+ a_bf16_i16x32 = _mm512_maskz_loadu_epi16(mask, a_scalars);
520
+ b_bf16_i16x32 = _mm512_maskz_loadu_epi16(mask, b_scalars);
519
521
  count_scalars = 0;
520
522
  }
521
523
  else {
522
- a_bf16x16 = _mm256_loadu_si256((__m256i const *)a_scalars);
523
- b_bf16x16 = _mm256_loadu_si256((__m256i const *)b_scalars);
524
- a_scalars += 16, b_scalars += 16, count_scalars -= 16;
524
+ a_bf16_i16x32 = _mm512_loadu_si512(a_scalars);
525
+ b_bf16_i16x32 = _mm512_loadu_si512(b_scalars);
526
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
525
527
  }
526
- __m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
527
- __m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
528
- sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
528
+ __m512 a_even_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(a_bf16_i16x32, 16));
529
+ __m512 b_even_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(b_bf16_i16x32, 16));
530
+ sum_f32x16 = _mm512_fmadd_ps(a_even_f32x16, b_even_f32x16, sum_f32x16);
531
+ __m512 a_odd_f32x16 = _mm512_castsi512_ps(_mm512_and_si512(a_bf16_i16x32, mask_high_u32x16));
532
+ __m512 b_odd_f32x16 = _mm512_castsi512_ps(_mm512_and_si512(b_bf16_i16x32, mask_high_u32x16));
533
+ sum_f32x16 = _mm512_fmadd_ps(a_odd_f32x16, b_odd_f32x16, sum_f32x16);
529
534
  if (count_scalars) goto nk_dot_bf16_skylake_cycle;
530
535
 
531
536
  *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
@@ -533,23 +538,23 @@ nk_dot_bf16_skylake_cycle:
533
538
 
534
539
  NK_PUBLIC void nk_dot_e4m3_skylake(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
535
540
  nk_f32_t *result) {
536
- __m128i a_e4m3x16, b_e4m3x16;
541
+ __m128i a_e4m3_u8x16, b_e4m3_u8x16;
537
542
  __m512 sum_f32x16 = _mm512_setzero_ps();
538
543
 
539
544
  nk_dot_e4m3_skylake_cycle:
540
545
  if (count_scalars < 16) {
541
546
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
542
- a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
543
- b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
547
+ a_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
548
+ b_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
544
549
  count_scalars = 0;
545
550
  }
546
551
  else {
547
- a_e4m3x16 = _mm_loadu_si128((__m128i const *)a_scalars);
548
- b_e4m3x16 = _mm_loadu_si128((__m128i const *)b_scalars);
552
+ a_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)a_scalars);
553
+ b_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)b_scalars);
549
554
  a_scalars += 16, b_scalars += 16, count_scalars -= 16;
550
555
  }
551
- __m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
552
- __m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3x16);
556
+ __m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3_u8x16);
557
+ __m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3_u8x16);
553
558
  sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
554
559
  if (count_scalars) goto nk_dot_e4m3_skylake_cycle;
555
560
 
@@ -558,23 +563,23 @@ nk_dot_e4m3_skylake_cycle:
558
563
 
559
564
  NK_PUBLIC void nk_dot_e5m2_skylake(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
560
565
  nk_f32_t *result) {
561
- __m128i a_e5m2x16, b_e5m2x16;
566
+ __m128i a_e5m2_u8x16, b_e5m2_u8x16;
562
567
  __m512 sum_f32x16 = _mm512_setzero_ps();
563
568
 
564
569
  nk_dot_e5m2_skylake_cycle:
565
570
  if (count_scalars < 16) {
566
571
  __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
567
- a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
568
- b_e5m2x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
572
+ a_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
573
+ b_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
569
574
  count_scalars = 0;
570
575
  }
571
576
  else {
572
- a_e5m2x16 = _mm_loadu_si128((__m128i const *)a_scalars);
573
- b_e5m2x16 = _mm_loadu_si128((__m128i const *)b_scalars);
577
+ a_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)a_scalars);
578
+ b_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)b_scalars);
574
579
  a_scalars += 16, b_scalars += 16, count_scalars -= 16;
575
580
  }
576
- __m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
577
- __m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2x16);
581
+ __m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2_u8x16);
582
+ __m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2_u8x16);
578
583
  sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
579
584
  if (count_scalars) goto nk_dot_e5m2_skylake_cycle;
580
585
 
@@ -587,12 +592,12 @@ NK_PUBLIC void nk_dot_e2m3_skylake(nk_e2m3_t const *a_scalars, nk_e2m3_t const *
587
592
  // 64 elements per iteration using AVX-512BW. Result = i32_dot / 256.0f (exact).
588
593
  //
589
594
  // LUTs replicated 4× for 512-bit VPSHUFB (operates per 128-bit lane):
590
- __m512i const lut_lower_u8x64 = _mm512_set_epi8( //
595
+ __m512i const lut_low_u8x64 = _mm512_set_epi8( //
591
596
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
592
597
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
593
598
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
594
599
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
595
- __m512i const lut_upper_u8x64 = _mm512_set_epi8( //
600
+ __m512i const lut_high_u8x64 = _mm512_set_epi8( //
596
601
  120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
597
602
  120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
598
603
  120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
@@ -625,16 +630,16 @@ nk_dot_e2m3_skylake_cycle:
625
630
  __m512i b_shuffle_index_u8x64 = _mm512_and_si512(b_magnitude_u8x64, nibble_mask_u8x64);
626
631
 
627
632
  // Bit-4 select via kmask (cleaner than Haswell's vector compare)
628
- __mmask64 a_upper_select = _mm512_test_epi8_mask(a_magnitude_u8x64, half_select_u8x64);
629
- __mmask64 b_upper_select = _mm512_test_epi8_mask(b_magnitude_u8x64, half_select_u8x64);
633
+ __mmask64 a_high_select = _mm512_test_epi8_mask(a_magnitude_u8x64, half_select_u8x64);
634
+ __mmask64 b_high_select = _mm512_test_epi8_mask(b_magnitude_u8x64, half_select_u8x64);
630
635
 
631
636
  // Dual VPSHUFB + mask-blend for 32-entry LUT
632
- __m512i a_unsigned_u8x64 = _mm512_mask_blend_epi8(a_upper_select,
633
- _mm512_shuffle_epi8(lut_lower_u8x64, a_shuffle_index_u8x64),
634
- _mm512_shuffle_epi8(lut_upper_u8x64, a_shuffle_index_u8x64));
635
- __m512i b_unsigned_u8x64 = _mm512_mask_blend_epi8(b_upper_select,
636
- _mm512_shuffle_epi8(lut_lower_u8x64, b_shuffle_index_u8x64),
637
- _mm512_shuffle_epi8(lut_upper_u8x64, b_shuffle_index_u8x64));
637
+ __m512i a_unsigned_u8x64 = _mm512_mask_blend_epi8(a_high_select,
638
+ _mm512_shuffle_epi8(lut_low_u8x64, a_shuffle_index_u8x64),
639
+ _mm512_shuffle_epi8(lut_high_u8x64, a_shuffle_index_u8x64));
640
+ __m512i b_unsigned_u8x64 = _mm512_mask_blend_epi8(b_high_select,
641
+ _mm512_shuffle_epi8(lut_low_u8x64, b_shuffle_index_u8x64),
642
+ _mm512_shuffle_epi8(lut_high_u8x64, b_shuffle_index_u8x64));
638
643
 
639
644
  // Combined sign: (a ^ b) & 0x20, negate b where signs differ using kmask
640
645
  __m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e2m3_u8x64, b_e2m3_u8x64), sign_mask_u8x64);
@@ -657,12 +662,12 @@ NK_PUBLIC void nk_dot_e3m2_skylake(nk_e3m2_t const *a_scalars, nk_e3m2_t const *
657
662
  // 64 elements per iteration using AVX-512BW. Magnitudes reach 448, requiring i16.
658
663
  // Result = i32_dot / 256.0f (exact, no rounding error).
659
664
  //
660
- __m512i const lut_lo_lower_u8x64 = _mm512_set_epi8( //
665
+ __m512i const lut_low_byte_first_u8x64 = _mm512_set_epi8( //
661
666
  28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
662
667
  28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
663
668
  28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
664
669
  28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
665
- __m512i const lut_lo_upper_u8x64 = _mm512_set_epi8( //
670
+ __m512i const lut_low_byte_second_u8x64 = _mm512_set_epi8( //
666
671
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
667
672
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
668
673
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
@@ -695,51 +700,53 @@ nk_dot_e3m2_skylake_cycle:
695
700
  __m512i b_shuffle_index_u8x64 = _mm512_and_si512(b_magnitude_u8x64, nibble_mask_u8x64);
696
701
 
697
702
  // Bit-4 select via kmask
698
- __mmask64 a_upper_select = _mm512_test_epi8_mask(a_magnitude_u8x64, half_select_u8x64);
699
- __mmask64 b_upper_select = _mm512_test_epi8_mask(b_magnitude_u8x64, half_select_u8x64);
703
+ __mmask64 a_high_select = _mm512_test_epi8_mask(a_magnitude_u8x64, half_select_u8x64);
704
+ __mmask64 b_high_select = _mm512_test_epi8_mask(b_magnitude_u8x64, half_select_u8x64);
700
705
 
701
706
  // Dual VPSHUFB + mask-blend for low bytes
702
- __m512i a_lo_bytes_u8x64 = _mm512_mask_blend_epi8(a_upper_select,
703
- _mm512_shuffle_epi8(lut_lo_lower_u8x64, a_shuffle_index_u8x64),
704
- _mm512_shuffle_epi8(lut_lo_upper_u8x64, a_shuffle_index_u8x64));
705
- __m512i b_lo_bytes_u8x64 = _mm512_mask_blend_epi8(b_upper_select,
706
- _mm512_shuffle_epi8(lut_lo_lower_u8x64, b_shuffle_index_u8x64),
707
- _mm512_shuffle_epi8(lut_lo_upper_u8x64, b_shuffle_index_u8x64));
707
+ __m512i a_low_byte_u8x64 = _mm512_mask_blend_epi8(
708
+ a_high_select, _mm512_shuffle_epi8(lut_low_byte_first_u8x64, a_shuffle_index_u8x64),
709
+ _mm512_shuffle_epi8(lut_low_byte_second_u8x64, a_shuffle_index_u8x64));
710
+ __m512i b_low_byte_u8x64 = _mm512_mask_blend_epi8(
711
+ b_high_select, _mm512_shuffle_epi8(lut_low_byte_first_u8x64, b_shuffle_index_u8x64),
712
+ _mm512_shuffle_epi8(lut_low_byte_second_u8x64, b_shuffle_index_u8x64));
708
713
 
709
714
  // High byte: 1 iff magnitude >= 28 (unsigned compare via _mm512_cmpge_epu8_mask)
710
- __mmask64 a_hi_mask = _mm512_cmpge_epu8_mask(a_magnitude_u8x64, _mm512_set1_epi8(28));
711
- __mmask64 b_hi_mask = _mm512_cmpge_epu8_mask(b_magnitude_u8x64, _mm512_set1_epi8(28));
712
- __m512i a_hi_bytes_u8x64 = _mm512_maskz_mov_epi8(a_hi_mask, ones_u8x64);
713
- __m512i b_hi_bytes_u8x64 = _mm512_maskz_mov_epi8(b_hi_mask, ones_u8x64);
715
+ __mmask64 a_high_mask = _mm512_cmpge_epu8_mask(a_magnitude_u8x64, _mm512_set1_epi8(28));
716
+ __mmask64 b_high_mask = _mm512_cmpge_epu8_mask(b_magnitude_u8x64, _mm512_set1_epi8(28));
717
+ __m512i a_high_byte_u8x64 = _mm512_maskz_mov_epi8(a_high_mask, ones_u8x64);
718
+ __m512i b_high_byte_u8x64 = _mm512_maskz_mov_epi8(b_high_mask, ones_u8x64);
714
719
 
715
720
  // Interleave low and high bytes into i16
716
- __m512i a_lo_i16x32 = _mm512_unpacklo_epi8(a_lo_bytes_u8x64, a_hi_bytes_u8x64);
717
- __m512i a_hi_i16x32 = _mm512_unpackhi_epi8(a_lo_bytes_u8x64, a_hi_bytes_u8x64);
718
- __m512i b_lo_i16x32 = _mm512_unpacklo_epi8(b_lo_bytes_u8x64, b_hi_bytes_u8x64);
719
- __m512i b_hi_i16x32 = _mm512_unpackhi_epi8(b_lo_bytes_u8x64, b_hi_bytes_u8x64);
721
+ __m512i a_low_i16x32 = _mm512_unpacklo_epi8(a_low_byte_u8x64, a_high_byte_u8x64);
722
+ __m512i a_high_i16x32 = _mm512_unpackhi_epi8(a_low_byte_u8x64, a_high_byte_u8x64);
723
+ __m512i b_low_i16x32 = _mm512_unpacklo_epi8(b_low_byte_u8x64, b_high_byte_u8x64);
724
+ __m512i b_high_i16x32 = _mm512_unpackhi_epi8(b_low_byte_u8x64, b_high_byte_u8x64);
720
725
 
721
726
  // Combined sign: (a ^ b) & 0x20, need to apply at i16 level
722
727
  // Compute sign mask at u8 level, widen to match unpacklo/unpackhi ordering via PEXT
723
728
  __m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e3m2_u8x64, b_e3m2_u8x64), sign_mask_u8x64);
724
729
  __mmask64 negate_u8_mask = _mm512_test_epi8_mask(sign_combined_u8x64, sign_combined_u8x64);
725
730
  // Extract bits matching unpacklo element ordering (bytes 0-7,16-23,32-39,48-55 per 64-byte vector)
726
- __mmask32 negate_lo_i16 = (__mmask32)_pext_u64(negate_u8_mask, 0x00FF00FF00FF00FFULL);
727
- __mmask32 negate_hi_i16 = (__mmask32)_pext_u64(negate_u8_mask, 0xFF00FF00FF00FF00ULL);
731
+ __mmask32 negate_low_i16 = (__mmask32)_pext_u64(negate_u8_mask, 0x00FF00FF00FF00FFULL);
732
+ __mmask32 negate_high_i16 = (__mmask32)_pext_u64(negate_u8_mask, 0xFF00FF00FF00FF00ULL);
728
733
  // Negate b at i16 level using mask_sub
729
- __m512i b_signed_lo_i16x32 = _mm512_mask_sub_epi16(b_lo_i16x32, negate_lo_i16, _mm512_setzero_si512(), b_lo_i16x32);
730
- __m512i b_signed_hi_i16x32 = _mm512_mask_sub_epi16(b_hi_i16x32, negate_hi_i16, _mm512_setzero_si512(), b_hi_i16x32);
734
+ __m512i b_signed_low_i16x32 = _mm512_mask_sub_epi16(b_low_i16x32, negate_low_i16, _mm512_setzero_si512(),
735
+ b_low_i16x32);
736
+ __m512i b_signed_high_i16x32 = _mm512_mask_sub_epi16(b_high_i16x32, negate_high_i16, _mm512_setzero_si512(),
737
+ b_high_i16x32);
731
738
 
732
739
  // VPMADDWD: a_i16 × b_signed_i16 → i32 accumulator
733
- sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_lo_i16x32, b_signed_lo_i16x32));
734
- sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_hi_i16x32, b_signed_hi_i16x32));
740
+ sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_low_i16x32, b_signed_low_i16x32));
741
+ sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_high_i16x32, b_signed_high_i16x32));
735
742
 
736
743
  if (count_scalars) goto nk_dot_e3m2_skylake_cycle;
737
744
  *result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
738
745
  }
739
746
 
740
- #pragma endregion - Smaller Floats
747
+ #pragma endregion F16 and BF16 Floats
741
748
 
742
- #pragma region - Small Integers
749
+ #pragma region I8 and U8 Integers
743
750
 
744
751
  NK_PUBLIC void nk_dot_i8_skylake(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
745
752
  nk_i32_t *result) {
@@ -869,10 +876,34 @@ NK_INTERNAL void nk_dot_f32x8_finalize_skylake(
869
876
  result->ymm_pd = _mm256_set_m128d(sum_cd_f64x2, sum_ab_f64x2);
870
877
  }
871
878
 
872
- #pragma endregion - Traditional Floats
873
-
874
879
  typedef nk_dot_through_f32_state_skylake_t_ nk_dot_bf16x16_state_skylake_t;
875
880
 
881
+ typedef nk_dot_through_f32_state_skylake_t_ nk_dot_bf16x32_state_skylake_t;
882
+
883
+ NK_INTERNAL void nk_dot_bf16x32_init_skylake(nk_dot_bf16x32_state_skylake_t *state) {
884
+ nk_dot_through_f32_init_skylake_(state);
885
+ }
886
+
887
+ NK_INTERNAL void nk_dot_bf16x32_update_skylake(nk_dot_bf16x32_state_skylake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
888
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
889
+ nk_unused_(depth_offset);
890
+ nk_unused_(active_dimensions);
891
+ __m512i mask_high_u32x16 = _mm512_set1_epi32((int)0xFFFF0000);
892
+ __m512 a_even_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(a.zmm, 16));
893
+ __m512 b_even_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(b.zmm, 16));
894
+ state->sum_f32x16 = _mm512_fmadd_ps(a_even_f32x16, b_even_f32x16, state->sum_f32x16);
895
+ __m512 a_odd_f32x16 = _mm512_castsi512_ps(_mm512_and_si512(a.zmm, mask_high_u32x16));
896
+ __m512 b_odd_f32x16 = _mm512_castsi512_ps(_mm512_and_si512(b.zmm, mask_high_u32x16));
897
+ state->sum_f32x16 = _mm512_fmadd_ps(a_odd_f32x16, b_odd_f32x16, state->sum_f32x16);
898
+ }
899
+
900
+ NK_INTERNAL void nk_dot_bf16x32_finalize_skylake( //
901
+ nk_dot_bf16x32_state_skylake_t const *state_a, nk_dot_bf16x32_state_skylake_t const *state_b, //
902
+ nk_dot_bf16x32_state_skylake_t const *state_c, nk_dot_bf16x32_state_skylake_t const *state_d, //
903
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
904
+ nk_dot_through_f32_finalize_skylake_(state_a, state_b, state_c, state_d, total_dimensions, result);
905
+ }
906
+
876
907
  typedef nk_dot_through_f32_state_skylake_t_ nk_dot_f16x16_state_skylake_t;
877
908
 
878
909
  typedef struct nk_dot_e2m3x64_state_skylake_t {
@@ -887,14 +918,14 @@ NK_INTERNAL void nk_dot_e2m3x64_update_skylake(nk_dot_e2m3x64_state_skylake_t *s
887
918
  nk_size_t depth_offset, nk_size_t active_dimensions) {
888
919
  nk_unused_(depth_offset);
889
920
  nk_unused_(active_dimensions);
890
- __m512i const lut_lower_u8x64 = _mm512_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
891
- 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26, 24,
892
- 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26, 24, 22, 20,
893
- 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
894
- __m512i const lut_upper_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
895
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
896
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
897
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
921
+ __m512i const lut_low_u8x64 = _mm512_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
922
+ 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26, 24, 22, 20,
923
+ 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26, 24, 22, 20, 18, 16, 14,
924
+ 12, 10, 8, 6, 4, 2, 0);
925
+ __m512i const lut_high_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
926
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
927
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
928
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
898
929
  __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
899
930
  __m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
900
931
  __m512i const half_select_u8x64 = _mm512_set1_epi8(0x10);
@@ -909,19 +940,19 @@ NK_INTERNAL void nk_dot_e2m3x64_update_skylake(nk_dot_e2m3x64_state_skylake_t *s
909
940
  __m512i a_shuffle_idx = _mm512_and_si512(a_magnitude, nibble_mask_u8x64);
910
941
  __m512i b_shuffle_idx = _mm512_and_si512(b_magnitude, nibble_mask_u8x64);
911
942
 
912
- __mmask64 a_upper = _mm512_test_epi8_mask(a_magnitude, half_select_u8x64);
913
- __mmask64 b_upper = _mm512_test_epi8_mask(b_magnitude, half_select_u8x64);
943
+ __mmask64 a_high = _mm512_test_epi8_mask(a_magnitude, half_select_u8x64);
944
+ __mmask64 b_high = _mm512_test_epi8_mask(b_magnitude, half_select_u8x64);
914
945
 
915
- __m512i a_unsigned = _mm512_mask_blend_epi8(a_upper, _mm512_shuffle_epi8(lut_lower_u8x64, a_shuffle_idx),
916
- _mm512_shuffle_epi8(lut_upper_u8x64, a_shuffle_idx));
917
- __m512i b_unsigned = _mm512_mask_blend_epi8(b_upper, _mm512_shuffle_epi8(lut_lower_u8x64, b_shuffle_idx),
918
- _mm512_shuffle_epi8(lut_upper_u8x64, b_shuffle_idx));
946
+ __m512i a_unsigned = _mm512_mask_blend_epi8(a_high, _mm512_shuffle_epi8(lut_low_u8x64, a_shuffle_idx),
947
+ _mm512_shuffle_epi8(lut_high_u8x64, a_shuffle_idx));
948
+ __m512i b_unsigned = _mm512_mask_blend_epi8(b_high, _mm512_shuffle_epi8(lut_low_u8x64, b_shuffle_idx),
949
+ _mm512_shuffle_epi8(lut_high_u8x64, b_shuffle_idx));
919
950
 
920
951
  __m512i sign_combined = _mm512_and_si512(_mm512_xor_si512(a_u8x64, b_u8x64), sign_mask_u8x64);
921
952
  __mmask64 negate_mask = _mm512_test_epi8_mask(sign_combined, sign_combined);
922
- __m512i b_signed = _mm512_mask_sub_epi8(b_unsigned, negate_mask, _mm512_setzero_si512(), b_unsigned);
953
+ __m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_unsigned, negate_mask, _mm512_setzero_si512(), b_unsigned);
923
954
 
924
- __m512i products_i16x32 = _mm512_maddubs_epi16(a_unsigned, b_signed);
955
+ __m512i products_i16x32 = _mm512_maddubs_epi16(a_unsigned, b_signed_i8x64);
925
956
  state->sum_i32x16 = _mm512_add_epi32(state->sum_i32x16, _mm512_madd_epi16(products_i16x32, ones_i16x32));
926
957
  }
927
958
 
@@ -976,10 +1007,10 @@ NK_INTERNAL void nk_dot_e3m2x64_update_skylake(nk_dot_e3m2x64_state_skylake_t *s
976
1007
  nk_size_t depth_offset, nk_size_t active_dimensions) {
977
1008
  nk_unused_(depth_offset);
978
1009
  nk_unused_(active_dimensions);
979
- __m512i const lut_lo_lower_u8x64 = _mm512_set_epi8( //
1010
+ __m512i const lut_low_byte_first_u8x64 = _mm512_set_epi8( //
980
1011
  28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
981
1012
  28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
982
- __m512i const lut_lo_upper_u8x64 = _mm512_set_epi8( //
1013
+ __m512i const lut_low_byte_second_u8x64 = _mm512_set_epi8( //
983
1014
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
984
1015
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
985
1016
  (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
@@ -998,34 +1029,35 @@ NK_INTERNAL void nk_dot_e3m2x64_update_skylake(nk_dot_e3m2x64_state_skylake_t *s
998
1029
  __m512i a_shuffle_idx = _mm512_and_si512(a_magnitude, nibble_mask_u8x64);
999
1030
  __m512i b_shuffle_idx = _mm512_and_si512(b_magnitude, nibble_mask_u8x64);
1000
1031
 
1001
- __mmask64 a_upper = _mm512_test_epi8_mask(a_magnitude, half_select_u8x64);
1002
- __mmask64 b_upper = _mm512_test_epi8_mask(b_magnitude, half_select_u8x64);
1032
+ __mmask64 a_high = _mm512_test_epi8_mask(a_magnitude, half_select_u8x64);
1033
+ __mmask64 b_high = _mm512_test_epi8_mask(b_magnitude, half_select_u8x64);
1003
1034
 
1004
- __m512i a_lo_bytes = _mm512_mask_blend_epi8(a_upper, _mm512_shuffle_epi8(lut_lo_lower_u8x64, a_shuffle_idx),
1005
- _mm512_shuffle_epi8(lut_lo_upper_u8x64, a_shuffle_idx));
1006
- __m512i b_lo_bytes = _mm512_mask_blend_epi8(b_upper, _mm512_shuffle_epi8(lut_lo_lower_u8x64, b_shuffle_idx),
1007
- _mm512_shuffle_epi8(lut_lo_upper_u8x64, b_shuffle_idx));
1035
+ __m512i a_low_byte = _mm512_mask_blend_epi8(a_high, _mm512_shuffle_epi8(lut_low_byte_first_u8x64, a_shuffle_idx),
1036
+ _mm512_shuffle_epi8(lut_low_byte_second_u8x64, a_shuffle_idx));
1037
+ __m512i b_low_byte = _mm512_mask_blend_epi8(b_high, _mm512_shuffle_epi8(lut_low_byte_first_u8x64, b_shuffle_idx),
1038
+ _mm512_shuffle_epi8(lut_low_byte_second_u8x64, b_shuffle_idx));
1008
1039
 
1009
- __mmask64 a_hi_mask = _mm512_cmpge_epu8_mask(a_magnitude, _mm512_set1_epi8(28));
1010
- __mmask64 b_hi_mask = _mm512_cmpge_epu8_mask(b_magnitude, _mm512_set1_epi8(28));
1011
- __m512i a_hi_bytes = _mm512_maskz_mov_epi8(a_hi_mask, ones_u8x64);
1012
- __m512i b_hi_bytes = _mm512_maskz_mov_epi8(b_hi_mask, ones_u8x64);
1040
+ __mmask64 a_high_mask = _mm512_cmpge_epu8_mask(a_magnitude, _mm512_set1_epi8(28));
1041
+ __mmask64 b_high_mask = _mm512_cmpge_epu8_mask(b_magnitude, _mm512_set1_epi8(28));
1042
+ __m512i a_high_byte = _mm512_maskz_mov_epi8(a_high_mask, ones_u8x64);
1043
+ __m512i b_high_byte = _mm512_maskz_mov_epi8(b_high_mask, ones_u8x64);
1013
1044
 
1014
- __m512i a_lo_i16 = _mm512_unpacklo_epi8(a_lo_bytes, a_hi_bytes);
1015
- __m512i a_hi_i16 = _mm512_unpackhi_epi8(a_lo_bytes, a_hi_bytes);
1016
- __m512i b_lo_i16 = _mm512_unpacklo_epi8(b_lo_bytes, b_hi_bytes);
1017
- __m512i b_hi_i16 = _mm512_unpackhi_epi8(b_lo_bytes, b_hi_bytes);
1045
+ __m512i a_low_i16x32 = _mm512_unpacklo_epi8(a_low_byte, a_high_byte);
1046
+ __m512i a_high_i16x32 = _mm512_unpackhi_epi8(a_low_byte, a_high_byte);
1047
+ __m512i b_low_i16x32 = _mm512_unpacklo_epi8(b_low_byte, b_high_byte);
1048
+ __m512i b_high_i16x32 = _mm512_unpackhi_epi8(b_low_byte, b_high_byte);
1018
1049
 
1019
1050
  // Combined sign: negate b at i16 level via PEXT + mask_sub
1020
1051
  __m512i sign_combined = _mm512_and_si512(_mm512_xor_si512(a_u8x64, b_u8x64), sign_mask_u8x64);
1021
1052
  __mmask64 negate_u8 = _mm512_test_epi8_mask(sign_combined, sign_combined);
1022
- __mmask32 negate_lo = (__mmask32)_pext_u64(negate_u8, 0x00FF00FF00FF00FFULL);
1023
- __mmask32 negate_hi = (__mmask32)_pext_u64(negate_u8, 0xFF00FF00FF00FF00ULL);
1024
- __m512i b_signed_lo = _mm512_mask_sub_epi16(b_lo_i16, negate_lo, _mm512_setzero_si512(), b_lo_i16);
1025
- __m512i b_signed_hi = _mm512_mask_sub_epi16(b_hi_i16, negate_hi, _mm512_setzero_si512(), b_hi_i16);
1026
-
1027
- state->sum_a_i32x16 = _mm512_add_epi32(state->sum_a_i32x16, _mm512_madd_epi16(a_lo_i16, b_signed_lo));
1028
- state->sum_b_i32x16 = _mm512_add_epi32(state->sum_b_i32x16, _mm512_madd_epi16(a_hi_i16, b_signed_hi));
1053
+ __mmask32 negate_low = (__mmask32)_pext_u64(negate_u8, 0x00FF00FF00FF00FFULL);
1054
+ __mmask32 negate_high = (__mmask32)_pext_u64(negate_u8, 0xFF00FF00FF00FF00ULL);
1055
+ __m512i b_signed_low_i16x32 = _mm512_mask_sub_epi16(b_low_i16x32, negate_low, _mm512_setzero_si512(), b_low_i16x32);
1056
+ __m512i b_signed_high_i16x32 = _mm512_mask_sub_epi16(b_high_i16x32, negate_high, _mm512_setzero_si512(),
1057
+ b_high_i16x32);
1058
+
1059
+ state->sum_a_i32x16 = _mm512_add_epi32(state->sum_a_i32x16, _mm512_madd_epi16(a_low_i16x32, b_signed_low_i16x32));
1060
+ state->sum_b_i32x16 = _mm512_add_epi32(state->sum_b_i32x16, _mm512_madd_epi16(a_high_i16x32, b_signed_high_i16x32));
1029
1061
  }
1030
1062
 
1031
1063
  NK_INTERNAL void nk_dot_e3m2x64_finalize_skylake( //
@@ -1067,7 +1099,7 @@ NK_INTERNAL void nk_dot_e3m2x64_finalize_skylake(
1067
1099
  results->xmm = _mm_castps_si128(sum_f32x4);
1068
1100
  }
1069
1101
 
1070
- #pragma endregion - Small Integers
1102
+ #pragma endregion I8 and U8 Integers
1071
1103
 
1072
1104
  #if defined(__clang__)
1073
1105
  #pragma clang attribute pop