numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -101,7 +101,7 @@ extern "C" {
101
101
  #endif
102
102
 
103
103
  #if defined(__clang__)
104
- #pragma clang attribute push(__attribute__((target("sme,sve"))), apply_to = function)
104
+ #pragma clang attribute push(__attribute__((target("sme"))), apply_to = function)
105
105
  #elif defined(__GNUC__)
106
106
  #pragma GCC push_options
107
107
  #pragma GCC target("+sme")
@@ -116,10 +116,10 @@ extern "C" {
116
116
  * 3. Shift left by 16 to place in f32 exponent+mantissa position
117
117
  * 4. Reinterpret as f32
118
118
  */
119
- NK_INTERNAL svfloat32_t nk_bf16_to_f32_sve_(svbool_t predicate_f32x, svbfloat16_t x_bf16x) __arm_streaming {
119
+ NK_INTERNAL svfloat32_t nk_bf16_to_f32_sve_(svbool_t predicate_b32x, svbfloat16_t x_bf16x) __arm_streaming {
120
120
  svuint16_t x_u16x = svreinterpret_u16_bf16(x_bf16x);
121
121
  svuint32_t x_u32x = svunpklo_u32(x_u16x);
122
- x_u32x = svlsl_n_u32_x(predicate_f32x, x_u32x, 16);
122
+ x_u32x = svlsl_n_u32_x(predicate_b32x, x_u32x, 16);
123
123
  return svreinterpret_f32_u32(x_u32x);
124
124
  }
125
125
 
@@ -131,10 +131,10 @@ NK_INTERNAL svfloat32_t nk_bf16_to_f32_sve_(svbool_t predicate_f32x, svbfloat16_
131
131
  * 3. Shift right by 16
132
132
  * 4. Narrow to u16 and reinterpret as bf16
133
133
  */
134
- NK_INTERNAL svbfloat16_t nk_f32_to_bf16_sve_(svbool_t predicate_f32x, svfloat32_t x_f32x) __arm_streaming {
134
+ NK_INTERNAL svbfloat16_t nk_f32_to_bf16_sve_(svbool_t predicate_b32x, svfloat32_t x_f32x) __arm_streaming {
135
135
  svuint32_t x_u32x = svreinterpret_u32_f32(x_f32x);
136
- x_u32x = svadd_n_u32_x(predicate_f32x, x_u32x, 0x8000); // Round to nearest
137
- x_u32x = svlsr_n_u32_x(predicate_f32x, x_u32x, 16);
136
+ x_u32x = svadd_n_u32_x(predicate_b32x, x_u32x, 0x8000); // Round to nearest
137
+ x_u32x = svlsr_n_u32_x(predicate_b32x, x_u32x, 16);
138
138
  svuint16_t x_u16x = svuzp1_u16(svreinterpret_u16_u32(x_u32x), svreinterpret_u16_u32(x_u32x));
139
139
  return svreinterpret_bf16_u16(x_u16x);
140
140
  }
@@ -166,71 +166,71 @@ typedef struct {
166
166
  * @param x Input vector
167
167
  * @return exp(x) approximation
168
168
  */
169
- NK_INTERNAL svfloat32_t nk_exp_f32_sve_(svbool_t predicate_f32x, svfloat32_t x_f32x) __arm_streaming {
169
+ NK_INTERNAL svfloat32_t nk_exp_f32_sve_(svbool_t predicate_b32x, svfloat32_t x_f32x) __arm_streaming {
170
170
  // Constants for Cody-Waite range reduction
171
171
  svfloat32_t log2e_f32x = svdup_f32(1.4426950408889634f);
172
- svfloat32_t ln2_hi_f32x = svdup_f32(0.693145751953125f);
173
- svfloat32_t ln2_lo_f32x = svdup_f32(1.42860682030941723212e-6f);
172
+ svfloat32_t ln2_high_f32x = svdup_f32(0.693145751953125f);
173
+ svfloat32_t ln2_low_f32x = svdup_f32(1.42860682030941723212e-6f);
174
174
 
175
175
  // Clamp to avoid overflow/underflow
176
176
  svfloat32_t max_x_f32x = svdup_f32(88.3762626647949f);
177
177
  svfloat32_t min_x_f32x = svdup_f32(-87.3365447504021f);
178
- x_f32x = svmax_f32_m(predicate_f32x, svmin_f32_m(predicate_f32x, x_f32x, max_x_f32x), min_x_f32x);
178
+ x_f32x = svmax_f32_m(predicate_b32x, svmin_f32_m(predicate_b32x, x_f32x, max_x_f32x), min_x_f32x);
179
179
 
180
180
  // n = round(x / ln(2))
181
- svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_f32x, svmul_f32_m(predicate_f32x, x_f32x, log2e_f32x));
181
+ svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_b32x, svmul_f32_m(predicate_b32x, x_f32x, log2e_f32x));
182
182
 
183
183
  // r = x - n × ln(2) using Cody-Waite for precision
184
- svfloat32_t r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_hi_f32x, x_f32x);
185
- r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_lo_f32x, r_f32x);
184
+ svfloat32_t r_f32x = svmsb_f32_m(predicate_b32x, n_f32x, ln2_high_f32x, x_f32x);
185
+ r_f32x = svmsb_f32_m(predicate_b32x, n_f32x, ln2_low_f32x, r_f32x);
186
186
 
187
187
  // Polynomial approximation for exp(r): degree 4
188
188
  // exp(r) ≈ 1 + r + r²/2 + r³/6 + r⁴/24
189
189
  svfloat32_t p_f32x = svdup_f32(4.1666666667e-2f); // 1/24
190
- p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.6666666667e-1f)); // 1/6
191
- p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
192
- p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
193
- p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
190
+ p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.6666666667e-1f)); // 1/6
191
+ p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
192
+ p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
193
+ p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
194
194
 
195
195
  // Reconstruct: exp(x) = 2ⁿ × exp(r)
196
196
  // 2ⁿ via IEEE 754 exponent manipulation
197
- svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_f32x, n_f32x);
198
- n_i32x = svadd_s32_m(predicate_f32x, n_i32x, svdup_s32(127));
199
- n_i32x = svlsl_n_s32_m(predicate_f32x, n_i32x, 23);
197
+ svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_b32x, n_f32x);
198
+ n_i32x = svadd_s32_m(predicate_b32x, n_i32x, svdup_s32(127));
199
+ n_i32x = svlsl_n_s32_m(predicate_b32x, n_i32x, 23);
200
200
  svfloat32_t pow2n_f32x = svreinterpret_f32_s32(n_i32x);
201
201
 
202
- return svmul_f32_m(predicate_f32x, p_f32x, pow2n_f32x);
202
+ return svmul_f32_m(predicate_b32x, p_f32x, pow2n_f32x);
203
203
  }
204
204
 
205
205
  /**
206
206
  * @brief Degree-3 fast exp approximation. Max relative error ~0.5%.
207
207
  * Saves 1 FMA per call vs degree-4 nk_exp_f32_sve_.
208
208
  */
209
- NK_INTERNAL svfloat32_t nk_exp_fast_f32_sve_(svbool_t predicate_f32x, svfloat32_t x_f32x) __arm_streaming {
209
+ NK_INTERNAL svfloat32_t nk_exp_fast_f32_sve_(svbool_t predicate_b32x, svfloat32_t x_f32x) __arm_streaming {
210
210
  svfloat32_t log2e_f32x = svdup_f32(1.4426950408889634f);
211
- svfloat32_t ln2_hi_f32x = svdup_f32(0.693145751953125f);
212
- svfloat32_t ln2_lo_f32x = svdup_f32(1.42860682030941723212e-6f);
211
+ svfloat32_t ln2_high_f32x = svdup_f32(0.693145751953125f);
212
+ svfloat32_t ln2_low_f32x = svdup_f32(1.42860682030941723212e-6f);
213
213
 
214
214
  svfloat32_t max_x_f32x = svdup_f32(88.3762626647949f);
215
215
  svfloat32_t min_x_f32x = svdup_f32(-87.3365447504021f);
216
- x_f32x = svmax_f32_m(predicate_f32x, svmin_f32_m(predicate_f32x, x_f32x, max_x_f32x), min_x_f32x);
216
+ x_f32x = svmax_f32_m(predicate_b32x, svmin_f32_m(predicate_b32x, x_f32x, max_x_f32x), min_x_f32x);
217
217
 
218
- svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_f32x, svmul_f32_m(predicate_f32x, x_f32x, log2e_f32x));
219
- svfloat32_t r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_hi_f32x, x_f32x);
220
- r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_lo_f32x, r_f32x);
218
+ svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_b32x, svmul_f32_m(predicate_b32x, x_f32x, log2e_f32x));
219
+ svfloat32_t r_f32x = svmsb_f32_m(predicate_b32x, n_f32x, ln2_high_f32x, x_f32x);
220
+ r_f32x = svmsb_f32_m(predicate_b32x, n_f32x, ln2_low_f32x, r_f32x);
221
221
 
222
222
  // Degree-3: exp(r) ~ 1 + r + r^2/2 + r^3/6 (drop 1/24 term)
223
223
  svfloat32_t p_f32x = svdup_f32(1.6666666667e-1f); // 1/6
224
- p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
225
- p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
226
- p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
224
+ p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
225
+ p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
226
+ p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
227
227
 
228
- svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_f32x, n_f32x);
229
- n_i32x = svadd_s32_m(predicate_f32x, n_i32x, svdup_s32(127));
230
- n_i32x = svlsl_n_s32_m(predicate_f32x, n_i32x, 23);
228
+ svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_b32x, n_f32x);
229
+ n_i32x = svadd_s32_m(predicate_b32x, n_i32x, svdup_s32(127));
230
+ n_i32x = svlsl_n_s32_m(predicate_b32x, n_i32x, 23);
231
231
  svfloat32_t pow2n_f32x = svreinterpret_f32_s32(n_i32x);
232
232
 
233
- return svmul_f32_m(predicate_f32x, p_f32x, pow2n_f32x);
233
+ return svmul_f32_m(predicate_b32x, p_f32x, pow2n_f32x);
234
234
  }
235
235
 
236
236
  NK_PUBLIC nk_size_t nk_attention_packed_kv_size_bf16_sme(nk_size_t num_kv_heads, nk_size_t head_dim,
@@ -410,8 +410,8 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
410
410
  nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim, nk_size_t head_dim_padded, nk_size_t dim_tile_count,
411
411
  nk_f32_t scale) {
412
412
 
413
- svbool_t const predicate_all_f32x = svptrue_b32();
414
- svbool_t const predicate_all_f16x = svptrue_b16();
413
+ svbool_t const predicate_all_b32x = svptrue_b32();
414
+ svbool_t const predicate_all_b16x = svptrue_b16();
415
415
  nk_size_t const valid_query_count = (query_len < 16) ? query_len : 16;
416
416
 
417
417
  svfloat32_t row_max_f32x = svdup_f32(NK_F32_MIN);
@@ -420,12 +420,12 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
420
420
  NK_ALIGN64 nk_f32_t output_accumulator[16 * 256];
421
421
  svfloat32_t zero_f32x = svdup_f32(0.0f);
422
422
  for (nk_size_t i = 0; i < 16 * head_dim_padded; i += svcntw()) {
423
- svst1_f32(predicate_all_f32x, output_accumulator + i, zero_f32x);
423
+ svst1_f32(predicate_all_b32x, output_accumulator + i, zero_f32x);
424
424
  }
425
425
 
426
426
  nk_size_t kv_block_index = 0;
427
427
  nk_size_t kv_start = 0;
428
- svbool_t const batch_predicate_f32x = svwhilelt_b32(0u, 16u);
428
+ svbool_t const batch_predicate_b32x = svwhilelt_b32(0u, 16u);
429
429
 
430
430
  nk_size_t const k_depth_step_count = head_dim_padded / 2;
431
431
 
@@ -434,11 +434,11 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
434
434
  for (nk_size_t batch = 0; batch < head_dim_padded / 32; batch++) {
435
435
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
436
436
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
437
- svld1_hor_za32(0, query_index, batch_predicate_f32x,
437
+ svld1_hor_za32(0, query_index, batch_predicate_b32x,
438
438
  (nk_f32_t const *)(q + query_index * head_dim + batch * 32));
439
439
  for (nk_size_t step = 0; step < 16; step++)
440
- svst1_f32(predicate_all_f32x, queries_transposed + (batch * 16 + step) * 16,
441
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, step));
440
+ svst1_f32(predicate_all_b32x, queries_transposed + (batch * 16 + step) * 16,
441
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, step));
442
442
  }
443
443
 
444
444
  // Bc=32 main loop (prefill only, skipped for decode)
@@ -447,14 +447,17 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
447
447
  // Q×K^T: pure memory→BFMOPA, no ZA staging for Q or K
448
448
  svzero_mask_za(nk_sme_zero_za32_tile_2_);
449
449
  svzero_mask_za(nk_sme_zero_za32_tile_3_);
450
- nk_bf16_t const *keys_block_lower = k + kv_block_index * k_depth_step_count * 32;
451
- nk_bf16_t const *keys_block_upper = k + (kv_block_index + 1) * k_depth_step_count * 32;
450
+ nk_bf16_t const *keys_block_low = k + kv_block_index * k_depth_step_count * 32;
451
+ nk_bf16_t const *keys_block_high = k + (kv_block_index + 1) * k_depth_step_count * 32;
452
452
  for (nk_size_t step = 0; step < k_depth_step_count; step++) {
453
- svbfloat16_t zn = svreinterpret_bf16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
454
- svbfloat16_t zm0 = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(keys_block_lower + step * 32));
455
- svbfloat16_t zm1 = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(keys_block_upper + step * 32));
456
- svmopa_za32_bf16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm0);
457
- svmopa_za32_bf16_m(3, predicate_all_f32x, predicate_all_f32x, zn, zm1);
453
+ svbfloat16_t zn_bf16x = svreinterpret_bf16_f32(
454
+ svld1_f32(predicate_all_b32x, queries_transposed + step * 16));
455
+ svbfloat16_t zm0_bf16x = svld1_bf16(predicate_all_b16x,
456
+ (bfloat16_t const *)(keys_block_low + step * 32));
457
+ svbfloat16_t zm1_bf16x = svld1_bf16(predicate_all_b16x,
458
+ (bfloat16_t const *)(keys_block_high + step * 32));
459
+ svmopa_za32_bf16_m(2, predicate_all_b32x, predicate_all_b32x, zn_bf16x, zm0_bf16x);
460
+ svmopa_za32_bf16_m(3, predicate_all_b32x, predicate_all_b32x, zn_bf16x, zm1_bf16x);
458
461
  }
