numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,2066 @@
1
+ /**
2
+ * @brief FlashAttention-style kernels for SME.
3
+ * @file include/numkong/attention/sme.h
4
+ * @author Ash Vardanian
5
+ * @date January 11, 2026
6
+ *
7
+ * @sa include/numkong/attention.h
8
+ *
9
+ * This file implements FlashAttention-2 style scaled dot-product attention (SDPA) optimized
10
+ * for ARM SME instructions on Apple M4 and similar processors. The kernel computes:
11
+ *
12
+ * O = softmax(Q × Kᵀ / √d) × V
13
+ *
14
+ * Key features:
15
+ * - Online softmax: Mathematically exact, processes KV blocks incrementally
16
+ * - Pre-packed KV cache: BFMOPA/FMOPA-interleaved format amortizes packing for repeated inference
17
+ * - GQA/MQA support: Different `num_heads` and `num_kv_heads` for grouped-query attention
18
+ * - Pure Streaming SVE: No NEON intrinsics for non-linear operations
19
+ *
20
+ * Target models (2025):
21
+ * - Kimi K2: `head_dim`=112, 64 heads, MHA, 128K context
22
+ * - LLaMA 3.1 405B: `head_dim`=128, 128 heads, 16 KV heads (GQA 8:1), 128K context
23
+ * - Qwen 2.5 72B: `head_dim`=128, 64 heads, 8 KV heads (GQA 8:1), 32K context
24
+ *
25
+ * @section attention_sme_architecture Architecture
26
+ *
27
+ * Both Q×Kᵀ and P×V phases use BFMOPA/FMOPA outer products on ZA tiles, eliminating
28
+ * element-wise SVE loops that dominated the original implementation. The Q matrix is
29
+ * pre-transposed once into a buffer matching the interleaving that ZA vertical reads
30
+ * would produce, so Q×Kᵀ runs as pure memory-to-BFMOPA with no per-block ZA staging.
31
+ *
32
+ * Block sizes:
33
+ * - Bᵣ = 16 (query block rows, matches ZA32 tile height)
34
+ * - Bᶜ = 32 (main prefill loop, processes two KV blocks per iteration using ZA2+ZA3)
35
+ * - Bᶜ = 16 (tail loop for remaining KV positions, and decode path)
36
+ *
37
+ * KV packing format:
38
+ * - K is stored in BFMOPA-interleaved format: `K_packed[kv_block][depth_step][32]` where
39
+ * `packed[2*ki + sub] = K[kv_block*16 + ki][2*depth_step + sub]`
40
+ * - V is stored in BFMOPA-interleaved format: `V_packed[kv_block][dim_tile][depth_step][32]`
41
+ * where `packed[2*dj + sub] = V[kv_block*16 + 2*depth_step + sub][dim_tile*16 + dj]`
42
+ * - The `reserved[0]` header field stores `v_dim_tile_count` for efficient V addressing
43
+ *
44
+ * Softmax:
45
+ * - Column-wise max and exp using ZA tile vertical reads (avoids per-row horizontal extracts)
46
+ * - Correction skip: when the block max does not exceed the running max, the output
47
+ * accumulator rescaling is skipped entirely (common in later KV blocks)
48
+ * - Degree-3 fast exp (`nk_exp_fast_f32_sve_`) saves 1 FMA per call vs degree-4
49
+ * - Weights stored directly as bf16/f16 in ZA0 columns via `svzip1` (no f32 round-trip)
50
+ *
51
+ * Decode path (query_len=1):
52
+ * - Uses element-wise SVE with scalar weight broadcasts instead of BFMOPA P×V
53
+ * - BFMOPA overhead too high for single-query case due to ZA setup cost
54
+ *
55
+ * P×V prefill path:
56
+ * - 4-tile BFMOPA processing: 4 dim-tiles × 8 depth steps per KV block = 32 BFMOPA ops
57
+ * - ZA0-ZA3 accumulate simultaneously, read results with MOVA, add to output accumulator
58
+ * - Remainder dim-tiles handled 1-at-a-time using ZA0 only
59
+ *
60
+ * SME tile dimensions (for SVL=512, i.e., Apple M4):
61
+ * - ZA32 tile: 16 × 16 `f32` elements (1KB)
62
+ * - `bf16`/`f16` vectors: 32 elements per SVE vector
63
+ *
64
+ * @section attention_sme_history Optimization History
65
+ *
66
+ * Phase 1 (January 2026): Initial implementation using ZA staging transpose for Q×Kᵀ
67
+ * and element-wise SVE for P×V. Q and K rows were loaded into ZA0/ZA1 horizontally,
68
+ * read back vertically to produce interleaved vectors for BFMOPA. The P×V phase used
69
+ * scalar `svmla_f32_x` loops over head_dim for each query-key pair. Softmax used
70
+ * degree-4 polynomial exp with per-row horizontal max/sum. Performance: ~25-50 GFLOP/s
71
+ * on Apple M4 (bf16, 8 heads, query_len=64, kv_len=4096, head_dim=128).
72
+ *
73
+ * Phase 2 (February 2026): BFMOPA/FMOPA P×V with pre-packed V in interleaved format.
74
+ * Key changes integrated:
75
+ * - Q pre-transposed once into a buffer, eliminating per-block ZA staging for Q
76
+ * - K pre-packed in interleaved format, enabling pure memory-to-BFMOPA Q×Kᵀ
77
+ * - V pre-packed in BFMOPA-interleaved format with dim-tile blocking
78
+ * - P×V uses 4-tile BFMOPA accumulation (ZA0-ZA3) with pre-extracted P columns
79
+ * - Bᶜ=32 main loop for prefill (2 KV blocks per iteration via ZA2+ZA3)
80
+ * - Column-wise softmax: vertical ZA reads for max/exp instead of per-row horizontal
81
+ * - Correction skip when running max is unchanged
82
+ * - Degree-3 fast exp (~0.5% max relative error, saves 1 FMA per call)
83
+ * - Weights stored directly as bf16/f16 via `svzip1` (no f32 quantization round-trip)
84
+ * Performance: ~300-400 GFLOP/s on Apple M4 (same configuration), a 6-14× improvement.
85
+ *
86
+ * Rejected approaches:
87
+ * - BFMOPA P×V for decode (query_len=1): ZA setup overhead exceeds element-wise SVE cost
88
+ * - `svdot_lane` for Q×Kᵀ: lower throughput than BFMOPA on M4
89
+ * - Shared ZA tiles between softmax and P×V: register pressure too high with 4-tile P×V
90
+ */
91
+ #ifndef NK_ATTENTION_SME_H
92
+ #define NK_ATTENTION_SME_H
93
+
94
+ #if NK_TARGET_ARM_
95
+ #if NK_TARGET_SME
96
+
97
+ #include "numkong/types.h"
98
+
99
+ #if defined(__cplusplus)
100
+ extern "C" {
101
+ #endif
102
+
103
+ #if defined(__clang__)
104
+ #pragma clang attribute push(__attribute__((target("sme,sve"))), apply_to = function)
105
+ #elif defined(__GNUC__)
106
+ #pragma GCC push_options
107
+ #pragma GCC target("+sme")
108
+ #endif
109
+
110
+ /**
111
+ * @brief Convert bf16 vector to f32 in registers (streaming SVE compatible).
112
+ *
113
+ * BF16 is the upper 16 bits of F32, so we:
114
+ * 1. Reinterpret bf16 as u16
115
+ * 2. Zero-extend to u32 (unpklo for lower half)
116
+ * 3. Shift left by 16 to place in f32 exponent+mantissa position
117
+ * 4. Reinterpret as f32
118
+ */
119
+ NK_INTERNAL svfloat32_t nk_bf16_to_f32_sve_(svbool_t predicate_f32x, svbfloat16_t x_bf16x) __arm_streaming {
120
+ svuint16_t x_u16x = svreinterpret_u16_bf16(x_bf16x);
121
+ svuint32_t x_u32x = svunpklo_u32(x_u16x);
122
+ x_u32x = svlsl_n_u32_x(predicate_f32x, x_u32x, 16);
123
+ return svreinterpret_f32_u32(x_u32x);
124
+ }
125
+
126
+ /**
127
+ * @brief Convert f32 vector to bf16 in registers with rounding (streaming SVE compatible).
128
+ *
129
+ * 1. Reinterpret f32 as u32
130
+ * 2. Add rounding bias (0x8000) for round-to-nearest
131
+ * 3. Shift right by 16
132
+ * 4. Narrow to u16 and reinterpret as bf16
133
+ */
134
+ NK_INTERNAL svbfloat16_t nk_f32_to_bf16_sve_(svbool_t predicate_f32x, svfloat32_t x_f32x) __arm_streaming {
135
+ svuint32_t x_u32x = svreinterpret_u32_f32(x_f32x);
136
+ x_u32x = svadd_n_u32_x(predicate_f32x, x_u32x, 0x8000); // Round to nearest
137
+ x_u32x = svlsr_n_u32_x(predicate_f32x, x_u32x, 16);
138
+ svuint16_t x_u16x = svuzp1_u16(svreinterpret_u16_u32(x_u32x), svreinterpret_u16_u32(x_u32x));
139
+ return svreinterpret_bf16_u16(x_u16x);
140
+ }
141
+
142
+ /**
143
+ * @brief Packed KV cache header for attention (64-byte aligned).
144
+ *
145
+ * Layout in memory:
146
+ * [header: 64 bytes][K tiles: variable][V tiles: variable]
147
+ */
148
+ typedef struct {
149
+ nk_u32_t num_kv_heads; ///< Number of K/V heads (for GQA, may differ from Q heads)
150
+ nk_u32_t head_dim; ///< Original head dimension (64, 112, 128)
151
+ nk_u32_t head_dim_padded; ///< Padded to multiple of 32 for SME
152
+ nk_u32_t seq_len; ///< Current sequence length
153
+ nk_u32_t max_seq_len; ///< Maximum sequence length (for pre-allocation)
154
+ nk_u32_t k_offset; ///< Byte offset to K data from header start
155
+ nk_u32_t v_offset; ///< Byte offset to V data from header start
156
+ nk_u32_t reserved[9]; ///< reserved[0] = v_dim_tile_count; remainder pads to 64 bytes
157
+ } nk_attention_sme_packed_header_t;
158
+
159
+ /**
160
+ * @brief Fast exp approximation in Streaming SVE.
161
+ *
162
+ * Uses Cody-Waite range reduction + Horner polynomial (degree 4).
163
+ * Accuracy: ~0.1% relative error, acceptable for softmax normalization.
164
+ *
165
+ * @param pg Active predicate
166
+ * @param x Input vector
167
+ * @return exp(x) approximation
168
+ */
169
+ NK_INTERNAL svfloat32_t nk_exp_f32_sve_(svbool_t predicate_f32x, svfloat32_t x_f32x) __arm_streaming {
170
+ // Constants for Cody-Waite range reduction
171
+ svfloat32_t log2e_f32x = svdup_f32(1.4426950408889634f);
172
+ svfloat32_t ln2_hi_f32x = svdup_f32(0.693145751953125f);
173
+ svfloat32_t ln2_lo_f32x = svdup_f32(1.42860682030941723212e-6f);
174
+
175
+ // Clamp to avoid overflow/underflow
176
+ svfloat32_t max_x_f32x = svdup_f32(88.3762626647949f);
177
+ svfloat32_t min_x_f32x = svdup_f32(-87.3365447504021f);
178
+ x_f32x = svmax_f32_m(predicate_f32x, svmin_f32_m(predicate_f32x, x_f32x, max_x_f32x), min_x_f32x);
179
+
180
+ // n = round(x / ln(2))
181
+ svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_f32x, svmul_f32_m(predicate_f32x, x_f32x, log2e_f32x));
182
+
183
+ // r = x - n × ln(2) using Cody-Waite for precision
184
+ svfloat32_t r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_hi_f32x, x_f32x);
185
+ r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_lo_f32x, r_f32x);
186
+
187
+ // Polynomial approximation for exp(r): degree 4
188
+ // exp(r) ≈ 1 + r + r²/2 + r³/6 + r⁴/24
189
+ svfloat32_t p_f32x = svdup_f32(4.1666666667e-2f); // 1/24
190
+ p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.6666666667e-1f)); // 1/6
191
+ p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
192
+ p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
193
+ p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
194
+
195
+ // Reconstruct: exp(x) = 2ⁿ × exp(r)
196
+ // 2ⁿ via IEEE 754 exponent manipulation
197
+ svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_f32x, n_f32x);
198
+ n_i32x = svadd_s32_m(predicate_f32x, n_i32x, svdup_s32(127));
199
+ n_i32x = svlsl_n_s32_m(predicate_f32x, n_i32x, 23);
200
+ svfloat32_t pow2n_f32x = svreinterpret_f32_s32(n_i32x);
201
+
202
+ return svmul_f32_m(predicate_f32x, p_f32x, pow2n_f32x);
203
+ }
204
+
205
+ /**
206
+ * @brief Degree-3 fast exp approximation. Max relative error ~0.5%.
207
+ * Saves 1 FMA per call vs degree-4 nk_exp_f32_sve_.
208
+ */
209
+ NK_INTERNAL svfloat32_t nk_exp_fast_f32_sve_(svbool_t predicate_f32x, svfloat32_t x_f32x) __arm_streaming {
210
+ svfloat32_t log2e_f32x = svdup_f32(1.4426950408889634f);
211
+ svfloat32_t ln2_hi_f32x = svdup_f32(0.693145751953125f);
212
+ svfloat32_t ln2_lo_f32x = svdup_f32(1.42860682030941723212e-6f);
213
+
214
+ svfloat32_t max_x_f32x = svdup_f32(88.3762626647949f);
215
+ svfloat32_t min_x_f32x = svdup_f32(-87.3365447504021f);
216
+ x_f32x = svmax_f32_m(predicate_f32x, svmin_f32_m(predicate_f32x, x_f32x, max_x_f32x), min_x_f32x);
217
+
218
+ svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_f32x, svmul_f32_m(predicate_f32x, x_f32x, log2e_f32x));
219
+ svfloat32_t r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_hi_f32x, x_f32x);
220
+ r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_lo_f32x, r_f32x);
221
+
222
+ // Degree-3: exp(r) ~ 1 + r + r^2/2 + r^3/6 (drop 1/24 term)
223
+ svfloat32_t p_f32x = svdup_f32(1.6666666667e-1f); // 1/6
224
+ p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
225
+ p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
226
+ p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
227
+
228
+ svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_f32x, n_f32x);
229
+ n_i32x = svadd_s32_m(predicate_f32x, n_i32x, svdup_s32(127));
230
+ n_i32x = svlsl_n_s32_m(predicate_f32x, n_i32x, 23);
231
+ svfloat32_t pow2n_f32x = svreinterpret_f32_s32(n_i32x);
232
+
233
+ return svmul_f32_m(predicate_f32x, p_f32x, pow2n_f32x);
234
+ }
235
+
236
+ NK_PUBLIC nk_size_t nk_attention_packed_kv_size_bf16_sme(nk_size_t num_kv_heads, nk_size_t head_dim,
237
+ nk_size_t max_seq_len) {
238
+ nk_size_t head_dim_padded = (head_dim + 31) / 32 * 32;
239
+ nk_size_t kv_blocks = (max_seq_len + 15) / 16;
240
+ nk_size_t seq_padded = kv_blocks * 16;
241
+ // K and V both use BFMOPA-interleaved format: [num_kv_heads, kv_blocks, depth_steps, 32]
242
+ nk_size_t k_size = num_kv_heads * seq_padded * head_dim_padded * sizeof(nk_bf16_t);
243
+ nk_size_t v_size = k_size;
244
+ return sizeof(nk_attention_sme_packed_header_t) + k_size + v_size;
245
+ }
246
+
247
+ NK_PUBLIC nk_size_t nk_attention_packed_kv_size_f16_sme(nk_size_t num_kv_heads, nk_size_t head_dim,
248
+ nk_size_t max_seq_len) {
249
+ return nk_attention_packed_kv_size_bf16_sme(num_kv_heads, head_dim, max_seq_len);
250
+ }
251
+
252
+ __arm_locally_streaming static void nk_attention_pack_kv_bf16_sme_streaming_(nk_bf16_t const *k, nk_bf16_t const *v,
253
+ nk_size_t num_kv_heads, nk_size_t head_dim,
254
+ nk_size_t seq_len, nk_size_t k_stride,
255
+ nk_size_t v_stride, void *kv_packed) {
256
+
257
+ nk_attention_sme_packed_header_t *header = (nk_attention_sme_packed_header_t *)kv_packed;
258
+ nk_size_t head_dim_padded = (head_dim + 31) / 32 * 32;
259
+ nk_size_t dim_tile_count = (head_dim_padded + 15) / 16;
260
+ nk_size_t kv_block_count = (seq_len + 15) / 16;
261
+
262
+ nk_size_t k_depth_step_count = head_dim_padded / 2;
263
+ nk_size_t head_elems = kv_block_count * 16 * head_dim_padded;
264
+
265
+ header->num_kv_heads = (nk_u32_t)num_kv_heads;
266
+ header->head_dim = (nk_u32_t)head_dim;
267
+ header->head_dim_padded = (nk_u32_t)head_dim_padded;
268
+ header->seq_len = (nk_u32_t)seq_len;
269
+ header->k_offset = sizeof(nk_attention_sme_packed_header_t);
270
+ header->reserved[0] = (nk_u32_t)dim_tile_count; // v_dim_tile_count
271
+ header->v_offset = header->k_offset + (nk_u32_t)(num_kv_heads * head_elems * sizeof(nk_bf16_t));
272
+
273
+ nk_bf16_t *k_packed = (nk_bf16_t *)((char *)kv_packed + header->k_offset);
274
+ nk_bf16_t *v_packed = (nk_bf16_t *)((char *)kv_packed + header->v_offset);
275
+
276
+ for (nk_size_t h = 0; h < num_kv_heads; h++) {
277
+ nk_bf16_t const *k_head = k + h * k_stride;
278
+ nk_bf16_t const *v_head = v + h * v_stride;
279
+
280
+ // K packing: BFMOPA-interleaved format
281
+ // K_packed[kv_block][depth_step][32] where
282
+ // packed[2*ki + sub] = K[kv_block*16 + ki][2*depth_step + sub]
283
+ nk_bf16_t *k_out = k_packed + h * head_elems;
284
+ for (nk_size_t kv_block = 0; kv_block < kv_block_count; kv_block++) {
285
+ for (nk_size_t depth_step = 0; depth_step < k_depth_step_count; depth_step++) {
286
+ nk_bf16_t *vec_out = k_out + (kv_block * k_depth_step_count + depth_step) * 32;
287
+ for (nk_size_t ki = 0; ki < 16; ki++) {
288
+ for (nk_size_t sub = 0; sub < 2; sub++) {
289
+ nk_size_t row = kv_block * 16 + ki;
290
+ nk_size_t col = 2 * depth_step + sub;
291
+ nk_bf16_t zero = {0};
292
+ vec_out[2 * ki + sub] = (row < seq_len && col < head_dim) ? k_head[row * head_dim + col] : zero;
293
+ }
294
+ }
295
+ }
296
+ }
297
+
298
+ // V packing: BFMOPA-interleaved format
299
+ nk_bf16_t *v_out = v_packed + h * head_elems;
300
+ for (nk_size_t kv_block = 0; kv_block < kv_block_count; kv_block++) {
301
+ for (nk_size_t dim_tile = 0; dim_tile < dim_tile_count; dim_tile++) {
302
+ for (nk_size_t depth_step = 0; depth_step < 8; depth_step++) {
303
+ nk_bf16_t *vec_out = v_out + (kv_block * dim_tile_count * 8 + dim_tile * 8 + depth_step) * 32;
304
+ for (nk_size_t dj = 0; dj < 16; dj++) {
305
+ for (nk_size_t sub = 0; sub < 2; sub++) {
306
+ nk_size_t ki = kv_block * 16 + 2 * depth_step + sub;
307
+ nk_size_t d = dim_tile * 16 + dj;
308
+ nk_bf16_t zero = {0};
309
+ vec_out[2 * dj + sub] = (ki < seq_len && d < head_dim) ? v_head[ki * head_dim + d] : zero;
310
+ }
311
+ }
312
+ }
313
+ }
314
+ }
315
+ }
316
+ }
317
+
318
+ NK_PUBLIC void nk_attention_pack_kv_bf16_sme(nk_bf16_t const *k, nk_bf16_t const *v, nk_size_t num_kv_heads,
319
+ nk_size_t head_dim, nk_size_t seq_len, nk_size_t k_stride,
320
+ nk_size_t v_stride, void *kv_packed) {
321
+ nk_attention_pack_kv_bf16_sme_streaming_(k, v, num_kv_heads, head_dim, seq_len, k_stride, v_stride, kv_packed);
322
+ }
323
+
324
+ __arm_locally_streaming static void nk_attention_pack_kv_f16_sme_streaming_(nk_f16_t const *k, nk_f16_t const *v,
325
+ nk_size_t num_kv_heads, nk_size_t head_dim,
326
+ nk_size_t seq_len, nk_size_t k_stride,
327
+ nk_size_t v_stride, void *kv_packed) {
328
+
329
+ nk_attention_sme_packed_header_t *header = (nk_attention_sme_packed_header_t *)kv_packed;
330
+ nk_size_t head_dim_padded = (head_dim + 31) / 32 * 32;
331
+ nk_size_t dim_tile_count = (head_dim_padded + 15) / 16;
332
+ nk_size_t kv_block_count = (seq_len + 15) / 16;
333
+
334
+ nk_size_t k_depth_step_count = head_dim_padded / 2;
335
+ nk_size_t head_elems = kv_block_count * 16 * head_dim_padded;
336
+
337
+ header->num_kv_heads = (nk_u32_t)num_kv_heads;
338
+ header->head_dim = (nk_u32_t)head_dim;
339
+ header->head_dim_padded = (nk_u32_t)head_dim_padded;
340
+ header->seq_len = (nk_u32_t)seq_len;
341
+ header->k_offset = sizeof(nk_attention_sme_packed_header_t);
342
+ header->reserved[0] = (nk_u32_t)dim_tile_count; // v_dim_tile_count
343
+ header->v_offset = header->k_offset + (nk_u32_t)(num_kv_heads * head_elems * sizeof(nk_f16_t));
344
+
345
+ nk_f16_t *k_packed = (nk_f16_t *)((char *)kv_packed + header->k_offset);
346
+ nk_f16_t *v_packed = (nk_f16_t *)((char *)kv_packed + header->v_offset);
347
+
348
+ for (nk_size_t h = 0; h < num_kv_heads; h++) {
349
+ nk_f16_t const *k_head = k + h * k_stride;
350
+ nk_f16_t const *v_head = v + h * v_stride;
351
+
352
+ // K packing: FMOPA-interleaved format
353
+ nk_f16_t *k_out = k_packed + h * head_elems;
354
+ for (nk_size_t kv_block = 0; kv_block < kv_block_count; kv_block++) {
355
+ for (nk_size_t depth_step = 0; depth_step < k_depth_step_count; depth_step++) {
356
+ nk_f16_t *vec_out = k_out + (kv_block * k_depth_step_count + depth_step) * 32;
357
+ for (nk_size_t ki = 0; ki < 16; ki++) {
358
+ for (nk_size_t sub = 0; sub < 2; sub++) {
359
+ nk_size_t row = kv_block * 16 + ki;
360
+ nk_size_t col = 2 * depth_step + sub;
361
+ nk_f16_t zero = {0};
362
+ vec_out[2 * ki + sub] = (row < seq_len && col < head_dim) ? k_head[row * head_dim + col] : zero;
363
+ }
364
+ }
365
+ }
366
+ }
367
+
368
+ // V packing: FMOPA-interleaved format
369
+ nk_f16_t *v_out = v_packed + h * head_elems;
370
+ for (nk_size_t kv_block = 0; kv_block < kv_block_count; kv_block++) {
371
+ for (nk_size_t dim_tile = 0; dim_tile < dim_tile_count; dim_tile++) {
372
+ for (nk_size_t depth_step = 0; depth_step < 8; depth_step++) {
373
+ nk_f16_t *vec_out = v_out + (kv_block * dim_tile_count * 8 + dim_tile * 8 + depth_step) * 32;
374
+ for (nk_size_t dj = 0; dj < 16; dj++) {
375
+ for (nk_size_t sub = 0; sub < 2; sub++) {
376
+ nk_size_t ki = kv_block * 16 + 2 * depth_step + sub;
377
+ nk_size_t d = dim_tile * 16 + dj;
378
+ nk_f16_t zero = {0};
379
+ vec_out[2 * dj + sub] = (ki < seq_len && d < head_dim) ? v_head[ki * head_dim + d] : zero;
380
+ }
381
+ }
382
+ }
383
+ }
384
+ }
385
+ }
386
+ }
387
+
388
+ NK_PUBLIC void nk_attention_pack_kv_f16_sme(nk_f16_t const *k, nk_f16_t const *v, nk_size_t num_kv_heads,
389
+ nk_size_t head_dim, nk_size_t seq_len, nk_size_t k_stride,
390
+ nk_size_t v_stride, void *kv_packed) {
391
+ nk_attention_pack_kv_f16_sme_streaming_(k, v, num_kv_heads, head_dim, seq_len, k_stride, v_stride, kv_packed);
392
+ }
393
+
394
+ /**
395
+ * @brief Optimized bf16 attention kernel with BFMOPA P×V.
396
+ *
397
+ * Key design choices:
398
+ * - P×V uses BFMOPA with pre-packed V (4-tile accumulation) instead of element-wise SVE
399
+ * - Scores read via column-wise vertical ZA reads for vectorized max/exp
400
+ * - Weights stored directly as bf16 (no f32 round-trip)
401
+ * - Uses degree-3 fast exp for softmax
402
+ * - Correction skip when running max is unchanged
403
+ * - Decode path (valid_query_count==1) remains element-wise SVE (BFMOPA overhead too high)
404
+ */
405
+ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_streaming_(
406
+ nk_bf16_t const *q, // [query_len, head_dim]
407
+ nk_bf16_t const *k, // [kv_len, head_dim_padded] BFMOPA-interleaved
408
+ nk_bf16_t const *v_packed, // BFMOPA-interleaved V for this KV head
409
+ nk_bf16_t *output, // [query_len, head_dim]
410
+ nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim, nk_size_t head_dim_padded, nk_size_t dim_tile_count,
411
+ nk_f32_t scale) {
412
+
413
+ svbool_t const predicate_all_f32x = svptrue_b32();
414
+ svbool_t const predicate_all_f16x = svptrue_b16();
415
+ nk_size_t const valid_query_count = (query_len < 16) ? query_len : 16;
416
+
417
+ svfloat32_t row_max_f32x = svdup_f32(NK_F32_MIN);
418
+ svfloat32_t row_sum_f32x = svdup_f32(0.0f);
419
+
420
+ NK_ALIGN64 nk_f32_t output_accumulator[16 * 256];
421
+ svfloat32_t zero_f32x = svdup_f32(0.0f);
422
+ for (nk_size_t i = 0; i < 16 * head_dim_padded; i += svcntw()) {
423
+ svst1_f32(predicate_all_f32x, output_accumulator + i, zero_f32x);
424
+ }
425
+
426
+ nk_size_t kv_block_index = 0;
427
+ nk_size_t kv_start = 0;
428
+ svbool_t const batch_predicate_f32x = svwhilelt_b32(0u, 16u);
429
+
430
+ nk_size_t const k_depth_step_count = head_dim_padded / 2;
431
+
432
+ // Pre-transpose Q once: queries_transposed[step][16 f32 words]
433
+ NK_ALIGN64 nk_f32_t queries_transposed[128 * 16]; // max head_dim_padded/2 * 16 = 128 * 16
434
+ for (nk_size_t batch = 0; batch < head_dim_padded / 32; batch++) {
435
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
436
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
437
+ svld1_hor_za32(0, query_index, batch_predicate_f32x,
438
+ (nk_f32_t const *)(q + query_index * head_dim + batch * 32));
439
+ for (nk_size_t step = 0; step < 16; step++)
440
+ svst1_f32(predicate_all_f32x, queries_transposed + (batch * 16 + step) * 16,
441
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, step));
442
+ }
443
+
444
+ // Bc=32 main loop (prefill only, skipped for decode)
445
+ if (valid_query_count > 1) {
446
+ for (; kv_start + 32 <= kv_len; kv_start += 32, kv_block_index += 2) {
447
+ // Q×K^T: pure memory→BFMOPA, no ZA staging for Q or K
448
+ svzero_mask_za(nk_sme_zero_za32_tile_2_);
449
+ svzero_mask_za(nk_sme_zero_za32_tile_3_);
450
+ nk_bf16_t const *keys_block_lower = k + kv_block_index * k_depth_step_count * 32;
451
+ nk_bf16_t const *keys_block_upper = k + (kv_block_index + 1) * k_depth_step_count * 32;
452
+ for (nk_size_t step = 0; step < k_depth_step_count; step++) {
453
+ svbfloat16_t zn = svreinterpret_bf16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
454
+ svbfloat16_t zm0 = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(keys_block_lower + step * 32));
455
+ svbfloat16_t zm1 = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(keys_block_upper + step * 32));
456
+ svmopa_za32_bf16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm0);
457
+ svmopa_za32_bf16_m(3, predicate_all_f32x, predicate_all_f32x, zn, zm1);
458
+ }
459
+
460
+ // Pass 1: Column-wise max (read ZA2/ZA3 columns vertically)
461
+ svfloat32_t scale_f32x = svdup_f32(scale);
462
+ svfloat32_t block_max_f32x = svdup_f32(NK_F32_MIN);
463
+ for (nk_size_t column_index = 0; column_index < 16; column_index++) {
464
+ svfloat32_t score_column_f32x = svmul_f32_x(
465
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
466
+ scale_f32x);
467
+ block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
468
+ }
469
+ for (nk_size_t column_index = 0; column_index < 16; column_index++) {
470
+ svfloat32_t score_column_f32x = svmul_f32_x(
471
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
472
+ scale_f32x);
473
+ block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
474
+ }
475
+
476
+ // Softmax correction (fully vectorized)
477
+ svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, row_max_f32x, block_max_f32x);
478
+ svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
479
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, row_max_f32x, new_max_f32x));
480
+ svbool_t max_changed = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
481
+ nk_u32_t max_was_updated = svptest_any(predicate_all_f32x, max_changed) ? 1 : 0;
482
+ if (max_was_updated) row_sum_f32x = svmul_f32_x(predicate_all_f32x, row_sum_f32x, correction_f32x);
483
+ NK_ALIGN64 nk_f32_t corrections[16];
484
+ svst1_f32(predicate_all_f32x, corrections, correction_f32x);
485
+
486
+ // Pass 2: Column-wise exp + fused P write + sum
487
+ svfloat32_t sum_delta_f32x = svdup_f32(0.0f);
488
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
489
+ // ZA2 columns in pairs → ZA0 columns 0-7
490
+ for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
491
+ svfloat32_t score_even_f32x = svmul_f32_x(
492
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
493
+ scale_f32x);
494
+ svfloat32_t score_odd_f32x = svmul_f32_x(
495
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
496
+ scale_f32x);
497
+ svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
498
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
499
+ svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
500
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
501
+ sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
502
+ sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
503
+ svbfloat16_t weight_pair_bf16 = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_f32x, weight_even_f32x),
504
+ nk_f32_to_bf16_sve_(predicate_all_f32x, weight_odd_f32x));
505
+ svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x,
506
+ svreinterpret_f32_bf16(weight_pair_bf16));
507
+ }
508
+ // ZA3 columns in pairs → ZA0 columns 8-15
509
+ for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
510
+ svfloat32_t score_even_f32x = svmul_f32_x(
511
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
512
+ scale_f32x);
513
+ svfloat32_t score_odd_f32x = svmul_f32_x(
514
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index + 1),
515
+ scale_f32x);
516
+ svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
517
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
518
+ svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
519
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
520
+ sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
521
+ sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
522
+ svbfloat16_t weight_pair_bf16 = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_f32x, weight_even_f32x),
523
+ nk_f32_to_bf16_sve_(predicate_all_f32x, weight_odd_f32x));
524
+ svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_f32x,
525
+ svreinterpret_f32_bf16(weight_pair_bf16));
526
+ }
527
+ row_sum_f32x = svadd_f32_x(predicate_all_f32x, row_sum_f32x, sum_delta_f32x);
528
+ row_max_f32x = new_max_f32x;
529
+
530
+ // Extract P columns from ZA0
531
+ svbfloat16_t probability_column_0_f32x = svreinterpret_bf16_f32(
532
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
533
+ svbfloat16_t probability_column_1_f32x = svreinterpret_bf16_f32(
534
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
535
+ svbfloat16_t probability_column_2_f32x = svreinterpret_bf16_f32(
536
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
537
+ svbfloat16_t probability_column_3_f32x = svreinterpret_bf16_f32(
538
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
539
+ svbfloat16_t probability_column_4_f32x = svreinterpret_bf16_f32(
540
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
541
+ svbfloat16_t probability_column_5_f32x = svreinterpret_bf16_f32(
542
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
543
+ svbfloat16_t probability_column_6_f32x = svreinterpret_bf16_f32(
544
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
545
+ svbfloat16_t probability_column_7_f32x = svreinterpret_bf16_f32(
546
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
547
+ svbfloat16_t probability_column_8_f32x = svreinterpret_bf16_f32(
548
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 8));
549
+ svbfloat16_t probability_column_9_f32x = svreinterpret_bf16_f32(
550
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 9));
551
+ svbfloat16_t probability_column_10_f32x = svreinterpret_bf16_f32(
552
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 10));
553
+ svbfloat16_t probability_column_11_f32x = svreinterpret_bf16_f32(
554
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 11));
555
+ svbfloat16_t probability_column_12_f32x = svreinterpret_bf16_f32(
556
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 12));
557
+ svbfloat16_t probability_column_13_f32x = svreinterpret_bf16_f32(
558
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 13));
559
+ svbfloat16_t probability_column_14_f32x = svreinterpret_bf16_f32(
560
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 14));
561
+ svbfloat16_t probability_column_15_f32x = svreinterpret_bf16_f32(
562
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 15));
563
+
564
+ // Pre-apply correction once before P×V
565
+ svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
566
+ nk_bf16_t const *values_block_lower = v_packed + kv_block_index * dim_tile_count * 8 * 32;
567
+ nk_bf16_t const *values_block_upper = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
568
+
569
+ if (max_was_updated) {
570
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
571
+ svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
572
+ for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
573
+ svst1_f32(
574
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
575
+ svmul_f32_x(predicate_all_f32x,
576
+ svld1_f32(predicate_all_f32x,
577
+ output_accumulator + query_index * head_dim_padded + dim_offset),
578
+ correction_scalar_f32x));
579
+ }
580
+ }
581
+
582
+ // P×V: zero → BFMOPA → read → add (no ZA writes for output_accumulator)
583
+ nk_size_t dim_tile = 0;
584
+ for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
585
+ svzero_za();
586
+ // Block0: 8 depth steps (KV positions 0-15)
587
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
588
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
589
+ ((dim_tile + 0) * 8 + 0) * 32)));
590
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
591
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
592
+ ((dim_tile + 1) * 8 + 0) * 32)));
593
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
594
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
595
+ ((dim_tile + 2) * 8 + 0) * 32)));
596
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
597
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
598
+ ((dim_tile + 3) * 8 + 0) * 32)));
599
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
600
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
601
+ ((dim_tile + 0) * 8 + 1) * 32)));
602
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
603
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
604
+ ((dim_tile + 1) * 8 + 1) * 32)));
605
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
606
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
607
+ ((dim_tile + 2) * 8 + 1) * 32)));
608
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
609
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
610
+ ((dim_tile + 3) * 8 + 1) * 32)));
611
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
612
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
613
+ ((dim_tile + 0) * 8 + 2) * 32)));
614
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
615
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
616
+ ((dim_tile + 1) * 8 + 2) * 32)));
617
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
618
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
619
+ ((dim_tile + 2) * 8 + 2) * 32)));
620
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
621
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
622
+ ((dim_tile + 3) * 8 + 2) * 32)));
623
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
624
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
625
+ ((dim_tile + 0) * 8 + 3) * 32)));
626
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
627
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
628
+ ((dim_tile + 1) * 8 + 3) * 32)));
629
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
630
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
631
+ ((dim_tile + 2) * 8 + 3) * 32)));
632
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
633
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
634
+ ((dim_tile + 3) * 8 + 3) * 32)));
635
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
636
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
637
+ ((dim_tile + 0) * 8 + 4) * 32)));
638
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
639
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
640
+ ((dim_tile + 1) * 8 + 4) * 32)));
641
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
642
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
643
+ ((dim_tile + 2) * 8 + 4) * 32)));
644
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
645
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
646
+ ((dim_tile + 3) * 8 + 4) * 32)));
647
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
648
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
649
+ ((dim_tile + 0) * 8 + 5) * 32)));
650
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
651
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
652
+ ((dim_tile + 1) * 8 + 5) * 32)));
653
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
654
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
655
+ ((dim_tile + 2) * 8 + 5) * 32)));
656
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
657
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
658
+ ((dim_tile + 3) * 8 + 5) * 32)));
659
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
660
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
661
+ ((dim_tile + 0) * 8 + 6) * 32)));
662
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
663
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
664
+ ((dim_tile + 1) * 8 + 6) * 32)));
665
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
666
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
667
+ ((dim_tile + 2) * 8 + 6) * 32)));
668
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
669
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
670
+ ((dim_tile + 3) * 8 + 6) * 32)));
671
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
672
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
673
+ ((dim_tile + 0) * 8 + 7) * 32)));
674
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
675
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
676
+ ((dim_tile + 1) * 8 + 7) * 32)));
677
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
678
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
679
+ ((dim_tile + 2) * 8 + 7) * 32)));
680
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
681
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
682
+ ((dim_tile + 3) * 8 + 7) * 32)));
683
+ // Block1: 8 depth steps (KV positions 16-31)
684
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
685
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
686
+ ((dim_tile + 0) * 8 + 0) * 32)));
687
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
688
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
689
+ ((dim_tile + 1) * 8 + 0) * 32)));
690
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
691
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
692
+ ((dim_tile + 2) * 8 + 0) * 32)));
693
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
694
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
695
+ ((dim_tile + 3) * 8 + 0) * 32)));
696
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
697
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
698
+ ((dim_tile + 0) * 8 + 1) * 32)));
699
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
700
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
701
+ ((dim_tile + 1) * 8 + 1) * 32)));
702
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
703
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
704
+ ((dim_tile + 2) * 8 + 1) * 32)));
705
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
706
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
707
+ ((dim_tile + 3) * 8 + 1) * 32)));
708
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
709
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
710
+ ((dim_tile + 0) * 8 + 2) * 32)));
711
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
712
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
713
+ ((dim_tile + 1) * 8 + 2) * 32)));
714
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
715
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
716
+ ((dim_tile + 2) * 8 + 2) * 32)));
717
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
718
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
719
+ ((dim_tile + 3) * 8 + 2) * 32)));
720
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
721
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
722
+ ((dim_tile + 0) * 8 + 3) * 32)));
723
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
724
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
725
+ ((dim_tile + 1) * 8 + 3) * 32)));
726
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
727
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
728
+ ((dim_tile + 2) * 8 + 3) * 32)));
729
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
730
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
731
+ ((dim_tile + 3) * 8 + 3) * 32)));
732
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
733
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
734
+ ((dim_tile + 0) * 8 + 4) * 32)));
735
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
736
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
737
+ ((dim_tile + 1) * 8 + 4) * 32)));
738
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
739
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
740
+ ((dim_tile + 2) * 8 + 4) * 32)));
741
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
742
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
743
+ ((dim_tile + 3) * 8 + 4) * 32)));
744
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
745
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
746
+ ((dim_tile + 0) * 8 + 5) * 32)));
747
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
748
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
749
+ ((dim_tile + 1) * 8 + 5) * 32)));
750
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
751
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
752
+ ((dim_tile + 2) * 8 + 5) * 32)));
753
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
754
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
755
+ ((dim_tile + 3) * 8 + 5) * 32)));
756
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
757
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
758
+ ((dim_tile + 0) * 8 + 6) * 32)));
759
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
760
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
761
+ ((dim_tile + 1) * 8 + 6) * 32)));
762
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
763
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
764
+ ((dim_tile + 2) * 8 + 6) * 32)));
765
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
766
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
767
+ ((dim_tile + 3) * 8 + 6) * 32)));
768
+ svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
769
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
770
+ ((dim_tile + 0) * 8 + 7) * 32)));
771
+ svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
772
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
773
+ ((dim_tile + 1) * 8 + 7) * 32)));
774
+ svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
775
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
776
+ ((dim_tile + 2) * 8 + 7) * 32)));
777
+ svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
778
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
779
+ ((dim_tile + 3) * 8 + 7) * 32)));
780
+ // Read BFMOPA result and ADD to output_accumulator
781
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
782
+ svst1_f32(
783
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
784
+ svadd_f32_x(predicate_all_f32x,
785
+ svld1_f32(predicate_all_f32x,
786
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
787
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
788
+ svst1_f32(
789
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
790
+ svadd_f32_x(predicate_all_f32x,
791
+ svld1_f32(predicate_all_f32x,
792
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
793
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
794
+ svst1_f32(
795
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
796
+ svadd_f32_x(predicate_all_f32x,
797
+ svld1_f32(predicate_all_f32x,
798
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
799
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
800
+ svst1_f32(
801
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
802
+ svadd_f32_x(predicate_all_f32x,
803
+ svld1_f32(predicate_all_f32x,
804
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
805
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
806
+ }
807
+ }
808
+ // Remainder: 1 dim_tile at a time using ZA0
809
+ for (; dim_tile < dim_tile_count; dim_tile++) {
810
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
811
+ svmopa_za32_bf16_m(
812
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
813
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 0) * 32)));
814
+ svmopa_za32_bf16_m(
815
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
816
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 1) * 32)));
817
+ svmopa_za32_bf16_m(
818
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
819
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 2) * 32)));
820
+ svmopa_za32_bf16_m(
821
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
822
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 3) * 32)));
823
+ svmopa_za32_bf16_m(
824
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
825
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 4) * 32)));
826
+ svmopa_za32_bf16_m(
827
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
828
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 5) * 32)));
829
+ svmopa_za32_bf16_m(
830
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
831
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 6) * 32)));
832
+ svmopa_za32_bf16_m(
833
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
834
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 7) * 32)));
835
+ svmopa_za32_bf16_m(
836
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
837
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 0) * 32)));
838
+ svmopa_za32_bf16_m(
839
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
840
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 1) * 32)));
841
+ svmopa_za32_bf16_m(
842
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
843
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 2) * 32)));
844
+ svmopa_za32_bf16_m(
845
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
846
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 3) * 32)));
847
+ svmopa_za32_bf16_m(
848
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
849
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 4) * 32)));
850
+ svmopa_za32_bf16_m(
851
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
852
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 5) * 32)));
853
+ svmopa_za32_bf16_m(
854
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
855
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 6) * 32)));
856
+ svmopa_za32_bf16_m(
857
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
858
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 7) * 32)));
859
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
860
+ svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
861
+ svadd_f32_x(predicate_all_f32x,
862
+ svld1_f32(predicate_all_f32x,
863
+ output_accumulator + query_index * head_dim_padded + dim_tile * 16),
864
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
865
+ }
866
+ }
867
+ }
868
+
869
+ // Bc=16 tail loop (handles remaining KV positions and decode path)
870
+ for (; kv_start < kv_len; kv_start += 16, kv_block_index++) {
871
+ nk_size_t const valid_kv = ((kv_start + 16) <= kv_len) ? 16 : (kv_len - kv_start);
872
+
873
+ // Q×K^T: pure memory→BFMOPA, no ZA staging
874
+ svzero_mask_za(nk_sme_zero_za32_tile_2_);
875
+ nk_bf16_t const *k_block = k + kv_block_index * k_depth_step_count * 32;
876
+ for (nk_size_t step = 0; step < k_depth_step_count; step++) {
877
+ svbfloat16_t zn = svreinterpret_bf16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
878
+ svbfloat16_t zm = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(k_block + step * 32));
879
+ svmopa_za32_bf16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm);
880
+ }
881
+
882
+ // Pass 1: Column-wise max (read ZA2 columns vertically)
883
+ svfloat32_t scale_16_f32x = svdup_f32(scale);
884
+ svfloat32_t block_max_16_f32x = svdup_f32(NK_F32_MIN);
885
+ for (nk_size_t column_index = 0; column_index < 16; column_index++) {
886
+ svfloat32_t score_column_f32x = svmul_f32_x(
887
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
888
+ scale_16_f32x);
889
+ block_max_16_f32x = svmax_f32_x(predicate_all_f32x, block_max_16_f32x, score_column_f32x);
890
+ }
891
+
892
+ // Softmax correction (fully vectorized)
893
+ svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, row_max_f32x, block_max_16_f32x);
894
+ svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_f32x,
895
+ svsub_f32_x(predicate_all_f32x, row_max_f32x, new_max_f32x));
896
+ svbool_t max_changed_16 = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
897
+ nk_u32_t max_was_updated_16 = svptest_any(predicate_all_f32x, max_changed_16) ? 1 : 0;
898
+ if (max_was_updated_16) row_sum_f32x = svmul_f32_x(predicate_all_f32x, row_sum_f32x, correction_f32x);
899
+ NK_ALIGN64 nk_f32_t corrections[16];
900
+ svst1_f32(predicate_all_f32x, corrections, correction_f32x);
901
+
902
+ // Pass 2: Column-wise exp + fused P write + sum (ZA2 → ZA0 columns 0-7)
903
+ svfloat32_t sum_delta_16_f32x = svdup_f32(0.0f);
904
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
905
+ for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
906
+ svfloat32_t score_even_f32x = svmul_f32_x(
907
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
908
+ scale_16_f32x);
909
+ svfloat32_t score_odd_f32x = svmul_f32_x(
910
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
911
+ scale_16_f32x);
912
+ svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
913
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
914
+ svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
915
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
916
+ sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_even_f32x);
917
+ sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_odd_f32x);
918
+ svbfloat16_t weight_pair_bf16 = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_f32x, weight_even_f32x),
919
+ nk_f32_to_bf16_sve_(predicate_all_f32x, weight_odd_f32x));
920
+ svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x, svreinterpret_f32_bf16(weight_pair_bf16));
921
+ }
922
+ row_sum_f32x = svadd_f32_x(predicate_all_f32x, row_sum_f32x, sum_delta_16_f32x);
923
+ row_max_f32x = new_max_f32x;
924
+
925
+ if (valid_query_count == 1) {
926
+ // Decode path: extract f32 weights from ZA0 row 0 using SVE
927
+ svbfloat16_t row0_bf16 = svreinterpret_bf16_f32(
928
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
929
+ svbfloat16_t weights_even_bf16 = svuzp1_bf16(row0_bf16, row0_bf16);
930
+ svbfloat16_t weights_odd_bf16 = svuzp2_bf16(row0_bf16, row0_bf16);
931
+ NK_ALIGN64 nk_f32_t decode_weights[16];
932
+ svst1_f32(svwhilelt_b32(0u, 8u), decode_weights,
933
+ nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_even_bf16));
934
+ svst1_f32(svwhilelt_b32(0u, 8u), decode_weights + 8,
935
+ nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_odd_bf16));
936
+ NK_ALIGN64 nk_f32_t decode_weights_ordered[16];
937
+ for (nk_size_t i = 0; i < 8; i++) {
938
+ decode_weights_ordered[2 * i] = decode_weights[i];
939
+ decode_weights_ordered[2 * i + 1] = decode_weights[8 + i];
940
+ }
941
+ svfloat32_t corr_f32x = svdup_f32(corrections[0]);
942
+ for (nk_size_t d = 0; d < head_dim; d += svcntw()) {
943
+ svbool_t predicate_f32x = svwhilelt_b32_u64(d, head_dim);
944
+ svfloat32_t acc_f32x = svmul_f32_x(predicate_f32x, svld1_f32(predicate_f32x, output_accumulator + d),
945
+ corr_f32x);
946
+ for (nk_size_t ki = 0; ki < valid_kv; ki++) {
947
+ nk_size_t dim_tile = d / 16, depth_s = ki / 2, sub = ki % 2;
948
+ nk_bf16_t const *v_vec = v_packed +
949
+ (kv_block_index * dim_tile_count * 8 + dim_tile * 8 + depth_s) * 32;
950
+ svbfloat16_t packed_bf16x = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)v_vec);
951
+ svbfloat16_t v_selected = (sub == 0) ? svuzp1_bf16(packed_bf16x, packed_bf16x)
952
+ : svuzp2_bf16(packed_bf16x, packed_bf16x);
953
+ acc_f32x = svmla_f32_x(predicate_f32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
954
+ nk_bf16_to_f32_sve_(predicate_f32x, v_selected));
955
+ }
956
+ svst1_f32(predicate_f32x, output_accumulator + d, acc_f32x);
957
+ }
958
+ }
959
+ else {
960
+ // Prefill Bc=16: extract P columns, pre-apply correction, add-after P×V
961
+ svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
962
+
963
+ svbfloat16_t probability_column_0_f32x = svreinterpret_bf16_f32(
964
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
965
+ svbfloat16_t probability_column_1_f32x = svreinterpret_bf16_f32(
966
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
967
+ svbfloat16_t probability_column_2_f32x = svreinterpret_bf16_f32(
968
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
969
+ svbfloat16_t probability_column_3_f32x = svreinterpret_bf16_f32(
970
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
971
+ svbfloat16_t probability_column_4_f32x = svreinterpret_bf16_f32(
972
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
973
+ svbfloat16_t probability_column_5_f32x = svreinterpret_bf16_f32(
974
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
975
+ svbfloat16_t probability_column_6_f32x = svreinterpret_bf16_f32(
976
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
977
+ svbfloat16_t probability_column_7_f32x = svreinterpret_bf16_f32(
978
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
979
+
980
+ nk_bf16_t const *v_block = v_packed + kv_block_index * dim_tile_count * 8 * 32;
981
+
982
+ // Pre-apply correction
983
+ if (max_was_updated_16) {
984
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
985
+ svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
986
+ for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
987
+ svst1_f32(
988
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
989
+ svmul_f32_x(predicate_all_f32x,
990
+ svld1_f32(predicate_all_f32x,
991
+ output_accumulator + query_index * head_dim_padded + dim_offset),
992
+ correction_scalar_f32x));
993
+ }
994
+ }
995
+
996
+ // P×V: zero → BFMOPA → read → add
997
+ nk_size_t dim_tile = 0;
998
+ for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
999
+ svzero_za();
1000
+ svmopa_za32_bf16_m(
1001
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1002
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
1003
+ svmopa_za32_bf16_m(
1004
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1005
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
1006
+ svmopa_za32_bf16_m(
1007
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1008
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
1009
+ svmopa_za32_bf16_m(
1010
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1011
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
1012
+ svmopa_za32_bf16_m(
1013
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1014
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
1015
+ svmopa_za32_bf16_m(
1016
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1017
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
1018
+ svmopa_za32_bf16_m(
1019
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1020
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
1021
+ svmopa_za32_bf16_m(
1022
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1023
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
1024
+ svmopa_za32_bf16_m(
1025
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1026
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
1027
+ svmopa_za32_bf16_m(
1028
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1029
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
1030
+ svmopa_za32_bf16_m(
1031
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1032
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
1033
+ svmopa_za32_bf16_m(
1034
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1035
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
1036
+ svmopa_za32_bf16_m(
1037
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1038
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
1039
+ svmopa_za32_bf16_m(
1040
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1041
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
1042
+ svmopa_za32_bf16_m(
1043
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1044
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
1045
+ svmopa_za32_bf16_m(
1046
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1047
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
1048
+ svmopa_za32_bf16_m(
1049
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1050
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
1051
+ svmopa_za32_bf16_m(
1052
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1053
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
1054
+ svmopa_za32_bf16_m(
1055
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1056
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
1057
+ svmopa_za32_bf16_m(
1058
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1059
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
1060
+ svmopa_za32_bf16_m(
1061
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1062
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
1063
+ svmopa_za32_bf16_m(
1064
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1065
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
1066
+ svmopa_za32_bf16_m(
1067
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1068
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
1069
+ svmopa_za32_bf16_m(
1070
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1071
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
1072
+ svmopa_za32_bf16_m(
1073
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1074
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
1075
+ svmopa_za32_bf16_m(
1076
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1077
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
1078
+ svmopa_za32_bf16_m(
1079
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1080
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
1081
+ svmopa_za32_bf16_m(
1082
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1083
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
1084
+ svmopa_za32_bf16_m(
1085
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1086
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
1087
+ svmopa_za32_bf16_m(
1088
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1089
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
1090
+ svmopa_za32_bf16_m(
1091
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1092
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
1093
+ svmopa_za32_bf16_m(
1094
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1095
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
1096
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1097
+ svst1_f32(
1098
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
1099
+ svadd_f32_x(predicate_all_f32x,
1100
+ svld1_f32(predicate_all_f32x,
1101
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
1102
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1103
+ svst1_f32(
1104
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
1105
+ svadd_f32_x(predicate_all_f32x,
1106
+ svld1_f32(predicate_all_f32x,
1107
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
1108
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
1109
+ svst1_f32(
1110
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
1111
+ svadd_f32_x(predicate_all_f32x,
1112
+ svld1_f32(predicate_all_f32x,
1113
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
1114
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
1115
+ svst1_f32(
1116
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
1117
+ svadd_f32_x(predicate_all_f32x,
1118
+ svld1_f32(predicate_all_f32x,
1119
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
1120
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
1121
+ }
1122
+ }
1123
+ for (; dim_tile < dim_tile_count; dim_tile++) {
1124
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
1125
+ svmopa_za32_bf16_m(
1126
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1127
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
1128
+ svmopa_za32_bf16_m(
1129
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1130
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
1131
+ svmopa_za32_bf16_m(
1132
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1133
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
1134
+ svmopa_za32_bf16_m(
1135
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1136
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
1137
+ svmopa_za32_bf16_m(
1138
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1139
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
1140
+ svmopa_za32_bf16_m(
1141
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1142
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
1143
+ svmopa_za32_bf16_m(
1144
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1145
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
1146
+ svmopa_za32_bf16_m(
1147
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1148
+ svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
1149
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
1150
+ svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
1151
+ svadd_f32_x(predicate_all_f32x,
1152
+ svld1_f32(predicate_all_f32x,
1153
+ output_accumulator + query_index * head_dim_padded + dim_tile * 16),
1154
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1155
+ }
1156
+ }
1157
+ }
1158
+
1159
+ // Final normalization
1160
+ NK_ALIGN64 nk_f32_t final_sums[16];
1161
+ svst1_f32(predicate_all_f32x, final_sums, row_sum_f32x);
1162
+
1163
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1164
+ nk_f32_t inv_sum = (final_sums[query_index] > 0.0f) ? (1.0f / final_sums[query_index]) : 0.0f;
1165
+ svfloat32_t inv_sum_f32x = svdup_f32(inv_sum);
1166
+
1167
+ for (nk_size_t dim_offset = 0; dim_offset < head_dim; dim_offset += svcntw()) {
1168
+ svbool_t predicate_f32x = svwhilelt_b32_u64(dim_offset, head_dim);
1169
+ svfloat32_t output_f32x = svmul_f32_x(
1170
+ predicate_f32x,
1171
+ svld1_f32(predicate_f32x, output_accumulator + query_index * head_dim_padded + dim_offset),
1172
+ inv_sum_f32x);
1173
+ svbfloat16_t output_bf16x = nk_f32_to_bf16_sve_(predicate_f32x, output_f32x);
1174
+ nk_size_t store_count = (head_dim - dim_offset) < (nk_size_t)svcntw() ? (head_dim - dim_offset)
1175
+ : (nk_size_t)svcntw();
1176
+ svbool_t store_predicate_f16x = svwhilelt_b16_u64(0u, store_count);
1177
+ svst1_bf16(store_predicate_f16x, (bfloat16_t *)(output + query_index * head_dim + dim_offset),
1178
+ output_bf16x);
1179
+ }
1180
+ }
1181
+ }
1182
+
1183
+ NK_PUBLIC void nk_attention_bf16_sme(nk_bf16_t const *q, void const *kv_packed, nk_bf16_t *output, nk_size_t num_heads,
1184
+ nk_size_t num_kv_heads, nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim,
1185
+ nk_f32_t scale) {
1186
+
1187
+ nk_attention_sme_packed_header_t const *header = (nk_attention_sme_packed_header_t const *)kv_packed;
1188
+ nk_size_t head_dim_padded = header->head_dim_padded;
1189
+ nk_size_t dim_tile_count = header->reserved[0]; // v_dim_tile_count
1190
+ nk_size_t kv_blocks = (kv_len + 15) / 16;
1191
+ nk_size_t kv_head_stride = kv_blocks * 16 * head_dim_padded;
1192
+
1193
+ nk_bf16_t const *k_packed = (nk_bf16_t const *)((char const *)kv_packed + header->k_offset);
1194
+ nk_bf16_t const *v_packed = (nk_bf16_t const *)((char const *)kv_packed + header->v_offset);
1195
+
1196
+ nk_size_t group_size = (num_kv_heads > 0) ? num_heads / num_kv_heads : 1;
1197
+
1198
+ for (nk_size_t q_head = 0; q_head < num_heads; q_head++) {
1199
+ nk_size_t kv_head = q_head / group_size;
1200
+
1201
+ nk_bf16_t const *q_ptr = q + q_head * query_len * head_dim;
1202
+ nk_bf16_t const *k_ptr = k_packed + kv_head * kv_head_stride;
1203
+ nk_bf16_t const *v_ptr = v_packed + kv_head * kv_head_stride;
1204
+ nk_bf16_t *out_ptr = output + q_head * query_len * head_dim;
1205
+
1206
+ for (nk_size_t q_start = 0; q_start < query_len; q_start += 16) {
1207
+ nk_size_t q_block_len = (q_start + 16 < query_len) ? 16 : (query_len - q_start);
1208
+
1209
+ nk_attention_bf16_sme_streaming_(q_ptr + q_start * head_dim, k_ptr, v_ptr, out_ptr + q_start * head_dim,
1210
+ q_block_len, kv_len, head_dim, head_dim_padded, dim_tile_count, scale);
1211
+ }
1212
+ }
1213
+ }
1214
+
1215
+ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streaming_(
1216
+ nk_f16_t const *q, // [query_len, head_dim]
1217
+ nk_f16_t const *k, // [kv_len, head_dim_padded] FMOPA-interleaved
1218
+ nk_f16_t const *v_packed, // FMOPA-interleaved V for this KV head
1219
+ nk_f16_t *output, // [query_len, head_dim]
1220
+ nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim, nk_size_t head_dim_padded, nk_size_t dim_tile_count,
1221
+ nk_f32_t scale) {
1222
+
1223
+ svbool_t const predicate_all_f32x = svptrue_b32();
1224
+ svbool_t const predicate_all_f16x = svptrue_b16();
1225
+ nk_size_t const valid_query_count = (query_len < 16) ? query_len : 16;
1226
+
1227
+ NK_ALIGN64 nk_f32_t row_max[16];
1228
+ NK_ALIGN64 nk_f32_t row_sum[16];
1229
+ NK_ALIGN64 nk_f32_t output_accumulator[16 * 256];
1230
+
1231
+ svst1_f32(predicate_all_f32x, row_max, svdup_f32(NK_F32_MIN));
1232
+ svst1_f32(predicate_all_f32x, row_sum, svdup_f32(0.0f));
1233
+ svfloat32_t zero_f32x = svdup_f32(0.0f);
1234
+ for (nk_size_t i = 0; i < 16 * head_dim_padded; i += svcntw()) {
1235
+ svst1_f32(predicate_all_f32x, output_accumulator + i, zero_f32x);
1236
+ }
1237
+
1238
+ nk_size_t kv_block_index = 0;
1239
+ nk_size_t kv_start = 0;
1240
+ svbool_t const batch_predicate_f32x = svwhilelt_b32(0u, 16u);
1241
+
1242
+ nk_size_t const k_depth_step_count = head_dim_padded / 2;
1243
+
1244
+ // Pre-transpose Q once: queries_transposed[step][16 f32 words]
1245
+ // queries_transposed[step] reinterpret-as-f16 = {Q[0][2s], Q[0][2s+1], Q[1][2s], Q[1][2s+1], ...}
1246
+ // This is the same interleaving ZA0 vertical reads would produce.
1247
+ NK_ALIGN64 nk_f32_t queries_transposed[128 * 16]; // max head_dim_padded/2 * 16 = 128 * 16
1248
+ for (nk_size_t batch = 0; batch < head_dim_padded / 32; batch++) {
1249
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
1250
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
1251
+ svld1_hor_za32(0, query_index, batch_predicate_f32x,
1252
+ (nk_f32_t const *)(q + query_index * head_dim + batch * 32));
1253
+ for (nk_size_t step = 0; step < 16; step++)
1254
+ svst1_f32(predicate_all_f32x, queries_transposed + (batch * 16 + step) * 16,
1255
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, step));
1256
+ }
1257
+
1258
+ // === Bc=32 main loop (prefill only, skipped for decode) ===
1259
+ if (valid_query_count > 1) {
1260
+ for (; kv_start + 32 <= kv_len; kv_start += 32, kv_block_index += 2) {
1261
+ // Q×K^T: pure memory→FMOPA, no ZA staging for Q or K
1262
+ svzero_mask_za(nk_sme_zero_za32_tile_2_);
1263
+ svzero_mask_za(nk_sme_zero_za32_tile_3_);
1264
+ nk_f16_t const *keys_block_lower = k + kv_block_index * k_depth_step_count * 32;
1265
+ nk_f16_t const *keys_block_upper = k + (kv_block_index + 1) * k_depth_step_count * 32;
1266
+ for (nk_size_t step = 0; step < k_depth_step_count; step++) {
1267
+ svfloat16_t zn = svreinterpret_f16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
1268
+ svfloat16_t zm0 = svld1_f16(predicate_all_f16x, (float16_t const *)(keys_block_lower + step * 32));
1269
+ svfloat16_t zm1 = svld1_f16(predicate_all_f16x, (float16_t const *)(keys_block_upper + step * 32));
1270
+ svmopa_za32_f16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm0);
1271
+ svmopa_za32_f16_m(3, predicate_all_f32x, predicate_all_f32x, zn, zm1);
1272
+ }
1273
+ // ZA2 = scores[query_index][0:15], ZA3 = scores[query_index][16:31]
1274
+
1275
+ // Pass 1: Column-wise max (read ZA2/ZA3 columns vertically)
1276
+ svfloat32_t scale_f32x = svdup_f32(scale);
1277
+ svfloat32_t block_max_f32x = svdup_f32(NK_F32_MIN);
1278
+ for (nk_size_t column_index = 0; column_index < 16; column_index++) {
1279
+ svfloat32_t score_column_f32x = svmul_f32_x(
1280
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
1281
+ scale_f32x);
1282
+ block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
1283
+ }
1284
+ for (nk_size_t column_index = 0; column_index < 16; column_index++) {
1285
+ svfloat32_t score_column_f32x = svmul_f32_x(
1286
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
1287
+ scale_f32x);
1288
+ block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
1289
+ }
1290
+
1291
+ // Softmax correction (vectorized via array load/store)
1292
+ svfloat32_t old_max_f32x = svld1_f32(predicate_all_f32x, row_max);
1293
+ svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, old_max_f32x, block_max_f32x);
1294
+ svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
1295
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, old_max_f32x, new_max_f32x));
1296
+ svbool_t max_changed = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
1297
+ nk_u32_t max_was_updated = svptest_any(predicate_all_f32x, max_changed) ? 1 : 0;
1298
+ svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_f32x, row_sum);
1299
+ if (max_was_updated)
1300
+ row_sum_corrected_f32x = svmul_f32_x(predicate_all_f32x, row_sum_corrected_f32x, correction_f32x);
1301
+ NK_ALIGN64 nk_f32_t corrections[16];
1302
+ svst1_f32(predicate_all_f32x, corrections, correction_f32x);
1303
+
1304
+ // Pass 2: Column-wise exp + fused P write + sum
1305
+ svfloat32_t sum_delta_f32x = svdup_f32(0.0f);
1306
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
1307
+ // ZA2 columns in pairs -> ZA0 columns 0-7
1308
+ for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
1309
+ svfloat32_t score_even_f32x = svmul_f32_x(
1310
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
1311
+ scale_f32x);
1312
+ svfloat32_t score_odd_f32x = svmul_f32_x(
1313
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
1314
+ scale_f32x);
1315
+ svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
1316
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
1317
+ svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
1318
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
1319
+ sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
1320
+ sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
1321
+ svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_f32x, weight_even_f32x),
1322
+ svcvt_f16_f32_x(predicate_all_f32x, weight_odd_f32x));
1323
+ svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x,
1324
+ svreinterpret_f32_f16(weight_pair_f16x));
1325
+ }
1326
+ // ZA3 columns in pairs -> ZA0 columns 8-15
1327
+ for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
1328
+ svfloat32_t score_even_f32x = svmul_f32_x(
1329
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
1330
+ scale_f32x);
1331
+ svfloat32_t score_odd_f32x = svmul_f32_x(
1332
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index + 1),
1333
+ scale_f32x);
1334
+ svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
1335
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
1336
+ svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
1337
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
1338
+ sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
1339
+ sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
1340
+ svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_f32x, weight_even_f32x),
1341
+ svcvt_f16_f32_x(predicate_all_f32x, weight_odd_f32x));
1342
+ svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_f32x,
1343
+ svreinterpret_f32_f16(weight_pair_f16x));
1344
+ }
1345
+ row_sum_corrected_f32x = svadd_f32_x(predicate_all_f32x, row_sum_corrected_f32x, sum_delta_f32x);
1346
+ svst1_f32(predicate_all_f32x, row_sum, row_sum_corrected_f32x);
1347
+ svst1_f32(predicate_all_f32x, row_max, new_max_f32x);
1348
+
1349
+ // Extract P columns from ZA0
1350
+ svfloat16_t probability_column_0_f32x = svreinterpret_f16_f32(
1351
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
1352
+ svfloat16_t probability_column_1_f32x = svreinterpret_f16_f32(
1353
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
1354
+ svfloat16_t probability_column_2_f32x = svreinterpret_f16_f32(
1355
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
1356
+ svfloat16_t probability_column_3_f32x = svreinterpret_f16_f32(
1357
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
1358
+ svfloat16_t probability_column_4_f32x = svreinterpret_f16_f32(
1359
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
1360
+ svfloat16_t probability_column_5_f32x = svreinterpret_f16_f32(
1361
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
1362
+ svfloat16_t probability_column_6_f32x = svreinterpret_f16_f32(
1363
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
1364
+ svfloat16_t probability_column_7_f32x = svreinterpret_f16_f32(
1365
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
1366
+ svfloat16_t probability_column_8_f32x = svreinterpret_f16_f32(
1367
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 8));
1368
+ svfloat16_t probability_column_9_f32x = svreinterpret_f16_f32(
1369
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 9));
1370
+ svfloat16_t probability_column_10_f32x = svreinterpret_f16_f32(
1371
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 10));
1372
+ svfloat16_t probability_column_11_f32x = svreinterpret_f16_f32(
1373
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 11));
1374
+ svfloat16_t probability_column_12_f32x = svreinterpret_f16_f32(
1375
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 12));
1376
+ svfloat16_t probability_column_13_f32x = svreinterpret_f16_f32(
1377
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 13));
1378
+ svfloat16_t probability_column_14_f32x = svreinterpret_f16_f32(
1379
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 14));
1380
+ svfloat16_t probability_column_15_f32x = svreinterpret_f16_f32(
1381
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 15));
1382
+
1383
+ // Pre-apply correction once before P×V
1384
+ svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
1385
+ nk_f16_t const *values_block_lower = v_packed + kv_block_index * dim_tile_count * 8 * 32;
1386
+ nk_f16_t const *values_block_upper = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
1387
+
1388
+ if (max_was_updated) {
1389
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1390
+ svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
1391
+ for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
1392
+ svst1_f32(
1393
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
1394
+ svmul_f32_x(predicate_all_f32x,
1395
+ svld1_f32(predicate_all_f32x,
1396
+ output_accumulator + query_index * head_dim_padded + dim_offset),
1397
+ correction_scalar_f32x));
1398
+ }
1399
+ }
1400
+
1401
+ // P×V: zero -> FMOPA -> read -> add (no ZA writes for output_accumulator)
1402
+ nk_size_t dim_tile = 0;
1403
+ for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
1404
+ svzero_za();
1405
+ // Block0: 8 depth steps (KV positions 0-15)
1406
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1407
+ svld1_f16(predicate_all_f16x,
1408
+ (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 0) * 32)));
1409
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1410
+ svld1_f16(predicate_all_f16x,
1411
+ (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 0) * 32)));
1412
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1413
+ svld1_f16(predicate_all_f16x,
1414
+ (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 0) * 32)));
1415
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1416
+ svld1_f16(predicate_all_f16x,
1417
+ (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 0) * 32)));
1418
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1419
+ svld1_f16(predicate_all_f16x,
1420
+ (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 1) * 32)));
1421
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1422
+ svld1_f16(predicate_all_f16x,
1423
+ (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 1) * 32)));
1424
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1425
+ svld1_f16(predicate_all_f16x,
1426
+ (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 1) * 32)));
1427
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1428
+ svld1_f16(predicate_all_f16x,
1429
+ (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 1) * 32)));
1430
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1431
+ svld1_f16(predicate_all_f16x,
1432
+ (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 2) * 32)));
1433
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1434
+ svld1_f16(predicate_all_f16x,
1435
+ (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 2) * 32)));
1436
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1437
+ svld1_f16(predicate_all_f16x,
1438
+ (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 2) * 32)));
1439
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1440
+ svld1_f16(predicate_all_f16x,
1441
+ (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 2) * 32)));
1442
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1443
+ svld1_f16(predicate_all_f16x,
1444
+ (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 3) * 32)));
1445
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1446
+ svld1_f16(predicate_all_f16x,
1447
+ (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 3) * 32)));
1448
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1449
+ svld1_f16(predicate_all_f16x,
1450
+ (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 3) * 32)));
1451
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1452
+ svld1_f16(predicate_all_f16x,
1453
+ (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 3) * 32)));
1454
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1455
+ svld1_f16(predicate_all_f16x,
1456
+ (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 4) * 32)));
1457
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1458
+ svld1_f16(predicate_all_f16x,
1459
+ (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 4) * 32)));
1460
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1461
+ svld1_f16(predicate_all_f16x,
1462
+ (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 4) * 32)));
1463
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1464
+ svld1_f16(predicate_all_f16x,
1465
+ (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 4) * 32)));
1466
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1467
+ svld1_f16(predicate_all_f16x,
1468
+ (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 5) * 32)));
1469
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1470
+ svld1_f16(predicate_all_f16x,
1471
+ (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 5) * 32)));
1472
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1473
+ svld1_f16(predicate_all_f16x,
1474
+ (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 5) * 32)));
1475
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1476
+ svld1_f16(predicate_all_f16x,
1477
+ (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 5) * 32)));
1478
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1479
+ svld1_f16(predicate_all_f16x,
1480
+ (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 6) * 32)));
1481
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1482
+ svld1_f16(predicate_all_f16x,
1483
+ (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 6) * 32)));
1484
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1485
+ svld1_f16(predicate_all_f16x,
1486
+ (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 6) * 32)));
1487
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1488
+ svld1_f16(predicate_all_f16x,
1489
+ (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 6) * 32)));
1490
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1491
+ svld1_f16(predicate_all_f16x,
1492
+ (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 7) * 32)));
1493
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1494
+ svld1_f16(predicate_all_f16x,
1495
+ (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 7) * 32)));
1496
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1497
+ svld1_f16(predicate_all_f16x,
1498
+ (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 7) * 32)));
1499
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1500
+ svld1_f16(predicate_all_f16x,
1501
+ (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 7) * 32)));
1502
+ // Block1: 8 depth steps (KV positions 16-31)
1503
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1504
+ svld1_f16(predicate_all_f16x,
1505
+ (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 0) * 32)));
1506
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1507
+ svld1_f16(predicate_all_f16x,
1508
+ (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 0) * 32)));
1509
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1510
+ svld1_f16(predicate_all_f16x,
1511
+ (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 0) * 32)));
1512
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1513
+ svld1_f16(predicate_all_f16x,
1514
+ (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 0) * 32)));
1515
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1516
+ svld1_f16(predicate_all_f16x,
1517
+ (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 1) * 32)));
1518
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1519
+ svld1_f16(predicate_all_f16x,
1520
+ (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 1) * 32)));
1521
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1522
+ svld1_f16(predicate_all_f16x,
1523
+ (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 1) * 32)));
1524
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1525
+ svld1_f16(predicate_all_f16x,
1526
+ (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 1) * 32)));
1527
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1528
+ svld1_f16(predicate_all_f16x,
1529
+ (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 2) * 32)));
1530
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1531
+ svld1_f16(predicate_all_f16x,
1532
+ (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 2) * 32)));
1533
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1534
+ svld1_f16(predicate_all_f16x,
1535
+ (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 2) * 32)));
1536
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1537
+ svld1_f16(predicate_all_f16x,
1538
+ (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 2) * 32)));
1539
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1540
+ svld1_f16(predicate_all_f16x,
1541
+ (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 3) * 32)));
1542
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1543
+ svld1_f16(predicate_all_f16x,
1544
+ (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 3) * 32)));
1545
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1546
+ svld1_f16(predicate_all_f16x,
1547
+ (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 3) * 32)));
1548
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1549
+ svld1_f16(predicate_all_f16x,
1550
+ (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 3) * 32)));
1551
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1552
+ svld1_f16(predicate_all_f16x,
1553
+ (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 4) * 32)));
1554
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1555
+ svld1_f16(predicate_all_f16x,
1556
+ (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 4) * 32)));
1557
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1558
+ svld1_f16(predicate_all_f16x,
1559
+ (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 4) * 32)));
1560
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1561
+ svld1_f16(predicate_all_f16x,
1562
+ (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 4) * 32)));
1563
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1564
+ svld1_f16(predicate_all_f16x,
1565
+ (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 5) * 32)));
1566
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1567
+ svld1_f16(predicate_all_f16x,
1568
+ (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 5) * 32)));
1569
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1570
+ svld1_f16(predicate_all_f16x,
1571
+ (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 5) * 32)));
1572
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1573
+ svld1_f16(predicate_all_f16x,
1574
+ (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 5) * 32)));
1575
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1576
+ svld1_f16(predicate_all_f16x,
1577
+ (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 6) * 32)));
1578
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1579
+ svld1_f16(predicate_all_f16x,
1580
+ (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 6) * 32)));
1581
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1582
+ svld1_f16(predicate_all_f16x,
1583
+ (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 6) * 32)));
1584
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1585
+ svld1_f16(predicate_all_f16x,
1586
+ (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 6) * 32)));
1587
+ svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1588
+ svld1_f16(predicate_all_f16x,
1589
+ (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 7) * 32)));
1590
+ svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1591
+ svld1_f16(predicate_all_f16x,
1592
+ (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 7) * 32)));
1593
+ svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1594
+ svld1_f16(predicate_all_f16x,
1595
+ (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 7) * 32)));
1596
+ svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1597
+ svld1_f16(predicate_all_f16x,
1598
+ (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 7) * 32)));
1599
+ // Read FMOPA result and ADD to output_accumulator
1600
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1601
+ svst1_f32(
1602
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
1603
+ svadd_f32_x(predicate_all_f32x,
1604
+ svld1_f32(predicate_all_f32x,
1605
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
1606
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1607
+ svst1_f32(
1608
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
1609
+ svadd_f32_x(predicate_all_f32x,
1610
+ svld1_f32(predicate_all_f32x,
1611
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
1612
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
1613
+ svst1_f32(
1614
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
1615
+ svadd_f32_x(predicate_all_f32x,
1616
+ svld1_f32(predicate_all_f32x,
1617
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
1618
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
1619
+ svst1_f32(
1620
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
1621
+ svadd_f32_x(predicate_all_f32x,
1622
+ svld1_f32(predicate_all_f32x,
1623
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
1624
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
1625
+ }
1626
+ }
1627
+ // Remainder: 1 dim_tile at a time using ZA0
1628
+ for (; dim_tile < dim_tile_count; dim_tile++) {
1629
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
1630
+ svmopa_za32_f16_m(
1631
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1632
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 0) * 32)));
1633
+ svmopa_za32_f16_m(
1634
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1635
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 1) * 32)));
1636
+ svmopa_za32_f16_m(
1637
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1638
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 2) * 32)));
1639
+ svmopa_za32_f16_m(
1640
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1641
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 3) * 32)));
1642
+ svmopa_za32_f16_m(
1643
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1644
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 4) * 32)));
1645
+ svmopa_za32_f16_m(
1646
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1647
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 5) * 32)));
1648
+ svmopa_za32_f16_m(
1649
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1650
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 6) * 32)));
1651
+ svmopa_za32_f16_m(
1652
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1653
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 7) * 32)));
1654
+ svmopa_za32_f16_m(
1655
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1656
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 0) * 32)));
1657
+ svmopa_za32_f16_m(
1658
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1659
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 1) * 32)));
1660
+ svmopa_za32_f16_m(
1661
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1662
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 2) * 32)));
1663
+ svmopa_za32_f16_m(
1664
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1665
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 3) * 32)));
1666
+ svmopa_za32_f16_m(
1667
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1668
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 4) * 32)));
1669
+ svmopa_za32_f16_m(
1670
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1671
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 5) * 32)));
1672
+ svmopa_za32_f16_m(
1673
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1674
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 6) * 32)));
1675
+ svmopa_za32_f16_m(
1676
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1677
+ svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 7) * 32)));
1678
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
1679
+ svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
1680
+ svadd_f32_x(predicate_all_f32x,
1681
+ svld1_f32(predicate_all_f32x,
1682
+ output_accumulator + query_index * head_dim_padded + dim_tile * 16),
1683
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1684
+ }
1685
+ }
1686
+ }
1687
+
1688
+ // === Bc=16 tail loop (handles remaining KV positions and decode path) ===
1689
+ for (; kv_start < kv_len; kv_start += 16, kv_block_index++) {
1690
+ nk_size_t const valid_kv = ((kv_start + 16) <= kv_len) ? 16 : (kv_len - kv_start);
1691
+
1692
+ // Q×K^T: pure memory→FMOPA, no ZA staging
1693
+ svzero_mask_za(nk_sme_zero_za32_tile_2_);
1694
+ nk_f16_t const *k_block = k + kv_block_index * k_depth_step_count * 32;
1695
+ for (nk_size_t step = 0; step < k_depth_step_count; step++) {
1696
+ svfloat16_t zn = svreinterpret_f16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
1697
+ svfloat16_t zm = svld1_f16(predicate_all_f16x, (float16_t const *)(k_block + step * 32));
1698
+ svmopa_za32_f16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm);
1699
+ }
1700
+
1701
+ // Pass 1: Column-wise max (read ZA2 columns vertically)
1702
+ svfloat32_t scale_16_f32x = svdup_f32(scale);
1703
+ svfloat32_t block_max_16_f32x = svdup_f32(NK_F32_MIN);
1704
+ for (nk_size_t column_index = 0; column_index < 16; column_index++) {
1705
+ svfloat32_t score_column_f32x = svmul_f32_x(
1706
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
1707
+ scale_16_f32x);
1708
+ block_max_16_f32x = svmax_f32_x(predicate_all_f32x, block_max_16_f32x, score_column_f32x);
1709
+ }
1710
+
1711
+ svfloat32_t old_max_f32x = svld1_f32(predicate_all_f32x, row_max);
1712
+ svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, old_max_f32x, block_max_16_f32x);
1713
+ svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_f32x,
1714
+ svsub_f32_x(predicate_all_f32x, old_max_f32x, new_max_f32x));
1715
+ svbool_t max_changed_16 = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
1716
+ nk_u32_t max_was_updated_16 = svptest_any(predicate_all_f32x, max_changed_16) ? 1 : 0;
1717
+ svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_f32x, row_sum);
1718
+ if (max_was_updated_16)
1719
+ row_sum_corrected_f32x = svmul_f32_x(predicate_all_f32x, row_sum_corrected_f32x, correction_f32x);
1720
+ NK_ALIGN64 nk_f32_t corrections[16];
1721
+ svst1_f32(predicate_all_f32x, corrections, correction_f32x);
1722
+
1723
+ // Pass 2: Column-wise exp + fused P write + sum (ZA2 → ZA0 columns 0-7)
1724
+ svfloat32_t sum_delta_16_f32x = svdup_f32(0.0f);
1725
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
1726
+ for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
1727
+ svfloat32_t score_even_f32x = svmul_f32_x(
1728
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
1729
+ scale_16_f32x);
1730
+ svfloat32_t score_odd_f32x = svmul_f32_x(
1731
+ predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
1732
+ scale_16_f32x);
1733
+ svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
1734
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
1735
+ svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
1736
+ predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
1737
+ sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_even_f32x);
1738
+ sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_odd_f32x);
1739
+ svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_f32x, weight_even_f32x),
1740
+ svcvt_f16_f32_x(predicate_all_f32x, weight_odd_f32x));
1741
+ svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x, svreinterpret_f32_f16(weight_pair_f16x));
1742
+ }
1743
+ row_sum_corrected_f32x = svadd_f32_x(predicate_all_f32x, row_sum_corrected_f32x, sum_delta_16_f32x);
1744
+ svst1_f32(predicate_all_f32x, row_sum, row_sum_corrected_f32x);
1745
+ svst1_f32(predicate_all_f32x, row_max, new_max_f32x);
1746
+
1747
+ if (valid_query_count == 1) {
1748
+ // Decode path: extract f32 weights from ZA0 row 0 using SVE
1749
+ svfloat16_t row0_f16 = svreinterpret_f16_f32(svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
1750
+ svfloat16_t weights_even_f16 = svuzp1_f16(row0_f16, row0_f16);
1751
+ svfloat16_t weights_odd_f16 = svuzp2_f16(row0_f16, row0_f16);
1752
+ NK_ALIGN64 nk_f32_t decode_weights[16];
1753
+ svst1_f32(svwhilelt_b32(0u, 8u), decode_weights, svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_even_f16));
1754
+ svst1_f32(svwhilelt_b32(0u, 8u), decode_weights + 8,
1755
+ svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_odd_f16));
1756
+ NK_ALIGN64 nk_f32_t decode_weights_ordered[16];
1757
+ for (nk_size_t i = 0; i < 8; i++) {
1758
+ decode_weights_ordered[2 * i] = decode_weights[i];
1759
+ decode_weights_ordered[2 * i + 1] = decode_weights[8 + i];
1760
+ }
1761
+ svfloat32_t corr_f32x = svdup_f32(corrections[0]);
1762
+ for (nk_size_t d = 0; d < head_dim; d += svcntw()) {
1763
+ svbool_t predicate_f32x = svwhilelt_b32_u64(d, head_dim);
1764
+ svfloat32_t acc_f32x = svmul_f32_x(predicate_f32x, svld1_f32(predicate_f32x, output_accumulator + d),
1765
+ corr_f32x);
1766
+ for (nk_size_t ki = 0; ki < valid_kv; ki++) {
1767
+ nk_size_t dim_tile = d / 16, depth_s = ki / 2, sub = ki % 2;
1768
+ nk_f16_t const *v_vec = v_packed +
1769
+ (kv_block_index * dim_tile_count * 8 + dim_tile * 8 + depth_s) * 32;
1770
+ svfloat16_t packed_f16x = svld1_f16(predicate_all_f16x, (float16_t const *)v_vec);
1771
+ svfloat16_t v_selected = (sub == 0) ? svuzp1_f16(packed_f16x, packed_f16x)
1772
+ : svuzp2_f16(packed_f16x, packed_f16x);
1773
+ acc_f32x = svmla_f32_x(predicate_f32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
1774
+ svcvt_f32_f16_x(predicate_f32x, v_selected));
1775
+ }
1776
+ svst1_f32(predicate_f32x, output_accumulator + d, acc_f32x);
1777
+ }
1778
+ }
1779
+ else {
1780
+ // Prefill Bc=16: extract P columns, pre-apply correction, add-after P×V
1781
+ svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
1782
+
1783
+ svfloat16_t probability_column_0_f32x = svreinterpret_f16_f32(
1784
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
1785
+ svfloat16_t probability_column_1_f32x = svreinterpret_f16_f32(
1786
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
1787
+ svfloat16_t probability_column_2_f32x = svreinterpret_f16_f32(
1788
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
1789
+ svfloat16_t probability_column_3_f32x = svreinterpret_f16_f32(
1790
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
1791
+ svfloat16_t probability_column_4_f32x = svreinterpret_f16_f32(
1792
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
1793
+ svfloat16_t probability_column_5_f32x = svreinterpret_f16_f32(
1794
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
1795
+ svfloat16_t probability_column_6_f32x = svreinterpret_f16_f32(
1796
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
1797
+ svfloat16_t probability_column_7_f32x = svreinterpret_f16_f32(
1798
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
1799
+
1800
+ nk_f16_t const *v_block = v_packed + kv_block_index * dim_tile_count * 8 * 32;
1801
+
1802
+ if (max_was_updated_16) {
1803
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1804
+ svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
1805
+ for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
1806
+ svst1_f32(
1807
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
1808
+ svmul_f32_x(predicate_all_f32x,
1809
+ svld1_f32(predicate_all_f32x,
1810
+ output_accumulator + query_index * head_dim_padded + dim_offset),
1811
+ correction_scalar_f32x));
1812
+ }
1813
+ }
1814
+
1815
+ nk_size_t dim_tile = 0;
1816
+ for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
1817
+ svzero_za();
1818
+ svmopa_za32_f16_m(
1819
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1820
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
1821
+ svmopa_za32_f16_m(
1822
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1823
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
1824
+ svmopa_za32_f16_m(
1825
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1826
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
1827
+ svmopa_za32_f16_m(
1828
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1829
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
1830
+ svmopa_za32_f16_m(
1831
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1832
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
1833
+ svmopa_za32_f16_m(
1834
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1835
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
1836
+ svmopa_za32_f16_m(
1837
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1838
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
1839
+ svmopa_za32_f16_m(
1840
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1841
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
1842
+ svmopa_za32_f16_m(
1843
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1844
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
1845
+ svmopa_za32_f16_m(
1846
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1847
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
1848
+ svmopa_za32_f16_m(
1849
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1850
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
1851
+ svmopa_za32_f16_m(
1852
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1853
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
1854
+ svmopa_za32_f16_m(
1855
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1856
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
1857
+ svmopa_za32_f16_m(
1858
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1859
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
1860
+ svmopa_za32_f16_m(
1861
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1862
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
1863
+ svmopa_za32_f16_m(
1864
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1865
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
1866
+ svmopa_za32_f16_m(
1867
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1868
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
1869
+ svmopa_za32_f16_m(
1870
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1871
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
1872
+ svmopa_za32_f16_m(
1873
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1874
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
1875
+ svmopa_za32_f16_m(
1876
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1877
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
1878
+ svmopa_za32_f16_m(
1879
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1880
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
1881
+ svmopa_za32_f16_m(
1882
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1883
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
1884
+ svmopa_za32_f16_m(
1885
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1886
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
1887
+ svmopa_za32_f16_m(
1888
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1889
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
1890
+ svmopa_za32_f16_m(
1891
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1892
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
1893
+ svmopa_za32_f16_m(
1894
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1895
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
1896
+ svmopa_za32_f16_m(
1897
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1898
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
1899
+ svmopa_za32_f16_m(
1900
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1901
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
1902
+ svmopa_za32_f16_m(
1903
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1904
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
1905
+ svmopa_za32_f16_m(
1906
+ 1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1907
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
1908
+ svmopa_za32_f16_m(
1909
+ 2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1910
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
1911
+ svmopa_za32_f16_m(
1912
+ 3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1913
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
1914
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1915
+ svst1_f32(
1916
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
1917
+ svadd_f32_x(predicate_all_f32x,
1918
+ svld1_f32(predicate_all_f32x,
1919
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
1920
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1921
+ svst1_f32(
1922
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
1923
+ svadd_f32_x(predicate_all_f32x,
1924
+ svld1_f32(predicate_all_f32x,
1925
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
1926
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
1927
+ svst1_f32(
1928
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
1929
+ svadd_f32_x(predicate_all_f32x,
1930
+ svld1_f32(predicate_all_f32x,
1931
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
1932
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
1933
+ svst1_f32(
1934
+ predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
1935
+ svadd_f32_x(predicate_all_f32x,
1936
+ svld1_f32(predicate_all_f32x,
1937
+ output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
1938
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
1939
+ }
1940
+ }
1941
+ for (; dim_tile < dim_tile_count; dim_tile++) {
1942
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
1943
+ svmopa_za32_f16_m(
1944
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1945
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
1946
+ svmopa_za32_f16_m(
1947
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1948
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
1949
+ svmopa_za32_f16_m(
1950
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1951
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
1952
+ svmopa_za32_f16_m(
1953
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1954
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
1955
+ svmopa_za32_f16_m(
1956
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1957
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
1958
+ svmopa_za32_f16_m(
1959
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1960
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
1961
+ svmopa_za32_f16_m(
1962
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1963
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
1964
+ svmopa_za32_f16_m(
1965
+ 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1966
+ svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
1967
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
1968
+ svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
1969
+ svadd_f32_x(predicate_all_f32x,
1970
+ svld1_f32(predicate_all_f32x,
1971
+ output_accumulator + query_index * head_dim_padded + dim_tile * 16),
1972
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1973
+ }
1974
+ }
1975
+ }
1976
+
1977
+ // Final normalization
1978
+ svfloat32_t final_sum_f32x = svld1_f32(predicate_all_f32x, row_sum);
1979
+ svfloat32_t ones_f32x = svdup_f32(1.0f);
1980
+ svfloat32_t zeros_f32x = svdup_f32(0.0f);
1981
+ svbool_t sum_positive = svcmpgt_f32(predicate_all_f32x, final_sum_f32x, zeros_f32x);
1982
+ svfloat32_t inv_sum_f32x = svsel_f32(sum_positive, svdiv_f32_x(predicate_all_f32x, ones_f32x, final_sum_f32x),
1983
+ zeros_f32x);
1984
+
1985
+ NK_ALIGN64 nk_f32_t inv_sums[16];
1986
+ svst1_f32(predicate_all_f32x, inv_sums, inv_sum_f32x);
1987
+
1988
+ for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1989
+ svfloat32_t inv_sum_f32x = svdup_f32(inv_sums[query_index]);
1990
+ for (nk_size_t dim_offset = 0; dim_offset < head_dim; dim_offset += svcntw()) {
1991
+ svbool_t predicate_f32x = svwhilelt_b32_u64(dim_offset, head_dim);
1992
+ svfloat32_t output_f32x = svmul_f32_x(
1993
+ predicate_f32x,
1994
+ svld1_f32(predicate_f32x, output_accumulator + query_index * head_dim_padded + dim_offset),
1995
+ inv_sum_f32x);
1996
+ svfloat16_t output_f16x = svcvt_f16_f32_x(predicate_f32x, output_f32x);
1997
+ nk_size_t store_count = (head_dim - dim_offset) < (nk_size_t)svcntw() ? (head_dim - dim_offset)
1998
+ : (nk_size_t)svcntw();
1999
+ svbool_t predicate_f16x = svwhilelt_b16_u64(0u, store_count);
2000
+ svst1_f16(predicate_f16x, (float16_t *)(output + query_index * head_dim + dim_offset), output_f16x);
2001
+ }
2002
+ }
2003
+ }
2004
+
2005
+ NK_PUBLIC void nk_attention_f16_sme(nk_f16_t const *q, void const *kv_packed, nk_f16_t *output, nk_size_t num_heads,
2006
+ nk_size_t num_kv_heads, nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim,
2007
+ nk_f32_t scale) {
2008
+
2009
+ nk_attention_sme_packed_header_t const *header = (nk_attention_sme_packed_header_t const *)kv_packed;
2010
+ nk_size_t head_dim_padded = header->head_dim_padded;
2011
+ nk_size_t dim_tile_count = header->reserved[0];
2012
+ nk_size_t kv_blocks = (kv_len + 15) / 16;
2013
+ // K and V both use interleaved format: kv_blocks * 16 * head_dim_padded elements per head
2014
+ nk_size_t kv_head_stride = kv_blocks * 16 * head_dim_padded;
2015
+
2016
+ nk_f16_t const *k_packed = (nk_f16_t const *)((char const *)kv_packed + header->k_offset);
2017
+ nk_f16_t const *v_packed = (nk_f16_t const *)((char const *)kv_packed + header->v_offset);
2018
+
2019
+ nk_size_t group_size = (num_kv_heads > 0) ? num_heads / num_kv_heads : 1;
2020
+
2021
+ for (nk_size_t q_head = 0; q_head < num_heads; q_head++) {
2022
+ nk_size_t kv_head = q_head / group_size;
2023
+
2024
+ nk_f16_t const *q_ptr = q + q_head * query_len * head_dim;
2025
+ nk_f16_t const *k_ptr = k_packed + kv_head * kv_head_stride;
2026
+ nk_f16_t const *v_ptr = v_packed + kv_head * kv_head_stride;
2027
+ nk_f16_t *out_ptr = output + q_head * query_len * head_dim;
2028
+
2029
+ for (nk_size_t q_start = 0; q_start < query_len; q_start += 16) {
2030
+ nk_size_t q_block_len = (q_start + 16 < query_len) ? 16 : (query_len - q_start);
2031
+
2032
+ nk_attention_f16_sme_streaming_(q_ptr + q_start * head_dim, k_ptr, v_ptr, out_ptr + q_start * head_dim,
2033
+ q_block_len, kv_len, head_dim, head_dim_padded, dim_tile_count, scale);
2034
+ }
2035
+ }
2036
+ }
2037
+
2038
+ NK_PUBLIC void nk_attention_causal_bf16_sme(nk_bf16_t const *q, void const *kv_packed, nk_bf16_t *output,
2039
+ nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
2040
+ nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
2041
+ // TODO: Implement proper causal masking with block skipping
2042
+ // For now, delegate to full attention (correct for decode where query_len=1)
2043
+ nk_attention_bf16_sme(q, kv_packed, output, num_heads, num_kv_heads, query_len, kv_len, head_dim, scale);
2044
+ }
2045
+
2046
+ NK_PUBLIC void nk_attention_causal_f16_sme(nk_f16_t const *q, void const *kv_packed, nk_f16_t *output,
2047
+ nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
2048
+ nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
2049
+ // TODO: Implement proper causal masking with block skipping
2050
+ // For now, delegate to full attention (correct for decode where query_len=1)
2051
+ nk_attention_f16_sme(q, kv_packed, output, num_heads, num_kv_heads, query_len, kv_len, head_dim, scale);
2052
+ }
2053
+
2054
+ #if defined(__clang__)
2055
+ #pragma clang attribute pop
2056
+ #elif defined(__GNUC__)
2057
+ #pragma GCC pop_options
2058
+ #endif
2059
+
2060
+ #if defined(__cplusplus)
2061
+ } // extern "C"
2062
+ #endif
2063
+
2064
+ #endif // NK_TARGET_SME
2065
+ #endif // NK_TARGET_ARM_
2066
+ #endif // NK_ATTENTION_SME_H