numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -10,15 +10,18 @@
10
10
  *
11
11
  * Key NEON instructions for dot products:
12
12
  *
13
- * Intrinsic Instruction Latency Throughput
14
- * A76 M4+/V1+/Oryon
15
- * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
16
- * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy 2/cy 4/cy
17
- * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
18
- * vaddvq_f32 FADDP+FADDP (reduce) 5cy 1/cy 1/cy
19
- * vaddvq_f64 FADDP (V.2D to scalar) 3cy 1/cy 1/cy
20
- * vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy 2/cy 2/cy
21
- * vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy 1/cy 1/cy
13
+ * Intrinsic Instruction A76 M5
14
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
15
+ * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy @ 2p 4cy @ 4p
16
+ * vfmsq_f64 FMLS (V.2D, V.2D, V.2D) 4cy @ 2p 4cy @ 4p
17
+ * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
18
+ * vmulq_f64 FMUL (V.2D, V.2D, V.2D) 3cy @ 2p 3cy @ 4p
19
+ * vaddvq_f32 FADDP+FADDP (reduce) 5cy @ 1p 8cy @ 1p
20
+ * vaddvq_f64 FADDP (V.2D to scalar) 3cy @ 1p 3cy @ 1p
21
+ * vpaddq_f32 FADDP (V.4S, V.4S, V.4S) 2cy @ 2p 3cy @ 4p
22
+ * vpaddq_f64 FADDP (V.2D, V.2D, V.2D) 2cy @ 2p 3cy @ 4p
23
+ * vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy @ 2p 3cy @ 2p
24
+ * vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy @ 1p 4cy @ 1p
22
25
  *
23
26
  * FMA throughput doubles on cores with 4 SIMD pipes (Apple M4+, Graviton3+, Oryon), but
24
27
  * horizontal reductions remain at 1/cy on all cores and become the main bottleneck.
@@ -118,21 +121,25 @@ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x2_neon_(float64x2_t sum_f64x2, float6
118
121
  return tentative_sum + (lower_error + upper_error + rounding_error);
119
122
  }
120
123
 
121
- #pragma region - Traditional Floats
124
+ #pragma region F32 and F64 Floats
122
125
 