459
462
 
460
463
  // Pass 1: Column-wise max (read ZA2/ZA3 columns vertically)
@@ -462,26 +465,26 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
462
465
  svfloat32_t block_max_f32x = svdup_f32(NK_F32_MIN);
463
466
  for (nk_size_t column_index = 0; column_index < 16; column_index++) {
464
467
  svfloat32_t score_column_f32x = svmul_f32_x(
465
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
468
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
466
469
  scale_f32x);
467
- block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
470
+ block_max_f32x = svmax_f32_x(predicate_all_b32x, block_max_f32x, score_column_f32x);
468
471
  }
469
472
  for (nk_size_t column_index = 0; column_index < 16; column_index++) {
470
473
  svfloat32_t score_column_f32x = svmul_f32_x(
471
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
474
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index),
472
475
  scale_f32x);
473
- block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
476
+ block_max_f32x = svmax_f32_x(predicate_all_b32x, block_max_f32x, score_column_f32x);
474
477
  }
475
478
 
476
479
  // Softmax correction (fully vectorized)
477
- svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, row_max_f32x, block_max_f32x);
480
+ svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_b32x, row_max_f32x, block_max_f32x);
478
481
  svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
479
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, row_max_f32x, new_max_f32x));
480
- svbool_t max_changed = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
481
- nk_u32_t max_was_updated = svptest_any(predicate_all_f32x, max_changed) ? 1 : 0;
482
- if (max_was_updated) row_sum_f32x = svmul_f32_x(predicate_all_f32x, row_sum_f32x, correction_f32x);
482
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, row_max_f32x, new_max_f32x));
483
+ svbool_t max_changed_b32x = svcmplt_f32(predicate_all_b32x, correction_f32x, svdup_f32(1.0f));
484
+ nk_u32_t max_was_updated = svptest_any(predicate_all_b32x, max_changed_b32x) ? 1 : 0;
485
+ if (max_was_updated) row_sum_f32x = svmul_f32_x(predicate_all_b32x, row_sum_f32x, correction_f32x);
483
486
  NK_ALIGN64 nk_f32_t corrections[16];
484
- svst1_f32(predicate_all_f32x, corrections, correction_f32x);
487
+ svst1_f32(predicate_all_b32x, corrections, correction_f32x);
485
488
 
486
489
  // Pass 2: Column-wise exp + fused P write + sum
487
490
  svfloat32_t sum_delta_f32x = svdup_f32(0.0f);
@@ -489,91 +492,91 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
489
492
  // ZA2 columns in pairs → ZA0 columns 0-7
490
493
  for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
491
494
  svfloat32_t score_even_f32x = svmul_f32_x(
492
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
495
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
493
496
  scale_f32x);
494
497
  svfloat32_t score_odd_f32x = svmul_f32_x(
495
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
498
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index + 1),
496
499
  scale_f32x);
497
500
  svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
498
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
501
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
499
502
  svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
500
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
501
- sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
502
- sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
503
- svbfloat16_t weight_pair_bf16 = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_f32x, weight_even_f32x),
504
- nk_f32_to_bf16_sve_(predicate_all_f32x, weight_odd_f32x));
505
- svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x,
506
- svreinterpret_f32_bf16(weight_pair_bf16));
503
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
504
+ sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_even_f32x);
505
+ sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_odd_f32x);
506
+ svbfloat16_t weight_pair_bf16x = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_b32x, weight_even_f32x),
507
+ nk_f32_to_bf16_sve_(predicate_all_b32x, weight_odd_f32x));
508
+ svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_b32x,
509
+ svreinterpret_f32_bf16(weight_pair_bf16x));
507
510
  }
508
511
  // ZA3 columns in pairs → ZA0 columns 8-15
509
512
  for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
510
513
  svfloat32_t score_even_f32x = svmul_f32_x(
511
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
514
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index),
512
515
  scale_f32x);
513
516
  svfloat32_t score_odd_f32x = svmul_f32_x(
514
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index + 1),
517
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index + 1),
515
518
  scale_f32x);
516
519
  svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
517
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
520
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
518
521
  svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
519
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
520
- sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
521
- sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
522
- svbfloat16_t weight_pair_bf16 = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_f32x, weight_even_f32x),
523
- nk_f32_to_bf16_sve_(predicate_all_f32x, weight_odd_f32x));
524
- svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_f32x,
525
- svreinterpret_f32_bf16(weight_pair_bf16));
522
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
523
+ sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_even_f32x);
524
+ sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_odd_f32x);
525
+ svbfloat16_t weight_pair_bf16x = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_b32x, weight_even_f32x),
526
+ nk_f32_to_bf16_sve_(predicate_all_b32x, weight_odd_f32x));
527
+ svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_b32x,
528
+ svreinterpret_f32_bf16(weight_pair_bf16x));
526
529
  }
527
- row_sum_f32x = svadd_f32_x(predicate_all_f32x, row_sum_f32x, sum_delta_f32x);
530
+ row_sum_f32x = svadd_f32_x(predicate_all_b32x, row_sum_f32x, sum_delta_f32x);
528
531
  row_max_f32x = new_max_f32x;
529
532
 
530
533
  // Extract P columns from ZA0
531
534
  svbfloat16_t probability_column_0_f32x = svreinterpret_bf16_f32(
532
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
535
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
533
536
  svbfloat16_t probability_column_1_f32x = svreinterpret_bf16_f32(
534
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
537
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 1));
535
538
  svbfloat16_t probability_column_2_f32x = svreinterpret_bf16_f32(
536
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
539
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 2));
537
540
  svbfloat16_t probability_column_3_f32x = svreinterpret_bf16_f32(
538
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
541
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 3));
539
542
  svbfloat16_t probability_column_4_f32x = svreinterpret_bf16_f32(
540
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
543
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 4));
541
544
  svbfloat16_t probability_column_5_f32x = svreinterpret_bf16_f32(
542
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
545
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 5));
543
546
  svbfloat16_t probability_column_6_f32x = svreinterpret_bf16_f32(
544
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
547
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 6));
545
548
  svbfloat16_t probability_column_7_f32x = svreinterpret_bf16_f32(
546
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
549
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 7));
547
550
  svbfloat16_t probability_column_8_f32x = svreinterpret_bf16_f32(
548
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 8));
551
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 8));
549
552
  svbfloat16_t probability_column_9_f32x = svreinterpret_bf16_f32(
550
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 9));
553
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 9));
551
554
  svbfloat16_t probability_column_10_f32x = svreinterpret_bf16_f32(
552
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 10));
555
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 10));
553
556
  svbfloat16_t probability_column_11_f32x = svreinterpret_bf16_f32(
554
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 11));
557
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 11));
555
558
  svbfloat16_t probability_column_12_f32x = svreinterpret_bf16_f32(
556
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 12));
559
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 12));
557
560
  svbfloat16_t probability_column_13_f32x = svreinterpret_bf16_f32(
558
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 13));
561
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 13));
559
562
  svbfloat16_t probability_column_14_f32x = svreinterpret_bf16_f32(
560
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 14));
563
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 14));
561
564
  svbfloat16_t probability_column_15_f32x = svreinterpret_bf16_f32(
562
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 15));
565
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 15));
563
566
 
564
567
  // Pre-apply correction once before P×V
565
- svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
566
- nk_bf16_t const *values_block_lower = v_packed + kv_block_index * dim_tile_count * 8 * 32;
567
- nk_bf16_t const *values_block_upper = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
568
+ svbool_t query_predicate_b16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
569
+ nk_bf16_t const *values_block_low = v_packed + kv_block_index * dim_tile_count * 8 * 32;
570
+ nk_bf16_t const *values_block_high = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
568
571
 
569
572
  if (max_was_updated) {
570
573
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
571
574
  svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
572
575
  for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
573
576
  svst1_f32(
574
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
575
- svmul_f32_x(predicate_all_f32x,
576
- svld1_f32(predicate_all_f32x,
577
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_offset,
578
+ svmul_f32_x(predicate_all_b32x,
579
+ svld1_f32(predicate_all_b32x,
577
580
  output_accumulator + query_index * head_dim_padded + dim_offset),
578
581
  correction_scalar_f32x));
579
582
  }
@@ -584,284 +587,284 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
584
587
  for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
585
588
  svzero_za();
586
589
  // Block0: 8 depth steps (KV positions 0-15)
587
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
588
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
589
- ((dim_tile + 0) * 8 + 0) * 32)));
590
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
591
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
592
- ((dim_tile + 1) * 8 + 0) * 32)));
593
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
594
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
595
- ((dim_tile + 2) * 8 + 0) * 32)));
596
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
597
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
598
- ((dim_tile + 3) * 8 + 0) * 32)));
599
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
600
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
601
- ((dim_tile + 0) * 8 + 1) * 32)));
602
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
603
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
604
- ((dim_tile + 1) * 8 + 1) * 32)));
605
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
606
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
607
- ((dim_tile + 2) * 8 + 1) * 32)));
608
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
609
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
610
- ((dim_tile + 3) * 8 + 1) * 32)));
611
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
612
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
613
- ((dim_tile + 0) * 8 + 2) * 32)));
614
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
615
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
616
- ((dim_tile + 1) * 8 + 2) * 32)));
617
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
618
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
619
- ((dim_tile + 2) * 8 + 2) * 32)));
620
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
621
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
622
- ((dim_tile + 3) * 8 + 2) * 32)));
623
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
624
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
625
- ((dim_tile + 0) * 8 + 3) * 32)));
626
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
627
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
628
- ((dim_tile + 1) * 8 + 3) * 32)));
629
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
630
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
631
- ((dim_tile + 2) * 8 + 3) * 32)));
632
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
633
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
634
- ((dim_tile + 3) * 8 + 3) * 32)));
635
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
636
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
637
- ((dim_tile + 0) * 8 + 4) * 32)));
638
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
639
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
640
- ((dim_tile + 1) * 8 + 4) * 32)));
641
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
642
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
643
- ((dim_tile + 2) * 8 + 4) * 32)));
644
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
645
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
646
- ((dim_tile + 3) * 8 + 4) * 32)));
647
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
648
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
649
- ((dim_tile + 0) * 8 + 5) * 32)));
650
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
651
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
652
- ((dim_tile + 1) * 8 + 5) * 32)));
653
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
654
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
655
- ((dim_tile + 2) * 8 + 5) * 32)));
656
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
657
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
658
- ((dim_tile + 3) * 8 + 5) * 32)));
659
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
660
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
661
- ((dim_tile + 0) * 8 + 6) * 32)));
662
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
663
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
664
- ((dim_tile + 1) * 8 + 6) * 32)));
665
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
666
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
667
- ((dim_tile + 2) * 8 + 6) * 32)));
668
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
669
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
670
- ((dim_tile + 3) * 8 + 6) * 32)));
671
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
672
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
673
- ((dim_tile + 0) * 8 + 7) * 32)));
674
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
675
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
676
- ((dim_tile + 1) * 8 + 7) * 32)));
677
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
678
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
679
- ((dim_tile + 2) * 8 + 7) * 32)));
680
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
681
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
682
- ((dim_tile + 3) * 8 + 7) * 32)));
590
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
591
+ svld1_bf16(predicate_all_b16x,
592
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 0) * 32)));
593
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
594
+ svld1_bf16(predicate_all_b16x,
595
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 0) * 32)));
596
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
597
+ svld1_bf16(predicate_all_b16x,
598
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 0) * 32)));
599
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
600
+ svld1_bf16(predicate_all_b16x,
601
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 0) * 32)));
602
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
603
+ svld1_bf16(predicate_all_b16x,
604
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 1) * 32)));
605
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
606
+ svld1_bf16(predicate_all_b16x,
607
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 1) * 32)));
608
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
609
+ svld1_bf16(predicate_all_b16x,
610
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 1) * 32)));
611
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
612
+ svld1_bf16(predicate_all_b16x,
613
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 1) * 32)));
614
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
615
+ svld1_bf16(predicate_all_b16x,
616
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 2) * 32)));
617
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
618
+ svld1_bf16(predicate_all_b16x,
619
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 2) * 32)));
620
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
621
+ svld1_bf16(predicate_all_b16x,
622
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 2) * 32)));
623
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
624
+ svld1_bf16(predicate_all_b16x,
625
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 2) * 32)));
626
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
627
+ svld1_bf16(predicate_all_b16x,
628
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 3) * 32)));
629
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
630
+ svld1_bf16(predicate_all_b16x,
631
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 3) * 32)));
632
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
633
+ svld1_bf16(predicate_all_b16x,
634
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 3) * 32)));
635
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
636
+ svld1_bf16(predicate_all_b16x,
637
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 3) * 32)));
638
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
639
+ svld1_bf16(predicate_all_b16x,
640
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 4) * 32)));
641
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
642
+ svld1_bf16(predicate_all_b16x,
643
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 4) * 32)));
644
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
645
+ svld1_bf16(predicate_all_b16x,
646
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 4) * 32)));
647
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
648
+ svld1_bf16(predicate_all_b16x,
649
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 4) * 32)));
650
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
651
+ svld1_bf16(predicate_all_b16x,
652
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 5) * 32)));
653
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
654
+ svld1_bf16(predicate_all_b16x,
655
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 5) * 32)));
656
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
657
+ svld1_bf16(predicate_all_b16x,
658
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 5) * 32)));
659
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
660
+ svld1_bf16(predicate_all_b16x,
661
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 5) * 32)));
662
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
663
+ svld1_bf16(predicate_all_b16x,
664
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 6) * 32)));
665
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
666
+ svld1_bf16(predicate_all_b16x,
667
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 6) * 32)));
668
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
669
+ svld1_bf16(predicate_all_b16x,
670
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 6) * 32)));
671
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
672
+ svld1_bf16(predicate_all_b16x,
673
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 6) * 32)));
674
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
675
+ svld1_bf16(predicate_all_b16x,
676
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 7) * 32)));
677
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
678
+ svld1_bf16(predicate_all_b16x,
679
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 7) * 32)));
680
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
681
+ svld1_bf16(predicate_all_b16x,
682
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 7) * 32)));
683
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
684
+ svld1_bf16(predicate_all_b16x,
685
+ (bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 7) * 32)));
683
686
  // Block1: 8 depth steps (KV positions 16-31)
