numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,20 +8,19 @@
8
8
  *
9
9
  * @section elementwise_neon_instructions ARM NEON Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * A76 M4+/V1+/Oryon
13
- * vld1q_f32 LD1 (V.4S) 4cy 2/cy 2/cy
14
- * vst1q_f32 ST1 (V.4S) 2cy 2/cy 2/cy
15
- * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
16
- * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
17
- * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
18
- * vaddq_f64 FADD (V.2D, V.2D, V.2D) 2cy 2/cy 4/cy
19
- * vmulq_f64 FMUL (V.2D, V.2D, V.2D) 3cy 2/cy 4/cy
20
- * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy 2/cy 4/cy
21
- * vqaddq_s16 SQADD (V.8H, V.8H, V.8H) 2cy 2/cy 4/cy
22
- * vcvtq_f32_s32 SCVTF (V.4S, V.4S) 3cy 2/cy 2/cy
23
- * vcvtnq_s32_f32 FCVTNS (V.4S, V.4S) 3cy 2/cy 2/cy
24
- * vqmovn_s32 SQXTN (V.4H, V.4S) 3cy 2/cy 2/cy
11
+ * Intrinsic Instruction A76 M5
12
+ * vld1q_f32 LD1 (V.4S) 4cy @ 2p 4cy @ 3p
13
+ * vst1q_f32 ST1 (V.4S) 2cy @ 2p 2cy @ 3p
14
+ * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
15
+ * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
16
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
17
+ * vaddq_f64 FADD (V.2D, V.2D, V.2D) 2cy @ 2p 2cy @ 4p
18
+ * vmulq_f64 FMUL (V.2D, V.2D, V.2D) 3cy @ 2p 3cy @ 4p
19
+ * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy @ 2p 3cy @ 4p
20
+ * vqaddq_s16 SQADD (V.8H, V.8H, V.8H) 2cy @ 2p 3cy @ 2p
21
+ * vcvtq_f32_s32 SCVTF (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
22
+ * vcvtnq_s32_f32 FCVTNS (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
23
+ * vqmovn_s32 SQXTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
25
24
  *
26
25
  * Elementwise operations are throughput-bound rather than latency-bound. FP arithmetic
27
26
  * throughput doubles on 4-pipe cores (Apple M4+, Graviton3+, Oryon) from 2/cy to 4/cy.
@@ -37,6 +36,7 @@
37
36
 
38
37
  #include "numkong/types.h"
39
38
  #include "numkong/cast/neon.h"
39
+ #include "numkong/cast/serial.h" // `nk_f32_to_u8_serial`, `nk_f32_to_i8_serial`
40
40
 
41
41
  #if defined(__cplusplus)
42
42
  extern "C" {
@@ -145,10 +145,10 @@ NK_PUBLIC void nk_each_sum_i16_neon(nk_i16_t const *a, nk_i16_t const *b, nk_siz
145
145
  // The main loop:
146
146
  nk_size_t i = 0;
147
147
  for (; i + 8 <= n; i += 8) {
148
- int16x8_t a_s16x8 = vld1q_s16(a + i);
149
- int16x8_t b_s16x8 = vld1q_s16(b + i);
150
- int16x8_t sum_s16x8 = vqaddq_s16(a_s16x8, b_s16x8);
151
- vst1q_s16(result + i, sum_s16x8);
148
+ int16x8_t a_i16x8 = vld1q_s16(a + i);
149
+ int16x8_t b_i16x8 = vld1q_s16(b + i);
150
+ int16x8_t sum_i16x8 = vqaddq_s16(a_i16x8, b_i16x8);
151
+ vst1q_s16(result + i, sum_i16x8);
152
152
  }
153
153
 
154
154
  // The tail:
@@ -291,10 +291,10 @@ NK_PUBLIC void nk_each_sum_i32_neon(nk_i32_t const *a, nk_i32_t const *b, nk_siz
291
291
  // The main loop:
292
292
  nk_size_t i = 0;
293
293
  for (; i + 4 <= n; i += 4) {
294
- int32x4_t a_s32x4 = vld1q_s32(a + i);
295
- int32x4_t b_s32x4 = vld1q_s32(b + i);
296
- int32x4_t sum_s32x4 = vqaddq_s32(a_s32x4, b_s32x4);
297
- vst1q_s32(result + i, sum_s32x4);
294
+ int32x4_t a_i32x4 = vld1q_s32(a + i);
295
+ int32x4_t b_i32x4 = vld1q_s32(b + i);
296
+ int32x4_t sum_i32x4 = vqaddq_s32(a_i32x4, b_i32x4);
297
+ vst1q_s32(result + i, sum_i32x4);
298
298
  }
299
299
 
300
300
  // The tail:
@@ -437,10 +437,10 @@ NK_PUBLIC void nk_each_sum_i64_neon(nk_i64_t const *a, nk_i64_t const *b, nk_siz
437
437
  // The main loop:
438
438
  nk_size_t i = 0;
439
439
  for (; i + 2 <= n; i += 2) {
440
- int64x2_t a_s64x2 = vld1q_s64(a + i);
441
- int64x2_t b_s64x2 = vld1q_s64(b + i);
442
- int64x2_t sum_s64x2 = vqaddq_s64(a_s64x2, b_s64x2);
443
- vst1q_s64(result + i, sum_s64x2);
440
+ int64x2_t a_i64x2 = vld1q_s64(a + i);
441
+ int64x2_t b_i64x2 = vld1q_s64(b + i);
442
+ int64x2_t sum_i64x2 = vqaddq_s64(a_i64x2, b_i64x2);
443
+ vst1q_s64(result + i, sum_i64x2);
444
444
  }
445
445
 
446
446
  // The tail:
@@ -679,9 +679,9 @@ NK_PUBLIC void nk_each_sum_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_
679
679
  float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
680
680
  float16x8_t b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b + i));
681
681
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
682
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
682
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
683
683
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
684
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
684
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
685
685
  float32x4_t result_low_f32x4 = vaddq_f32(a_low_f32x4, b_low_f32x4);
686
686
  float32x4_t result_high_f32x4 = vaddq_f32(a_high_f32x4, b_high_f32x4);
687
687
  nk_b32_vec_t result_low_vec = nk_f32x4_to_e4m3x4_neon_(result_low_f32x4);
@@ -703,9 +703,9 @@ NK_PUBLIC void nk_each_sum_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_
703
703
  float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
704
704
  float16x8_t b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b + i));
705
705
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
706
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
706
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
707
707
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
708
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
708
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
709
709
  float32x4_t result_low_f32x4 = vaddq_f32(a_low_f32x4, b_low_f32x4);
710
710
  float32x4_t result_high_f32x4 = vaddq_f32(a_high_f32x4, b_high_f32x4);
711
711
  nk_b32_vec_t result_low_vec = nk_f32x4_to_e5m2x4_neon_(result_low_f32x4);
@@ -729,7 +729,7 @@ NK_PUBLIC void nk_each_scale_e4m3_neon(nk_e4m3_t const *a, nk_size_t n, nk_f32_t
729
729
  for (; i + 8 <= n; i += 8) {
730
730
  float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
731
731
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
732
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
732
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
733
733
  float32x4_t result_low_f32x4 = vfmaq_f32(beta_f32x4, a_low_f32x4, alpha_f32x4);
734
734
  float32x4_t result_high_f32x4 = vfmaq_f32(beta_f32x4, a_high_f32x4, alpha_f32x4);
735
735
  nk_b32_vec_t result_low_vec = nk_f32x4_to_e4m3x4_neon_(result_low_f32x4);
@@ -752,7 +752,7 @@ NK_PUBLIC void nk_each_scale_e5m2_neon(nk_e5m2_t const *a, nk_size_t n, nk_f32_t
752
752
  for (; i + 8 <= n; i += 8) {
753
753
  float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
754
754
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
755
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
755
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
756
756
  float32x4_t result_low_f32x4 = vfmaq_f32(beta_f32x4, a_low_f32x4, alpha_f32x4);
757
757
  float32x4_t result_high_f32x4 = vfmaq_f32(beta_f32x4, a_high_f32x4, alpha_f32x4);
758
758
  nk_b32_vec_t result_low_vec = nk_f32x4_to_e5m2x4_neon_(result_low_f32x4);
@@ -776,9 +776,9 @@ NK_PUBLIC void nk_each_blend_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, n
776
776
  float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
777
777
  float16x8_t b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b + i));
778
778
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
779
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
779
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
780
780
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
781
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
781
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
782
782
  float32x4_t a_scaled_low_f32x4 = vmulq_f32(a_low_f32x4, alpha_f32x4);
783
783
  float32x4_t a_scaled_high_f32x4 = vmulq_f32(a_high_f32x4, alpha_f32x4);
784
784
  float32x4_t result_low_f32x4 = vfmaq_f32(a_scaled_low_f32x4, b_low_f32x4, beta_f32x4);
@@ -805,9 +805,9 @@ NK_PUBLIC void nk_each_blend_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, n
805
805
  float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
806
806
  float16x8_t b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b + i));
807
807
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
808
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
808
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
809
809
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
810
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
810
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
811
811
  float32x4_t a_scaled_low_f32x4 = vmulq_f32(a_low_f32x4, alpha_f32x4);
812
812
  float32x4_t a_scaled_high_f32x4 = vmulq_f32(a_high_f32x4, alpha_f32x4);
813
813
  float32x4_t result_low_f32x4 = vfmaq_f32(a_scaled_low_f32x4, b_low_f32x4, beta_f32x4);
@@ -835,11 +835,11 @@ NK_PUBLIC void nk_each_fma_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_
835
835
  float16x8_t b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b + i));