123
126
  NK_PUBLIC void nk_dot_f32_neon(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
124
127
  nk_f64_t *result) {
125
- // Upcast f32 to f64 for accumulation (2 f32s per iteration, avoids slow vget_low/high)
126
- float64x2_t sum_f64x2 = vdupq_n_f64(0);
128
+ // Upcast f32 to f64 via FCVTL/FCVTL2, two independent FMA chains for ILP
129
+ float64x2_t sum_low_f64x2 = vdupq_n_f64(0);
130
+ float64x2_t sum_high_f64x2 = vdupq_n_f64(0);
127
131
  nk_size_t idx_scalars = 0;
128
- for (; idx_scalars + 2 <= count_scalars; idx_scalars += 2) {
129
- float32x2_t a_f32x2 = vld1_f32(a_scalars + idx_scalars);
130
- float32x2_t b_f32x2 = vld1_f32(b_scalars + idx_scalars);
131
- float64x2_t a_f64x2 = vcvt_f64_f32(a_f32x2);
132
- float64x2_t b_f64x2 = vcvt_f64_f32(b_f32x2);
133
- sum_f64x2 = vfmaq_f64(sum_f64x2, a_f64x2, b_f64x2);
132
+ for (; idx_scalars + 4 <= count_scalars; idx_scalars += 4) {
133
+ float32x4_t a_f32x4 = vld1q_f32(a_scalars + idx_scalars);
134
+ float32x4_t b_f32x4 = vld1q_f32(b_scalars + idx_scalars);
135
+ float64x2_t a_low_f64x2 = vcvt_f64_f32(vget_low_f32(a_f32x4));
136
+ float64x2_t a_high_f64x2 = vcvt_high_f64_f32(a_f32x4);
137
+ float64x2_t b_low_f64x2 = vcvt_f64_f32(vget_low_f32(b_f32x4));
138
+ float64x2_t b_high_f64x2 = vcvt_high_f64_f32(b_f32x4);
139
+ sum_low_f64x2 = vfmaq_f64(sum_low_f64x2, a_low_f64x2, b_low_f64x2);
140
+ sum_high_f64x2 = vfmaq_f64(sum_high_f64x2, a_high_f64x2, b_high_f64x2);
134
141
  }
135
- nk_f64_t sum_f64 = vaddvq_f64(sum_f64x2);
142
+ nk_f64_t sum_f64 = vaddvq_f64(vaddq_f64(sum_low_f64x2, sum_high_f64x2));
136
143
  for (; idx_scalars < count_scalars; ++idx_scalars)
137
144
  sum_f64 += (nk_f64_t)a_scalars[idx_scalars] * (nk_f64_t)b_scalars[idx_scalars];
138
145
  *result = sum_f64;
@@ -243,10 +250,10 @@ NK_INTERNAL void nk_dot_f32x2_finalize_neon(
243
250
  nk_dot_f32x2_state_neon_t const *state_c, nk_dot_f32x2_state_neon_t const *state_d, //
244
251
  nk_size_t total_dimensions, nk_b256_vec_t *result) {
245
252
  nk_unused_(total_dimensions);
246
- result->f64s[0] = vaddvq_f64(state_a->sum_f64x2);
247
- result->f64s[1] = vaddvq_f64(state_b->sum_f64x2);
248
- result->f64s[2] = vaddvq_f64(state_c->sum_f64x2);
249
- result->f64s[3] = vaddvq_f64(state_d->sum_f64x2);
253
+ float64x2_t ab_f64x2 = vpaddq_f64(state_a->sum_f64x2, state_b->sum_f64x2);
254
+ float64x2_t cd_f64x2 = vpaddq_f64(state_c->sum_f64x2, state_d->sum_f64x2);
255
+ vst1q_f64(&result->f64s[0], ab_f64x2);
256
+ vst1q_f64(&result->f64s[2], cd_f64x2);
250
257
  }
251
258
 
252
259
  NK_PUBLIC void nk_dot_f64_neon(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
@@ -302,11 +309,11 @@ nk_dot_f64c_neon_cycle:
302
309
  nk_b128_vec_t a_tail, b_tail;
303
310
  nk_partial_load_b64x2_serial_(a_pairs, &a_tail, count_pairs * 2);
304
311
  nk_partial_load_b64x2_serial_(b_pairs, &b_tail, count_pairs * 2);
305
- float64x2_t zeros = vdupq_n_f64(0);
306
- a_real_f64x2 = vzip1q_f64(a_tail.f64x2, zeros);
307
- a_imag_f64x2 = vzip2q_f64(a_tail.f64x2, zeros);
308
- b_real_f64x2 = vzip1q_f64(b_tail.f64x2, zeros);
309
- b_imag_f64x2 = vzip2q_f64(b_tail.f64x2, zeros);
312
+ float64x2_t zeros_f64x2 = vdupq_n_f64(0);
313
+ a_real_f64x2 = vzip1q_f64(a_tail.f64x2, zeros_f64x2);
314
+ a_imag_f64x2 = vzip2q_f64(a_tail.f64x2, zeros_f64x2);
315
+ b_real_f64x2 = vzip1q_f64(b_tail.f64x2, zeros_f64x2);
316
+ b_imag_f64x2 = vzip2q_f64(b_tail.f64x2, zeros_f64x2);
310
317
  count_pairs = 0;
311
318
  }
312
319
  else {
@@ -385,11 +392,11 @@ nk_vdot_f64c_neon_cycle:
385
392
  nk_b128_vec_t a_tail, b_tail;
386
393
  nk_partial_load_b64x2_serial_(a_pairs, &a_tail, count_pairs * 2);
387
394
  nk_partial_load_b64x2_serial_(b_pairs, &b_tail, count_pairs * 2);
388
- float64x2_t zeros = vdupq_n_f64(0);
389
- a_real_f64x2 = vzip1q_f64(a_tail.f64x2, zeros);
390
- a_imag_f64x2 = vzip2q_f64(a_tail.f64x2, zeros);
391
- b_real_f64x2 = vzip1q_f64(b_tail.f64x2, zeros);
392
- b_imag_f64x2 = vzip2q_f64(b_tail.f64x2, zeros);
395
+ float64x2_t zeros_f64x2 = vdupq_n_f64(0);
396
+ a_real_f64x2 = vzip1q_f64(a_tail.f64x2, zeros_f64x2);
397
+ a_imag_f64x2 = vzip2q_f64(a_tail.f64x2, zeros_f64x2);
398
+ b_real_f64x2 = vzip1q_f64(b_tail.f64x2, zeros_f64x2);
399
+ b_imag_f64x2 = vzip2q_f64(b_tail.f64x2, zeros_f64x2);
393
400
  count_pairs = 0;
394
401
  }
395
402
  else {
@@ -505,9 +512,9 @@ NK_INTERNAL void nk_dot_f64x2_finalize_neon(
505
512
  result->f64s[3] = nk_dot_stable_sum_f64x2_neon_(state_d->sum_f64x2, state_d->compensation_f64x2);
506
513
  }
507
514
 
508
- #pragma endregion - Traditional Floats
515
+ #pragma endregion F32 and F64 Floats
509
516
 
510
- #pragma region - Smaller Floats
517
+ #pragma region F16 and BF16 Floats
511
518
 
512
519
  NK_PUBLIC void nk_dot_bf16_neon(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
513
520
  nk_f32_t *result) {
@@ -528,9 +535,9 @@ nk_dot_bf16_neon_cycle:
528
535
  a_scalars += 8, b_scalars += 8, count_scalars -= 8;
529
536
  }
530
537
  float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a_u16x8), 16));
