numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,2146 @@
1
+ /**
2
+ * @brief SIMD-accelerated Elementwise Arithmetic.
3
+ * @file include/numkong/each.h
4
+ * @author Ash Vardanian
5
+ * @date October 16, 2024
6
+ *
7
+ * Contains following element-wise operations:
8
+ *
9
+ * - Scale (Multiply) with shift: result[i] = alpha * a[i] + beta
10
+ * - Sum (Add): result[i] = a[i] + b[i]
11
+ * - Blend: result[i] = alpha * a[i] + beta * b[i]
12
+ * - FMA (Fused Multiply-Add): result[i] = alpha * a[i] * b[i] + beta * c[i]
13
+ *
14
+ * Beyond their obvious usecases, those can be reused for vector-scalar math and other operations:
15
+ *
16
+ * - Scale with beta = 0 for a pure multiply.
17
+ * - Sum is equivalent to WSum with alpha = beta = 1.
18
+ * - Average is WSum with alpha = beta = 0.5.
19
+ * - Elementwise multiply is FMA with beta = 0.
20
+ *
21
+ * For dtypes:
22
+ *
23
+ * - f64: 64-bit IEEE floating point numbers × 64-bit scales
24
+ * - f32: 32-bit IEEE floating point numbers × 32-bit scales
25
+ * - f16: 16-bit IEEE floating point numbers × 32-bit scales
26
+ * - bf16: 16-bit brain floating point numbers × 32-bit scales
27
+ * - e4m3: 8-bit e4m3 floating point numbers × 32-bit scales
28
+ * - e5m2: 8-bit e5m2 floating point numbers × 32-bit scales
29
+ * - e2m3: 8-bit e2m3 floating point numbers (MX) × 32-bit scales
30
+ * - e3m2: 8-bit e3m2 floating point numbers (MX) × 32-bit scales
31
+ * - i8/u8: 8-bit signed and unsigned integers × 32-bit scales
32
+ * - i16/u16: 16-bit signed and unsigned integers × 32-bit scales
33
+ * - i32/u32: 32-bit signed and unsigned integers × 64-bit scales
34
+ * - i64/u64: 64-bit signed and unsigned integers × 64-bit scales
35
+ *
36
+ * For hardware architectures:
37
+ *
38
+ * - Arm: NEON, NEON+F16, NEON+BF16
39
+ * - x86: Haswell, Skylake, Ice Lake, Sapphire Rapids
40
+ * - RISC-V: RVV
41
+ *
42
+ *
43
+ * @section numerical_stability Numerical Stability
44
+ *
45
+ * Integer sum is elementwise a[i]+b[i] clamped to the type's range. Serial widens to
46
+ * i64 then clamps on store. NEON uses hardware saturating adds (SQADD/UQADD).
47
+ * f16/bf16/FP8 sum: promoted to f32, added, truncated back — double rounding possible.
48
+ * Scale/blend/fma: float alpha/beta arithmetic, result rounded to nearest, ties to even, then clamped.
49
+ * f32/f64 operations are native precision with no widening.
50
+ *
51
+ * @section x86_instructions Relevant x86 Instructions
52
+ *
53
+ * FP16 conversions (VCVTPH2PS/VCVTPS2PH) are used for f16 scale/sum/blend/fma operations, converting
54
+ * to f32 for arithmetic then back. The 6-7 cycle latency is amortized over vector-width elements.
55
+ * Saturating integer adds (VPADDSW/VPADDUSW) provide overflow protection for i16/u16 sums without
56
+ * branching. FMA (VFMADD231PS) is the workhorse for scale (alpha*x+beta) and blend (alpha*a+beta*b).
57
+ *
58
+ * Intrinsic Instruction Ice Genoa
59
+ * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 7c @ p0+p5 6c @ p12+p23
60
+ * _mm512_cvtps_ph VCVTPS2PH (YMM, ZMM, I8) 7c @ p0+p5 7c @ p12+p23
61
+ * _mm256_adds_epi16 VPADDSW (YMM, YMM, YMM) 1c @ p01 N/A
62
+ * _mm256_adds_epu16 VPADDUSW (YMM, YMM, YMM) 1c @ p01 N/A
63
+ * _mm512_fpclass_ps_mask VFPCLASSPS (K, ZMM, I8) 3c @ p5 5c @ p01
64
+ * _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4c @ p01 4c @ p01
65
+ *
66
+ * @section arm_instructions Relevant ARM NEON/SVE Instructions
67
+ *
68
+ * On ARM, i8/u8 elementwise operations convert to f16 intermediates using FCVT to maintain high
69
+ * vector throughput (8 elements per 128-bit register vs 4 for f32). Saturating adds (SQADD/UQADD)
70
+ * handle integer overflow. FMLA provides fused multiply-add for floating-point scale/blend/fma.
71
+ *
72
+ * Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
73
+ * vfmaq_f32 FMLA.S (vec) 4c @ V0123 4c @ V0123 4c @ V0123
74
+ * vqaddq_s16 SQADD (vec) 3c @ V0123 2c @ V0123 2c @ V0123
75
+ * vqaddq_u16 UQADD (vec) 3c @ V0123 2c @ V0123 2c @ V0123
76
+ * vcvtq_f32_s32 SCVTF (vec) 3c @ V0123 3c @ V01 3c @ V01
77
+ * vcvtnq_s32_f32 FCVTNS (vec) 3c @ V0123 3c @ V01 3c @ V01
78
+ *
79
+ * @section references References
80
+ *
81
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
82
+ * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
83
+ *
84
+ */
85
+ #ifndef NK_EACH_H
86
+ #define NK_EACH_H
87
+
88
+ #include "numkong/types.h"
89
+
90
+ #if defined(__cplusplus)
91
+ extern "C" {
92
+ #endif
93
+
94
+ /**
95
+ * @brief Element-wise scale with shift: result[i] = alpha * a[i] + beta.
96
+ *
97
+ * @param[in] a The input vector.
98
+ * @param[in] n The number of elements in the vector.
99
+ * @param[in] alpha Pointer to the scaling factor (type depends on input precision).
100
+ * @param[in] beta Pointer to the shift (bias) value (type depends on input precision).
101
+ * @param[out] result The output vector.
102
+ */
103
+ NK_DYNAMIC void nk_each_scale_f64(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
104
+ nk_f64_t *result);
105
+ /** @copydoc nk_each_scale_f64 */
106
+ NK_DYNAMIC void nk_each_scale_f32(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
107
+ nk_f32_t *result);
108
+ /** @copydoc nk_each_scale_f64 */
109
+ NK_DYNAMIC void nk_each_scale_f16(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
110
+ nk_f16_t *result);
111
+ /** @copydoc nk_each_scale_f64 */
112
+ NK_DYNAMIC void nk_each_scale_bf16(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
113
+ nk_bf16_t *result);
114
+ /** @copydoc nk_each_scale_f64 */
115
+ NK_DYNAMIC void nk_each_scale_i8(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
116
+ nk_i8_t *result);
117
+ /** @copydoc nk_each_scale_f64 */
118
+ NK_DYNAMIC void nk_each_scale_u8(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
119
+ nk_u8_t *result);
120
+ /** @copydoc nk_each_scale_f64 */
121
+ NK_DYNAMIC void nk_each_scale_i16(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
122
+ nk_i16_t *result);
123
+ /** @copydoc nk_each_scale_f64 */
124
+ NK_DYNAMIC void nk_each_scale_u16(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
125
+ nk_u16_t *result);
126
+ /** @copydoc nk_each_scale_f64 */
127
+ NK_DYNAMIC void nk_each_scale_i32(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
128
+ nk_i32_t *result);
129
+ /** @copydoc nk_each_scale_f64 */
130
+ NK_DYNAMIC void nk_each_scale_u32(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
131
+ nk_u32_t *result);
132
+ /** @copydoc nk_each_scale_f64 */
133
+ NK_DYNAMIC void nk_each_scale_i64(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
134
+ nk_i64_t *result);
135
+ /** @copydoc nk_each_scale_f64 */
136
+ NK_DYNAMIC void nk_each_scale_u64(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
137
+ nk_u64_t *result);
138
+
139
+ /**
140
+ * @brief Element-wise sum: result[i] = a[i] + b[i].
141
+ *
142
+ * @param[in] a The first input vector.
143
+ * @param[in] b The second input vector.
144
+ * @param[in] n The number of elements in the vectors.
145
+ * @param[out] result The output vector.
146
+ */
147
+ NK_DYNAMIC void nk_each_sum_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
148
+ /** @copydoc nk_each_sum_f64 */
149
+ NK_DYNAMIC void nk_each_sum_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
150
+ /** @copydoc nk_each_sum_f64 */
151
+ NK_DYNAMIC void nk_each_sum_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
152
+ /** @copydoc nk_each_sum_f64 */
153
+ NK_DYNAMIC void nk_each_sum_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
154
+ /** @copydoc nk_each_sum_f64 */
155
+ NK_DYNAMIC void nk_each_sum_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
156
+ /** @copydoc nk_each_sum_f64 */
157
+ NK_DYNAMIC void nk_each_sum_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
158
+ /** @copydoc nk_each_sum_f64 */
159
+ NK_DYNAMIC void nk_each_sum_i16(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
160
+ /** @copydoc nk_each_sum_f64 */
161
+ NK_DYNAMIC void nk_each_sum_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
162
+ /** @copydoc nk_each_sum_f64 */
163
+ NK_DYNAMIC void nk_each_sum_i32(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
164
+ /** @copydoc nk_each_sum_f64 */
165
+ NK_DYNAMIC void nk_each_sum_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
166
+ /** @copydoc nk_each_sum_f64 */
167
+ NK_DYNAMIC void nk_each_sum_i64(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
168
+ /** @copydoc nk_each_sum_f64 */
169
+ NK_DYNAMIC void nk_each_sum_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
170
+
171
+ /**
172
+ * @brief Weighted sum: result[i] = alpha * a[i] + beta * b[i].
173
+ *
174
+ * @param[in] a The first input vector.
175
+ * @param[in] b The second input vector.
176
+ * @param[in] n The number of elements in the vectors.
177
+ * @param[in] alpha Pointer to the first weight (type depends on input precision).
178
+ * @param[in] beta Pointer to the second weight (type depends on input precision).
179
+ * @param[out] result The output vector.
180
+ */
181
+ NK_DYNAMIC void nk_each_blend_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
182
+ nk_f64_t const *beta, nk_f64_t *result);
183
+ /** @copydoc nk_each_blend_f64 */
184
+ NK_DYNAMIC void nk_each_blend_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
185
+ nk_f32_t const *beta, nk_f32_t *result);
186
+ /** @copydoc nk_each_blend_f64 */
187
+ NK_DYNAMIC void nk_each_blend_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
188
+ nk_f32_t const *beta, nk_f16_t *result);
189
+ /** @copydoc nk_each_blend_f64 */
190
+ NK_DYNAMIC void nk_each_blend_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
191
+ nk_f32_t const *beta, nk_bf16_t *result);
192
+ /** @copydoc nk_each_blend_f64 */
193
+ NK_DYNAMIC void nk_each_blend_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
194
+ nk_f32_t const *beta, nk_i8_t *result);
195
+ /** @copydoc nk_each_blend_f64 */
196
+ NK_DYNAMIC void nk_each_blend_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
197
+ nk_f32_t const *beta, nk_u8_t *result);
198
+ /** @copydoc nk_each_blend_f64 */
199
+ NK_DYNAMIC void nk_each_blend_i16(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_f32_t const *alpha,
200
+ nk_f32_t const *beta, nk_i16_t *result);
201
+ /** @copydoc nk_each_blend_f64 */
202
+ NK_DYNAMIC void nk_each_blend_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_f32_t const *alpha,
203
+ nk_f32_t const *beta, nk_u16_t *result);
204
+ /** @copydoc nk_each_blend_f64 */
205
+ NK_DYNAMIC void nk_each_blend_i32(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_f64_t const *alpha,
206
+ nk_f64_t const *beta, nk_i32_t *result);
207
+ /** @copydoc nk_each_blend_f64 */
208
+ NK_DYNAMIC void nk_each_blend_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_f64_t const *alpha,
209
+ nk_f64_t const *beta, nk_u32_t *result);
210
+ /** @copydoc nk_each_blend_f64 */
211
+ NK_DYNAMIC void nk_each_blend_i64(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_f64_t const *alpha,
212
+ nk_f64_t const *beta, nk_i64_t *result);
213
+ /** @copydoc nk_each_blend_f64 */
214
+ NK_DYNAMIC void nk_each_blend_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_f64_t const *alpha,
215
+ nk_f64_t const *beta, nk_u64_t *result);
216
+
217
+ /**
218
+ * @brief Fused multiply-add: result[i] = alpha * a[i] * b[i] + beta * c[i].
219
+ *
220
+ * @param[in] a The first input vector.
221
+ * @param[in] b The second input vector.
222
+ * @param[in] c The third input vector.
223
+ * @param[in] n The number of elements in the vectors.
224
+ * @param[in] alpha Pointer to the scaling factor for a[i] * b[i] (type depends on input precision).
225
+ * @param[in] beta Pointer to the scaling factor for c[i] (type depends on input precision).
226
+ * @param[out] result The output vector.
227
+ */
228
+ NK_DYNAMIC void nk_each_fma_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
229
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
230
+ /** @copydoc nk_each_fma_f64 */
231
+ NK_DYNAMIC void nk_each_fma_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
232
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
233
+ /** @copydoc nk_each_fma_f64 */
234
+ NK_DYNAMIC void nk_each_fma_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
235
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
236
+ /** @copydoc nk_each_fma_f64 */
237
+ NK_DYNAMIC void nk_each_fma_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
238
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
239
+ /** @copydoc nk_each_fma_f64 */
240
+ NK_DYNAMIC void nk_each_fma_i8(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, nk_f32_t const *alpha,
241
+ nk_f32_t const *beta, nk_i8_t *result);
242
+ /** @copydoc nk_each_fma_f64 */
243
+ NK_DYNAMIC void nk_each_fma_u8(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, nk_f32_t const *alpha,
244
+ nk_f32_t const *beta, nk_u8_t *result);
245
+ /** @copydoc nk_each_fma_f64 */
246
+ NK_DYNAMIC void nk_each_fma_i16(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
247
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
248
+
249
+ /** @copydoc nk_each_sum_f64 */
250
+ NK_DYNAMIC void nk_each_sum_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
251
+ /** @copydoc nk_each_sum_f64 */
252
+ NK_DYNAMIC void nk_each_sum_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
253
+ /** @copydoc nk_each_scale_f64 */
254
+ NK_DYNAMIC void nk_each_scale_e4m3(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
255
+ nk_e4m3_t *result);
256
+ /** @copydoc nk_each_scale_f64 */
257
+ NK_DYNAMIC void nk_each_scale_e5m2(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
258
+ nk_e5m2_t *result);
259
+ /** @copydoc nk_each_blend_f64 */
260
+ NK_DYNAMIC void nk_each_blend_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
261
+ nk_f32_t const *beta, nk_e4m3_t *result);
262
+ /** @copydoc nk_each_blend_f64 */
263
+ NK_DYNAMIC void nk_each_blend_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
264
+ nk_f32_t const *beta, nk_e5m2_t *result);
265
+ /** @copydoc nk_each_fma_f64 */
266
+ NK_DYNAMIC void nk_each_fma_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
267
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
268
+ /** @copydoc nk_each_fma_f64 */
269
+ NK_DYNAMIC void nk_each_fma_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
270
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
271
+ /** @copydoc nk_each_sum_f64 */
272
+ NK_DYNAMIC void nk_each_sum_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_e2m3_t *result);
273
+ /** @copydoc nk_each_sum_f64 */
274
+ NK_DYNAMIC void nk_each_sum_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_e3m2_t *result);
275
+ /** @copydoc nk_each_scale_f64 */
276
+ NK_DYNAMIC void nk_each_scale_e2m3(nk_e2m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
277
+ nk_e2m3_t *result);
278
+ /** @copydoc nk_each_scale_f64 */
279
+ NK_DYNAMIC void nk_each_scale_e3m2(nk_e3m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
280
+ nk_e3m2_t *result);
281
+ /** @copydoc nk_each_blend_f64 */
282
+ NK_DYNAMIC void nk_each_blend_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
283
+ nk_f32_t const *beta, nk_e2m3_t *result);
284
+ /** @copydoc nk_each_blend_f64 */
285
+ NK_DYNAMIC void nk_each_blend_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
286
+ nk_f32_t const *beta, nk_e3m2_t *result);
287
+ /** @copydoc nk_each_fma_f64 */
288
+ NK_DYNAMIC void nk_each_fma_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_e2m3_t const *c, nk_size_t n,
289
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e2m3_t *result);
290
+ /** @copydoc nk_each_fma_f64 */
291
+ NK_DYNAMIC void nk_each_fma_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_e3m2_t const *c, nk_size_t n,
292
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e3m2_t *result);
293
+ /** @copydoc nk_each_fma_f64 */
294
+ NK_DYNAMIC void nk_each_fma_u16(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
295
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
296
+ /** @copydoc nk_each_fma_f64 */
297
+ NK_DYNAMIC void nk_each_fma_i32(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
298
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
299
+ /** @copydoc nk_each_fma_f64 */
300
+ NK_DYNAMIC void nk_each_fma_u32(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
301
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
302
+ /** @copydoc nk_each_fma_f64 */
303
+ NK_DYNAMIC void nk_each_fma_i64(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
304
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
305
+ /** @copydoc nk_each_fma_f64 */
306
+ NK_DYNAMIC void nk_each_fma_u64(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
307
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
308
+
309
+ /** @copydoc nk_each_sum_f64 */
310
+ NK_DYNAMIC void nk_each_sum_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t *result);
311
+ /** @copydoc nk_each_sum_f64 */
312
+ NK_DYNAMIC void nk_each_sum_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
313
+ /** @copydoc nk_each_scale_f64 */
314
+ NK_DYNAMIC void nk_each_scale_f32c(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
315
+ nk_f32c_t *result);
316
+ /** @copydoc nk_each_scale_f64 */
317
+ NK_DYNAMIC void nk_each_scale_f64c(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
318
+ nk_f64c_t *result);
319
+ /** @copydoc nk_each_blend_f64 */
320
+ NK_DYNAMIC void nk_each_blend_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
321
+ nk_f32c_t const *beta, nk_f32c_t *result);
322
+ /** @copydoc nk_each_blend_f64 */
323
+ NK_DYNAMIC void nk_each_blend_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
324
+ nk_f64c_t const *beta, nk_f64c_t *result);
325
+ /** @copydoc nk_each_fma_f64 */
326
+ NK_DYNAMIC void nk_each_fma_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
327
+ nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
328
+ /** @copydoc nk_each_fma_f64 */
329
+ NK_DYNAMIC void nk_each_fma_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
330
+ nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
331
+
332
+ /** @copydoc nk_each_scale_f64 */
333
+ NK_PUBLIC void nk_each_scale_f64_serial(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
334
+ nk_f64_t *result);
335
+ /** @copydoc nk_each_scale_f64 */
336
+ NK_PUBLIC void nk_each_scale_f32_serial(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
337
+ nk_f32_t *result);
338
+ /** @copydoc nk_each_scale_f64 */
339
+ NK_PUBLIC void nk_each_scale_f16_serial(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
340
+ nk_f16_t *result);
341
+ /** @copydoc nk_each_scale_f64 */
342
+ NK_PUBLIC void nk_each_scale_bf16_serial(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
343
+ nk_bf16_t *result);
344
+ /** @copydoc nk_each_scale_f64 */
345
+ NK_PUBLIC void nk_each_scale_i8_serial(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
346
+ nk_i8_t *result);
347
+ /** @copydoc nk_each_scale_f64 */
348
+ NK_PUBLIC void nk_each_scale_u8_serial(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
349
+ nk_u8_t *result);
350
+ /** @copydoc nk_each_scale_f64 */
351
+ NK_PUBLIC void nk_each_scale_i16_serial(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
352
+ nk_i16_t *result);
353
+ /** @copydoc nk_each_scale_f64 */
354
+ NK_PUBLIC void nk_each_scale_u16_serial(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
355
+ nk_u16_t *result);
356
+ /** @copydoc nk_each_scale_f64 */
357
+ NK_PUBLIC void nk_each_scale_i32_serial(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
358
+ nk_i32_t *result);
359
+ /** @copydoc nk_each_scale_f64 */
360
+ NK_PUBLIC void nk_each_scale_u32_serial(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
361
+ nk_u32_t *result);
362
+ /** @copydoc nk_each_scale_f64 */
363
+ NK_PUBLIC void nk_each_scale_i64_serial(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
364
+ nk_i64_t *result);
365
+ /** @copydoc nk_each_scale_f64 */
366
+ NK_PUBLIC void nk_each_scale_u64_serial(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
367
+ nk_u64_t *result);
368
+
369
+ /** @copydoc nk_each_sum_f64 */
370
+ NK_PUBLIC void nk_each_sum_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
371
+ /** @copydoc nk_each_sum_f64 */
372
+ NK_PUBLIC void nk_each_sum_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
373
+ /** @copydoc nk_each_sum_f64 */
374
+ NK_PUBLIC void nk_each_sum_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
375
+ /** @copydoc nk_each_sum_f64 */
376
+ NK_PUBLIC void nk_each_sum_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
377
+ /** @copydoc nk_each_sum_f64 */
378
+ NK_PUBLIC void nk_each_sum_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
379
+ /** @copydoc nk_each_sum_f64 */
380
+ NK_PUBLIC void nk_each_sum_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
381
+ /** @copydoc nk_each_sum_f64 */
382
+ NK_PUBLIC void nk_each_sum_i16_serial(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
383
+ /** @copydoc nk_each_sum_f64 */
384
+ NK_PUBLIC void nk_each_sum_u16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
385
+ /** @copydoc nk_each_sum_f64 */
386
+ NK_PUBLIC void nk_each_sum_i32_serial(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
387
+ /** @copydoc nk_each_sum_f64 */
388
+ NK_PUBLIC void nk_each_sum_u32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
389
+ /** @copydoc nk_each_sum_f64 */
390
+ NK_PUBLIC void nk_each_sum_i64_serial(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
391
+ /** @copydoc nk_each_sum_f64 */
392
+ NK_PUBLIC void nk_each_sum_u64_serial(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
393
+
394
+ /** @copydoc nk_each_blend_f64 */
395
+ NK_PUBLIC void nk_each_blend_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
396
+ nk_f64_t const *beta, nk_f64_t *result);
397
+ /** @copydoc nk_each_blend_f64 */
398
+ NK_PUBLIC void nk_each_blend_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
399
+ nk_f32_t const *beta, nk_f32_t *result);
400
+ /** @copydoc nk_each_blend_f64 */
401
+ NK_PUBLIC void nk_each_blend_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
402
+ nk_f32_t const *beta, nk_f16_t *result);
403
+ /** @copydoc nk_each_blend_f64 */
404
+ NK_PUBLIC void nk_each_blend_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
405
+ nk_f32_t const *beta, nk_bf16_t *result);
406
+ /** @copydoc nk_each_blend_f64 */
407
+ NK_PUBLIC void nk_each_blend_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
408
+ nk_f32_t const *beta, nk_i8_t *result);
409
+ /** @copydoc nk_each_blend_f64 */
410
+ NK_PUBLIC void nk_each_blend_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
411
+ nk_f32_t const *beta, nk_u8_t *result);
412
+ /** @copydoc nk_each_blend_f64 */
413
+ NK_PUBLIC void nk_each_blend_i16_serial(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_f32_t const *alpha,
414
+ nk_f32_t const *beta, nk_i16_t *result);
415
+ /** @copydoc nk_each_blend_f64 */
416
+ NK_PUBLIC void nk_each_blend_u16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_f32_t const *alpha,
417
+ nk_f32_t const *beta, nk_u16_t *result);
418
+ /** @copydoc nk_each_blend_f64 */
419
+ NK_PUBLIC void nk_each_blend_i32_serial(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_f64_t const *alpha,
420
+ nk_f64_t const *beta, nk_i32_t *result);
421
+ /** @copydoc nk_each_blend_f64 */
422
+ NK_PUBLIC void nk_each_blend_u32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_f64_t const *alpha,
423
+ nk_f64_t const *beta, nk_u32_t *result);
424
+ /** @copydoc nk_each_blend_f64 */
425
+ NK_PUBLIC void nk_each_blend_i64_serial(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_f64_t const *alpha,
426
+ nk_f64_t const *beta, nk_i64_t *result);
427
+ /** @copydoc nk_each_blend_f64 */
428
+ NK_PUBLIC void nk_each_blend_u64_serial(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_f64_t const *alpha,
429
+ nk_f64_t const *beta, nk_u64_t *result);
430
+
431
+ /** @copydoc nk_each_fma_f64 */
432
+ NK_PUBLIC void nk_each_fma_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
433
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
434
+ /** @copydoc nk_each_fma_f64 */
435
+ NK_PUBLIC void nk_each_fma_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
436
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
437
+ /** @copydoc nk_each_fma_f64 */
438
+ NK_PUBLIC void nk_each_fma_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
439
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
440
+ /** @copydoc nk_each_fma_f64 */
441
+ NK_PUBLIC void nk_each_fma_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
442
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
443
+ /** @copydoc nk_each_fma_f64 */
444
+ NK_PUBLIC void nk_each_fma_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
445
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
446
+ /** @copydoc nk_each_fma_f64 */
447
+ NK_PUBLIC void nk_each_fma_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
448
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
449
+ /** @copydoc nk_each_fma_f64 */
450
+ NK_PUBLIC void nk_each_fma_i16_serial(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
451
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
452
+ /** @copydoc nk_each_fma_f64 */
453
+ NK_PUBLIC void nk_each_fma_u16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
454
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
455
+ /** @copydoc nk_each_fma_f64 */
456
+ NK_PUBLIC void nk_each_fma_i32_serial(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
457
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
458
+ /** @copydoc nk_each_fma_f64 */
459
+ NK_PUBLIC void nk_each_fma_u32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
460
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
461
+ /** @copydoc nk_each_fma_f64 */
462
+ NK_PUBLIC void nk_each_fma_i64_serial(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
463
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
464
+ /** @copydoc nk_each_fma_f64 */
465
+ NK_PUBLIC void nk_each_fma_u64_serial(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
466
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
467
+
468
+ /** @copydoc nk_each_sum_e4m3 */
469
+ NK_PUBLIC void nk_each_sum_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
470
+ /** @copydoc nk_each_sum_e5m2 */
471
+ NK_PUBLIC void nk_each_sum_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
472
+ /** @copydoc nk_each_scale_e4m3 */
473
+ NK_PUBLIC void nk_each_scale_e4m3_serial(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
474
+ nk_e4m3_t *result);
475
+ /** @copydoc nk_each_scale_e5m2 */
476
+ NK_PUBLIC void nk_each_scale_e5m2_serial(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
477
+ nk_e5m2_t *result);
478
+ /** @copydoc nk_each_blend_e4m3 */
479
+ NK_PUBLIC void nk_each_blend_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
480
+ nk_f32_t const *beta, nk_e4m3_t *result);
481
+ /** @copydoc nk_each_blend_e5m2 */
482
+ NK_PUBLIC void nk_each_blend_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
483
+ nk_f32_t const *beta, nk_e5m2_t *result);
484
+ /** @copydoc nk_each_fma_e4m3 */
485
+ NK_PUBLIC void nk_each_fma_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
486
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
487
+ /** @copydoc nk_each_fma_e5m2 */
488
+ NK_PUBLIC void nk_each_fma_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
489
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
490
+
491
+ /** @copydoc nk_each_sum_e2m3 */
492
+ NK_PUBLIC void nk_each_sum_e2m3_serial(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_e2m3_t *result);
493
+ /** @copydoc nk_each_sum_e3m2 */
494
+ NK_PUBLIC void nk_each_sum_e3m2_serial(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_e3m2_t *result);
495
+ /** @copydoc nk_each_scale_e2m3 */
496
+ NK_PUBLIC void nk_each_scale_e2m3_serial(nk_e2m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
497
+ nk_e2m3_t *result);
498
+ /** @copydoc nk_each_scale_e3m2 */
499
+ NK_PUBLIC void nk_each_scale_e3m2_serial(nk_e3m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
500
+ nk_e3m2_t *result);
501
+ /** @copydoc nk_each_blend_e2m3 */
502
+ NK_PUBLIC void nk_each_blend_e2m3_serial(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
503
+ nk_f32_t const *beta, nk_e2m3_t *result);
504
+ /** @copydoc nk_each_blend_e3m2 */
505
+ NK_PUBLIC void nk_each_blend_e3m2_serial(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
506
+ nk_f32_t const *beta, nk_e3m2_t *result);
507
+ /** @copydoc nk_each_fma_e2m3 */
508
+ NK_PUBLIC void nk_each_fma_e2m3_serial(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_e2m3_t const *c, nk_size_t n,
509
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e2m3_t *result);
510
+ /** @copydoc nk_each_fma_e3m2 */
511
+ NK_PUBLIC void nk_each_fma_e3m2_serial(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_e3m2_t const *c, nk_size_t n,
512
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e3m2_t *result);
513
+
514
+ /** @copydoc nk_each_sum_f64 */
515
+ NK_PUBLIC void nk_each_sum_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t *result);
516
+ /** @copydoc nk_each_sum_f64 */
517
+ NK_PUBLIC void nk_each_sum_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
518
+ /** @copydoc nk_each_scale_f64 */
519
+ NK_PUBLIC void nk_each_scale_f32c_serial(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
520
+ nk_f32c_t *result);
521
+ /** @copydoc nk_each_scale_f64 */
522
+ NK_PUBLIC void nk_each_scale_f64c_serial(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
523
+ nk_f64c_t *result);
524
+ /** @copydoc nk_each_blend_f64 */
525
+ NK_PUBLIC void nk_each_blend_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
526
+ nk_f32c_t const *beta, nk_f32c_t *result);
527
+ /** @copydoc nk_each_blend_f64 */
528
+ NK_PUBLIC void nk_each_blend_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
529
+ nk_f64c_t const *beta, nk_f64c_t *result);
530
+ /** @copydoc nk_each_fma_f64 */
531
+ NK_PUBLIC void nk_each_fma_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
532
+ nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
533
+ /** @copydoc nk_each_fma_f64 */
534
+ NK_PUBLIC void nk_each_fma_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
535
+ nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
536
+
537
+ #if NK_TARGET_NEON
538
+ /** @copydoc nk_each_scale_f32 */
539
+ NK_PUBLIC void nk_each_scale_f32_neon(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
540
+ nk_f32_t *result);
541
+ /** @copydoc nk_each_scale_i16 */
542
+ NK_PUBLIC void nk_each_scale_i16_neon(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
543
+ nk_i16_t *result);
544
+ /** @copydoc nk_each_scale_u16 */
545
+ NK_PUBLIC void nk_each_scale_u16_neon(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
546
+ nk_u16_t *result);
547
+ /** @copydoc nk_each_scale_i32 */
548
+ NK_PUBLIC void nk_each_scale_i32_neon(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
549
+ nk_i32_t *result);
550
+ /** @copydoc nk_each_scale_u32 */
551
+ NK_PUBLIC void nk_each_scale_u32_neon(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
552
+ nk_u32_t *result);
553
+ /** @copydoc nk_each_scale_i64 */
554
+ NK_PUBLIC void nk_each_scale_i64_neon(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
555
+ nk_i64_t *result);
556
+ /** @copydoc nk_each_scale_u64 */
557
+ NK_PUBLIC void nk_each_scale_u64_neon(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
558
+ nk_u64_t *result);
559
+
560
+ /** @copydoc nk_each_sum_f32 */
561
+ NK_PUBLIC void nk_each_sum_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
562
+ /** @copydoc nk_each_sum_i16 */
563
+ NK_PUBLIC void nk_each_sum_i16_neon(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
564
+ /** @copydoc nk_each_sum_u16 */
565
+ NK_PUBLIC void nk_each_sum_u16_neon(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
566
+ /** @copydoc nk_each_sum_i32 */
567
+ NK_PUBLIC void nk_each_sum_i32_neon(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
568
+ /** @copydoc nk_each_sum_u32 */
569
+ NK_PUBLIC void nk_each_sum_u32_neon(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
570
+ /** @copydoc nk_each_sum_i64 */
571
+ NK_PUBLIC void nk_each_sum_i64_neon(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
572
+ /** @copydoc nk_each_sum_u64 */
573
+ NK_PUBLIC void nk_each_sum_u64_neon(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
574
+
575
+ /** @copydoc nk_each_blend_f32 */
576
+ NK_PUBLIC void nk_each_blend_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
577
+ nk_f32_t const *beta, nk_f32_t *result);
578
+
579
+ /** @copydoc nk_each_fma_f32 */
580
+ NK_PUBLIC void nk_each_fma_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
581
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
582
+ /** @copydoc nk_each_fma_i16 */
583
+ NK_PUBLIC void nk_each_fma_i16_neon(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
584
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
585
+ /** @copydoc nk_each_fma_u16 */
586
+ NK_PUBLIC void nk_each_fma_u16_neon(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
587
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
588
+ /** @copydoc nk_each_fma_i32 */
589
+ NK_PUBLIC void nk_each_fma_i32_neon(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
590
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
591
+ /** @copydoc nk_each_fma_u32 */
592
+ NK_PUBLIC void nk_each_fma_u32_neon(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
593
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
594
+ /** @copydoc nk_each_fma_i64 */
595
+ NK_PUBLIC void nk_each_fma_i64_neon(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
596
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
597
+ /** @copydoc nk_each_fma_u64 */
598
+ NK_PUBLIC void nk_each_fma_u64_neon(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
599
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
600
+
601
+ /** @copydoc nk_each_sum_f64 */
602
+ NK_PUBLIC void nk_each_sum_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
603
+ /** @copydoc nk_each_scale_f64 */
604
+ NK_PUBLIC void nk_each_scale_f64_neon(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
605
+ nk_f64_t *result);
606
+ /** @copydoc nk_each_blend_f64 */
607
+ NK_PUBLIC void nk_each_blend_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
608
+ nk_f64_t const *beta, nk_f64_t *result);
609
+ /** @copydoc nk_each_fma_f64 */
610
+ NK_PUBLIC void nk_each_fma_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
611
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
612
+
613
+ /** @copydoc nk_each_sum_e4m3 */
614
+ NK_PUBLIC void nk_each_sum_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
615
+ /** @copydoc nk_each_sum_e5m2 */
616
+ NK_PUBLIC void nk_each_sum_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
617
+ /** @copydoc nk_each_scale_e4m3 */
618
+ NK_PUBLIC void nk_each_scale_e4m3_neon(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
619
+ nk_e4m3_t *result);
620
+ /** @copydoc nk_each_scale_e5m2 */
621
+ NK_PUBLIC void nk_each_scale_e5m2_neon(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
622
+ nk_e5m2_t *result);
623
+ /** @copydoc nk_each_blend_e4m3 */
624
+ NK_PUBLIC void nk_each_blend_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
625
+ nk_f32_t const *beta, nk_e4m3_t *result);
626
+ /** @copydoc nk_each_blend_e5m2 */
627
+ NK_PUBLIC void nk_each_blend_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
628
+ nk_f32_t const *beta, nk_e5m2_t *result);
629
+ /** @copydoc nk_each_fma_e4m3 */
630
+ NK_PUBLIC void nk_each_fma_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
631
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
632
+ /** @copydoc nk_each_fma_e5m2 */
633
+ NK_PUBLIC void nk_each_fma_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
634
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
635
+
636
+ /** @copydoc nk_each_scale_f64 */
637
+ NK_PUBLIC void nk_each_scale_f32c_neon(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
638
+ nk_f32c_t *result);
639
+ /** @copydoc nk_each_scale_f64 */
640
+ NK_PUBLIC void nk_each_scale_f64c_neon(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
641
+ nk_f64c_t *result);
642
+ /** @copydoc nk_each_blend_f64 */
643
+ NK_PUBLIC void nk_each_blend_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
644
+ nk_f32c_t const *beta, nk_f32c_t *result);
645
+ /** @copydoc nk_each_blend_f64 */
646
+ NK_PUBLIC void nk_each_blend_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
647
+ nk_f64c_t const *beta, nk_f64c_t *result);
648
+ /** @copydoc nk_each_fma_f64 */
649
+ NK_PUBLIC void nk_each_fma_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
650
+ nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
651
+ /** @copydoc nk_each_fma_f64 */
652
+ NK_PUBLIC void nk_each_fma_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
653
+ nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
654
+ #endif // NK_TARGET_NEON
655
+
656
+ #if NK_TARGET_NEONBFDOT
657
+ /** @copydoc nk_each_sum_bf16 */
658
+ NK_PUBLIC void nk_each_sum_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
659
+ /** @copydoc nk_each_scale_bf16 */
660
+ NK_PUBLIC void nk_each_scale_bf16_neonbfdot(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha,
661
+ nk_f32_t const *beta, nk_bf16_t *result);
662
+ /** @copydoc nk_each_blend_bf16 */
663
+ NK_PUBLIC void nk_each_blend_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
664
+ nk_f32_t const *beta, nk_bf16_t *result);
665
+ /** @copydoc nk_each_fma_bf16 */
666
+ NK_PUBLIC void nk_each_fma_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
667
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
668
+ #endif // NK_TARGET_NEONBFDOT
669
+
670
+ #if NK_TARGET_NEONHALF
671
+ /** @copydoc nk_each_sum_f16 */
672
+ NK_PUBLIC void nk_each_sum_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
673
+ /** @copydoc nk_each_scale_f16 */
674
+ NK_PUBLIC void nk_each_scale_f16_neonhalf(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
675
+ nk_f16_t *result);
676
+ /** @copydoc nk_each_blend_f16 */
677
+ NK_PUBLIC void nk_each_blend_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
678
+ nk_f32_t const *beta, nk_f16_t *result);
679
+ /** @copydoc nk_each_fma_f16 */
680
+ NK_PUBLIC void nk_each_fma_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
681
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
682
+
683
+ /** @copydoc nk_each_sum_i8 */
684
+ NK_PUBLIC void nk_each_sum_i8_neonhalf(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
685
+ /** @copydoc nk_each_sum_u8 */
686
+ NK_PUBLIC void nk_each_sum_u8_neonhalf(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
687
+ /** @copydoc nk_each_scale_i8 */
688
+ NK_PUBLIC void nk_each_scale_i8_neonhalf(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
689
+ nk_i8_t *result);
690
+ /** @copydoc nk_each_scale_u8 */
691
+ NK_PUBLIC void nk_each_scale_u8_neonhalf(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
692
+ nk_u8_t *result);
693
+ /** @copydoc nk_each_blend_i8 */
694
+ NK_PUBLIC void nk_each_blend_i8_neonhalf(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
695
+ nk_f32_t const *beta, nk_i8_t *result);
696
+ /** @copydoc nk_each_blend_u8 */
697
+ NK_PUBLIC void nk_each_blend_u8_neonhalf(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
698
+ nk_f32_t const *beta, nk_u8_t *result);
699
+ /** @copydoc nk_each_fma_i8 */
700
+ NK_PUBLIC void nk_each_fma_i8_neonhalf(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
701
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
702
+ /** @copydoc nk_each_fma_u8 */
703
+ NK_PUBLIC void nk_each_fma_u8_neonhalf(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
704
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
705
+ #endif // NK_TARGET_NEONHALF
706
+
707
+ #if NK_TARGET_HASWELL
708
+ /** @copydoc nk_each_scale_f64 */
709
+ NK_PUBLIC void nk_each_scale_f64_haswell(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
710
+ nk_f64_t *result);
711
+ /** @copydoc nk_each_scale_f32 */
712
+ NK_PUBLIC void nk_each_scale_f32_haswell(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
713
+ nk_f32_t *result);
714
+ /** @copydoc nk_each_scale_f16 */
715
+ NK_PUBLIC void nk_each_scale_f16_haswell(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
716
+ nk_f16_t *result);
717
+ /** @copydoc nk_each_scale_bf16 */
718
+ NK_PUBLIC void nk_each_scale_bf16_haswell(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
719
+ nk_bf16_t *result);
720
+ /** @copydoc nk_each_scale_i8 */
721
+ NK_PUBLIC void nk_each_scale_i8_haswell(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
722
+ nk_i8_t *result);
723
+ /** @copydoc nk_each_scale_u8 */
724
+ NK_PUBLIC void nk_each_scale_u8_haswell(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
725
+ nk_u8_t *result);
726
+ /** @copydoc nk_each_scale_i16 */
727
+ NK_PUBLIC void nk_each_scale_i16_haswell(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
728
+ nk_i16_t *result);
729
+ /** @copydoc nk_each_scale_u16 */
730
+ NK_PUBLIC void nk_each_scale_u16_haswell(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
731
+ nk_u16_t *result);
732
+ /** @copydoc nk_each_scale_i32 */
733
+ NK_PUBLIC void nk_each_scale_i32_haswell(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
734
+ nk_i32_t *result);
735
+ /** @copydoc nk_each_scale_u32 */
736
+ NK_PUBLIC void nk_each_scale_u32_haswell(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
737
+ nk_u32_t *result);
738
+
739
+ /** @copydoc nk_each_sum_f64 */
740
+ NK_PUBLIC void nk_each_sum_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
741
+ /** @copydoc nk_each_sum_f32 */
742
+ NK_PUBLIC void nk_each_sum_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
743
+ /** @copydoc nk_each_sum_f16 */
744
+ NK_PUBLIC void nk_each_sum_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
745
+ /** @copydoc nk_each_sum_bf16 */
746
+ NK_PUBLIC void nk_each_sum_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
747
+ /** @copydoc nk_each_sum_i8 */
748
+ NK_PUBLIC void nk_each_sum_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
749
+ /** @copydoc nk_each_sum_u8 */
750
+ NK_PUBLIC void nk_each_sum_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
751
+ /** @copydoc nk_each_sum_i16 */
752
+ NK_PUBLIC void nk_each_sum_i16_haswell(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
753
+ /** @copydoc nk_each_sum_u16 */
754
+ NK_PUBLIC void nk_each_sum_u16_haswell(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
755
+ /** @copydoc nk_each_sum_i32 */
756
+ NK_PUBLIC void nk_each_sum_i32_haswell(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
757
+ /** @copydoc nk_each_sum_u32 */
758
+ NK_PUBLIC void nk_each_sum_u32_haswell(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
759
+
760
+ /** @copydoc nk_each_blend_f64 */
761
+ NK_PUBLIC void nk_each_blend_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
762
+ nk_f64_t const *beta, nk_f64_t *result);
763
+ /** @copydoc nk_each_blend_f32 */
764
+ NK_PUBLIC void nk_each_blend_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
765
+ nk_f32_t const *beta, nk_f32_t *result);
766
+ /** @copydoc nk_each_blend_f16 */
767
+ NK_PUBLIC void nk_each_blend_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
768
+ nk_f32_t const *beta, nk_f16_t *result);
769
+ /** @copydoc nk_each_blend_bf16 */
770
+ NK_PUBLIC void nk_each_blend_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
771
+ nk_f32_t const *beta, nk_bf16_t *result);
772
+ /** @copydoc nk_each_blend_i8 */
773
+ NK_PUBLIC void nk_each_blend_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
774
+ nk_f32_t const *beta, nk_i8_t *result);
775
+ /** @copydoc nk_each_blend_u8 */
776
+ NK_PUBLIC void nk_each_blend_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
777
+ nk_f32_t const *beta, nk_u8_t *result);
778
+
779
+ /** @copydoc nk_each_fma_f64 */
780
+ NK_PUBLIC void nk_each_fma_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
781
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
782
+ /** @copydoc nk_each_fma_f32 */
783
+ NK_PUBLIC void nk_each_fma_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
784
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
785
+ /** @copydoc nk_each_fma_f16 */
786
+ NK_PUBLIC void nk_each_fma_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
787
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
788
+ /** @copydoc nk_each_fma_bf16 */
789
+ NK_PUBLIC void nk_each_fma_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
790
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
791
+ /** @copydoc nk_each_fma_i8 */
792
+ NK_PUBLIC void nk_each_fma_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
793
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
794
+ /** @copydoc nk_each_fma_u8 */
795
+ NK_PUBLIC void nk_each_fma_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
796
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
797
+ /** @copydoc nk_each_fma_i16 */
798
+ NK_PUBLIC void nk_each_fma_i16_haswell(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
799
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
800
+ /** @copydoc nk_each_fma_u16 */
801
+ NK_PUBLIC void nk_each_fma_u16_haswell(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
802
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
803
+ /** @copydoc nk_each_fma_i32 */
804
+ NK_PUBLIC void nk_each_fma_i32_haswell(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
805
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
806
+ /** @copydoc nk_each_fma_u32 */
807
+ NK_PUBLIC void nk_each_fma_u32_haswell(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
808
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
809
+
810
+ /** @copydoc nk_each_sum_e4m3 */
811
+ NK_PUBLIC void nk_each_sum_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
812
+ /** @copydoc nk_each_sum_e5m2 */
813
+ NK_PUBLIC void nk_each_sum_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
814
+ /** @copydoc nk_each_scale_e4m3 */
815
+ NK_PUBLIC void nk_each_scale_e4m3_haswell(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
816
+ nk_e4m3_t *result);
817
+ /** @copydoc nk_each_scale_e5m2 */
818
+ NK_PUBLIC void nk_each_scale_e5m2_haswell(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
819
+ nk_e5m2_t *result);
820
+ /** @copydoc nk_each_blend_e4m3 */
821
+ NK_PUBLIC void nk_each_blend_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
822
+ nk_f32_t const *beta, nk_e4m3_t *result);
823
+ /** @copydoc nk_each_blend_e5m2 */
824
+ NK_PUBLIC void nk_each_blend_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
825
+ nk_f32_t const *beta, nk_e5m2_t *result);
826
+ /** @copydoc nk_each_fma_e4m3 */
827
+ NK_PUBLIC void nk_each_fma_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
828
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
829
+ /** @copydoc nk_each_fma_e5m2 */
830
+ NK_PUBLIC void nk_each_fma_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
831
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
832
+
833
+ /** @copydoc nk_each_scale_f64 */
834
+ NK_PUBLIC void nk_each_scale_f32c_haswell(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha,
835
+ nk_f32c_t const *beta, nk_f32c_t *result);
836
+ /** @copydoc nk_each_scale_f64 */
837
+ NK_PUBLIC void nk_each_scale_f64c_haswell(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha,
838
+ nk_f64c_t const *beta, nk_f64c_t *result);
839
+ /** @copydoc nk_each_blend_f64 */
840
+ NK_PUBLIC void nk_each_blend_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
841
+ nk_f32c_t const *beta, nk_f32c_t *result);
842
+ /** @copydoc nk_each_blend_f64 */
843
+ NK_PUBLIC void nk_each_blend_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
844
+ nk_f64c_t const *beta, nk_f64c_t *result);
845
+ /** @copydoc nk_each_fma_f64 */
846
+ NK_PUBLIC void nk_each_fma_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
847
+ nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
848
+ /** @copydoc nk_each_fma_f64 */
849
+ NK_PUBLIC void nk_each_fma_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
850
+ nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
851
+ #endif // NK_TARGET_HASWELL
852
+
853
+ #if NK_TARGET_SKYLAKE
854
+ /** @copydoc nk_each_scale_f64 */
855
+ NK_PUBLIC void nk_each_scale_f64_skylake(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
856
+ nk_f64_t *result);
857
+ /** @copydoc nk_each_scale_f32 */
858
+ NK_PUBLIC void nk_each_scale_f32_skylake(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
859
+ nk_f32_t *result);
860
+ /** @copydoc nk_each_scale_f16 */
861
+ NK_PUBLIC void nk_each_scale_f16_skylake(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
862
+ nk_f16_t *result);
863
+ /** @copydoc nk_each_scale_bf16 */
864
+ NK_PUBLIC void nk_each_scale_bf16_skylake(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
865
+ nk_bf16_t *result);
866
+ /** @copydoc nk_each_scale_i8 */
867
+ NK_PUBLIC void nk_each_scale_i8_skylake(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
868
+ nk_i8_t *result);
869
+ /** @copydoc nk_each_scale_u8 */
870
+ NK_PUBLIC void nk_each_scale_u8_skylake(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
871
+ nk_u8_t *result);
872
+ /** @copydoc nk_each_scale_i16 */
873
+ NK_PUBLIC void nk_each_scale_i16_skylake(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
874
+ nk_i16_t *result);
875
+ /** @copydoc nk_each_scale_u16 */
876
+ NK_PUBLIC void nk_each_scale_u16_skylake(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
877
+ nk_u16_t *result);
878
+ /** @copydoc nk_each_scale_i32 */
879
+ NK_PUBLIC void nk_each_scale_i32_skylake(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
880
+ nk_i32_t *result);
881
+ /** @copydoc nk_each_scale_u32 */
882
+ NK_PUBLIC void nk_each_scale_u32_skylake(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
883
+ nk_u32_t *result);
884
+ /** @copydoc nk_each_scale_i64 */
885
+ NK_PUBLIC void nk_each_scale_i64_skylake(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
886
+ nk_i64_t *result);
887
+ /** @copydoc nk_each_scale_u64 */
888
+ NK_PUBLIC void nk_each_scale_u64_skylake(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
889
+ nk_u64_t *result);
890
+
891
+ /** @copydoc nk_each_sum_f64 */
892
+ NK_PUBLIC void nk_each_sum_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
893
+ /** @copydoc nk_each_sum_f32 */
894
+ NK_PUBLIC void nk_each_sum_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
895
+ /** @copydoc nk_each_sum_bf16 */
896
+ NK_PUBLIC void nk_each_sum_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
897
+
898
+ /** @copydoc nk_each_blend_f64 */
899
+ NK_PUBLIC void nk_each_blend_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
900
+ nk_f64_t const *beta, nk_f64_t *result);
901
+ /** @copydoc nk_each_blend_f32 */
902
+ NK_PUBLIC void nk_each_blend_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
903
+ nk_f32_t const *beta, nk_f32_t *result);
904
+ /** @copydoc nk_each_blend_f16 */
905
+ NK_PUBLIC void nk_each_blend_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
906
+ nk_f32_t const *beta, nk_f16_t *result);
907
+ /** @copydoc nk_each_blend_bf16 */
908
+ NK_PUBLIC void nk_each_blend_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
909
+ nk_f32_t const *beta, nk_bf16_t *result);
910
+
911
+ /** @copydoc nk_each_fma_f64 */
912
+ NK_PUBLIC void nk_each_fma_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
913
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
914
+ /** @copydoc nk_each_fma_f32 */
915
+ NK_PUBLIC void nk_each_fma_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
916
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
917
+ /** @copydoc nk_each_fma_f16 */
918
+ NK_PUBLIC void nk_each_fma_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
919
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
920
+ /** @copydoc nk_each_fma_bf16 */
921
+ NK_PUBLIC void nk_each_fma_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
922
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
923
+ /** @copydoc nk_each_fma_i8 */
924
+ NK_PUBLIC void nk_each_fma_i8_skylake(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
925
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
926
+ /** @copydoc nk_each_fma_u8 */
927
+ NK_PUBLIC void nk_each_fma_u8_skylake(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
928
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
929
+ /** @copydoc nk_each_fma_i16 */
930
+ NK_PUBLIC void nk_each_fma_i16_skylake(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
931
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
932
+ /** @copydoc nk_each_fma_u16 */
933
+ NK_PUBLIC void nk_each_fma_u16_skylake(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
934
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
935
+ /** @copydoc nk_each_fma_i32 */
936
+ NK_PUBLIC void nk_each_fma_i32_skylake(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
937
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
938
+ /** @copydoc nk_each_fma_u32 */
939
+ NK_PUBLIC void nk_each_fma_u32_skylake(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
940
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
941
+ /** @copydoc nk_each_fma_i64 */
942
+ NK_PUBLIC void nk_each_fma_i64_skylake(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
943
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
944
+ /** @copydoc nk_each_fma_u64 */
945
+ NK_PUBLIC void nk_each_fma_u64_skylake(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
946
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
947
+ /** @copydoc nk_each_sum_e4m3 */
948
+ NK_PUBLIC void nk_each_sum_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
949
+ /** @copydoc nk_each_sum_e5m2 */
950
+ NK_PUBLIC void nk_each_sum_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
951
+ /** @copydoc nk_each_scale_e4m3 */
952
+ NK_PUBLIC void nk_each_scale_e4m3_skylake(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
953
+ nk_e4m3_t *result);
954
+ /** @copydoc nk_each_scale_e5m2 */
955
+ NK_PUBLIC void nk_each_scale_e5m2_skylake(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
956
+ nk_e5m2_t *result);
957
+ /** @copydoc nk_each_blend_e4m3 */
958
+ NK_PUBLIC void nk_each_blend_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
959
+ nk_f32_t const *beta, nk_e4m3_t *result);
960
+ /** @copydoc nk_each_blend_e5m2 */
961
+ NK_PUBLIC void nk_each_blend_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
962
+ nk_f32_t const *beta, nk_e5m2_t *result);
963
+ /** @copydoc nk_each_fma_e4m3 */
964
+ NK_PUBLIC void nk_each_fma_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
965
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
966
+ /** @copydoc nk_each_fma_e5m2 */
967
+ NK_PUBLIC void nk_each_fma_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
968
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
969
+
970
+ /** @copydoc nk_each_scale_f64 */
971
+ NK_PUBLIC void nk_each_scale_f32c_skylake(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha,
972
+ nk_f32c_t const *beta, nk_f32c_t *result);
973
+ /** @copydoc nk_each_scale_f64 */
974
+ NK_PUBLIC void nk_each_scale_f64c_skylake(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha,
975
+ nk_f64c_t const *beta, nk_f64c_t *result);
976
+ /** @copydoc nk_each_blend_f64 */
977
+ NK_PUBLIC void nk_each_blend_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
978
+ nk_f32c_t const *beta, nk_f32c_t *result);
979
+ /** @copydoc nk_each_blend_f64 */
980
+ NK_PUBLIC void nk_each_blend_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
981
+ nk_f64c_t const *beta, nk_f64c_t *result);
982
+ /** @copydoc nk_each_fma_f64 */
983
+ NK_PUBLIC void nk_each_fma_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
984
+ nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
985
+ /** @copydoc nk_each_fma_f64 */
986
+ NK_PUBLIC void nk_each_fma_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
987
+ nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
988
+ #endif // NK_TARGET_SKYLAKE
989
+
990
+ #if NK_TARGET_ICELAKE
991
+ /** @copydoc nk_each_sum_i8 */
992
+ NK_PUBLIC void nk_each_sum_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
993
+ /** @copydoc nk_each_sum_u8 */
994
+ NK_PUBLIC void nk_each_sum_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
995
+ /** @copydoc nk_each_sum_i16 */
996
+ NK_PUBLIC void nk_each_sum_i16_icelake(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
997
+ /** @copydoc nk_each_sum_u16 */
998
+ NK_PUBLIC void nk_each_sum_u16_icelake(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
999
+ /** @copydoc nk_each_sum_i32 */
1000
+ NK_PUBLIC void nk_each_sum_i32_icelake(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
1001
+ /** @copydoc nk_each_sum_u32 */
1002
+ NK_PUBLIC void nk_each_sum_u32_icelake(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
1003
+ /** @copydoc nk_each_sum_i64 */
1004
+ NK_PUBLIC void nk_each_sum_i64_icelake(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
1005
+ /** @copydoc nk_each_sum_u64 */
1006
+ NK_PUBLIC void nk_each_sum_u64_icelake(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
1007
+ #endif // NK_TARGET_ICELAKE
1008
+
1009
+ #if NK_TARGET_SAPPHIRE
1010
+ /** @copydoc nk_each_scale_i8 */
1011
+ NK_PUBLIC void nk_each_scale_i8_sapphire(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1012
+ nk_i8_t *result);
1013
+ /** @copydoc nk_each_scale_u8 */
1014
+ NK_PUBLIC void nk_each_scale_u8_sapphire(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1015
+ nk_u8_t *result);
1016
+
1017
+ /** @copydoc nk_each_sum_f16 */
1018
+ NK_PUBLIC void nk_each_sum_f16_sapphire(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
1019
+ /** @copydoc nk_each_sum_e4m3 */
1020
+ NK_PUBLIC void nk_each_sum_e4m3_sapphire(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
1021
+
1022
+ /** @copydoc nk_each_blend_i8 */
1023
+ NK_PUBLIC void nk_each_blend_i8_sapphire(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
1024
+ nk_f32_t const *beta, nk_i8_t *result);
1025
+ /** @copydoc nk_each_blend_u8 */
1026
+ NK_PUBLIC void nk_each_blend_u8_sapphire(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
1027
+ nk_f32_t const *beta, nk_u8_t *result);
1028
+
1029
+ /** @copydoc nk_each_fma_i8 */
1030
+ NK_PUBLIC void nk_each_fma_i8_sapphire(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
1031
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
1032
+ /** @copydoc nk_each_fma_u8 */
1033
+ NK_PUBLIC void nk_each_fma_u8_sapphire(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
1034
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
1035
+ #endif // NK_TARGET_SAPPHIRE
1036
+
1037
+ #if NK_TARGET_RVV
1038
+ /** @copydoc nk_each_sum_f64 */
1039
+ NK_PUBLIC void nk_each_sum_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
1040
+ /** @copydoc nk_each_sum_f32 */
1041
+ NK_PUBLIC void nk_each_sum_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
1042
+ /** @copydoc nk_each_sum_f16 */
1043
+ NK_PUBLIC void nk_each_sum_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
1044
+ /** @copydoc nk_each_sum_bf16 */
1045
+ NK_PUBLIC void nk_each_sum_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
1046
+ /** @copydoc nk_each_sum_i8 */
1047
+ NK_PUBLIC void nk_each_sum_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
1048
+ /** @copydoc nk_each_sum_u8 */
1049
+ NK_PUBLIC void nk_each_sum_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
1050
+ /** @copydoc nk_each_sum_i16 */
1051
+ NK_PUBLIC void nk_each_sum_i16_rvv(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
1052
+ /** @copydoc nk_each_sum_u16 */
1053
+ NK_PUBLIC void nk_each_sum_u16_rvv(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
1054
+ /** @copydoc nk_each_sum_i32 */
1055
+ NK_PUBLIC void nk_each_sum_i32_rvv(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
1056
+ /** @copydoc nk_each_sum_u32 */
1057
+ NK_PUBLIC void nk_each_sum_u32_rvv(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
1058
+ /** @copydoc nk_each_sum_i64 */
1059
+ NK_PUBLIC void nk_each_sum_i64_rvv(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
1060
+ /** @copydoc nk_each_sum_u64 */
1061
+ NK_PUBLIC void nk_each_sum_u64_rvv(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
1062
+ /** @copydoc nk_each_sum_e4m3 */
1063
+ NK_PUBLIC void nk_each_sum_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
1064
+ /** @copydoc nk_each_sum_e5m2 */
1065
+ NK_PUBLIC void nk_each_sum_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
1066
+
1067
+ /** @copydoc nk_each_scale_f64 */
1068
+ NK_PUBLIC void nk_each_scale_f64_rvv(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1069
+ nk_f64_t *result);
1070
+ /** @copydoc nk_each_scale_f32 */
1071
+ NK_PUBLIC void nk_each_scale_f32_rvv(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1072
+ nk_f32_t *result);
1073
+ /** @copydoc nk_each_scale_f16 */
1074
+ NK_PUBLIC void nk_each_scale_f16_rvv(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1075
+ nk_f16_t *result);
1076
+ /** @copydoc nk_each_scale_bf16 */
1077
+ NK_PUBLIC void nk_each_scale_bf16_rvv(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1078
+ nk_bf16_t *result);
1079
+ /** @copydoc nk_each_scale_i8 */
1080
+ NK_PUBLIC void nk_each_scale_i8_rvv(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1081
+ nk_i8_t *result);
1082
+ /** @copydoc nk_each_scale_u8 */
1083
+ NK_PUBLIC void nk_each_scale_u8_rvv(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1084
+ nk_u8_t *result);
1085
+ /** @copydoc nk_each_scale_i16 */
1086
+ NK_PUBLIC void nk_each_scale_i16_rvv(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1087
+ nk_i16_t *result);
1088
+ /** @copydoc nk_each_scale_u16 */
1089
+ NK_PUBLIC void nk_each_scale_u16_rvv(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1090
+ nk_u16_t *result);
1091
+ /** @copydoc nk_each_scale_i32 */
1092
+ NK_PUBLIC void nk_each_scale_i32_rvv(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1093
+ nk_i32_t *result);
1094
+ /** @copydoc nk_each_scale_u32 */
1095
+ NK_PUBLIC void nk_each_scale_u32_rvv(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1096
+ nk_u32_t *result);
1097
+ /** @copydoc nk_each_scale_i64 */
1098
+ NK_PUBLIC void nk_each_scale_i64_rvv(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1099
+ nk_i64_t *result);
1100
+ /** @copydoc nk_each_scale_u64 */
1101
+ NK_PUBLIC void nk_each_scale_u64_rvv(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1102
+ nk_u64_t *result);
1103
+ /** @copydoc nk_each_scale_e4m3 */
1104
+ NK_PUBLIC void nk_each_scale_e4m3_rvv(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1105
+ nk_e4m3_t *result);
1106
+ /** @copydoc nk_each_scale_e5m2 */
1107
+ NK_PUBLIC void nk_each_scale_e5m2_rvv(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1108
+ nk_e5m2_t *result);
1109
+
1110
+ /** @copydoc nk_each_blend_f64 */
1111
+ NK_PUBLIC void nk_each_blend_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
1112
+ nk_f64_t const *beta, nk_f64_t *result);
1113
+ /** @copydoc nk_each_blend_f32 */
1114
+ NK_PUBLIC void nk_each_blend_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
1115
+ nk_f32_t const *beta, nk_f32_t *result);
1116
+ /** @copydoc nk_each_blend_f16 */
1117
+ NK_PUBLIC void nk_each_blend_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
1118
+ nk_f32_t const *beta, nk_f16_t *result);
1119
+ /** @copydoc nk_each_blend_bf16 */
1120
+ NK_PUBLIC void nk_each_blend_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
1121
+ nk_f32_t const *beta, nk_bf16_t *result);
1122
+ /** @copydoc nk_each_blend_i8 */
1123
+ NK_PUBLIC void nk_each_blend_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
1124
+ nk_f32_t const *beta, nk_i8_t *result);
1125
+ /** @copydoc nk_each_blend_u8 */
1126
+ NK_PUBLIC void nk_each_blend_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
1127
+ nk_f32_t const *beta, nk_u8_t *result);
1128
+ /** @copydoc nk_each_blend_e4m3 */
1129
+ NK_PUBLIC void nk_each_blend_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
1130
+ nk_f32_t const *beta, nk_e4m3_t *result);
1131
+ /** @copydoc nk_each_blend_e5m2 */
1132
+ NK_PUBLIC void nk_each_blend_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
1133
+ nk_f32_t const *beta, nk_e5m2_t *result);
1134
+
1135
+ /** @copydoc nk_each_fma_f64 */
1136
+ NK_PUBLIC void nk_each_fma_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
1137
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
1138
+ /** @copydoc nk_each_fma_f32 */
1139
+ NK_PUBLIC void nk_each_fma_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
1140
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
1141
+ /** @copydoc nk_each_fma_f16 */
1142
+ NK_PUBLIC void nk_each_fma_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
1143
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
1144
+ /** @copydoc nk_each_fma_bf16 */
1145
+ NK_PUBLIC void nk_each_fma_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
1146
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
1147
+ /** @copydoc nk_each_fma_i8 */
1148
+ NK_PUBLIC void nk_each_fma_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
1149
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
1150
+ /** @copydoc nk_each_fma_u8 */
1151
+ NK_PUBLIC void nk_each_fma_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
1152
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
1153
+ /** @copydoc nk_each_fma_i16 */
1154
+ NK_PUBLIC void nk_each_fma_i16_rvv(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
1155
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
1156
+ /** @copydoc nk_each_fma_u16 */
1157
+ NK_PUBLIC void nk_each_fma_u16_rvv(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
1158
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
1159
+ /** @copydoc nk_each_fma_i32 */
1160
+ NK_PUBLIC void nk_each_fma_i32_rvv(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
1161
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
1162
+ /** @copydoc nk_each_fma_u32 */
1163
+ NK_PUBLIC void nk_each_fma_u32_rvv(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
1164
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
1165
+ /** @copydoc nk_each_fma_i64 */
1166
+ NK_PUBLIC void nk_each_fma_i64_rvv(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
1167
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
1168
+ /** @copydoc nk_each_fma_u64 */
1169
+ NK_PUBLIC void nk_each_fma_u64_rvv(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
1170
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
1171
+ /** @copydoc nk_each_fma_e4m3 */
1172
+ NK_PUBLIC void nk_each_fma_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
1173
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
1174
+ /** @copydoc nk_each_fma_e5m2 */
1175
+ NK_PUBLIC void nk_each_fma_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
1176
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
1177
+ /** @copydoc nk_each_scale_f32c */
1178
+ NK_PUBLIC void nk_each_scale_f32c_rvv(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
1179
+ nk_f32c_t *result);
1180
+ /** @copydoc nk_each_scale_f64c */
1181
+ NK_PUBLIC void nk_each_scale_f64c_rvv(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
1182
+ nk_f64c_t *result);
1183
+ /** @copydoc nk_each_blend_f32c */
1184
+ NK_PUBLIC void nk_each_blend_f32c_rvv(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
1185
+ nk_f32c_t const *beta, nk_f32c_t *result);
1186
+ /** @copydoc nk_each_blend_f64c */
1187
+ NK_PUBLIC void nk_each_blend_f64c_rvv(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
1188
+ nk_f64c_t const *beta, nk_f64c_t *result);
1189
+ /** @copydoc nk_each_fma_f32c */
1190
+ NK_PUBLIC void nk_each_fma_f32c_rvv(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
1191
+ nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
1192
+ /** @copydoc nk_each_fma_f64c */
1193
+ NK_PUBLIC void nk_each_fma_f64c_rvv(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
1194
+ nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
1195
+ #endif // NK_TARGET_RVV
1196
+
1197
+ /**
1198
+ * @brief Returns the scalar parameter dtype for elementwise scale/blend/fma operations.
1199
+ */
1200
+ NK_INTERNAL nk_dtype_t nk_each_scale_input_dtype(nk_dtype_t dtype) {
1201
+ switch (dtype) {
1202
+ case nk_f64c_k: return nk_f64c_k;
1203
+ case nk_f32c_k: return nk_f32c_k;
1204
+ case nk_f64_k: return nk_f64_k;
1205
+ case nk_f32_k: return nk_f32_k;
1206
+ case nk_f16_k: return nk_f32_k;
1207
+ case nk_bf16_k: return nk_f32_k;
1208
+ case nk_i64_k: return nk_f64_k;
1209
+ case nk_u64_k: return nk_f64_k;
1210
+ case nk_i32_k: return nk_f64_k;
1211
+ case nk_u32_k: return nk_f64_k;
1212
+ case nk_i16_k: return nk_f32_k;
1213
+ case nk_u16_k: return nk_f32_k;
1214
+ case nk_i8_k: return nk_f32_k;
1215
+ case nk_u8_k: return nk_f32_k;
1216
+ default: return nk_dtype_unknown_k;
1217
+ }
1218
+ }
1219
+
1220
+ #if defined(__cplusplus)
1221
+ } // extern "C"
1222
+ #endif
1223
+
1224
+ #include "numkong/each/serial.h"
1225
+ #include "numkong/each/neon.h"
1226
+ #include "numkong/each/neonhalf.h"
1227
+ #include "numkong/each/neonbfdot.h"
1228
+ #include "numkong/each/haswell.h"
1229
+ #include "numkong/each/skylake.h"
1230
+ #include "numkong/each/icelake.h"
1231
+ #include "numkong/each/sapphire.h"
1232
+ #include "numkong/each/rvv.h"
1233
+
1234
+ #if defined(__cplusplus)
1235
+ extern "C" {
1236
+ #endif
1237
+
1238
+ #if !NK_DYNAMIC_DISPATCH
1239
+
1240
+ NK_PUBLIC void nk_each_sum_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *r) {
1241
+ #if NK_TARGET_SKYLAKE
1242
+ nk_each_sum_f64_skylake(a, b, n, r);
1243
+ #elif NK_TARGET_HASWELL
1244
+ nk_each_sum_f64_haswell(a, b, n, r);
1245
+ #elif NK_TARGET_NEON
1246
+ nk_each_sum_f64_neon(a, b, n, r);
1247
+ #elif NK_TARGET_RVV
1248
+ nk_each_sum_f64_rvv(a, b, n, r);
1249
+ #else
1250
+ nk_each_sum_f64_serial(a, b, n, r);
1251
+ #endif
1252
+ }
1253
+
1254
+ NK_PUBLIC void nk_each_sum_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *r) {
1255
+ #if NK_TARGET_SKYLAKE
1256
+ nk_each_sum_f32_skylake(a, b, n, r);
1257
+ #elif NK_TARGET_HASWELL
1258
+ nk_each_sum_f32_haswell(a, b, n, r);
1259
+ #elif NK_TARGET_NEON
1260
+ nk_each_sum_f32_neon(a, b, n, r);
1261
+ #elif NK_TARGET_RVV
1262
+ nk_each_sum_f32_rvv(a, b, n, r);
1263
+ #else
1264
+ nk_each_sum_f32_serial(a, b, n, r);
1265
+ #endif
1266
+ }
1267
+
1268
+ NK_PUBLIC void nk_each_sum_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *r) {
1269
+ #if NK_TARGET_SKYLAKE
1270
+ nk_each_sum_bf16_skylake(a, b, n, r);
1271
+ #elif NK_TARGET_HASWELL
1272
+ nk_each_sum_bf16_haswell(a, b, n, r);
1273
+ #elif NK_TARGET_NEONBFDOT
1274
+ nk_each_sum_bf16_neonbfdot(a, b, n, r);
1275
+ #elif NK_TARGET_RVV
1276
+ nk_each_sum_bf16_rvv(a, b, n, r);
1277
+ #else
1278
+ nk_each_sum_bf16_serial(a, b, n, r);
1279
+ #endif
1280
+ }
1281
+
1282
+ NK_PUBLIC void nk_each_sum_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *r) {
1283
+ #if NK_TARGET_SAPPHIRE
1284
+ nk_each_sum_f16_sapphire(a, b, n, r);
1285
+ #elif NK_TARGET_HASWELL
1286
+ nk_each_sum_f16_haswell(a, b, n, r);
1287
+ #elif NK_TARGET_NEONHALF
1288
+ nk_each_sum_f16_neonhalf(a, b, n, r);
1289
+ #elif NK_TARGET_RVV
1290
+ nk_each_sum_f16_rvv(a, b, n, r);
1291
+ #else
1292
+ nk_each_sum_f16_serial(a, b, n, r);
1293
+ #endif
1294
+ }
1295
+
1296
+ NK_PUBLIC void nk_each_sum_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *r) {
1297
+ #if NK_TARGET_ICELAKE
1298
+ nk_each_sum_i8_icelake(a, b, n, r);
1299
+ #elif NK_TARGET_HASWELL
1300
+ nk_each_sum_i8_haswell(a, b, n, r);
1301
+ #elif NK_TARGET_NEONHALF
1302
+ nk_each_sum_i8_neonhalf(a, b, n, r);
1303
+ #elif NK_TARGET_RVV
1304
+ nk_each_sum_i8_rvv(a, b, n, r);
1305
+ #else
1306
+ nk_each_sum_i8_serial(a, b, n, r);
1307
+ #endif
1308
+ }
1309
+
1310
+ NK_PUBLIC void nk_each_sum_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *r) {
1311
+ #if NK_TARGET_ICELAKE
1312
+ nk_each_sum_u8_icelake(a, b, n, r);
1313
+ #elif NK_TARGET_HASWELL
1314
+ nk_each_sum_u8_haswell(a, b, n, r);
1315
+ #elif NK_TARGET_NEONHALF
1316
+ nk_each_sum_u8_neonhalf(a, b, n, r);
1317
+ #elif NK_TARGET_RVV
1318
+ nk_each_sum_u8_rvv(a, b, n, r);
1319
+ #else
1320
+ nk_each_sum_u8_serial(a, b, n, r);
1321
+ #endif
1322
+ }
1323
+
1324
+ NK_PUBLIC void nk_each_sum_i16(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *r) {
1325
+ #if NK_TARGET_ICELAKE
1326
+ nk_each_sum_i16_icelake(a, b, n, r);
1327
+ #elif NK_TARGET_HASWELL
1328
+ nk_each_sum_i16_haswell(a, b, n, r);
1329
+ #elif NK_TARGET_NEON
1330
+ nk_each_sum_i16_neon(a, b, n, r);
1331
+ #elif NK_TARGET_RVV
1332
+ nk_each_sum_i16_rvv(a, b, n, r);
1333
+ #else
1334
+ nk_each_sum_i16_serial(a, b, n, r);
1335
+ #endif
1336
+ }
1337
+
1338
+ NK_PUBLIC void nk_each_sum_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *r) {
1339
+ #if NK_TARGET_ICELAKE
1340
+ nk_each_sum_u16_icelake(a, b, n, r);
1341
+ #elif NK_TARGET_HASWELL
1342
+ nk_each_sum_u16_haswell(a, b, n, r);
1343
+ #elif NK_TARGET_NEON
1344
+ nk_each_sum_u16_neon(a, b, n, r);
1345
+ #elif NK_TARGET_RVV
1346
+ nk_each_sum_u16_rvv(a, b, n, r);
1347
+ #else
1348
+ nk_each_sum_u16_serial(a, b, n, r);
1349
+ #endif
1350
+ }
1351
+
1352
+ NK_PUBLIC void nk_each_sum_i32(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *r) {
1353
+ #if NK_TARGET_ICELAKE
1354
+ nk_each_sum_i32_icelake(a, b, n, r);
1355
+ #elif NK_TARGET_HASWELL
1356
+ nk_each_sum_i32_haswell(a, b, n, r);
1357
+ #elif NK_TARGET_NEON
1358
+ nk_each_sum_i32_neon(a, b, n, r);
1359
+ #elif NK_TARGET_RVV
1360
+ nk_each_sum_i32_rvv(a, b, n, r);
1361
+ #else
1362
+ nk_each_sum_i32_serial(a, b, n, r);
1363
+ #endif
1364
+ }
1365
+
1366
+ NK_PUBLIC void nk_each_sum_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *r) {
1367
+ #if NK_TARGET_ICELAKE
1368
+ nk_each_sum_u32_icelake(a, b, n, r);
1369
+ #elif NK_TARGET_HASWELL
1370
+ nk_each_sum_u32_haswell(a, b, n, r);
1371
+ #elif NK_TARGET_NEON
1372
+ nk_each_sum_u32_neon(a, b, n, r);
1373
+ #elif NK_TARGET_RVV
1374
+ nk_each_sum_u32_rvv(a, b, n, r);
1375
+ #else
1376
+ nk_each_sum_u32_serial(a, b, n, r);
1377
+ #endif
1378
+ }
1379
+
1380
+ NK_PUBLIC void nk_each_sum_i64(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *r) {
1381
+ #if NK_TARGET_ICELAKE
1382
+ nk_each_sum_i64_icelake(a, b, n, r);
1383
+ #elif NK_TARGET_NEON
1384
+ nk_each_sum_i64_neon(a, b, n, r);
1385
+ #elif NK_TARGET_RVV
1386
+ nk_each_sum_i64_rvv(a, b, n, r);
1387
+ #else
1388
+ nk_each_sum_i64_serial(a, b, n, r);
1389
+ #endif
1390
+ }
1391
+
1392
+ NK_PUBLIC void nk_each_sum_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *r) {
1393
+ #if NK_TARGET_ICELAKE
1394
+ nk_each_sum_u64_icelake(a, b, n, r);
1395
+ #elif NK_TARGET_NEON
1396
+ nk_each_sum_u64_neon(a, b, n, r);
1397
+ #elif NK_TARGET_RVV
1398
+ nk_each_sum_u64_rvv(a, b, n, r);
1399
+ #else
1400
+ nk_each_sum_u64_serial(a, b, n, r);
1401
+ #endif
1402
+ }
1403
+
1404
+ NK_PUBLIC void nk_each_scale_f64(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1405
+ nk_f64_t *r) {
1406
+ #if NK_TARGET_SKYLAKE
1407
+ nk_each_scale_f64_skylake(a, n, alpha, beta, r);
1408
+ #elif NK_TARGET_HASWELL
1409
+ nk_each_scale_f64_haswell(a, n, alpha, beta, r);
1410
+ #elif NK_TARGET_NEON
1411
+ nk_each_scale_f64_neon(a, n, alpha, beta, r);
1412
+ #elif NK_TARGET_RVV
1413
+ nk_each_scale_f64_rvv(a, n, alpha, beta, r);
1414
+ #else
1415
+ nk_each_scale_f64_serial(a, n, alpha, beta, r);
1416
+ #endif
1417
+ }
1418
+
1419
+ NK_PUBLIC void nk_each_scale_f32(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1420
+ nk_f32_t *r) {
1421
+ #if NK_TARGET_SKYLAKE
1422
+ nk_each_scale_f32_skylake(a, n, alpha, beta, r);
1423
+ #elif NK_TARGET_HASWELL
1424
+ nk_each_scale_f32_haswell(a, n, alpha, beta, r);
1425
+ #elif NK_TARGET_NEON
1426
+ nk_each_scale_f32_neon(a, n, alpha, beta, r);
1427
+ #elif NK_TARGET_RVV
1428
+ nk_each_scale_f32_rvv(a, n, alpha, beta, r);
1429
+ #else
1430
+ nk_each_scale_f32_serial(a, n, alpha, beta, r);
1431
+ #endif
1432
+ }
1433
+
1434
+ NK_PUBLIC void nk_each_scale_bf16(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1435
+ nk_bf16_t *r) {
1436
+ #if NK_TARGET_SKYLAKE
1437
+ nk_each_scale_bf16_skylake(a, n, alpha, beta, r);
1438
+ #elif NK_TARGET_HASWELL
1439
+ nk_each_scale_bf16_haswell(a, n, alpha, beta, r);
1440
+ #elif NK_TARGET_NEONBFDOT
1441
+ nk_each_scale_bf16_neonbfdot(a, n, alpha, beta, r);
1442
+ #elif NK_TARGET_RVV
1443
+ nk_each_scale_bf16_rvv(a, n, alpha, beta, r);
1444
+ #else
1445
+ nk_each_scale_bf16_serial(a, n, alpha, beta, r);
1446
+ #endif
1447
+ }
1448
+
1449
+ NK_PUBLIC void nk_each_scale_f16(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1450
+ nk_f16_t *r) {
1451
+ #if NK_TARGET_SKYLAKE
1452
+ nk_each_scale_f16_skylake(a, n, alpha, beta, r);
1453
+ #elif NK_TARGET_HASWELL
1454
+ nk_each_scale_f16_haswell(a, n, alpha, beta, r);
1455
+ #elif NK_TARGET_NEONHALF
1456
+ nk_each_scale_f16_neonhalf(a, n, alpha, beta, r);
1457
+ #elif NK_TARGET_RVV
1458
+ nk_each_scale_f16_rvv(a, n, alpha, beta, r);
1459
+ #else
1460
+ nk_each_scale_f16_serial(a, n, alpha, beta, r);
1461
+ #endif
1462
+ }
1463
+
1464
+ NK_PUBLIC void nk_each_scale_i8(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1465
+ nk_i8_t *r) {
1466
+ #if NK_TARGET_SAPPHIRE
1467
+ nk_each_scale_i8_sapphire(a, n, alpha, beta, r);
1468
+ #elif NK_TARGET_SKYLAKE
1469
+ nk_each_scale_i8_skylake(a, n, alpha, beta, r);
1470
+ #elif NK_TARGET_HASWELL
1471
+ nk_each_scale_i8_haswell(a, n, alpha, beta, r);
1472
+ #elif NK_TARGET_NEONHALF
1473
+ nk_each_scale_i8_neonhalf(a, n, alpha, beta, r);
1474
+ #elif NK_TARGET_RVV
1475
+ nk_each_scale_i8_rvv(a, n, alpha, beta, r);
1476
+ #else
1477
+ nk_each_scale_i8_serial(a, n, alpha, beta, r);
1478
+ #endif
1479
+ }
1480
+
1481
+ NK_PUBLIC void nk_each_scale_u8(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1482
+ nk_u8_t *r) {
1483
+ #if NK_TARGET_SAPPHIRE
1484
+ nk_each_scale_u8_sapphire(a, n, alpha, beta, r);
1485
+ #elif NK_TARGET_SKYLAKE
1486
+ nk_each_scale_u8_skylake(a, n, alpha, beta, r);
1487
+ #elif NK_TARGET_HASWELL
1488
+ nk_each_scale_u8_haswell(a, n, alpha, beta, r);
1489
+ #elif NK_TARGET_NEONHALF
1490
+ nk_each_scale_u8_neonhalf(a, n, alpha, beta, r);
1491
+ #elif NK_TARGET_RVV
1492
+ nk_each_scale_u8_rvv(a, n, alpha, beta, r);
1493
+ #else
1494
+ nk_each_scale_u8_serial(a, n, alpha, beta, r);
1495
+ #endif
1496
+ }
1497
+
1498
+ NK_PUBLIC void nk_each_scale_i16(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1499
+ nk_i16_t *r) {
1500
+ #if NK_TARGET_SKYLAKE
1501
+ nk_each_scale_i16_skylake(a, n, alpha, beta, r);
1502
+ #elif NK_TARGET_HASWELL
1503
+ nk_each_scale_i16_haswell(a, n, alpha, beta, r);
1504
+ #elif NK_TARGET_NEON
1505
+ nk_each_scale_i16_neon(a, n, alpha, beta, r);
1506
+ #elif NK_TARGET_RVV
1507
+ nk_each_scale_i16_rvv(a, n, alpha, beta, r);
1508
+ #else
1509
+ nk_each_scale_i16_serial(a, n, alpha, beta, r);
1510
+ #endif
1511
+ }
1512
+
1513
+ NK_PUBLIC void nk_each_scale_u16(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1514
+ nk_u16_t *r) {
1515
+ #if NK_TARGET_SKYLAKE
1516
+ nk_each_scale_u16_skylake(a, n, alpha, beta, r);
1517
+ #elif NK_TARGET_HASWELL
1518
+ nk_each_scale_u16_haswell(a, n, alpha, beta, r);
1519
+ #elif NK_TARGET_NEON
1520
+ nk_each_scale_u16_neon(a, n, alpha, beta, r);
1521
+ #elif NK_TARGET_RVV
1522
+ nk_each_scale_u16_rvv(a, n, alpha, beta, r);
1523
+ #else
1524
+ nk_each_scale_u16_serial(a, n, alpha, beta, r);
1525
+ #endif
1526
+ }
1527
+
1528
+ NK_PUBLIC void nk_each_scale_i32(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1529
+ nk_i32_t *r) {
1530
+ #if NK_TARGET_SKYLAKE
1531
+ nk_each_scale_i32_skylake(a, n, alpha, beta, r);
1532
+ #elif NK_TARGET_HASWELL
1533
+ nk_each_scale_i32_haswell(a, n, alpha, beta, r);
1534
+ #elif NK_TARGET_NEON
1535
+ nk_each_scale_i32_neon(a, n, alpha, beta, r);
1536
+ #elif NK_TARGET_RVV
1537
+ nk_each_scale_i32_rvv(a, n, alpha, beta, r);
1538
+ #else
1539
+ nk_each_scale_i32_serial(a, n, alpha, beta, r);
1540
+ #endif
1541
+ }
1542
+
1543
+ NK_PUBLIC void nk_each_scale_u32(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1544
+ nk_u32_t *r) {
1545
+ #if NK_TARGET_SKYLAKE
1546
+ nk_each_scale_u32_skylake(a, n, alpha, beta, r);
1547
+ #elif NK_TARGET_HASWELL
1548
+ nk_each_scale_u32_haswell(a, n, alpha, beta, r);
1549
+ #elif NK_TARGET_NEON
1550
+ nk_each_scale_u32_neon(a, n, alpha, beta, r);
1551
+ #elif NK_TARGET_RVV
1552
+ nk_each_scale_u32_rvv(a, n, alpha, beta, r);
1553
+ #else
1554
+ nk_each_scale_u32_serial(a, n, alpha, beta, r);
1555
+ #endif
1556
+ }
1557
+
1558
+ NK_PUBLIC void nk_each_scale_i64(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1559
+ nk_i64_t *r) {
1560
+ #if NK_TARGET_SKYLAKE
1561
+ nk_each_scale_i64_skylake(a, n, alpha, beta, r);
1562
+ #elif NK_TARGET_NEON
1563
+ nk_each_scale_i64_neon(a, n, alpha, beta, r);
1564
+ #elif NK_TARGET_RVV
1565
+ nk_each_scale_i64_rvv(a, n, alpha, beta, r);
1566
+ #else
1567
+ nk_each_scale_i64_serial(a, n, alpha, beta, r);
1568
+ #endif
1569
+ }
1570
+
1571
+ NK_PUBLIC void nk_each_scale_u64(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
1572
+ nk_u64_t *r) {
1573
+ #if NK_TARGET_SKYLAKE
1574
+ nk_each_scale_u64_skylake(a, n, alpha, beta, r);
1575
+ #elif NK_TARGET_NEON
1576
+ nk_each_scale_u64_neon(a, n, alpha, beta, r);
1577
+ #elif NK_TARGET_RVV
1578
+ nk_each_scale_u64_rvv(a, n, alpha, beta, r);
1579
+ #else
1580
+ nk_each_scale_u64_serial(a, n, alpha, beta, r);
1581
+ #endif
1582
+ }
1583
+
1584
+ NK_PUBLIC void nk_each_blend_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
1585
+ nk_f64_t const *beta, nk_f64_t *r) {
1586
+ #if NK_TARGET_SKYLAKE
1587
+ nk_each_blend_f64_skylake(a, b, n, alpha, beta, r);
1588
+ #elif NK_TARGET_HASWELL
1589
+ nk_each_blend_f64_haswell(a, b, n, alpha, beta, r);
1590
+ #elif NK_TARGET_NEON
1591
+ nk_each_blend_f64_neon(a, b, n, alpha, beta, r);
1592
+ #elif NK_TARGET_RVV
1593
+ nk_each_blend_f64_rvv(a, b, n, alpha, beta, r);
1594
+ #else
1595
+ nk_each_blend_f64_serial(a, b, n, alpha, beta, r);
1596
+ #endif
1597
+ }
1598
+
1599
+ NK_PUBLIC void nk_each_blend_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
1600
+ nk_f32_t const *beta, nk_f32_t *r) {
1601
+ #if NK_TARGET_SKYLAKE
1602
+ nk_each_blend_f32_skylake(a, b, n, alpha, beta, r);
1603
+ #elif NK_TARGET_HASWELL
1604
+ nk_each_blend_f32_haswell(a, b, n, alpha, beta, r);
1605
+ #elif NK_TARGET_NEON
1606
+ nk_each_blend_f32_neon(a, b, n, alpha, beta, r);
1607
+ #elif NK_TARGET_RVV
1608
+ nk_each_blend_f32_rvv(a, b, n, alpha, beta, r);
1609
+ #else
1610
+ nk_each_blend_f32_serial(a, b, n, alpha, beta, r);
1611
+ #endif
1612
+ }
1613
+
1614
+ NK_PUBLIC void nk_each_blend_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
1615
+ nk_f32_t const *beta, nk_bf16_t *r) {
1616
+ #if NK_TARGET_SKYLAKE
1617
+ nk_each_blend_bf16_skylake(a, b, n, alpha, beta, r);
1618
+ #elif NK_TARGET_HASWELL
1619
+ nk_each_blend_bf16_haswell(a, b, n, alpha, beta, r);
1620
+ #elif NK_TARGET_NEONBFDOT
1621
+ nk_each_blend_bf16_neonbfdot(a, b, n, alpha, beta, r);
1622
+ #elif NK_TARGET_RVV
1623
+ nk_each_blend_bf16_rvv(a, b, n, alpha, beta, r);
1624
+ #else
1625
+ nk_each_blend_bf16_serial(a, b, n, alpha, beta, r);
1626
+ #endif
1627
+ }
1628
+
1629
+ NK_PUBLIC void nk_each_blend_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
1630
+ nk_f32_t const *beta, nk_f16_t *r) {
1631
+ #if NK_TARGET_SKYLAKE
1632
+ nk_each_blend_f16_skylake(a, b, n, alpha, beta, r);
1633
+ #elif NK_TARGET_HASWELL
1634
+ nk_each_blend_f16_haswell(a, b, n, alpha, beta, r);
1635
+ #elif NK_TARGET_NEONHALF
1636
+ nk_each_blend_f16_neonhalf(a, b, n, alpha, beta, r);
1637
+ #elif NK_TARGET_RVV
1638
+ nk_each_blend_f16_rvv(a, b, n, alpha, beta, r);
1639
+ #else
1640
+ nk_each_blend_f16_serial(a, b, n, alpha, beta, r);
1641
+ #endif
1642
+ }
1643
+
1644
+ NK_PUBLIC void nk_each_blend_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
1645
+ nk_f32_t const *beta, nk_i8_t *r) {
1646
+ #if NK_TARGET_SAPPHIRE
1647
+ nk_each_blend_i8_sapphire(a, b, n, alpha, beta, r);
1648
+ #elif NK_TARGET_HASWELL
1649
+ nk_each_blend_i8_haswell(a, b, n, alpha, beta, r);
1650
+ #elif NK_TARGET_NEONHALF
1651
+ nk_each_blend_i8_neonhalf(a, b, n, alpha, beta, r);
1652
+ #elif NK_TARGET_RVV
1653
+ nk_each_blend_i8_rvv(a, b, n, alpha, beta, r);
1654
+ #else
1655
+ nk_each_blend_i8_serial(a, b, n, alpha, beta, r);
1656
+ #endif
1657
+ }
1658
+
1659
+ NK_PUBLIC void nk_each_blend_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
1660
+ nk_f32_t const *beta, nk_u8_t *r) {
1661
+ #if NK_TARGET_SAPPHIRE
1662
+ nk_each_blend_u8_sapphire(a, b, n, alpha, beta, r);
1663
+ #elif NK_TARGET_HASWELL
1664
+ nk_each_blend_u8_haswell(a, b, n, alpha, beta, r);
1665
+ #elif NK_TARGET_NEONHALF
1666
+ nk_each_blend_u8_neonhalf(a, b, n, alpha, beta, r);
1667
+ #elif NK_TARGET_RVV
1668
+ nk_each_blend_u8_rvv(a, b, n, alpha, beta, r);
1669
+ #else
1670
+ nk_each_blend_u8_serial(a, b, n, alpha, beta, r);
1671
+ #endif
1672
+ }
1673
+
1674
+ NK_PUBLIC void nk_each_blend_i16(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_f32_t const *alpha,
1675
+ nk_f32_t const *beta, nk_i16_t *r) {
1676
+ nk_each_blend_i16_serial(a, b, n, alpha, beta, r);
1677
+ }
1678
+
1679
+ NK_PUBLIC void nk_each_blend_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_f32_t const *alpha,
1680
+ nk_f32_t const *beta, nk_u16_t *r) {
1681
+ nk_each_blend_u16_serial(a, b, n, alpha, beta, r);
1682
+ }
1683
+
1684
+ NK_PUBLIC void nk_each_blend_i32(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_f64_t const *alpha,
1685
+ nk_f64_t const *beta, nk_i32_t *r) {
1686
+ nk_each_blend_i32_serial(a, b, n, alpha, beta, r);
1687
+ }
1688
+
1689
+ NK_PUBLIC void nk_each_blend_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_f64_t const *alpha,
1690
+ nk_f64_t const *beta, nk_u32_t *r) {
1691
+ nk_each_blend_u32_serial(a, b, n, alpha, beta, r);
1692
+ }
1693
+
1694
+ NK_PUBLIC void nk_each_blend_i64(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_f64_t const *alpha,
1695
+ nk_f64_t const *beta, nk_i64_t *r) {
1696
+ nk_each_blend_i64_serial(a, b, n, alpha, beta, r);
1697
+ }
1698
+
1699
+ NK_PUBLIC void nk_each_blend_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_f64_t const *alpha,
1700
+ nk_f64_t const *beta, nk_u64_t *r) {
1701
+ nk_each_blend_u64_serial(a, b, n, alpha, beta, r);
1702
+ }
1703
+
1704
+ NK_PUBLIC void nk_each_fma_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
1705
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *r) {
1706
+ #if NK_TARGET_SKYLAKE
1707
+ nk_each_fma_f64_skylake(a, b, c, n, alpha, beta, r);
1708
+ #elif NK_TARGET_HASWELL
1709
+ nk_each_fma_f64_haswell(a, b, c, n, alpha, beta, r);
1710
+ #elif NK_TARGET_NEON
1711
+ nk_each_fma_f64_neon(a, b, c, n, alpha, beta, r);
1712
+ #elif NK_TARGET_RVV
1713
+ nk_each_fma_f64_rvv(a, b, c, n, alpha, beta, r);
1714
+ #else
1715
+ nk_each_fma_f64_serial(a, b, c, n, alpha, beta, r);
1716
+ #endif
1717
+ }
1718
+
1719
+ NK_PUBLIC void nk_each_fma_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
1720
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *r) {
1721
+ #if NK_TARGET_SKYLAKE
1722
+ nk_each_fma_f32_skylake(a, b, c, n, alpha, beta, r);
1723
+ #elif NK_TARGET_HASWELL
1724
+ nk_each_fma_f32_haswell(a, b, c, n, alpha, beta, r);
1725
+ #elif NK_TARGET_NEON
1726
+ nk_each_fma_f32_neon(a, b, c, n, alpha, beta, r);
1727
+ #elif NK_TARGET_RVV
1728
+ nk_each_fma_f32_rvv(a, b, c, n, alpha, beta, r);
1729
+ #else
1730
+ nk_each_fma_f32_serial(a, b, c, n, alpha, beta, r);
1731
+ #endif
1732
+ }
1733
+
1734
+ NK_PUBLIC void nk_each_fma_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
1735
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *r) {
1736
+ #if NK_TARGET_SKYLAKE
1737
+ nk_each_fma_bf16_skylake(a, b, c, n, alpha, beta, r);
1738
+ #elif NK_TARGET_HASWELL
1739
+ nk_each_fma_bf16_haswell(a, b, c, n, alpha, beta, r);
1740
+ #elif NK_TARGET_NEONBFDOT
1741
+ nk_each_fma_bf16_neonbfdot(a, b, c, n, alpha, beta, r);
1742
+ #elif NK_TARGET_RVV
1743
+ nk_each_fma_bf16_rvv(a, b, c, n, alpha, beta, r);
1744
+ #else
1745
+ nk_each_fma_bf16_serial(a, b, c, n, alpha, beta, r);
1746
+ #endif
1747
+ }
1748
+
1749
+ NK_PUBLIC void nk_each_fma_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
1750
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *r) {
1751
+ #if NK_TARGET_SKYLAKE
1752
+ nk_each_fma_f16_skylake(a, b, c, n, alpha, beta, r);
1753
+ #elif NK_TARGET_HASWELL
1754
+ nk_each_fma_f16_haswell(a, b, c, n, alpha, beta, r);
1755
+ #elif NK_TARGET_NEONHALF
1756
+ nk_each_fma_f16_neonhalf(a, b, c, n, alpha, beta, r);
1757
+ #elif NK_TARGET_RVV
1758
+ nk_each_fma_f16_rvv(a, b, c, n, alpha, beta, r);
1759
+ #else
1760
+ nk_each_fma_f16_serial(a, b, c, n, alpha, beta, r);
1761
+ #endif
1762
+ }
1763
+
1764
+ NK_PUBLIC void nk_each_fma_i8(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, nk_f32_t const *alpha,
1765
+ nk_f32_t const *beta, nk_i8_t *r) {
1766
+ #if NK_TARGET_SAPPHIRE
1767
+ nk_each_fma_i8_sapphire(a, b, c, n, alpha, beta, r);
1768
+ #elif NK_TARGET_SKYLAKE
1769
+ nk_each_fma_i8_skylake(a, b, c, n, alpha, beta, r);
1770
+ #elif NK_TARGET_HASWELL
1771
+ nk_each_fma_i8_haswell(a, b, c, n, alpha, beta, r);
1772
+ #elif NK_TARGET_NEONHALF
1773
+ nk_each_fma_i8_neonhalf(a, b, c, n, alpha, beta, r);
1774
+ #elif NK_TARGET_RVV
1775
+ nk_each_fma_i8_rvv(a, b, c, n, alpha, beta, r);
1776
+ #else
1777
+ nk_each_fma_i8_serial(a, b, c, n, alpha, beta, r);
1778
+ #endif
1779
+ }
1780
+
1781
+ NK_PUBLIC void nk_each_fma_u8(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, nk_f32_t const *alpha,
1782
+ nk_f32_t const *beta, nk_u8_t *r) {
1783
+ #if NK_TARGET_SAPPHIRE
1784
+ nk_each_fma_u8_sapphire(a, b, c, n, alpha, beta, r);
1785
+ #elif NK_TARGET_SKYLAKE
1786
+ nk_each_fma_u8_skylake(a, b, c, n, alpha, beta, r);
1787
+ #elif NK_TARGET_HASWELL
1788
+ nk_each_fma_u8_haswell(a, b, c, n, alpha, beta, r);
1789
+ #elif NK_TARGET_NEONHALF
1790
+ nk_each_fma_u8_neonhalf(a, b, c, n, alpha, beta, r);
1791
+ #elif NK_TARGET_RVV
1792
+ nk_each_fma_u8_rvv(a, b, c, n, alpha, beta, r);
1793
+ #else
1794
+ nk_each_fma_u8_serial(a, b, c, n, alpha, beta, r);
1795
+ #endif
1796
+ }
1797
+
1798
+ NK_PUBLIC void nk_each_fma_i16(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
1799
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *r) {
1800
+ #if NK_TARGET_SKYLAKE
1801
+ nk_each_fma_i16_skylake(a, b, c, n, alpha, beta, r);
1802
+ #elif NK_TARGET_HASWELL
1803
+ nk_each_fma_i16_haswell(a, b, c, n, alpha, beta, r);
1804
+ #elif NK_TARGET_NEON
1805
+ nk_each_fma_i16_neon(a, b, c, n, alpha, beta, r);
1806
+ #elif NK_TARGET_RVV
1807
+ nk_each_fma_i16_rvv(a, b, c, n, alpha, beta, r);
1808
+ #else
1809
+ nk_each_fma_i16_serial(a, b, c, n, alpha, beta, r);
1810
+ #endif
1811
+ }
1812
+
1813
+ NK_PUBLIC void nk_each_fma_u16(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
1814
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *r) {
1815
+ #if NK_TARGET_SKYLAKE
1816
+ nk_each_fma_u16_skylake(a, b, c, n, alpha, beta, r);
1817
+ #elif NK_TARGET_HASWELL
1818
+ nk_each_fma_u16_haswell(a, b, c, n, alpha, beta, r);
1819
+ #elif NK_TARGET_NEON
1820
+ nk_each_fma_u16_neon(a, b, c, n, alpha, beta, r);
1821
+ #elif NK_TARGET_RVV
1822
+ nk_each_fma_u16_rvv(a, b, c, n, alpha, beta, r);
1823
+ #else
1824
+ nk_each_fma_u16_serial(a, b, c, n, alpha, beta, r);
1825
+ #endif
1826
+ }
1827
+
1828
+ NK_PUBLIC void nk_each_fma_i32(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
1829
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *r) {
1830
+ #if NK_TARGET_SKYLAKE
1831
+ nk_each_fma_i32_skylake(a, b, c, n, alpha, beta, r);
1832
+ #elif NK_TARGET_HASWELL
1833
+ nk_each_fma_i32_haswell(a, b, c, n, alpha, beta, r);
1834
+ #elif NK_TARGET_NEON
1835
+ nk_each_fma_i32_neon(a, b, c, n, alpha, beta, r);
1836
+ #elif NK_TARGET_RVV
1837
+ nk_each_fma_i32_rvv(a, b, c, n, alpha, beta, r);
1838
+ #else
1839
+ nk_each_fma_i32_serial(a, b, c, n, alpha, beta, r);
1840
+ #endif
1841
+ }
1842
+
1843
+ NK_PUBLIC void nk_each_fma_u32(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
1844
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *r) {
1845
+ #if NK_TARGET_SKYLAKE
1846
+ nk_each_fma_u32_skylake(a, b, c, n, alpha, beta, r);
1847
+ #elif NK_TARGET_HASWELL
1848
+ nk_each_fma_u32_haswell(a, b, c, n, alpha, beta, r);
1849
+ #elif NK_TARGET_NEON
1850
+ nk_each_fma_u32_neon(a, b, c, n, alpha, beta, r);
1851
+ #elif NK_TARGET_RVV
1852
+ nk_each_fma_u32_rvv(a, b, c, n, alpha, beta, r);
1853
+ #else
1854
+ nk_each_fma_u32_serial(a, b, c, n, alpha, beta, r);
1855
+ #endif
1856
+ }
1857
+
1858
+ NK_PUBLIC void nk_each_fma_i64(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
1859
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *r) {
1860
+ #if NK_TARGET_SKYLAKE
1861
+ nk_each_fma_i64_skylake(a, b, c, n, alpha, beta, r);
1862
+ #elif NK_TARGET_NEON
1863
+ nk_each_fma_i64_neon(a, b, c, n, alpha, beta, r);
1864
+ #elif NK_TARGET_RVV
1865
+ nk_each_fma_i64_rvv(a, b, c, n, alpha, beta, r);
1866
+ #else
1867
+ nk_each_fma_i64_serial(a, b, c, n, alpha, beta, r);
1868
+ #endif
1869
+ }
1870
+
1871
+ NK_PUBLIC void nk_each_fma_u64(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
1872
+ nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *r) {
1873
+ #if NK_TARGET_SKYLAKE
1874
+ nk_each_fma_u64_skylake(a, b, c, n, alpha, beta, r);
1875
+ #elif NK_TARGET_NEON
1876
+ nk_each_fma_u64_neon(a, b, c, n, alpha, beta, r);
1877
+ #elif NK_TARGET_RVV
1878
+ nk_each_fma_u64_rvv(a, b, c, n, alpha, beta, r);
1879
+ #else
1880
+ nk_each_fma_u64_serial(a, b, c, n, alpha, beta, r);
1881
+ #endif
1882
+ }
1883
+
1884
+ NK_PUBLIC void nk_each_sum_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result) {
1885
+ #if NK_TARGET_SAPPHIRE
1886
+ nk_each_sum_e4m3_sapphire(a, b, n, result);
1887
+ #elif NK_TARGET_SKYLAKE
1888
+ nk_each_sum_e4m3_skylake(a, b, n, result);
1889
+ #elif NK_TARGET_HASWELL
1890
+ nk_each_sum_e4m3_haswell(a, b, n, result);
1891
+ #elif NK_TARGET_NEON
1892
+ nk_each_sum_e4m3_neon(a, b, n, result);
1893
+ #elif NK_TARGET_RVV
1894
+ nk_each_sum_e4m3_rvv(a, b, n, result);
1895
+ #else
1896
+ nk_each_sum_e4m3_serial(a, b, n, result);
1897
+ #endif
1898
+ }
1899
+
1900
+ NK_PUBLIC void nk_each_sum_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result) {
1901
+ #if NK_TARGET_SKYLAKE
1902
+ nk_each_sum_e5m2_skylake(a, b, n, result);
1903
+ #elif NK_TARGET_HASWELL
1904
+ nk_each_sum_e5m2_haswell(a, b, n, result);
1905
+ #elif NK_TARGET_NEON
1906
+ nk_each_sum_e5m2_neon(a, b, n, result);
1907
+ #elif NK_TARGET_RVV
1908
+ nk_each_sum_e5m2_rvv(a, b, n, result);
1909
+ #else
1910
+ nk_each_sum_e5m2_serial(a, b, n, result);
1911
+ #endif
1912
+ }
1913
+
1914
+ NK_PUBLIC void nk_each_scale_e4m3(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1915
+ nk_e4m3_t *result) {
1916
+ #if NK_TARGET_SKYLAKE
1917
+ nk_each_scale_e4m3_skylake(a, n, alpha, beta, result);
1918
+ #elif NK_TARGET_HASWELL
1919
+ nk_each_scale_e4m3_haswell(a, n, alpha, beta, result);
1920
+ #elif NK_TARGET_NEON
1921
+ nk_each_scale_e4m3_neon(a, n, alpha, beta, result);
1922
+ #elif NK_TARGET_RVV
1923
+ nk_each_scale_e4m3_rvv(a, n, alpha, beta, result);
1924
+ #else
1925
+ nk_each_scale_e4m3_serial(a, n, alpha, beta, result);
1926
+ #endif
1927
+ }
1928
+
1929
+ NK_PUBLIC void nk_each_scale_e5m2(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
1930
+ nk_e5m2_t *result) {
1931
+ #if NK_TARGET_SKYLAKE
1932
+ nk_each_scale_e5m2_skylake(a, n, alpha, beta, result);
1933
+ #elif NK_TARGET_HASWELL
1934
+ nk_each_scale_e5m2_haswell(a, n, alpha, beta, result);
1935
+ #elif NK_TARGET_NEON
1936
+ nk_each_scale_e5m2_neon(a, n, alpha, beta, result);
1937
+ #elif NK_TARGET_RVV
1938
+ nk_each_scale_e5m2_rvv(a, n, alpha, beta, result);
1939
+ #else
1940
+ nk_each_scale_e5m2_serial(a, n, alpha, beta, result);
1941
+ #endif
1942
+ }
1943
+
1944
+ NK_PUBLIC void nk_each_blend_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
1945
+ nk_f32_t const *beta, nk_e4m3_t *result) {
1946
+ #if NK_TARGET_SKYLAKE
1947
+ nk_each_blend_e4m3_skylake(a, b, n, alpha, beta, result);
1948
+ #elif NK_TARGET_HASWELL
1949
+ nk_each_blend_e4m3_haswell(a, b, n, alpha, beta, result);
1950
+ #elif NK_TARGET_NEON
1951
+ nk_each_blend_e4m3_neon(a, b, n, alpha, beta, result);
1952
+ #elif NK_TARGET_RVV
1953
+ nk_each_blend_e4m3_rvv(a, b, n, alpha, beta, result);
1954
+ #else
1955
+ nk_each_blend_e4m3_serial(a, b, n, alpha, beta, result);
1956
+ #endif
1957
+ }
1958
+
1959
+ NK_PUBLIC void nk_each_blend_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
1960
+ nk_f32_t const *beta, nk_e5m2_t *result) {
1961
+ #if NK_TARGET_SKYLAKE
1962
+ nk_each_blend_e5m2_skylake(a, b, n, alpha, beta, result);
1963
+ #elif NK_TARGET_HASWELL
1964
+ nk_each_blend_e5m2_haswell(a, b, n, alpha, beta, result);
1965
+ #elif NK_TARGET_NEON
1966
+ nk_each_blend_e5m2_neon(a, b, n, alpha, beta, result);
1967
+ #elif NK_TARGET_RVV
1968
+ nk_each_blend_e5m2_rvv(a, b, n, alpha, beta, result);
1969
+ #else
1970
+ nk_each_blend_e5m2_serial(a, b, n, alpha, beta, result);
1971
+ #endif
1972
+ }
1973
+
1974
+ NK_PUBLIC void nk_each_fma_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
1975
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result) {
1976
+ #if NK_TARGET_SKYLAKE
1977
+ nk_each_fma_e4m3_skylake(a, b, c, n, alpha, beta, result);
1978
+ #elif NK_TARGET_HASWELL
1979
+ nk_each_fma_e4m3_haswell(a, b, c, n, alpha, beta, result);
1980
+ #elif NK_TARGET_NEON
1981
+ nk_each_fma_e4m3_neon(a, b, c, n, alpha, beta, result);
1982
+ #elif NK_TARGET_RVV
1983
+ nk_each_fma_e4m3_rvv(a, b, c, n, alpha, beta, result);
1984
+ #else
1985
+ nk_each_fma_e4m3_serial(a, b, c, n, alpha, beta, result);
1986
+ #endif
1987
+ }
1988
+
1989
+ NK_PUBLIC void nk_each_fma_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
1990
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result) {
1991
+ #if NK_TARGET_SKYLAKE
1992
+ nk_each_fma_e5m2_skylake(a, b, c, n, alpha, beta, result);
1993
+ #elif NK_TARGET_HASWELL
1994
+ nk_each_fma_e5m2_haswell(a, b, c, n, alpha, beta, result);
1995
+ #elif NK_TARGET_NEON
1996
+ nk_each_fma_e5m2_neon(a, b, c, n, alpha, beta, result);
1997
+ #elif NK_TARGET_RVV
1998
+ nk_each_fma_e5m2_rvv(a, b, c, n, alpha, beta, result);
1999
+ #else
2000
+ nk_each_fma_e5m2_serial(a, b, c, n, alpha, beta, result);
2001
+ #endif
2002
+ }
2003
+
2004
+ NK_PUBLIC void nk_each_sum_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_e2m3_t *result) {
2005
+ nk_each_sum_e2m3_serial(a, b, n, result);
2006
+ }
2007
+
2008
+ NK_PUBLIC void nk_each_sum_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_e3m2_t *result) {
2009
+ nk_each_sum_e3m2_serial(a, b, n, result);
2010
+ }
2011
+
2012
+ NK_PUBLIC void nk_each_scale_e2m3(nk_e2m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
2013
+ nk_e2m3_t *result) {
2014
+ nk_each_scale_e2m3_serial(a, n, alpha, beta, result);
2015
+ }
2016
+
2017
+ NK_PUBLIC void nk_each_scale_e3m2(nk_e3m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
2018
+ nk_e3m2_t *result) {
2019
+ nk_each_scale_e3m2_serial(a, n, alpha, beta, result);
2020
+ }
2021
+
2022
+ NK_PUBLIC void nk_each_blend_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
2023
+ nk_f32_t const *beta, nk_e2m3_t *result) {
2024
+ nk_each_blend_e2m3_serial(a, b, n, alpha, beta, result);
2025
+ }
2026
+
2027
+ NK_PUBLIC void nk_each_blend_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
2028
+ nk_f32_t const *beta, nk_e3m2_t *result) {
2029
+ nk_each_blend_e3m2_serial(a, b, n, alpha, beta, result);
2030
+ }
2031
+
2032
+ NK_PUBLIC void nk_each_fma_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_e2m3_t const *c, nk_size_t n,
2033
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e2m3_t *result) {
2034
+ nk_each_fma_e2m3_serial(a, b, c, n, alpha, beta, result);
2035
+ }
2036
+
2037
+ NK_PUBLIC void nk_each_fma_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_e3m2_t const *c, nk_size_t n,
2038
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_e3m2_t *result) {
2039
+ nk_each_fma_e3m2_serial(a, b, c, n, alpha, beta, result);
2040
+ }
2041
+
2042
+ NK_PUBLIC void nk_each_sum_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t *r) {
2043
+ nk_each_sum_f32c_serial(a, b, n, r);
2044
+ }
2045
+
2046
+ NK_PUBLIC void nk_each_sum_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *r) {
2047
+ nk_each_sum_f64c_serial(a, b, n, r);
2048
+ }
2049
+
2050
+ NK_PUBLIC void nk_each_scale_f32c(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
2051
+ nk_f32c_t *r) {
2052
+ #if NK_TARGET_SKYLAKE
2053
+ nk_each_scale_f32c_skylake(a, n, alpha, beta, r);
2054
+ #elif NK_TARGET_HASWELL
2055
+ nk_each_scale_f32c_haswell(a, n, alpha, beta, r);
2056
+ #elif NK_TARGET_NEON
2057
+ nk_each_scale_f32c_neon(a, n, alpha, beta, r);
2058
+ #elif NK_TARGET_RVV
2059
+ nk_each_scale_f32c_rvv(a, n, alpha, beta, r);
2060
+ #else
2061
+ nk_each_scale_f32c_serial(a, n, alpha, beta, r);
2062
+ #endif
2063
+ }
2064
+
2065
+ NK_PUBLIC void nk_each_scale_f64c(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
2066
+ nk_f64c_t *r) {
2067
+ #if NK_TARGET_SKYLAKE
2068
+ nk_each_scale_f64c_skylake(a, n, alpha, beta, r);
2069
+ #elif NK_TARGET_HASWELL
2070
+ nk_each_scale_f64c_haswell(a, n, alpha, beta, r);
2071
+ #elif NK_TARGET_NEON
2072
+ nk_each_scale_f64c_neon(a, n, alpha, beta, r);
2073
+ #elif NK_TARGET_RVV
2074
+ nk_each_scale_f64c_rvv(a, n, alpha, beta, r);
2075
+ #else
2076
+ nk_each_scale_f64c_serial(a, n, alpha, beta, r);
2077
+ #endif
2078
+ }
2079
+
2080
+ NK_PUBLIC void nk_each_blend_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
2081
+ nk_f32c_t const *beta, nk_f32c_t *r) {
2082
+ #if NK_TARGET_SKYLAKE
2083
+ nk_each_blend_f32c_skylake(a, b, n, alpha, beta, r);
2084
+ #elif NK_TARGET_HASWELL
2085
+ nk_each_blend_f32c_haswell(a, b, n, alpha, beta, r);
2086
+ #elif NK_TARGET_NEON
2087
+ nk_each_blend_f32c_neon(a, b, n, alpha, beta, r);
2088
+ #elif NK_TARGET_RVV
2089
+ nk_each_blend_f32c_rvv(a, b, n, alpha, beta, r);
2090
+ #else
2091
+ nk_each_blend_f32c_serial(a, b, n, alpha, beta, r);
2092
+ #endif
2093
+ }
2094
+
2095
+ NK_PUBLIC void nk_each_blend_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
2096
+ nk_f64c_t const *beta, nk_f64c_t *r) {
2097
+ #if NK_TARGET_SKYLAKE
2098
+ nk_each_blend_f64c_skylake(a, b, n, alpha, beta, r);
2099
+ #elif NK_TARGET_HASWELL
2100
+ nk_each_blend_f64c_haswell(a, b, n, alpha, beta, r);
2101
+ #elif NK_TARGET_NEON
2102
+ nk_each_blend_f64c_neon(a, b, n, alpha, beta, r);
2103
+ #elif NK_TARGET_RVV
2104
+ nk_each_blend_f64c_rvv(a, b, n, alpha, beta, r);
2105
+ #else
2106
+ nk_each_blend_f64c_serial(a, b, n, alpha, beta, r);
2107
+ #endif
2108
+ }
2109
+
2110
+ NK_PUBLIC void nk_each_fma_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
2111
+ nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *r) {
2112
+ #if NK_TARGET_SKYLAKE
2113
+ nk_each_fma_f32c_skylake(a, b, c, n, alpha, beta, r);
2114
+ #elif NK_TARGET_HASWELL
2115
+ nk_each_fma_f32c_haswell(a, b, c, n, alpha, beta, r);
2116
+ #elif NK_TARGET_NEON
2117
+ nk_each_fma_f32c_neon(a, b, c, n, alpha, beta, r);
2118
+ #elif NK_TARGET_RVV
2119
+ nk_each_fma_f32c_rvv(a, b, c, n, alpha, beta, r);
2120
+ #else
2121
+ nk_each_fma_f32c_serial(a, b, c, n, alpha, beta, r);
2122
+ #endif
2123
+ }
2124
+
2125
+ NK_PUBLIC void nk_each_fma_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
2126
+ nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *r) {
2127
+ #if NK_TARGET_SKYLAKE
2128
+ nk_each_fma_f64c_skylake(a, b, c, n, alpha, beta, r);
2129
+ #elif NK_TARGET_HASWELL
2130
+ nk_each_fma_f64c_haswell(a, b, c, n, alpha, beta, r);
2131
+ #elif NK_TARGET_NEON
2132
+ nk_each_fma_f64c_neon(a, b, c, n, alpha, beta, r);
2133
+ #elif NK_TARGET_RVV
2134
+ nk_each_fma_f64c_rvv(a, b, c, n, alpha, beta, r);
2135
+ #else
2136
+ nk_each_fma_f64c_serial(a, b, c, n, alpha, beta, r);
2137
+ #endif
2138
+ }
2139
+
2140
+ #endif // !NK_DYNAMIC_DISPATCH
2141
+
2142
+ #if defined(__cplusplus)
2143
+ } // extern "C"
2144
+ #endif
2145
+
2146
+ #endif // NK_EACH_H