684
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
685
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
686
- ((dim_tile + 0) * 8 + 0) * 32)));
687
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
688
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
689
- ((dim_tile + 1) * 8 + 0) * 32)));
690
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
691
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
692
- ((dim_tile + 2) * 8 + 0) * 32)));
693
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
694
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
695
- ((dim_tile + 3) * 8 + 0) * 32)));
696
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
697
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
698
- ((dim_tile + 0) * 8 + 1) * 32)));
699
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
700
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
701
- ((dim_tile + 1) * 8 + 1) * 32)));
702
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
703
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
704
- ((dim_tile + 2) * 8 + 1) * 32)));
705
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
706
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
707
- ((dim_tile + 3) * 8 + 1) * 32)));
708
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
709
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
710
- ((dim_tile + 0) * 8 + 2) * 32)));
711
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
712
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
713
- ((dim_tile + 1) * 8 + 2) * 32)));
714
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
715
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
716
- ((dim_tile + 2) * 8 + 2) * 32)));
717
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
718
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
719
- ((dim_tile + 3) * 8 + 2) * 32)));
720
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
721
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
722
- ((dim_tile + 0) * 8 + 3) * 32)));
723
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
724
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
725
- ((dim_tile + 1) * 8 + 3) * 32)));
726
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
727
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
728
- ((dim_tile + 2) * 8 + 3) * 32)));
729
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
730
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
731
- ((dim_tile + 3) * 8 + 3) * 32)));
732
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
733
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
734
- ((dim_tile + 0) * 8 + 4) * 32)));
735
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
736
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
737
- ((dim_tile + 1) * 8 + 4) * 32)));
738
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
739
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
740
- ((dim_tile + 2) * 8 + 4) * 32)));
741
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
742
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
743
- ((dim_tile + 3) * 8 + 4) * 32)));
744
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
745
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
746
- ((dim_tile + 0) * 8 + 5) * 32)));
747
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
748
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
749
- ((dim_tile + 1) * 8 + 5) * 32)));
750
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
751
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
752
- ((dim_tile + 2) * 8 + 5) * 32)));
753
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
754
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
755
- ((dim_tile + 3) * 8 + 5) * 32)));
756
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
757
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
758
- ((dim_tile + 0) * 8 + 6) * 32)));
759
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
760
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
761
- ((dim_tile + 1) * 8 + 6) * 32)));
762
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
763
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
764
- ((dim_tile + 2) * 8 + 6) * 32)));
765
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
766
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
767
- ((dim_tile + 3) * 8 + 6) * 32)));
768
- svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
769
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
770
- ((dim_tile + 0) * 8 + 7) * 32)));
771
- svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
772
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
773
- ((dim_tile + 1) * 8 + 7) * 32)));
774
- svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
775
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
776
- ((dim_tile + 2) * 8 + 7) * 32)));
777
- svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
778
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
779
- ((dim_tile + 3) * 8 + 7) * 32)));
687
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
688
+ svld1_bf16(predicate_all_b16x,
689
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 0) * 32)));
690
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
691
+ svld1_bf16(predicate_all_b16x,
692
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 0) * 32)));
693
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
694
+ svld1_bf16(predicate_all_b16x,
695
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 0) * 32)));
696
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
697
+ svld1_bf16(predicate_all_b16x,
698
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 0) * 32)));
699
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
700
+ svld1_bf16(predicate_all_b16x,
701
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 1) * 32)));
702
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
703
+ svld1_bf16(predicate_all_b16x,
704
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 1) * 32)));
705
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
706
+ svld1_bf16(predicate_all_b16x,
707
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 1) * 32)));
708
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
709
+ svld1_bf16(predicate_all_b16x,
710
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 1) * 32)));
711
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
712
+ svld1_bf16(predicate_all_b16x,
713
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 2) * 32)));
714
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
715
+ svld1_bf16(predicate_all_b16x,
716
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 2) * 32)));
717
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
718
+ svld1_bf16(predicate_all_b16x,
719
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 2) * 32)));
720
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
721
+ svld1_bf16(predicate_all_b16x,
722
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 2) * 32)));
723
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
724
+ svld1_bf16(predicate_all_b16x,
725
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 3) * 32)));
726
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
727
+ svld1_bf16(predicate_all_b16x,
728
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 3) * 32)));
729
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
730
+ svld1_bf16(predicate_all_b16x,
731
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 3) * 32)));
732
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
733
+ svld1_bf16(predicate_all_b16x,
734
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 3) * 32)));
735
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
736
+ svld1_bf16(predicate_all_b16x,
737
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 4) * 32)));
738
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
739
+ svld1_bf16(predicate_all_b16x,
740
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 4) * 32)));
741
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
742
+ svld1_bf16(predicate_all_b16x,
743
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 4) * 32)));
744
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
745
+ svld1_bf16(predicate_all_b16x,
746
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 4) * 32)));
747
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
748
+ svld1_bf16(predicate_all_b16x,
749
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 5) * 32)));
750
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
751
+ svld1_bf16(predicate_all_b16x,
752
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 5) * 32)));
753
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
754
+ svld1_bf16(predicate_all_b16x,
755
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 5) * 32)));
756
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
757
+ svld1_bf16(predicate_all_b16x,
758
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 5) * 32)));
759
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
760
+ svld1_bf16(predicate_all_b16x,
761
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 6) * 32)));
762
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
763
+ svld1_bf16(predicate_all_b16x,
764
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 6) * 32)));
765
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
766
+ svld1_bf16(predicate_all_b16x,
767
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 6) * 32)));
768
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
769
+ svld1_bf16(predicate_all_b16x,
770
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 6) * 32)));
771
+ svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
772
+ svld1_bf16(predicate_all_b16x,
773
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 7) * 32)));
774
+ svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
775
+ svld1_bf16(predicate_all_b16x,
776
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 7) * 32)));
777
+ svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
778
+ svld1_bf16(predicate_all_b16x,
779
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 7) * 32)));
780
+ svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
781
+ svld1_bf16(predicate_all_b16x,
782
+ (bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 7) * 32)));
780
783
  // Read BFMOPA result and ADD to output_accumulator
781
784
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
782
785
  svst1_f32(
783
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
784
- svadd_f32_x(predicate_all_f32x,
785
- svld1_f32(predicate_all_f32x,
786
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
787
+ svadd_f32_x(predicate_all_b32x,
788
+ svld1_f32(predicate_all_b32x,
786
789
  output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
787
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
790
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
788
791
  svst1_f32(
789
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
790
- svadd_f32_x(predicate_all_f32x,
791
- svld1_f32(predicate_all_f32x,
792
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
793
+ svadd_f32_x(predicate_all_b32x,
794
+ svld1_f32(predicate_all_b32x,
792
795
  output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
793
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
796
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 1, query_index)));
794
797
  svst1_f32(
795
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
796
- svadd_f32_x(predicate_all_f32x,
797
- svld1_f32(predicate_all_f32x,
798
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
799
+ svadd_f32_x(predicate_all_b32x,
800
+ svld1_f32(predicate_all_b32x,
798
801
  output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
799
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
802
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, query_index)));
800
803
  svst1_f32(
801
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
802
- svadd_f32_x(predicate_all_f32x,
803
- svld1_f32(predicate_all_f32x,
804
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
805
+ svadd_f32_x(predicate_all_b32x,
806
+ svld1_f32(predicate_all_b32x,
804
807
  output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
805
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
808
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, query_index)));
806
809
  }
807
810
  }
808
811
  // Remainder: 1 dim_tile at a time using ZA0
809
812
  for (; dim_tile < dim_tile_count; dim_tile++) {
810
813
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
811
814
  svmopa_za32_bf16_m(
812
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
813
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 0) * 32)));
815
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
816
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 0) * 32)));
814
817
  svmopa_za32_bf16_m(
815
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
816
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 1) * 32)));
818
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
819
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 1) * 32)));
817
820
  svmopa_za32_bf16_m(
818
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
819
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 2) * 32)));
821
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
822
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 2) * 32)));
820
823
  svmopa_za32_bf16_m(
821
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
822
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 3) * 32)));
824
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
825
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 3) * 32)));
823
826
  svmopa_za32_bf16_m(
824
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
825
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 4) * 32)));
827
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
828
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 4) * 32)));
826
829
  svmopa_za32_bf16_m(
827
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
828
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 5) * 32)));
830
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
831
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 5) * 32)));
829
832
  svmopa_za32_bf16_m(
830
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
831
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 6) * 32)));
833
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
834
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 6) * 32)));
832
835
  svmopa_za32_bf16_m(
833
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
834
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 7) * 32)));
836
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
837
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 7) * 32)));
835
838
  svmopa_za32_bf16_m(
836
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
837
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 0) * 32)));
839
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
840
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 0) * 32)));
838
841
  svmopa_za32_bf16_m(
839
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
840
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 1) * 32)));
842
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
843
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 1) * 32)));
841
844
  svmopa_za32_bf16_m(
842
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
843
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 2) * 32)));
845
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
846
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 2) * 32)));
844
847
  svmopa_za32_bf16_m(
845
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
846
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 3) * 32)));
848
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
849
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 3) * 32)));
847
850
  svmopa_za32_bf16_m(
848
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
849
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 4) * 32)));
851
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
852
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 4) * 32)));
850
853
  svmopa_za32_bf16_m(
851
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
852
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 5) * 32)));
854
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
855
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 5) * 32)));
853
856
  svmopa_za32_bf16_m(
854
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
855
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 6) * 32)));
857
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
858
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 6) * 32)));
856
859
  svmopa_za32_bf16_m(
857
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
858
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 7) * 32)));
860
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
861
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 7) * 32)));
859
862
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
860
- svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
861
- svadd_f32_x(predicate_all_f32x,
862
- svld1_f32(predicate_all_f32x,
863
+ svst1_f32(predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
864
+ svadd_f32_x(predicate_all_b32x,
865
+ svld1_f32(predicate_all_b32x,
863
866
  output_accumulator + query_index * head_dim_padded + dim_tile * 16),
864
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
867
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
865
868
  }
866
869
  }
867
870
  }
@@ -874,9 +877,10 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
874
877
  svzero_mask_za(nk_sme_zero_za32_tile_2_);
875
878
  nk_bf16_t const *k_block = k + kv_block_index * k_depth_step_count * 32;
876
879
  for (nk_size_t step = 0; step < k_depth_step_count; step++) {
877
- svbfloat16_t zn = svreinterpret_bf16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
878
- svbfloat16_t zm = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(k_block + step * 32));
879
- svmopa_za32_bf16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm);
880
+ svbfloat16_t zn_bf16x = svreinterpret_bf16_f32(
881
+ svld1_f32(predicate_all_b32x, queries_transposed + step * 16));
882
+ svbfloat16_t zm_bf16x = svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(k_block + step * 32));
883
+ svmopa_za32_bf16_m(2, predicate_all_b32x, predicate_all_b32x, zn_bf16x, zm_bf16x);
880
884
  }
881
885
 
882
886
  // Pass 1: Column-wise max (read ZA2 columns vertically)
@@ -884,55 +888,55 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
884
888
  svfloat32_t block_max_16_f32x = svdup_f32(NK_F32_MIN);
885
889
  for (nk_size_t column_index = 0; column_index < 16; column_index++) {
886
890
  svfloat32_t score_column_f32x = svmul_f32_x(
887
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
891
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
888
892
  scale_16_f32x);
889
- block_max_16_f32x = svmax_f32_x(predicate_all_f32x, block_max_16_f32x, score_column_f32x);
893
+ block_max_16_f32x = svmax_f32_x(predicate_all_b32x, block_max_16_f32x, score_column_f32x);
890
894
  }
891
895
 
892
896
  // Softmax correction (fully vectorized)
893
- svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, row_max_f32x, block_max_16_f32x);
894
- svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_f32x,
895
- svsub_f32_x(predicate_all_f32x, row_max_f32x, new_max_f32x));
896
- svbool_t max_changed_16 = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
897
- nk_u32_t max_was_updated_16 = svptest_any(predicate_all_f32x, max_changed_16) ? 1 : 0;
898
- if (max_was_updated_16) row_sum_f32x = svmul_f32_x(predicate_all_f32x, row_sum_f32x, correction_f32x);
897
+ svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_b32x, row_max_f32x, block_max_16_f32x);
898
+ svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_b32x,
899
+ svsub_f32_x(predicate_all_b32x, row_max_f32x, new_max_f32x));
900
+ svbool_t max_changed_16_b32x = svcmplt_f32(predicate_all_b32x, correction_f32x, svdup_f32(1.0f));
901
+ nk_u32_t max_was_updated_16 = svptest_any(predicate_all_b32x, max_changed_16_b32x) ? 1 : 0;
902
+ if (max_was_updated_16) row_sum_f32x = svmul_f32_x(predicate_all_b32x, row_sum_f32x, correction_f32x);
899
903
  NK_ALIGN64 nk_f32_t corrections[16];