531
- float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a_u16x8), 16));
538
+ float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(a_u16x8, 16));
532
539
  float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b_u16x8), 16));
533
- float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b_u16x8), 16));
540
+ float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(b_u16x8, 16));
534
541
  sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
535
542
  sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
536
543
  if (count_scalars) goto nk_dot_bf16_neon_cycle;
@@ -555,9 +562,9 @@ NK_INTERNAL void nk_dot_bf16x8_update_neon(nk_dot_bf16x8_state_neon_t *state, nk
555
562
  nk_unused_(active_dimensions);
556
563
  // Convert bf16 to f32 via USHLL shift-16 (low and high halves)
557
564
  float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a.u16x8), 16));
558
- float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a.u16x8), 16));
565
+ float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(a.u16x8, 16));
559
566
  float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.u16x8), 16));
560
- float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.u16x8), 16));
567
+ float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(b.u16x8, 16));
561
568
  state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_low_f32x4, b_low_f32x4);
562
569
  state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_high_f32x4, b_high_f32x4);
563
570
  }
@@ -567,10 +574,9 @@ NK_INTERNAL void nk_dot_bf16x8_finalize_neon(
567
574
  nk_dot_bf16x8_state_neon_t const *state_c, nk_dot_bf16x8_state_neon_t const *state_d, //
568
575
  nk_size_t total_dimensions, nk_b128_vec_t *result) {
569
576
  nk_unused_(total_dimensions);
570
- result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
571
- result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
572
- result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
573
- result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
577
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
578
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
579
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
574
580
  }
575
581
 
576
582
  NK_PUBLIC void nk_dot_f16_neon(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
@@ -591,10 +597,12 @@ nk_dot_f16_neon_cycle:
591
597
  b_u16x8 = vld1q_u16((nk_u16_t const *)b_scalars);
592
598
  a_scalars += 8, b_scalars += 8, count_scalars -= 8;
593
599
  }
