numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -34,9 +34,10 @@ extern "C" {
34
34
  nk_define_cross_compensated_pack_size_(dots, i8, alder, i8, i8,
35
35
  /*sum_value_type=*/i32, /*norm_value_type=*/u32,
36
36
  /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
37
- nk_define_cross_compensated_pack_(dots, i8, alder, i8, i8, nk_assign_from_to_,
38
- /*sum_value_type=*/i32, /*norm_value_type=*/u32, nk_dots_reduce_moments_i8_,
39
- /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
37
+ nk_define_cross_compensated_pack_(dots, i8, alder, i8, i8, nk_b128_vec_t, nk_load_b128_haswell_,
38
+ nk_partial_load_b8x16_serial_, nk_store_b128_haswell_, nk_partial_store_b8x16_serial_,
39
+ /*simd_width=*/16, /*sum_value_type=*/i32, /*norm_value_type=*/u32,
40
+ nk_dots_reduce_moments_i8_, /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
40
41
  nk_define_cross_compensated_symmetric_(dots, i8, alder, i8, i32,
41
42
  /*sum_value_type=*/i32, /*norm_value_type=*/u32, nk_b256_vec_t,
42
43
  nk_dot_i8x32_state_alder_t, nk_b128_vec_t, nk_dot_i8x32_init_alder,
@@ -60,9 +61,10 @@ nk_define_cross_compensated_packed_(dots, i8, alder, i8, i8, i32,
60
61
  nk_define_cross_compensated_pack_size_(dots, u8, alder, u8, u8,
61
62
  /*sum_value_type=*/u32, /*norm_value_type=*/u32,
62
63
  /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
63
- nk_define_cross_compensated_pack_(dots, u8, alder, u8, u8, nk_assign_from_to_,
64
- /*sum_value_type=*/u32, /*norm_value_type=*/u32, nk_dots_reduce_moments_u8_,
65
- /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
64
+ nk_define_cross_compensated_pack_(dots, u8, alder, u8, u8, nk_b128_vec_t, nk_load_b128_haswell_,
65
+ nk_partial_load_b8x16_serial_, nk_store_b128_haswell_, nk_partial_store_b8x16_serial_,
66
+ /*simd_width=*/16, /*sum_value_type=*/u32, /*norm_value_type=*/u32,
67
+ nk_dots_reduce_moments_u8_, /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
66
68
  nk_define_cross_compensated_symmetric_(dots, u8, alder, u8, u32,
67
69
  /*sum_value_type=*/u32, /*norm_value_type=*/u32, nk_b256_vec_t,
68
70
  nk_dot_u8x32_state_alder_t, nk_b128_vec_t, nk_dot_u8x32_init_alder,
@@ -85,9 +87,10 @@ nk_define_cross_compensated_packed_(dots, u8, alder, u8, u8, u32,
85
87
  /* E2M3 GEMM via DPBUSD integer path: depth_simd_dimensions=32 (32 e2m3s = 32 bytes = AVX2 register width) */
86
88
  nk_define_cross_pack_size_(dots, e2m3, alder, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
87
89
  /*dimensions_per_value=*/1)
88
- nk_define_cross_pack_(dots, e2m3, alder, e2m3, e2m3, nk_assign_from_to_, /*norm_value_type=*/f32,
89
- nk_dots_reduce_sumsq_e2m3_, /*depth_simd_dimensions=*/32,
90
- /*dimensions_per_value=*/1)
90
+ nk_define_cross_pack_(dots, e2m3, alder, e2m3, e2m3, nk_b256_vec_t, nk_load_b256_haswell_,
91
+ nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
92
+ /*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e2m3_,
93
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
91
94
  nk_define_cross_symmetric_(dots, e2m3, alder, e2m3, f32, nk_b256_vec_t, nk_dot_e2m3x32_state_alder_t, nk_b128_vec_t,
92
95
  nk_dot_e2m3x32_init_alder, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
93
96
  nk_dot_e2m3x32_update_alder, nk_dot_e2m3x32_finalize_alder, nk_store_b128_haswell_,
@@ -0,0 +1,86 @@
1
+ /**
2
+ * @brief SIMD-accelerated Batched Dot Products for Diamond Rapids.
3
+ * @file include/numkong/dots/diamond.h
4
+ * @author Ash Vardanian
5
+ * @date March 23, 2026
6
+ *
7
+ * @sa include/numkong/dots.h
8
+ *
9
+ * Uses VCVTHF82PH/VCVTBF82PH for native FP8→FP16 conversion, then VDPPHPS for
10
+ * FP16-pair dot products accumulating into FP32. Processes 32 FP8 elements per iteration.
11
+ */
12
+ #ifndef NK_DOTS_DIAMOND_H
13
+ #define NK_DOTS_DIAMOND_H
14
+
15
+ #if NK_TARGET_X86_
16
+ #if NK_TARGET_DIAMOND
17
+
18
+ #include "numkong/dot/diamond.h"
19
+
20
+ #if defined(__cplusplus)
21
+ extern "C" {
22
+ #endif
23
+
24
+ #if defined(__clang__)
25
+ #pragma clang attribute push( \
26
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx10.2-512,f16c,fma,bmi,bmi2"))), \
27
+ apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx10.2-512", "f16c", "fma", \
31
+ "bmi", "bmi2")
32
+ #endif
33
+
34
+ /* E4M3 GEMM: depth_simd_dimensions=32 (32 e4m3s = 32 bytes), FP16 intermediate, F32 accumulator */
35
+ nk_define_cross_pack_size_(dots, e4m3, diamond, e4m3, e4m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
36
+ /*dimensions_per_value=*/1)
37
+ nk_define_cross_pack_(dots, e4m3, diamond, e4m3, e4m3, nk_b512_vec_t, nk_load_b512_skylake_,
38
+ nk_partial_load_b8x64_skylake_, nk_store_b512_skylake_, nk_partial_store_b8x64_skylake_,
39
+ /*simd_width=*/64, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
40
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
41
+ nk_define_cross_symmetric_(dots, e4m3, diamond, e4m3, f32, nk_b512_vec_t, nk_dot_through_f16_state_diamond_t_,
42
+ nk_b128_vec_t, nk_dot_through_f16_init_diamond_, nk_load_e4m3x32_to_f16x32_diamond_,
43
+ nk_partial_load_e4m3x32_to_f16x32_diamond_, nk_dot_through_f16_update_diamond_,
44
+ nk_dot_through_f16_finalize_diamond_, nk_store_b128_haswell_,
45
+ nk_partial_store_b32x4_skylake_,
46
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
47
+ nk_define_cross_packed_(dots, e4m3, diamond, e4m3, e4m3, f32, nk_b512_vec_t, nk_dot_through_f16_state_diamond_t_,
48
+ nk_b128_vec_t, nk_dot_through_f16_init_diamond_, nk_load_e4m3x32_to_f16x32_diamond_,
49
+ nk_partial_load_e4m3x32_to_f16x32_diamond_, nk_load_e4m3x32_to_f16x32_diamond_,
50
+ nk_partial_load_e4m3x32_to_f16x32_diamond_, nk_dot_through_f16_update_diamond_,
51
+ nk_dot_through_f16_finalize_diamond_, nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_,
52
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
53
+
54
+ /* E5M2 GEMM: depth_simd_dimensions=32 (32 e5m2s = 32 bytes), FP16 intermediate, F32 accumulator */
55
+ nk_define_cross_pack_size_(dots, e5m2, diamond, e5m2, e5m2, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
56
+ /*dimensions_per_value=*/1)
57
+ nk_define_cross_pack_(dots, e5m2, diamond, e5m2, e5m2, nk_b512_vec_t, nk_load_b512_skylake_,
58
+ nk_partial_load_b8x64_skylake_, nk_store_b512_skylake_, nk_partial_store_b8x64_skylake_,
59
+ /*simd_width=*/64, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
60
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
61
+ nk_define_cross_symmetric_(dots, e5m2, diamond, e5m2, f32, nk_b512_vec_t, nk_dot_through_f16_state_diamond_t_,
62
+ nk_b128_vec_t, nk_dot_through_f16_init_diamond_, nk_load_e5m2x32_to_f16x32_diamond_,
63
+ nk_partial_load_e5m2x32_to_f16x32_diamond_, nk_dot_through_f16_update_diamond_,
64
+ nk_dot_through_f16_finalize_diamond_, nk_store_b128_haswell_,
65
+ nk_partial_store_b32x4_skylake_,
66
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
67
+ nk_define_cross_packed_(dots, e5m2, diamond, e5m2, e5m2, f32, nk_b512_vec_t, nk_dot_through_f16_state_diamond_t_,
68
+ nk_b128_vec_t, nk_dot_through_f16_init_diamond_, nk_load_e5m2x32_to_f16x32_diamond_,
69
+ nk_partial_load_e5m2x32_to_f16x32_diamond_, nk_load_e5m2x32_to_f16x32_diamond_,
70
+ nk_partial_load_e5m2x32_to_f16x32_diamond_, nk_dot_through_f16_update_diamond_,
71
+ nk_dot_through_f16_finalize_diamond_, nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_,
72
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
73
+
74
+ #if defined(__clang__)
75
+ #pragma clang attribute pop
76
+ #elif defined(__GNUC__)
77
+ #pragma GCC pop_options
78
+ #endif
79
+
80
+ #if defined(__cplusplus)
81
+ } // extern "C"
82
+ #endif
83
+
84
+ #endif // NK_TARGET_DIAMOND
85
+ #endif // NK_TARGET_X86_
86
+ #endif // NK_DOTS_DIAMOND_H
@@ -31,8 +31,10 @@ extern "C" {
31
31
  /* BF16 GEMM: depth_simd_dimensions=32 (32 bf16s = 64 bytes = 1 cache line) */
32
32
  nk_define_cross_pack_size_(dots, bf16, genoa, bf16, bf16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
33
33
  /*dimensions_per_value=*/1)
34
- nk_define_cross_pack_(dots, bf16, genoa, bf16, bf16, nk_assign_from_to_, /*norm_value_type=*/f32,
35
- nk_dots_reduce_sumsq_bf16_, /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
34
+ nk_define_cross_pack_(dots, bf16, genoa, bf16, bf16, nk_b512_vec_t, nk_load_b512_skylake_,
35
+ nk_partial_load_b16x32_skylake_, nk_store_b512_skylake_, nk_partial_store_b16x32_skylake_,
36
+ /*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_bf16_,
37
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
36
38
  nk_define_cross_symmetric_(dots, bf16, genoa, bf16, f32, nk_b512_vec_t, nk_dot_through_bf16_state_genoa_t_,
37
39
  nk_b128_vec_t, nk_dot_through_bf16_init_genoa_, nk_load_b512_skylake_,
38
40
  nk_partial_load_b16x32_skylake_, nk_dot_through_bf16_update_genoa_,
@@ -48,7 +50,9 @@ nk_define_cross_packed_(dots, bf16, genoa, bf16, bf16, f32, nk_b512_vec_t, nk_do
48
50
  /* E4M3 GEMM: depth_simd_dimensions=32 (32 e4m3s = 32 bytes = half cache line), F32 accumulator */
49
51
  nk_define_cross_pack_size_(dots, e4m3, genoa, e4m3, bf16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
50
52
  /*dimensions_per_value=*/1)
51
- nk_define_cross_pack_(dots, e4m3, genoa, e4m3, bf16, nk_e4m3_to_bf16, /*norm_value_type=*/f32,
53
+ nk_define_cross_pack_(dots, e4m3, genoa, e4m3, bf16, nk_b512_vec_t, nk_load_e4m3x32_to_bf16x32_icelake_,
54
+ nk_partial_load_e4m3x32_to_bf16x32_icelake_, nk_store_b512_skylake_,
55
+ nk_partial_store_b16x32_skylake_, /*simd_width=*/32, /*norm_value_type=*/f32,
52
56
  nk_dots_reduce_sumsq_e4m3_, /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
53
57
  nk_define_cross_symmetric_(dots, e4m3, genoa, e4m3, f32, nk_b512_vec_t, nk_dot_through_bf16_state_genoa_t_,
54
58
  nk_b128_vec_t, nk_dot_through_bf16_init_genoa_, nk_load_e4m3x32_to_bf16x32_icelake_,
@@ -65,7 +69,9 @@ nk_define_cross_packed_(dots, e4m3, genoa, e4m3, bf16, f32, nk_b512_vec_t, nk_do
65
69
  /* E5M2 GEMM: depth_simd_dimensions=32 (32 e5m2s = 32 bytes = half cache line), F32 accumulator */
66
70
  nk_define_cross_pack_size_(dots, e5m2, genoa, e5m2, bf16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
67
71
  /*dimensions_per_value=*/1)
68
- nk_define_cross_pack_(dots, e5m2, genoa, e5m2, bf16, nk_e5m2_to_bf16, /*norm_value_type=*/f32,
72
+ nk_define_cross_pack_(dots, e5m2, genoa, e5m2, bf16, nk_b512_vec_t, nk_load_e5m2x32_to_bf16x32_icelake_,
73
+ nk_partial_load_e5m2x32_to_bf16x32_icelake_, nk_store_b512_skylake_,
74
+ nk_partial_store_b16x32_skylake_, /*simd_width=*/32, /*norm_value_type=*/f32,
69
75
  nk_dots_reduce_sumsq_e5m2_, /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
70
76
  nk_define_cross_symmetric_(dots, e5m2, genoa, e5m2, f32, nk_b512_vec_t, nk_dot_through_bf16_state_genoa_t_,
71
77
  nk_b128_vec_t, nk_dot_through_bf16_init_genoa_, nk_load_e5m2x32_to_bf16x32_icelake_,
@@ -8,12 +8,12 @@
8
8
  *
9
9
  * @section haswell_dots_instructions Key AVX2/FMA GEMM Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm256_fmadd_ps/pd VFMADD (YMM, YMM, YMM) 5cy 0.5/cy p01
13
- * _mm256_mul_ps VMULPS (YMM, YMM, YMM) 5cy 0.5/cy p01
14
- * _mm256_add_ps VADDPS (YMM, YMM, YMM) 3cy 1/cy p01
15
- * _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy 1/cy p01
16
- * _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 5cy 1/cy p0
11
+ * Intrinsic Instruction Haswell Genoa
12
+ * _mm256_fmadd_ps/pd VFMADD (YMM, YMM, YMM) 5cy @ p01 4cy @ p01
13
+ * _mm256_mul_ps VMULPS (YMM, YMM, YMM) 5cy @ p01 3cy @ p01
14
+ * _mm256_add_ps VADDPS (YMM, YMM, YMM) 3cy @ p01 3cy @ p23
15
+ * _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy @ p01 4cy @ p12+p23
16
+ * _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 5cy @ p01 3cy @ p01
17
17
  *
18
18
  * GEMM kernels use tiled dot products with 4-way parallel accumulation to hide FMA latency.
19
19
  * Type-specific tile sizes: f32/f64 use depth_simd_dimensions=4, f16/bf16 use depth_simd_dimensions=8,
@@ -43,9 +43,9 @@ extern "C" {
43
43
  /* F32 GEMM: depth_simd_dimensions=4 (4 f32s = 16 bytes for f32->f64 upcast accumulation) */
44
44
  nk_define_cross_pack_size_(dots, f32, haswell, f32, f32, /*norm_value_type=*/f64, /*depth_simd_dimensions=*/4,
45
45
  /*dimensions_per_value=*/1)
46
- nk_define_cross_pack_(dots, f32, haswell, f32, f32, nk_assign_from_to_, /*norm_value_type=*/f64,
47
- nk_dots_reduce_sumsq_f32_,
48
- /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
46
+ nk_define_cross_pack_(dots, f32, haswell, f32, f32, nk_b256_vec_t, nk_load_b256_haswell_, nk_partial_load_b32x8_serial_,
47
+ nk_store_b256_haswell_, nk_partial_store_b32x8_serial_, /*simd_width=*/8, /*norm_value_type=*/f64,
48
+ nk_dots_reduce_sumsq_f32_, /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
49
49
  nk_define_cross_symmetric_(dots, f32, haswell, f32, f64, nk_b128_vec_t, nk_dot_f32x4_state_haswell_t, nk_b256_vec_t,
50
50
  nk_dot_f32x4_init_haswell, nk_load_b128_haswell_, nk_partial_load_b32x4_haswell_,
51
51
  nk_dot_f32x4_update_haswell, nk_dot_f32x4_finalize_haswell, nk_store_b256_haswell_,
@@ -60,9 +60,10 @@ nk_define_cross_packed_(dots, f32, haswell, f32, f32, f64, nk_b128_vec_t, nk_dot
60
60
  /* F64 GEMM: depth_simd_dimensions=4 (4 f64s = 32 bytes = AVX2 register width) */
61
61
  nk_define_cross_pack_size_(dots, f64, haswell, f64, f64, /*norm_value_type=*/f64, /*depth_simd_dimensions=*/4,
62
62
  /*dimensions_per_value=*/1)
63
- nk_define_cross_pack_(dots, f64, haswell, f64, f64, nk_assign_from_to_, /*norm_value_type=*/f64,
64
- nk_dots_reduce_sumsq_f64_,
65
- /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
63
+ nk_define_cross_pack_(dots, f64, haswell, f64, f64, nk_b256_vec_t, nk_load_b256_haswell_,
64
+ nk_partial_load_b64x4_haswell_, nk_store_b256_haswell_, nk_partial_store_b64x4_haswell_,
65
+ /*simd_width=*/4, /*norm_value_type=*/f64, nk_dots_reduce_sumsq_f64_, /*depth_simd_dimensions=*/4,
66
+ /*dimensions_per_value=*/1)
66
67
  nk_define_cross_symmetric_(dots, f64, haswell, f64, f64, nk_b256_vec_t, nk_dot_f64x4_state_haswell_t, nk_b256_vec_t,
67
68
  nk_dot_f64x4_init_haswell, nk_load_b256_haswell_, nk_partial_load_b64x4_haswell_,
68
69
  nk_dot_f64x4_update_haswell, nk_dot_f64x4_finalize_haswell, nk_store_b256_haswell_,
@@ -77,8 +78,10 @@ nk_define_cross_packed_(dots, f64, haswell, f64, f64, f64, nk_b256_vec_t, nk_dot
77
78
  /* F16 GEMM: depth_simd_dimensions=8 (8 f16s = 16 bytes = 128-bit input) → upcasted to 8×f32 (256-bit) */
78
79
  nk_define_cross_pack_size_(dots, f16, haswell, f16, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
79
80
  /*dimensions_per_value=*/1)
80
- nk_define_cross_pack_(dots, f16, haswell, f16, f32, nk_f16_to_f32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_f16_,
81
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1) // Store as F32
81
+ nk_define_cross_pack_(dots, f16, haswell, f16, f32, nk_b256_vec_t, nk_load_f16x8_to_f32x8_haswell_,
82
+ nk_partial_load_f16x8_to_f32x8_haswell_, nk_store_b256_haswell_, nk_partial_store_b32x8_serial_,
83
+ /*simd_width=*/8, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_f16_, /*depth_simd_dimensions=*/8,
84
+ /*dimensions_per_value=*/1)
82
85
  nk_define_cross_symmetric_(dots, f16, haswell, f16, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
83
86
  nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_f16x8_to_f32x8_haswell_,
84
87
  nk_partial_load_f16x8_to_f32x8_haswell_, nk_dot_through_f32_update_haswell_,
@@ -93,29 +96,31 @@ nk_define_cross_packed_(dots, f16, haswell, f16, f32, f32, nk_b256_vec_t, nk_dot
93
96
  /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
94
97
 
95
98
  /* BF16 GEMM: depth_simd_dimensions=8 (8 bf16s = 16 bytes = 128-bit input) → upcasted to 8×f32 (256-bit) */
96
- nk_define_cross_pack_size_(dots, bf16, haswell, bf16, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
99
+ /* BF16 GEMM: depth_simd_dimensions=16, raw bf16 storage, unpack(zero, bf16) f32 inline */
100
+ nk_define_cross_pack_size_(dots, bf16, haswell, bf16, bf16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
97
101
  /*dimensions_per_value=*/1)
98
- nk_define_cross_pack_(dots, bf16, haswell, bf16, f32, nk_bf16_to_f32, /*norm_value_type=*/f32,
99
- nk_dots_reduce_sumsq_bf16_,
100
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1) // Store as F32
101
- nk_define_cross_symmetric_(dots, bf16, haswell, bf16, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
102
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_bf16x8_to_f32x8_haswell_,
103
- nk_partial_load_bf16x8_to_f32x8_haswell_, nk_dot_through_f32_update_haswell_,
104
- nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
102
+ nk_define_cross_pack_(dots, bf16, haswell, bf16, bf16, nk_b256_vec_t, nk_load_b256_haswell_,
103
+ nk_partial_load_b16x16_serial_, nk_store_b256_haswell_, nk_partial_store_b16x16_serial_,
104
+ /*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_bf16_,
105
+ /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
106
+ nk_define_cross_symmetric_(dots, bf16, haswell, bf16, f32, nk_b256_vec_t, nk_dot_bf16x16_state_haswell_t, nk_b128_vec_t,
107
+ nk_dot_bf16x16_init_haswell, nk_load_b256_haswell_, nk_partial_load_b16x16_serial_,
108
+ nk_dot_bf16x16_update_haswell, nk_dot_bf16x16_finalize_haswell, nk_store_b128_haswell_,
105
109
  nk_partial_store_b32x4_haswell_,
106
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
107
- nk_define_cross_packed_(dots, bf16, haswell, bf16, f32, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
108
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_bf16x8_to_f32x8_haswell_,
109
- nk_partial_load_bf16x8_to_f32x8_haswell_, nk_load_b256_haswell_, nk_partial_load_b32x8_serial_,
110
- nk_dot_through_f32_update_haswell_, nk_dot_through_f32_finalize_haswell_,
111
- nk_store_b128_haswell_, nk_partial_store_b32x4_haswell_,
112
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
110
+ /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
111
+ nk_define_cross_packed_(dots, bf16, haswell, bf16, bf16, f32, nk_b256_vec_t, nk_dot_bf16x16_state_haswell_t,
112
+ nk_b128_vec_t, nk_dot_bf16x16_init_haswell, nk_load_b256_haswell_,
113
+ nk_partial_load_b16x16_serial_, nk_load_b256_haswell_, nk_partial_load_b16x16_serial_,
114
+ nk_dot_bf16x16_update_haswell, nk_dot_bf16x16_finalize_haswell, nk_store_b128_haswell_,
115
+ nk_partial_store_b32x4_haswell_,
116
+ /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
113
117
 
114
118
  /* E4M3 GEMM: depth_simd_dimensions=8 (8 e4m3s = 8 bytes) → upcasted to 8×f32 (256-bit) */
115
119
  nk_define_cross_pack_size_(dots, e4m3, haswell, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
116
120
  /*dimensions_per_value=*/1)
117
- nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_e4m3_to_f32, /*norm_value_type=*/f32,
118
- nk_dots_reduce_sumsq_e4m3_,
121
+ nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_load_e4m3x8_to_f32x8_haswell_,
122
+ nk_partial_load_e4m3x8_to_f32x8_haswell_, nk_store_b256_haswell_, nk_partial_store_b32x8_serial_,
123
+ /*simd_width=*/8, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
119
124
  /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
120
125
  nk_define_cross_symmetric_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
121
126
  nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e4m3x8_to_f32x8_haswell_,
@@ -133,8 +138,9 @@ nk_define_cross_packed_(dots, e4m3, haswell, e4m3, f32, f32, nk_b256_vec_t, nk_d
133
138
  /* E5M2 GEMM: depth_simd_dimensions=8 (8 e5m2s = 8 bytes) → upcasted to 8×f32 (256-bit) */
134
139
  nk_define_cross_pack_size_(dots, e5m2, haswell, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
135
140
  /*dimensions_per_value=*/1)
136
- nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_e5m2_to_f32, /*norm_value_type=*/f32,
137
- nk_dots_reduce_sumsq_e5m2_,
141
+ nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_load_e5m2x8_to_f32x8_haswell_,
142
+ nk_partial_load_e5m2x8_to_f32x8_haswell_, nk_store_b256_haswell_, nk_partial_store_b32x8_serial_,
143
+ /*simd_width=*/8, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
138
144
  /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
139
145
  nk_define_cross_symmetric_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
140
146
  nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e5m2x8_to_f32x8_haswell_,
@@ -152,8 +158,9 @@ nk_define_cross_packed_(dots, e5m2, haswell, e5m2, f32, f32, nk_b256_vec_t, nk_d
152
158
  /* E2M3 GEMM: integer LUT path, depth_simd_dimensions=32 (32 e2m3s = 32 bytes = AVX2 register width) */
153
159
  nk_define_cross_pack_size_(dots, e2m3, haswell, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
154
160
  /*dimensions_per_value=*/1)
155
- nk_define_cross_pack_(dots, e2m3, haswell, e2m3, e2m3, nk_assign_from_to_, /*norm_value_type=*/f32,
156
- nk_dots_reduce_sumsq_e2m3_,
161
+ nk_define_cross_pack_(dots, e2m3, haswell, e2m3, e2m3, nk_b256_vec_t, nk_load_b256_haswell_,
162
+ nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
163
+ /*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e2m3_,
157
164
  /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
158
165
  nk_define_cross_symmetric_(dots, e2m3, haswell, e2m3, f32, nk_b256_vec_t, nk_dot_e2m3x32_state_haswell_t, nk_b128_vec_t,
159
166
  nk_dot_e2m3x32_init_haswell, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
@@ -170,8 +177,9 @@ nk_define_cross_packed_(dots, e2m3, haswell, e2m3, e2m3, f32, nk_b256_vec_t, nk_
170
177
  /* E3M2 GEMM: integer LUT path, depth_simd_dimensions=32 (32 e3m2s = 32 bytes = AVX2 register width) */
171
178
  nk_define_cross_pack_size_(dots, e3m2, haswell, e3m2, e3m2, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
172
179
  /*dimensions_per_value=*/1)
173
- nk_define_cross_pack_(dots, e3m2, haswell, e3m2, e3m2, nk_assign_from_to_, /*norm_value_type=*/f32,
174
- nk_dots_reduce_sumsq_e3m2_,
180
+ nk_define_cross_pack_(dots, e3m2, haswell, e3m2, e3m2, nk_b256_vec_t, nk_load_b256_haswell_,
181
+ nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
182
+ /*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e3m2_,
175
183
  /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
176
184
  nk_define_cross_symmetric_(dots, e3m2, haswell, e3m2, f32, nk_b256_vec_t, nk_dot_e3m2x32_state_haswell_t, nk_b128_vec_t,
177
185
  nk_dot_e3m2x32_init_haswell, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
@@ -188,8 +196,10 @@ nk_define_cross_packed_(dots, e3m2, haswell, e3m2, e3m2, f32, nk_b256_vec_t, nk_
188
196
  /* I8 GEMM: depth_simd_dimensions=16 (16 i8s = 16 bytes = 128-bit input) */
189
197
  nk_define_cross_pack_size_(dots, i8, haswell, i8, i8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
190
198
  /*dimensions_per_value=*/1)
191
- nk_define_cross_pack_(dots, i8, haswell, i8, i8, nk_assign_from_to_, /*norm_value_type=*/u32, nk_dots_reduce_sumsq_i8_,
192
- /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
199
+ nk_define_cross_pack_(dots, i8, haswell, i8, i8, nk_b128_vec_t, nk_load_b128_haswell_, nk_partial_load_b8x16_serial_,
200
+ nk_store_b128_haswell_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
201
+ /*norm_value_type=*/u32, nk_dots_reduce_sumsq_i8_, /*depth_simd_dimensions=*/16,
202
+ /*dimensions_per_value=*/1)
193
203
  nk_define_cross_symmetric_(dots, i8, haswell, i8, i32, nk_b128_vec_t, nk_dot_i8x16_state_haswell_t, nk_b128_vec_t,
194
204
  nk_dot_i8x16_init_haswell, nk_load_b128_haswell_, nk_partial_load_b8x16_serial_,
195
205
  nk_dot_i8x16_update_haswell, nk_dot_i8x16_finalize_haswell, nk_store_b128_haswell_,
@@ -204,8 +214,10 @@ nk_define_cross_packed_(dots, i8, haswell, i8, i8, i32, nk_b128_vec_t, nk_dot_i8
204
214
  /* U8 GEMM: depth_simd_dimensions=16 (16 u8s = 16 bytes = 128-bit input) */
205
215
  nk_define_cross_pack_size_(dots, u8, haswell, u8, u8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
206
216
  /*dimensions_per_value=*/1)
207
- nk_define_cross_pack_(dots, u8, haswell, u8, u8, nk_assign_from_to_, /*norm_value_type=*/u32, nk_dots_reduce_sumsq_u8_,
208
- /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
217
+ nk_define_cross_pack_(dots, u8, haswell, u8, u8, nk_b128_vec_t, nk_load_b128_haswell_, nk_partial_load_b8x16_serial_,
218
+ nk_store_b128_haswell_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
219
+ /*norm_value_type=*/u32, nk_dots_reduce_sumsq_u8_, /*depth_simd_dimensions=*/16,
220
+ /*dimensions_per_value=*/1)
209
221
  nk_define_cross_symmetric_(dots, u8, haswell, u8, u32, nk_b128_vec_t, nk_dot_u8x16_state_haswell_t, nk_b128_vec_t,
210
222
  nk_dot_u8x16_init_haswell, nk_load_b128_haswell_, nk_partial_load_b8x16_serial_,
211
223
  nk_dot_u8x16_update_haswell, nk_dot_u8x16_finalize_haswell, nk_store_b128_haswell_,
@@ -222,9 +234,10 @@ nk_define_cross_packed_(dots, u8, haswell, u8, u8, u32, nk_b128_vec_t, nk_dot_u8
222
234
  nk_define_cross_compensated_pack_size_(dots, i4, haswell, i4x2, i4x2,
223
235
  /*sum_value_type=*/i32, /*norm_value_type=*/u32,
224
236
  /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/2)
225
- nk_define_cross_compensated_pack_(dots, i4, haswell, i4x2, i4x2, nk_assign_from_to_,
226
- /*sum_value_type=*/i32, /*norm_value_type=*/u32, nk_dots_reduce_moments_i4_,
227
- /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/2)
237
+ nk_define_cross_compensated_pack_(dots, i4, haswell, i4x2, i4x2, nk_b128_vec_t, nk_load_b128_haswell_,
238
+ nk_partial_load_b8x16_serial_, nk_store_b128_haswell_, nk_partial_store_b8x16_serial_,
239
+ /*simd_width=*/16, /*sum_value_type=*/i32, /*norm_value_type=*/u32,
240
+ nk_dots_reduce_moments_i4_, /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/2)
228
241
  nk_define_cross_compensated_symmetric_(dots, i4, haswell, i4x2, i32,
229
242
  /*sum_value_type=*/i32, /*norm_value_type=*/u32, nk_b128_vec_t,
230
243
  nk_dot_i4x32_state_haswell_t, nk_b128_vec_t, nk_dot_i4x32_init_haswell,
@@ -249,8 +262,9 @@ nk_define_cross_compensated_packed_(dots, i4, haswell, i4x2, i4x2, i32,
249
262
  * Note: dimensions_per_value=2 because 2 nibbles (u4 values) are packed per byte */
250
263
  nk_define_cross_pack_size_(dots, u4, haswell, u4x2, u4x2, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/32,
251
264
  /*dimensions_per_value=*/2)
252
- nk_define_cross_pack_(dots, u4, haswell, u4x2, u4x2, nk_assign_from_to_, /*norm_value_type=*/u32,
253
- nk_dots_reduce_sumsq_u4_,
265
+ nk_define_cross_pack_(dots, u4, haswell, u4x2, u4x2, nk_b128_vec_t, nk_load_b128_haswell_,
266
+ nk_partial_load_b8x16_serial_, nk_store_b128_haswell_, nk_partial_store_b8x16_serial_,
267
+ /*simd_width=*/16, /*norm_value_type=*/u32, nk_dots_reduce_sumsq_u4_,
254
268
  /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/2)
255
269
  nk_define_cross_symmetric_(dots, u4, haswell, u4x2, u32, nk_b128_vec_t, nk_dot_u4x32_state_haswell_t, nk_b128_vec_t,
256
270
  nk_dot_u4x32_init_haswell, nk_load_b128_haswell_, nk_partial_load_b4x32_serial_,
@@ -266,8 +280,9 @@ nk_define_cross_packed_(dots, u4, haswell, u4x2, u4x2, u32, nk_b128_vec_t, nk_do
266
280
  /* U1 GEMM: depth_simd_dimensions=128 (128 bits = 16 bytes) */
267
281
  nk_define_cross_pack_size_(dots, u1, haswell, u1x8, u1x8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/128,
268
282
  /*dimensions_per_value=*/8)
269
- nk_define_cross_pack_(dots, u1, haswell, u1x8, u1x8, nk_assign_from_to_, /*norm_value_type=*/u32,
270
- nk_dots_reduce_sum_u1_,
283
+ nk_define_cross_pack_(dots, u1, haswell, u1x8, u1x8, nk_b128_vec_t, nk_load_b128_haswell_,
284
+ nk_partial_load_b8x16_serial_, nk_store_b128_haswell_, nk_partial_store_b8x16_serial_,
285
+ /*simd_width=*/16, /*norm_value_type=*/u32, nk_dots_reduce_sum_u1_,
271
286
  /*depth_simd_dimensions=*/128, /*dimensions_per_value=*/8)
272
287
  nk_define_cross_symmetric_(dots, u1, haswell, u1x8, u32, nk_b128_vec_t, nk_dot_u1x128_state_haswell_t, nk_b128_vec_t,
273
288
  nk_dot_u1x128_init_haswell, nk_load_b128_haswell_, nk_partial_load_b1x128_serial_,
@@ -8,11 +8,11 @@
8
8
  *
9
9
  * @section ice_dots_instructions Relevant Instructions
10
10
  *
11
- * Intrinsic Instruction Ice Genoa
12
- * _mm512_dpbusd_epi32 VPDPBUSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
13
- * _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
14
- * _mm512_cvtepi8_epi32 VPMOVSXBD (ZMM, XMM) 3cy @ p5 3cy @ p12
15
- * _mm512_loadu_si512 VMOVDQU64 (ZMM, M512) 7cy @ p23 7cy @ p23
11
+ * Intrinsic Instruction Icelake Genoa
12
+ * _mm512_dpbusd_epi32 VPDPBUSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
13
+ * _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
14
+ * _mm512_cvtepi8_epi32 VPMOVSXBD (ZMM, XMM) 3cy @ p5 3cy @ p12
15
+ * _mm512_loadu_si512 VMOVDQU64 (ZMM, M512) 7cy @ p23 7cy @ p23
16
16
  *
17
17
  * Ice Lake's VNNI instructions accelerate int8 GEMM by computing 4-element dot products per lane.
18
18
  * VPDPBUSD/VPDPWSSD bottleneck on port 0, limiting throughput to 1/cy. AMD Genoa achieves 0.5/cy
@@ -45,9 +45,11 @@ extern "C" {
45
45
  nk_define_cross_compensated_pack_size_(dots, i8, icelake, i8, i8,
46
46
  /*sum_value_type=*/i32, /*norm_value_type=*/u32,
47
47
  /*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
48
- nk_define_cross_compensated_pack_(dots, i8, icelake, i8, i8, nk_assign_from_to_,
49
- /*sum_value_type=*/i32, /*norm_value_type=*/u32, nk_dots_reduce_moments_i8_,
50
- /*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
48
+ nk_define_cross_compensated_pack_(dots, i8, icelake, i8, i8, nk_b512_vec_t, nk_load_b512_skylake_,
49
+ nk_partial_load_b8x64_skylake_, nk_store_b512_skylake_,
50
+ nk_partial_store_b8x64_skylake_, /*simd_width=*/64, /*sum_value_type=*/i32,
51
+ /*norm_value_type=*/u32, nk_dots_reduce_moments_i8_, /*depth_simd_dimensions=*/64,
52
+ /*dimensions_per_value=*/1)
51
53
  nk_define_cross_compensated_symmetric_(dots, i8, icelake, i8, i32,
52
54
  /*sum_value_type=*/i32, /*norm_value_type=*/u32, nk_b512_vec_t,
53
55
  nk_dot_i8x64_state_icelake_t, nk_b128_vec_t, nk_dot_i8x64_init_icelake,
@@ -72,9 +74,11 @@ nk_define_cross_compensated_packed_(dots, i8, icelake, i8, i8, i32,
72
74
  nk_define_cross_compensated_pack_size_(dots, u8, icelake, u8, u8,
73
75
  /*sum_value_type=*/u32, /*norm_value_type=*/u32,
74
76
  /*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
75
- nk_define_cross_compensated_pack_(dots, u8, icelake, u8, u8, nk_assign_from_to_,
76
- /*sum_value_type=*/u32, /*norm_value_type=*/u32, nk_dots_reduce_moments_u8_,
77
- /*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
77
+ nk_define_cross_compensated_pack_(dots, u8, icelake, u8, u8, nk_b512_vec_t, nk_load_b512_skylake_,
78
+ nk_partial_load_b8x64_skylake_, nk_store_b512_skylake_,
79
+ nk_partial_store_b8x64_skylake_, /*simd_width=*/64, /*sum_value_type=*/u32,
80
+ /*norm_value_type=*/u32, nk_dots_reduce_moments_u8_, /*depth_simd_dimensions=*/64,
81
+ /*dimensions_per_value=*/1)
78
82
  nk_define_cross_compensated_symmetric_(dots, u8, icelake, u8, u32,
79
83
  /*sum_value_type=*/u32, /*norm_value_type=*/u32, nk_b512_vec_t,
80
84
  nk_dot_u8x64_state_icelake_t, nk_b128_vec_t, nk_dot_u8x64_init_icelake,
@@ -99,9 +103,11 @@ nk_define_cross_compensated_packed_(dots, u8, icelake, u8, u8, u32,
99
103
  nk_define_cross_compensated_pack_size_(dots, i4, icelake, i4x2, i4x2,
100
104
  /*sum_value_type=*/i32, /*norm_value_type=*/u32,
101
105
  /*depth_simd_dimensions=*/128, /*dimensions_per_value=*/2)
102
- nk_define_cross_compensated_pack_(dots, i4, icelake, i4x2, i4x2, nk_assign_from_to_,
103
- /*sum_value_type=*/i32, /*norm_value_type=*/u32, nk_dots_reduce_moments_i4_,
104
- /*depth_simd_dimensions=*/128, /*dimensions_per_value=*/2)
106
+ nk_define_cross_compensated_pack_(dots, i4, icelake, i4x2, i4x2, nk_b512_vec_t, nk_load_b512_skylake_,
107
+ nk_partial_load_b8x64_skylake_, nk_store_b512_skylake_,
108
+ nk_partial_store_b8x64_skylake_, /*simd_width=*/64, /*sum_value_type=*/i32,
109
+ /*norm_value_type=*/u32, nk_dots_reduce_moments_i4_, /*depth_simd_dimensions=*/64,
110
+ /*dimensions_per_value=*/2)
105
111
  nk_define_cross_compensated_symmetric_(dots, i4, icelake, i4x2, i32,
106
112
  /*sum_value_type=*/i32, /*norm_value_type=*/u32, nk_b512_vec_t,
107
113
  nk_dot_i4x128_state_icelake_t, nk_b128_vec_t, nk_dot_i4x128_init_icelake,
@@ -125,8 +131,10 @@ nk_define_cross_compensated_packed_(dots, i4, icelake, i4x2, i4x2, i32,
125
131
  /* U4 GEMM: depth_simd_dimensions=128 (128 nibbles = 64 bytes = full cache line) */
126
132
  nk_define_cross_pack_size_(dots, u4, icelake, u4x2, u4x2, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/128,
127
133
  /*dimensions_per_value=*/2)
128
- nk_define_cross_pack_(dots, u4, icelake, u4x2, u4x2, nk_assign_from_to_, /*norm_value_type=*/u32,
129
- nk_dots_reduce_sumsq_u4_, /*depth_simd_dimensions=*/128, /*dimensions_per_value=*/2)
134
+ nk_define_cross_pack_(dots, u4, icelake, u4x2, u4x2, nk_b512_vec_t, nk_load_b512_skylake_,
135
+ nk_partial_load_b8x64_skylake_, nk_store_b512_skylake_, nk_partial_store_b8x64_skylake_,
136
+ /*simd_width=*/64, /*norm_value_type=*/u32, nk_dots_reduce_sumsq_u4_,
137
+ /*depth_simd_dimensions=*/128, /*dimensions_per_value=*/2)
130
138
 
131
139
  nk_define_cross_symmetric_(dots, u4, icelake, u4x2, u32, nk_b512_vec_t, nk_dot_u4x128_state_icelake_t, nk_b128_vec_t,
132
140
  nk_dot_u4x128_init_icelake, nk_load_b512_skylake_, nk_partial_load_b4x128_skylake_,
@@ -142,8 +150,9 @@ nk_define_cross_packed_(dots, u4, icelake, u4x2, u4x2, u32, nk_b512_vec_t, nk_do
142
150
  /* U1 GEMM: depth_simd_dimensions=512 (512 bits = 64 bytes = full cache line) */
143
151
  nk_define_cross_pack_size_(dots, u1, icelake, u1x8, u1x8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/512,
144
152
  /*dimensions_per_value=*/8)
145
- nk_define_cross_pack_(dots, u1, icelake, u1x8, u1x8, nk_assign_from_to_, /*norm_value_type=*/u32,
146
- nk_dots_reduce_sum_u1_,
153
+ nk_define_cross_pack_(dots, u1, icelake, u1x8, u1x8, nk_b512_vec_t, nk_load_b512_skylake_,
154
+ nk_partial_load_b8x64_skylake_, nk_store_b512_skylake_, nk_partial_store_b8x64_skylake_,
155
+ /*simd_width=*/64, /*norm_value_type=*/u32, nk_dots_reduce_sum_u1_,
147
156
  /*depth_simd_dimensions=*/512, /*dimensions_per_value=*/8)
148
157
  nk_define_cross_symmetric_(dots, u1, icelake, u1x8, u32, nk_b512_vec_t, nk_dot_u1x512_state_icelake_t, nk_b128_vec_t,
149
158
  nk_dot_u1x512_init_icelake, nk_load_b512_skylake_, nk_partial_load_b1x512_skylake_,