900
- svst1_f32(predicate_all_f32x, corrections, correction_f32x);
904
+ svst1_f32(predicate_all_b32x, corrections, correction_f32x);
901
905
 
902
906
  // Pass 2: Column-wise exp + fused P write + sum (ZA2 → ZA0 columns 0-7)
903
907
  svfloat32_t sum_delta_16_f32x = svdup_f32(0.0f);
904
908
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
905
909
  for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
906
910
  svfloat32_t score_even_f32x = svmul_f32_x(
907
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
911
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
908
912
  scale_16_f32x);
909
913
  svfloat32_t score_odd_f32x = svmul_f32_x(
910
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
914
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index + 1),
911
915
  scale_16_f32x);
912
916
  svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
913
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
917
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
914
918
  svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
915
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
916
- sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_even_f32x);
917
- sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_odd_f32x);
918
- svbfloat16_t weight_pair_bf16 = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_f32x, weight_even_f32x),
919
- nk_f32_to_bf16_sve_(predicate_all_f32x, weight_odd_f32x));
920
- svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x, svreinterpret_f32_bf16(weight_pair_bf16));
919
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
920
+ sum_delta_16_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_16_f32x, weight_even_f32x);
921
+ sum_delta_16_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_16_f32x, weight_odd_f32x);
922
+ svbfloat16_t weight_pair_bf16x = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_b32x, weight_even_f32x),
923
+ nk_f32_to_bf16_sve_(predicate_all_b32x, weight_odd_f32x));
924
+ svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_b32x, svreinterpret_f32_bf16(weight_pair_bf16x));
921
925
  }
922
- row_sum_f32x = svadd_f32_x(predicate_all_f32x, row_sum_f32x, sum_delta_16_f32x);
926
+ row_sum_f32x = svadd_f32_x(predicate_all_b32x, row_sum_f32x, sum_delta_16_f32x);
923
927
  row_max_f32x = new_max_f32x;
924
928
 
925
929
  if (valid_query_count == 1) {
926
930
  // Decode path: extract f32 weights from ZA0 row 0 using SVE
927
- svbfloat16_t row0_bf16 = svreinterpret_bf16_f32(
928
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
929
- svbfloat16_t weights_even_bf16 = svuzp1_bf16(row0_bf16, row0_bf16);
930
- svbfloat16_t weights_odd_bf16 = svuzp2_bf16(row0_bf16, row0_bf16);
931
+ svbfloat16_t row0_bf16x = svreinterpret_bf16_f32(
932
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
933
+ svbfloat16_t weights_even_bf16x = svuzp1_bf16(row0_bf16x, row0_bf16x);
934
+ svbfloat16_t weights_odd_bf16x = svuzp2_bf16(row0_bf16x, row0_bf16x);
931
935
  NK_ALIGN64 nk_f32_t decode_weights[16];
932
936
  svst1_f32(svwhilelt_b32(0u, 8u), decode_weights,
933
- nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_even_bf16));
937
+ nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_even_bf16x));
934
938
  svst1_f32(svwhilelt_b32(0u, 8u), decode_weights + 8,
935
- nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_odd_bf16));
939
+ nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_odd_bf16x));
936
940
  NK_ALIGN64 nk_f32_t decode_weights_ordered[16];
937
941
  for (nk_size_t i = 0; i < 8; i++) {
938
942
  decode_weights_ordered[2 * i] = decode_weights[i];
@@ -940,42 +944,42 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
940
944
  }
941
945
  svfloat32_t corr_f32x = svdup_f32(corrections[0]);
942
946
  for (nk_size_t d = 0; d < head_dim; d += svcntw()) {
943
- svbool_t predicate_f32x = svwhilelt_b32_u64(d, head_dim);
944
- svfloat32_t acc_f32x = svmul_f32_x(predicate_f32x, svld1_f32(predicate_f32x, output_accumulator + d),
947
+ svbool_t predicate_b32x = svwhilelt_b32_u64(d, head_dim);
948
+ svfloat32_t acc_f32x = svmul_f32_x(predicate_b32x, svld1_f32(predicate_b32x, output_accumulator + d),
945
949
  corr_f32x);
946
950
  for (nk_size_t ki = 0; ki < valid_kv; ki++) {
947
951
  nk_size_t dim_tile = d / 16, depth_s = ki / 2, sub = ki % 2;
948
952
  nk_bf16_t const *v_vec = v_packed +
949
953
  (kv_block_index * dim_tile_count * 8 + dim_tile * 8 + depth_s) * 32;
950
- svbfloat16_t packed_bf16x = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)v_vec);
951
- svbfloat16_t v_selected = (sub == 0) ? svuzp1_bf16(packed_bf16x, packed_bf16x)
952
- : svuzp2_bf16(packed_bf16x, packed_bf16x);
953
- acc_f32x = svmla_f32_x(predicate_f32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
954
- nk_bf16_to_f32_sve_(predicate_f32x, v_selected));
954
+ svbfloat16_t packed_bf16x = svld1_bf16(predicate_all_b16x, (bfloat16_t const *)v_vec);
955
+ svbfloat16_t v_selected_bf16x = (sub == 0) ? svuzp1_bf16(packed_bf16x, packed_bf16x)
956
+ : svuzp2_bf16(packed_bf16x, packed_bf16x);
957
+ acc_f32x = svmla_f32_x(predicate_b32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
958
+ nk_bf16_to_f32_sve_(predicate_b32x, v_selected_bf16x));
955
959
  }
956
- svst1_f32(predicate_f32x, output_accumulator + d, acc_f32x);
960
+ svst1_f32(predicate_b32x, output_accumulator + d, acc_f32x);
957
961
  }
958
962
  }
959
963
  else {
960
964
  // Prefill Bc=16: extract P columns, pre-apply correction, add-after P×V
961
- svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
965
+ svbool_t query_predicate_b16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
962
966
 
963
967
  svbfloat16_t probability_column_0_f32x = svreinterpret_bf16_f32(
964
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
968
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
965
969
  svbfloat16_t probability_column_1_f32x = svreinterpret_bf16_f32(
966
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
970
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 1));
967
971
  svbfloat16_t probability_column_2_f32x = svreinterpret_bf16_f32(
968
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
972
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 2));
969
973
  svbfloat16_t probability_column_3_f32x = svreinterpret_bf16_f32(
970
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
974
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 3));
971
975
  svbfloat16_t probability_column_4_f32x = svreinterpret_bf16_f32(
972
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
976
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 4));
973
977
  svbfloat16_t probability_column_5_f32x = svreinterpret_bf16_f32(
974
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
978
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 5));
975
979
  svbfloat16_t probability_column_6_f32x = svreinterpret_bf16_f32(
976
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
980
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 6));
977
981
  svbfloat16_t probability_column_7_f32x = svreinterpret_bf16_f32(
978
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
982
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 7));
979
983
 
980
984
  nk_bf16_t const *v_block = v_packed + kv_block_index * dim_tile_count * 8 * 32;
981
985
 
@@ -985,9 +989,9 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
985
989
  svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
986
990
  for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