594
- float32x4_t a_low_f32x4 = nk_f16x4_to_f32x4_neon_(vget_low_u16(a_u16x8));
595
- float32x4_t a_high_f32x4 = nk_f16x4_to_f32x4_neon_(vget_high_u16(a_u16x8));
596
- float32x4_t b_low_f32x4 = nk_f16x4_to_f32x4_neon_(vget_low_u16(b_u16x8));
597
- float32x4_t b_high_f32x4 = nk_f16x4_to_f32x4_neon_(vget_high_u16(b_u16x8));
600
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_u16x8);
601
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_u16x8);
602
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
603
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
604
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
605
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
598
606
  sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
599
607
  sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
600
608
  if (count_scalars) goto nk_dot_f16_neon_cycle;
@@ -604,8 +612,8 @@ nk_dot_f16_neon_cycle:
604
612
  /**
605
613
  * @brief Running state for 128-bit dot accumulation over f16 scalars on plain NEON.
606
614
  *
607
- * Processes 8 f16 values at a time (128 bits), converting to f32 via integer bit
608
- * manipulation for accumulation without requiring the ARMv8.2-A FP16 extension.
615
+ * Processes 8 f16 values at a time (128 bits), converting to f32 via FCVTL
616
+ * for accumulation without requiring the ARMv8.2-A FP16 arithmetic extension.
609
617
  */
610
618
  typedef struct nk_dot_f16x8_state_neon_t {
611
619
  float32x4_t sum_f32x4;
@@ -617,11 +625,13 @@ NK_INTERNAL void nk_dot_f16x8_update_neon(nk_dot_f16x8_state_neon_t *state, nk_b
617
625
  nk_size_t depth_offset, nk_size_t active_dimensions) {
618
626
  nk_unused_(depth_offset);
619
627
  nk_unused_(active_dimensions);
620
- // Convert f16 to f32 via integer bit manipulation (low and high halves)
621
- float32x4_t a_low_f32x4 = nk_f16x4_to_f32x4_neon_(vget_low_u16(a.u16x8));
622
- float32x4_t a_high_f32x4 = nk_f16x4_to_f32x4_neon_(vget_high_u16(a.u16x8));
623
- float32x4_t b_low_f32x4 = nk_f16x4_to_f32x4_neon_(vget_low_u16(b.u16x8));
624
- float32x4_t b_high_f32x4 = nk_f16x4_to_f32x4_neon_(vget_high_u16(b.u16x8));
628
+ // Convert f16 to f32 via FCVTL / FCVTL2 (low and high halves)
629
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(a.u16x8);
630
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(b.u16x8);
631
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
632
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
633
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
634
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
625
635
  state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_low_f32x4, b_low_f32x4);
626
636
  state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_high_f32x4, b_high_f32x4);
627
637
  }
@@ -631,10 +641,9 @@ NK_INTERNAL void nk_dot_f16x8_finalize_neon(
631
641
  nk_dot_f16x8_state_neon_t const *state_c, nk_dot_f16x8_state_neon_t const *state_d, //
632
642
  nk_size_t total_dimensions, nk_b128_vec_t *result) {
633
643
  nk_unused_(total_dimensions);
634
- result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
635
- result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
636
- result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
637
- result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
644
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
645
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
646
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
638
647
  }
639
648
 
640
649
  NK_PUBLIC void nk_dot_e4m3_neon(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
@@ -656,9 +665,9 @@ nk_dot_e4m3_neon_cycle:
656
665
  a_scalars += 8, b_scalars += 8, count_scalars -= 8;
657
666
  }
658
667
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
659
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
668
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
660
669
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
661
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
670
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
662
671
  sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
663
672
  sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
664
673
  if (count_scalars) goto nk_dot_e4m3_neon_cycle;
@@ -684,9 +693,9 @@ nk_dot_e5m2_neon_cycle:
684
693
  a_scalars += 8, b_scalars += 8, count_scalars -= 8;
685
694
  }
686
695
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
687
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
696
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
688
697
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
689
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
698
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
690
699
  sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
691
700
  sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
