numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,2804 @@
1
+ /**
2
+ * @brief SIMD-accelerated Batched Dot Products.
3
+ * @file include/numkong/dots.h
4
+ * @author Ash Vardanian
5
+ * @date September 14, 2024
6
+ *
7
+ * Implements batch dot-product kernels computing C[m × n] = A[m × k] × B[n × k]ᵀ
8
+ * with row-major A and arbitrary B, optimized for ML inference and similarity workloads.
9
+ *
10
+ * Primary Use Cases (1-to-N focus):
11
+ *
12
+ * - k-NN search: ‖a-b‖² = ‖a‖² + ‖b‖² - 2(a × b)
13
+ * - Cosine similarity: (a × b) / (‖a‖ × ‖b‖)
14
+ * - Sparse attention patterns
15
+ * - Embedding similarity matrices
16
+ * - k-means clustering, DBSCAN, hierarchical clustering
17
+ *
18
+ * It implements several operations:
19
+ *
20
+ * - "dots_packed" - computing dot-products where the B matrix is pre-packed into optimal form
21
+ * - "dots_packed_size" - which estimates the memory requirements for external `malloc`
22
+ * - "dots_pack" - to perform the pre-processing
23
+ * - "dots_compact" - optional helpers to normalize or downcast into original precision
24
+ * - "dots_symmetric" - for A × Aᵀ Gram matrix multiplication
25
+ *
26
+ * If the original "dots_packed" is analogous to "GEMM" (General Matrix Multiplication) in BLAS,
27
+ * the "dots_symmetric" is similar to the "SYRK" (the Symmetric rank-k update of a matrix).
28
+ *
29
+ * For dtypes:
30
+ *
31
+ * - f64: 64-bit IEEE floating point numbers → 64-bit floats
32
+ * - f32: 32-bit IEEE floating point numbers → 64-bit floats
33
+ * - f16: 16-bit IEEE floating point numbers → 32-bit floats
34
+ * - bf16: 16-bit brain floating point numbers → 32-bit floats
35
+ * - e4m3: 8-bit e4m3 floating point numbers → 32-bit floats
36
+ * - e5m2: 8-bit e5m2 floating point numbers → 32-bit floats
37
+ * - e2m3: 8-bit e2m3 floating point numbers (MX) → 32-bit floats
38
+ * - e3m2: 8-bit e3m2 floating point numbers (MX) → 32-bit floats
39
+ * - i8: 8-bit signed integers → 32-bit signed integers
40
+ * - u8: 8-bit unsigned integers → 32-bit unsigned integers
41
+ * - i4: 4-bit signed integers (packed pairs) → 32-bit signed integers
42
+ * - u4: 4-bit unsigned integers (packed pairs) → 32-bit unsigned integers
43
+ * - u1: 1-bit binary (packed octets) → 32-bit unsigned integers
44
+ *
45
+ * For hardware architectures:
46
+ *
47
+ * - Arm: NEON, NEON+HALF, NEON+FHM, NEON+BF16, NEON+SDOT, SVE, SME, SME+F64, SME+BI32
48
+ * - x86: Haswell, Skylake, Ice Lake, Genoa, Sapphire Rapids (AMX), Sierra Forest
49
+ * - RISC-V: RVV
50
+ *
51
+ * @section numerical_stability Numerical Stability
52
+ *
53
+ * - f64: Dot2 (Ogita-Rump-Oishi) on the accurate backends, otherwise native f64 FMA accumulation.
54
+ * - f32: public outputs widen to f64. Packed and symmetric kernels keep payloads narrow but widen accumulation.
55
+ * - bf16/f16: f32 accumulation. VDPBF16PS on Genoa does bf16×bf16→f32 natively.
56
+ * - e2m3/e3m2: f16 intermediate with flush to f32 every 128 elements (Sapphire).
57
+ * - i8: i32 accumulation. AMX TDPBSSD gives i8×i8→i32 tiles. Overflows at k > ~131K.
58
+ * - u1: Popcount, exact.
59
+ *
60
+ * @section memory_layout Memory Layout and Transpose Semantics
61
+ *
62
+ * All matrices use row-major storage. Column-major is NOT supported.
63
+ * The kernel computes C = A × Bᵀ where:
64
+ *
65
+ * - A is (m × k): m rows, k columns, stride = a_stride bytes between rows
66
+ * - B is (n × k): n rows, k columns, stride = b_stride bytes between rows
67
+ * - C is (m × n): m rows, n columns, stride = c_stride bytes between rows
68
+ *
69
+ * This means C[i,j] = dot(row i of A, row j of B) = Σₗ A[i,l] × B[j,l].
70
+ *
71
+ * All strides are in bytes.
72
+ *
73
+ * To compute standard A × B (where B is k × n), pass Bᵀ to the packing function:
74
+ *
75
+ * @code{.c}
76
+ * // Standard matmul: C[m × n] = A[m × k] × B[k × n]
77
+ * // B is stored row-major as k rows of n elements
78
+ * // Treat it as Bᵀ: n rows of k elements with stride = sizeof(element)
79
+ * nk_dots_pack_bf16(b, width, depth, sizeof(nk_bf16_t), b_packed);
80
+ * nk_dots_packed_bf16(a, b_packed, c, height, width, depth, a_stride, c_stride);
81
+ * // Result: C = A × (Bᵀ)ᵀ = A × B
82
+ * @endcode
83
+ *
84
+ * @section two_phase_api Two-Phase API for Static Weights
85
+ *
86
+ * Matrix multiplication hardware (AMX, SME) requires specific data layouts that differ
87
+ * from standard row-major ordering. Since one matrix (typically weights in neural networks)
88
+ * is often static, we provide a two-phase API: pack once, multiply many times.
89
+ *
90
+ * @code{.c}
91
+ * // Similarity search: C[m × n] = queries[m × k] × database[n × k]ᵀ
92
+ * // Both matrices stored row-major, each row is one vector of dimension k
93
+ * nk_size_t packed_bytes = nk_dots_packed_size_bf16(width, depth);
94
+ * void *b_packed = malloc(packed_bytes);
95
+ * nk_dots_pack_bf16(database, width, depth, depth * sizeof(nk_bf16_t), b_packed);
96
+ * nk_dots_packed_bf16(queries, b_packed, c, height, width, depth, ...);
97
+ * // Result: C[i,j] = dot(query i, database vector j)
98
+ * @endcode
99
+ *
100
+ * The packed format is opaque and backend-specific. AMX expects (16 × 32) tiles with interleaved
101
+ * pairs, while NEON/SVE use arrangements optimized for their vector lengths.
102
+ *
103
+ * @section why_int8 Why INT8 and Not UINT8?
104
+ *
105
+ * Unsigned 8-bit integers were considered but deprioritized. The industry has converged on
106
+ * signed INT8 as the standard for quantized inference:
107
+ *
108
+ * Framework Default Notes
109
+ * PyTorch qint8 New X86 backend uses INT8 via oneDNN
110
+ * TensorFlow Lite int8 Actively removing UINT8 support
111
+ * ONNX Runtime S8S8 "Should be the first choice"
112
+ * TensorRT INT8 Symmetric [-128,127], no UINT8 option
113
+ * ARM CMSIS-NN int8 Follows TFLite INT8 spec exactly
114
+ *
115
+ * @section why_no_scaling Why No Alpha/Beta Scaling?
116
+ *
117
+ * BLAS-style `C = α × A × B + β × C` scaling was considered but omitted. While useful for scientific
118
+ * computing (iterative solvers, matrix factorizations), it's rarely used in ML inference where
119
+ * frameworks handle such operations via graph fusion. More importantly, on chips with separate
120
+ * physical registers for vector and matrix operations (like AMX), moving scalars between register
121
+ * files adds transfer latency that negates any benefit.
122
+ *
123
+ * @section why_no_pad Why Not Pad N Dimension to Eliminate Edge Handling?
124
+ *
125
+ * Padding N to a tile-aligned boundary (multiple of 16) during packing was considered to eliminate
126
+ * the separate AVX-512 edge kernel for N remainder rows. While this sounds simpler ("pure AMX"),
127
+ * it actually increases code size by ~125 lines because:
128
+ *
129
+ * - The AVX-512 edge fallback is compact (~40 lines) and handles both full-M × N-edge and
130
+ * M-edge × N-edge cases through a single reusable function
131
+ * - Replacing it with "AMX + masked stores" requires verbose tile handling code duplicated
132
+ * across all 4 multiply functions (aligned/misaligned × BF16/I8)
133
+ * - Each function needs a new "trailing N tile for full M blocks" section (~50 lines each)
134
+ *
135
+ * The current hybrid layout (AMX for full tiles, AVX-512 for edges) is more maintainable despite
136
+ * being conceptually less uniform. Memory overhead of the edge region is negligible (<2% worst case).
137
+ *
138
+ * @section x86_instructions Relevant x86 Instructions
139
+ *
140
+ * Low-precision matmul relies on VPMADD* (AVX2), VNNI dot-products, and BF16 dot-products
141
+ * on AVX-512. Zen4 improves throughput by dual-issuing many integer ops on FP ports.
142
+ *
143
+ * Intrinsic Instruction Haswell Genoa
144
+ * _mm256_maddubs_epi16 VPMADDUBSW (YMM, YMM, YMM) 5c @ p0 3c @ p01
145
+ * _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 5c @ p0 3c @ p01
146
+ * _mm256_dpbusd_epi32 VPDPBUSD (YMM, K, YMM, YMM) n/a 4c @ p01
147
+ * _mm256_dpwssds_epi32 VPDPWSSDS (YMM, K, YMM, YMM) n/a 4c @ p01
148
+ * _mm256_dpbf16_ps VDPBF16PS (YMM, YMM, YMM) n/a 6c @ p01
149
+ *
150
+ * AMX tile ops (TDPBF16PS/TDPBUSD/TDPBSSD) are not covered by the uops.info 2022 dataset.
151
+ *
152
+ * @section references References
153
+ *
154
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
155
+ * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
156
+ * - uops.info: https://uops.info/
157
+ * - Matrix Multiplication in 40 lines: https://en.algorithmica.org/hpc/algorithms/matmul/
158
+ * - LLaMA CPU optimization: https://justine.lol/matmul/
159
+ * - SME outer-product notes: https://github.com/tzakharko/m4-sme-exploration
160
+ *
161
+ */
162
+ #ifndef NK_DOTS_H
163
+ #define NK_DOTS_H
164
+
165
+ #include "numkong/types.h"
166
+
167
+ #if defined(__cplusplus)
168
+ extern "C" {
169
+ #endif
170
+
171
+ /**
172
+ * @brief Returns packed buffer size in bytes for second multiplier matrix (B).
173
+ * @param[in] width The number of rows in B (output columns).
174
+ * @param[in] depth The number of columns in B.
175
+ * @note The packed layout is backend-specific and must be produced by the matching pack function.
176
+ */
177
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_bf16(nk_size_t width, nk_size_t depth);
178
+ /** @copydoc nk_dots_packed_size_bf16 */
179
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth);
180
+ /** @copydoc nk_dots_packed_size_bf16 */
181
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_e4m3(nk_size_t width, nk_size_t depth);
182
+ /** @copydoc nk_dots_packed_size_bf16 */
183
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_e5m2(nk_size_t width, nk_size_t depth);
184
+ /** @copydoc nk_dots_packed_size_bf16 */
185
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_e2m3(nk_size_t width, nk_size_t depth);
186
+ /** @copydoc nk_dots_packed_size_bf16 */
187
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_e3m2(nk_size_t width, nk_size_t depth);
188
+ /** @copydoc nk_dots_packed_size_bf16 */
189
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_f32(nk_size_t width, nk_size_t depth);
190
+ /** @copydoc nk_dots_packed_size_bf16 */
191
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_f64(nk_size_t width, nk_size_t depth);
192
+ /** @copydoc nk_dots_packed_size_bf16 */
193
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_i8(nk_size_t width, nk_size_t depth);
194
+ /** @copydoc nk_dots_packed_size_bf16 */
195
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_u8(nk_size_t width, nk_size_t depth);
196
+ /** @copydoc nk_dots_packed_size_bf16 */
197
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_i4(nk_size_t width, nk_size_t depth);
198
+ /** @copydoc nk_dots_packed_size_bf16 */
199
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_u4(nk_size_t width, nk_size_t depth);
200
+ /** @copydoc nk_dots_packed_size_bf16 */
201
+ NK_DYNAMIC nk_size_t nk_dots_packed_size_u1(nk_size_t width, nk_size_t depth);
202
+
203
+ /**
204
+ * @brief Packs the second multiplier (B) matrix into a backend-specific layout.
205
+ * @param[in] b The input B matrix in row-major order.
206
+ * @param[in] width The number of rows in B (output columns).
207
+ * @param[in] depth The number of columns in B.
208
+ * @param[in] b_stride The row stride in bytes for B.
209
+ * @param[out] b_packed The output packed buffer from nk_dots_packed_size_bf16.
210
+ */
211
+ NK_DYNAMIC void nk_dots_pack_bf16(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
212
+ void *b_packed);
213
+ /** @copydoc nk_dots_pack_bf16 */
214
+ NK_DYNAMIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
215
+ void *b_packed);
216
+ /** @copydoc nk_dots_pack_bf16 */
217
+ NK_DYNAMIC void nk_dots_pack_e4m3(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
218
+ void *b_packed);
219
+ /** @copydoc nk_dots_pack_bf16 */
220
+ NK_DYNAMIC void nk_dots_pack_e5m2(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
221
+ void *b_packed);
222
+ /** @copydoc nk_dots_pack_bf16 */
223
+ NK_DYNAMIC void nk_dots_pack_e2m3(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
224
+ void *b_packed);
225
+ /** @copydoc nk_dots_pack_bf16 */
226
+ NK_DYNAMIC void nk_dots_pack_e3m2(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
227
+ void *b_packed);
228
+ /** @copydoc nk_dots_pack_bf16 */
229
+ NK_DYNAMIC void nk_dots_pack_f32(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
230
+ void *b_packed);
231
+ /** @copydoc nk_dots_pack_bf16 */
232
+ NK_DYNAMIC void nk_dots_pack_f64(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
233
+ void *b_packed);
234
+ /** @copydoc nk_dots_pack_bf16 */
235
+ NK_DYNAMIC void nk_dots_pack_i8(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride, void *b_packed);
236
+ /** @copydoc nk_dots_pack_bf16 */
237
+ NK_DYNAMIC void nk_dots_pack_u8(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride, void *b_packed);
238
+ /** @copydoc nk_dots_pack_bf16 */
239
+ NK_DYNAMIC void nk_dots_pack_i4(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
240
+ void *b_packed);
241
+ /** @copydoc nk_dots_pack_bf16 */
242
+ NK_DYNAMIC void nk_dots_pack_u4(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
243
+ void *b_packed);
244
+ /** @copydoc nk_dots_pack_bf16 */
245
+ NK_DYNAMIC void nk_dots_pack_u1(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
246
+ void *b_packed);
247
+
248
+ /**
249
+ * @brief Computes C = A × Bᵀ using packed second multiplier matrix (B), accumulating into C.
250
+ * @param[in] a The input A matrix in row-major order.
251
+ * @param[in] b_packed The packed B matrix produced.
252
+ * @param[out] c The output C matrix in row-major order.
253
+ * @param[in] height The number of rows in A.
254
+ * @param[in] width The number of rows in B (output columns).
255
+ * @param[in] depth The shared inner dimension.
256
+ * @param[in] a_stride The row stride in bytes for A.
257
+ * @param[in] c_stride The row stride in bytes for C.
258
+ */
259
+ NK_DYNAMIC void nk_dots_packed_bf16(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
260
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
261
+ /** @copydoc nk_dots_packed_bf16 */
262
+ NK_DYNAMIC void nk_dots_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
263
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
264
+ /** @copydoc nk_dots_packed_bf16 */
265
+ NK_DYNAMIC void nk_dots_packed_e4m3(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
266
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
267
+ /** @copydoc nk_dots_packed_bf16 */
268
+ NK_DYNAMIC void nk_dots_packed_e5m2(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
269
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
270
+ /** @copydoc nk_dots_packed_bf16 */
271
+ NK_DYNAMIC void nk_dots_packed_e2m3(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
272
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
273
+ /** @copydoc nk_dots_packed_bf16 */
274
+ NK_DYNAMIC void nk_dots_packed_e3m2(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
275
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
276
+ /** @copydoc nk_dots_packed_bf16 */
277
+ NK_DYNAMIC void nk_dots_packed_f32(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
278
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
279
+ /** @copydoc nk_dots_packed_bf16 */
280
+ NK_DYNAMIC void nk_dots_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
281
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
282
+ /** @copydoc nk_dots_packed_bf16 */
283
+ NK_DYNAMIC void nk_dots_packed_i8(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
284
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
285
+ /** @copydoc nk_dots_packed_bf16 */
286
+ NK_DYNAMIC void nk_dots_packed_u8(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
287
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
288
+ /** @copydoc nk_dots_packed_bf16 */
289
+ NK_DYNAMIC void nk_dots_packed_i4(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
290
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
291
+ /** @copydoc nk_dots_packed_bf16 */
292
+ NK_DYNAMIC void nk_dots_packed_u4(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
293
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
294
+ /** @copydoc nk_dots_packed_bf16 */
295
+ NK_DYNAMIC void nk_dots_packed_u1(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
296
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
297
+
298
+ /**
299
+ * @brief Computes C = A × Aᵀ symmetric Gram matrix.
300
+ * @param[in] vectors Input matrix of row vectors in row-major order.
301
+ * @param[in] n_vectors Number of vectors (rows) in the input matrix.
302
+ * @param[in] depth Dimension of each vector (columns).
303
+ * @param[in] stride Row stride in bytes for the input matrix.
304
+ * @param[out] result Output symmetric matrix (n_vectors × n_vectors).
305
+ * @param[in] result_stride Row stride in bytes for the result matrix.
306
+ * @param[in] row_start Starting row offset of results to compute (needed for parallelism).
307
+ * @param[in] row_count Number of rows of results to compute (needed for parallelism).
308
+ */
309
+ NK_DYNAMIC void nk_dots_symmetric_bf16(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
310
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
311
+ nk_size_t row_count);
312
+ /** @copydoc nk_dots_symmetric_bf16 */
313
+ NK_DYNAMIC void nk_dots_symmetric_f16(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
314
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
315
+ nk_size_t row_count);
316
+ /** @copydoc nk_dots_symmetric_bf16 */
317
+ NK_DYNAMIC void nk_dots_symmetric_e4m3(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
318
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
319
+ nk_size_t row_count);
320
+ /** @copydoc nk_dots_symmetric_bf16 */
321
+ NK_DYNAMIC void nk_dots_symmetric_e5m2(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
322
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
323
+ nk_size_t row_count);
324
+ /** @copydoc nk_dots_symmetric_bf16 */
325
+ NK_DYNAMIC void nk_dots_symmetric_e2m3(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
326
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
327
+ nk_size_t row_count);
328
+ /** @copydoc nk_dots_symmetric_bf16 */
329
+ NK_DYNAMIC void nk_dots_symmetric_e3m2(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
330
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
331
+ nk_size_t row_count);
332
+ /** @copydoc nk_dots_symmetric_bf16 */
333
+ NK_DYNAMIC void nk_dots_symmetric_f32(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
334
+ nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start,
335
+ nk_size_t row_count);
336
+ /** @copydoc nk_dots_symmetric_bf16 */
337
+ NK_DYNAMIC void nk_dots_symmetric_f64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
338
+ nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start,
339
+ nk_size_t row_count);
340
+ /** @copydoc nk_dots_symmetric_bf16 */
341
+ NK_DYNAMIC void nk_dots_symmetric_i8(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
342
+ nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
343
+ nk_size_t row_count);
344
+ /** @copydoc nk_dots_symmetric_bf16 */
345
+ NK_DYNAMIC void nk_dots_symmetric_u8(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
346
+ nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
347
+ nk_size_t row_count);
348
+ /** @copydoc nk_dots_symmetric_bf16 */
349
+ NK_DYNAMIC void nk_dots_symmetric_i4(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
350
+ nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
351
+ nk_size_t row_count);
352
+ /** @copydoc nk_dots_symmetric_bf16 */
353
+ NK_DYNAMIC void nk_dots_symmetric_u4(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
354
+ nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
355
+ nk_size_t row_count);
356
+ /** @copydoc nk_dots_symmetric_bf16 */
357
+ NK_DYNAMIC void nk_dots_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
358
+ nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
359
+ nk_size_t row_count);
360
+
361
+ /** @copydoc nk_dots_packed_size_f32 */
362
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_serial(nk_size_t width, nk_size_t depth);
363
+ /** @copydoc nk_dots_pack_f32 */
364
+ NK_PUBLIC void nk_dots_pack_f32_serial(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
365
+ void *b_packed);
366
+ /** @copydoc nk_dots_packed_f32 */
367
+ NK_PUBLIC void nk_dots_packed_f32_serial(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
368
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
369
+ /** @copydoc nk_dots_symmetric_f32 */
370
+ NK_PUBLIC void nk_dots_symmetric_f32_serial(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
371
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
372
+ nk_size_t row_start, nk_size_t row_count);
373
+
374
+ /** @copydoc nk_dots_packed_size_f64 */
375
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_serial(nk_size_t width, nk_size_t depth);
376
+ /** @copydoc nk_dots_pack_f64 */
377
+ NK_PUBLIC void nk_dots_pack_f64_serial(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
378
+ void *b_packed);
379
+ /** @copydoc nk_dots_packed_f64 */
380
+ NK_PUBLIC void nk_dots_packed_f64_serial(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
381
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
382
+ /** @copydoc nk_dots_symmetric_f64 */
383
+ NK_PUBLIC void nk_dots_symmetric_f64_serial(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
384
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
385
+ nk_size_t row_start, nk_size_t row_count);
386
+
387
+ /** @copydoc nk_dots_packed_size_f16 */
388
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_serial(nk_size_t width, nk_size_t depth);
389
+ /** @copydoc nk_dots_pack_f16 */
390
+ NK_PUBLIC void nk_dots_pack_f16_serial(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
391
+ void *b_packed);
392
+ /** @copydoc nk_dots_packed_f16 */
393
+ NK_PUBLIC void nk_dots_packed_f16_serial(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
394
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
395
+ /** @copydoc nk_dots_symmetric_f16 */
396
+ NK_PUBLIC void nk_dots_symmetric_f16_serial(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
397
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
398
+ nk_size_t row_start, nk_size_t row_count);
399
+
400
+ /** @copydoc nk_dots_packed_size_bf16 */
401
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_serial(nk_size_t width, nk_size_t depth);
402
+ /** @copydoc nk_dots_pack_bf16 */
403
+ NK_PUBLIC void nk_dots_pack_bf16_serial(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
404
+ void *b_packed);
405
+ /** @copydoc nk_dots_packed_bf16 */
406
+ NK_PUBLIC void nk_dots_packed_bf16_serial(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
407
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
408
+ /** @copydoc nk_dots_symmetric_bf16 */
409
+ NK_PUBLIC void nk_dots_symmetric_bf16_serial(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
410
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
411
+ nk_size_t row_start, nk_size_t row_count);
412
+
413
+ /** @copydoc nk_dots_packed_size_i8 */
414
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_serial(nk_size_t width, nk_size_t depth);
415
+ /** @copydoc nk_dots_pack_i8 */
416
+ NK_PUBLIC void nk_dots_pack_i8_serial(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
417
+ void *b_packed);
418
+ /** @copydoc nk_dots_packed_i8 */
419
+ NK_PUBLIC void nk_dots_packed_i8_serial(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
420
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
421
+ /** @copydoc nk_dots_symmetric_i8 */
422
+ NK_PUBLIC void nk_dots_symmetric_i8_serial(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
423
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
424
+ nk_size_t row_start, nk_size_t row_count);
425
+
426
+ /** @copydoc nk_dots_packed_size_u8 */
427
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_serial(nk_size_t width, nk_size_t depth);
428
+ /** @copydoc nk_dots_pack_u8 */
429
+ NK_PUBLIC void nk_dots_pack_u8_serial(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
430
+ void *b_packed);
431
+ /** @copydoc nk_dots_packed_u8 */
432
+ NK_PUBLIC void nk_dots_packed_u8_serial(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
433
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
434
+ /** @copydoc nk_dots_symmetric_u8 */
435
+ NK_PUBLIC void nk_dots_symmetric_u8_serial(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
436
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
437
+ nk_size_t row_start, nk_size_t row_count);
438
+
439
+ /** @copydoc nk_dots_packed_size_u4 */
440
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u4_serial(nk_size_t width, nk_size_t depth);
441
+ /** @copydoc nk_dots_pack_u4 */
442
+ NK_PUBLIC void nk_dots_pack_u4_serial(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
443
+ void *b_packed);
444
+ /** @copydoc nk_dots_packed_u4 */
445
+ NK_PUBLIC void nk_dots_packed_u4_serial(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
446
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
447
+ /** @copydoc nk_dots_symmetric_u4 */
448
+ NK_PUBLIC void nk_dots_symmetric_u4_serial(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
449
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
450
+ nk_size_t row_start, nk_size_t row_count);
451
+
452
+ /** @copydoc nk_dots_packed_size_u1 */
453
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u1_serial(nk_size_t width, nk_size_t depth);
454
+ /** @copydoc nk_dots_pack_u1 */
455
+ NK_PUBLIC void nk_dots_pack_u1_serial(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
456
+ void *b_packed);
457
+ /** @copydoc nk_dots_packed_u1 */
458
+ NK_PUBLIC void nk_dots_packed_u1_serial(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
459
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
460
+ /** @copydoc nk_dots_symmetric_u1 */
461
+ NK_PUBLIC void nk_dots_symmetric_u1_serial(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
462
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
463
+ nk_size_t row_start, nk_size_t row_count);
464
+
465
+ /** @copydoc nk_dots_packed_size_i4 */
466
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i4_serial(nk_size_t width, nk_size_t depth);
467
+ /** @copydoc nk_dots_pack_i4 */
468
+ NK_PUBLIC void nk_dots_pack_i4_serial(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
469
+ void *b_packed);
470
+ /** @copydoc nk_dots_packed_i4 */
471
+ NK_PUBLIC void nk_dots_packed_i4_serial(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
472
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
473
+ /** @copydoc nk_dots_symmetric_i4 */
474
+ NK_PUBLIC void nk_dots_symmetric_i4_serial(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
475
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
476
+ nk_size_t row_start, nk_size_t row_count);
477
+ /** @copydoc nk_dots_symmetric_e4m3 */
478
+ NK_PUBLIC void nk_dots_symmetric_e4m3_serial(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
479
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
480
+ nk_size_t row_start, nk_size_t row_count);
481
+ /** @copydoc nk_dots_symmetric_e5m2 */
482
+ NK_PUBLIC void nk_dots_symmetric_e5m2_serial(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
483
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
484
+ nk_size_t row_start, nk_size_t row_count);
485
+ /** @copydoc nk_dots_symmetric_e2m3 */
486
+ NK_PUBLIC void nk_dots_symmetric_e2m3_serial(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
487
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
488
+ nk_size_t row_start, nk_size_t row_count);
489
+ /** @copydoc nk_dots_symmetric_e3m2 */
490
+ NK_PUBLIC void nk_dots_symmetric_e3m2_serial(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
491
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
492
+ nk_size_t row_start, nk_size_t row_count);
493
+ /** @copydoc nk_dots_packed_size_e2m3 */
494
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_serial(nk_size_t width, nk_size_t depth);
495
+ /** @copydoc nk_dots_pack_e2m3 */
496
+ NK_PUBLIC void nk_dots_pack_e2m3_serial(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
497
+ void *b_packed);
498
+ /** @copydoc nk_dots_packed_e2m3 */
499
+ NK_PUBLIC void nk_dots_packed_e2m3_serial(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
500
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
501
+ /** @copydoc nk_dots_packed_size_e3m2 */
502
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_serial(nk_size_t width, nk_size_t depth);
503
+ /** @copydoc nk_dots_pack_e3m2 */
504
+ NK_PUBLIC void nk_dots_pack_e3m2_serial(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
505
+ void *b_packed);
506
+ /** @copydoc nk_dots_packed_e3m2 */
507
+ NK_PUBLIC void nk_dots_packed_e3m2_serial(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
508
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
509
+
510
+ /* Genoa backends using AVX-512 with BF16 extensions.
511
+ * These use VDPBF16PS for BF16 dot products.
512
+ * Packing interleaves elements for SIMD broadcast patterns.
513
+ */
514
+ #if NK_TARGET_GENOA
515
+ /** @copydoc nk_dots_packed_size_bf16 */
516
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_genoa(nk_size_t width, nk_size_t depth);
517
+ /** @copydoc nk_dots_pack_bf16 */
518
+ NK_PUBLIC void nk_dots_pack_bf16_genoa(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
519
+ void *b_packed);
520
+ /** @copydoc nk_dots_packed_bf16 */
521
+ NK_PUBLIC void nk_dots_packed_bf16_genoa(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
522
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
523
+ /** @copydoc nk_dots_symmetric_bf16 */
524
+ NK_PUBLIC void nk_dots_symmetric_bf16_genoa(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
525
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
526
+ nk_size_t row_start, nk_size_t row_count);
527
+
528
+ /** @copydoc nk_dots_packed_size_e4m3 */
529
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_genoa(nk_size_t width, nk_size_t depth);
530
+ /** @copydoc nk_dots_pack_e4m3 */
531
+ NK_PUBLIC void nk_dots_pack_e4m3_genoa(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
532
+ void *b_packed);
533
+ /** @copydoc nk_dots_packed_e4m3 */
534
+ NK_PUBLIC void nk_dots_packed_e4m3_genoa(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
535
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
536
+ /** @copydoc nk_dots_packed_size_e5m2 */
537
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_genoa(nk_size_t width, nk_size_t depth);
538
+ /** @copydoc nk_dots_pack_e5m2 */
539
+ NK_PUBLIC void nk_dots_pack_e5m2_genoa(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
540
+ void *b_packed);
541
+ /** @copydoc nk_dots_packed_e5m2 */
542
+ NK_PUBLIC void nk_dots_packed_e5m2_genoa(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
543
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
544
+ /** @copydoc nk_dots_symmetric_e4m3 */
545
+ NK_PUBLIC void nk_dots_symmetric_e4m3_genoa(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
546
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
547
+ nk_size_t row_start, nk_size_t row_count);
548
+ /** @copydoc nk_dots_symmetric_e5m2 */
549
+ NK_PUBLIC void nk_dots_symmetric_e5m2_genoa(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
550
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
551
+ nk_size_t row_start, nk_size_t row_count);
552
+ #endif // NK_TARGET_GENOA
553
+
554
+ /* Sapphire Rapids backends using Intel AMX (Advanced Matrix Extensions).
555
+ * AMX provides 8 tile registers (TMM0-TMM7), each holding up to 1KB of data.
556
+ * Tiles are configured as 16 rows × 64 bytes, enabling (16 × 32) BF16 or (16 × 64) INT8 tiles.
557
+ * Packing arranges data into AMX-native tile layout with pair interleaving for TDPBF16PS.
558
+ */
559
+ #if NK_TARGET_SAPPHIREAMX
560
+ /** @copydoc nk_dots_packed_size_bf16 */
561
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sapphireamx(nk_size_t width, nk_size_t depth);
562
+ /** @copydoc nk_dots_pack_bf16 */
563
+ NK_PUBLIC void nk_dots_pack_bf16_sapphireamx(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
564
+ void *b_packed);
565
+ /** @copydoc nk_dots_packed_bf16 */
566
+ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
567
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride,
568
+ nk_size_t c_stride);
569
+ /** @copydoc nk_dots_symmetric_bf16 */
570
+ NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
571
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
572
+ nk_size_t row_start, nk_size_t row_count);
573
+
574
+ /** @copydoc nk_dots_packed_size_i8 */
575
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sapphireamx(nk_size_t width, nk_size_t depth);
576
+ /** @copydoc nk_dots_pack_i8 */
577
+ NK_PUBLIC void nk_dots_pack_i8_sapphireamx(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
578
+ void *b_packed);
579
+ /** @copydoc nk_dots_packed_i8 */
580
+ NK_PUBLIC void nk_dots_packed_i8_sapphireamx(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
581
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
582
+ /** @copydoc nk_dots_symmetric_i8 */
583
+ NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
584
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
585
+ nk_size_t row_start, nk_size_t row_count);
586
+
587
+ /** @copydoc nk_dots_packed_size_e4m3 */
588
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sapphireamx(nk_size_t width, nk_size_t depth);
589
+ /** @copydoc nk_dots_pack_e4m3 */
590
+ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
591
+ void *b_packed);
592
+ /** @copydoc nk_dots_packed_e4m3 */
593
+ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
594
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride,
595
+ nk_size_t c_stride);
596
+
597
+ /** @copydoc nk_dots_symmetric_e4m3 */
598
+ NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
599
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
600
+ nk_size_t row_start, nk_size_t row_count);
601
+
602
+ /** @copydoc nk_dots_packed_size_e5m2 */
603
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sapphireamx(nk_size_t width, nk_size_t depth);
604
+ /** @copydoc nk_dots_pack_e5m2 */
605
+ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
606
+ void *b_packed);
607
+ /** @copydoc nk_dots_packed_e5m2 */
608
+ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
609
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride,
610
+ nk_size_t c_stride);
611
+ /** @copydoc nk_dots_symmetric_e5m2 */
612
+ NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
613
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
614
+ nk_size_t row_start, nk_size_t row_count);
615
+ /** @copydoc nk_dots_packed_size_e2m3 */
616
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sapphireamx(nk_size_t width, nk_size_t depth);
617
+ /** @copydoc nk_dots_pack_e2m3 */
618
+ NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
619
+ void *b_packed);
620
+ /** @copydoc nk_dots_packed_e2m3 */
621
+ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
622
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride,
623
+ nk_size_t c_stride);
624
+ /** @copydoc nk_dots_symmetric_e2m3 */
625
+ NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
626
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
627
+ nk_size_t row_start, nk_size_t row_count);
628
+
629
+ /** @copydoc nk_dots_packed_size_e3m2 */
630
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_sapphireamx(nk_size_t width, nk_size_t depth);
631
+ /** @copydoc nk_dots_pack_e3m2 */
632
+ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
633
+ void *b_packed);
634
+ /** @copydoc nk_dots_packed_e3m2 */
635
+ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
636
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride,
637
+ nk_size_t c_stride);
638
+ /** @copydoc nk_dots_symmetric_e3m2 */
639
+ NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
640
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
641
+ nk_size_t row_start, nk_size_t row_count);
642
+
643
+ /** @copydoc nk_dots_packed_size_u8 */
644
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sapphireamx(nk_size_t width, nk_size_t depth);
645
+ /** @copydoc nk_dots_pack_u8 */
646
+ NK_PUBLIC void nk_dots_pack_u8_sapphireamx(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
647
+ void *b_packed);
648
+ /** @copydoc nk_dots_packed_u8 */
649
+ NK_PUBLIC void nk_dots_packed_u8_sapphireamx(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
650
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
651
+ /** @copydoc nk_dots_symmetric_u8 */
652
+ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
653
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
654
+ nk_size_t row_start, nk_size_t row_count);
655
+ #endif // NK_TARGET_SAPPHIREAMX
656
+
657
+ /* ARM SME backends using Scalable Matrix Extension.
658
+ * SME provides ZA tile registers for outer product operations.
659
+ * F16/BF16/I8/U8/E4M3 use ZA32 tiles, F32/F64 use ZA64 tiles (FEAT_SME_F64F64).
660
+ */
661
+ #if NK_TARGET_SME
662
+ /** @copydoc nk_dots_packed_size_f16 */
663
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_sme(nk_size_t width, nk_size_t depth);
664
+ /** @copydoc nk_dots_pack_f16 */
665
+ NK_PUBLIC void nk_dots_pack_f16_sme(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
666
+ void *b_packed);
667
+ /** @copydoc nk_dots_packed_f16 */
668
+ NK_PUBLIC void nk_dots_packed_f16_sme(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
669
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
670
+ /** @copydoc nk_dots_symmetric_f16 */
671
+ NK_PUBLIC void nk_dots_symmetric_f16_sme(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
672
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
673
+ nk_size_t row_start, nk_size_t row_count);
674
+
675
+ /** @copydoc nk_dots_packed_size_bf16 */
676
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sme(nk_size_t width, nk_size_t depth);
677
+ /** @copydoc nk_dots_pack_bf16 */
678
+ NK_PUBLIC void nk_dots_pack_bf16_sme(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
679
+ void *b_packed);
680
+ /** @copydoc nk_dots_packed_bf16 */
681
+ NK_PUBLIC void nk_dots_packed_bf16_sme(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
682
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
683
+ /** @copydoc nk_dots_symmetric_bf16 */
684
+ NK_PUBLIC void nk_dots_symmetric_bf16_sme(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
685
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
686
+ nk_size_t row_start, nk_size_t row_count);
687
+
688
+ /** @copydoc nk_dots_packed_size_i8 */
689
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sme(nk_size_t width, nk_size_t depth);
690
+ /** @copydoc nk_dots_pack_i8 */
691
+ NK_PUBLIC void nk_dots_pack_i8_sme(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
692
+ void *b_packed);
693
+ /** @copydoc nk_dots_packed_i8 */
694
+ NK_PUBLIC void nk_dots_packed_i8_sme(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
695
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
696
+ /** @copydoc nk_dots_symmetric_i8 */
697
+ NK_PUBLIC void nk_dots_symmetric_i8_sme(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
698
+ nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
699
+ nk_size_t row_count);
700
+
701
+ /** @copydoc nk_dots_packed_size_u8 */
702
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sme(nk_size_t width, nk_size_t depth);
703
+ /** @copydoc nk_dots_pack_u8 */
704
+ NK_PUBLIC void nk_dots_pack_u8_sme(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
705
+ void *b_packed);
706
+ /** @copydoc nk_dots_packed_u8 */
707
+ NK_PUBLIC void nk_dots_packed_u8_sme(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
708
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
709
+ /** @copydoc nk_dots_symmetric_u8 */
710
+ NK_PUBLIC void nk_dots_symmetric_u8_sme(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
711
+ nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
712
+ nk_size_t row_count);
713
+
714
+ /** @copydoc nk_dots_packed_size_e4m3 */
715
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sme(nk_size_t width, nk_size_t depth);
716
+ /** @copydoc nk_dots_pack_e4m3 */
717
+ NK_PUBLIC void nk_dots_pack_e4m3_sme(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
718
+ void *b_packed);
719
+ /** @copydoc nk_dots_packed_e4m3 */
720
+ NK_PUBLIC void nk_dots_packed_e4m3_sme(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
721
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
722
+ /** @copydoc nk_dots_symmetric_e4m3 */
723
+ NK_PUBLIC void nk_dots_symmetric_e4m3_sme(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
724
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
725
+ nk_size_t row_start, nk_size_t row_count);
726
+
727
+ /** @copydoc nk_dots_packed_size_e5m2 */
728
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sme(nk_size_t width, nk_size_t depth);
729
+ /** @copydoc nk_dots_pack_e5m2 */
730
+ NK_PUBLIC void nk_dots_pack_e5m2_sme(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
731
+ void *b_packed);
732
+ /** @copydoc nk_dots_packed_e5m2 */
733
+ NK_PUBLIC void nk_dots_packed_e5m2_sme(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
734
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
735
+ /** @copydoc nk_dots_symmetric_e5m2 */
736
+ NK_PUBLIC void nk_dots_symmetric_e5m2_sme(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
737
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
738
+ nk_size_t row_start, nk_size_t row_count);
739
+
740
+ /** @copydoc nk_dots_packed_size_u4 */
741
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u4_sme(nk_size_t width, nk_size_t depth);
742
+ /** @copydoc nk_dots_pack_u4 */
743
+ NK_PUBLIC void nk_dots_pack_u4_sme(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
744
+ void *b_packed);
745
+ /** @copydoc nk_dots_packed_u4 */
746
+ NK_PUBLIC void nk_dots_packed_u4_sme(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
747
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
748
+ /** @copydoc nk_dots_symmetric_u4 */
749
+ NK_PUBLIC void nk_dots_symmetric_u4_sme(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
750
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
751
+ nk_size_t row_start, nk_size_t row_count);
752
+
753
+ /** @copydoc nk_dots_packed_size_i4 */
754
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i4_sme(nk_size_t width, nk_size_t depth);
755
+ /** @copydoc nk_dots_pack_i4 */
756
+ NK_PUBLIC void nk_dots_pack_i4_sme(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
757
+ void *b_packed);
758
+ /** @copydoc nk_dots_packed_i4 */
759
+ NK_PUBLIC void nk_dots_packed_i4_sme(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
760
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
761
+ /** @copydoc nk_dots_symmetric_i4 */
762
+ NK_PUBLIC void nk_dots_symmetric_i4_sme(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
763
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
764
+ nk_size_t row_start, nk_size_t row_count);
765
+
766
+ /** @copydoc nk_dots_packed_size_e2m3 */
767
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sme(nk_size_t width, nk_size_t depth);
768
+ /** @copydoc nk_dots_pack_e2m3 */
769
+ NK_PUBLIC void nk_dots_pack_e2m3_sme(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
770
+ void *b_packed);
771
+ /** @copydoc nk_dots_packed_e2m3 */
772
+ NK_PUBLIC void nk_dots_packed_e2m3_sme(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
773
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
774
+ /** @copydoc nk_dots_symmetric_e2m3 */
775
+ NK_PUBLIC void nk_dots_symmetric_e2m3_sme(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
776
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
777
+ nk_size_t row_start, nk_size_t row_count);
778
+
779
+ /** @copydoc nk_dots_packed_size_e3m2 */
780
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_sme(nk_size_t width, nk_size_t depth);
781
+ /** @copydoc nk_dots_pack_e3m2 */
782
+ NK_PUBLIC void nk_dots_pack_e3m2_sme(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
783
+ void *b_packed);
784
+ /** @copydoc nk_dots_packed_e3m2 */
785
+ NK_PUBLIC void nk_dots_packed_e3m2_sme(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
786
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
787
+ /** @copydoc nk_dots_symmetric_e3m2 */
788
+ NK_PUBLIC void nk_dots_symmetric_e3m2_sme(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
789
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
790
+ nk_size_t row_start, nk_size_t row_count);
791
+ #endif // NK_TARGET_SME
792
+
793
+ /* ARM SME with integer-accumulating binary outer products.
794
+ * Used for packed 1-bit dot products backed by ZA32.
795
+ */
796
+ #if NK_TARGET_SMEBI32
797
+ /** @copydoc nk_dots_packed_size_u1 */
798
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u1_smebi32(nk_size_t width, nk_size_t depth);
799
+ /** @copydoc nk_dots_pack_u1 */
800
+ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
801
+ void *b_packed);
802
+ /** @copydoc nk_dots_packed_u1 */
803
+ NK_PUBLIC void nk_dots_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
804
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
805
+ /** @copydoc nk_dots_symmetric_u1 */
806
+ NK_PUBLIC void nk_dots_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
807
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
808
+ nk_size_t row_start, nk_size_t row_count);
809
+ #endif // NK_TARGET_SMEBI32
810
+
811
+ /* ARM SME with FEAT_SME_F64F64 (F32/F64 with F64 accumulators).
812
+ * Requires Apple M4 or equivalent with F64 outer product support.
813
+ */
814
+ #if NK_TARGET_SMEF64
815
+ /** @copydoc nk_dots_packed_size_f32 */
816
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_smef64(nk_size_t width, nk_size_t depth);
817
+ /** @copydoc nk_dots_pack_f32 */
818
+ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
819
+ void *b_packed);
820
+ /** @copydoc nk_dots_packed_f32 */
821
+ NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
822
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
823
+ /** @copydoc nk_dots_symmetric_f32 */
824
+ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
825
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
826
+ nk_size_t row_start, nk_size_t row_count);
827
+
828
+ /** @copydoc nk_dots_packed_size_f64 */
829
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t width, nk_size_t depth);
830
+ /** @copydoc nk_dots_pack_f64 */
831
+ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
832
+ void *b_packed);
833
+ /** @copydoc nk_dots_packed_f64 */
834
+ NK_PUBLIC void nk_dots_packed_f64_smef64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
835
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
836
+ /** @copydoc nk_dots_symmetric_f64 */
837
+ NK_PUBLIC void nk_dots_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
838
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
839
+ nk_size_t row_start, nk_size_t row_count);
840
+ #endif // NK_TARGET_SMEF64
841
+
842
+ /* Haswell backends using AVX2 (Intel Core 4th gen).
843
+ * Supports F32/F64 via FMA, F16/BF16/FP8 via software emulation, I8/U8 via VPMADDUBSW+VPADDD.
844
+ */
845
+ #if NK_TARGET_HASWELL
846
+ /** @copydoc nk_dots_packed_size_f32 */
847
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_haswell(nk_size_t width, nk_size_t depth);
848
+ /** @copydoc nk_dots_pack_f32 */
849
+ NK_PUBLIC void nk_dots_pack_f32_haswell(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
850
+ void *b_packed);
851
+ /** @copydoc nk_dots_packed_f32 */
852
+ NK_PUBLIC void nk_dots_packed_f32_haswell(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
853
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
854
+ /** @copydoc nk_dots_symmetric_f32 */
855
+ NK_PUBLIC void nk_dots_symmetric_f32_haswell(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
856
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
857
+ nk_size_t row_start, nk_size_t row_count);
858
+ /** @copydoc nk_dots_packed_size_f64 */
859
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_haswell(nk_size_t width, nk_size_t depth);
860
+ /** @copydoc nk_dots_pack_f64 */
861
+ NK_PUBLIC void nk_dots_pack_f64_haswell(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
862
+ void *b_packed);
863
+ /** @copydoc nk_dots_packed_f64 */
864
+ NK_PUBLIC void nk_dots_packed_f64_haswell(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
865
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
866
+ /** @copydoc nk_dots_symmetric_f64 */
867
+ NK_PUBLIC void nk_dots_symmetric_f64_haswell(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
868
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
869
+ nk_size_t row_start, nk_size_t row_count);
870
+ /** @copydoc nk_dots_packed_size_f16 */
871
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_haswell(nk_size_t width, nk_size_t depth);
872
+ /** @copydoc nk_dots_pack_f16 */
873
+ NK_PUBLIC void nk_dots_pack_f16_haswell(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
874
+ void *b_packed);
875
+ /** @copydoc nk_dots_packed_f16 */
876
+ NK_PUBLIC void nk_dots_packed_f16_haswell(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
877
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
878
+ /** @copydoc nk_dots_symmetric_f16 */
879
+ NK_PUBLIC void nk_dots_symmetric_f16_haswell(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
880
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
881
+ nk_size_t row_start, nk_size_t row_count);
882
+ /** @copydoc nk_dots_packed_size_bf16 */
883
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_haswell(nk_size_t width, nk_size_t depth);
884
+ /** @copydoc nk_dots_pack_bf16 */
885
+ NK_PUBLIC void nk_dots_pack_bf16_haswell(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
886
+ void *b_packed);
887
+ /** @copydoc nk_dots_packed_bf16 */
888
+ NK_PUBLIC void nk_dots_packed_bf16_haswell(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
889
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
890
+ /** @copydoc nk_dots_symmetric_bf16 */
891
+ NK_PUBLIC void nk_dots_symmetric_bf16_haswell(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
892
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
893
+ nk_size_t row_start, nk_size_t row_count);
894
+ /** @copydoc nk_dots_packed_size_e4m3 */
895
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_haswell(nk_size_t width, nk_size_t depth);
896
+ /** @copydoc nk_dots_pack_e4m3 */
897
+ NK_PUBLIC void nk_dots_pack_e4m3_haswell(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
898
+ void *b_packed);
899
+ /** @copydoc nk_dots_packed_e4m3 */
900
+ NK_PUBLIC void nk_dots_packed_e4m3_haswell(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
901
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
902
+ /** @copydoc nk_dots_symmetric_e4m3 */
903
+ NK_PUBLIC void nk_dots_symmetric_e4m3_haswell(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
904
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
905
+ nk_size_t row_start, nk_size_t row_count);
906
+ /** @copydoc nk_dots_packed_size_e5m2 */
907
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_haswell(nk_size_t width, nk_size_t depth);
908
+ /** @copydoc nk_dots_pack_e5m2 */
909
+ NK_PUBLIC void nk_dots_pack_e5m2_haswell(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
910
+ void *b_packed);
911
+ /** @copydoc nk_dots_packed_e5m2 */
912
+ NK_PUBLIC void nk_dots_packed_e5m2_haswell(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
913
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
914
+ /** @copydoc nk_dots_symmetric_e5m2 */
915
+ NK_PUBLIC void nk_dots_symmetric_e5m2_haswell(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
916
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
917
+ nk_size_t row_start, nk_size_t row_count);
918
+ /** @copydoc nk_dots_packed_size_e2m3 */
919
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_haswell(nk_size_t width, nk_size_t depth);
920
+ /** @copydoc nk_dots_pack_e2m3 */
921
+ NK_PUBLIC void nk_dots_pack_e2m3_haswell(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
922
+ void *b_packed);
923
+ /** @copydoc nk_dots_packed_e2m3 */
924
+ NK_PUBLIC void nk_dots_packed_e2m3_haswell(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
925
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
926
+ /** @copydoc nk_dots_symmetric_e2m3 */
927
+ NK_PUBLIC void nk_dots_symmetric_e2m3_haswell(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
928
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
929
+ nk_size_t row_start, nk_size_t row_count);
930
+ /** @copydoc nk_dots_packed_size_e3m2 */
931
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_haswell(nk_size_t width, nk_size_t depth);
932
+ /** @copydoc nk_dots_pack_e3m2 */
933
+ NK_PUBLIC void nk_dots_pack_e3m2_haswell(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
934
+ void *b_packed);
935
+ /** @copydoc nk_dots_packed_e3m2 */
936
+ NK_PUBLIC void nk_dots_packed_e3m2_haswell(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
937
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
938
+ /** @copydoc nk_dots_symmetric_e3m2 */
939
+ NK_PUBLIC void nk_dots_symmetric_e3m2_haswell(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
940
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
941
+ nk_size_t row_start, nk_size_t row_count);
942
+ /** @copydoc nk_dots_packed_size_i8 */
943
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_haswell(nk_size_t width, nk_size_t depth);
944
+ /** @copydoc nk_dots_pack_i8 */
945
+ NK_PUBLIC void nk_dots_pack_i8_haswell(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
946
+ void *b_packed);
947
+ /** @copydoc nk_dots_packed_i8 */
948
+ NK_PUBLIC void nk_dots_packed_i8_haswell(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
949
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
950
+ /** @copydoc nk_dots_symmetric_i8 */
951
+ NK_PUBLIC void nk_dots_symmetric_i8_haswell(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
952
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
953
+ nk_size_t row_start, nk_size_t row_count);
954
+ /** @copydoc nk_dots_packed_size_u8 */
955
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_haswell(nk_size_t width, nk_size_t depth);
956
+ /** @copydoc nk_dots_pack_u8 */
957
+ NK_PUBLIC void nk_dots_pack_u8_haswell(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
958
+ void *b_packed);
959
+ /** @copydoc nk_dots_packed_u8 */
960
+ NK_PUBLIC void nk_dots_packed_u8_haswell(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
961
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
962
+ /** @copydoc nk_dots_symmetric_u8 */
963
+ NK_PUBLIC void nk_dots_symmetric_u8_haswell(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
964
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
965
+ nk_size_t row_start, nk_size_t row_count);
966
+ /** @copydoc nk_dots_packed_size_u1 */
967
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u1_haswell(nk_size_t width, nk_size_t depth);
968
+ /** @copydoc nk_dots_pack_u1 */
969
+ NK_PUBLIC void nk_dots_pack_u1_haswell(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
970
+ void *b_packed);
971
+ /** @copydoc nk_dots_packed_u1 */
972
+ NK_PUBLIC void nk_dots_packed_u1_haswell(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
973
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
974
+ /** @copydoc nk_dots_symmetric_u1 */
975
+ NK_PUBLIC void nk_dots_symmetric_u1_haswell(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
976
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
977
+ nk_size_t row_start, nk_size_t row_count);
978
+ /** @copydoc nk_dots_packed_size_i4 */
979
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i4_haswell(nk_size_t width, nk_size_t depth);
980
+ /** @copydoc nk_dots_pack_i4 */
981
+ NK_PUBLIC void nk_dots_pack_i4_haswell(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
982
+ void *b_packed);
983
+ /** @copydoc nk_dots_packed_i4 */
984
+ NK_PUBLIC void nk_dots_packed_i4_haswell(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
985
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
986
+ /** @copydoc nk_dots_symmetric_i4 */
987
+ NK_PUBLIC void nk_dots_symmetric_i4_haswell(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
988
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
989
+ nk_size_t row_start, nk_size_t row_count);
990
+ /** @copydoc nk_dots_packed_size_u4 */
991
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u4_haswell(nk_size_t width, nk_size_t depth);
992
+ /** @copydoc nk_dots_pack_u4 */
993
+ NK_PUBLIC void nk_dots_pack_u4_haswell(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
994
+ void *b_packed);
995
+ /** @copydoc nk_dots_packed_u4 */
996
+ NK_PUBLIC void nk_dots_packed_u4_haswell(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
997
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
998
+ /** @copydoc nk_dots_symmetric_u4 */
999
+ NK_PUBLIC void nk_dots_symmetric_u4_haswell(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1000
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1001
+ nk_size_t row_start, nk_size_t row_count);
1002
+ #endif // NK_TARGET_HASWELL
1003
+
1004
+ /* Skylake backends using AVX-512 (Intel Core 6th gen+).
1005
+ * Provides 512-bit vectors (16× f32, 8× f64), supporting F32/F64/F16/BF16/FP8 with FMA.
1006
+ */
1007
+ #if NK_TARGET_SKYLAKE
1008
+ /** @copydoc nk_dots_packed_size_f64 */
1009
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_skylake(nk_size_t width, nk_size_t depth);
1010
+ /** @copydoc nk_dots_pack_f64 */
1011
+ NK_PUBLIC void nk_dots_pack_f64_skylake(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1012
+ void *b_packed);
1013
+ /** @copydoc nk_dots_packed_f64 */
1014
+ NK_PUBLIC void nk_dots_packed_f64_skylake(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1015
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1016
+ /** @copydoc nk_dots_symmetric_f64 */
1017
+ NK_PUBLIC void nk_dots_symmetric_f64_skylake(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1018
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1019
+ nk_size_t row_start, nk_size_t row_count);
1020
+ /** @copydoc nk_dots_packed_size_f32 */
1021
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_skylake(nk_size_t width, nk_size_t depth);
1022
+ /** @copydoc nk_dots_pack_f32 */
1023
+ NK_PUBLIC void nk_dots_pack_f32_skylake(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1024
+ void *b_packed);
1025
+ /** @copydoc nk_dots_packed_f32 */
1026
+ NK_PUBLIC void nk_dots_packed_f32_skylake(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1027
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1028
+ /** @copydoc nk_dots_symmetric_f32 */
1029
+ NK_PUBLIC void nk_dots_symmetric_f32_skylake(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1030
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1031
+ nk_size_t row_start, nk_size_t row_count);
1032
+ /** @copydoc nk_dots_packed_size_bf16 */
1033
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_skylake(nk_size_t width, nk_size_t depth);
1034
+ /** @copydoc nk_dots_pack_bf16 */
1035
+ NK_PUBLIC void nk_dots_pack_bf16_skylake(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1036
+ void *b_packed);
1037
+ /** @copydoc nk_dots_packed_bf16 */
1038
+ NK_PUBLIC void nk_dots_packed_bf16_skylake(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1039
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1040
+ /** @copydoc nk_dots_symmetric_bf16 */
1041
+ NK_PUBLIC void nk_dots_symmetric_bf16_skylake(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1042
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1043
+ nk_size_t row_start, nk_size_t row_count);
1044
+ /** @copydoc nk_dots_packed_size_f16 */
1045
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_skylake(nk_size_t width, nk_size_t depth);
1046
+ /** @copydoc nk_dots_pack_f16 */
1047
+ NK_PUBLIC void nk_dots_pack_f16_skylake(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1048
+ void *b_packed);
1049
+ /** @copydoc nk_dots_packed_f16 */
1050
+ NK_PUBLIC void nk_dots_packed_f16_skylake(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1051
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1052
+ /** @copydoc nk_dots_symmetric_f16 */
1053
+ NK_PUBLIC void nk_dots_symmetric_f16_skylake(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1054
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1055
+ nk_size_t row_start, nk_size_t row_count);
1056
+ /** @copydoc nk_dots_packed_size_e4m3 */
1057
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_skylake(nk_size_t width, nk_size_t depth);
1058
+ /** @copydoc nk_dots_pack_e4m3 */
1059
+ NK_PUBLIC void nk_dots_pack_e4m3_skylake(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1060
+ void *b_packed);
1061
+ /** @copydoc nk_dots_packed_e4m3 */
1062
+ NK_PUBLIC void nk_dots_packed_e4m3_skylake(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1063
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1064
+ /** @copydoc nk_dots_symmetric_e4m3 */
1065
+ NK_PUBLIC void nk_dots_symmetric_e4m3_skylake(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1066
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1067
+ nk_size_t row_start, nk_size_t row_count);
1068
+ /** @copydoc nk_dots_packed_size_e5m2 */
1069
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_skylake(nk_size_t width, nk_size_t depth);
1070
+ /** @copydoc nk_dots_pack_e5m2 */
1071
+ NK_PUBLIC void nk_dots_pack_e5m2_skylake(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1072
+ void *b_packed);
1073
+ /** @copydoc nk_dots_packed_e5m2 */
1074
+ NK_PUBLIC void nk_dots_packed_e5m2_skylake(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1075
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1076
+ /** @copydoc nk_dots_symmetric_e5m2 */
1077
+ NK_PUBLIC void nk_dots_symmetric_e5m2_skylake(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1078
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1079
+ nk_size_t row_start, nk_size_t row_count);
1080
+ /** @copydoc nk_dots_packed_size_e2m3 */
1081
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_skylake(nk_size_t width, nk_size_t depth);
1082
+ /** @copydoc nk_dots_pack_e2m3 */
1083
+ NK_PUBLIC void nk_dots_pack_e2m3_skylake(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1084
+ void *b_packed);
1085
+ /** @copydoc nk_dots_packed_e2m3 */
1086
+ NK_PUBLIC void nk_dots_packed_e2m3_skylake(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1087
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1088
+ /** @copydoc nk_dots_symmetric_e2m3 */
1089
+ NK_PUBLIC void nk_dots_symmetric_e2m3_skylake(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1090
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1091
+ nk_size_t row_start, nk_size_t row_count);
1092
+ /** @copydoc nk_dots_packed_size_e3m2 */
1093
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_skylake(nk_size_t width, nk_size_t depth);
1094
+ /** @copydoc nk_dots_pack_e3m2 */
1095
+ NK_PUBLIC void nk_dots_pack_e3m2_skylake(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1096
+ void *b_packed);
1097
+ /** @copydoc nk_dots_packed_e3m2 */
1098
+ NK_PUBLIC void nk_dots_packed_e3m2_skylake(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1099
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1100
+ /** @copydoc nk_dots_symmetric_e3m2 */
1101
+ NK_PUBLIC void nk_dots_symmetric_e3m2_skylake(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1102
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1103
+ nk_size_t row_start, nk_size_t row_count);
1104
+ #endif // NK_TARGET_SKYLAKE
1105
+
1106
+ /* Ice Lake backends using AVX-512 with VNNI (Vector Neural Network Instructions).
1107
+ * Adds VPDPBUSD for I8/U8, VPDPWSSD for I4/U4 with efficient dot products.
1108
+ */
1109
+ #if NK_TARGET_ICELAKE
1110
+ /** @copydoc nk_dots_packed_size_i8 */
1111
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_icelake(nk_size_t width, nk_size_t depth);
1112
+ /** @copydoc nk_dots_pack_i8 */
1113
+ NK_PUBLIC void nk_dots_pack_i8_icelake(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1114
+ void *b_packed);
1115
+ /** @copydoc nk_dots_packed_i8 */
1116
+ NK_PUBLIC void nk_dots_packed_i8_icelake(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
1117
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1118
+ /** @copydoc nk_dots_symmetric_i8 */
1119
+ NK_PUBLIC void nk_dots_symmetric_i8_icelake(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1120
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
1121
+ nk_size_t row_start, nk_size_t row_count);
1122
+ /** @copydoc nk_dots_packed_size_u8 */
1123
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_icelake(nk_size_t width, nk_size_t depth);
1124
+ /** @copydoc nk_dots_pack_u8 */
1125
+ NK_PUBLIC void nk_dots_pack_u8_icelake(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1126
+ void *b_packed);
1127
+ /** @copydoc nk_dots_packed_u8 */
1128
+ NK_PUBLIC void nk_dots_packed_u8_icelake(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1129
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1130
+ /** @copydoc nk_dots_symmetric_u8 */
1131
+ NK_PUBLIC void nk_dots_symmetric_u8_icelake(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1132
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1133
+ nk_size_t row_start, nk_size_t row_count);
1134
+ /** @copydoc nk_dots_packed_size_i4 */
1135
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i4_icelake(nk_size_t width, nk_size_t depth);
1136
+ /** @copydoc nk_dots_pack_i4 */
1137
+ NK_PUBLIC void nk_dots_pack_i4_icelake(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1138
+ void *b_packed);
1139
+ /** @copydoc nk_dots_packed_i4 */
1140
+ NK_PUBLIC void nk_dots_packed_i4_icelake(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
1141
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1142
+ /** @copydoc nk_dots_symmetric_i4 */
1143
+ NK_PUBLIC void nk_dots_symmetric_i4_icelake(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1144
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
1145
+ nk_size_t row_start, nk_size_t row_count);
1146
+ /** @copydoc nk_dots_packed_size_u4 */
1147
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u4_icelake(nk_size_t width, nk_size_t depth);
1148
+ /** @copydoc nk_dots_pack_u4 */
1149
+ NK_PUBLIC void nk_dots_pack_u4_icelake(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1150
+ void *b_packed);
1151
+ /** @copydoc nk_dots_packed_u4 */
1152
+ NK_PUBLIC void nk_dots_packed_u4_icelake(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1153
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1154
+ /** @copydoc nk_dots_symmetric_u4 */
1155
+ NK_PUBLIC void nk_dots_symmetric_u4_icelake(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1156
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1157
+ nk_size_t row_start, nk_size_t row_count);
1158
+ /** @copydoc nk_dots_packed_size_u1 */
1159
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u1_icelake(nk_size_t width, nk_size_t depth);
1160
+ /** @copydoc nk_dots_pack_u1 */
1161
+ NK_PUBLIC void nk_dots_pack_u1_icelake(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1162
+ void *b_packed);
1163
+ /** @copydoc nk_dots_packed_u1 */
1164
+ NK_PUBLIC void nk_dots_packed_u1_icelake(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1165
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1166
+ /** @copydoc nk_dots_symmetric_u1 */
1167
+ NK_PUBLIC void nk_dots_symmetric_u1_icelake(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1168
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1169
+ nk_size_t row_start, nk_size_t row_count);
1170
+ #endif // NK_TARGET_ICELAKE
1171
+
1172
+ /* Alder backends using AMX with TDPB[SU]SD / TDPBF16PS.
1173
+ * Optimized for I8/U8 via AMX integer tiles, E2M3 via AMX BF16 tiles.
1174
+ */
1175
+ #if NK_TARGET_ALDER
1176
+ /** @copydoc nk_dots_packed_size_i8 */
1177
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_alder(nk_size_t width, nk_size_t depth);
1178
+ /** @copydoc nk_dots_pack_i8 */
1179
+ NK_PUBLIC void nk_dots_pack_i8_alder(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1180
+ void *b_packed);
1181
+ /** @copydoc nk_dots_packed_i8 */
1182
+ NK_PUBLIC void nk_dots_packed_i8_alder(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
1183
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1184
+ /** @copydoc nk_dots_symmetric_i8 */
1185
+ NK_PUBLIC void nk_dots_symmetric_i8_alder(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1186
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
1187
+ nk_size_t row_start, nk_size_t row_count);
1188
+ /** @copydoc nk_dots_packed_size_u8 */
1189
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_alder(nk_size_t width, nk_size_t depth);
1190
+ /** @copydoc nk_dots_pack_u8 */
1191
+ NK_PUBLIC void nk_dots_pack_u8_alder(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1192
+ void *b_packed);
1193
+ /** @copydoc nk_dots_packed_u8 */
1194
+ NK_PUBLIC void nk_dots_packed_u8_alder(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1195
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1196
+ /** @copydoc nk_dots_symmetric_u8 */
1197
+ NK_PUBLIC void nk_dots_symmetric_u8_alder(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1198
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1199
+ nk_size_t row_start, nk_size_t row_count);
1200
+ /** @copydoc nk_dots_packed_size_e2m3 */
1201
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_alder(nk_size_t width, nk_size_t depth);
1202
+ /** @copydoc nk_dots_pack_e2m3 */
1203
+ NK_PUBLIC void nk_dots_pack_e2m3_alder(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1204
+ void *b_packed);
1205
+ /** @copydoc nk_dots_packed_e2m3 */
1206
+ NK_PUBLIC void nk_dots_packed_e2m3_alder(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1207
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1208
+ /** @copydoc nk_dots_symmetric_e2m3 */
1209
+ NK_PUBLIC void nk_dots_symmetric_e2m3_alder(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1210
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1211
+ nk_size_t row_start, nk_size_t row_count);
1212
+ #endif // NK_TARGET_ALDER
1213
+
1214
+ /* Sierra backends using AVX10.2 with VMPSADBW.
1215
+ * Optimized for I8/U8 via VMPSADBW (vector multiply-sum of absolute differences).
1216
+ */
1217
+ #if NK_TARGET_SIERRA
1218
+ /** @copydoc nk_dots_packed_size_i8 */
1219
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sierra(nk_size_t width, nk_size_t depth);
1220
+ /** @copydoc nk_dots_pack_i8 */
1221
+ NK_PUBLIC void nk_dots_pack_i8_sierra(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1222
+ void *b_packed);
1223
+ /** @copydoc nk_dots_packed_i8 */
1224
+ NK_PUBLIC void nk_dots_packed_i8_sierra(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
1225
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1226
+ /** @copydoc nk_dots_symmetric_i8 */
1227
+ NK_PUBLIC void nk_dots_symmetric_i8_sierra(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1228
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
1229
+ nk_size_t row_start, nk_size_t row_count);
1230
+ /** @copydoc nk_dots_packed_size_u8 */
1231
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sierra(nk_size_t width, nk_size_t depth);
1232
+ /** @copydoc nk_dots_pack_u8 */
1233
+ NK_PUBLIC void nk_dots_pack_u8_sierra(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1234
+ void *b_packed);
1235
+ /** @copydoc nk_dots_packed_u8 */
1236
+ NK_PUBLIC void nk_dots_packed_u8_sierra(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1237
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1238
+ /** @copydoc nk_dots_symmetric_u8 */
1239
+ NK_PUBLIC void nk_dots_symmetric_u8_sierra(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1240
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1241
+ nk_size_t row_start, nk_size_t row_count);
1242
+ /** @copydoc nk_dots_packed_size_e2m3 */
1243
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sierra(nk_size_t width, nk_size_t depth);
1244
+ /** @copydoc nk_dots_pack_e2m3 */
1245
+ NK_PUBLIC void nk_dots_pack_e2m3_sierra(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1246
+ void *b_packed);
1247
+ /** @copydoc nk_dots_packed_e2m3 */
1248
+ NK_PUBLIC void nk_dots_packed_e2m3_sierra(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1249
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1250
+ /** @copydoc nk_dots_symmetric_e2m3 */
1251
+ NK_PUBLIC void nk_dots_symmetric_e2m3_sierra(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1252
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1253
+ nk_size_t row_start, nk_size_t row_count);
1254
+ #endif // NK_TARGET_SIERRA
1255
+
1256
+ /* WASM Relaxed SIMD backends using wasm_i32x4_relaxed_dot_i8x16_i7x16_add.
1257
+ * Covers I8/U8/E2M3 (depth_simd_dimensions=16), BF16/F32 (4), F64 (2).
1258
+ */
1259
+ #if NK_TARGET_V128RELAXED
1260
+ /** @copydoc nk_dots_packed_size_i8 */
1261
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_v128relaxed(nk_size_t width, nk_size_t depth);
1262
+ /** @copydoc nk_dots_pack_i8 */
1263
+ NK_PUBLIC void nk_dots_pack_i8_v128relaxed(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1264
+ void *b_packed);
1265
+ /** @copydoc nk_dots_packed_i8 */
1266
+ NK_PUBLIC void nk_dots_packed_i8_v128relaxed(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
1267
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1268
+ /** @copydoc nk_dots_symmetric_i8 */
1269
+ NK_PUBLIC void nk_dots_symmetric_i8_v128relaxed(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1270
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
1271
+ nk_size_t row_start, nk_size_t row_count);
1272
+ /** @copydoc nk_dots_packed_size_u8 */
1273
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_v128relaxed(nk_size_t width, nk_size_t depth);
1274
+ /** @copydoc nk_dots_pack_u8 */
1275
+ NK_PUBLIC void nk_dots_pack_u8_v128relaxed(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1276
+ void *b_packed);
1277
+ /** @copydoc nk_dots_packed_u8 */
1278
+ NK_PUBLIC void nk_dots_packed_u8_v128relaxed(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1279
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1280
+ /** @copydoc nk_dots_symmetric_u8 */
1281
+ NK_PUBLIC void nk_dots_symmetric_u8_v128relaxed(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1282
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1283
+ nk_size_t row_start, nk_size_t row_count);
1284
+ /** @copydoc nk_dots_packed_size_e2m3 */
1285
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_v128relaxed(nk_size_t width, nk_size_t depth);
1286
+ /** @copydoc nk_dots_pack_e2m3 */
1287
+ NK_PUBLIC void nk_dots_pack_e2m3_v128relaxed(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1288
+ void *b_packed);
1289
+ /** @copydoc nk_dots_packed_e2m3 */
1290
+ NK_PUBLIC void nk_dots_packed_e2m3_v128relaxed(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1291
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride,
1292
+ nk_size_t c_stride);
1293
+ /** @copydoc nk_dots_symmetric_e2m3 */
1294
+ NK_PUBLIC void nk_dots_symmetric_e2m3_v128relaxed(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1295
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1296
+ nk_size_t row_start, nk_size_t row_count);
1297
+ /** @copydoc nk_dots_packed_size_bf16 */
1298
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_v128relaxed(nk_size_t width, nk_size_t depth);
1299
+ /** @copydoc nk_dots_pack_bf16 */
1300
+ NK_PUBLIC void nk_dots_pack_bf16_v128relaxed(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1301
+ void *b_packed);
1302
+ /** @copydoc nk_dots_packed_bf16 */
1303
+ NK_PUBLIC void nk_dots_packed_bf16_v128relaxed(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1304
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride,
1305
+ nk_size_t c_stride);
1306
+ /** @copydoc nk_dots_symmetric_bf16 */
1307
+ NK_PUBLIC void nk_dots_symmetric_bf16_v128relaxed(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1308
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1309
+ nk_size_t row_start, nk_size_t row_count);
1310
+ /** @copydoc nk_dots_packed_size_f32 */
1311
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_v128relaxed(nk_size_t width, nk_size_t depth);
1312
+ /** @copydoc nk_dots_pack_f32 */
1313
+ NK_PUBLIC void nk_dots_pack_f32_v128relaxed(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1314
+ void *b_packed);
1315
+ /** @copydoc nk_dots_packed_f32 */
1316
+ NK_PUBLIC void nk_dots_packed_f32_v128relaxed(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1317
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1318
+ /** @copydoc nk_dots_symmetric_f32 */
1319
+ NK_PUBLIC void nk_dots_symmetric_f32_v128relaxed(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1320
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1321
+ nk_size_t row_start, nk_size_t row_count);
1322
+ /** @copydoc nk_dots_packed_size_f64 */
1323
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_v128relaxed(nk_size_t width, nk_size_t depth);
1324
+ /** @copydoc nk_dots_pack_f64 */
1325
+ NK_PUBLIC void nk_dots_pack_f64_v128relaxed(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1326
+ void *b_packed);
1327
+ /** @copydoc nk_dots_packed_f64 */
1328
+ NK_PUBLIC void nk_dots_packed_f64_v128relaxed(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1329
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1330
+ /** @copydoc nk_dots_symmetric_f64 */
1331
+ NK_PUBLIC void nk_dots_symmetric_f64_v128relaxed(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1332
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1333
+ nk_size_t row_start, nk_size_t row_count);
1334
+ /** @copydoc nk_dots_packed_size_bf16 */
1335
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_v128relaxed(nk_size_t width, nk_size_t depth);
1336
+ /** @copydoc nk_dots_pack_bf16 */
1337
+ NK_PUBLIC void nk_dots_pack_e4m3_v128relaxed(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1338
+ void *b_packed);
1339
+ /** @copydoc nk_dots_packed_bf16 */
1340
+ NK_PUBLIC void nk_dots_packed_e4m3_v128relaxed(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1341
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride,
1342
+ nk_size_t c_stride);
1343
+ /** @copydoc nk_dots_symmetric_bf16 */
1344
+ NK_PUBLIC void nk_dots_symmetric_e4m3_v128relaxed(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1345
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1346
+ nk_size_t row_start, nk_size_t row_count);
1347
+ /** @copydoc nk_dots_packed_size_bf16 */
1348
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_v128relaxed(nk_size_t width, nk_size_t depth);
1349
+ /** @copydoc nk_dots_pack_bf16 */
1350
+ NK_PUBLIC void nk_dots_pack_e5m2_v128relaxed(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1351
+ void *b_packed);
1352
+ /** @copydoc nk_dots_packed_bf16 */
1353
+ NK_PUBLIC void nk_dots_packed_e5m2_v128relaxed(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1354
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride,
1355
+ nk_size_t c_stride);
1356
+ /** @copydoc nk_dots_symmetric_bf16 */
1357
+ NK_PUBLIC void nk_dots_symmetric_e5m2_v128relaxed(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1358
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1359
+ nk_size_t row_start, nk_size_t row_count);
1360
+ /** @copydoc nk_dots_packed_size_bf16 */
1361
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u4_v128relaxed(nk_size_t width, nk_size_t depth);
1362
+ /** @copydoc nk_dots_pack_bf16 */
1363
+ NK_PUBLIC void nk_dots_pack_u4_v128relaxed(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1364
+ void *b_packed);
1365
+ /** @copydoc nk_dots_packed_bf16 */
1366
+ NK_PUBLIC void nk_dots_packed_u4_v128relaxed(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1367
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1368
+ /** @copydoc nk_dots_symmetric_bf16 */
1369
+ NK_PUBLIC void nk_dots_symmetric_u4_v128relaxed(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1370
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1371
+ nk_size_t row_start, nk_size_t row_count);
1372
+ /** @copydoc nk_dots_packed_size_bf16 */
1373
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i4_v128relaxed(nk_size_t width, nk_size_t depth);
1374
+ /** @copydoc nk_dots_pack_bf16 */
1375
+ NK_PUBLIC void nk_dots_pack_i4_v128relaxed(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1376
+ void *b_packed);
1377
+ /** @copydoc nk_dots_packed_bf16 */
1378
+ NK_PUBLIC void nk_dots_packed_i4_v128relaxed(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
1379
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1380
+ /** @copydoc nk_dots_symmetric_bf16 */
1381
+ NK_PUBLIC void nk_dots_symmetric_i4_v128relaxed(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1382
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
1383
+ nk_size_t row_start, nk_size_t row_count);
1384
+ /** @copydoc nk_dots_packed_size_u1 */
1385
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u1_v128relaxed(nk_size_t width, nk_size_t depth);
1386
+ /** @copydoc nk_dots_pack_u1 */
1387
+ NK_PUBLIC void nk_dots_pack_u1_v128relaxed(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1388
+ void *b_packed);
1389
+ /** @copydoc nk_dots_packed_u1 */
1390
+ NK_PUBLIC void nk_dots_packed_u1_v128relaxed(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1391
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1392
+ /** @copydoc nk_dots_symmetric_u1 */
1393
+ NK_PUBLIC void nk_dots_symmetric_u1_v128relaxed(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1394
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1395
+ nk_size_t row_start, nk_size_t row_count);
1396
+ #endif // NK_TARGET_V128RELAXED
1397
+
1398
+ /* ARM NEON backends (base NEON with F32/F64 support).
1399
+ * Uses FMLA for F32 dots, FMLA (scalar) for F64.
1400
+ */
1401
+ #if NK_TARGET_NEON
1402
+ /** @copydoc nk_dots_packed_size_f32 */
1403
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_neon(nk_size_t width, nk_size_t depth);
1404
+ /** @copydoc nk_dots_pack_f32 */
1405
+ NK_PUBLIC void nk_dots_pack_f32_neon(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1406
+ void *b_packed);
1407
+ /** @copydoc nk_dots_packed_f32 */
1408
+ NK_PUBLIC void nk_dots_packed_f32_neon(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1409
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1410
+ /** @copydoc nk_dots_symmetric_f32 */
1411
+ NK_PUBLIC void nk_dots_symmetric_f32_neon(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1412
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1413
+ nk_size_t row_start, nk_size_t row_count);
1414
+ /** @copydoc nk_dots_packed_size_f64 */
1415
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_neon(nk_size_t width, nk_size_t depth);
1416
+ /** @copydoc nk_dots_pack_f64 */
1417
+ NK_PUBLIC void nk_dots_pack_f64_neon(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1418
+ void *b_packed);
1419
+ /** @copydoc nk_dots_packed_f64 */
1420
+ NK_PUBLIC void nk_dots_packed_f64_neon(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1421
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1422
+ /** @copydoc nk_dots_symmetric_f64 */
1423
+ NK_PUBLIC void nk_dots_symmetric_f64_neon(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1424
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1425
+ nk_size_t row_start, nk_size_t row_count);
1426
+ /** @copydoc nk_dots_packed_size_u1 */
1427
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u1_neon(nk_size_t width, nk_size_t depth);
1428
+ /** @copydoc nk_dots_pack_u1 */
1429
+ NK_PUBLIC void nk_dots_pack_u1_neon(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1430
+ void *b_packed);
1431
+ /** @copydoc nk_dots_packed_u1 */
1432
+ NK_PUBLIC void nk_dots_packed_u1_neon(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1433
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1434
+ /** @copydoc nk_dots_symmetric_u1 */
1435
+ NK_PUBLIC void nk_dots_symmetric_u1_neon(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1436
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1437
+ nk_size_t row_start, nk_size_t row_count);
1438
+ /** @copydoc nk_dots_packed_size_f16 */
1439
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_neon(nk_size_t width, nk_size_t depth);
1440
+ /** @copydoc nk_dots_pack_f16 */
1441
+ NK_PUBLIC void nk_dots_pack_f16_neon(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1442
+ void *b_packed);
1443
+ /** @copydoc nk_dots_packed_f16 */
1444
+ NK_PUBLIC void nk_dots_packed_f16_neon(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1445
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1446
+ /** @copydoc nk_dots_symmetric_f16 */
1447
+ NK_PUBLIC void nk_dots_symmetric_f16_neon(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1448
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1449
+ nk_size_t row_start, nk_size_t row_count);
1450
+ /** @copydoc nk_dots_packed_size_bf16 */
1451
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_neon(nk_size_t width, nk_size_t depth);
1452
+ /** @copydoc nk_dots_pack_bf16 */
1453
+ NK_PUBLIC void nk_dots_pack_bf16_neon(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1454
+ void *b_packed);
1455
+ /** @copydoc nk_dots_packed_bf16 */
1456
+ NK_PUBLIC void nk_dots_packed_bf16_neon(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1457
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1458
+ /** @copydoc nk_dots_symmetric_bf16 */
1459
+ NK_PUBLIC void nk_dots_symmetric_bf16_neon(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1460
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1461
+ nk_size_t row_start, nk_size_t row_count);
1462
+ #endif // NK_TARGET_NEON
1463
+
1464
+ /* ARM NEON with F16 arithmetic (ARMv8.2-A FP16).
1465
+ * Provides native F16 FMLA for half-precision dot products.
1466
+ */
1467
+ #if NK_TARGET_NEONHALF
1468
+ /** @copydoc nk_dots_packed_size_f16 */
1469
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_neonhalf(nk_size_t width, nk_size_t depth);
1470
+ /** @copydoc nk_dots_pack_f16 */
1471
+ NK_PUBLIC void nk_dots_pack_f16_neonhalf(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1472
+ void *b_packed);
1473
+ /** @copydoc nk_dots_packed_f16 */
1474
+ NK_PUBLIC void nk_dots_packed_f16_neonhalf(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1475
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1476
+ /** @copydoc nk_dots_symmetric_f16 */
1477
+ NK_PUBLIC void nk_dots_symmetric_f16_neonhalf(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1478
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1479
+ nk_size_t row_start, nk_size_t row_count);
1480
+ #endif // NK_TARGET_NEONHALF
1481
+
1482
+ /* ARM NEON with BF16 dot product (ARMv8.6-A BF16).
1483
+ * Uses BFDOT/BFMMLA for efficient BF16 matrix operations.
1484
+ */
1485
+ #if NK_TARGET_NEONBFDOT
1486
+ /** @copydoc nk_dots_packed_size_bf16 */
1487
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_neonbfdot(nk_size_t width, nk_size_t depth);
1488
+ /** @copydoc nk_dots_pack_bf16 */
1489
+ NK_PUBLIC void nk_dots_pack_bf16_neonbfdot(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1490
+ void *b_packed);
1491
+ /** @copydoc nk_dots_packed_bf16 */
1492
+ NK_PUBLIC void nk_dots_packed_bf16_neonbfdot(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1493
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1494
+ /** @copydoc nk_dots_symmetric_bf16 */
1495
+ NK_PUBLIC void nk_dots_symmetric_bf16_neonbfdot(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1496
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1497
+ nk_size_t row_start, nk_size_t row_count);
1498
+ #endif // NK_TARGET_NEONBFDOT
1499
+
1500
+ /* ARM NEON with signed/unsigned dot product (ARMv8.2-A DotProd).
1501
+ * Provides SDOT/UDOT for I8/U8 vector dot products.
1502
+ */
1503
+ #if NK_TARGET_NEONSDOT
1504
+ /** @copydoc nk_dots_packed_size_i8 */
1505
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_neonsdot(nk_size_t width, nk_size_t depth);
1506
+ /** @copydoc nk_dots_pack_i8 */
1507
+ NK_PUBLIC void nk_dots_pack_i8_neonsdot(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1508
+ void *b_packed);
1509
+ /** @copydoc nk_dots_packed_i8 */
1510
+ NK_PUBLIC void nk_dots_packed_i8_neonsdot(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
1511
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1512
+ /** @copydoc nk_dots_symmetric_i8 */
1513
+ NK_PUBLIC void nk_dots_symmetric_i8_neonsdot(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1514
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
1515
+ nk_size_t row_start, nk_size_t row_count);
1516
+ /** @copydoc nk_dots_packed_size_u8 */
1517
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_neonsdot(nk_size_t width, nk_size_t depth);
1518
+ /** @copydoc nk_dots_pack_u8 */
1519
+ NK_PUBLIC void nk_dots_pack_u8_neonsdot(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1520
+ void *b_packed);
1521
+ /** @copydoc nk_dots_packed_u8 */
1522
+ NK_PUBLIC void nk_dots_packed_u8_neonsdot(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1523
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1524
+ /** @copydoc nk_dots_symmetric_u8 */
1525
+ NK_PUBLIC void nk_dots_symmetric_u8_neonsdot(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1526
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
1527
+ nk_size_t row_start, nk_size_t row_count);
1528
+ #endif // NK_TARGET_NEONSDOT
1529
+
1530
+ /* ARM NEON with FP16 FML (fused multiply-long, ARMv8.2-A FP16FML).
1531
+ * Uses FMLAL/FMLSL for F16 and custom FP8 (E2M3/E3M2) operations.
1532
+ */
1533
+ #if NK_TARGET_NEONFHM
1534
+ /** @copydoc nk_dots_packed_size_f16 */
1535
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_neonfhm(nk_size_t width, nk_size_t depth);
1536
+ /** @copydoc nk_dots_pack_f16 */
1537
+ NK_PUBLIC void nk_dots_pack_f16_neonfhm(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1538
+ void *b_packed);
1539
+ /** @copydoc nk_dots_packed_f16 */
1540
+ NK_PUBLIC void nk_dots_packed_f16_neonfhm(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1541
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1542
+ /** @copydoc nk_dots_symmetric_f16 */
1543
+ NK_PUBLIC void nk_dots_symmetric_f16_neonfhm(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1544
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1545
+ nk_size_t row_start, nk_size_t row_count);
1546
+ /** @copydoc nk_dots_packed_size_e4m3 */
1547
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_neonfhm(nk_size_t width, nk_size_t depth);
1548
+ /** @copydoc nk_dots_pack_e4m3 */
1549
+ NK_PUBLIC void nk_dots_pack_e4m3_neonfhm(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1550
+ void *b_packed);
1551
+ /** @copydoc nk_dots_packed_e4m3 */
1552
+ NK_PUBLIC void nk_dots_packed_e4m3_neonfhm(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1553
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1554
+ /** @copydoc nk_dots_symmetric_e4m3 */
1555
+ NK_PUBLIC void nk_dots_symmetric_e4m3_neonfhm(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1556
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1557
+ nk_size_t row_start, nk_size_t row_count);
1558
+ /** @copydoc nk_dots_packed_size_e5m2 */
1559
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_neonfhm(nk_size_t width, nk_size_t depth);
1560
+ /** @copydoc nk_dots_pack_e5m2 */
1561
+ NK_PUBLIC void nk_dots_pack_e5m2_neonfhm(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1562
+ void *b_packed);
1563
+ /** @copydoc nk_dots_packed_e5m2 */
1564
+ NK_PUBLIC void nk_dots_packed_e5m2_neonfhm(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1565
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1566
+ /** @copydoc nk_dots_symmetric_e5m2 */
1567
+ NK_PUBLIC void nk_dots_symmetric_e5m2_neonfhm(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1568
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1569
+ nk_size_t row_start, nk_size_t row_count);
1570
+ #endif // NK_TARGET_NEONFHM
1571
+
1572
+ #if NK_TARGET_RVV
1573
+ /** @copydoc nk_dots_packed_size_e2m3 */
1574
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_rvv(nk_size_t width, nk_size_t depth);
1575
+ /** @copydoc nk_dots_pack_e2m3 */
1576
+ NK_PUBLIC void nk_dots_pack_e2m3_rvv(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1577
+ void *b_packed);
1578
+ /** @copydoc nk_dots_packed_e2m3 */
1579
+ NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1580
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1581
+ /** @copydoc nk_dots_symmetric_e2m3 */
1582
+ NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1583
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1584
+ nk_size_t row_start, nk_size_t row_count);
1585
+ /** @copydoc nk_dots_packed_size_e3m2 */
1586
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_rvv(nk_size_t width, nk_size_t depth);
1587
+ /** @copydoc nk_dots_pack_e3m2 */
1588
+ NK_PUBLIC void nk_dots_pack_e3m2_rvv(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1589
+ void *b_packed);
1590
+ /** @copydoc nk_dots_packed_e3m2 */
1591
+ NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1592
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1593
+ /** @copydoc nk_dots_symmetric_e3m2 */
1594
+ NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1595
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1596
+ nk_size_t row_start, nk_size_t row_count);
1597
+ /** @copydoc nk_dots_packed_size_f32 */
1598
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_rvv(nk_size_t width, nk_size_t depth);
1599
+ /** @copydoc nk_dots_pack_f32 */
1600
+ NK_PUBLIC void nk_dots_pack_f32_rvv(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1601
+ void *b_packed);
1602
+ /** @copydoc nk_dots_packed_f32 */
1603
+ NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1604
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1605
+ /** @copydoc nk_dots_symmetric_f32 */
1606
+ NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1607
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1608
+ nk_size_t row_start, nk_size_t row_count);
1609
+ /** @copydoc nk_dots_packed_size_f64 */
1610
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_rvv(nk_size_t width, nk_size_t depth);
1611
+ /** @copydoc nk_dots_pack_f64 */
1612
+ NK_PUBLIC void nk_dots_pack_f64_rvv(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1613
+ void *b_packed);
1614
+ /** @copydoc nk_dots_packed_f64 */
1615
+ NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1616
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1617
+ /** @copydoc nk_dots_symmetric_f64 */
1618
+ NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1619
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1620
+ nk_size_t row_start, nk_size_t row_count);
1621
+ /** @copydoc nk_dots_packed_size_bf16 */
1622
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_rvv(nk_size_t width, nk_size_t depth);
1623
+ /** @copydoc nk_dots_pack_bf16 */
1624
+ NK_PUBLIC void nk_dots_pack_bf16_rvv(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1625
+ void *b_packed);
1626
+ /** @copydoc nk_dots_packed_bf16 */
1627
+ NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1628
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1629
+ /** @copydoc nk_dots_symmetric_bf16 */
1630
+ NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1631
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1632
+ nk_size_t row_start, nk_size_t row_count);
1633
+ /** @copydoc nk_dots_packed_size_f16 */
1634
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_rvv(nk_size_t width, nk_size_t depth);
1635
+ /** @copydoc nk_dots_pack_f16 */
1636
+ NK_PUBLIC void nk_dots_pack_f16_rvv(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1637
+ void *b_packed);
1638
+ /** @copydoc nk_dots_packed_f16 */
1639
+ NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1640
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1641
+ /** @copydoc nk_dots_symmetric_f16 */
1642
+ NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1643
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1644
+ nk_size_t row_start, nk_size_t row_count);
1645
+ /** @copydoc nk_dots_packed_size_i8 */
1646
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_rvv(nk_size_t width, nk_size_t depth);
1647
+ /** @copydoc nk_dots_pack_i8 */
1648
+ NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1649
+ void *b_packed);
1650
+ /** @copydoc nk_dots_packed_i8 */
1651
+ NK_PUBLIC void nk_dots_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
1652
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1653
+ /** @copydoc nk_dots_symmetric_i8 */
1654
+ NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
1655
+ nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
1656
+ nk_size_t row_count);
1657
+ /** @copydoc nk_dots_packed_size_u8 */
1658
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_rvv(nk_size_t width, nk_size_t depth);
1659
+ /** @copydoc nk_dots_pack_u8 */
1660
+ NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1661
+ void *b_packed);
1662
+ /** @copydoc nk_dots_packed_u8 */
1663
+ NK_PUBLIC void nk_dots_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
1664
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1665
+ /** @copydoc nk_dots_symmetric_u8 */
1666
+ NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
1667
+ nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
1668
+ nk_size_t row_count);
1669
+ /** @copydoc nk_dots_packed_size_e4m3 */
1670
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_rvv(nk_size_t width, nk_size_t depth);
1671
+ /** @copydoc nk_dots_pack_e4m3 */
1672
+ NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1673
+ void *b_packed);
1674
+ /** @copydoc nk_dots_packed_e4m3 */
1675
+ NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1676
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1677
+ /** @copydoc nk_dots_symmetric_e4m3 */
1678
+ NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1679
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1680
+ nk_size_t row_start, nk_size_t row_count);
1681
+ /** @copydoc nk_dots_packed_size_e5m2 */
1682
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_rvv(nk_size_t width, nk_size_t depth);
1683
+ /** @copydoc nk_dots_pack_e5m2 */
1684
+ NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1685
+ void *b_packed);
1686
+ /** @copydoc nk_dots_packed_e5m2 */
1687
+ NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1688
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
1689
+ /** @copydoc nk_dots_symmetric_e5m2 */
1690
+ NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1691
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1692
+ nk_size_t row_start, nk_size_t row_count);
1693
+ #endif // NK_TARGET_RVV
1694
+
1695
+ #if defined(__cplusplus)
1696
+ } // extern "C"
1697
+ #endif
1698
+
1699
+ #include "numkong/dots/serial.h"
1700
+ #include "numkong/dots/haswell.h"
1701
+ #include "numkong/dots/skylake.h"
1702
+ #include "numkong/dots/icelake.h"
1703
+ #include "numkong/dots/alder.h"
1704
+ #include "numkong/dots/sierra.h"
1705
+ #include "numkong/dots/genoa.h"
1706
+ #include "numkong/dots/sapphireamx.h"
1707
+ #include "numkong/dots/neon.h"
1708
+ #include "numkong/dots/neonsdot.h"
1709
+ #include "numkong/dots/neonhalf.h"
1710
+ #include "numkong/dots/neonfhm.h"
1711
+ #include "numkong/dots/neonbfdot.h"
1712
+ #include "numkong/dots/sme.h"
1713
+ #include "numkong/dots/smef64.h"
1714
+ #include "numkong/dots/smebi32.h"
1715
+ #include "numkong/dots/rvv.h"
1716
+ #include "numkong/dots/v128relaxed.h"
1717
+
1718
+ #if defined(__cplusplus)
1719
+ extern "C" {
1720
+ #endif
1721
+
1722
+ #if !NK_DYNAMIC_DISPATCH
1723
+
1724
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32(nk_size_t width, nk_size_t depth) {
1725
+ #if NK_TARGET_SMEF64
1726
+ return nk_dots_packed_size_f32_smef64(width, depth);
1727
+ #elif NK_TARGET_SKYLAKE
1728
+ return nk_dots_packed_size_f32_skylake(width, depth);
1729
+ #elif NK_TARGET_HASWELL
1730
+ return nk_dots_packed_size_f32_haswell(width, depth);
1731
+ #elif NK_TARGET_NEON
1732
+ return nk_dots_packed_size_f32_neon(width, depth);
1733
+ #elif NK_TARGET_RVV
1734
+ return nk_dots_packed_size_f32_rvv(width, depth);
1735
+ #elif NK_TARGET_V128RELAXED
1736
+ return nk_dots_packed_size_f32_v128relaxed(width, depth);
1737
+ #else
1738
+ return nk_dots_packed_size_f32_serial(width, depth);
1739
+ #endif
1740
+ }
1741
+
1742
+ NK_PUBLIC void nk_dots_pack_f32(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1743
+ void *b_packed) {
1744
+ #if NK_TARGET_SMEF64
1745
+ nk_dots_pack_f32_smef64(b, width, depth, b_stride, b_packed);
1746
+ #elif NK_TARGET_SKYLAKE
1747
+ nk_dots_pack_f32_skylake(b, width, depth, b_stride, b_packed);
1748
+ #elif NK_TARGET_HASWELL
1749
+ nk_dots_pack_f32_haswell(b, width, depth, b_stride, b_packed);
1750
+ #elif NK_TARGET_NEON
1751
+ nk_dots_pack_f32_neon(b, width, depth, b_stride, b_packed);
1752
+ #elif NK_TARGET_RVV
1753
+ nk_dots_pack_f32_rvv(b, width, depth, b_stride, b_packed);
1754
+ #elif NK_TARGET_V128RELAXED
1755
+ nk_dots_pack_f32_v128relaxed(b, width, depth, b_stride, b_packed);
1756
+ #else
1757
+ nk_dots_pack_f32_serial(b, width, depth, b_stride, b_packed);
1758
+ #endif
1759
+ }
1760
+
1761
+ NK_PUBLIC void nk_dots_packed_f32(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1762
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
1763
+ #if NK_TARGET_SMEF64
1764
+ nk_dots_packed_f32_smef64(a, b_packed, c, height, width, depth, a_stride, c_stride);
1765
+ #elif NK_TARGET_SKYLAKE
1766
+ nk_dots_packed_f32_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
1767
+ #elif NK_TARGET_HASWELL
1768
+ nk_dots_packed_f32_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
1769
+ #elif NK_TARGET_NEON
1770
+ nk_dots_packed_f32_neon(a, b_packed, c, height, width, depth, a_stride, c_stride);
1771
+ #elif NK_TARGET_RVV
1772
+ nk_dots_packed_f32_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
1773
+ #elif NK_TARGET_V128RELAXED
1774
+ nk_dots_packed_f32_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
1775
+ #else
1776
+ nk_dots_packed_f32_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
1777
+ #endif
1778
+ }
1779
+
1780
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64(nk_size_t width, nk_size_t depth) {
1781
+ #if NK_TARGET_SMEF64
1782
+ return nk_dots_packed_size_f64_smef64(width, depth);
1783
+ #elif NK_TARGET_SKYLAKE
1784
+ return nk_dots_packed_size_f64_skylake(width, depth);
1785
+ #elif NK_TARGET_HASWELL
1786
+ return nk_dots_packed_size_f64_haswell(width, depth);
1787
+ #elif NK_TARGET_NEON
1788
+ return nk_dots_packed_size_f64_neon(width, depth);
1789
+ #elif NK_TARGET_RVV
1790
+ return nk_dots_packed_size_f64_rvv(width, depth);
1791
+ #elif NK_TARGET_V128RELAXED
1792
+ return nk_dots_packed_size_f64_v128relaxed(width, depth);
1793
+ #else
1794
+ return nk_dots_packed_size_f64_serial(width, depth);
1795
+ #endif
1796
+ }
1797
+
1798
+ NK_PUBLIC void nk_dots_pack_f64(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1799
+ void *b_packed) {
1800
+ #if NK_TARGET_SMEF64
1801
+ nk_dots_pack_f64_smef64(b, width, depth, b_stride, b_packed);
1802
+ #elif NK_TARGET_SKYLAKE
1803
+ nk_dots_pack_f64_skylake(b, width, depth, b_stride, b_packed);
1804
+ #elif NK_TARGET_HASWELL
1805
+ nk_dots_pack_f64_haswell(b, width, depth, b_stride, b_packed);
1806
+ #elif NK_TARGET_NEON
1807
+ nk_dots_pack_f64_neon(b, width, depth, b_stride, b_packed);
1808
+ #elif NK_TARGET_RVV
1809
+ nk_dots_pack_f64_rvv(b, width, depth, b_stride, b_packed);
1810
+ #elif NK_TARGET_V128RELAXED
1811
+ nk_dots_pack_f64_v128relaxed(b, width, depth, b_stride, b_packed);
1812
+ #else
1813
+ nk_dots_pack_f64_serial(b, width, depth, b_stride, b_packed);
1814
+ #endif
1815
+ }
1816
+
1817
+ NK_PUBLIC void nk_dots_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
1818
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
1819
+ #if NK_TARGET_SMEF64
1820
+ nk_dots_packed_f64_smef64(a, b_packed, c, height, width, depth, a_stride, c_stride);
1821
+ #elif NK_TARGET_SKYLAKE
1822
+ nk_dots_packed_f64_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
1823
+ #elif NK_TARGET_HASWELL
1824
+ nk_dots_packed_f64_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
1825
+ #elif NK_TARGET_NEON
1826
+ nk_dots_packed_f64_neon(a, b_packed, c, height, width, depth, a_stride, c_stride);
1827
+ #elif NK_TARGET_RVV
1828
+ nk_dots_packed_f64_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
1829
+ #elif NK_TARGET_V128RELAXED
1830
+ nk_dots_packed_f64_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
1831
+ #else
1832
+ nk_dots_packed_f64_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
1833
+ #endif
1834
+ }
1835
+
1836
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth) {
1837
+ #if NK_TARGET_SME
1838
+ return nk_dots_packed_size_f16_sme(width, depth);
1839
+ #elif NK_TARGET_NEONFHM
1840
+ return nk_dots_packed_size_f16_neonfhm(width, depth);
1841
+ #elif NK_TARGET_NEONHALF
1842
+ return nk_dots_packed_size_f16_neonhalf(width, depth);
1843
+ #elif NK_TARGET_NEON
1844
+ return nk_dots_packed_size_f16_neon(width, depth);
1845
+ #elif NK_TARGET_SKYLAKE
1846
+ return nk_dots_packed_size_f16_skylake(width, depth);
1847
+ #elif NK_TARGET_HASWELL
1848
+ return nk_dots_packed_size_f16_haswell(width, depth);
1849
+ #elif NK_TARGET_RVV
1850
+ return nk_dots_packed_size_f16_rvv(width, depth);
1851
+ #else
1852
+ return nk_dots_packed_size_f16_serial(width, depth);
1853
+ #endif
1854
+ }
1855
+
1856
+ NK_PUBLIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1857
+ void *b_packed) {
1858
+ #if NK_TARGET_SME
1859
+ nk_dots_pack_f16_sme(b, width, depth, b_stride, b_packed);
1860
+ #elif NK_TARGET_NEONFHM
1861
+ nk_dots_pack_f16_neonfhm(b, width, depth, b_stride, b_packed);
1862
+ #elif NK_TARGET_NEONHALF
1863
+ nk_dots_pack_f16_neonhalf(b, width, depth, b_stride, b_packed);
1864
+ #elif NK_TARGET_NEON
1865
+ nk_dots_pack_f16_neon(b, width, depth, b_stride, b_packed);
1866
+ #elif NK_TARGET_SKYLAKE
1867
+ nk_dots_pack_f16_skylake(b, width, depth, b_stride, b_packed);
1868
+ #elif NK_TARGET_HASWELL
1869
+ nk_dots_pack_f16_haswell(b, width, depth, b_stride, b_packed);
1870
+ #elif NK_TARGET_RVV
1871
+ nk_dots_pack_f16_rvv(b, width, depth, b_stride, b_packed);
1872
+ #else
1873
+ nk_dots_pack_f16_serial(b, width, depth, b_stride, b_packed);
1874
+ #endif
1875
+ }
1876
+
1877
+ NK_PUBLIC void nk_dots_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1878
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
1879
+ #if NK_TARGET_SME
1880
+ nk_dots_packed_f16_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
1881
+ #elif NK_TARGET_NEONFHM
1882
+ nk_dots_packed_f16_neonfhm(a, b_packed, c, height, width, depth, a_stride, c_stride);
1883
+ #elif NK_TARGET_NEONHALF
1884
+ nk_dots_packed_f16_neonhalf(a, b_packed, c, height, width, depth, a_stride, c_stride);
1885
+ #elif NK_TARGET_NEON
1886
+ nk_dots_packed_f16_neon(a, b_packed, c, height, width, depth, a_stride, c_stride);
1887
+ #elif NK_TARGET_SKYLAKE
1888
+ nk_dots_packed_f16_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
1889
+ #elif NK_TARGET_HASWELL
1890
+ nk_dots_packed_f16_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
1891
+ #elif NK_TARGET_RVV
1892
+ nk_dots_packed_f16_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
1893
+ #else
1894
+ nk_dots_packed_f16_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
1895
+ #endif
1896
+ }
1897
+
1898
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16(nk_size_t width, nk_size_t depth) {
1899
+ #if NK_TARGET_SME
1900
+ return nk_dots_packed_size_bf16_sme(width, depth);
1901
+ #elif NK_TARGET_SAPPHIREAMX
1902
+ return nk_dots_packed_size_bf16_sapphireamx(width, depth);
1903
+ #elif NK_TARGET_NEONBFDOT
1904
+ return nk_dots_packed_size_bf16_neonbfdot(width, depth);
1905
+ #elif NK_TARGET_GENOA
1906
+ return nk_dots_packed_size_bf16_genoa(width, depth);
1907
+ #elif NK_TARGET_SKYLAKE
1908
+ return nk_dots_packed_size_bf16_skylake(width, depth);
1909
+ #elif NK_TARGET_HASWELL
1910
+ return nk_dots_packed_size_bf16_haswell(width, depth);
1911
+ #elif NK_TARGET_RVV
1912
+ return nk_dots_packed_size_bf16_rvv(width, depth);
1913
+ #elif NK_TARGET_V128RELAXED
1914
+ return nk_dots_packed_size_bf16_v128relaxed(width, depth);
1915
+ #else
1916
+ return nk_dots_packed_size_bf16_serial(width, depth);
1917
+ #endif
1918
+ }
1919
+
1920
+ NK_PUBLIC void nk_dots_pack_bf16(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
1921
+ void *b_packed) {
1922
+ #if NK_TARGET_SME
1923
+ nk_dots_pack_bf16_sme(b, width, depth, b_stride, b_packed);
1924
+ #elif NK_TARGET_SAPPHIREAMX
1925
+ nk_dots_pack_bf16_sapphireamx(b, width, depth, b_stride, b_packed);
1926
+ #elif NK_TARGET_NEONBFDOT
1927
+ nk_dots_pack_bf16_neonbfdot(b, width, depth, b_stride, b_packed);
1928
+ #elif NK_TARGET_GENOA
1929
+ nk_dots_pack_bf16_genoa(b, width, depth, b_stride, b_packed);
1930
+ #elif NK_TARGET_SKYLAKE
1931
+ nk_dots_pack_bf16_skylake(b, width, depth, b_stride, b_packed);
1932
+ #elif NK_TARGET_HASWELL
1933
+ nk_dots_pack_bf16_haswell(b, width, depth, b_stride, b_packed);
1934
+ #elif NK_TARGET_RVV
1935
+ nk_dots_pack_bf16_rvv(b, width, depth, b_stride, b_packed);
1936
+ #elif NK_TARGET_V128RELAXED
1937
+ nk_dots_pack_bf16_v128relaxed(b, width, depth, b_stride, b_packed);
1938
+ #else
1939
+ nk_dots_pack_bf16_serial(b, width, depth, b_stride, b_packed);
1940
+ #endif
1941
+ }
1942
+
1943
+ NK_PUBLIC void nk_dots_packed_bf16(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
1944
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
1945
+ #if NK_TARGET_SME
1946
+ nk_dots_packed_bf16_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
1947
+ #elif NK_TARGET_SAPPHIREAMX
1948
+ nk_dots_packed_bf16_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
1949
+ #elif NK_TARGET_NEONBFDOT
1950
+ nk_dots_packed_bf16_neonbfdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
1951
+ #elif NK_TARGET_GENOA
1952
+ nk_dots_packed_bf16_genoa(a, b_packed, c, height, width, depth, a_stride, c_stride);
1953
+ #elif NK_TARGET_SKYLAKE
1954
+ nk_dots_packed_bf16_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
1955
+ #elif NK_TARGET_HASWELL
1956
+ nk_dots_packed_bf16_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
1957
+ #elif NK_TARGET_RVV
1958
+ nk_dots_packed_bf16_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
1959
+ #elif NK_TARGET_V128RELAXED
1960
+ nk_dots_packed_bf16_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
1961
+ #else
1962
+ nk_dots_packed_bf16_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
1963
+ #endif
1964
+ }
1965
+
1966
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8(nk_size_t width, nk_size_t depth) {
1967
+ #if NK_TARGET_SME
1968
+ return nk_dots_packed_size_i8_sme(width, depth);
1969
+ #elif NK_TARGET_SAPPHIREAMX
1970
+ return nk_dots_packed_size_i8_sapphireamx(width, depth);
1971
+ #elif NK_TARGET_NEONSDOT
1972
+ return nk_dots_packed_size_i8_neonsdot(width, depth);
1973
+ #elif NK_TARGET_ICELAKE
1974
+ return nk_dots_packed_size_i8_icelake(width, depth);
1975
+ #elif NK_TARGET_SIERRA
1976
+ return nk_dots_packed_size_i8_sierra(width, depth);
1977
+ #elif NK_TARGET_ALDER
1978
+ return nk_dots_packed_size_i8_alder(width, depth);
1979
+ #elif NK_TARGET_HASWELL
1980
+ return nk_dots_packed_size_i8_haswell(width, depth);
1981
+ #elif NK_TARGET_RVV
1982
+ return nk_dots_packed_size_i8_rvv(width, depth);
1983
+ #elif NK_TARGET_V128RELAXED
1984
+ return nk_dots_packed_size_i8_v128relaxed(width, depth);
1985
+ #else
1986
+ return nk_dots_packed_size_i8_serial(width, depth);
1987
+ #endif
1988
+ }
1989
+
1990
+ NK_PUBLIC void nk_dots_pack_i8(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride, void *b_packed) {
1991
+ #if NK_TARGET_SME
1992
+ nk_dots_pack_i8_sme(b, width, depth, b_stride, b_packed);
1993
+ #elif NK_TARGET_SAPPHIREAMX
1994
+ nk_dots_pack_i8_sapphireamx(b, width, depth, b_stride, b_packed);
1995
+ #elif NK_TARGET_NEONSDOT
1996
+ nk_dots_pack_i8_neonsdot(b, width, depth, b_stride, b_packed);
1997
+ #elif NK_TARGET_ICELAKE
1998
+ nk_dots_pack_i8_icelake(b, width, depth, b_stride, b_packed);
1999
+ #elif NK_TARGET_SIERRA
2000
+ nk_dots_pack_i8_sierra(b, width, depth, b_stride, b_packed);
2001
+ #elif NK_TARGET_ALDER
2002
+ nk_dots_pack_i8_alder(b, width, depth, b_stride, b_packed);
2003
+ #elif NK_TARGET_HASWELL
2004
+ nk_dots_pack_i8_haswell(b, width, depth, b_stride, b_packed);
2005
+ #elif NK_TARGET_RVV
2006
+ nk_dots_pack_i8_rvv(b, width, depth, b_stride, b_packed);
2007
+ #elif NK_TARGET_V128RELAXED
2008
+ nk_dots_pack_i8_v128relaxed(b, width, depth, b_stride, b_packed);
2009
+ #else
2010
+ nk_dots_pack_i8_serial(b, width, depth, b_stride, b_packed);
2011
+ #endif
2012
+ }
2013
+
2014
+ NK_PUBLIC void nk_dots_packed_i8(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height, nk_size_t width,
2015
+ nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2016
+ #if NK_TARGET_SME
2017
+ nk_dots_packed_i8_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
2018
+ #elif NK_TARGET_SAPPHIREAMX
2019
+ nk_dots_packed_i8_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
2020
+ #elif NK_TARGET_NEONSDOT
2021
+ nk_dots_packed_i8_neonsdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
2022
+ #elif NK_TARGET_ICELAKE
2023
+ nk_dots_packed_i8_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
2024
+ #elif NK_TARGET_SIERRA
2025
+ nk_dots_packed_i8_sierra(a, b_packed, c, height, width, depth, a_stride, c_stride);
2026
+ #elif NK_TARGET_ALDER
2027
+ nk_dots_packed_i8_alder(a, b_packed, c, height, width, depth, a_stride, c_stride);
2028
+ #elif NK_TARGET_HASWELL
2029
+ nk_dots_packed_i8_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
2030
+ #elif NK_TARGET_RVV
2031
+ nk_dots_packed_i8_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
2032
+ #elif NK_TARGET_V128RELAXED
2033
+ nk_dots_packed_i8_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
2034
+ #else
2035
+ nk_dots_packed_i8_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
2036
+ #endif
2037
+ }
2038
+
2039
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8(nk_size_t width, nk_size_t depth) {
2040
+ #if NK_TARGET_SME
2041
+ return nk_dots_packed_size_u8_sme(width, depth);
2042
+ #elif NK_TARGET_SAPPHIREAMX
2043
+ return nk_dots_packed_size_u8_sapphireamx(width, depth);
2044
+ #elif NK_TARGET_NEONSDOT
2045
+ return nk_dots_packed_size_u8_neonsdot(width, depth);
2046
+ #elif NK_TARGET_ICELAKE
2047
+ return nk_dots_packed_size_u8_icelake(width, depth);
2048
+ #elif NK_TARGET_SIERRA
2049
+ return nk_dots_packed_size_u8_sierra(width, depth);
2050
+ #elif NK_TARGET_ALDER
2051
+ return nk_dots_packed_size_u8_alder(width, depth);
2052
+ #elif NK_TARGET_HASWELL
2053
+ return nk_dots_packed_size_u8_haswell(width, depth);
2054
+ #elif NK_TARGET_RVV
2055
+ return nk_dots_packed_size_u8_rvv(width, depth);
2056
+ #elif NK_TARGET_V128RELAXED
2057
+ return nk_dots_packed_size_u8_v128relaxed(width, depth);
2058
+ #else
2059
+ return nk_dots_packed_size_u8_serial(width, depth);
2060
+ #endif
2061
+ }
2062
+
2063
+ NK_PUBLIC void nk_dots_pack_u8(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride, void *b_packed) {
2064
+ #if NK_TARGET_SME
2065
+ nk_dots_pack_u8_sme(b, width, depth, b_stride, b_packed);
2066
+ #elif NK_TARGET_SAPPHIREAMX
2067
+ nk_dots_pack_u8_sapphireamx(b, width, depth, b_stride, b_packed);
2068
+ #elif NK_TARGET_NEONSDOT
2069
+ nk_dots_pack_u8_neonsdot(b, width, depth, b_stride, b_packed);
2070
+ #elif NK_TARGET_ICELAKE
2071
+ nk_dots_pack_u8_icelake(b, width, depth, b_stride, b_packed);
2072
+ #elif NK_TARGET_SIERRA
2073
+ nk_dots_pack_u8_sierra(b, width, depth, b_stride, b_packed);
2074
+ #elif NK_TARGET_ALDER
2075
+ nk_dots_pack_u8_alder(b, width, depth, b_stride, b_packed);
2076
+ #elif NK_TARGET_HASWELL
2077
+ nk_dots_pack_u8_haswell(b, width, depth, b_stride, b_packed);
2078
+ #elif NK_TARGET_RVV
2079
+ nk_dots_pack_u8_rvv(b, width, depth, b_stride, b_packed);
2080
+ #elif NK_TARGET_V128RELAXED
2081
+ nk_dots_pack_u8_v128relaxed(b, width, depth, b_stride, b_packed);
2082
+ #else
2083
+ nk_dots_pack_u8_serial(b, width, depth, b_stride, b_packed);
2084
+ #endif
2085
+ }
2086
+
2087
+ NK_PUBLIC void nk_dots_packed_u8(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height, nk_size_t width,
2088
+ nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2089
+ #if NK_TARGET_SME
2090
+ nk_dots_packed_u8_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
2091
+ #elif NK_TARGET_SAPPHIREAMX
2092
+ nk_dots_packed_u8_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
2093
+ #elif NK_TARGET_NEONSDOT
2094
+ nk_dots_packed_u8_neonsdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
2095
+ #elif NK_TARGET_ICELAKE
2096
+ nk_dots_packed_u8_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
2097
+ #elif NK_TARGET_SIERRA
2098
+ nk_dots_packed_u8_sierra(a, b_packed, c, height, width, depth, a_stride, c_stride);
2099
+ #elif NK_TARGET_ALDER
2100
+ nk_dots_packed_u8_alder(a, b_packed, c, height, width, depth, a_stride, c_stride);
2101
+ #elif NK_TARGET_HASWELL
2102
+ nk_dots_packed_u8_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
2103
+ #elif NK_TARGET_RVV
2104
+ nk_dots_packed_u8_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
2105
+ #elif NK_TARGET_V128RELAXED
2106
+ nk_dots_packed_u8_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
2107
+ #else
2108
+ nk_dots_packed_u8_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
2109
+ #endif
2110
+ }
2111
+
2112
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3(nk_size_t width, nk_size_t depth) {
2113
+ #if NK_TARGET_SME
2114
+ return nk_dots_packed_size_e4m3_sme(width, depth);
2115
+ #elif NK_TARGET_SAPPHIREAMX
2116
+ return nk_dots_packed_size_e4m3_sapphireamx(width, depth);
2117
+ #elif NK_TARGET_NEONFHM
2118
+ return nk_dots_packed_size_e4m3_neonfhm(width, depth);
2119
+ #elif NK_TARGET_GENOA
2120
+ return nk_dots_packed_size_e4m3_genoa(width, depth);
2121
+ #elif NK_TARGET_SKYLAKE
2122
+ return nk_dots_packed_size_e4m3_skylake(width, depth);
2123
+ #elif NK_TARGET_HASWELL
2124
+ return nk_dots_packed_size_e4m3_haswell(width, depth);
2125
+ #elif NK_TARGET_RVV
2126
+ return nk_dots_packed_size_e4m3_rvv(width, depth);
2127
+ #elif NK_TARGET_V128RELAXED
2128
+ return nk_dots_packed_size_e4m3_v128relaxed(width, depth);
2129
+ #else
2130
+ return nk_dots_packed_size_e4m3_serial(width, depth);
2131
+ #endif
2132
+ }
2133
+
2134
+ NK_PUBLIC void nk_dots_pack_e4m3(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
2135
+ void *b_packed) {
2136
+ #if NK_TARGET_SME
2137
+ nk_dots_pack_e4m3_sme(b, width, depth, b_stride, b_packed);
2138
+ #elif NK_TARGET_SAPPHIREAMX
2139
+ nk_dots_pack_e4m3_sapphireamx(b, width, depth, b_stride, b_packed);
2140
+ #elif NK_TARGET_NEONFHM
2141
+ nk_dots_pack_e4m3_neonfhm(b, width, depth, b_stride, b_packed);
2142
+ #elif NK_TARGET_GENOA
2143
+ nk_dots_pack_e4m3_genoa(b, width, depth, b_stride, b_packed);
2144
+ #elif NK_TARGET_SKYLAKE
2145
+ nk_dots_pack_e4m3_skylake(b, width, depth, b_stride, b_packed);
2146
+ #elif NK_TARGET_HASWELL
2147
+ nk_dots_pack_e4m3_haswell(b, width, depth, b_stride, b_packed);
2148
+ #elif NK_TARGET_RVV
2149
+ nk_dots_pack_e4m3_rvv(b, width, depth, b_stride, b_packed);
2150
+ #elif NK_TARGET_V128RELAXED
2151
+ nk_dots_pack_e4m3_v128relaxed(b, width, depth, b_stride, b_packed);
2152
+ #else
2153
+ nk_dots_pack_e4m3_serial(b, width, depth, b_stride, b_packed);
2154
+ #endif
2155
+ }
2156
+
2157
+ NK_PUBLIC void nk_dots_packed_e4m3(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
2158
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2159
+ #if NK_TARGET_SME
2160
+ nk_dots_packed_e4m3_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
2161
+ #elif NK_TARGET_SAPPHIREAMX
2162
+ nk_dots_packed_e4m3_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
2163
+ #elif NK_TARGET_NEONFHM
2164
+ nk_dots_packed_e4m3_neonfhm(a, b_packed, c, height, width, depth, a_stride, c_stride);
2165
+ #elif NK_TARGET_GENOA
2166
+ nk_dots_packed_e4m3_genoa(a, b_packed, c, height, width, depth, a_stride, c_stride);
2167
+ #elif NK_TARGET_SKYLAKE
2168
+ nk_dots_packed_e4m3_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
2169
+ #elif NK_TARGET_HASWELL
2170
+ nk_dots_packed_e4m3_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
2171
+ #elif NK_TARGET_RVV
2172
+ nk_dots_packed_e4m3_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
2173
+ #elif NK_TARGET_V128RELAXED
2174
+ nk_dots_packed_e4m3_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
2175
+ #else
2176
+ nk_dots_packed_e4m3_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
2177
+ #endif
2178
+ }
2179
+
2180
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2(nk_size_t width, nk_size_t depth) {
2181
+ #if NK_TARGET_SME
2182
+ return nk_dots_packed_size_e5m2_sme(width, depth);
2183
+ #elif NK_TARGET_SAPPHIREAMX
2184
+ return nk_dots_packed_size_e5m2_sapphireamx(width, depth);
2185
+ #elif NK_TARGET_NEONFHM
2186
+ return nk_dots_packed_size_e5m2_neonfhm(width, depth);
2187
+ #elif NK_TARGET_GENOA
2188
+ return nk_dots_packed_size_e5m2_genoa(width, depth);
2189
+ #elif NK_TARGET_SKYLAKE
2190
+ return nk_dots_packed_size_e5m2_skylake(width, depth);
2191
+ #elif NK_TARGET_HASWELL
2192
+ return nk_dots_packed_size_e5m2_haswell(width, depth);
2193
+ #elif NK_TARGET_RVV
2194
+ return nk_dots_packed_size_e5m2_rvv(width, depth);
2195
+ #elif NK_TARGET_V128RELAXED
2196
+ return nk_dots_packed_size_e5m2_v128relaxed(width, depth);
2197
+ #else
2198
+ return nk_dots_packed_size_e5m2_serial(width, depth);
2199
+ #endif
2200
+ }
2201
+
2202
+ NK_PUBLIC void nk_dots_pack_e5m2(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
2203
+ void *b_packed) {
2204
+ #if NK_TARGET_SME
2205
+ nk_dots_pack_e5m2_sme(b, width, depth, b_stride, b_packed);
2206
+ #elif NK_TARGET_SAPPHIREAMX
2207
+ nk_dots_pack_e5m2_sapphireamx(b, width, depth, b_stride, b_packed);
2208
+ #elif NK_TARGET_NEONFHM
2209
+ nk_dots_pack_e5m2_neonfhm(b, width, depth, b_stride, b_packed);
2210
+ #elif NK_TARGET_GENOA
2211
+ nk_dots_pack_e5m2_genoa(b, width, depth, b_stride, b_packed);
2212
+ #elif NK_TARGET_SKYLAKE
2213
+ nk_dots_pack_e5m2_skylake(b, width, depth, b_stride, b_packed);
2214
+ #elif NK_TARGET_HASWELL
2215
+ nk_dots_pack_e5m2_haswell(b, width, depth, b_stride, b_packed);
2216
+ #elif NK_TARGET_RVV
2217
+ nk_dots_pack_e5m2_rvv(b, width, depth, b_stride, b_packed);
2218
+ #elif NK_TARGET_V128RELAXED
2219
+ nk_dots_pack_e5m2_v128relaxed(b, width, depth, b_stride, b_packed);
2220
+ #else
2221
+ nk_dots_pack_e5m2_serial(b, width, depth, b_stride, b_packed);
2222
+ #endif
2223
+ }
2224
+
2225
+ NK_PUBLIC void nk_dots_packed_e5m2(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
2226
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2227
+ #if NK_TARGET_SME
2228
+ nk_dots_packed_e5m2_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
2229
+ #elif NK_TARGET_SAPPHIREAMX
2230
+ nk_dots_packed_e5m2_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
2231
+ #elif NK_TARGET_NEONFHM
2232
+ nk_dots_packed_e5m2_neonfhm(a, b_packed, c, height, width, depth, a_stride, c_stride);
2233
+ #elif NK_TARGET_GENOA
2234
+ nk_dots_packed_e5m2_genoa(a, b_packed, c, height, width, depth, a_stride, c_stride);
2235
+ #elif NK_TARGET_SKYLAKE
2236
+ nk_dots_packed_e5m2_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
2237
+ #elif NK_TARGET_HASWELL
2238
+ nk_dots_packed_e5m2_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
2239
+ #elif NK_TARGET_RVV
2240
+ nk_dots_packed_e5m2_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
2241
+ #elif NK_TARGET_V128RELAXED
2242
+ nk_dots_packed_e5m2_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
2243
+ #else
2244
+ nk_dots_packed_e5m2_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
2245
+ #endif
2246
+ }
2247
+
2248
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3(nk_size_t width, nk_size_t depth) {
2249
+ #if NK_TARGET_SME
2250
+ return nk_dots_packed_size_e2m3_sme(width, depth);
2251
+ #elif NK_TARGET_SAPPHIREAMX
2252
+ return nk_dots_packed_size_e2m3_sapphireamx(width, depth);
2253
+ #elif NK_TARGET_SKYLAKE
2254
+ return nk_dots_packed_size_e2m3_skylake(width, depth);
2255
+ #elif NK_TARGET_SIERRA
2256
+ return nk_dots_packed_size_e2m3_sierra(width, depth);
2257
+ #elif NK_TARGET_ALDER
2258
+ return nk_dots_packed_size_e2m3_alder(width, depth);
2259
+ #elif NK_TARGET_HASWELL
2260
+ return nk_dots_packed_size_e2m3_haswell(width, depth);
2261
+ #elif NK_TARGET_RVV
2262
+ return nk_dots_packed_size_e2m3_rvv(width, depth);
2263
+ #elif NK_TARGET_V128RELAXED
2264
+ return nk_dots_packed_size_e2m3_v128relaxed(width, depth);
2265
+ #else
2266
+ return nk_dots_packed_size_e2m3_serial(width, depth);
2267
+ #endif
2268
+ }
2269
+
2270
+ NK_PUBLIC void nk_dots_pack_e2m3(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
2271
+ void *b_packed) {
2272
+ #if NK_TARGET_SME
2273
+ nk_dots_pack_e2m3_sme(b, width, depth, b_stride, b_packed);
2274
+ #elif NK_TARGET_SAPPHIREAMX
2275
+ nk_dots_pack_e2m3_sapphireamx(b, width, depth, b_stride, b_packed);
2276
+ #elif NK_TARGET_SKYLAKE
2277
+ nk_dots_pack_e2m3_skylake(b, width, depth, b_stride, b_packed);
2278
+ #elif NK_TARGET_SIERRA
2279
+ nk_dots_pack_e2m3_sierra(b, width, depth, b_stride, b_packed);
2280
+ #elif NK_TARGET_ALDER
2281
+ nk_dots_pack_e2m3_alder(b, width, depth, b_stride, b_packed);
2282
+ #elif NK_TARGET_HASWELL
2283
+ nk_dots_pack_e2m3_haswell(b, width, depth, b_stride, b_packed);
2284
+ #elif NK_TARGET_RVV
2285
+ nk_dots_pack_e2m3_rvv(b, width, depth, b_stride, b_packed);
2286
+ #elif NK_TARGET_V128RELAXED
2287
+ nk_dots_pack_e2m3_v128relaxed(b, width, depth, b_stride, b_packed);
2288
+ #else
2289
+ nk_dots_pack_e2m3_serial(b, width, depth, b_stride, b_packed);
2290
+ #endif
2291
+ }
2292
+
2293
+ NK_PUBLIC void nk_dots_packed_e2m3(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
2294
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2295
+ #if NK_TARGET_SME
2296
+ nk_dots_packed_e2m3_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
2297
+ #elif NK_TARGET_SAPPHIREAMX
2298
+ nk_dots_packed_e2m3_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
2299
+ #elif NK_TARGET_SKYLAKE
2300
+ nk_dots_packed_e2m3_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
2301
+ #elif NK_TARGET_SIERRA
2302
+ nk_dots_packed_e2m3_sierra(a, b_packed, c, height, width, depth, a_stride, c_stride);
2303
+ #elif NK_TARGET_ALDER
2304
+ nk_dots_packed_e2m3_alder(a, b_packed, c, height, width, depth, a_stride, c_stride);
2305
+ #elif NK_TARGET_HASWELL
2306
+ nk_dots_packed_e2m3_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
2307
+ #elif NK_TARGET_RVV
2308
+ nk_dots_packed_e2m3_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
2309
+ #elif NK_TARGET_V128RELAXED
2310
+ nk_dots_packed_e2m3_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
2311
+ #else
2312
+ nk_dots_packed_e2m3_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
2313
+ #endif
2314
+ }
2315
+
2316
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2(nk_size_t width, nk_size_t depth) {
2317
+ #if NK_TARGET_SME
2318
+ return nk_dots_packed_size_e3m2_sme(width, depth);
2319
+ #elif NK_TARGET_SAPPHIREAMX
2320
+ return nk_dots_packed_size_e3m2_sapphireamx(width, depth);
2321
+ #elif NK_TARGET_SKYLAKE
2322
+ return nk_dots_packed_size_e3m2_skylake(width, depth);
2323
+ #elif NK_TARGET_HASWELL
2324
+ return nk_dots_packed_size_e3m2_haswell(width, depth);
2325
+ #elif NK_TARGET_RVV
2326
+ return nk_dots_packed_size_e3m2_rvv(width, depth);
2327
+ #else
2328
+ return nk_dots_packed_size_e3m2_serial(width, depth);
2329
+ #endif
2330
+ }
2331
+
2332
+ NK_PUBLIC void nk_dots_pack_e3m2(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
2333
+ void *b_packed) {
2334
+ #if NK_TARGET_SME
2335
+ nk_dots_pack_e3m2_sme(b, width, depth, b_stride, b_packed);
2336
+ #elif NK_TARGET_SAPPHIREAMX
2337
+ nk_dots_pack_e3m2_sapphireamx(b, width, depth, b_stride, b_packed);
2338
+ #elif NK_TARGET_SKYLAKE
2339
+ nk_dots_pack_e3m2_skylake(b, width, depth, b_stride, b_packed);
2340
+ #elif NK_TARGET_HASWELL
2341
+ nk_dots_pack_e3m2_haswell(b, width, depth, b_stride, b_packed);
2342
+ #elif NK_TARGET_RVV
2343
+ nk_dots_pack_e3m2_rvv(b, width, depth, b_stride, b_packed);
2344
+ #else
2345
+ nk_dots_pack_e3m2_serial(b, width, depth, b_stride, b_packed);
2346
+ #endif
2347
+ }
2348
+
2349
+ NK_PUBLIC void nk_dots_packed_e3m2(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
2350
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2351
+ #if NK_TARGET_SME
2352
+ nk_dots_packed_e3m2_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
2353
+ #elif NK_TARGET_SAPPHIREAMX
2354
+ nk_dots_packed_e3m2_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
2355
+ #elif NK_TARGET_SKYLAKE
2356
+ nk_dots_packed_e3m2_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
2357
+ #elif NK_TARGET_HASWELL
2358
+ nk_dots_packed_e3m2_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
2359
+ #elif NK_TARGET_RVV
2360
+ nk_dots_packed_e3m2_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
2361
+ #else
2362
+ nk_dots_packed_e3m2_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
2363
+ #endif
2364
+ }
2365
+
2366
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u4(nk_size_t width, nk_size_t depth) {
2367
+ #if NK_TARGET_SME
2368
+ return nk_dots_packed_size_u4_sme(width, depth);
2369
+ #elif NK_TARGET_ICELAKE
2370
+ return nk_dots_packed_size_u4_icelake(width, depth);
2371
+ #elif NK_TARGET_NEONSDOT
2372
+ return nk_dots_packed_size_u4_neonsdot(width, depth);
2373
+ #elif NK_TARGET_HASWELL
2374
+ return nk_dots_packed_size_u4_haswell(width, depth);
2375
+ #elif NK_TARGET_V128RELAXED
2376
+ return nk_dots_packed_size_u4_v128relaxed(width, depth);
2377
+ #else
2378
+ return nk_dots_packed_size_u4_serial(width, depth);
2379
+ #endif
2380
+ }
2381
+
2382
+ NK_PUBLIC void nk_dots_pack_u4(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
2383
+ void *b_packed) {
2384
+ #if NK_TARGET_SME
2385
+ nk_dots_pack_u4_sme(b, width, depth, b_stride, b_packed);
2386
+ #elif NK_TARGET_ICELAKE
2387
+ nk_dots_pack_u4_icelake(b, width, depth, b_stride, b_packed);
2388
+ #elif NK_TARGET_NEONSDOT
2389
+ nk_dots_pack_u4_neonsdot(b, width, depth, b_stride, b_packed);
2390
+ #elif NK_TARGET_HASWELL
2391
+ nk_dots_pack_u4_haswell(b, width, depth, b_stride, b_packed);
2392
+ #elif NK_TARGET_V128RELAXED
2393
+ nk_dots_pack_u4_v128relaxed(b, width, depth, b_stride, b_packed);
2394
+ #else
2395
+ nk_dots_pack_u4_serial(b, width, depth, b_stride, b_packed);
2396
+ #endif
2397
+ }
2398
+
2399
+ NK_PUBLIC void nk_dots_packed_u4(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
2400
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2401
+ #if NK_TARGET_SME
2402
+ nk_dots_packed_u4_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
2403
+ #elif NK_TARGET_ICELAKE
2404
+ nk_dots_packed_u4_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
2405
+ #elif NK_TARGET_NEONSDOT
2406
+ nk_dots_packed_u4_neonsdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
2407
+ #elif NK_TARGET_HASWELL
2408
+ nk_dots_packed_u4_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
2409
+ #elif NK_TARGET_V128RELAXED
2410
+ nk_dots_packed_u4_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
2411
+ #else
2412
+ nk_dots_packed_u4_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
2413
+ #endif
2414
+ }
2415
+
2416
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u1(nk_size_t width, nk_size_t depth) {
2417
+ #if NK_TARGET_SMEBI32
2418
+ return nk_dots_packed_size_u1_smebi32(width, depth);
2419
+ #elif NK_TARGET_ICELAKE
2420
+ return nk_dots_packed_size_u1_icelake(width, depth);
2421
+ #elif NK_TARGET_HASWELL
2422
+ return nk_dots_packed_size_u1_haswell(width, depth);
2423
+ #elif NK_TARGET_NEON
2424
+ return nk_dots_packed_size_u1_neon(width, depth);
2425
+ #elif NK_TARGET_V128RELAXED
2426
+ return nk_dots_packed_size_u1_v128relaxed(width, depth);
2427
+ #else
2428
+ return nk_dots_packed_size_u1_serial(width, depth);
2429
+ #endif
2430
+ }
2431
+
2432
+ NK_PUBLIC void nk_dots_pack_u1(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
2433
+ void *b_packed) {
2434
+ #if NK_TARGET_SMEBI32
2435
+ nk_dots_pack_u1_smebi32(b, width, depth, b_stride, b_packed);
2436
+ #elif NK_TARGET_ICELAKE
2437
+ nk_dots_pack_u1_icelake(b, width, depth, b_stride, b_packed);
2438
+ #elif NK_TARGET_HASWELL
2439
+ nk_dots_pack_u1_haswell(b, width, depth, b_stride, b_packed);
2440
+ #elif NK_TARGET_NEON
2441
+ nk_dots_pack_u1_neon(b, width, depth, b_stride, b_packed);
2442
+ #elif NK_TARGET_V128RELAXED
2443
+ nk_dots_pack_u1_v128relaxed(b, width, depth, b_stride, b_packed);
2444
+ #else
2445
+ nk_dots_pack_u1_serial(b, width, depth, b_stride, b_packed);
2446
+ #endif
2447
+ }
2448
+
2449
+ NK_PUBLIC void nk_dots_packed_u1(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
2450
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2451
+ #if NK_TARGET_SMEBI32
2452
+ nk_dots_packed_u1_smebi32(a, b_packed, c, height, width, depth, a_stride, c_stride);
2453
+ #elif NK_TARGET_ICELAKE
2454
+ nk_dots_packed_u1_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
2455
+ #elif NK_TARGET_HASWELL
2456
+ nk_dots_packed_u1_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
2457
+ #elif NK_TARGET_NEON
2458
+ nk_dots_packed_u1_neon(a, b_packed, c, height, width, depth, a_stride, c_stride);
2459
+ #elif NK_TARGET_V128RELAXED
2460
+ nk_dots_packed_u1_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
2461
+ #else
2462
+ nk_dots_packed_u1_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
2463
+ #endif
2464
+ }
2465
+
2466
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i4(nk_size_t width, nk_size_t depth) {
2467
+ #if NK_TARGET_SME
2468
+ return nk_dots_packed_size_i4_sme(width, depth);
2469
+ #elif NK_TARGET_ICELAKE
2470
+ return nk_dots_packed_size_i4_icelake(width, depth);
2471
+ #elif NK_TARGET_NEONSDOT
2472
+ return nk_dots_packed_size_i4_neonsdot(width, depth);
2473
+ #elif NK_TARGET_HASWELL
2474
+ return nk_dots_packed_size_i4_haswell(width, depth);
2475
+ #elif NK_TARGET_V128RELAXED
2476
+ return nk_dots_packed_size_i4_v128relaxed(width, depth);
2477
+ #else
2478
+ return nk_dots_packed_size_i4_serial(width, depth);
2479
+ #endif
2480
+ }
2481
+
2482
+ NK_PUBLIC void nk_dots_pack_i4(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
2483
+ void *b_packed) {
2484
+ #if NK_TARGET_SME
2485
+ nk_dots_pack_i4_sme(b, width, depth, b_stride, b_packed);
2486
+ #elif NK_TARGET_ICELAKE
2487
+ nk_dots_pack_i4_icelake(b, width, depth, b_stride, b_packed);
2488
+ #elif NK_TARGET_NEONSDOT
2489
+ nk_dots_pack_i4_neonsdot(b, width, depth, b_stride, b_packed);
2490
+ #elif NK_TARGET_HASWELL
2491
+ nk_dots_pack_i4_haswell(b, width, depth, b_stride, b_packed);
2492
+ #elif NK_TARGET_V128RELAXED
2493
+ nk_dots_pack_i4_v128relaxed(b, width, depth, b_stride, b_packed);
2494
+ #else
2495
+ nk_dots_pack_i4_serial(b, width, depth, b_stride, b_packed);
2496
+ #endif
2497
+ }
2498
+
2499
+ NK_PUBLIC void nk_dots_packed_i4(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
2500
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2501
+ #if NK_TARGET_SME
2502
+ nk_dots_packed_i4_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
2503
+ #elif NK_TARGET_ICELAKE
2504
+ nk_dots_packed_i4_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
2505
+ #elif NK_TARGET_NEONSDOT
2506
+ nk_dots_packed_i4_neonsdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
2507
+ #elif NK_TARGET_HASWELL
2508
+ nk_dots_packed_i4_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
2509
+ #elif NK_TARGET_V128RELAXED
2510
+ nk_dots_packed_i4_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
2511
+ #else
2512
+ nk_dots_packed_i4_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
2513
+ #endif
2514
+ }
2515
+
2516
+ NK_PUBLIC void nk_dots_symmetric_f16(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2517
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
2518
+ nk_size_t row_count) {
2519
+ #if NK_TARGET_SME
2520
+ nk_dots_symmetric_f16_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2521
+ #elif NK_TARGET_NEONHALF
2522
+ nk_dots_symmetric_f16_neonhalf(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2523
+ #elif NK_TARGET_NEON
2524
+ nk_dots_symmetric_f16_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2525
+ #elif NK_TARGET_NEONFHM
2526
+ nk_dots_symmetric_f16_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2527
+ #elif NK_TARGET_SKYLAKE
2528
+ nk_dots_symmetric_f16_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2529
+ #elif NK_TARGET_HASWELL
2530
+ nk_dots_symmetric_f16_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2531
+ #elif NK_TARGET_RVV
2532
+ nk_dots_symmetric_f16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2533
+ #else
2534
+ nk_dots_symmetric_f16_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2535
+ #endif
2536
+ }
2537
+
2538
+ NK_PUBLIC void nk_dots_symmetric_bf16(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2539
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
2540
+ nk_size_t row_count) {
2541
+ #if NK_TARGET_SME
2542
+ nk_dots_symmetric_bf16_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2543
+ #elif NK_TARGET_SAPPHIREAMX
2544
+ nk_dots_symmetric_bf16_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2545
+ #elif NK_TARGET_NEONBFDOT
2546
+ nk_dots_symmetric_bf16_neonbfdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2547
+ #elif NK_TARGET_GENOA
2548
+ nk_dots_symmetric_bf16_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2549
+ #elif NK_TARGET_SKYLAKE
2550
+ nk_dots_symmetric_bf16_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2551
+ #elif NK_TARGET_HASWELL
2552
+ nk_dots_symmetric_bf16_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2553
+ #elif NK_TARGET_RVV
2554
+ nk_dots_symmetric_bf16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2555
+ #elif NK_TARGET_V128RELAXED
2556
+ nk_dots_symmetric_bf16_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2557
+ #else
2558
+ nk_dots_symmetric_bf16_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2559
+ #endif
2560
+ }
2561
+
2562
+ NK_PUBLIC void nk_dots_symmetric_i8(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2563
+ nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
2564
+ nk_size_t row_count) {
2565
+ #if NK_TARGET_SME
2566
+ nk_dots_symmetric_i8_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2567
+ #elif NK_TARGET_SAPPHIREAMX
2568
+ nk_dots_symmetric_i8_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2569
+ #elif NK_TARGET_NEONSDOT
2570
+ nk_dots_symmetric_i8_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2571
+ #elif NK_TARGET_ICELAKE
2572
+ nk_dots_symmetric_i8_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2573
+ #elif NK_TARGET_SIERRA
2574
+ nk_dots_symmetric_i8_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2575
+ #elif NK_TARGET_ALDER
2576
+ nk_dots_symmetric_i8_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2577
+ #elif NK_TARGET_HASWELL
2578
+ nk_dots_symmetric_i8_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2579
+ #elif NK_TARGET_RVV
2580
+ nk_dots_symmetric_i8_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2581
+ #elif NK_TARGET_V128RELAXED
2582
+ nk_dots_symmetric_i8_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2583
+ #else
2584
+ nk_dots_symmetric_i8_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2585
+ #endif
2586
+ }
2587
+
2588
+ NK_PUBLIC void nk_dots_symmetric_u8(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2589
+ nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
2590
+ nk_size_t row_count) {
2591
+ #if NK_TARGET_SME
2592
+ nk_dots_symmetric_u8_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2593
+ #elif NK_TARGET_SAPPHIREAMX
2594
+ nk_dots_symmetric_u8_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2595
+ #elif NK_TARGET_ICELAKE
2596
+ nk_dots_symmetric_u8_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2597
+ #elif NK_TARGET_SIERRA
2598
+ nk_dots_symmetric_u8_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2599
+ #elif NK_TARGET_ALDER
2600
+ nk_dots_symmetric_u8_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2601
+ #elif NK_TARGET_NEONSDOT
2602
+ nk_dots_symmetric_u8_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2603
+ #elif NK_TARGET_HASWELL
2604
+ nk_dots_symmetric_u8_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2605
+ #elif NK_TARGET_RVV
2606
+ nk_dots_symmetric_u8_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2607
+ #elif NK_TARGET_V128RELAXED
2608
+ nk_dots_symmetric_u8_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2609
+ #else
2610
+ nk_dots_symmetric_u8_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2611
+ #endif
2612
+ }
2613
+
2614
+ NK_PUBLIC void nk_dots_symmetric_e4m3(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2615
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
2616
+ nk_size_t row_count) {
2617
+ #if NK_TARGET_SME
2618
+ nk_dots_symmetric_e4m3_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2619
+ #elif NK_TARGET_NEONFHM
2620
+ nk_dots_symmetric_e4m3_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2621
+ #elif NK_TARGET_SAPPHIREAMX
2622
+ nk_dots_symmetric_e4m3_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2623
+ #elif NK_TARGET_GENOA
2624
+ nk_dots_symmetric_e4m3_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2625
+ #elif NK_TARGET_SKYLAKE
2626
+ nk_dots_symmetric_e4m3_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2627
+ #elif NK_TARGET_HASWELL
2628
+ nk_dots_symmetric_e4m3_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2629
+ #elif NK_TARGET_RVV
2630
+ nk_dots_symmetric_e4m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2631
+ #elif NK_TARGET_V128RELAXED
2632
+ nk_dots_symmetric_e4m3_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2633
+ #else
2634
+ nk_dots_symmetric_e4m3_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2635
+ #endif
2636
+ }
2637
+
2638
+ NK_PUBLIC void nk_dots_symmetric_e5m2(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2639
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
2640
+ nk_size_t row_count) {
2641
+ #if NK_TARGET_SME
2642
+ nk_dots_symmetric_e5m2_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2643
+ #elif NK_TARGET_NEONFHM
2644
+ nk_dots_symmetric_e5m2_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2645
+ #elif NK_TARGET_SAPPHIREAMX
2646
+ nk_dots_symmetric_e5m2_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2647
+ #elif NK_TARGET_GENOA
2648
+ nk_dots_symmetric_e5m2_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2649
+ #elif NK_TARGET_SKYLAKE
2650
+ nk_dots_symmetric_e5m2_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2651
+ #elif NK_TARGET_HASWELL
2652
+ nk_dots_symmetric_e5m2_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2653
+ #elif NK_TARGET_RVV
2654
+ nk_dots_symmetric_e5m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2655
+ #elif NK_TARGET_V128RELAXED
2656
+ nk_dots_symmetric_e5m2_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2657
+ #else
2658
+ nk_dots_symmetric_e5m2_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2659
+ #endif
2660
+ }
2661
+
2662
+ NK_PUBLIC void nk_dots_symmetric_e2m3(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2663
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
2664
+ nk_size_t row_count) {
2665
+ #if NK_TARGET_SME
2666
+ nk_dots_symmetric_e2m3_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2667
+ #elif NK_TARGET_SAPPHIREAMX
2668
+ nk_dots_symmetric_e2m3_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2669
+ #elif NK_TARGET_SKYLAKE
2670
+ nk_dots_symmetric_e2m3_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2671
+ #elif NK_TARGET_SIERRA
2672
+ nk_dots_symmetric_e2m3_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2673
+ #elif NK_TARGET_ALDER
2674
+ nk_dots_symmetric_e2m3_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2675
+ #elif NK_TARGET_HASWELL
2676
+ nk_dots_symmetric_e2m3_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2677
+ #elif NK_TARGET_RVV
2678
+ nk_dots_symmetric_e2m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2679
+ #elif NK_TARGET_V128RELAXED
2680
+ nk_dots_symmetric_e2m3_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2681
+ #else
2682
+ nk_dots_symmetric_e2m3_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2683
+ #endif
2684
+ }
2685
+
2686
+ NK_PUBLIC void nk_dots_symmetric_e3m2(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2687
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
2688
+ nk_size_t row_count) {
2689
+ #if NK_TARGET_SME
2690
+ nk_dots_symmetric_e3m2_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2691
+ #elif NK_TARGET_SAPPHIREAMX
2692
+ nk_dots_symmetric_e3m2_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2693
+ #elif NK_TARGET_SKYLAKE
2694
+ nk_dots_symmetric_e3m2_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2695
+ #elif NK_TARGET_HASWELL
2696
+ nk_dots_symmetric_e3m2_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2697
+ #elif NK_TARGET_RVV
2698
+ nk_dots_symmetric_e3m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2699
+ #else
2700
+ nk_dots_symmetric_e3m2_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2701
+ #endif
2702
+ }
2703
+
2704
+ NK_PUBLIC void nk_dots_symmetric_u4(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2705
+ nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
2706
+ nk_size_t row_count) {
2707
+ #if NK_TARGET_SME
2708
+ nk_dots_symmetric_u4_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2709
+ #elif NK_TARGET_ICELAKE
2710
+ nk_dots_symmetric_u4_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2711
+ #elif NK_TARGET_NEONSDOT
2712
+ nk_dots_symmetric_u4_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2713
+ #elif NK_TARGET_HASWELL
2714
+ nk_dots_symmetric_u4_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2715
+ #elif NK_TARGET_V128RELAXED
2716
+ nk_dots_symmetric_u4_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2717
+ #else
2718
+ nk_dots_symmetric_u4_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2719
+ #endif
2720
+ }
2721
+
2722
+ NK_PUBLIC void nk_dots_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2723
+ nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
2724
+ nk_size_t row_count) {
2725
+ #if NK_TARGET_SMEBI32
2726
+ nk_dots_symmetric_u1_smebi32(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2727
+ #elif NK_TARGET_ICELAKE
2728
+ nk_dots_symmetric_u1_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2729
+ #elif NK_TARGET_HASWELL
2730
+ nk_dots_symmetric_u1_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2731
+ #elif NK_TARGET_NEON
2732
+ nk_dots_symmetric_u1_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2733
+ #elif NK_TARGET_V128RELAXED
2734
+ nk_dots_symmetric_u1_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2735
+ #else
2736
+ nk_dots_symmetric_u1_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2737
+ #endif
2738
+ }
2739
+
2740
+ NK_PUBLIC void nk_dots_symmetric_i4(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2741
+ nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
2742
+ nk_size_t row_count) {
2743
+ #if NK_TARGET_SME
2744
+ nk_dots_symmetric_i4_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2745
+ #elif NK_TARGET_ICELAKE
2746
+ nk_dots_symmetric_i4_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2747
+ #elif NK_TARGET_NEONSDOT
2748
+ nk_dots_symmetric_i4_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2749
+ #elif NK_TARGET_HASWELL
2750
+ nk_dots_symmetric_i4_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2751
+ #elif NK_TARGET_V128RELAXED
2752
+ nk_dots_symmetric_i4_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2753
+ #else
2754
+ nk_dots_symmetric_i4_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2755
+ #endif
2756
+ }
2757
+
2758
+ NK_PUBLIC void nk_dots_symmetric_f32(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2759
+ nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start,
2760
+ nk_size_t row_count) {
2761
+ #if NK_TARGET_SMEF64
2762
+ nk_dots_symmetric_f32_smef64(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2763
+ #elif NK_TARGET_SKYLAKE
2764
+ nk_dots_symmetric_f32_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2765
+ #elif NK_TARGET_HASWELL
2766
+ nk_dots_symmetric_f32_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2767
+ #elif NK_TARGET_NEON
2768
+ nk_dots_symmetric_f32_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2769
+ #elif NK_TARGET_RVV
2770
+ nk_dots_symmetric_f32_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2771
+ #elif NK_TARGET_V128RELAXED
2772
+ nk_dots_symmetric_f32_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2773
+ #else
2774
+ nk_dots_symmetric_f32_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2775
+ #endif
2776
+ }
2777
+
2778
+ NK_PUBLIC void nk_dots_symmetric_f64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2779
+ nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start,
2780
+ nk_size_t row_count) {
2781
+ #if NK_TARGET_SMEF64
2782
+ nk_dots_symmetric_f64_smef64(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2783
+ #elif NK_TARGET_SKYLAKE
2784
+ nk_dots_symmetric_f64_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2785
+ #elif NK_TARGET_HASWELL
2786
+ nk_dots_symmetric_f64_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2787
+ #elif NK_TARGET_NEON
2788
+ nk_dots_symmetric_f64_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2789
+ #elif NK_TARGET_RVV
2790
+ nk_dots_symmetric_f64_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2791
+ #elif NK_TARGET_V128RELAXED
2792
+ nk_dots_symmetric_f64_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2793
+ #else
2794
+ nk_dots_symmetric_f64_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2795
+ #endif
2796
+ }
2797
+
2798
+ #endif // !NK_DYNAMIC_DISPATCH
2799
+
2800
+ #if defined(__cplusplus)
2801
+ } // extern "C"
2802
+ #endif
2803
+
2804
+ #endif