987
991
  svst1_f32(
988
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
989
- svmul_f32_x(predicate_all_f32x,
990
- svld1_f32(predicate_all_f32x,
992
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_offset,
993
+ svmul_f32_x(predicate_all_b32x,
994
+ svld1_f32(predicate_all_b32x,
991
995
  output_accumulator + query_index * head_dim_padded + dim_offset),
992
996
  correction_scalar_f32x));
993
997
  }
@@ -998,183 +1002,183 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
998
1002
  for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
999
1003
  svzero_za();
1000
1004
  svmopa_za32_bf16_m(
1001
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1002
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
1005
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1006
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
1003
1007
  svmopa_za32_bf16_m(
1004
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1005
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
1008
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1009
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
1006
1010
  svmopa_za32_bf16_m(
1007
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1008
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
1011
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1012
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
1009
1013
  svmopa_za32_bf16_m(
1010
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1011
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
1014
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1015
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
1012
1016
  svmopa_za32_bf16_m(
1013
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1014
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
1017
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1018
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
1015
1019
  svmopa_za32_bf16_m(
1016
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1017
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
1020
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1021
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
1018
1022
  svmopa_za32_bf16_m(
1019
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1020
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
1023
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1024
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
1021
1025
  svmopa_za32_bf16_m(
1022
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1023
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
1026
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1027
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
1024
1028
  svmopa_za32_bf16_m(
1025
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1026
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
1029
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1030
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
1027
1031
  svmopa_za32_bf16_m(
1028
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1029
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
1032
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1033
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
1030
1034
  svmopa_za32_bf16_m(
1031
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1032
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
1035
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1036
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
1033
1037
  svmopa_za32_bf16_m(
1034
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1035
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
1038
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1039
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
1036
1040
  svmopa_za32_bf16_m(
1037
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1038
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
1041
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1042
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
1039
1043
  svmopa_za32_bf16_m(
1040
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1041
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
1044
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1045
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
1042
1046
  svmopa_za32_bf16_m(
1043
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1044
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
1047
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1048
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
1045
1049
  svmopa_za32_bf16_m(
1046
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1047
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
1050
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1051
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
1048
1052
  svmopa_za32_bf16_m(
1049
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1050
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
1053
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1054
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
1051
1055
  svmopa_za32_bf16_m(
1052
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1053
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
1056
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1057
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
1054
1058
  svmopa_za32_bf16_m(
1055
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1056
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
1059
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1060
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
1057
1061
  svmopa_za32_bf16_m(
1058
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1059
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
1062
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1063
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
1060
1064
  svmopa_za32_bf16_m(
1061
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1062
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
1065
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1066
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
1063
1067
  svmopa_za32_bf16_m(
1064
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1065
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
1068
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1069
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
1066
1070
  svmopa_za32_bf16_m(
1067
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1068
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
1071
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1072
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
1069
1073
  svmopa_za32_bf16_m(
1070
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1071
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
1074
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1075
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
1072
1076
  svmopa_za32_bf16_m(
1073
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1074
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
1077
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1078
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
1075
1079
  svmopa_za32_bf16_m(
1076
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1077
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
1080
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1081
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
1078
1082
  svmopa_za32_bf16_m(
1079
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1080
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
1083
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1084
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
1081
1085
  svmopa_za32_bf16_m(
1082
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1083
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
1086
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1087
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
1084
1088
  svmopa_za32_bf16_m(
1085
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1086
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
1089
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1090
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
1087
1091
  svmopa_za32_bf16_m(
1088
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1089
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
1092
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1093
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
1090
1094
  svmopa_za32_bf16_m(
1091
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1092
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
1095
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1096
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
1093
1097
  svmopa_za32_bf16_m(
1094
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1095
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
1098
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1099
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
1096
1100
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1097
1101
  svst1_f32(
1098
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
1099
- svadd_f32_x(predicate_all_f32x,
1100
- svld1_f32(predicate_all_f32x,
1102
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
1103
+ svadd_f32_x(predicate_all_b32x,
1104
+ svld1_f32(predicate_all_b32x,
1101
1105
  output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
1102
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1106
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
1103
1107
  svst1_f32(
1104
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
1105
- svadd_f32_x(predicate_all_f32x,
1106
- svld1_f32(predicate_all_f32x,
1108
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
1109
+ svadd_f32_x(predicate_all_b32x,
1110
+ svld1_f32(predicate_all_b32x,
1107
1111
  output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
1108
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
1112
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 1, query_index)));
1109
1113
  svst1_f32(
1110
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
1111
- svadd_f32_x(predicate_all_f32x,
1112
- svld1_f32(predicate_all_f32x,
1114
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
1115
+ svadd_f32_x(predicate_all_b32x,
1116
+ svld1_f32(predicate_all_b32x,
1113
1117
  output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
1114
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
1118
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, query_index)));
1115
1119
  svst1_f32(
1116
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
1117
- svadd_f32_x(predicate_all_f32x,
1118
- svld1_f32(predicate_all_f32x,
1120
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
1121
+ svadd_f32_x(predicate_all_b32x,
1122
+ svld1_f32(predicate_all_b32x,
1119
1123
  output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
1120
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
1124
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, query_index)));
1121
1125
  }
1122
1126
  }
1123
1127
  for (; dim_tile < dim_tile_count; dim_tile++) {
1124
1128
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
1125
1129
  svmopa_za32_bf16_m(
1126
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1127
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
1130
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1131
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
1128
1132
  svmopa_za32_bf16_m(
1129
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1130
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
1133
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1134
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
1131
1135
  svmopa_za32_bf16_m(
1132
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1133
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
1136
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1137
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
1134
1138
  svmopa_za32_bf16_m(
1135
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1136
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
1139
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1140
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
1137
1141
  svmopa_za32_bf16_m(
1138
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1139
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
1142
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1143
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
1140
1144
  svmopa_za32_bf16_m(
1141
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1142
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
1145
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1146
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
1143
1147
  svmopa_za32_bf16_m(
1144
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1145
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
1148
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1149
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
1146
1150
  svmopa_za32_bf16_m(
1147
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1148
- svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
1151
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1152
+ svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
1149
1153
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
1150
- svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
1151
- svadd_f32_x(predicate_all_f32x,
1152
- svld1_f32(predicate_all_f32x,
1154
+ svst1_f32(predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
1155
+ svadd_f32_x(predicate_all_b32x,
1156
+ svld1_f32(predicate_all_b32x,
1153
1157
  output_accumulator + query_index * head_dim_padded + dim_tile * 16),
1154
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1158
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
1155
1159
  }
1156
1160
  }
1157
1161
  }
1158
1162
 
1159
1163
  // Final normalization
1160
1164
  NK_ALIGN64 nk_f32_t final_sums[16];
1161
- svst1_f32(predicate_all_f32x, final_sums, row_sum_f32x);
1165
+ svst1_f32(predicate_all_b32x, final_sums, row_sum_f32x);
1162
1166
 
1163
1167
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1164
1168
  nk_f32_t inv_sum = (final_sums[query_index] > 0.0f) ? (1.0f / final_sums[query_index]) : 0.0f;
1165
1169
  svfloat32_t inv_sum_f32x = svdup_f32(inv_sum);
1166
1170
 
1167
1171
  for (nk_size_t dim_offset = 0; dim_offset < head_dim; dim_offset += svcntw()) {
1168
- svbool_t predicate_f32x = svwhilelt_b32_u64(dim_offset, head_dim);
1172
+ svbool_t predicate_b32x = svwhilelt_b32_u64(dim_offset, head_dim);
1169
1173
  svfloat32_t output_f32x = svmul_f32_x(
1170
- predicate_f32x,
1171
- svld1_f32(predicate_f32x, output_accumulator + query_index * head_dim_padded + dim_offset),
1174
+ predicate_b32x,
1175
+ svld1_f32(predicate_b32x, output_accumulator + query_index * head_dim_padded + dim_offset),
1172
1176
  inv_sum_f32x);
1173
- svbfloat16_t output_bf16x = nk_f32_to_bf16_sve_(predicate_f32x, output_f32x);
1177
+ svbfloat16_t output_bf16x = nk_f32_to_bf16_sve_(predicate_b32x, output_f32x);
1174
1178
  nk_size_t store_count = (head_dim - dim_offset) < (nk_size_t)svcntw() ? (head_dim - dim_offset)
1175
1179
  : (nk_size_t)svcntw();
1176
- svbool_t store_predicate_f16x = svwhilelt_b16_u64(0u, store_count);
1177
- svst1_bf16(store_predicate_f16x, (bfloat16_t *)(output + query_index * head_dim + dim_offset),
1180
+ svbool_t store_predicate_b16x = svwhilelt_b16_u64(0u, store_count);
1181
+ svst1_bf16(store_predicate_b16x, (bfloat16_t *)(output + query_index * head_dim + dim_offset),
1178
1182
  output_bf16x);
1179
1183
  }
1180
1184
  }
@@ -1220,24 +1224,24 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1220
1224
  nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim, nk_size_t head_dim_padded, nk_size_t dim_tile_count,
1221
1225
  nk_f32_t scale) {
1222
1226
 
1223
- svbool_t const predicate_all_f32x = svptrue_b32();
1224
- svbool_t const predicate_all_f16x = svptrue_b16();
1227
+ svbool_t const predicate_all_b32x = svptrue_b32();
1228
+ svbool_t const predicate_all_b16x = svptrue_b16();
1225
1229
  nk_size_t const valid_query_count = (query_len < 16) ? query_len : 16;
1226
1230
 
1227
1231
  NK_ALIGN64 nk_f32_t row_max[16];
1228
1232
  NK_ALIGN64 nk_f32_t row_sum[16];
1229
1233
  NK_ALIGN64 nk_f32_t output_accumulator[16 * 256];
1230
1234
 
1231
- svst1_f32(predicate_all_f32x, row_max, svdup_f32(NK_F32_MIN));
1232
- svst1_f32(predicate_all_f32x, row_sum, svdup_f32(0.0f));
1235
+ svst1_f32(predicate_all_b32x, row_max, svdup_f32(NK_F32_MIN));
1236
+ svst1_f32(predicate_all_b32x, row_sum, svdup_f32(0.0f));
1233
1237
  svfloat32_t zero_f32x = svdup_f32(0.0f);
1234
1238
  for (nk_size_t i = 0; i < 16 * head_dim_padded; i += svcntw()) {
1235
- svst1_f32(predicate_all_f32x, output_accumulator + i, zero_f32x);
1239
+ svst1_f32(predicate_all_b32x, output_accumulator + i, zero_f32x);
1236
1240
  }
1237
1241
 
1238
1242
  nk_size_t kv_block_index = 0;
1239
1243
  nk_size_t kv_start = 0;
1240
- svbool_t const batch_predicate_f32x = svwhilelt_b32(0u, 16u);
1244
+ svbool_t const batch_predicate_b32x = svwhilelt_b32(0u, 16u);
1241
1245
 
1242
1246
  nk_size_t const k_depth_step_count = head_dim_padded / 2;
1243
1247
 
@@ -1248,11 +1252,11 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1248
1252
  for (nk_size_t batch = 0; batch < head_dim_padded / 32; batch++) {
1249
1253
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
1250
1254
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
1251
- svld1_hor_za32(0, query_index, batch_predicate_f32x,
1255
+ svld1_hor_za32(0, query_index, batch_predicate_b32x,
1252
1256
  (nk_f32_t const *)(q + query_index * head_dim + batch * 32));
1253
1257
  for (nk_size_t step = 0; step < 16; step++)
1254
- svst1_f32(predicate_all_f32x, queries_transposed + (batch * 16 + step) * 16,
1255
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, step));
1258
+ svst1_f32(predicate_all_b32x, queries_transposed + (batch * 16 + step) * 16,
1259
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, step));
1256
1260
  }
1257
1261
 
1258
1262
  // === Bc=32 main loop (prefill only, skipped for decode) ===
@@ -1261,14 +1265,15 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1261
1265
  // Q×K^T: pure memory→FMOPA, no ZA staging for Q or K
1262
1266
  svzero_mask_za(nk_sme_zero_za32_tile_2_);
1263
1267
  svzero_mask_za(nk_sme_zero_za32_tile_3_);
1264
- nk_f16_t const *keys_block_lower = k + kv_block_index * k_depth_step_count * 32;
1265
- nk_f16_t const *keys_block_upper = k + (kv_block_index + 1) * k_depth_step_count * 32;
1268
+ nk_f16_t const *keys_block_low = k + kv_block_index * k_depth_step_count * 32;
1269
+ nk_f16_t const *keys_block_high = k + (kv_block_index + 1) * k_depth_step_count * 32;
1266
1270
  for (nk_size_t step = 0; step < k_depth_step_count; step++) {
1267
- svfloat16_t zn = svreinterpret_f16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
1268
- svfloat16_t zm0 = svld1_f16(predicate_all_f16x, (float16_t const *)(keys_block_lower + step * 32));
1269
- svfloat16_t zm1 = svld1_f16(predicate_all_f16x, (float16_t const *)(keys_block_upper + step * 32));
1270
- svmopa_za32_f16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm0);
1271
- svmopa_za32_f16_m(3, predicate_all_f32x, predicate_all_f32x, zn, zm1);
1271
+ svfloat16_t zn_f16x = svreinterpret_f16_f32(
1272
+ svld1_f32(predicate_all_b32x, queries_transposed + step * 16));
1273
+ svfloat16_t zm0_f16x = svld1_f16(predicate_all_b16x, (float16_t const *)(keys_block_low + step * 32));
1274
+ svfloat16_t zm1_f16x = svld1_f16(predicate_all_b16x, (float16_t const *)(keys_block_high + step * 32));
1275
+ svmopa_za32_f16_m(2, predicate_all_b32x, predicate_all_b32x, zn_f16x, zm0_f16x);
1276
+ svmopa_za32_f16_m(3, predicate_all_b32x, predicate_all_b32x, zn_f16x, zm1_f16x);
1272
1277
  }
1273
1278
  // ZA2 = scores[query_index][0:15], ZA3 = scores[query_index][16:31]
1274
1279
 
@@ -1277,29 +1282,29 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1277
1282
  svfloat32_t block_max_f32x = svdup_f32(NK_F32_MIN);
1278
1283
  for (nk_size_t column_index = 0; column_index < 16; column_index++) {
1279
1284
  svfloat32_t score_column_f32x = svmul_f32_x(
1280
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
1285
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
1281
1286
  scale_f32x);
1282
- block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
1287
+ block_max_f32x = svmax_f32_x(predicate_all_b32x, block_max_f32x, score_column_f32x);
1283
1288
  }
1284
1289
  for (nk_size_t column_index = 0; column_index < 16; column_index++) {
1285
1290
  svfloat32_t score_column_f32x = svmul_f32_x(
1286
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
1291
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index),
1287
1292
  scale_f32x);
1288
- block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
1293
+ block_max_f32x = svmax_f32_x(predicate_all_b32x, block_max_f32x, score_column_f32x);
1289
1294
  }
1290
1295
 
1291
1296
  // Softmax correction (vectorized via array load/store)
1292
- svfloat32_t old_max_f32x = svld1_f32(predicate_all_f32x, row_max);
1293
- svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, old_max_f32x, block_max_f32x);
1297
+ svfloat32_t old_max_f32x = svld1_f32(predicate_all_b32x, row_max);
1298
+ svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_b32x, old_max_f32x, block_max_f32x);
1294
1299
  svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
1295
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, old_max_f32x, new_max_f32x));
1296
- svbool_t max_changed = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
1297
- nk_u32_t max_was_updated = svptest_any(predicate_all_f32x, max_changed) ? 1 : 0;
1298
- svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_f32x, row_sum);
1300
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, old_max_f32x, new_max_f32x));
1301
+ svbool_t max_changed_b32x = svcmplt_f32(predicate_all_b32x, correction_f32x, svdup_f32(1.0f));
1302
+ nk_u32_t max_was_updated = svptest_any(predicate_all_b32x, max_changed_b32x) ? 1 : 0;
1303
+ svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_b32x, row_sum);
1299
1304
  if (max_was_updated)
1300
- row_sum_corrected_f32x = svmul_f32_x(predicate_all_f32x, row_sum_corrected_f32x, correction_f32x);
1305
+ row_sum_corrected_f32x = svmul_f32_x(predicate_all_b32x, row_sum_corrected_f32x, correction_f32x);
1301
1306
  NK_ALIGN64 nk_f32_t corrections[16];
1302
- svst1_f32(predicate_all_f32x, corrections, correction_f32x);
1307
+ svst1_f32(predicate_all_b32x, corrections, correction_f32x);
1303
1308
 
1304
1309
  // Pass 2: Column-wise exp + fused P write + sum
1305
1310
  svfloat32_t sum_delta_f32x = svdup_f32(0.0f);
@@ -1307,92 +1312,92 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1307
1312
  // ZA2 columns in pairs -> ZA0 columns 0-7
1308
1313
  for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
1309
1314
  svfloat32_t score_even_f32x = svmul_f32_x(
1310
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
1315
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
1311
1316
  scale_f32x);
1312
1317
  svfloat32_t score_odd_f32x = svmul_f32_x(
1313
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
1318
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index + 1),
1314
1319
  scale_f32x);
1315
1320
  svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
1316
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
1321
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
1317
1322
  svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
1318
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
1319
- sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
1320
- sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
1321
- svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_f32x, weight_even_f32x),
1322
- svcvt_f16_f32_x(predicate_all_f32x, weight_odd_f32x));
1323
- svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x,
1323
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
1324
+ sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_even_f32x);
1325
+ sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_odd_f32x);
1326
+ svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_b32x, weight_even_f32x),
1327
+ svcvt_f16_f32_x(predicate_all_b32x, weight_odd_f32x));
1328
+ svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_b32x,
1324
1329
  svreinterpret_f32_f16(weight_pair_f16x));
1325
1330
  }
1326
1331
  // ZA3 columns in pairs -> ZA0 columns 8-15
1327
1332
  for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
1328
1333
  svfloat32_t score_even_f32x = svmul_f32_x(
1329
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
1334
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index),
1330
1335
  scale_f32x);
1331
1336
  svfloat32_t score_odd_f32x = svmul_f32_x(
1332
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index + 1),
1337
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index + 1),
1333
1338
  scale_f32x);
1334
1339
  svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
1335
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
1340
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
1336
1341
  svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
1337
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
1338
- sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
1339
- sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
1340
- svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_f32x, weight_even_f32x),
1341
- svcvt_f16_f32_x(predicate_all_f32x, weight_odd_f32x));
1342
- svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_f32x,
1342
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
1343
+ sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_even_f32x);
1344
+ sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_odd_f32x);
1345
+ svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_b32x, weight_even_f32x),
1346
+ svcvt_f16_f32_x(predicate_all_b32x, weight_odd_f32x));
1347
+ svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_b32x,
1343
1348
  svreinterpret_f32_f16(weight_pair_f16x));
1344
1349
  }
1345
- row_sum_corrected_f32x = svadd_f32_x(predicate_all_f32x, row_sum_corrected_f32x, sum_delta_f32x);
1346
- svst1_f32(predicate_all_f32x, row_sum, row_sum_corrected_f32x);
1347
- svst1_f32(predicate_all_f32x, row_max, new_max_f32x);
1350
+ row_sum_corrected_f32x = svadd_f32_x(predicate_all_b32x, row_sum_corrected_f32x, sum_delta_f32x);
1351
+ svst1_f32(predicate_all_b32x, row_sum, row_sum_corrected_f32x);
1352
+ svst1_f32(predicate_all_b32x, row_max, new_max_f32x);
1348
1353
 
1349
1354
  // Extract P columns from ZA0
1350
1355
  svfloat16_t probability_column_0_f32x = svreinterpret_f16_f32(
1351
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
1356
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
1352
1357
  svfloat16_t probability_column_1_f32x = svreinterpret_f16_f32(
1353
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
1358
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 1));
1354
1359
  svfloat16_t probability_column_2_f32x = svreinterpret_f16_f32(
1355
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
1360
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 2));
1356
1361
  svfloat16_t probability_column_3_f32x = svreinterpret_f16_f32(
1357
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
1362
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 3));
1358
1363
  svfloat16_t probability_column_4_f32x = svreinterpret_f16_f32(
1359
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
1364
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 4));
1360
1365
  svfloat16_t probability_column_5_f32x = svreinterpret_f16_f32(
1361
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
1366
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 5));
1362
1367
  svfloat16_t probability_column_6_f32x = svreinterpret_f16_f32(
1363
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
1368
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 6));
1364
1369
  svfloat16_t probability_column_7_f32x = svreinterpret_f16_f32(
1365
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
1370
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 7));
1366
1371
  svfloat16_t probability_column_8_f32x = svreinterpret_f16_f32(
1367
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 8));
1372
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 8));
1368
1373
  svfloat16_t probability_column_9_f32x = svreinterpret_f16_f32(
1369
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 9));
1374
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 9));
1370
1375
  svfloat16_t probability_column_10_f32x = svreinterpret_f16_f32(
1371
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 10));
1376
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 10));
1372
1377
  svfloat16_t probability_column_11_f32x = svreinterpret_f16_f32(
1373
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 11));
1378
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 11));
1374
1379
  svfloat16_t probability_column_12_f32x = svreinterpret_f16_f32(
1375
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 12));
1380
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 12));
1376
1381
  svfloat16_t probability_column_13_f32x = svreinterpret_f16_f32(
1377
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 13));
1382
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 13));
1378
1383
  svfloat16_t probability_column_14_f32x = svreinterpret_f16_f32(
1379
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 14));
1384
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 14));
1380
1385
  svfloat16_t probability_column_15_f32x = svreinterpret_f16_f32(
1381
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 15));
1386
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 15));
1382
1387
 