836
836
  float16x8_t c_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(c + i));
837
837
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
838
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
838
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
839
839
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
840
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
840
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
841
841
  float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
842
- float32x4_t c_high_f32x4 = vcvt_f32_f16(vget_high_f16(c_f16x8));
842
+ float32x4_t c_high_f32x4 = vcvt_high_f32_f16(c_f16x8);
843
843
  float32x4_t ab_low_f32x4 = vmulq_f32(a_low_f32x4, b_low_f32x4);
844
844
  float32x4_t ab_high_f32x4 = vmulq_f32(a_high_f32x4, b_high_f32x4);
845
845
  float32x4_t ab_scaled_low_f32x4 = vmulq_f32(ab_low_f32x4, alpha_f32x4);
@@ -870,11 +870,11 @@ NK_PUBLIC void nk_each_fma_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_
870
870
  float16x8_t b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b + i));
871
871
  float16x8_t c_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(c + i));
872
872
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
873
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
873
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
874
874
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
875
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
875
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
876
876
  float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
877
- float32x4_t c_high_f32x4 = vcvt_f32_f16(vget_high_f16(c_f16x8));
877
+ float32x4_t c_high_f32x4 = vcvt_high_f32_f16(c_f16x8);
878
878
  float32x4_t ab_low_f32x4 = vmulq_f32(a_low_f32x4, b_low_f32x4);