692
701
  if (count_scalars) goto nk_dot_e5m2_neon_cycle;
@@ -713,12 +722,10 @@ nk_dot_e2m3_neon_cycle:
713
722
  a_scalars += 16, b_scalars += 16, count_scalars -= 16;
714
723
  }
715
724
  sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_low_f16x8)), vcvt_f32_f16(vget_low_f16(b_low_f16x8)));
716
- sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_high_f16(a_low_f16x8)),
717
- vcvt_f32_f16(vget_high_f16(b_low_f16x8)));
725
+ sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_high_f32_f16(a_low_f16x8), vcvt_high_f32_f16(b_low_f16x8));
718
726
  sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_high_f16x8)),
719
727
  vcvt_f32_f16(vget_low_f16(b_high_f16x8)));
720
- sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_high_f16(a_high_f16x8)),
721
- vcvt_f32_f16(vget_high_f16(b_high_f16x8)));
728
+ sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_high_f32_f16(a_high_f16x8), vcvt_high_f32_f16(b_high_f16x8));
722
729
  if (count_scalars) goto nk_dot_e2m3_neon_cycle;
723
730
  *result = vaddvq_f32(sum_f32x4);
724
731
  }
@@ -743,19 +750,17 @@ nk_dot_e3m2_neon_cycle:
743
750
  a_scalars += 16, b_scalars += 16, count_scalars -= 16;
744
751
  }
745
752
  sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_low_f16x8)), vcvt_f32_f16(vget_low_f16(b_low_f16x8)));
746
- sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_high_f16(a_low_f16x8)),
747
- vcvt_f32_f16(vget_high_f16(b_low_f16x8)));
753
+ sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_high_f32_f16(a_low_f16x8), vcvt_high_f32_f16(b_low_f16x8));
748
754
  sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_high_f16x8)),
749
755
  vcvt_f32_f16(vget_low_f16(b_high_f16x8)));
750
- sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_high_f16(a_high_f16x8)),
751
- vcvt_f32_f16(vget_high_f16(b_high_f16x8)));
756
+ sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_high_f32_f16(a_high_f16x8), vcvt_high_f32_f16(b_high_f16x8));
752
757
  if (count_scalars) goto nk_dot_e3m2_neon_cycle;
753
758
  *result = vaddvq_f32(sum_f32x4);
754
759
  }
755
760
 
756
- #pragma endregion - Smaller Floats
761
+ #pragma endregion F16 and BF16 Floats
757
762
 
758
- #pragma region - Binary
763
+ #pragma region Binary
759
764
 