1383
1388
  // Pre-apply correction once before P×V
1384
- svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
1385
- nk_f16_t const *values_block_lower = v_packed + kv_block_index * dim_tile_count * 8 * 32;
1386
- nk_f16_t const *values_block_upper = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
1389
+ svbool_t query_predicate_b16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
1390
+ nk_f16_t const *values_block_low = v_packed + kv_block_index * dim_tile_count * 8 * 32;
1391
+ nk_f16_t const *values_block_high = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
1387
1392
 
1388
1393
  if (max_was_updated) {
1389
1394
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1390
1395
  svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
1391
1396
  for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
1392
1397
  svst1_f32(
1393
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
1394
- svmul_f32_x(predicate_all_f32x,
1395
- svld1_f32(predicate_all_f32x,
1398
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_offset,
1399
+ svmul_f32_x(predicate_all_b32x,
1400
+ svld1_f32(predicate_all_b32x,
1396
1401
  output_accumulator + query_index * head_dim_padded + dim_offset),
1397
1402
  correction_scalar_f32x));
1398
1403
  }
@@ -1403,284 +1408,284 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1403
1408
  for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
1404
1409
  svzero_za();
1405
1410
  // Block0: 8 depth steps (KV positions 0-15)
1406
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1407
- svld1_f16(predicate_all_f16x,
1408
- (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 0) * 32)));
1409
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1410
- svld1_f16(predicate_all_f16x,
1411
- (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 0) * 32)));
1412
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1413
- svld1_f16(predicate_all_f16x,
1414
- (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 0) * 32)));
1415
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1416
- svld1_f16(predicate_all_f16x,
1417
- (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 0) * 32)));
1418
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1419
- svld1_f16(predicate_all_f16x,
1420
- (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 1) * 32)));
1421
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1422
- svld1_f16(predicate_all_f16x,
1423
- (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 1) * 32)));
1424
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1425
- svld1_f16(predicate_all_f16x,
1426
- (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 1) * 32)));
1427
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1428
- svld1_f16(predicate_all_f16x,
1429
- (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 1) * 32)));
1430
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1431
- svld1_f16(predicate_all_f16x,
1432
- (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 2) * 32)));
1433
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1434
- svld1_f16(predicate_all_f16x,
1435
- (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 2) * 32)));
1436
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1437
- svld1_f16(predicate_all_f16x,
1438
- (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 2) * 32)));
1439
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1440
- svld1_f16(predicate_all_f16x,
1441
- (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 2) * 32)));
1442
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1443
- svld1_f16(predicate_all_f16x,
1444
- (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 3) * 32)));
1445
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1446
- svld1_f16(predicate_all_f16x,
1447
- (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 3) * 32)));
1448
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1449
- svld1_f16(predicate_all_f16x,
1450
- (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 3) * 32)));
1451
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1452
- svld1_f16(predicate_all_f16x,
1453
- (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 3) * 32)));
1454
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1455
- svld1_f16(predicate_all_f16x,
1456
- (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 4) * 32)));
1457
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1458
- svld1_f16(predicate_all_f16x,
1459
- (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 4) * 32)));
1460
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1461
- svld1_f16(predicate_all_f16x,
1462
- (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 4) * 32)));
1463
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1464
- svld1_f16(predicate_all_f16x,
1465
- (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 4) * 32)));
1466
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1467
- svld1_f16(predicate_all_f16x,
1468
- (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 5) * 32)));
1469
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1470
- svld1_f16(predicate_all_f16x,
1471
- (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 5) * 32)));
1472
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1473
- svld1_f16(predicate_all_f16x,
1474
- (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 5) * 32)));
1475
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1476
- svld1_f16(predicate_all_f16x,
1477
- (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 5) * 32)));
1478
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1479
- svld1_f16(predicate_all_f16x,
1480
- (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 6) * 32)));
1481
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1482
- svld1_f16(predicate_all_f16x,
1483
- (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 6) * 32)));
1484
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1485
- svld1_f16(predicate_all_f16x,
1486
- (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 6) * 32)));
1487
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1488
- svld1_f16(predicate_all_f16x,
1489
- (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 6) * 32)));
1490
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1491
- svld1_f16(predicate_all_f16x,
1492
- (float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 7) * 32)));
1493
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1494
- svld1_f16(predicate_all_f16x,
1495
- (float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 7) * 32)));
1496
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1497
- svld1_f16(predicate_all_f16x,
1498
- (float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 7) * 32)));
1499
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1500
- svld1_f16(predicate_all_f16x,
1501
- (float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 7) * 32)));
1411
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1412
+ svld1_f16(predicate_all_b16x,
1413
+ (float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 0) * 32)));
1414
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1415
+ svld1_f16(predicate_all_b16x,
1416
+ (float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 0) * 32)));
1417
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1418
+ svld1_f16(predicate_all_b16x,
1419
+ (float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 0) * 32)));
1420
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1421
+ svld1_f16(predicate_all_b16x,
1422
+ (float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 0) * 32)));
1423
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1424
+ svld1_f16(predicate_all_b16x,
1425
+ (float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 1) * 32)));
1426
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1427
+ svld1_f16(predicate_all_b16x,
1428
+ (float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 1) * 32)));
1429
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1430
+ svld1_f16(predicate_all_b16x,
1431
+ (float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 1) * 32)));
1432
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1433
+ svld1_f16(predicate_all_b16x,
1434
+ (float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 1) * 32)));
1435
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1436
+ svld1_f16(predicate_all_b16x,
1437
+ (float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 2) * 32)));
1438
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1439
+ svld1_f16(predicate_all_b16x,
1440
+ (float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 2) * 32)));
1441
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1442
+ svld1_f16(predicate_all_b16x,
1443
+ (float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 2) * 32)));
1444
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1445
+ svld1_f16(predicate_all_b16x,
1446
+ (float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 2) * 32)));
1447
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1448
+ svld1_f16(predicate_all_b16x,
1449
+ (float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 3) * 32)));
1450
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1451
+ svld1_f16(predicate_all_b16x,
1452
+ (float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 3) * 32)));
1453
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1454
+ svld1_f16(predicate_all_b16x,
1455
+ (float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 3) * 32)));
1456
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1457
+ svld1_f16(predicate_all_b16x,
1458
+ (float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 3) * 32)));
1459
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1460
+ svld1_f16(predicate_all_b16x,
1461
+ (float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 4) * 32)));
1462
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1463
+ svld1_f16(predicate_all_b16x,
1464
+ (float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 4) * 32)));
1465
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1466
+ svld1_f16(predicate_all_b16x,
1467
+ (float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 4) * 32)));
1468
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1469
+ svld1_f16(predicate_all_b16x,
1470
+ (float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 4) * 32)));
1471
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1472
+ svld1_f16(predicate_all_b16x,
1473
+ (float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 5) * 32)));
1474
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1475
+ svld1_f16(predicate_all_b16x,
1476
+ (float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 5) * 32)));
1477
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1478
+ svld1_f16(predicate_all_b16x,
1479
+ (float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 5) * 32)));
1480
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1481
+ svld1_f16(predicate_all_b16x,
1482
+ (float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 5) * 32)));
1483
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1484
+ svld1_f16(predicate_all_b16x,
1485
+ (float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 6) * 32)));
1486
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1487
+ svld1_f16(predicate_all_b16x,
1488
+ (float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 6) * 32)));
1489
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1490
+ svld1_f16(predicate_all_b16x,
1491
+ (float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 6) * 32)));
1492
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1493
+ svld1_f16(predicate_all_b16x,
1494
+ (float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 6) * 32)));
1495
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1496
+ svld1_f16(predicate_all_b16x,
1497
+ (float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 7) * 32)));
1498
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1499
+ svld1_f16(predicate_all_b16x,
1500
+ (float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 7) * 32)));
1501
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1502
+ svld1_f16(predicate_all_b16x,
1503
+ (float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 7) * 32)));
1504
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1505
+ svld1_f16(predicate_all_b16x,
1506
+ (float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 7) * 32)));
1502
1507
  // Block1: 8 depth steps (KV positions 16-31)
1503
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1504
- svld1_f16(predicate_all_f16x,
1505
- (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 0) * 32)));
1506
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1507
- svld1_f16(predicate_all_f16x,
1508
- (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 0) * 32)));
1509
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1510
- svld1_f16(predicate_all_f16x,
1511
- (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 0) * 32)));
1512
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1513
- svld1_f16(predicate_all_f16x,
1514
- (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 0) * 32)));
1515
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1516
- svld1_f16(predicate_all_f16x,
1517
- (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 1) * 32)));
1518
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1519
- svld1_f16(predicate_all_f16x,
1520
- (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 1) * 32)));
1521
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1522
- svld1_f16(predicate_all_f16x,
1523
- (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 1) * 32)));
1524
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1525
- svld1_f16(predicate_all_f16x,
1526
- (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 1) * 32)));
1527
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1528
- svld1_f16(predicate_all_f16x,
1529
- (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 2) * 32)));
1530
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1531
- svld1_f16(predicate_all_f16x,
1532
- (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 2) * 32)));
1533
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1534
- svld1_f16(predicate_all_f16x,
1535
- (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 2) * 32)));
1536
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1537
- svld1_f16(predicate_all_f16x,
1538
- (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 2) * 32)));
1539
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1540
- svld1_f16(predicate_all_f16x,
1541
- (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 3) * 32)));
1542
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1543
- svld1_f16(predicate_all_f16x,
1544
- (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 3) * 32)));
1545
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1546
- svld1_f16(predicate_all_f16x,
1547
- (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 3) * 32)));
1548
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1549
- svld1_f16(predicate_all_f16x,
1550
- (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 3) * 32)));
1551
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1552
- svld1_f16(predicate_all_f16x,
1553
- (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 4) * 32)));
1554
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1555
- svld1_f16(predicate_all_f16x,
1556
- (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 4) * 32)));
1557
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1558
- svld1_f16(predicate_all_f16x,
1559
- (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 4) * 32)));
1560
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1561
- svld1_f16(predicate_all_f16x,
1562
- (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 4) * 32)));
1563
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1564
- svld1_f16(predicate_all_f16x,
1565
- (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 5) * 32)));
1566
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1567
- svld1_f16(predicate_all_f16x,
1568
- (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 5) * 32)));
1569
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1570
- svld1_f16(predicate_all_f16x,
1571
- (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 5) * 32)));
1572
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1573
- svld1_f16(predicate_all_f16x,
1574
- (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 5) * 32)));
1575
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1576
- svld1_f16(predicate_all_f16x,
1577
- (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 6) * 32)));
1578
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1579
- svld1_f16(predicate_all_f16x,
1580
- (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 6) * 32)));
1581
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1582
- svld1_f16(predicate_all_f16x,
1583
- (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 6) * 32)));
1584
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1585
- svld1_f16(predicate_all_f16x,
1586
- (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 6) * 32)));
1587
- svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1588
- svld1_f16(predicate_all_f16x,
1589
- (float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 7) * 32)));
1590
- svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1591
- svld1_f16(predicate_all_f16x,
1592
- (float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 7) * 32)));
1593
- svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1594
- svld1_f16(predicate_all_f16x,
1595
- (float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 7) * 32)));
1596
- svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1597
- svld1_f16(predicate_all_f16x,
1598
- (float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 7) * 32)));
1508
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
1509
+ svld1_f16(predicate_all_b16x,
1510
+ (float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 0) * 32)));
1511
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
1512
+ svld1_f16(predicate_all_b16x,
1513
+ (float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 0) * 32)));
1514
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
1515
+ svld1_f16(predicate_all_b16x,
1516
+ (float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 0) * 32)));
1517
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
1518
+ svld1_f16(predicate_all_b16x,
1519
+ (float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 0) * 32)));
1520
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
1521
+ svld1_f16(predicate_all_b16x,
1522
+ (float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 1) * 32)));
1523
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
1524
+ svld1_f16(predicate_all_b16x,
1525
+ (float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 1) * 32)));
1526
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
1527
+ svld1_f16(predicate_all_b16x,
1528
+ (float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 1) * 32)));
1529
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
1530
+ svld1_f16(predicate_all_b16x,
1531
+ (float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 1) * 32)));
1532
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
1533
+ svld1_f16(predicate_all_b16x,
1534
+ (float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 2) * 32)));
1535
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
1536
+ svld1_f16(predicate_all_b16x,
1537
+ (float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 2) * 32)));
1538
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
1539
+ svld1_f16(predicate_all_b16x,
1540
+ (float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 2) * 32)));
1541
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
1542
+ svld1_f16(predicate_all_b16x,
1543
+ (float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 2) * 32)));
1544
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
1545
+ svld1_f16(predicate_all_b16x,
1546
+ (float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 3) * 32)));
1547
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
1548
+ svld1_f16(predicate_all_b16x,
1549
+ (float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 3) * 32)));
1550
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
1551
+ svld1_f16(predicate_all_b16x,
1552
+ (float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 3) * 32)));
1553
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
1554
+ svld1_f16(predicate_all_b16x,
1555
+ (float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 3) * 32)));
1556
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
1557
+ svld1_f16(predicate_all_b16x,
1558
+ (float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 4) * 32)));
1559
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
1560
+ svld1_f16(predicate_all_b16x,
1561
+ (float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 4) * 32)));
1562
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
1563
+ svld1_f16(predicate_all_b16x,
1564
+ (float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 4) * 32)));
1565
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
1566
+ svld1_f16(predicate_all_b16x,
1567
+ (float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 4) * 32)));
1568
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
1569
+ svld1_f16(predicate_all_b16x,
1570
+ (float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 5) * 32)));
1571
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
1572
+ svld1_f16(predicate_all_b16x,
1573
+ (float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 5) * 32)));
1574
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
1575
+ svld1_f16(predicate_all_b16x,
1576
+ (float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 5) * 32)));
1577
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
1578
+ svld1_f16(predicate_all_b16x,
1579
+ (float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 5) * 32)));
1580
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
1581
+ svld1_f16(predicate_all_b16x,
1582
+ (float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 6) * 32)));
1583
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
1584
+ svld1_f16(predicate_all_b16x,
1585
+ (float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 6) * 32)));
1586
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
1587
+ svld1_f16(predicate_all_b16x,
1588
+ (float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 6) * 32)));
1589
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
1590
+ svld1_f16(predicate_all_b16x,
1591
+ (float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 6) * 32)));
1592
+ svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
1593
+ svld1_f16(predicate_all_b16x,
1594
+ (float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 7) * 32)));
1595
+ svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
1596
+ svld1_f16(predicate_all_b16x,
1597
+ (float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 7) * 32)));
1598
+ svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
1599
+ svld1_f16(predicate_all_b16x,
1600
+ (float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 7) * 32)));
1601
+ svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
1602
+ svld1_f16(predicate_all_b16x,
1603
+ (float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 7) * 32)));
1599
1604
  // Read FMOPA result and ADD to output_accumulator