879
879
  float32x4_t ab_high_f32x4 = vmulq_f32(a_high_f32x4, b_high_f32x4);
880
880
  float32x4_t ab_scaled_low_f32x4 = vmulq_f32(ab_low_f32x4, alpha_f32x4);
@@ -1089,6 +1089,40 @@ NK_PUBLIC void nk_each_fma_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_
1089
1089
  }
1090
1090
  }
1091
1091
 
1092
+ NK_PUBLIC void nk_each_sum_u8_neon(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result) {
1093
+ // The main loop:
1094
+ nk_size_t i = 0;
1095
+ for (; i + 16 <= n; i += 16) {
1096
+ uint8x16_t a_vec = vld1q_u8(a + i);
1097
+ uint8x16_t b_vec = vld1q_u8(b + i);
1098
+ uint8x16_t sum_vec = vqaddq_u8(a_vec, b_vec);
1099
+ vst1q_u8(result + i, sum_vec);
1100
+ }
1101
+
1102
+ // The tail:
1103
+ for (; i < n; ++i) {
1104
+ nk_f32_t sum = (nk_f32_t)a[i] + b[i];
1105
+ nk_f32_to_u8_serial(&sum, result + i);
1106
+ }
1107
+ }
1108
+
1109
+ NK_PUBLIC void nk_each_sum_i8_neon(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result) {
1110
+ // The main loop:
1111
+ nk_size_t i = 0;
1112
+ for (; i + 16 <= n; i += 16) {
1113
+ int8x16_t a_vec = vld1q_s8(a + i);
1114
+ int8x16_t b_vec = vld1q_s8(b + i);
1115
+ int8x16_t sum_vec = vqaddq_s8(a_vec, b_vec);
1116
+ vst1q_s8(result + i, sum_vec);
1117
+ }
1118
+
1119
+ // The tail:
1120
+ for (; i < n; ++i) {
1121
+ nk_f32_t sum = (nk_f32_t)a[i] + b[i];
1122
+ nk_f32_to_i8_serial(&sum, result + i);
1123
+ }
1124
+ }
1125
+
1092
1126
  #if defined(__clang__)
1093
1127
  #pragma clang attribute pop
1094
1128
  #elif defined(__GNUC__)
@@ -8,18 +8,17 @@
8
8
  *
9
9
  * @section elementwise_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * A76 M4+/V1+/Oryon
