numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,10 +8,10 @@
8
8
  *
9
9
  * @section dot_alder_instructions AVX-VNNI Instructions Performance
10
10
  *
11
- * Intrinsic Instruction Alder Lake Raptor Lake
12
- * _mm256_dpbusd_epi32 VPDPBUSD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
13
- * _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
14
- * _mm256_sad_epu8 VPSADBW (YMM, YMM, YMM) 3cy @ p5 3cy @ p5
11
+ * Intrinsic Instruction Alder Lake Raptor Lake
12
+ * _mm256_dpbusd_epi32 VPDPBUSD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
13
+ * _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
14
+ * _mm256_sad_epu8 VPSADBW (YMM, YMM, YMM) 3cy @ p5 3cy @ p5
15
15
  *
16
16
  * Alder Lake and Raptor Lake support AVX-VNNI (256-bit VNNI)
17
17
  * for accelerated integer dot products. This is the 256-bit variant of AVX-512 VNNI found on Ice Lake.
@@ -208,13 +208,13 @@ NK_INTERNAL void nk_dot_i8x32_finalize_alder(
208
208
  _mm256_extracti128_si256(state_d->biased_product_sum_i32x8, 1));
209
209
 
210
210
  // 4-way transpose reduce
211
- __m128i t_ab_lo = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
212
- __m128i t_cd_lo = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
213
- __m128i t_ab_hi = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
214
- __m128i t_cd_hi = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
211
+ __m128i t_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
212
+ __m128i t_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
213
+ __m128i t_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
214
+ __m128i t_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
215
215
  __m128i biased_i32x4 = _mm_add_epi32(
216
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
217
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
216
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
217
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
218
218
 
219
219
  // Apply compensation: result = biased − 128 × Σb
220
220
  __m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
@@ -328,13 +328,13 @@ NK_INTERNAL void nk_dot_u8x32_finalize_alder(
328
328
  _mm256_extracti128_si256(state_d->biased_product_sum_i32x8, 1));
329
329
 
330
330
  // 4-way transpose reduce
331
- __m128i t_ab_lo = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
332
- __m128i t_cd_lo = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
333
- __m128i t_ab_hi = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
334
- __m128i t_cd_hi = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
331
+ __m128i t_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
332
+ __m128i t_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
333
+ __m128i t_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
334
+ __m128i t_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
335
335
  __m128i biased_i32x4 = _mm_add_epi32(
336
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
337
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
336
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
337
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
338
338
 
339
339
  // Apply compensation: result = biased + 128 × Σb
340
340
  __m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
@@ -355,20 +355,20 @@ NK_INTERNAL void nk_sum_i8x32_init_alder(nk_sum_i8x32_state_alder_t *state) {
355
355
  state->biased_sum_u64x4 = _mm256_setzero_si256();
356
356
  }
357
357
  NK_INTERNAL void nk_sum_i8x32_update_alder(nk_sum_i8x32_state_alder_t *state, nk_b256_vec_t vector) {
358
- /* Convert signed→unsigned via XOR 0x80, then SAD against zero gives sum of unsigned values */
358
+ // Convert signed→unsigned via XOR 0x80, then SAD against zero gives sum of unsigned values
359
359
  __m256i vector_unsigned_u8x32 = _mm256_xor_si256(vector.ymm, _mm256_set1_epi8((char)0x80));
360
360
  __m256i sad_result_u64x4 = _mm256_sad_epu8(vector_unsigned_u8x32, _mm256_setzero_si256());
361
361
  state->biased_sum_u64x4 = _mm256_add_epi64(state->biased_sum_u64x4, sad_result_u64x4);
362
362
  }
363
363
  NK_INTERNAL nk_i32_t nk_sum_i8x32_finalize_alder(nk_sum_i8x32_state_alder_t const *state, nk_size_t count) {
364
- /* Horizontal reduce u64x4 → scalar */
364
+ // Horizontal reduce u64x4 → scalar
365
365
  __m128i low_u64x2 = _mm256_castsi256_si128(state->biased_sum_u64x4);
366
366
  __m128i high_u64x2 = _mm256_extracti128_si256(state->biased_sum_u64x4, 1);
367
367
  __m128i paired_u64x2 = _mm_add_epi64(low_u64x2, high_u64x2);
368
368
  __m128i shuffled_u64x2 = _mm_shuffle_epi32(paired_u64x2, _MM_SHUFFLE(1, 0, 3, 2));
369
369
  __m128i total_u64x2 = _mm_add_epi64(paired_u64x2, shuffled_u64x2);
370
370
  nk_u64_t unsigned_sum = (nk_u64_t)_mm_cvtsi128_si64(total_u64x2);
371
- /* Undo XOR bias: signed_sum = unsigned_sum - 128 * count */
371
+ // Undo XOR bias: signed_sum = unsigned_sum - 128 * count
372
372
  return (nk_i32_t)((nk_i64_t)unsigned_sum - 128 * (nk_i64_t)count);
373
373
  }
374
374
 
@@ -403,10 +403,10 @@ NK_PUBLIC void nk_dot_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_
403
403
  // This is the Alder Lake (256-bit AVX-VNNI) variant of the Ice Lake kernel.
404
404
  // DPBUSD replaces MADDUBS+MADD (2 instructions → 1), accumulating u8×i8→i32 directly.
405
405
  //
406
- __m256i const lut_lower_u8x32 = _mm256_set_epi8( //
406
+ __m256i const lut_low_u8x32 = _mm256_set_epi8( //
407
407
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
408
408
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
409
- __m256i const lut_upper_u8x32 = _mm256_set_epi8( //
409
+ __m256i const lut_high_u8x32 = _mm256_set_epi8( //
410
410
  120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
411
411
  120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
412
412
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
@@ -436,18 +436,18 @@ nk_dot_e2m3_alder_cycle:
436
436
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
437
437
  __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
438
438
  __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
439
- __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
440
- half_select_u8x32);
441
- __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
442
- half_select_u8x32);
439
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
440
+ half_select_u8x32);
441
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
442
+ half_select_u8x32);
443
443
 
444
444
  // Dual VPSHUFB: lookup in both halves, blend based on bit 4
445
- __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
446
- _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
447
- a_upper_select_u8x32);
448
- __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
449
- _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
450
- b_upper_select_u8x32);
445
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
446
+ _mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
447
+ a_high_select_u8x32);
448
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
449
+ _mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
450
+ b_high_select_u8x32);
451
451
 