1600
1605
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1601
1606
  svst1_f32(
1602
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
1603
- svadd_f32_x(predicate_all_f32x,
1604
- svld1_f32(predicate_all_f32x,
1607
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
1608
+ svadd_f32_x(predicate_all_b32x,
1609
+ svld1_f32(predicate_all_b32x,
1605
1610
  output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
1606
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1611
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
1607
1612
  svst1_f32(
1608
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
1609
- svadd_f32_x(predicate_all_f32x,
1610
- svld1_f32(predicate_all_f32x,
1613
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
1614
+ svadd_f32_x(predicate_all_b32x,
1615
+ svld1_f32(predicate_all_b32x,
1611
1616
  output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
1612
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
1617
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 1, query_index)));
1613
1618
  svst1_f32(
1614
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
1615
- svadd_f32_x(predicate_all_f32x,
1616
- svld1_f32(predicate_all_f32x,
1619
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
1620
+ svadd_f32_x(predicate_all_b32x,
1621
+ svld1_f32(predicate_all_b32x,
1617
1622
  output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
1618
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
1623
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, query_index)));
1619
1624
  svst1_f32(
1620
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
1621
- svadd_f32_x(predicate_all_f32x,
1622
- svld1_f32(predicate_all_f32x,
1625
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
1626
+ svadd_f32_x(predicate_all_b32x,
1627
+ svld1_f32(predicate_all_b32x,
1623
1628
  output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
1624
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
1629
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, query_index)));
1625
1630
  }
1626
1631
  }
1627
1632
  // Remainder: 1 dim_tile at a time using ZA0
1628
1633
  for (; dim_tile < dim_tile_count; dim_tile++) {
1629
1634
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
1630
1635
  svmopa_za32_f16_m(
1631
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1632
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 0) * 32)));
1636
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1637
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 0) * 32)));
1633
1638
  svmopa_za32_f16_m(
1634
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1635
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 1) * 32)));
1639
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1640
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 1) * 32)));
1636
1641
  svmopa_za32_f16_m(
1637
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1638
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 2) * 32)));
1642
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1643
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 2) * 32)));
1639
1644
  svmopa_za32_f16_m(
1640
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1641
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 3) * 32)));
1645
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1646
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 3) * 32)));
1642
1647
  svmopa_za32_f16_m(
1643
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1644
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 4) * 32)));
1648
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1649
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 4) * 32)));
1645
1650
  svmopa_za32_f16_m(
1646
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1647
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 5) * 32)));
1651
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1652
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 5) * 32)));
1648
1653
  svmopa_za32_f16_m(
1649
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1650
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 6) * 32)));
1654
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1655
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 6) * 32)));
1651
1656
  svmopa_za32_f16_m(
1652
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1653
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 7) * 32)));
1657
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1658
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 7) * 32)));
1654
1659
  svmopa_za32_f16_m(
1655
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
1656
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 0) * 32)));
1660
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
1661
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 0) * 32)));
1657
1662
  svmopa_za32_f16_m(
1658
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
1659
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 1) * 32)));
1663
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
1664
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 1) * 32)));
1660
1665
  svmopa_za32_f16_m(
1661
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
1662
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 2) * 32)));
1666
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
1667
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 2) * 32)));
1663
1668
  svmopa_za32_f16_m(
1664
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
1665
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 3) * 32)));
1669
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
1670
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 3) * 32)));
1666
1671
  svmopa_za32_f16_m(
1667
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
1668
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 4) * 32)));
1672
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
1673
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 4) * 32)));
1669
1674
  svmopa_za32_f16_m(
1670
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
1671
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 5) * 32)));
1675
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
1676
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 5) * 32)));
1672
1677
  svmopa_za32_f16_m(
1673
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
1674
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 6) * 32)));
1678
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
1679
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 6) * 32)));
1675
1680
  svmopa_za32_f16_m(
1676
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
1677
- svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 7) * 32)));
1681
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
1682
+ svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 7) * 32)));
1678
1683
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
1679
- svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
1680
- svadd_f32_x(predicate_all_f32x,
1681
- svld1_f32(predicate_all_f32x,
1684
+ svst1_f32(predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
1685
+ svadd_f32_x(predicate_all_b32x,
1686
+ svld1_f32(predicate_all_b32x,
1682
1687
  output_accumulator + query_index * head_dim_padded + dim_tile * 16),
1683
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1688
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
1684
1689
  }
1685
1690
  }
1686
1691
  }
@@ -1693,9 +1698,9 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1693
1698
  svzero_mask_za(nk_sme_zero_za32_tile_2_);
1694
1699
  nk_f16_t const *k_block = k + kv_block_index * k_depth_step_count * 32;
1695
1700
  for (nk_size_t step = 0; step < k_depth_step_count; step++) {
1696
- svfloat16_t zn = svreinterpret_f16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
1697
- svfloat16_t zm = svld1_f16(predicate_all_f16x, (float16_t const *)(k_block + step * 32));
1698
- svmopa_za32_f16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm);
1701
+ svfloat16_t zn_f16x = svreinterpret_f16_f32(svld1_f32(predicate_all_b32x, queries_transposed + step * 16));
1702
+ svfloat16_t zm_f16x = svld1_f16(predicate_all_b16x, (float16_t const *)(k_block + step * 32));
1703
+ svmopa_za32_f16_m(2, predicate_all_b32x, predicate_all_b32x, zn_f16x, zm_f16x);
1699
1704
  }
1700
1705
 
1701
1706
  // Pass 1: Column-wise max (read ZA2 columns vertically)
@@ -1703,56 +1708,57 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1703
1708
  svfloat32_t block_max_16_f32x = svdup_f32(NK_F32_MIN);
1704
1709
  for (nk_size_t column_index = 0; column_index < 16; column_index++) {
1705
1710
  svfloat32_t score_column_f32x = svmul_f32_x(
1706
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
1711
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
1707
1712
  scale_16_f32x);
1708
- block_max_16_f32x = svmax_f32_x(predicate_all_f32x, block_max_16_f32x, score_column_f32x);
1713
+ block_max_16_f32x = svmax_f32_x(predicate_all_b32x, block_max_16_f32x, score_column_f32x);
1709
1714
  }
1710
1715
 
1711
- svfloat32_t old_max_f32x = svld1_f32(predicate_all_f32x, row_max);
1712
- svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, old_max_f32x, block_max_16_f32x);
1713
- svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_f32x,
1714
- svsub_f32_x(predicate_all_f32x, old_max_f32x, new_max_f32x));
1715
- svbool_t max_changed_16 = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
1716
- nk_u32_t max_was_updated_16 = svptest_any(predicate_all_f32x, max_changed_16) ? 1 : 0;
1717
- svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_f32x, row_sum);
1716
+ svfloat32_t old_max_f32x = svld1_f32(predicate_all_b32x, row_max);
1717
+ svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_b32x, old_max_f32x, block_max_16_f32x);
1718
+ svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_b32x,
1719
+ svsub_f32_x(predicate_all_b32x, old_max_f32x, new_max_f32x));
1720
+ svbool_t max_changed_16_b32x = svcmplt_f32(predicate_all_b32x, correction_f32x, svdup_f32(1.0f));
1721
+ nk_u32_t max_was_updated_16 = svptest_any(predicate_all_b32x, max_changed_16_b32x) ? 1 : 0;
1722
+ svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_b32x, row_sum);
1718
1723
  if (max_was_updated_16)
1719
- row_sum_corrected_f32x = svmul_f32_x(predicate_all_f32x, row_sum_corrected_f32x, correction_f32x);
1724
+ row_sum_corrected_f32x = svmul_f32_x(predicate_all_b32x, row_sum_corrected_f32x, correction_f32x);
1720
1725
  NK_ALIGN64 nk_f32_t corrections[16];
1721
- svst1_f32(predicate_all_f32x, corrections, correction_f32x);
1726
+ svst1_f32(predicate_all_b32x, corrections, correction_f32x);
1722
1727
 
1723
1728
  // Pass 2: Column-wise exp + fused P write + sum (ZA2 → ZA0 columns 0-7)
1724
1729
  svfloat32_t sum_delta_16_f32x = svdup_f32(0.0f);
1725
1730
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
1726
1731
  for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
1727
1732
  svfloat32_t score_even_f32x = svmul_f32_x(
1728
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
1733
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
1729
1734
  scale_16_f32x);
1730
1735
  svfloat32_t score_odd_f32x = svmul_f32_x(
1731
- predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
1736
+ predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index + 1),
1732
1737
  scale_16_f32x);
1733
1738
  svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
1734
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
1739
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
1735
1740
  svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
1736
- predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
1737
- sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_even_f32x);
1738
- sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_odd_f32x);
1739
- svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_f32x, weight_even_f32x),
1740
- svcvt_f16_f32_x(predicate_all_f32x, weight_odd_f32x));
1741
- svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x, svreinterpret_f32_f16(weight_pair_f16x));
1741
+ predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
1742
+ sum_delta_16_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_16_f32x, weight_even_f32x);
1743
+ sum_delta_16_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_16_f32x, weight_odd_f32x);
1744
+ svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_b32x, weight_even_f32x),
1745
+ svcvt_f16_f32_x(predicate_all_b32x, weight_odd_f32x));
1746
+ svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_b32x, svreinterpret_f32_f16(weight_pair_f16x));
1742
1747
  }
1743
- row_sum_corrected_f32x = svadd_f32_x(predicate_all_f32x, row_sum_corrected_f32x, sum_delta_16_f32x);
1744
- svst1_f32(predicate_all_f32x, row_sum, row_sum_corrected_f32x);
1745
- svst1_f32(predicate_all_f32x, row_max, new_max_f32x);
1748
+ row_sum_corrected_f32x = svadd_f32_x(predicate_all_b32x, row_sum_corrected_f32x, sum_delta_16_f32x);
1749
+ svst1_f32(predicate_all_b32x, row_sum, row_sum_corrected_f32x);
1750
+ svst1_f32(predicate_all_b32x, row_max, new_max_f32x);
1746
1751
 
1747
1752
  if (valid_query_count == 1) {
1748
1753
  // Decode path: extract f32 weights from ZA0 row 0 using SVE
1749
- svfloat16_t row0_f16 = svreinterpret_f16_f32(svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
1750
- svfloat16_t weights_even_f16 = svuzp1_f16(row0_f16, row0_f16);
1751
- svfloat16_t weights_odd_f16 = svuzp2_f16(row0_f16, row0_f16);
1754
+ svfloat16_t row0_f16x = svreinterpret_f16_f32(
1755
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
1756
+ svfloat16_t weights_even_f16x = svuzp1_f16(row0_f16x, row0_f16x);
1757
+ svfloat16_t weights_odd_f16x = svuzp2_f16(row0_f16x, row0_f16x);
1752
1758
  NK_ALIGN64 nk_f32_t decode_weights[16];
1753
- svst1_f32(svwhilelt_b32(0u, 8u), decode_weights, svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_even_f16));
1759
+ svst1_f32(svwhilelt_b32(0u, 8u), decode_weights, svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_even_f16x));
1754
1760
  svst1_f32(svwhilelt_b32(0u, 8u), decode_weights + 8,
1755
- svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_odd_f16));
1761
+ svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_odd_f16x));
1756
1762
  NK_ALIGN64 nk_f32_t decode_weights_ordered[16];