13
- * vld1_bf16 LD1 (V.4H) 4cy 2/cy 3/cy
14
- * vst1_bf16 ST1 (V.4H) 2cy 2/cy 3/cy
15
- * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy 2/cy 4/cy
16
- * vcvt_bf16_f32 BFCVT (V.4H, V.4S) 3cy 2/cy 4/cy
17
- * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
18
- * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
19
- * vmulq_n_f32 FMUL (V.4S, V.4S, scalar) 3cy 2/cy 4/cy
20
- * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
21
- * vfmaq_n_f32 FMLA (V.4S, V.4S, scalar) 4cy 2/cy 4/cy
22
- * vdupq_n_f32 DUP (V.4S, scalar) 2cy 2/cy 4/cy
11
+ * Intrinsic Instruction A76 M5
12
+ * vld1_bf16 LD1 (V.4H) 4cy @ 2p 4cy @ 3p
13
+ * vst1_bf16 ST1 (V.4H) 2cy @ 2p 2cy @ 3p
14
+ * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
15
+ * vcvt_bf16_f32 BFCVT (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
16
+ * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
17
+ * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
18
+ * vmulq_n_f32 FMUL (V.4S, V.4S, scalar) 3cy @ 2p 3cy @ 4p
19
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
20
+ * vfmaq_n_f32 FMLA (V.4S, V.4S, scalar) 4cy @ 2p 3cy @ 4p
21
+ * vdupq_n_f32 DUP (V.4S, scalar) 2cy @ 2p 2cy @ 4p
23
22
  *
24
23
  * The ARMv8.6-BF16 extension provides element-wise operations on BF16 data by converting to F32
25
24
  * for arithmetic, then back to BF16 for storage. This preserves the dynamic range benefits of BF16
@@ -8,28 +8,27 @@
8
8
  *
9
9
  * @section elementwise_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * A76 M4+/V1+/Oryon
13
- * vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
14
- * vst1q_f16 ST1 (V.8H) 2cy 2/cy 3/cy
15
- * vaddq_f16 FADD (V.8H, V.8H, V.8H) 2cy 2/cy 4/cy
16
- * vmulq_f16 FMUL (V.8H, V.8H, V.8H) 3cy 2/cy 4/cy
17
- * vmulq_n_f16 FMUL (V.8H, V.8H, scalar) 3cy 2/cy 4/cy
18
- * vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
19
- * vfmaq_n_f16 FMLA (V.8H, V.8H, scalar) 4cy 2/cy 4/cy
20
- * vdupq_n_f16 DUP (V.8H, scalar) 2cy 2/cy 4/cy
21
- * vld1_u8 LD1 (V.8B) 4cy 2/cy 3/cy
22
- * vld1_s8 LD1 (V.8B) 4cy 2/cy 3/cy
23
- * vmovl_u8 UXTL (V.8H, V.8B) 2cy 2/cy 4/cy
24
- * vmovl_s8 SXTL (V.8H, V.8B) 2cy 2/cy 4/cy
25
- * vcvtq_f16_u16 UCVTF (V.8H, V.8H) 3cy 2/cy 4/cy
26
- * vcvtq_f16_s16 SCVTF (V.8H, V.8H) 3cy 2/cy 4/cy
27
- * vcvtnq_u16_f16 FCVTNU (V.8H, V.8H) 3cy 2/cy 4/cy
28
- * vcvtnq_s16_f16 FCVTNS (V.8H, V.8H) 3cy 2/cy 4/cy
29
- * vqmovn_u16 UQXTN (V.8B, V.8H) 3cy 2/cy 4/cy
30
- * vqmovn_s16 SQXTN (V.8B, V.8H) 3cy 2/cy 4/cy
31
- * vqaddq_u8 UQADD (V.16B, V.16B, V.16B) 2cy 2/cy 4/cy
32
- * vqaddq_s8 SQADD (V.16B, V.16B, V.16B) 2cy 2/cy 4/cy
11
+ * Intrinsic Instruction A76 M5
12
+ * vld1q_f16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
13
+ * vst1q_f16 ST1 (V.8H) 2cy @ 2p 2cy @ 3p
14
+ * vaddq_f16 FADD (V.8H, V.8H, V.8H) 2cy @ 2p 2cy @ 4p
15
+ * vmulq_f16 FMUL (V.8H, V.8H, V.8H) 3cy @ 2p 3cy @ 4p
16
+ * vmulq_n_f16 FMUL (V.8H, V.8H, scalar) 3cy @ 2p 3cy @ 4p
17
+ * vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
18
+ * vfmaq_n_f16 FMLA (V.8H, V.8H, scalar) 4cy @ 2p 4cy @ 4p
19
+ * vdupq_n_f16 DUP (V.8H, scalar) 2cy @ 2p 2cy @ 4p
20
+ * vld1_u8 LD1 (V.8B) 4cy @ 2p 4cy @ 3p
21
+ * vld1_s8 LD1 (V.8B) 4cy @ 2p 4cy @ 3p
22
+ * vmovl_u8 UXTL (V.8H, V.8B) 2cy @ 2p 2cy @ 4p
23
+ * vmovl_s8 SXTL (V.8H, V.8B) 2cy @ 2p 2cy @ 4p
24
+ * vcvtq_f16_u16 UCVTF (V.8H, V.8H) 3cy @ 2p 3cy @ 4p
25
+ * vcvtq_f16_s16 SCVTF (V.8H, V.8H) 3cy @ 2p 3cy @ 4p
26
+ * vcvtnq_u16_f16 FCVTNU (V.8H, V.8H) 3cy @ 2p 3cy @ 4p
27
+ * vcvtnq_s16_f16 FCVTNS (V.8H, V.8H) 3cy @ 2p 3cy @ 4p
28
+ * vqmovn_u16 UQXTN (V.8B, V.8H) 3cy @ 2p 3cy @ 4p
29
+ * vqmovn_s16 SQXTN (V.8B, V.8H) 3cy @ 2p 3cy @ 4p
30
+ * vqaddq_u8 UQADD (V.16B, V.16B, V.16B) 2cy @ 2p 3cy @ 2p
31
+ * vqaddq_s8 SQADD (V.16B, V.16B, V.16B) 2cy @ 2p 3cy @ 2p
33
32
  *
34
33
  * The ARMv8.2-FP16 extension enables native half-precision element-wise operations, processing 8
35
34
  * F16 elements per instruction. Operations like sum, scale, blend, and fma work directly in F16,
@@ -46,6 +45,7 @@
46
45
 
47
46
  #include "numkong/types.h"
48
47
  #include "numkong/cast/serial.h" // `nk_f32_to_i8_serial`
48
+ #include "numkong/each/neon.h" // `nk_each_sum_u8_neon`, `nk_each_sum_i8_neon`
49
49
 
50
50
  #if defined(__cplusplus)
51
51
  extern "C" {
@@ -161,23 +161,6 @@ NK_PUBLIC void nk_each_fma_f16_neonhalf( //
161
161
  beta_f16 * ((float16_t const *)c)[i];
162
162
  }
163
163
 
164
- NK_PUBLIC void nk_each_sum_u8_neonhalf(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result) {
165
- // The main loop:
166
- nk_size_t i = 0;
167
- for (; i + 16 <= n; i += 16) {
168
- uint8x16_t a_vec = vld1q_u8(a + i);
169
- uint8x16_t b_vec = vld1q_u8(b + i);
170
- uint8x16_t sum_vec = vqaddq_u8(a_vec, b_vec);
171
- vst1q_u8(result + i, sum_vec);
172
- }
173
-
174
- // The tail:
175
- for (; i < n; ++i) {
176
- nk_f32_t sum = (nk_f32_t)a[i] + b[i];
177
- nk_f32_to_u8_serial(&sum, result + i);
178
- }
179
- }
180
-
181
164
  NK_PUBLIC void nk_each_scale_u8_neonhalf(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
182
165
  nk_u8_t *result) {
183
166
  float16_t alpha_f16 = (float16_t)*alpha;
@@ -213,7 +196,7 @@ NK_PUBLIC void nk_each_blend_u8_neonhalf( //
213
196
  // 1. Simple addition, when both weights are equal to 1.0.
214
197
  if (alpha_val == 1 && beta_val == 1) {
215
198
  // In this case we can avoid expensive multiplications.
216
- nk_each_sum_u8_neonhalf(a, b, n, result);
199
+ nk_each_sum_u8_neon(a, b, n, result);
217
200
  return;
218
201
  }
219
202
  // 2. Just scaling, when one of the weights is equal to zero.
@@ -249,52 +232,6 @@ NK_PUBLIC void nk_each_blend_u8_neonhalf( //
249
232
  }
250
233
  }
251
234
 
252
- NK_PUBLIC void nk_each_fma_u8_neonhalf( //
253
- nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, //
254
- nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
255
- float16_t alpha_f16 = (float16_t)*alpha;
256
- float16_t beta_f16 = (float16_t)*beta;
257
-
258
- // The main loop:
259
- nk_size_t i = 0;
260
- for (; i + 8 <= n; i += 8) {
261
- uint8x8_t a_u8x8 = vld1_u8(a + i);
262
- uint8x8_t b_u8x8 = vld1_u8(b + i);
263
- uint8x8_t c_u8x8 = vld1_u8(c + i);
264
- float16x8_t a_f16x8 = vcvtq_f16_u16(vmovl_u8(a_u8x8));
265
- float16x8_t b_f16x8 = vcvtq_f16_u16(vmovl_u8(b_u8x8));
266
- float16x8_t c_f16x8 = vcvtq_f16_u16(vmovl_u8(c_u8x8));
267
- float16x8_t ab_f16x8 = vmulq_f16(a_f16x8, b_f16x8);
268
- float16x8_t ab_scaled_f16x8 = vmulq_n_f16(ab_f16x8, alpha_f16);
269
- float16x8_t result_f16x8 = vfmaq_n_f16(ab_scaled_f16x8, c_f16x8, beta_f16);
270
- uint8x8_t result_u8x8 = vqmovn_u16(vcvtnq_u16_f16(result_f16x8));
271
- vst1_u8(result + i, result_u8x8);
272
- }
273
-
274
- // The tail:
275
- for (; i < n; ++i) {
276
- nk_f32_t sum = alpha_f16 * a[i] * b[i] + beta_f16 * c[i];
277
- nk_f32_to_u8_serial(&sum, result + i);
278
- }
279
- }
280
-
281
- NK_PUBLIC void nk_each_sum_i8_neonhalf(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result) {
282
- // The main loop:
283
- nk_size_t i = 0;
284
- for (; i + 16 <= n; i += 16) {
285
- int8x16_t a_vec = vld1q_s8(a + i);
286
- int8x16_t b_vec = vld1q_s8(b + i);
287
- int8x16_t sum_vec = vqaddq_s8(a_vec, b_vec);
288
- vst1q_s8(result + i, sum_vec);
289
- }
290
-
291
- // The tail:
292
- for (; i < n; ++i) {
293
- nk_f32_t sum = (nk_f32_t)a[i] + b[i];
294
- nk_f32_to_i8_serial(&sum, result + i);
295
- }
296
- }
297
-
298
235
  NK_PUBLIC void nk_each_scale_i8_neonhalf(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
299
236
  nk_i8_t *result) {
300
237
  float16_t alpha_f16 = (float16_t)*alpha;
@@ -330,7 +267,7 @@ NK_PUBLIC void nk_each_blend_i8_neonhalf( //
330
267
  // 1. Simple addition, when both weights are equal to 1.0.
331
268
  if (alpha_val == 1 && beta_val == 1) {
332
269
  // In this case we can avoid expensive multiplications.
333
- nk_each_sum_i8_neonhalf(a, b, n, result);
270
+ nk_each_sum_i8_neon(a, b, n, result);
334
271
  return;
335
272
  }
336
273
  // 2. Just scaling, when one of the weights is equal to zero.
@@ -366,35 +303,6 @@ NK_PUBLIC void nk_each_blend_i8_neonhalf( //
366
303
  }
367
304
  }
368
305
 
369
- NK_PUBLIC void nk_each_fma_i8_neonhalf( //
370
- nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, //
371
- nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
372
- float16_t alpha_f16 = (float16_t)*alpha;
373
- float16_t beta_f16 = (float16_t)*beta;
374
-
375
- // The main loop:
376
- nk_size_t i = 0;
377
- for (; i + 8 <= n; i += 8) {
378
- int8x8_t a_i8x8 = vld1_s8(a + i);
379
- int8x8_t b_i8x8 = vld1_s8(b + i);
380
- int8x8_t c_i8x8 = vld1_s8(c + i);
381
- float16x8_t a_f16x8 = vcvtq_f16_s16(vmovl_s8(a_i8x8));
382
- float16x8_t b_f16x8 = vcvtq_f16_s16(vmovl_s8(b_i8x8));
383
- float16x8_t c_f16x8 = vcvtq_f16_s16(vmovl_s8(c_i8x8));
384
- float16x8_t ab_f16x8 = vmulq_f16(a_f16x8, b_f16x8);
385
- float16x8_t ab_scaled_f16x8 = vmulq_n_f16(ab_f16x8, alpha_f16);
386
- float16x8_t result_f16x8 = vfmaq_n_f16(ab_scaled_f16x8, c_f16x8, beta_f16);
387
- int8x8_t result_i8x8 = vqmovn_s16(vcvtnq_s16_f16(result_f16x8));
388
- vst1_s8(result + i, result_i8x8);
389
- }
390
-
391
- // The tail:
392
- for (; i < n; ++i) {
393
- nk_f32_t sum = alpha_f16 * a[i] * b[i] + beta_f16 * c[i];
394
- nk_f32_to_i8_serial(&sum, result + i);
395
- }
396
- }
397
-
398
306
  #if defined(__clang__)
399
307
  #pragma clang attribute pop
400
308
  #elif defined(__GNUC__)
@@ -185,8 +185,8 @@ NK_PUBLIC void nk_each_sum_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_s
185
185
  NK_PUBLIC void nk_each_scale_f64_rvv(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
186
186
  nk_f64_t *result) {
187
187
  nk_f64_t alpha_val = *alpha, beta_val = *beta;
188
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
189
- vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val, vlmax);
188
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
189
+ vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val, max_vector_length);
190
190
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
191
191
  vector_length = __riscv_vsetvl_e64m4(n);
192
192
  vfloat64m4_t a_f64m4 = __riscv_vle64_v_f64m4(a, vector_length);
@@ -198,8 +198,8 @@ NK_PUBLIC void nk_each_scale_f64_rvv(nk_f64_t const *a, nk_size_t n, nk_f64_t co
198
198
  NK_PUBLIC void nk_each_scale_f32_rvv(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
199
199
  nk_f32_t *result) {
200
200
  nk_f32_t alpha_val = *alpha, beta_val = *beta;
201
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
202
- vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, vlmax);
201
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
202
+ vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
203
203
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
204
204
  vector_length = __riscv_vsetvl_e32m4(n);
205
205
  vfloat32m4_t a_f32m4 = __riscv_vle32_v_f32m4(a, vector_length);
@@ -211,8 +211,8 @@ NK_PUBLIC void nk_each_scale_f32_rvv(nk_f32_t const *a, nk_size_t n, nk_f32_t co
211
211
  NK_PUBLIC void nk_each_scale_f16_rvv(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
212
212
  nk_f16_t *result) {
213
213
  nk_f32_t alpha_val = *alpha, beta_val = *beta;
214
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
215
- vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, vlmax);
214
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
215
+ vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, max_vector_length);
216
216
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
217
217
  vector_length = __riscv_vsetvl_e16m1(n);
218
218
  vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a, vector_length);
@@ -226,8 +226,8 @@ NK_PUBLIC void nk_each_scale_f16_rvv(nk_f16_t const *a, nk_size_t n, nk_f32_t co
226
226
  NK_PUBLIC void nk_each_scale_bf16_rvv(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
227
227
  nk_bf16_t *result) {
228
228
  nk_f32_t alpha_val = *alpha, beta_val = *beta;
229
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
230
- vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, vlmax);
229
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
230
+ vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, max_vector_length);
231
231
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
232
232
  vector_length = __riscv_vsetvl_e16m1(n);
233
233
  vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a, vector_length);
@@ -241,8 +241,8 @@ NK_PUBLIC void nk_each_scale_bf16_rvv(nk_bf16_t const *a, nk_size_t n, nk_f32_t
241
241
  NK_PUBLIC void nk_each_scale_i8_rvv(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
242
242
  nk_i8_t *result) {
243
243
  nk_f32_t alpha_val = *alpha, beta_val = *beta;
244
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
245
- vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, vlmax);
244
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
245
+ vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
246
246
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
247
247
  vector_length = __riscv_vsetvl_e8m1(n);
248
248
  vint8m1_t a_i8m1 = __riscv_vle8_v_i8m1(a, vector_length);
@@ -262,8 +262,8 @@ NK_PUBLIC void nk_each_scale_i8_rvv(nk_i8_t const *a, nk_size_t n, nk_f32_t cons
262
262
  NK_PUBLIC void nk_each_scale_u8_rvv(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
263
263
  nk_u8_t *result) {
264
264
  nk_f32_t alpha_val = *alpha, beta_val = *beta;
265
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
266
- vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, vlmax);
265
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
266
+ vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
267
267
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
268
268
  vector_length = __riscv_vsetvl_e8m1(n);
269
269
  vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1(a, vector_length);
@@ -283,8 +283,8 @@ NK_PUBLIC void nk_each_scale_u8_rvv(nk_u8_t const *a, nk_size_t n, nk_f32_t cons
283
283
  NK_PUBLIC void nk_each_scale_i16_rvv(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
284
284
  nk_i16_t *result) {
285
285
  nk_f32_t alpha_val = *alpha, beta_val = *beta;
286
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
287
- vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, vlmax);
286
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
287
+ vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, max_vector_length);
288
288
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
289
289
  vector_length = __riscv_vsetvl_e16m1(n);
290
290
  vint16m1_t a_i16m1 = __riscv_vle16_v_i16m1(a, vector_length);
@@ -302,8 +302,8 @@ NK_PUBLIC void nk_each_scale_i16_rvv(nk_i16_t const *a, nk_size_t n, nk_f32_t co
302
302
  NK_PUBLIC void nk_each_scale_u16_rvv(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
303
303
  nk_u16_t *result) {
304
304
  nk_f32_t alpha_val = *alpha, beta_val = *beta;
305
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
306
- vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, vlmax);
305
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
306
+ vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, max_vector_length);
307
307
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
308
308
  vector_length = __riscv_vsetvl_e16m1(n);
309
309
  vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1(a, vector_length);
@@ -321,8 +321,8 @@ NK_PUBLIC void nk_each_scale_u16_rvv(nk_u16_t const *a, nk_size_t n, nk_f32_t co
321
321
  NK_PUBLIC void nk_each_scale_i32_rvv(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
322
322
  nk_i32_t *result) {
323
323
  nk_f64_t alpha_val = *alpha, beta_val = *beta;
324
- nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
325
- vfloat64m2_t beta_f64m2 = __riscv_vfmv_v_f_f64m2(beta_val, vlmax);
324
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
325
+ vfloat64m2_t beta_f64m2 = __riscv_vfmv_v_f_f64m2(beta_val, max_vector_length);
326
326
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
327
327
  vector_length = __riscv_vsetvl_e32m1(n);
328
328
  vint32m1_t a_i32m1 = __riscv_vle32_v_i32m1(a, vector_length);
@@ -338,8 +338,8 @@ NK_PUBLIC void nk_each_scale_i32_rvv(nk_i32_t const *a, nk_size_t n, nk_f64_t co
338
338
  NK_PUBLIC void nk_each_scale_u32_rvv(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
339
339
  nk_u32_t *result) {
340
340
  nk_f64_t alpha_val = *alpha, beta_val = *beta;
341
- nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
342
- vfloat64m2_t beta_f64m2 = __riscv_vfmv_v_f_f64m2(beta_val, vlmax);
341
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
342
+ vfloat64m2_t beta_f64m2 = __riscv_vfmv_v_f_f64m2(beta_val, max_vector_length);
343
343
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
344
344
  vector_length = __riscv_vsetvl_e32m1(n);
345
345
  vuint32m1_t a_u32m1 = __riscv_vle32_v_u32m1(a, vector_length);
@@ -355,8 +355,8 @@ NK_PUBLIC void nk_each_scale_u32_rvv(nk_u32_t const *a, nk_size_t n, nk_f64_t co
355
355
  NK_PUBLIC void nk_each_scale_i64_rvv(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
356
356
  nk_i64_t *result) {
357
357
  nk_f64_t alpha_val = *alpha, beta_val = *beta;
358
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
359
- vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val, vlmax);
358
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
359
+ vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val, max_vector_length);
360
360
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
361
361
  vector_length = __riscv_vsetvl_e64m4(n);
362
362
  vint64m4_t a_i64m4 = __riscv_vle64_v_i64m4(a, vector_length);
@@ -370,8 +370,8 @@ NK_PUBLIC void nk_each_scale_i64_rvv(nk_i64_t const *a, nk_size_t n, nk_f64_t co
370
370
  NK_PUBLIC void nk_each_scale_u64_rvv(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
371
371
  nk_u64_t *result) {
372
372
  nk_f64_t alpha_val = *alpha, beta_val = *beta;
373
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
374
- vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val, vlmax);
373
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
374
+ vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val, max_vector_length);
375
375
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
376
376
  vector_length = __riscv_vsetvl_e64m4(n);
377
377
  vuint64m4_t a_u64m4 = __riscv_vle64_v_u64m4(a, vector_length);
@@ -386,8 +386,8 @@ NK_PUBLIC void nk_each_scale_u64_rvv(nk_u64_t const *a, nk_size_t n, nk_f64_t co
386
386
  NK_PUBLIC void nk_each_scale_e4m3_rvv(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
387
387
  nk_e4m3_t *result) {
388
388
  nk_f32_t alpha_val = *alpha, beta_val = *beta;
389
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
390
- vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, vlmax);
389
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
390
+ vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
391
391
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
392
392
  vector_length = __riscv_vsetvl_e8m1(n);
393
393
  vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a, vector_length);
@@ -401,8 +401,8 @@ NK_PUBLIC void nk_each_scale_e4m3_rvv(nk_e4m3_t const *a, nk_size_t n, nk_f32_t
401
401
  NK_PUBLIC void nk_each_scale_e5m2_rvv(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
402
402
  nk_e5m2_t *result) {
403
403
  nk_f32_t alpha_val = *alpha, beta_val = *beta;
404
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
405
- vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, vlmax);
404
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
405
+ vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
406
406
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
407
407
  vector_length = __riscv_vsetvl_e8m1(n);
408
408
  vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a, vector_length);