numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1070 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for Real and Complex Numbers.
3
+ * @file include/numkong/dot.h
4
+ * @author Ash Vardanian
5
+ * @date February 24, 2024
6
+ *
7
+ * Contains:
8
+ *
9
+ * - Dot Product for Real and Complex vectors
10
+ * - Conjugate Dot Product for Complex vectors
11
+ *
12
+ * For dtypes:
13
+ *
14
+ * - f64: 64-bit IEEE floating point numbers → 64-bit floats
15
+ * - f32: 32-bit IEEE floating point numbers → 64-bit floats
16
+ * - f16: 16-bit IEEE floating point numbers → 32-bit floats
17
+ * - bf16: 16-bit brain floating point numbers → 32-bit floats
18
+ * - e4m3: 8-bit e4m3 floating point numbers → 32-bit floats
19
+ * - e5m2: 8-bit e5m2 floating point numbers → 32-bit floats
20
+ * - e2m3: 8-bit e2m3 floating point numbers (MX) → 32-bit floats
21
+ * - e3m2: 8-bit e3m2 floating point numbers (MX) → 32-bit floats
22
+ * - i8: 8-bit signed integers → 32-bit signed integers
23
+ * - u8: 8-bit unsigned integers → 32-bit unsigned integers
24
+ * - i4: 4-bit signed integers (packed nibble pairs) → 32-bit signed integers
25
+ * - u4: 4-bit unsigned integers (packed nibble pairs) → 32-bit unsigned integers
26
+ * - u1: 1-bit binary (packed octets) → 32-bit unsigned integers
27
+ *
28
+ * Complex dot product variants:
29
+ *
30
+ * - f64c: 64-bit complex pairs → 64-bit complex
31
+ * - f32c: 32-bit complex pairs → 64-bit complex
32
+ * - f16c: 16-bit complex pairs → 32-bit complex
33
+ * - bf16c: 16-bit brain complex pairs → 32-bit complex
34
+ *
35
+ * For hardware architectures:
36
+ *
37
+ * - Arm: NEON, NEON+I8, NEON+F16, NEON+FHM, NEON+BF16, SVE, SVE+F16
38
+ * - x86: Haswell, Skylake, Ice Lake, Genoa, Sapphire Rapids, Sierra Forest
39
+ * - RISC-V: RVV, RVV+BF16, RVV+HALF, RVV+BB
40
+ * - WASM: V128Relaxed
41
+ *
42
+ * @section numerical_stability Numerical Stability
43
+ *
44
+ * - f64: Dot2/Ogita-Rump-Oishi style compensated summation across serial and SIMD stateful paths.
45
+ * - f32: public outputs widen to f64/f64c. Arithmetic widens before the first lossy reduction step.
46
+ * - f16/bf16: Promoted to f32 accumulator.
47
+ * - e4m3/e5m2: Promoted to f32. On Sapphire, e2m3/e3m2 use f16 intermediate with periodic
48
+ * flush to f32 every 128 elements to avoid f16 overflow (max lane sum ~225 / ~3136).
49
+ * - i8: i32 accumulator. Max product |(-128)²| = 16,384. Overflows at n > 2^31/16,384 ≈ 131K.
50
+ * - u8: u32 accumulator. Max product 255² = 65,025. Overflows at n > 2^32/65,025 ≈ 66K.
51
+ * - i4: i32 accumulator. Max product 8² = 64. Safe for n ≤ ~33M.
52
+ * - u4: u32 accumulator. Max product 15² = 225. Safe for n ≤ ~19M.
53
+ * - u1: Popcount of AND into u32. Safe for n_bits ≤ 2^32.
54
+ * - Complex: Components accumulated independently; same guarantees as real counterpart.
55
+ *
56
+ * @section streaming_api Streaming API
57
+ *
58
+ * For compile-time dispatch and vector-at-a-time accumulation, we provide streaming helpers
59
+ * that accept two `nk_b512_vec_t` blocks and update a running sum for non-complex dot
60
+ * products. The `<count>` suffix reflects how many scalars of that type fit in a 512-bit block.
61
+ * The helpers are exposed per scalar type as:
62
+ *
63
+ * - nk_dot_<type>x<count>_state_<isa>_t
64
+ * - nk_dot_<type>x<count>_init_<isa>
65
+ * - nk_dot_<type>x<count>_update_<isa>
66
+ * - nk_dot_<type>x<count>_finalize_<isa>
67
+ *
68
+ * @section x86_instructions Relevant x86 Instructions
69
+ *
70
+ * Floating-point dot products use FMA (VFMADD231PS/PD) for sum += a[i]*b[i] accumulation.
71
+ * Integer i8 dot products use VPMADDUBSW (u8 × i8 → i16) + VPMADDWD (i16 × 1 → i32) on Haswell,
72
+ * or the newer VNNI instructions VPDPBUSD/VPDPWSSD on Ice Lake+ for direct u8 × i8 → i32.
73
+ * BF16 dot products (VDPBF16PS) are Genoa-only, accumulating bf16 pairs directly to f32.
74
+ * Genoa shows 40% faster integer multiply-add (3c vs 5c) than Ice Lake.
75
+ *
76
+ * Intrinsic Instruction Haswell Ice Genoa
77
+ * _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 5c @ p01 4c @ p01 4c @ p01
78
+ * _mm256_fmadd_pd VFMADD231PD (YMM, YMM, YMM) 5c @ p01 4c @ p01 4c @ p01
79
+ * _mm256_maddubs_epi16 VPMADDUBSW (YMM, YMM, YMM) 5c @ p0 5c @ p01 3c @ p01
80
+ * _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 5c @ p0 5c @ p01 3c @ p01
81
+ * _mm256_dpbusd_epi32 VPDPBUSD (YMM, YMM, YMM) N/A 5c @ p01 4c @ p01
82
+ * _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) N/A 5c @ p0 4c @ p01
83
+ * _mm512_dpbf16_ps VDPBF16PS (ZMM, ZMM, ZMM) N/A N/A 6c @ p01
84
+ *
85
+ * @section arm_neon_instructions Relevant ARM NEON Instructions
86
+ *
87
+ * NEON integer dot products use SDOT/UDOT (ARMv8.2 dotprod) for direct i8 × i8 → i32 or u8 × u8 → u32
88
+ * accumulation - 4x faster than the multiply-add sequence on older cores. BFDOT (ARMv8.6 bf16)
89
+ * provides native bf16 dot products on Graviton 3+. Complex dot products use LD2 for deinterleaved
90
+ * loads of real/imag pairs, though its L01+V throughput can bottleneck on memory-bound workloads.
91
+ *
92
+ * Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
93
+ * vfmaq_f32 FMLA.S (vec) 4c @ V0123 4c @ V0123 4c @ V0123
94
+ * vfmaq_f64 FMLA.D (vec) 4c @ V0123 4c @ V0123 4c @ V0123
95
+ * vdotq_s32 SDOT (vec) 3c @ V0123 3c @ V0123 3c @ V0123
96
+ * vdotq_u32 UDOT (vec) 3c @ V0123 3c @ V0123 3c @ V0123
97
+ * vbfdotq_f32 BFDOT (vec) N/A 4c @ V0123 5c @ V0123
98
+ * vld2q_f32 LD2 (Q-form) 5c @ L01+V 8c @ L01+V 8c @ L01+V
99
+ *
100
+ * @section arm_sve_instructions Relevant ARM SVE Instructions
101
+ *
102
+ * SVE implementations use predicated FMA (svmla_f32_x) with WHILELT for tail masking, avoiding
103
+ * scalar cleanup loops. FADDV performs horizontal reduction; notably 45% faster on Graviton 4
104
+ * (6c) than Graviton 3 (11c). SVE complex dot products use svld2 for structure loads.
105
+ *
106
+ * Intrinsic Instruction Graviton 3 Graviton 4
107
+ * svmla_f32_x FMLA (pred) 4c @ V0123 4c @ V0123
108
+ * svmls_f32_x FMLS (pred) 4c @ V0123 4c @ V0123
109
+ * svwhilelt_b32 WHILELT 3c @ M0 3c @ M0
110
+ * svld2_f32 LD2 (SVE) 8c @ L01+V 8c @ L01+V
111
+ * svaddv_f32 FADDV 11c @ V0123 6c @ V0123
112
+ *
113
+ * @section complex_instructions Complex Number Optimizations
114
+ *
115
+ * Standard complex multiplication involves subtraction for the real part.
116
+ * Instead of using subtracting variants of FMA for every element, we accumulate real
117
+ * and imaginary products positively and apply a single bitwise XOR to flip the sign
118
+ * bits before the final horizontal reduction. This delayed application of the sign
119
+ * flip doubles the throughput on older x86 architectures like Haswell by maximizing
120
+ * FMA unit utilization and reducing execution dependency chains.
121
+ *
122
+ * @section references References
123
+ *
124
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
125
+ * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
126
+ *
127
+ */
128
+ #ifndef NK_DOT_H
129
+ #define NK_DOT_H
130
+
131
+ #include "numkong/types.h"
132
+
133
+ #if defined(__cplusplus)
134
+ extern "C" {
135
+ #endif
136
+
137
+ /**
138
+ * @brief Dot product computing the sum of elementwise products between two vectors.
139
+ *
140
+ * @param[in] a The first vector.
141
+ * @param[in] b The second vector.
142
+ * @param[in] n The number of elements in the vectors.
143
+ * @param[out] result The output dot product value.
144
+ *
145
+ * @note The output value can be negative.
146
+ * @note The output value is zero if and only if the two vectors are orthogonal.
147
+ * @note Defined for floating-point, integer, and binary data types.
148
+ */
149
+ NK_DYNAMIC void nk_dot_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
150
+ /** @copydoc nk_dot_f32 */
151
+ NK_DYNAMIC void nk_dot_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
152
+ /** @copydoc nk_dot_f32 */
153
+ NK_DYNAMIC void nk_dot_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
154
+ /** @copydoc nk_dot_f32 */
155
+ NK_DYNAMIC void nk_dot_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
156
+ /** @copydoc nk_dot_f32 */
157
+ NK_DYNAMIC void nk_dot_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
158
+ /** @copydoc nk_dot_f32 */
159
+ NK_DYNAMIC void nk_dot_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
160
+ /** @copydoc nk_dot_f32 */
161
+ NK_DYNAMIC void nk_dot_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
162
+ /** @copydoc nk_dot_f32 */
163
+ NK_DYNAMIC void nk_dot_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
164
+ /** @copydoc nk_dot_f32 */
165
+ NK_DYNAMIC void nk_dot_u1(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
166
+ /** @copydoc nk_dot_f32 */
167
+ NK_DYNAMIC void nk_dot_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
168
+ /** @copydoc nk_dot_f32 */
169
+ NK_DYNAMIC void nk_dot_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
170
+ /** @copydoc nk_dot_f32 */
171
+ NK_DYNAMIC void nk_dot_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
172
+ /** @copydoc nk_dot_f32 */
173
+ NK_DYNAMIC void nk_dot_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
174
+
175
+ /**
176
+ * @brief Complex dot product computing the sum of elementwise products between two complex vectors.
177
+ *
178
+ * @param[in] a_pairs The first complex vector.
179
+ * @param[in] b_pairs The second complex vector.
180
+ * @param[in] count_pairs The number of complex pairs in the vectors.
181
+ * @param[out] result The output complex value as {real, imag}.
182
+ */
183
+ NK_DYNAMIC void nk_dot_f32c(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
184
+ nk_f64c_t *result);
185
+ /** @copydoc nk_dot_f32c */
186
+ NK_DYNAMIC void nk_dot_f64c(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
187
+ nk_f64c_t *result);
188
+ /** @copydoc nk_dot_f32c */
189
+ NK_DYNAMIC void nk_dot_f16c(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
190
+ nk_f32c_t *result);
191
+ /** @copydoc nk_dot_f32c */
192
+ NK_DYNAMIC void nk_dot_bf16c(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
193
+ nk_f32c_t *result);
194
+
195
+ /**
196
+ * @brief Complex conjugate dot product between two complex vectors.
197
+ *
198
+ * @param[in] a_pairs The first complex vector.
199
+ * @param[in] b_pairs The second complex vector.
200
+ * @param[in] count_pairs The number of complex pairs in the vectors.
201
+ * @param[out] result The output complex value as {real, imag}.
202
+ */
203
+ NK_DYNAMIC void nk_vdot_f32c(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
204
+ nk_f64c_t *result);
205
+ /** @copydoc nk_vdot_f32c */
206
+ NK_DYNAMIC void nk_vdot_f64c(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
207
+ nk_f64c_t *result);
208
+ /** @copydoc nk_vdot_f32c */
209
+ NK_DYNAMIC void nk_vdot_f16c(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
210
+ nk_f32c_t *result);
211
+ /** @copydoc nk_vdot_f32c */
212
+ NK_DYNAMIC void nk_vdot_bf16c(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
213
+ nk_f32c_t *result);
214
+
215
+ /** @copydoc nk_dot_f64 */
216
+ NK_PUBLIC void nk_dot_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
217
+ /** @copydoc nk_dot_f64c */
218
+ NK_PUBLIC void nk_dot_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
219
+ /** @copydoc nk_vdot_f64c */
220
+ NK_PUBLIC void nk_vdot_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
221
+
222
+ /** @copydoc nk_dot_f32 */
223
+ NK_PUBLIC void nk_dot_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
224
+ /** @copydoc nk_dot_f32c */
225
+ NK_PUBLIC void nk_dot_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
226
+ /** @copydoc nk_vdot_f32c */
227
+ NK_PUBLIC void nk_vdot_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
228
+
229
+ /** @copydoc nk_dot_f16 */
230
+ NK_PUBLIC void nk_dot_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
231
+ /** @copydoc nk_dot_f16c */
232
+ NK_PUBLIC void nk_dot_f16c_serial(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
233
+ /** @copydoc nk_vdot_f16c */
234
+ NK_PUBLIC void nk_vdot_f16c_serial(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
235
+
236
+ /** @copydoc nk_dot_bf16 */
237
+ NK_PUBLIC void nk_dot_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
238
+ /** @copydoc nk_dot_bf16c */
239
+ NK_PUBLIC void nk_dot_bf16c_serial(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
240
+ /** @copydoc nk_vdot_bf16c */
241
+ NK_PUBLIC void nk_vdot_bf16c_serial(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
242
+
243
+ /** @copydoc nk_dot_i8 */
244
+ NK_PUBLIC void nk_dot_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
245
+ /** @copydoc nk_dot_u8 */
246
+ NK_PUBLIC void nk_dot_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
247
+ /** @copydoc nk_dot_i4 */
248
+ NK_PUBLIC void nk_dot_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
249
+ /** @copydoc nk_dot_u4 */
250
+ NK_PUBLIC void nk_dot_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
251
+ /** @copydoc nk_dot_u1 */
252
+ NK_PUBLIC void nk_dot_u1_serial(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
253
+
254
+ /** @copydoc nk_dot_e4m3 */
255
+ NK_PUBLIC void nk_dot_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
256
+ /** @copydoc nk_dot_e5m2 */
257
+ NK_PUBLIC void nk_dot_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
258
+ /** @copydoc nk_dot_e2m3 */
259
+ NK_PUBLIC void nk_dot_e2m3_serial(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
260
+ /** @copydoc nk_dot_e3m2 */
261
+ NK_PUBLIC void nk_dot_e3m2_serial(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
262
+
263
+ #if NK_TARGET_NEON
264
+ /** @copydoc nk_dot_f32 */
265
+ NK_PUBLIC void nk_dot_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
266
+ /** @copydoc nk_dot_f32c */
267
+ NK_PUBLIC void nk_dot_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
268
+ /** @copydoc nk_vdot_f32c */
269
+ NK_PUBLIC void nk_vdot_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
270
+
271
+ /** @copydoc nk_dot_f64 */
272
+ NK_PUBLIC void nk_dot_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
273
+ /** @copydoc nk_dot_f64c */
274
+ NK_PUBLIC void nk_dot_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
275
+ /** @copydoc nk_vdot_f64c */
276
+ NK_PUBLIC void nk_vdot_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
277
+
278
+ /** @copydoc nk_dot_bf16 */
279
+ NK_PUBLIC void nk_dot_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
280
+
281
+ /** @copydoc nk_dot_e4m3 */
282
+ NK_PUBLIC void nk_dot_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
283
+ /** @copydoc nk_dot_e5m2 */
284
+ NK_PUBLIC void nk_dot_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
285
+ /** @copydoc nk_dot_e2m3 */
286
+ NK_PUBLIC void nk_dot_e2m3_neon(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
287
+ /** @copydoc nk_dot_e3m2 */
288
+ NK_PUBLIC void nk_dot_e3m2_neon(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
289
+
290
+ /** @copydoc nk_dot_u1 */
291
+ NK_PUBLIC void nk_dot_u1_neon(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
292
+
293
+ /** @copydoc nk_dot_f16 */
294
+ NK_PUBLIC void nk_dot_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
295
+
296
+ #endif // NK_TARGET_NEON
297
+
298
+ #if NK_TARGET_NEONHALF
299
+ /** @copydoc nk_dot_f16 */
300
+ NK_PUBLIC void nk_dot_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
301
+ /** @copydoc nk_dot_f16c */
302
+ NK_PUBLIC void nk_dot_f16c_neonhalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
303
+ /** @copydoc nk_vdot_f16c */
304
+ NK_PUBLIC void nk_vdot_f16c_neonhalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
305
+ #endif // NK_TARGET_NEONHALF
306
+
307
+ #if NK_TARGET_NEONFHM
308
+ /** @copydoc nk_dot_f16 */
309
+ NK_PUBLIC void nk_dot_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
310
+ /** @copydoc nk_dot_e4m3 */
311
+ NK_PUBLIC void nk_dot_e4m3_neonfhm(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
312
+ /** @copydoc nk_dot_e5m2 */
313
+ NK_PUBLIC void nk_dot_e5m2_neonfhm(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
314
+ /** @copydoc nk_dot_f16c */
315
+ NK_PUBLIC void nk_dot_f16c_neonfhm(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
316
+ /** @copydoc nk_vdot_f16c */
317
+ NK_PUBLIC void nk_vdot_f16c_neonfhm(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
318
+ #endif // NK_TARGET_NEONFHM
319
+
320
+ #if NK_TARGET_NEONSDOT
321
+ /** @copydoc nk_dot_i8 */
322
+ NK_PUBLIC void nk_dot_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
323
+ /** @copydoc nk_dot_u8 */
324
+ NK_PUBLIC void nk_dot_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
325
+ /** @copydoc nk_dot_i4 */
326
+ NK_PUBLIC void nk_dot_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
327
+ /** @copydoc nk_dot_u4 */
328
+ NK_PUBLIC void nk_dot_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
329
+ /** @copydoc nk_dot_e2m3 */
330
+ NK_PUBLIC void nk_dot_e2m3_neonsdot(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
331
+ /** @copydoc nk_dot_e3m2 */
332
+ NK_PUBLIC void nk_dot_e3m2_neonsdot(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
333
+ #endif // NK_TARGET_NEONSDOT
334
+
335
+ #if NK_TARGET_NEONBFDOT
336
+ /** @copydoc nk_dot_bf16 */
337
+ NK_PUBLIC void nk_dot_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
338
+ /** @copydoc nk_dot_e4m3 */
339
+ NK_PUBLIC void nk_dot_e4m3_neonbfdot(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
340
+ /** @copydoc nk_dot_e5m2 */
341
+ NK_PUBLIC void nk_dot_e5m2_neonbfdot(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
342
+ /** @copydoc nk_dot_bf16c */
343
+ NK_PUBLIC void nk_dot_bf16c_neonbfdot(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
344
+ /** @copydoc nk_vdot_bf16c */
345
+ NK_PUBLIC void nk_vdot_bf16c_neonbfdot(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
346
+ #endif // NK_TARGET_NEONBFDOT
347
+
348
+ #if NK_TARGET_SVEBFDOT
349
+ /** @copydoc nk_dot_bf16 */
350
+ NK_PUBLIC void nk_dot_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
351
+ #endif // NK_TARGET_SVEBFDOT
352
+
353
+ #if NK_TARGET_SVE
354
+ /** @copydoc nk_dot_f32 */
355
+ NK_PUBLIC void nk_dot_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
356
+ /** @copydoc nk_dot_f32c */
357
+ NK_PUBLIC void nk_dot_f32c_sve(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
358
+ /** @copydoc nk_vdot_f32c */
359
+ NK_PUBLIC void nk_vdot_f32c_sve(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
360
+ /** @copydoc nk_dot_f64 */
361
+ NK_PUBLIC void nk_dot_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
362
+ /** @copydoc nk_dot_f64c */
363
+ NK_PUBLIC void nk_dot_f64c_sve(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
364
+ /** @copydoc nk_vdot_f64c */
365
+ NK_PUBLIC void nk_vdot_f64c_sve(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
366
+ #endif // NK_TARGET_SVE
367
+
368
+ #if NK_TARGET_SVEHALF
369
+ /** @copydoc nk_dot_f16 */
370
+ NK_PUBLIC void nk_dot_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
371
+ /** @copydoc nk_dot_f16c */
372
+ NK_PUBLIC void nk_dot_f16c_svehalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
373
+ /** @copydoc nk_vdot_f16c */
374
+ NK_PUBLIC void nk_vdot_f16c_svehalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
375
+ #endif // NK_TARGET_SVEHALF
376
+
377
+ #if NK_TARGET_HASWELL
378
+ /** @copydoc nk_dot_f32 */
379
+ NK_PUBLIC void nk_dot_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
380
+ /** @copydoc nk_dot_f64 */
381
+ NK_PUBLIC void nk_dot_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
382
+ /** @copydoc nk_dot_f32c */
383
+ NK_PUBLIC void nk_dot_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
384
+ /** @copydoc nk_vdot_f32c */
385
+ NK_PUBLIC void nk_vdot_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
386
+ /** @copydoc nk_dot_f64c */
387
+ NK_PUBLIC void nk_dot_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
388
+ /** @copydoc nk_vdot_f64c */
389
+ NK_PUBLIC void nk_vdot_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
390
+
391
+ /** @copydoc nk_dot_f16 */
392
+ NK_PUBLIC void nk_dot_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
393
+ /** @copydoc nk_dot_f16c */
394
+ NK_PUBLIC void nk_dot_f16c_haswell(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
395
+ /** @copydoc nk_vdot_f16c */
396
+ NK_PUBLIC void nk_vdot_f16c_haswell(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
397
+
398
+ /** @copydoc nk_dot_bf16 */
399
+ NK_PUBLIC void nk_dot_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
400
+ /** @copydoc nk_dot_bf16c */
401
+ NK_PUBLIC void nk_dot_bf16c_haswell(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
402
+ /** @copydoc nk_vdot_bf16c */
403
+ NK_PUBLIC void nk_vdot_bf16c_haswell(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
404
+
405
+ /** @copydoc nk_dot_e4m3 */
406
+ NK_PUBLIC void nk_dot_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
407
+ /** @copydoc nk_dot_e5m2 */
408
+ NK_PUBLIC void nk_dot_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
409
+ /** @copydoc nk_dot_e2m3 */
410
+ NK_PUBLIC void nk_dot_e2m3_haswell(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
411
+ /** @copydoc nk_dot_e3m2 */
412
+ NK_PUBLIC void nk_dot_e3m2_haswell(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
413
+
414
+ /** @copydoc nk_dot_i8 */
415
+ NK_PUBLIC void nk_dot_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
416
+ /** @copydoc nk_dot_u8 */
417
+ NK_PUBLIC void nk_dot_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
418
+ /** @copydoc nk_dot_i4 */
419
+ NK_PUBLIC void nk_dot_i4_haswell(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
420
+ /** @copydoc nk_dot_u4 */
421
+ NK_PUBLIC void nk_dot_u4_haswell(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
422
+ /** @copydoc nk_dot_u1 */
423
+ NK_PUBLIC void nk_dot_u1_haswell(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
424
+
425
+ #endif // NK_TARGET_HASWELL
426
+
427
+ #if NK_TARGET_SKYLAKE
428
+ /** @copydoc nk_dot_f64 */
429
+ NK_PUBLIC void nk_dot_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
430
+ /** @copydoc nk_dot_f64c */
431
+ NK_PUBLIC void nk_dot_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
432
+ /** @copydoc nk_vdot_f64c */
433
+ NK_PUBLIC void nk_vdot_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
434
+
435
+ /** @copydoc nk_dot_f32 */
436
+ NK_PUBLIC void nk_dot_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
437
+ /** @copydoc nk_dot_f32c */
438
+ NK_PUBLIC void nk_dot_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
439
+ /** @copydoc nk_vdot_f32c */
440
+ NK_PUBLIC void nk_vdot_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
441
+
442
+ /** @copydoc nk_dot_f16 */
443
+ NK_PUBLIC void nk_dot_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
444
+ /** @copydoc nk_dot_bf16 */
445
+ NK_PUBLIC void nk_dot_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
446
+
447
+ /** @copydoc nk_dot_e4m3 */
448
+ NK_PUBLIC void nk_dot_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
449
+ /** @copydoc nk_dot_e5m2 */
450
+ NK_PUBLIC void nk_dot_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
451
+ /** @copydoc nk_dot_e2m3 */
452
+ NK_PUBLIC void nk_dot_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
453
+ /** @copydoc nk_dot_e3m2 */
454
+ NK_PUBLIC void nk_dot_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
455
+
456
+ /** @copydoc nk_dot_i8 */
457
+ NK_PUBLIC void nk_dot_i8_skylake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
458
+ /** @copydoc nk_dot_u8 */
459
+ NK_PUBLIC void nk_dot_u8_skylake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
460
+ #endif // NK_TARGET_SKYLAKE
461
+
462
+ #if NK_TARGET_ICELAKE
463
+ /** @copydoc nk_dot_i8 */
464
+ NK_PUBLIC void nk_dot_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
465
+ /** @copydoc nk_dot_u8 */
466
+ NK_PUBLIC void nk_dot_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
467
+ /** @copydoc nk_dot_i8 */
468
+ NK_PUBLIC void nk_dot_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
469
+ /** @copydoc nk_dot_u8 */
470
+ NK_PUBLIC void nk_dot_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
471
+ /** @copydoc nk_dot_e2m3 */
472
+ NK_PUBLIC void nk_dot_e2m3_icelake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
473
+ /** @copydoc nk_dot_e3m2 */
474
+ NK_PUBLIC void nk_dot_e3m2_icelake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
475
+ /** @copydoc nk_dot_u1 */
476
+ NK_PUBLIC void nk_dot_u1_icelake(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
477
+ #endif // NK_TARGET_ICELAKE
478
+
479
+ #if NK_TARGET_GENOA
480
+ /** @copydoc nk_dot_bf16 */
481
+ NK_PUBLIC void nk_dot_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
482
+ /** @copydoc nk_dot_bf16c */
483
+ NK_PUBLIC void nk_dot_bf16c_genoa(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
484
+ /** @copydoc nk_vdot_bf16c */
485
+ NK_PUBLIC void nk_vdot_bf16c_genoa(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
486
+
487
+ /** @copydoc nk_dot_e4m3 */
488
+ NK_PUBLIC void nk_dot_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
489
+ /** @copydoc nk_dot_e5m2 */
490
+ NK_PUBLIC void nk_dot_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
491
+ #endif // NK_TARGET_GENOA
492
+
493
+ #if NK_TARGET_ALDER
494
+ /** @copydoc nk_dot_i8 */
495
+ NK_PUBLIC void nk_dot_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
496
+ /** @copydoc nk_dot_u8 */
497
+ NK_PUBLIC void nk_dot_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
498
+ /** @copydoc nk_dot_e2m3 */
499
+ NK_PUBLIC void nk_dot_e2m3_alder(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
500
+ #endif // NK_TARGET_ALDER
501
+
502
+ #if NK_TARGET_SIERRA
503
+ /** @copydoc nk_dot_i8 */
504
+ NK_PUBLIC void nk_dot_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
505
+ /** @copydoc nk_dot_u8 */
506
+ NK_PUBLIC void nk_dot_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
507
+ /** @copydoc nk_dot_e2m3 */
508
+ NK_PUBLIC void nk_dot_e2m3_sierra(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
509
+ #endif // NK_TARGET_SIERRA
510
+
511
+ #if NK_TARGET_RVV
512
+ /** @copydoc nk_dot_f32 */
513
+ NK_PUBLIC void nk_dot_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
514
+ /** @copydoc nk_dot_f64 */
515
+ NK_PUBLIC void nk_dot_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
516
+ /** @copydoc nk_dot_f16 */
517
+ NK_PUBLIC void nk_dot_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
518
+ /** @copydoc nk_dot_bf16 */
519
+ NK_PUBLIC void nk_dot_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
520
+ /** @copydoc nk_dot_i8 */
521
+ NK_PUBLIC void nk_dot_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
522
+ /** @copydoc nk_dot_u8 */
523
+ NK_PUBLIC void nk_dot_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
524
+ /** @copydoc nk_dot_e4m3 */
525
+ NK_PUBLIC void nk_dot_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
526
+ /** @copydoc nk_dot_e5m2 */
527
+ NK_PUBLIC void nk_dot_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
528
+ /** @copydoc nk_dot_e2m3 */
529
+ NK_PUBLIC void nk_dot_e2m3_rvv(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
530
+ /** @copydoc nk_dot_e3m2 */
531
+ NK_PUBLIC void nk_dot_e3m2_rvv(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
532
+ /** @copydoc nk_dot_i4 */
533
+ NK_PUBLIC void nk_dot_i4_rvv(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
534
+ /** @copydoc nk_dot_u4 */
535
+ NK_PUBLIC void nk_dot_u4_rvv(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
536
+ /** @copydoc nk_dot_u1 */
537
+ NK_PUBLIC void nk_dot_u1_rvv(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
538
+ /** @copydoc nk_dot_f32c */
539
+ NK_PUBLIC void nk_dot_f32c_rvv(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
540
+ /** @copydoc nk_vdot_f32c */
541
+ NK_PUBLIC void nk_vdot_f32c_rvv(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
542
+ /** @copydoc nk_dot_f64c */
543
+ NK_PUBLIC void nk_dot_f64c_rvv(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
544
+ /** @copydoc nk_vdot_f64c */
545
+ NK_PUBLIC void nk_vdot_f64c_rvv(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
546
+ #endif // NK_TARGET_RVV
547
+
548
+ #if NK_TARGET_RVVHALF
549
+ /** @copydoc nk_dot_f16 */
550
+ NK_PUBLIC void nk_dot_f16_rvvhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
551
+ /** @copydoc nk_dot_e4m3 */
552
+ NK_PUBLIC void nk_dot_e4m3_rvvhalf(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
553
+ /** @copydoc nk_dot_e5m2 */
554
+ NK_PUBLIC void nk_dot_e5m2_rvvhalf(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
555
+ #endif // NK_TARGET_RVVHALF
556
+
557
+ #if NK_TARGET_RVVBF16
558
+ /** @copydoc nk_dot_bf16 */
559
+ NK_PUBLIC void nk_dot_bf16_rvvbf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
560
+ /** @copydoc nk_dot_e4m3 */
561
+ NK_PUBLIC void nk_dot_e4m3_rvvbf16(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
562
+ /** @copydoc nk_dot_e5m2 */
563
+ NK_PUBLIC void nk_dot_e5m2_rvvbf16(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
564
+ #endif // NK_TARGET_RVVBF16
565
+
566
+ #if NK_TARGET_RVVBB
567
+ /** @copydoc nk_dot_u1 */
568
+ NK_PUBLIC void nk_dot_u1_rvvbb(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
569
+ #endif // NK_TARGET_RVVBB
570
+
571
+ #if NK_TARGET_V128RELAXED
572
+ /** @copydoc nk_dot_f32 */
573
+ NK_PUBLIC void nk_dot_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
574
+ /** @copydoc nk_dot_f64 */
575
+ NK_PUBLIC void nk_dot_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
576
+ /** @copydoc nk_dot_f16 */
577
+ NK_PUBLIC void nk_dot_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
578
+ /** @copydoc nk_dot_bf16 */
579
+ NK_PUBLIC void nk_dot_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
580
+ /** @copydoc nk_dot_i8 */
581
+ NK_PUBLIC void nk_dot_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
582
+ /** @copydoc nk_dot_u8 */
583
+ NK_PUBLIC void nk_dot_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
584
+ /** @copydoc nk_dot_e2m3 */
585
+ NK_PUBLIC void nk_dot_e2m3_v128relaxed(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
586
+ /** @copydoc nk_dot_e3m2 */
587
+ NK_PUBLIC void nk_dot_e3m2_v128relaxed(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
588
+ /** @copydoc nk_dot_u1 */
589
+ NK_PUBLIC void nk_dot_u1_v128relaxed(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
590
+ /** @copydoc nk_dot_f32 */
591
+ NK_PUBLIC void nk_dot_e4m3_v128relaxed(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
592
+ /** @copydoc nk_dot_f32 */
593
+ NK_PUBLIC void nk_dot_e5m2_v128relaxed(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
594
+ /** @copydoc nk_dot_i4 */
595
+ NK_PUBLIC void nk_dot_i4_v128relaxed(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
596
+ /** @copydoc nk_dot_u4 */
597
+ NK_PUBLIC void nk_dot_u4_v128relaxed(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
598
+ /** @copydoc nk_dot_f32c */
599
+ NK_PUBLIC void nk_dot_f32c_v128relaxed(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
600
+ /** @copydoc nk_dot_f32c */
601
+ NK_PUBLIC void nk_vdot_f32c_v128relaxed(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
602
+ /** @copydoc nk_dot_f64c */
603
+ NK_PUBLIC void nk_dot_f64c_v128relaxed(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
604
+ /** @copydoc nk_dot_f64c */
605
+ NK_PUBLIC void nk_vdot_f64c_v128relaxed(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
606
+ #endif // NK_TARGET_V128RELAXED
607
+
608
+ /**
609
+ * @brief Returns the output dtype for dot products.
610
+ */
611
+ NK_INTERNAL nk_dtype_t nk_dot_output_dtype(nk_dtype_t dtype) {
612
+ switch (dtype) {
613
+ case nk_f64_k: return nk_f64_k;
614
+ case nk_f32_k: return nk_f64_k;
615
+ case nk_f16_k: return nk_f32_k;
616
+ case nk_bf16_k: return nk_f32_k;
617
+ case nk_e4m3_k: return nk_f32_k;
618
+ case nk_e5m2_k: return nk_f32_k;
619
+ case nk_e2m3_k: return nk_f32_k;
620
+ case nk_e3m2_k: return nk_f32_k;
621
+ case nk_f64c_k: return nk_f64c_k;
622
+ case nk_f32c_k: return nk_f64c_k;
623
+ case nk_f16c_k: return nk_f32c_k;
624
+ case nk_bf16c_k: return nk_f32c_k;
625
+ case nk_i8_k: return nk_i32_k;
626
+ case nk_u8_k: return nk_u32_k;
627
+ case nk_i4_k: return nk_i32_k;
628
+ case nk_u4_k: return nk_u32_k;
629
+ case nk_u1_k: return nk_u32_k;
630
+ default: return nk_dtype_unknown_k;
631
+ }
632
+ }
633
+
634
+ #if defined(__cplusplus)
635
+ } // extern "C"
636
+ #endif
637
+
638
+ #include "numkong/dot/serial.h"
639
+ #include "numkong/dot/neon.h"
640
+ #include "numkong/dot/neonsdot.h"
641
+ #include "numkong/dot/neonhalf.h"
642
+ #include "numkong/dot/neonfhm.h"
643
+ #include "numkong/dot/neonbfdot.h"
644
+ #include "numkong/dot/sve.h"
645
+ #include "numkong/dot/svehalf.h"
646
+ #include "numkong/dot/svebfdot.h"
647
+ #include "numkong/dot/haswell.h"
648
+ #include "numkong/dot/skylake.h"
649
+ #include "numkong/dot/icelake.h"
650
+ #include "numkong/dot/genoa.h"
651
+ #include "numkong/dot/sapphire.h"
652
+ #include "numkong/dot/alder.h"
653
+ #include "numkong/dot/sierra.h"
654
+ #include "numkong/dot/rvv.h"
655
+ #include "numkong/dot/rvvbb.h"
656
+ #include "numkong/dot/rvvhalf.h"
657
+ #include "numkong/dot/rvvbf16.h"
658
+ #include "numkong/dot/v128relaxed.h"
659
+
660
+ #if defined(__cplusplus)
661
+ extern "C" {
662
+ #endif
663
+
664
+ #if !NK_DYNAMIC_DISPATCH
665
+
666
+ NK_PUBLIC void nk_dot_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result) {
667
+ #if NK_TARGET_V128RELAXED
668
+ nk_dot_i8_v128relaxed(a, b, n, result);
669
+ #elif NK_TARGET_RVV
670
+ nk_dot_i8_rvv(a, b, n, result);
671
+ #elif NK_TARGET_NEONSDOT
672
+ nk_dot_i8_neonsdot(a, b, n, result);
673
+ #elif NK_TARGET_ICELAKE
674
+ nk_dot_i8_icelake(a, b, n, result);
675
+ #elif NK_TARGET_SKYLAKE
676
+ nk_dot_i8_skylake(a, b, n, result);
677
+ #elif NK_TARGET_SIERRA
678
+ nk_dot_i8_sierra(a, b, n, result);
679
+ #elif NK_TARGET_ALDER
680
+ nk_dot_i8_alder(a, b, n, result);
681
+ #elif NK_TARGET_HASWELL
682
+ nk_dot_i8_haswell(a, b, n, result);
683
+ #else
684
+ nk_dot_i8_serial(a, b, n, result);
685
+ #endif
686
+ }
687
+
688
+ NK_PUBLIC void nk_dot_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
689
+ #if NK_TARGET_V128RELAXED
690
+ nk_dot_u8_v128relaxed(a, b, n, result);
691
+ #elif NK_TARGET_RVV
692
+ nk_dot_u8_rvv(a, b, n, result);
693
+ #elif NK_TARGET_NEONSDOT
694
+ nk_dot_u8_neonsdot(a, b, n, result);
695
+ #elif NK_TARGET_ICELAKE
696
+ nk_dot_u8_icelake(a, b, n, result);
697
+ #elif NK_TARGET_SKYLAKE
698
+ nk_dot_u8_skylake(a, b, n, result);
699
+ #elif NK_TARGET_SIERRA
700
+ nk_dot_u8_sierra(a, b, n, result);
701
+ #elif NK_TARGET_ALDER
702
+ nk_dot_u8_alder(a, b, n, result);
703
+ #elif NK_TARGET_HASWELL
704
+ nk_dot_u8_haswell(a, b, n, result);
705
+ #else
706
+ nk_dot_u8_serial(a, b, n, result);
707
+ #endif
708
+ }
709
+
710
+ NK_PUBLIC void nk_dot_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result) {
711
+ #if NK_TARGET_ICELAKE
712
+ nk_dot_i4_icelake(a, b, n, result);
713
+ #elif NK_TARGET_NEONSDOT
714
+ nk_dot_i4_neonsdot(a, b, n, result);
715
+ #elif NK_TARGET_RVV
716
+ nk_dot_i4_rvv(a, b, n, result);
717
+ #elif NK_TARGET_HASWELL
718
+ nk_dot_i4_haswell(a, b, n, result);
719
+ #elif NK_TARGET_V128RELAXED
720
+ nk_dot_i4_v128relaxed(a, b, n, result);
721
+ #else
722
+ nk_dot_i4_serial(a, b, n, result);
723
+ #endif
724
+ }
725
+
726
+ NK_PUBLIC void nk_dot_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
727
+ #if NK_TARGET_ICELAKE
728
+ nk_dot_u4_icelake(a, b, n, result);
729
+ #elif NK_TARGET_NEONSDOT
730
+ nk_dot_u4_neonsdot(a, b, n, result);
731
+ #elif NK_TARGET_RVV
732
+ nk_dot_u4_rvv(a, b, n, result);
733
+ #elif NK_TARGET_HASWELL
734
+ nk_dot_u4_haswell(a, b, n, result);
735
+ #elif NK_TARGET_V128RELAXED
736
+ nk_dot_u4_v128relaxed(a, b, n, result);
737
+ #else
738
+ nk_dot_u4_serial(a, b, n, result);
739
+ #endif
740
+ }
741
+
742
+ NK_PUBLIC void nk_dot_u1(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
743
+ #if NK_TARGET_ICELAKE
744
+ nk_dot_u1_icelake(a, b, n_bits, result);
745
+ #elif NK_TARGET_HASWELL
746
+ nk_dot_u1_haswell(a, b, n_bits, result);
747
+ #elif NK_TARGET_V128RELAXED
748
+ nk_dot_u1_v128relaxed(a, b, n_bits, result);
749
+ #elif NK_TARGET_RVVBB
750
+ nk_dot_u1_rvvbb(a, b, n_bits, result);
751
+ #elif NK_TARGET_RVV
752
+ nk_dot_u1_rvv(a, b, n_bits, result);
753
+ #elif NK_TARGET_NEON
754
+ nk_dot_u1_neon(a, b, n_bits, result);
755
+ #else
756
+ nk_dot_u1_serial(a, b, n_bits, result);
757
+ #endif
758
+ }
759
+
760
+ NK_PUBLIC void nk_dot_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
761
+ #if NK_TARGET_V128RELAXED
762
+ nk_dot_f16_v128relaxed(a, b, n, result);
763
+ #elif NK_TARGET_RVVHALF
764
+ nk_dot_f16_rvvhalf(a, b, n, result);
765
+ #elif NK_TARGET_RVV
766
+ nk_dot_f16_rvv(a, b, n, result);
767
+ #elif NK_TARGET_SVEHALF
768
+ nk_dot_f16_svehalf(a, b, n, result);
769
+ #elif NK_TARGET_NEONFHM
770
+ nk_dot_f16_neonfhm(a, b, n, result);
771
+ #elif NK_TARGET_NEONHALF
772
+ nk_dot_f16_neonhalf(a, b, n, result);
773
+ #elif NK_TARGET_NEON
774
+ nk_dot_f16_neon(a, b, n, result);
775
+ #elif NK_TARGET_SKYLAKE
776
+ nk_dot_f16_skylake(a, b, n, result);
777
+ #elif NK_TARGET_HASWELL
778
+ nk_dot_f16_haswell(a, b, n, result);
779
+ #else
780
+ nk_dot_f16_serial(a, b, n, result);
781
+ #endif
782
+ }
783
+
784
+ NK_PUBLIC void nk_dot_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
785
+ #if NK_TARGET_V128RELAXED
786
+ nk_dot_bf16_v128relaxed(a, b, n, result);
787
+ #elif NK_TARGET_GENOA
788
+ nk_dot_bf16_genoa(a, b, n, result);
789
+ #elif NK_TARGET_RVVBF16
790
+ nk_dot_bf16_rvvbf16(a, b, n, result);
791
+ #elif NK_TARGET_RVV
792
+ nk_dot_bf16_rvv(a, b, n, result);
793
+ #elif NK_TARGET_SKYLAKE
794
+ nk_dot_bf16_skylake(a, b, n, result);
795
+ #elif NK_TARGET_HASWELL
796
+ nk_dot_bf16_haswell(a, b, n, result);
797
+ #elif NK_TARGET_SVEBFDOT
798
+ nk_dot_bf16_svebfdot(a, b, n, result);
799
+ #elif NK_TARGET_NEONBFDOT
800
+ nk_dot_bf16_neonbfdot(a, b, n, result);
801
+ #elif NK_TARGET_NEON
802
+ nk_dot_bf16_neon(a, b, n, result);
803
+ #else
804
+ nk_dot_bf16_serial(a, b, n, result);
805
+ #endif
806
+ }
807
+
808
+ NK_PUBLIC void nk_dot_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
809
+ #if NK_TARGET_GENOA
810
+ nk_dot_e4m3_genoa(a, b, n, result);
811
+ #elif NK_TARGET_NEONBFDOT
812
+ nk_dot_e4m3_neonbfdot(a, b, n, result);
813
+ #elif NK_TARGET_NEONFHM
814
+ nk_dot_e4m3_neonfhm(a, b, n, result);
815
+ #elif NK_TARGET_RVVHALF
816
+ nk_dot_e4m3_rvvhalf(a, b, n, result);
817
+ #elif NK_TARGET_RVVBF16
818
+ nk_dot_e4m3_rvvbf16(a, b, n, result);
819
+ #elif NK_TARGET_RVV
820
+ nk_dot_e4m3_rvv(a, b, n, result);
821
+ #elif NK_TARGET_V128RELAXED
822
+ nk_dot_e4m3_v128relaxed(a, b, n, result);
823
+ #elif NK_TARGET_SKYLAKE
824
+ nk_dot_e4m3_skylake(a, b, n, result);
825
+ #elif NK_TARGET_HASWELL
826
+ nk_dot_e4m3_haswell(a, b, n, result);
827
+ #elif NK_TARGET_NEON
828
+ nk_dot_e4m3_neon(a, b, n, result);
829
+ #else
830
+ nk_dot_e4m3_serial(a, b, n, result);
831
+ #endif
832
+ }
833
+
834
+ NK_PUBLIC void nk_dot_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
835
+ #if NK_TARGET_GENOA
836
+ nk_dot_e5m2_genoa(a, b, n, result);
837
+ #elif NK_TARGET_NEONBFDOT
838
+ nk_dot_e5m2_neonbfdot(a, b, n, result);
839
+ #elif NK_TARGET_NEONFHM
840
+ nk_dot_e5m2_neonfhm(a, b, n, result);
841
+ #elif NK_TARGET_RVVHALF
842
+ nk_dot_e5m2_rvvhalf(a, b, n, result);
843
+ #elif NK_TARGET_RVVBF16
844
+ nk_dot_e5m2_rvvbf16(a, b, n, result);
845
+ #elif NK_TARGET_RVV
846
+ nk_dot_e5m2_rvv(a, b, n, result);
847
+ #elif NK_TARGET_V128RELAXED
848
+ nk_dot_e5m2_v128relaxed(a, b, n, result);
849
+ #elif NK_TARGET_SKYLAKE
850
+ nk_dot_e5m2_skylake(a, b, n, result);
851
+ #elif NK_TARGET_HASWELL
852
+ nk_dot_e5m2_haswell(a, b, n, result);
853
+ #elif NK_TARGET_NEON
854
+ nk_dot_e5m2_neon(a, b, n, result);
855
+ #else
856
+ nk_dot_e5m2_serial(a, b, n, result);
857
+ #endif
858
+ }
859
+
860
+ NK_PUBLIC void nk_dot_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
861
+ #if NK_TARGET_ICELAKE
862
+ nk_dot_e2m3_icelake(a, b, n, result);
863
+ #elif NK_TARGET_SKYLAKE
864
+ nk_dot_e2m3_skylake(a, b, n, result);
865
+ #elif NK_TARGET_SIERRA
866
+ nk_dot_e2m3_sierra(a, b, n, result);
867
+ #elif NK_TARGET_ALDER
868
+ nk_dot_e2m3_alder(a, b, n, result);
869
+ #elif NK_TARGET_RVV
870
+ nk_dot_e2m3_rvv(a, b, n, result);
871
+ #elif NK_TARGET_HASWELL
872
+ nk_dot_e2m3_haswell(a, b, n, result);
873
+ #elif NK_TARGET_NEONSDOT
874
+ nk_dot_e2m3_neonsdot(a, b, n, result);
875
+ #elif NK_TARGET_NEON
876
+ nk_dot_e2m3_neon(a, b, n, result);
877
+ #elif NK_TARGET_V128RELAXED
878
+ nk_dot_e2m3_v128relaxed(a, b, n, result);
879
+ #else
880
+ nk_dot_e2m3_serial(a, b, n, result);
881
+ #endif
882
+ }
883
+
884
+ NK_PUBLIC void nk_dot_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
885
+ #if NK_TARGET_ICELAKE
886
+ nk_dot_e3m2_icelake(a, b, n, result);
887
+ #elif NK_TARGET_NEONSDOT
888
+ nk_dot_e3m2_neonsdot(a, b, n, result);
889
+ #elif NK_TARGET_V128RELAXED
890
+ nk_dot_e3m2_v128relaxed(a, b, n, result);
891
+ #elif NK_TARGET_RVV
892
+ nk_dot_e3m2_rvv(a, b, n, result);
893
+ #elif NK_TARGET_SKYLAKE
894
+ nk_dot_e3m2_skylake(a, b, n, result);
895
+ #elif NK_TARGET_HASWELL
896
+ nk_dot_e3m2_haswell(a, b, n, result);
897
+ #elif NK_TARGET_NEON
898
+ nk_dot_e3m2_neon(a, b, n, result);
899
+ #else
900
+ nk_dot_e3m2_serial(a, b, n, result);
901
+ #endif
902
+ }
903
+
904
+ NK_PUBLIC void nk_dot_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
905
+ #if NK_TARGET_V128RELAXED
906
+ nk_dot_f32_v128relaxed(a, b, n, result);
907
+ #elif NK_TARGET_RVV
908
+ nk_dot_f32_rvv(a, b, n, result);
909
+ #elif NK_TARGET_SVE
910
+ nk_dot_f32_sve(a, b, n, result);
911
+ #elif NK_TARGET_NEON
912
+ nk_dot_f32_neon(a, b, n, result);
913
+ #elif NK_TARGET_SKYLAKE
914
+ nk_dot_f32_skylake(a, b, n, result);
915
+ #elif NK_TARGET_HASWELL
916
+ nk_dot_f32_haswell(a, b, n, result);
917
+ #else
918
+ nk_dot_f32_serial(a, b, n, result);
919
+ #endif
920
+ }
921
+
922
+ NK_PUBLIC void nk_dot_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
923
+ #if NK_TARGET_V128RELAXED
924
+ nk_dot_f64_v128relaxed(a, b, n, result);
925
+ #elif NK_TARGET_RVV
926
+ nk_dot_f64_rvv(a, b, n, result);
927
+ #elif NK_TARGET_SVE
928
+ nk_dot_f64_sve(a, b, n, result);
929
+ #elif NK_TARGET_NEON
930
+ nk_dot_f64_neon(a, b, n, result);
931
+ #elif NK_TARGET_SKYLAKE
932
+ nk_dot_f64_skylake(a, b, n, result);
933
+ #elif NK_TARGET_HASWELL
934
+ nk_dot_f64_haswell(a, b, n, result);
935
+ #else
936
+ nk_dot_f64_serial(a, b, n, result);
937
+ #endif
938
+ }
939
+
940
+ NK_PUBLIC void nk_dot_f16c(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result) {
941
+ #if NK_TARGET_SVEHALF
942
+ nk_dot_f16c_svehalf(a, b, n, result);
943
+ #elif NK_TARGET_NEONFHM
944
+ nk_dot_f16c_neonfhm(a, b, n, result);
945
+ #elif NK_TARGET_NEONHALF
946
+ nk_dot_f16c_neonhalf(a, b, n, result);
947
+ #elif NK_TARGET_HASWELL
948
+ nk_dot_f16c_haswell(a, b, n, result);
949
+ #else
950
+ nk_dot_f16c_serial(a, b, n, result);
951
+ #endif
952
+ }
953
+
954
+ NK_PUBLIC void nk_dot_bf16c(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result) {
955
+ #if NK_TARGET_GENOA
956
+ nk_dot_bf16c_genoa(a, b, n, result);
957
+ #elif NK_TARGET_NEONBFDOT
958
+ nk_dot_bf16c_neonbfdot(a, b, n, result);
959
+ #elif NK_TARGET_HASWELL
960
+ nk_dot_bf16c_haswell(a, b, n, result);
961
+ #else
962
+ nk_dot_bf16c_serial(a, b, n, result);
963
+ #endif
964
+ }
965
+
966
+ NK_PUBLIC void nk_dot_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result) {
967
+ #if NK_TARGET_SVE
968
+ nk_dot_f32c_sve(a, b, n, result);
969
+ #elif NK_TARGET_NEON
970
+ nk_dot_f32c_neon(a, b, n, result);
971
+ #elif NK_TARGET_RVV
972
+ nk_dot_f32c_rvv(a, b, n, result);
973
+ #elif NK_TARGET_SKYLAKE
974
+ nk_dot_f32c_skylake(a, b, n, result);
975
+ #elif NK_TARGET_HASWELL
976
+ nk_dot_f32c_haswell(a, b, n, result);
977
+ #elif NK_TARGET_V128RELAXED
978
+ nk_dot_f32c_v128relaxed(a, b, n, result);
979
+ #else
980
+ nk_dot_f32c_serial(a, b, n, result);
981
+ #endif
982
+ }
983
+
984
+ NK_PUBLIC void nk_dot_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result) {
985
+ #if NK_TARGET_SVE
986
+ nk_dot_f64c_sve(a, b, n, result);
987
+ #elif NK_TARGET_NEON
988
+ nk_dot_f64c_neon(a, b, n, result);
989
+ #elif NK_TARGET_RVV
990
+ nk_dot_f64c_rvv(a, b, n, result);
991
+ #elif NK_TARGET_SKYLAKE
992
+ nk_dot_f64c_skylake(a, b, n, result);
993
+ #elif NK_TARGET_HASWELL
994
+ nk_dot_f64c_haswell(a, b, n, result);
995
+ #elif NK_TARGET_V128RELAXED
996
+ nk_dot_f64c_v128relaxed(a, b, n, result);
997
+ #else
998
+ nk_dot_f64c_serial(a, b, n, result);
999
+ #endif
1000
+ }
1001
+
1002
+ NK_PUBLIC void nk_vdot_f16c(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result) {
1003
+ #if NK_TARGET_SVEHALF
1004
+ nk_vdot_f16c_svehalf(a, b, n, result);
1005
+ #elif NK_TARGET_NEONFHM
1006
+ nk_vdot_f16c_neonfhm(a, b, n, result);
1007
+ #elif NK_TARGET_NEONHALF
1008
+ nk_vdot_f16c_neonhalf(a, b, n, result);
1009
+ #elif NK_TARGET_HASWELL
1010
+ nk_vdot_f16c_haswell(a, b, n, result);
1011
+ #else
1012
+ nk_vdot_f16c_serial(a, b, n, result);
1013
+ #endif
1014
+ }
1015
+
1016
+ NK_PUBLIC void nk_vdot_bf16c(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result) {
1017
+ #if NK_TARGET_GENOA
1018
+ nk_vdot_bf16c_genoa(a, b, n, result);
1019
+ #elif NK_TARGET_NEONBFDOT
1020
+ nk_vdot_bf16c_neonbfdot(a, b, n, result);
1021
+ #elif NK_TARGET_HASWELL
1022
+ nk_vdot_bf16c_haswell(a, b, n, result);
1023
+ #else
1024
+ nk_vdot_bf16c_serial(a, b, n, result);
1025
+ #endif
1026
+ }
1027
+
1028
+ NK_PUBLIC void nk_vdot_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result) {
1029
+ #if NK_TARGET_SVE
1030
+ nk_vdot_f32c_sve(a, b, n, result);
1031
+ #elif NK_TARGET_NEON
1032
+ nk_vdot_f32c_neon(a, b, n, result);
1033
+ #elif NK_TARGET_RVV
1034
+ nk_vdot_f32c_rvv(a, b, n, result);
1035
+ #elif NK_TARGET_SKYLAKE
1036
+ nk_vdot_f32c_skylake(a, b, n, result);
1037
+ #elif NK_TARGET_HASWELL
1038
+ nk_vdot_f32c_haswell(a, b, n, result);
1039
+ #elif NK_TARGET_V128RELAXED
1040
+ nk_vdot_f32c_v128relaxed(a, b, n, result);
1041
+ #else
1042
+ nk_vdot_f32c_serial(a, b, n, result);
1043
+ #endif
1044
+ }
1045
+
1046
+ NK_PUBLIC void nk_vdot_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result) {
1047
+ #if NK_TARGET_SVE
1048
+ nk_vdot_f64c_sve(a, b, n, result);
1049
+ #elif NK_TARGET_NEON
1050
+ nk_vdot_f64c_neon(a, b, n, result);
1051
+ #elif NK_TARGET_RVV
1052
+ nk_vdot_f64c_rvv(a, b, n, result);
1053
+ #elif NK_TARGET_SKYLAKE
1054
+ nk_vdot_f64c_skylake(a, b, n, result);
1055
+ #elif NK_TARGET_HASWELL
1056
+ nk_vdot_f64c_haswell(a, b, n, result);
1057
+ #elif NK_TARGET_V128RELAXED
1058
+ nk_vdot_f64c_v128relaxed(a, b, n, result);
1059
+ #else
1060
+ nk_vdot_f64c_serial(a, b, n, result);
1061
+ #endif
1062
+ }
1063
+
1064
+ #endif // !NK_DYNAMIC_DISPATCH
1065
+
1066
+ #if defined(__cplusplus)
1067
+ } // extern "C"
1068
+ #endif
1069
+
1070
+ #endif