1757
1763
  for (nk_size_t i = 0; i < 8; i++) {
1758
1764
  decode_weights_ordered[2 * i] = decode_weights[i];
@@ -1760,42 +1766,42 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1760
1766
  }
1761
1767
  svfloat32_t corr_f32x = svdup_f32(corrections[0]);
1762
1768
  for (nk_size_t d = 0; d < head_dim; d += svcntw()) {
1763
- svbool_t predicate_f32x = svwhilelt_b32_u64(d, head_dim);
1764
- svfloat32_t acc_f32x = svmul_f32_x(predicate_f32x, svld1_f32(predicate_f32x, output_accumulator + d),
1769
+ svbool_t predicate_b32x = svwhilelt_b32_u64(d, head_dim);
1770
+ svfloat32_t acc_f32x = svmul_f32_x(predicate_b32x, svld1_f32(predicate_b32x, output_accumulator + d),
1765
1771
  corr_f32x);
1766
1772
  for (nk_size_t ki = 0; ki < valid_kv; ki++) {
1767
1773
  nk_size_t dim_tile = d / 16, depth_s = ki / 2, sub = ki % 2;
1768
1774
  nk_f16_t const *v_vec = v_packed +
1769
1775
  (kv_block_index * dim_tile_count * 8 + dim_tile * 8 + depth_s) * 32;
1770
- svfloat16_t packed_f16x = svld1_f16(predicate_all_f16x, (float16_t const *)v_vec);
1771
- svfloat16_t v_selected = (sub == 0) ? svuzp1_f16(packed_f16x, packed_f16x)
1772
- : svuzp2_f16(packed_f16x, packed_f16x);
1773
- acc_f32x = svmla_f32_x(predicate_f32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
1774
- svcvt_f32_f16_x(predicate_f32x, v_selected));
1776
+ svfloat16_t packed_f16x = svld1_f16(predicate_all_b16x, (float16_t const *)v_vec);
1777
+ svfloat16_t v_selected_f16x = (sub == 0) ? svuzp1_f16(packed_f16x, packed_f16x)
1778
+ : svuzp2_f16(packed_f16x, packed_f16x);
1779
+ acc_f32x = svmla_f32_x(predicate_b32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
1780
+ svcvt_f32_f16_x(predicate_b32x, v_selected_f16x));
1775
1781
  }
1776
- svst1_f32(predicate_f32x, output_accumulator + d, acc_f32x);
1782
+ svst1_f32(predicate_b32x, output_accumulator + d, acc_f32x);
1777
1783
  }
1778
1784
  }
1779
1785
  else {
1780
1786
  // Prefill Bc=16: extract P columns, pre-apply correction, add-after P×V
1781
- svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
1787
+ svbool_t query_predicate_b16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
1782
1788
 
1783
1789
  svfloat16_t probability_column_0_f32x = svreinterpret_f16_f32(
1784
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
1790
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
1785
1791
  svfloat16_t probability_column_1_f32x = svreinterpret_f16_f32(
1786
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
1792
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 1));
1787
1793
  svfloat16_t probability_column_2_f32x = svreinterpret_f16_f32(
1788
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
1794
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 2));
1789
1795
  svfloat16_t probability_column_3_f32x = svreinterpret_f16_f32(
1790
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
1796
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 3));
1791
1797
  svfloat16_t probability_column_4_f32x = svreinterpret_f16_f32(
1792
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
1798
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 4));
1793
1799
  svfloat16_t probability_column_5_f32x = svreinterpret_f16_f32(
1794
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
1800
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 5));
1795
1801
  svfloat16_t probability_column_6_f32x = svreinterpret_f16_f32(
1796
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
1802
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 6));
1797
1803
  svfloat16_t probability_column_7_f32x = svreinterpret_f16_f32(
1798
- svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
1804
+ svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 7));
1799
1805
 
1800
1806
  nk_f16_t const *v_block = v_packed + kv_block_index * dim_tile_count * 8 * 32;
1801
1807
 
@@ -1804,9 +1810,9 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1804
1810
  svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
1805
1811
  for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
1806
1812
  svst1_f32(
1807
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
1808
- svmul_f32_x(predicate_all_f32x,
1809
- svld1_f32(predicate_all_f32x,
1813
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_offset,
1814
+ svmul_f32_x(predicate_all_b32x,
1815
+ svld1_f32(predicate_all_b32x,
1810
1816
  output_accumulator + query_index * head_dim_padded + dim_offset),
1811
1817
  correction_scalar_f32x));
1812
1818
  }
@@ -1816,188 +1822,188 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
1816
1822
  for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
1817
1823
  svzero_za();
1818
1824
  svmopa_za32_f16_m(
1819
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1820
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
1825
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1826
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
1821
1827
  svmopa_za32_f16_m(
1822
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1823
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
1828
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1829
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
1824
1830
  svmopa_za32_f16_m(
1825
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1826
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
1831
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1832
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
1827
1833
  svmopa_za32_f16_m(
1828
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1829
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
1834
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1835
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
1830
1836
  svmopa_za32_f16_m(
1831
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1832
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
1837
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1838
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
1833
1839
  svmopa_za32_f16_m(
1834
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1835
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
1840
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1841
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
1836
1842
  svmopa_za32_f16_m(
1837
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1838
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
1843
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1844
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
1839
1845
  svmopa_za32_f16_m(
1840
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1841
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
1846
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1847
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
1842
1848
  svmopa_za32_f16_m(
1843
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1844
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
1849
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1850
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
1845
1851
  svmopa_za32_f16_m(
1846
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1847
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
1852
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1853
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
1848
1854
  svmopa_za32_f16_m(
1849
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1850
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
1855
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1856
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
1851
1857
  svmopa_za32_f16_m(
1852
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1853
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
1858
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1859
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
1854
1860
  svmopa_za32_f16_m(
1855
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1856
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
1861
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1862
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
1857
1863
  svmopa_za32_f16_m(
1858
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1859
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
1864
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1865
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
1860
1866
  svmopa_za32_f16_m(
1861
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1862
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
1867
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1868
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
1863
1869
  svmopa_za32_f16_m(
1864
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1865
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
1870
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1871
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
1866
1872
  svmopa_za32_f16_m(
1867
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1868
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
1873
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1874
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
1869
1875
  svmopa_za32_f16_m(
1870
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1871
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
1876
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1877
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
1872
1878
  svmopa_za32_f16_m(
1873
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1874
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
1879
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1880
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
1875
1881
  svmopa_za32_f16_m(
1876
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1877
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
1882
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1883
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
1878
1884
  svmopa_za32_f16_m(
1879
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1880
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
1885
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1886
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
1881
1887
  svmopa_za32_f16_m(
1882
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1883
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
1888
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1889
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
1884
1890
  svmopa_za32_f16_m(
1885
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1886
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
1891
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1892
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
1887
1893
  svmopa_za32_f16_m(
1888
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1889
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
1894
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1895
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
1890
1896
  svmopa_za32_f16_m(
1891
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1892
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
1897
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1898
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
1893
1899
  svmopa_za32_f16_m(
1894
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1895
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
1900
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1901
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
1896
1902
  svmopa_za32_f16_m(
1897
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1898
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
1903
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1904
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
1899
1905
  svmopa_za32_f16_m(
1900
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1901
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
1906
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1907
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
1902
1908
  svmopa_za32_f16_m(
1903
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1904
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
1909
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1910
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
1905
1911
  svmopa_za32_f16_m(
1906
- 1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1907
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
1912
+ 1, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1913
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
1908
1914
  svmopa_za32_f16_m(
1909
- 2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1910
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
1915
+ 2, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1916
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
1911
1917
  svmopa_za32_f16_m(
1912
- 3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1913
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
1918
+ 3, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1919
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
1914
1920
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1915
1921
  svst1_f32(
1916
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
1917
- svadd_f32_x(predicate_all_f32x,
1918
- svld1_f32(predicate_all_f32x,
1922
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
1923
+ svadd_f32_x(predicate_all_b32x,
1924
+ svld1_f32(predicate_all_b32x,
1919
1925
  output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
1920
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1926
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
1921
1927
  svst1_f32(
1922
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
1923
- svadd_f32_x(predicate_all_f32x,
1924
- svld1_f32(predicate_all_f32x,
1928
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
1929
+ svadd_f32_x(predicate_all_b32x,
1930
+ svld1_f32(predicate_all_b32x,
1925
1931
  output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
1926
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
1932
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 1, query_index)));
1927
1933
  svst1_f32(
1928
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
1929
- svadd_f32_x(predicate_all_f32x,
1930
- svld1_f32(predicate_all_f32x,
1934
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
1935
+ svadd_f32_x(predicate_all_b32x,
1936
+ svld1_f32(predicate_all_b32x,
1931
1937
  output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
1932
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
1938
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, query_index)));
1933
1939
  svst1_f32(
1934
- predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
1935
- svadd_f32_x(predicate_all_f32x,
1936
- svld1_f32(predicate_all_f32x,
1940
+ predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
1941
+ svadd_f32_x(predicate_all_b32x,
1942
+ svld1_f32(predicate_all_b32x,
1937
1943
  output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
1938
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
1944
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, query_index)));
1939
1945
  }
1940
1946
  }
1941
1947
  for (; dim_tile < dim_tile_count; dim_tile++) {
1942
1948
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
1943
1949
  svmopa_za32_f16_m(
1944
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
1945
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
1950
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
1951
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
1946
1952
  svmopa_za32_f16_m(
1947
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
1948
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
1953
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
1954
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
1949
1955
  svmopa_za32_f16_m(
1950
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
1951
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
1956
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
1957
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
1952
1958
  svmopa_za32_f16_m(
1953
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
1954
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
1959
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
1960
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
1955
1961
  svmopa_za32_f16_m(
1956
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
1957
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
1962
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
1963
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
1958
1964
  svmopa_za32_f16_m(
1959
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
1960
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
1965
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
1966
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
1961
1967
  svmopa_za32_f16_m(
1962
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
1963
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
1968
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
1969
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
1964
1970
  svmopa_za32_f16_m(
1965
- 0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
1966
- svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
1971
+ 0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
1972
+ svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
1967
1973
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
1968
- svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
1969
- svadd_f32_x(predicate_all_f32x,
1970
- svld1_f32(predicate_all_f32x,
1974
+ svst1_f32(predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
1975
+ svadd_f32_x(predicate_all_b32x,
1976
+ svld1_f32(predicate_all_b32x,
1971
1977
  output_accumulator + query_index * head_dim_padded + dim_tile * 16),
1972
- svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
1978
+ svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
1973
1979
  }
1974
1980
  }
1975
1981
  }
1976
1982
 
1977
1983
  // Final normalization
1978
- svfloat32_t final_sum_f32x = svld1_f32(predicate_all_f32x, row_sum);
1984
+ svfloat32_t final_sum_f32x = svld1_f32(predicate_all_b32x, row_sum);
1979
1985
  svfloat32_t ones_f32x = svdup_f32(1.0f);
1980
1986
  svfloat32_t zeros_f32x = svdup_f32(0.0f);
1981
- svbool_t sum_positive = svcmpgt_f32(predicate_all_f32x, final_sum_f32x, zeros_f32x);
1982
- svfloat32_t inv_sum_f32x = svsel_f32(sum_positive, svdiv_f32_x(predicate_all_f32x, ones_f32x, final_sum_f32x),
1987
+ svbool_t sum_positive_b32x = svcmpgt_f32(predicate_all_b32x, final_sum_f32x, zeros_f32x);
1988
+ svfloat32_t inv_sum_f32x = svsel_f32(sum_positive_b32x, svdiv_f32_x(predicate_all_b32x, ones_f32x, final_sum_f32x),
1983
1989
  zeros_f32x);
1984
1990
 
1985
1991
  NK_ALIGN64 nk_f32_t inv_sums[16];
1986
- svst1_f32(predicate_all_f32x, inv_sums, inv_sum_f32x);
1992
+ svst1_f32(predicate_all_b32x, inv_sums, inv_sum_f32x);
1987
1993
 
1988
1994
  for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
1989
1995
  svfloat32_t inv_sum_f32x = svdup_f32(inv_sums[query_index]);
1990
1996
  for (nk_size_t dim_offset = 0; dim_offset < head_dim; dim_offset += svcntw()) {
1991
- svbool_t predicate_f32x = svwhilelt_b32_u64(dim_offset, head_dim);
1997
+ svbool_t predicate_b32x = svwhilelt_b32_u64(dim_offset, head_dim);
1992
1998
  svfloat32_t output_f32x = svmul_f32_x(
1993
- predicate_f32x,
1994
- svld1_f32(predicate_f32x, output_accumulator + query_index * head_dim_padded + dim_offset),
1999
+ predicate_b32x,
2000
+ svld1_f32(predicate_b32x, output_accumulator + query_index * head_dim_padded + dim_offset),
1995
2001
  inv_sum_f32x);
1996
- svfloat16_t output_f16x = svcvt_f16_f32_x(predicate_f32x, output_f32x);
2002
+ svfloat16_t output_f16x = svcvt_f16_f32_x(predicate_b32x, output_f32x);
1997
2003
  nk_size_t store_count = (head_dim - dim_offset) < (nk_size_t)svcntw() ? (head_dim - dim_offset)
1998
2004
  : (nk_size_t)svcntw();
1999
- svbool_t predicate_f16x = svwhilelt_b16_u64(0u, store_count);
2000
- svst1_f16(predicate_f16x, (float16_t *)(output + query_index * head_dim + dim_offset), output_f16x);
2005
+ svbool_t predicate_b16x = svwhilelt_b16_u64(0u, store_count);
2006
+ svst1_f16(predicate_b16x, (float16_t *)(output + query_index * head_dim + dim_offset), output_f16x);
2001
2007
  }
2002
2008
  }
2003
2009
  }