452
452
  // Combined sign: (a ^ b) & 0x20, negate b where signs differ
453
453
  __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
@@ -474,10 +474,10 @@ NK_INTERNAL void nk_dot_e2m3x32_update_alder(nk_dot_e2m3x32_state_alder_t *state
474
474
  nk_size_t depth_offset, nk_size_t active_dimensions) {
475
475
  nk_unused_(depth_offset);
476
476
  nk_unused_(active_dimensions);
477
- __m256i const lut_lower_u8x32 = _mm256_set_epi8( //
477
+ __m256i const lut_low_u8x32 = _mm256_set_epi8( //
478
478
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
479
479
  30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
480
- __m256i const lut_upper_u8x32 = _mm256_set_epi8( //
480
+ __m256i const lut_high_u8x32 = _mm256_set_epi8( //
481
481
  120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
482
482
  120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
483
483
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
@@ -493,18 +493,18 @@ NK_INTERNAL void nk_dot_e2m3x32_update_alder(nk_dot_e2m3x32_state_alder_t *state
493
493
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
494
494
  __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
495
495
  __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
496
- __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
497
- half_select_u8x32);
498
- __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
499
- half_select_u8x32);
496
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
497
+ half_select_u8x32);
498
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
499
+ half_select_u8x32);
500
500
 
501
501
  // Dual VPSHUFB + blend
502
- __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
503
- _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
504
- a_upper_select_u8x32);
505
- __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
506
- _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
507
- b_upper_select_u8x32);
502
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
503
+ _mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
504
+ a_high_select_u8x32);
505
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
506
+ _mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
507
+ b_high_select_u8x32);
508
508
 
509
509
  // Combined sign + conditional negate
510
510
  __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