760
765
  NK_PUBLIC void nk_dot_u1_neon(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
761
766
  nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
@@ -801,7 +806,53 @@ NK_INTERNAL void nk_dot_u1x128_finalize_neon( //
801
806
  result->u32x4 = vpaddq_u32(ab_sum_u32x4, cd_sum_u32x4);
802
807
  }
803
808
 
804
- #pragma endregion - Binary
809
+ #pragma endregion Binary
810
+
811
+ NK_PUBLIC void nk_dot_f16c_neon(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
812
+ nk_f32c_t *result) {
813
+ float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
814
+ float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
815
+ while (count_pairs >= 4) {
816
+ int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
817
+ int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
818
+ float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
819
+ float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
820
+ float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
821
+ float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
822
+ sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
823
+ sum_real_f32x4 = vfmsq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
824
+ sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
825
+ sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
826
+ count_pairs -= 4, a_pairs += 4, b_pairs += 4;
827
+ }
828
+ nk_f32c_t tail_result;
829
+ nk_dot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
830
+ result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
831
+ result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
832
+ }
833
+
834
+ NK_PUBLIC void nk_vdot_f16c_neon(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
835
+ nk_f32c_t *result) {
836
+ float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
837
+ float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
838
+ while (count_pairs >= 4) {
839
+ int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
840
+ int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
841
+ float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
842
+ float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
843
+ float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
844
+ float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
845
+ sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
846
+ sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
847
+ sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
848
+ sum_imag_f32x4 = vfmsq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
849
+ count_pairs -= 4, a_pairs += 4, b_pairs += 4;
850
+ }
851
+ nk_f32c_t tail_result;
852
+ nk_vdot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
853
+ result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
854
+ result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
855
+ }
805
856
 
806
857
  #if defined(__clang__)
807
858
  #pragma clang attribute pop
@@ -8,14 +8,14 @@
8
8
  *
9
9
  * @section dot_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * A76 M4+/V1+/Oryon
13
- * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy 2/cy 4/cy
14
- * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy 2/cy 4/cy
15
- * vld1q_bf16 LD1 (V.8H) 4cy 2/cy 3/cy
16
- * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
17
- * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
18
- * vfmsq_f32 FMLS (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
11
+ * Intrinsic Instruction A76 M5
12
+ * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy @ 2p 2cy @ 1p
13
+ * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
14
+ * vld1q_bf16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
15
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy @ 1p 8cy @ 1p
16
+ * vpaddq_f32 FADDP (V.4S, V.4S, V.4S) 2cy @ 2p 3cy @ 4p
17
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
18
+ * vfmsq_f32 FMLS (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
19
19
  *
20
20
  * The ARMv8.6-BF16 extension provides the BFDOT instruction for accelerated BF16 dot products,
21
21
  * targeting machine learning inference workloads. BF16 trades mantissa precision (7 bits vs 10 in
@@ -223,10 +223,9 @@ NK_INTERNAL void nk_dot_bf16x8_finalize_neonbfdot(
223
223
  nk_dot_bf16x8_state_neonbfdot_t const *state_c, nk_dot_bf16x8_state_neonbfdot_t const *state_d, //
224
224
  nk_size_t total_dimensions, nk_b128_vec_t *result) {
225
225
  nk_unused_(total_dimensions);
226
- result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
227
- result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
228
- result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
229
- result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
226
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
227
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
228
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
230
229
  }
231
230
 
232
231
  #if defined(__clang__)
@@ -8,14 +8,15 @@
8
8
  *
9
9
  * @section dot_neonfhm_instructions ARM NEON FP16 Matrix Instructions (ARMv8.4-FHM)
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * A76 M4+/V1+/Oryon
13
- * vfmlalq_low_f16 FMLAL (V.4S, V.8H, V.8H) 4cy 2/cy 4/cy
14
- * vfmlalq_high_f16 FMLAL2 (V.4S, V.8H, V.8H) 4cy 2/cy 4/cy
15
- * vfmlslq_low_f16 FMLSL (V.4S, V.8H, V.8H) 4cy 2/cy 4/cy
16
- * vfmlslq_high_f16 FMLSL2 (V.4S, V.8H, V.8H) 4cy 2/cy 4/cy
17
- * vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
18
- * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
11
+ * Intrinsic Instruction A76 M5
12
+ * vfmlalq_low_f16 FMLAL (V.4S, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
13
+ * vfmlalq_high_f16 FMLAL2 (V.4S, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
14
+ * vfmlslq_low_f16 FMLSL (V.4S, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
15
+ * vfmlslq_high_f16 FMLSL2 (V.4S, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
16
+ * vld1q_f16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
17
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy @ 1p 8cy @ 1p
18
+ * vpaddq_f32 FADDP (V.4S, V.4S, V.4S) 2cy @ 2p 3cy @ 4p
19
+ * vshll_n_u8 SHLL (V.8H, V.8B, #8) 2cy @ 2p 2cy @ 4p
19
20
  *
20
21
  * The ARMv8.4-FHM extension (FEAT_FHM) provides FMLAL/FMLSL instructions that fuse FP16 to FP32
21
22
  * widening with multiply-accumulate in a single operation. FMLAL executes as a single fused op
@@ -90,8 +91,8 @@ nk_dot_f16_neonfhm_cycle:
90
91
  count_scalars = 0;
91
92
  }
92
93
  else {
93
- a_f16x8 = vld1q_f16((nk_f16_for_arm_simd_t const *)(a_scalars));
94
- b_f16x8 = vld1q_f16((nk_f16_for_arm_simd_t const *)(b_scalars));
94
+ a_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(a_scalars)));
95
+ b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(b_scalars)));
95
96
  a_scalars += 8, b_scalars += 8, count_scalars -= 8;
96
97
  }
97
98
  // FMLAL: widening multiply-accumulate fp16 → f32
@@ -124,10 +125,9 @@ NK_INTERNAL void nk_dot_f16x8_finalize_neonfhm(
124
125
  nk_dot_f16x8_state_neonfhm_t const *state_c, nk_dot_f16x8_state_neonfhm_t const *state_d, //
125
126
  nk_size_t total_dimensions, nk_b128_vec_t *result) {
126
127
  nk_unused_(total_dimensions);
127
- result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
128
- result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
129
- result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
130
- result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
128
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
129
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
130
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
131
131
  }
132
132
 
133
133
  NK_PUBLIC void nk_dot_f16c_neonfhm(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
@@ -220,58 +220,58 @@ NK_PUBLIC void nk_vdot_f16c_neonfhm(nk_f16c_t const *a_pairs, nk_f16c_t const *b
220
220
 
221
221
  NK_PUBLIC void nk_dot_e4m3_neonfhm(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
222
222
  nk_f32_t *result) {
223
- float16x8_t a_low, a_high, b_low, b_high;
223
+ float16x8_t a_low_f16x8, a_high_f16x8, b_low_f16x8, b_high_f16x8;
224
224
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
225
225
  nk_dot_e4m3_neonfhm_cycle:
226
226
  if (count_scalars < 16) {
227
227
  nk_b128_vec_t a_vec, b_vec;
228
228
  nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
229
229
  nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
230
- nk_e4m3x16_to_f16x8x2_neon_(a_vec.u8x16, &a_low, &a_high);
231
- nk_e4m3x16_to_f16x8x2_neon_(b_vec.u8x16, &b_low, &b_high);
230
+ nk_e4m3x16_to_f16x8x2_neon_(a_vec.u8x16, &a_low_f16x8, &a_high_f16x8);
231
+ nk_e4m3x16_to_f16x8x2_neon_(b_vec.u8x16, &b_low_f16x8, &b_high_f16x8);
232
232
  count_scalars = 0;
233
233
  }
234
234
  else {
235
- nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(a_scalars), &a_low, &a_high);
236
- nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(b_scalars), &b_low, &b_high);
235
+ nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(a_scalars), &a_low_f16x8, &a_high_f16x8);
236
+ nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(b_scalars), &b_low_f16x8, &b_high_f16x8);
237
237
  a_scalars += 16, b_scalars += 16, count_scalars -= 16;
238
238
  }
239
- sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_low, b_low);
240
- sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_low, b_low);
241
- sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_high, b_high);
242
- sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_high, b_high);
239
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_low_f16x8, b_low_f16x8);
240
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_low_f16x8, b_low_f16x8);
241
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_high_f16x8, b_high_f16x8);
242
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_high_f16x8, b_high_f16x8);
243
243
  if (count_scalars) goto nk_dot_e4m3_neonfhm_cycle;
