numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,315 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for Genoa.
3
+ * @file include/numkong/dot/genoa.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_genoa_instructions Key AVX-512 BF16 Instructions
10
+ *
11
+ * Intrinsic Instruction Latency Throughput Ports
12
+ * _mm512_dpbf16_ps VDPBF16PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
13
+ * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
14
+ * _mm512_add_ps VADDPS (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
15
+ *
16
+ * AMD Genoa introduces native AVX-512 BF16 support with VDPBF16PS, which computes two BF16 dot products
17
+ * per 32-bit lane (32 BF16 multiplies accumulated into 16 FP32 values per instruction). This provides
18
+ * twice the throughput of FP32 FMA for BF16 workloads, ideal for machine learning inference.
19
+ *
20
+ * @section dot_genoa_stateful Stateful Streaming Logic
21
+ *
22
+ * To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
23
+ * `NK_INTERNAL` functions:
24
+ *
25
+ * - nk_dot_bf16x32 state with native BF16 dot-products using VDPBF16PS,
26
+ * - nk_dot_through_bf16 state for FP8 inputs (e4m3, e5m2) converted to BF16.
27
+ *
28
+ * @code{c}
29
+ * nk_dot_bf16x32_state_genoa_t state_first, state_second, state_third, state_fourth;
30
+ * nk_b512_vec_t query_bf16x32, target_first_bf16x32, target_second_bf16x32, target_third_bf16x32, target_fourth;
31
+ * nk_dot_bf16x32_init_genoa(&state_first);
32
+ * nk_dot_bf16x32_init_genoa(&state_second);
33
+ * nk_dot_bf16x32_init_genoa(&state_third);
34
+ * nk_dot_bf16x32_init_genoa(&state_fourth);
35
+ * for (nk_size_t idx = 0; idx + 32 <= depth; idx += 32) {
36
+ * query_bf16x32.zmm = _mm512_loadu_si512(query_ptr + idx);
37
+ * target_first_bf16x32.zmm = _mm512_loadu_si512(target_first_ptr + idx);
38
+ * target_second_bf16x32.zmm = _mm512_loadu_si512(target_second_ptr + idx);
39
+ * target_third_bf16x32.zmm = _mm512_loadu_si512(target_third_ptr + idx);
40
+ * target_fourth.zmm = _mm512_loadu_si512(target_fourth_ptr + idx);
41
+ * nk_dot_bf16x32_update_genoa(&state_first, query_bf16x32, target_first_bf16x32, idx, 32);
42
+ * nk_dot_bf16x32_update_genoa(&state_second, query_bf16x32, target_second_bf16x32, idx, 32);
43
+ * nk_dot_bf16x32_update_genoa(&state_third, query_bf16x32, target_third_bf16x32, idx, 32);
44
+ * nk_dot_bf16x32_update_genoa(&state_fourth, query_bf16x32, target_fourth, idx, 32);
45
+ * }
46
+ * nk_b128_vec_t results_f32x4;
47
+ * nk_dot_bf16x32_finalize_genoa(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f32x4);
48
+ * @endcode
49
+ *
50
+ * FP8 types (e4m3, e5m2) are upcast to BF16 using Ice Lake conversion functions, then
51
+ * accumulated using the native BF16 dot-product circuitry:
52
+ *
53
+ * @code{c}
54
+ * nk_dot_through_bf16_state_genoa_t_ state_first, state_second, state_third, state_fourth;
55
+ * nk_b512_vec_t query_bf16x32, target_first_bf16x32, target_second_bf16x32, target_third_bf16x32, target_fourth;
56
+ * nk_dot_through_bf16_init_genoa_(&state_first);
57
+ * nk_dot_through_bf16_init_genoa_(&state_second);
58
+ * nk_dot_through_bf16_init_genoa_(&state_third);
59
+ * nk_dot_through_bf16_init_genoa_(&state_fourth);
60
+ * for (nk_size_t idx = 0; idx + 32 <= depth; idx += 32) {
61
+ * nk_load_e4m3x32_to_bf16x32_icelake_(query_ptr + idx, &query_bf16x32);
62
+ * nk_load_e4m3x32_to_bf16x32_icelake_(target_first_ptr + idx, &target_first_bf16x32);
63
+ * nk_load_e4m3x32_to_bf16x32_icelake_(target_second_ptr + idx, &target_second_bf16x32);
64
+ * nk_load_e4m3x32_to_bf16x32_icelake_(target_third_ptr + idx, &target_third_bf16x32);
65
+ * nk_load_e4m3x32_to_bf16x32_icelake_(target_fourth_ptr + idx, &target_fourth);
66
+ * nk_dot_through_bf16_update_genoa_(&state_first, query_bf16x32, target_first_bf16x32, idx, 32);
67
+ * nk_dot_through_bf16_update_genoa_(&state_second, query_bf16x32, target_second_bf16x32, idx, 32);
68
+ * nk_dot_through_bf16_update_genoa_(&state_third, query_bf16x32, target_third_bf16x32, idx, 32);
69
+ * nk_dot_through_bf16_update_genoa_(&state_fourth, query_bf16x32, target_fourth, idx, 32);
70
+ * }
71
+ * nk_b128_vec_t results_f32x4;
72
+ * nk_dot_through_bf16_finalize_genoa_(&state_first, &state_second, &state_third, &state_fourth,
73
+ * depth, &results_f32x4);
74
+ * @endcode
75
+ */
76
+ #ifndef NK_DOT_GENOA_H
77
+ #define NK_DOT_GENOA_H
78
+
79
+ #if NK_TARGET_X86_
80
+ #if NK_TARGET_GENOA
81
+
82
+ #include "numkong/types.h"
83
+ #include "numkong/cast/icelake.h" // `nk_e4m3x32_to_bf16x32_icelake_`
84
+ #include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
85
+ #include "numkong/dot/skylake.h" // `nk_dot_through_f32_finalize_skylake_`
86
+
87
+ #if defined(__cplusplus)
88
+ extern "C" {
89
+ #endif
90
+
91
+ #if defined(__clang__)
92
+ #pragma clang attribute push( \
93
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512bf16,f16c,fma,bmi,bmi2"))), \
94
+ apply_to = function)
95
+ #elif defined(__GNUC__)
96
+ #pragma GCC push_options
97
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512bf16", "f16c", "fma", "bmi", "bmi2")
98
+ #endif
99
+
100
+ NK_PUBLIC void nk_dot_bf16_genoa(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
101
+ nk_f32_t *result) {
102
+ __m512i a_bf16x32, b_bf16x32;
103
+ __m512 sum_f32x16 = _mm512_setzero_ps();
104
+
105
+ nk_dot_bf16_genoa_cycle:
106
+ if (count_scalars < 32) {
107
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
108
+ a_bf16x32 = _mm512_maskz_loadu_epi16(mask, a_scalars);
109
+ b_bf16x32 = _mm512_maskz_loadu_epi16(mask, b_scalars);
110
+ count_scalars = 0;
111
+ }
112
+ else {
113
+ a_bf16x32 = _mm512_loadu_epi16(a_scalars);
114
+ b_bf16x32 = _mm512_loadu_epi16(b_scalars);
115
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
116
+ }
117
+ sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
118
+ if (count_scalars) goto nk_dot_bf16_genoa_cycle;
119
+
120
+ *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
121
+ }
122
+
123
+ NK_PUBLIC void nk_dot_bf16c_genoa(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
124
+ nk_f32c_t *result) {
125
+ __m512i a_bf16x32, b_bf16x32;
126
+ __m512 sum_real_f32x16 = _mm512_setzero_ps();
127
+ __m512 sum_imag_f32x16 = _mm512_setzero_ps();
128
+
129
+ // We take into account, that FMS is the same as FMA with a negative multiplier.
130
+ // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
131
+ // This way we can avoid the shuffling and the need for separate real and imaginary parts.
132
+ // For the imaginary part of the product, we would need to swap the real and imaginary parts of
133
+ // one of the vectors.
134
+ __m512i const sign_flip_bf16x32 = _mm512_set1_epi32(0x80000000);
135
+ __m512i const swap_adjacent_bf16x32 = _mm512_set_epi8( //
136
+ 61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane
137
+ 45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane
138
+ 29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane
139
+ 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane
140
+ );
141
+
142
+ nk_dot_bf16c_genoa_cycle:
143
+ if (count_pairs < 16) {
144
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_pairs * 2);
145
+ a_bf16x32 = _mm512_maskz_loadu_epi16(mask, (nk_i16_t const *)a_pairs);
146
+ b_bf16x32 = _mm512_maskz_loadu_epi16(mask, (nk_i16_t const *)b_pairs);
147
+ count_pairs = 0;
148
+ }
149
+ else {
150
+ a_bf16x32 = _mm512_loadu_epi16((nk_i16_t const *)a_pairs);
151
+ b_bf16x32 = _mm512_loadu_epi16((nk_i16_t const *)b_pairs);
152
+ a_pairs += 16, b_pairs += 16, count_pairs -= 16;
153
+ }
154
+ sum_real_f32x16 = _mm512_dpbf16_ps(sum_real_f32x16,
155
+ nk_m512bh_from_m512i_(_mm512_xor_si512(b_bf16x32, sign_flip_bf16x32)),
156
+ nk_m512bh_from_m512i_(a_bf16x32));
157
+ sum_imag_f32x16 = _mm512_dpbf16_ps(sum_imag_f32x16,
158
+ nk_m512bh_from_m512i_(_mm512_shuffle_epi8(b_bf16x32, swap_adjacent_bf16x32)),
159
+ nk_m512bh_from_m512i_(a_bf16x32));
160
+ if (count_pairs) goto nk_dot_bf16c_genoa_cycle;
161
+
162
+ // Reduce horizontal sums:
163
+ result->real = nk_reduce_add_f32x16_skylake_(sum_real_f32x16);
164
+ result->imag = nk_reduce_add_f32x16_skylake_(sum_imag_f32x16);
165
+ }
166
+
167
+ NK_PUBLIC void nk_vdot_bf16c_genoa(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
168
+ nk_f32c_t *result) {
169
+ __m512i a_bf16x32, b_bf16x32;
170
+ __m512 sum_real_f32x16 = _mm512_setzero_ps();
171
+ __m512 sum_imag_f32x16 = _mm512_setzero_ps();
172
+
173
+ // We take into account, that FMS is the same as FMA with a negative multiplier.
174
+ // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
175
+ // This way we can avoid the shuffling and the need for separate real and imaginary parts.
176
+ // For the imaginary part of the product, we would need to swap the real and imaginary parts of
177
+ // one of the vectors.
178
+ __m512i const sign_flip_bf16x32 = _mm512_set1_epi32(0x80000000);
179
+ __m512i const swap_adjacent_bf16x32 = _mm512_set_epi8( //
180
+ 61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane
181
+ 45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane
182
+ 29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane
183
+ 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane
184
+ );
185
+
186
+ nk_vdot_bf16c_genoa_cycle:
187
+ if (count_pairs < 16) {
188
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_pairs * 2);
189
+ a_bf16x32 = _mm512_maskz_loadu_epi16(mask, (nk_i16_t const *)a_pairs);
190
+ b_bf16x32 = _mm512_maskz_loadu_epi16(mask, (nk_i16_t const *)b_pairs);
191
+ count_pairs = 0;
192
+ }
193
+ else {
194
+ a_bf16x32 = _mm512_loadu_epi16((nk_i16_t const *)a_pairs);
195
+ b_bf16x32 = _mm512_loadu_epi16((nk_i16_t const *)b_pairs);
196
+ a_pairs += 16, b_pairs += 16, count_pairs -= 16;
197
+ }
198
+ sum_real_f32x16 = _mm512_dpbf16_ps(sum_real_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
199
+ nk_m512bh_from_m512i_(b_bf16x32));
200
+ a_bf16x32 = _mm512_xor_si512(a_bf16x32, sign_flip_bf16x32);
201
+ b_bf16x32 = _mm512_shuffle_epi8(b_bf16x32, swap_adjacent_bf16x32);
202
+ sum_imag_f32x16 = _mm512_dpbf16_ps(sum_imag_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
203
+ nk_m512bh_from_m512i_(b_bf16x32));
204
+ if (count_pairs) goto nk_vdot_bf16c_genoa_cycle;
205
+
206
+ // Reduce horizontal sums:
207
+ result->real = nk_reduce_add_f32x16_skylake_(sum_real_f32x16);
208
+ result->imag = nk_reduce_add_f32x16_skylake_(sum_imag_f32x16);
209
+ }
210
+
211
+ NK_PUBLIC void nk_dot_e4m3_genoa(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
212
+ nk_f32_t *result) {
213
+ __m256i a_e4m3x32, b_e4m3x32;
214
+ __m512 sum_f32x16 = _mm512_setzero_ps();
215
+
216
+ nk_dot_e4m3_genoa_cycle:
217
+ if (count_scalars < 32) {
218
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
219
+ a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a_scalars);
220
+ b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b_scalars);
221
+ count_scalars = 0;
222
+ }
223
+ else {
224
+ a_e4m3x32 = _mm256_loadu_epi8(a_scalars);
225
+ b_e4m3x32 = _mm256_loadu_epi8(b_scalars);
226
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
227
+ }
228
+ // Convert E4M3 to BF16 and compute dot product
229
+ __m512i a_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(a_e4m3x32);
230
+ __m512i b_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(b_e4m3x32);
231
+ sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
232
+ if (count_scalars) goto nk_dot_e4m3_genoa_cycle;
233
+
234
+ *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
235
+ }
236
+
237
+ NK_PUBLIC void nk_dot_e5m2_genoa(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
238
+ nk_f32_t *result) {
239
+ __m256i a_e5m2x32, b_e5m2x32;
240
+ __m512 sum_f32x16 = _mm512_setzero_ps();
241
+
242
+ nk_dot_e5m2_genoa_cycle:
243
+ if (count_scalars < 32) {
244
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
245
+ a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a_scalars);
246
+ b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b_scalars);
247
+ count_scalars = 0;
248
+ }
249
+ else {
250
+ a_e5m2x32 = _mm256_loadu_epi8(a_scalars);
251
+ b_e5m2x32 = _mm256_loadu_epi8(b_scalars);
252
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
253
+ }
254
+ // Convert E5M2 to BF16 and compute dot product
255
+ __m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
256
+ __m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
257
+ sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
258
+ if (count_scalars) goto nk_dot_e5m2_genoa_cycle;
259
+
260
+ *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
261
+ }
262
+
263
+ typedef nk_dot_through_f32_state_skylake_t_ nk_dot_through_bf16_state_genoa_t_;
264
+
265
+ NK_INTERNAL void nk_dot_through_bf16_init_genoa_(nk_dot_through_bf16_state_genoa_t_ *state) {
266
+ state->sum_f32x16 = _mm512_setzero();
267
+ }
268
+
269
+ NK_INTERNAL void nk_dot_through_bf16_update_genoa_(nk_dot_through_bf16_state_genoa_t_ *state, nk_b512_vec_t a,
270
+ nk_b512_vec_t b, nk_size_t depth_offset,
271
+ nk_size_t active_dimensions) {
272
+ nk_unused_(depth_offset);
273
+ nk_unused_(active_dimensions);
274
+ state->sum_f32x16 = _mm512_dpbf16_ps(state->sum_f32x16, nk_m512bh_from_m512i_(a.zmm), nk_m512bh_from_m512i_(b.zmm));
275
+ }
276
+
277
+ NK_INTERNAL void nk_dot_through_bf16_finalize_genoa_( //
278
+ nk_dot_through_bf16_state_genoa_t_ const *state_a, nk_dot_through_bf16_state_genoa_t_ const *state_b, //
279
+ nk_dot_through_bf16_state_genoa_t_ const *state_c, nk_dot_through_bf16_state_genoa_t_ const *state_d, //
280
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
281
+ nk_dot_through_f32_finalize_skylake_(state_a, state_b, state_c, state_d, total_dimensions, result);
282
+ }
283
+
284
+ typedef nk_dot_through_bf16_state_genoa_t_ nk_dot_bf16x32_state_genoa_t;
285
+
286
+ NK_INTERNAL void nk_dot_bf16x32_init_genoa(nk_dot_bf16x32_state_genoa_t *state) {
287
+ nk_dot_through_bf16_init_genoa_(state);
288
+ }
289
+
290
+ NK_INTERNAL void nk_dot_bf16x32_update_genoa(nk_dot_bf16x32_state_genoa_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
291
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
292
+ nk_dot_through_bf16_update_genoa_(state, a, b, depth_offset, active_dimensions);
293
+ }
294
+
295
+ NK_INTERNAL void nk_dot_bf16x32_finalize_genoa(nk_dot_bf16x32_state_genoa_t const *state_a,
296
+ nk_dot_bf16x32_state_genoa_t const *state_b,
297
+ nk_dot_bf16x32_state_genoa_t const *state_c,
298
+ nk_dot_bf16x32_state_genoa_t const *state_d, nk_size_t total_dimensions,
299
+ nk_b128_vec_t *result) {
300
+ nk_dot_through_bf16_finalize_genoa_(state_a, state_b, state_c, state_d, total_dimensions, result);
301
+ }
302
+
303
+ #if defined(__clang__)
304
+ #pragma clang attribute pop
305
+ #elif defined(__GNUC__)
306
+ #pragma GCC pop_options
307
+ #endif
308
+
309
+ #if defined(__cplusplus)
310
+ } // extern "C"
311
+ #endif
312
+
313
+ #endif // NK_TARGET_GENOA
314
+ #endif // NK_TARGET_X86_
315
+ #endif // NK_DOT_GENOA_H