@@ -0,0 +1,158 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for Diamond Rapids.
3
+ * @file include/numkong/dot/diamond.h
4
+ * @author Ash Vardanian
5
+ * @date March 23, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_diamond_instructions Key AVX10.2 FP8 + FP16 VNNI Instructions
10
+ *
11
+ * Intrinsic Instruction Diamond Rapids
12
+ * _mm512_cvthf8_ph VCVTHF82PH (ZMM, YMM) ~3cy (estimated)
13
+ * _mm512_cvtbf8_ph VCVTBF82PH (ZMM, YMM) ~3cy (estimated)
14
+ * _mm512_dpph_ps VDPPHPS (ZMM, ZMM, ZMM) ~6cy (estimated)
15
+ *
16
+ * Diamond Rapids (AVX10.2) introduces native FP8→FP16 conversion via VCVTHF82PH (E4M3→FP16)
17
+ * and VCVTBF82PH (E5M2→FP16), replacing the multi-instruction arithmetic conversion used by
18
+ * Genoa's BF16 path. VDPPHPS then computes two FP16 dot products per 32-bit lane, accumulating
19
+ * into FP32 — providing the same 32-element throughput as Genoa's VDPBF16PS but with FP16
20
+ * intermediate precision (10-bit mantissa vs BF16's 7-bit).
21
+ *
22
+ * @section dot_diamond_stateful Stateful Streaming Logic
23
+ *
24
+ * Defines stateful init/update/finalize helpers for tiled GEMM via the dots/ macros:
25
+ * - nk_dot_through_f16_state_diamond_t_ shared by both E4M3 and E5M2 (FP16→VDPPHPS→FP32)
26
+ */
27
+ #ifndef NK_DOT_DIAMOND_H
28
+ #define NK_DOT_DIAMOND_H
29
+
30
+ #if NK_TARGET_X86_
31
+ #if NK_TARGET_DIAMOND
32
+
33
+ #include "numkong/types.h"
34
+ #include "numkong/cast/diamond.h" // `nk_load_e4m3x32_to_f16x32_diamond_`
35
+ #include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
36
+ #include "numkong/dot/skylake.h" // `nk_dot_through_f32_finalize_skylake_`
37
+
38
+ #if defined(__cplusplus)
39
+ extern "C" {
40
+ #endif
41
+
42
+ #if defined(__clang__)
43
+ #pragma clang attribute push( \
44
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx10.2-512,f16c,fma,bmi,bmi2"))), \
45
+ apply_to = function)
46
+ #elif defined(__GNUC__)
47
+ #pragma GCC push_options
48
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx10.2-512", "f16c", "fma", \
49
+ "bmi", "bmi2")
50
+ #endif
51
+
52
+ NK_PUBLIC void nk_dot_e4m3_diamond(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
53
+ nk_f32_t *result) {
54
+ __m256i a_e4m3x32, b_e4m3x32;
55
+ __m512 sum_f32x16 = _mm512_setzero_ps();
56
+
57
+ nk_dot_e4m3_diamond_cycle:
58
+ if (count_scalars < 32) {
59
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
60
+ a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a_scalars);
61
+ b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b_scalars);
62
+ count_scalars = 0;
63
+ }
64
+ else {
65
+ a_e4m3x32 = _mm256_loadu_epi8(a_scalars);
66
+ b_e4m3x32 = _mm256_loadu_epi8(b_scalars);
67
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
68
+ }
69
+ __m512h a_f16x32 = _mm512_cvthf8_ph(a_e4m3x32);
70
+ __m512h b_f16x32 = _mm512_cvthf8_ph(b_e4m3x32);
71
+ sum_f32x16 = _mm512_dpph_ps(sum_f32x16, a_f16x32, b_f16x32);
72
+ if (count_scalars) goto nk_dot_e4m3_diamond_cycle;
73
+
74
+ *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
75
+ }
76
+
77
+ NK_PUBLIC void nk_dot_e5m2_diamond(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
78
+ nk_f32_t *result) {
79
+ __m256i a_e5m2x32, b_e5m2x32;
80
+ __m512 sum_f32x16 = _mm512_setzero_ps();
81
+
82
+ nk_dot_e5m2_diamond_cycle:
83
+ if (count_scalars < 32) {
84
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
85
+ a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a_scalars);
86
+ b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b_scalars);
87
+ count_scalars = 0;
88
+ }
89
+ else {
90
+ a_e5m2x32 = _mm256_loadu_epi8(a_scalars);
91
+ b_e5m2x32 = _mm256_loadu_epi8(b_scalars);
92
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
93
+ }
94
+ __m512h a_f16x32 = _mm512_cvtbf8_ph(a_e5m2x32);
95
+ __m512h b_f16x32 = _mm512_cvtbf8_ph(b_e5m2x32);
96
+ sum_f32x16 = _mm512_dpph_ps(sum_f32x16, a_f16x32, b_f16x32);
97
+ if (count_scalars) goto nk_dot_e5m2_diamond_cycle;
98
+
99
+ *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
100
+ }
101
+
102
+ NK_PUBLIC void nk_dot_f16_diamond(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
103
+ nk_f32_t *result) {
104
+ __m512h a_f16x32, b_f16x32;
105
+ __m512 sum_f32x16 = _mm512_setzero_ps();
106
+
107
+ nk_dot_f16_diamond_cycle:
108
+ if (count_scalars < 32) {
109
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
110
+ a_f16x32 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a_scalars));
111
+ b_f16x32 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b_scalars));
112
+ count_scalars = 0;
113
+ }
114
+ else {
115
+ a_f16x32 = _mm512_castsi512_ph(_mm512_loadu_epi16(a_scalars));
116
+ b_f16x32 = _mm512_castsi512_ph(_mm512_loadu_epi16(b_scalars));
117
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
118
+ }
119
+ sum_f32x16 = _mm512_dpph_ps(sum_f32x16, a_f16x32, b_f16x32);
120
+ if (count_scalars) goto nk_dot_f16_diamond_cycle;
121
+
122
+ *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
123
+ }
124
+
125
+ typedef nk_dot_through_f32_state_skylake_t_ nk_dot_through_f16_state_diamond_t_;
126
+
127
+ NK_INTERNAL void nk_dot_through_f16_init_diamond_(nk_dot_through_f16_state_diamond_t_ *state) {
128
+ state->sum_f32x16 = _mm512_setzero();
129
+ }
130
+
131
+ NK_INTERNAL void nk_dot_through_f16_update_diamond_(nk_dot_through_f16_state_diamond_t_ *state, nk_b512_vec_t a,
132
+ nk_b512_vec_t b, nk_size_t depth_offset,
133
+ nk_size_t active_dimensions) {
134
+ nk_unused_(depth_offset);
135
+ nk_unused_(active_dimensions);
136
+ state->sum_f32x16 = _mm512_dpph_ps(state->sum_f32x16, a.zmm_ph, b.zmm_ph);
137
+ }
138
+
139
+ NK_INTERNAL void nk_dot_through_f16_finalize_diamond_( //
140
+ nk_dot_through_f16_state_diamond_t_ const *state_a, nk_dot_through_f16_state_diamond_t_ const *state_b, //
141
+ nk_dot_through_f16_state_diamond_t_ const *state_c, nk_dot_through_f16_state_diamond_t_ const *state_d, //
142
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
143
+ nk_dot_through_f32_finalize_skylake_(state_a, state_b, state_c, state_d, total_dimensions, result);
144
+ }
145
+
146
+ #if defined(__clang__)
147
+ #pragma clang attribute pop
148
+ #elif defined(__GNUC__)
149
+ #pragma GCC pop_options
150
+ #endif
151
+
152
+ #if defined(__cplusplus)
153
+ } // extern "C"
154
+ #endif
155
+
156
+ #endif // NK_TARGET_DIAMOND
157
+ #endif // NK_TARGET_X86_
158
+ #endif // NK_DOT_DIAMOND_H
@@ -8,10 +8,10 @@
8
8
  *