244
244
  *result = vaddvq_f32(sum_f32x4);
245
245
  }
246
246
 
247
247
  NK_PUBLIC void nk_dot_e5m2_neonfhm(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
248
248
  nk_f32_t *result) {
249
- float16x8_t a_low, a_high, b_low, b_high;
249
+ float16x8_t a_low_f16x8, a_high_f16x8, b_low_f16x8, b_high_f16x8;
250
250
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
251
251
  nk_dot_e5m2_neonfhm_cycle:
252
252
  if (count_scalars < 16) {
253
253
  nk_b128_vec_t a_vec, b_vec;
254
254
  nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
255
255
  nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
256
- a_low = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a_vec.u8x16), 8));
257
- a_high = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(a_vec.u8x16), 8));
258
- b_low = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b_vec.u8x16), 8));
259
- b_high = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(b_vec.u8x16), 8));
256
+ a_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a_vec.u8x16), 8));
257
+ a_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(a_vec.u8x16, 8));
258
+ b_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b_vec.u8x16), 8));
259
+ b_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(b_vec.u8x16, 8));
260
260
  count_scalars = 0;
261
261
  }
262
262
  else {
263
263
  uint8x16_t a_u8x16 = vld1q_u8(a_scalars);
264
264
  uint8x16_t b_u8x16 = vld1q_u8(b_scalars);
265
- a_low = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a_u8x16), 8));
266
- a_high = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(a_u8x16), 8));
267
- b_low = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b_u8x16), 8));
268
- b_high = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(b_u8x16), 8));
265
+ a_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a_u8x16), 8));
266
+ a_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(a_u8x16, 8));
267
+ b_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b_u8x16), 8));
268
+ b_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(b_u8x16, 8));
269
269
  a_scalars += 16, b_scalars += 16, count_scalars -= 16;
270
270
  }
271
- sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_low, b_low);
272
- sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_low, b_low);
273
- sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_high, b_high);
274
- sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_high, b_high);
271
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_low_f16x8, b_low_f16x8);
272
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_low_f16x8, b_low_f16x8);
273
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_high_f16x8, b_high_f16x8);
274
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_high_f16x8, b_high_f16x8);
275
275
  if (count_scalars) goto nk_dot_e5m2_neonfhm_cycle;
276
276
  *result = vaddvq_f32(sum_f32x4);
277
277
  }
@@ -304,10 +304,9 @@ NK_INTERNAL void nk_dot_e4m3x16_finalize_neonfhm(
304
304
  nk_dot_e4m3x16_state_neonfhm_t const *state_c, nk_dot_e4m3x16_state_neonfhm_t const *state_d, //
305
305
  nk_size_t total_dimensions, nk_b128_vec_t *result) {
306
306
  nk_unused_(total_dimensions);
307
- result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
308
- result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
309
- result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
310
- result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
307
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
308
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
309
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
311
310
  }
312
311
 
313
312
  typedef struct nk_dot_e5m2x16_state_neonfhm_t {
@@ -324,9 +323,9 @@ NK_INTERNAL void nk_dot_e5m2x16_update_neonfhm(nk_dot_e5m2x16_state_neonfhm_t *s
324
323
  nk_unused_(active_dimensions);
325
324
  // Convert e5m2 → f16 via SHLL: widen u8→u16 and shift left 8 in one instruction
326
325
  float16x8_t a_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a.u8x16), 8));
327
- float16x8_t a_high_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(a.u8x16), 8));
326
+ float16x8_t a_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(a.u8x16, 8));
328
327
  float16x8_t b_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b.u8x16), 8));
329
- float16x8_t b_high_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(b.u8x16), 8));
328
+ float16x8_t b_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(b.u8x16, 8));
330
329
  // FMLAL: widening multiply-accumulate fp16 → f32
331
330
  state->sum_f32x4 = vfmlalq_low_f16(state->sum_f32x4, a_low_f16x8, b_low_f16x8);
332
331
  state->sum_f32x4 = vfmlalq_high_f16(state->sum_f32x4, a_low_f16x8, b_low_f16x8);
@@ -339,10 +338,9 @@ NK_INTERNAL void nk_dot_e5m2x16_finalize_neonfhm(
339
338
  nk_dot_e5m2x16_state_neonfhm_t const *state_c, nk_dot_e5m2x16_state_neonfhm_t const *state_d, //
340
339
  nk_size_t total_dimensions, nk_b128_vec_t *result) {
341
340
  nk_unused_(total_dimensions);
342
- result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
343
- result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
344
- result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
345
- result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
341
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
342
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
343
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
346
344
  }
347
345
 
348
346
  #if defined(__clang__)