9
9
  * @section dot_genoa_instructions Key AVX-512 BF16 Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm512_dpbf16_ps VDPBF16PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
13
- * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
14
- * _mm512_add_ps VADDPS (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
11
+ * Intrinsic Instruction Genoa Alder Lake
12
+ * _mm512_dpbf16_ps VDPBF16PS (ZMM, ZMM, ZMM) 6cy @ p01 8cy @ p0+p0+p5+p5
13
+ * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p01 4cy @ p0
14
+ * _mm512_add_ps VADDPS (ZMM, ZMM, ZMM) 4cy @ p01 3cy @ p05
15
15
  *
16
16
  * AMD Genoa introduces native AVX-512 BF16 support with VDPBF16PS, which computes two BF16 dot products
17
17
  * per 32-bit lane (32 BF16 multiplies accumulated into 16 FP32 values per instruction). This provides
@@ -208,32 +208,6 @@ nk_vdot_bf16c_genoa_cycle:
208
208
  result->imag = nk_reduce_add_f32x16_skylake_(sum_imag_f32x16);
209
209
  }
210
210
 
211
- NK_PUBLIC void nk_dot_e4m3_genoa(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
212
- nk_f32_t *result) {
213
- __m256i a_e4m3x32, b_e4m3x32;
214
- __m512 sum_f32x16 = _mm512_setzero_ps();
215
-
216
- nk_dot_e4m3_genoa_cycle:
217
- if (count_scalars < 32) {
218
- __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
219
- a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a_scalars);
220
- b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b_scalars);
221
- count_scalars = 0;
222
- }
223
- else {
224
- a_e4m3x32 = _mm256_loadu_epi8(a_scalars);
225
- b_e4m3x32 = _mm256_loadu_epi8(b_scalars);
226
- a_scalars += 32, b_scalars += 32, count_scalars -= 32;
227
- }
228
- // Convert E4M3 to BF16 and compute dot product
229
- __m512i a_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(a_e4m3x32);
230
- __m512i b_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(b_e4m3x32);
231
- sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
232
- if (count_scalars) goto nk_dot_e4m3_genoa_cycle;
233
-
234
- *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
235
- }
236
-
237
211
  NK_PUBLIC void nk_dot_e5m2_genoa(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
238
212
  nk_f32_t *result) {
239
213
  __m256i a_e5m2x32, b_e5m2x32;