numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,2066 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief FlashAttention-style kernels for SME.
|
|
3
|
+
* @file include/numkong/attention/sme.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 11, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/attention.h
|
|
8
|
+
*
|
|
9
|
+
* This file implements FlashAttention-2 style scaled dot-product attention (SDPA) optimized
|
|
10
|
+
* for ARM SME instructions on Apple M4 and similar processors. The kernel computes:
|
|
11
|
+
*
|
|
12
|
+
* O = softmax(Q × Kᵀ / √d) × V
|
|
13
|
+
*
|
|
14
|
+
* Key features:
|
|
15
|
+
* - Online softmax: Mathematically exact, processes KV blocks incrementally
|
|
16
|
+
* - Pre-packed KV cache: BFMOPA/FMOPA-interleaved format amortizes packing for repeated inference
|
|
17
|
+
* - GQA/MQA support: Different `num_heads` and `num_kv_heads` for grouped-query attention
|
|
18
|
+
* - Pure Streaming SVE: No NEON intrinsics for non-linear operations
|
|
19
|
+
*
|
|
20
|
+
* Target models (2025):
|
|
21
|
+
* - Kimi K2: `head_dim`=112, 64 heads, MHA, 128K context
|
|
22
|
+
* - LLaMA 3.1 405B: `head_dim`=128, 128 heads, 16 KV heads (GQA 8:1), 128K context
|
|
23
|
+
* - Qwen 2.5 72B: `head_dim`=128, 64 heads, 8 KV heads (GQA 8:1), 32K context
|
|
24
|
+
*
|
|
25
|
+
* @section attention_sme_architecture Architecture
|
|
26
|
+
*
|
|
27
|
+
* Both Q×Kᵀ and P×V phases use BFMOPA/FMOPA outer products on ZA tiles, eliminating
|
|
28
|
+
* element-wise SVE loops that dominated the original implementation. The Q matrix is
|
|
29
|
+
* pre-transposed once into a buffer matching the interleaving that ZA vertical reads
|
|
30
|
+
* would produce, so Q×Kᵀ runs as pure memory-to-BFMOPA with no per-block ZA staging.
|
|
31
|
+
*
|
|
32
|
+
* Block sizes:
|
|
33
|
+
* - Bᵣ = 16 (query block rows, matches ZA32 tile height)
|
|
34
|
+
* - Bᶜ = 32 (main prefill loop, processes two KV blocks per iteration using ZA2+ZA3)
|
|
35
|
+
* - Bᶜ = 16 (tail loop for remaining KV positions, and decode path)
|
|
36
|
+
*
|
|
37
|
+
* KV packing format:
|
|
38
|
+
* - K is stored in BFMOPA-interleaved format: `K_packed[kv_block][depth_step][32]` where
|
|
39
|
+
* `packed[2*ki + sub] = K[kv_block*16 + ki][2*depth_step + sub]`
|
|
40
|
+
* - V is stored in BFMOPA-interleaved format: `V_packed[kv_block][dim_tile][depth_step][32]`
|
|
41
|
+
* where `packed[2*dj + sub] = V[kv_block*16 + 2*depth_step + sub][dim_tile*16 + dj]`
|
|
42
|
+
* - The `reserved[0]` header field stores `v_dim_tile_count` for efficient V addressing
|
|
43
|
+
*
|
|
44
|
+
* Softmax:
|
|
45
|
+
* - Column-wise max and exp using ZA tile vertical reads (avoids per-row horizontal extracts)
|
|
46
|
+
* - Correction skip: when the block max does not exceed the running max, the output
|
|
47
|
+
* accumulator rescaling is skipped entirely (common in later KV blocks)
|
|
48
|
+
* - Degree-3 fast exp (`nk_exp_fast_f32_sve_`) saves 1 FMA per call vs degree-4
|
|
49
|
+
* - Weights stored directly as bf16/f16 in ZA0 columns via `svzip1` (no f32 round-trip)
|
|
50
|
+
*
|
|
51
|
+
* Decode path (query_len=1):
|
|
52
|
+
* - Uses element-wise SVE with scalar weight broadcasts instead of BFMOPA P×V
|
|
53
|
+
* - BFMOPA overhead too high for single-query case due to ZA setup cost
|
|
54
|
+
*
|
|
55
|
+
* P×V prefill path:
|
|
56
|
+
* - 4-tile BFMOPA processing: 4 dim-tiles × 8 depth steps per KV block = 32 BFMOPA ops
|
|
57
|
+
* - ZA0-ZA3 accumulate simultaneously, read results with MOVA, add to output accumulator
|
|
58
|
+
* - Remainder dim-tiles handled 1-at-a-time using ZA0 only
|
|
59
|
+
*
|
|
60
|
+
* SME tile dimensions (for SVL=512, i.e., Apple M4):
|
|
61
|
+
* - ZA32 tile: 16 × 16 `f32` elements (1KB)
|
|
62
|
+
* - `bf16`/`f16` vectors: 32 elements per SVE vector
|
|
63
|
+
*
|
|
64
|
+
* @section attention_sme_history Optimization History
|
|
65
|
+
*
|
|
66
|
+
* Phase 1 (January 2026): Initial implementation using ZA staging transpose for Q×Kᵀ
|
|
67
|
+
* and element-wise SVE for P×V. Q and K rows were loaded into ZA0/ZA1 horizontally,
|
|
68
|
+
* read back vertically to produce interleaved vectors for BFMOPA. The P×V phase used
|
|
69
|
+
* scalar `svmla_f32_x` loops over head_dim for each query-key pair. Softmax used
|
|
70
|
+
* degree-4 polynomial exp with per-row horizontal max/sum. Performance: ~25-50 GFLOP/s
|
|
71
|
+
* on Apple M4 (bf16, 8 heads, query_len=64, kv_len=4096, head_dim=128).
|
|
72
|
+
*
|
|
73
|
+
* Phase 2 (February 2026): BFMOPA/FMOPA P×V with pre-packed V in interleaved format.
|
|
74
|
+
* Key changes integrated:
|
|
75
|
+
* - Q pre-transposed once into a buffer, eliminating per-block ZA staging for Q
|
|
76
|
+
* - K pre-packed in interleaved format, enabling pure memory-to-BFMOPA Q×Kᵀ
|
|
77
|
+
* - V pre-packed in BFMOPA-interleaved format with dim-tile blocking
|
|
78
|
+
* - P×V uses 4-tile BFMOPA accumulation (ZA0-ZA3) with pre-extracted P columns
|
|
79
|
+
* - Bᶜ=32 main loop for prefill (2 KV blocks per iteration via ZA2+ZA3)
|
|
80
|
+
* - Column-wise softmax: vertical ZA reads for max/exp instead of per-row horizontal
|
|
81
|
+
* - Correction skip when running max is unchanged
|
|
82
|
+
* - Degree-3 fast exp (~0.5% max relative error, saves 1 FMA per call)
|
|
83
|
+
* - Weights stored directly as bf16/f16 via `svzip1` (no f32 quantization round-trip)
|
|
84
|
+
* Performance: ~300-400 GFLOP/s on Apple M4 (same configuration), a 6-14× improvement.
|
|
85
|
+
*
|
|
86
|
+
* Rejected approaches:
|
|
87
|
+
* - BFMOPA P×V for decode (query_len=1): ZA setup overhead exceeds element-wise SVE cost
|
|
88
|
+
* - `svdot_lane` for Q×Kᵀ: lower throughput than BFMOPA on M4
|
|
89
|
+
* - Shared ZA tiles between softmax and P×V: register pressure too high with 4-tile P×V
|
|
90
|
+
*/
|
|
91
|
+
#ifndef NK_ATTENTION_SME_H
|
|
92
|
+
#define NK_ATTENTION_SME_H
|
|
93
|
+
|
|
94
|
+
#if NK_TARGET_ARM_
|
|
95
|
+
#if NK_TARGET_SME
|
|
96
|
+
|
|
97
|
+
#include "numkong/types.h"
|
|
98
|
+
|
|
99
|
+
#if defined(__cplusplus)
|
|
100
|
+
extern "C" {
|
|
101
|
+
#endif
|
|
102
|
+
|
|
103
|
+
#if defined(__clang__)
|
|
104
|
+
#pragma clang attribute push(__attribute__((target("sme,sve"))), apply_to = function)
|
|
105
|
+
#elif defined(__GNUC__)
|
|
106
|
+
#pragma GCC push_options
|
|
107
|
+
#pragma GCC target("+sme")
|
|
108
|
+
#endif
|
|
109
|
+
|
|
110
|
+
/**
|
|
111
|
+
* @brief Convert bf16 vector to f32 in registers (streaming SVE compatible).
|
|
112
|
+
*
|
|
113
|
+
* BF16 is the upper 16 bits of F32, so we:
|
|
114
|
+
* 1. Reinterpret bf16 as u16
|
|
115
|
+
* 2. Zero-extend to u32 (unpklo for lower half)
|
|
116
|
+
* 3. Shift left by 16 to place in f32 exponent+mantissa position
|
|
117
|
+
* 4. Reinterpret as f32
|
|
118
|
+
*/
|
|
119
|
+
NK_INTERNAL svfloat32_t nk_bf16_to_f32_sve_(svbool_t predicate_f32x, svbfloat16_t x_bf16x) __arm_streaming {
|
|
120
|
+
svuint16_t x_u16x = svreinterpret_u16_bf16(x_bf16x);
|
|
121
|
+
svuint32_t x_u32x = svunpklo_u32(x_u16x);
|
|
122
|
+
x_u32x = svlsl_n_u32_x(predicate_f32x, x_u32x, 16);
|
|
123
|
+
return svreinterpret_f32_u32(x_u32x);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
/**
|
|
127
|
+
* @brief Convert f32 vector to bf16 in registers with rounding (streaming SVE compatible).
|
|
128
|
+
*
|
|
129
|
+
* 1. Reinterpret f32 as u32
|
|
130
|
+
* 2. Add rounding bias (0x8000) for round-to-nearest
|
|
131
|
+
* 3. Shift right by 16
|
|
132
|
+
* 4. Narrow to u16 and reinterpret as bf16
|
|
133
|
+
*/
|
|
134
|
+
NK_INTERNAL svbfloat16_t nk_f32_to_bf16_sve_(svbool_t predicate_f32x, svfloat32_t x_f32x) __arm_streaming {
|
|
135
|
+
svuint32_t x_u32x = svreinterpret_u32_f32(x_f32x);
|
|
136
|
+
x_u32x = svadd_n_u32_x(predicate_f32x, x_u32x, 0x8000); // Round to nearest
|
|
137
|
+
x_u32x = svlsr_n_u32_x(predicate_f32x, x_u32x, 16);
|
|
138
|
+
svuint16_t x_u16x = svuzp1_u16(svreinterpret_u16_u32(x_u32x), svreinterpret_u16_u32(x_u32x));
|
|
139
|
+
return svreinterpret_bf16_u16(x_u16x);
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
/**
|
|
143
|
+
* @brief Packed KV cache header for attention (64-byte aligned).
|
|
144
|
+
*
|
|
145
|
+
* Layout in memory:
|
|
146
|
+
* [header: 64 bytes][K tiles: variable][V tiles: variable]
|
|
147
|
+
*/
|
|
148
|
+
typedef struct {
|
|
149
|
+
nk_u32_t num_kv_heads; ///< Number of K/V heads (for GQA, may differ from Q heads)
|
|
150
|
+
nk_u32_t head_dim; ///< Original head dimension (64, 112, 128)
|
|
151
|
+
nk_u32_t head_dim_padded; ///< Padded to multiple of 32 for SME
|
|
152
|
+
nk_u32_t seq_len; ///< Current sequence length
|
|
153
|
+
nk_u32_t max_seq_len; ///< Maximum sequence length (for pre-allocation)
|
|
154
|
+
nk_u32_t k_offset; ///< Byte offset to K data from header start
|
|
155
|
+
nk_u32_t v_offset; ///< Byte offset to V data from header start
|
|
156
|
+
nk_u32_t reserved[9]; ///< reserved[0] = v_dim_tile_count; remainder pads to 64 bytes
|
|
157
|
+
} nk_attention_sme_packed_header_t;
|
|
158
|
+
|
|
159
|
+
/**
|
|
160
|
+
* @brief Fast exp approximation in Streaming SVE.
|
|
161
|
+
*
|
|
162
|
+
* Uses Cody-Waite range reduction + Horner polynomial (degree 4).
|
|
163
|
+
* Accuracy: ~0.1% relative error, acceptable for softmax normalization.
|
|
164
|
+
*
|
|
165
|
+
* @param pg Active predicate
|
|
166
|
+
* @param x Input vector
|
|
167
|
+
* @return exp(x) approximation
|
|
168
|
+
*/
|
|
169
|
+
NK_INTERNAL svfloat32_t nk_exp_f32_sve_(svbool_t predicate_f32x, svfloat32_t x_f32x) __arm_streaming {
|
|
170
|
+
// Constants for Cody-Waite range reduction
|
|
171
|
+
svfloat32_t log2e_f32x = svdup_f32(1.4426950408889634f);
|
|
172
|
+
svfloat32_t ln2_hi_f32x = svdup_f32(0.693145751953125f);
|
|
173
|
+
svfloat32_t ln2_lo_f32x = svdup_f32(1.42860682030941723212e-6f);
|
|
174
|
+
|
|
175
|
+
// Clamp to avoid overflow/underflow
|
|
176
|
+
svfloat32_t max_x_f32x = svdup_f32(88.3762626647949f);
|
|
177
|
+
svfloat32_t min_x_f32x = svdup_f32(-87.3365447504021f);
|
|
178
|
+
x_f32x = svmax_f32_m(predicate_f32x, svmin_f32_m(predicate_f32x, x_f32x, max_x_f32x), min_x_f32x);
|
|
179
|
+
|
|
180
|
+
// n = round(x / ln(2))
|
|
181
|
+
svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_f32x, svmul_f32_m(predicate_f32x, x_f32x, log2e_f32x));
|
|
182
|
+
|
|
183
|
+
// r = x - n × ln(2) using Cody-Waite for precision
|
|
184
|
+
svfloat32_t r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_hi_f32x, x_f32x);
|
|
185
|
+
r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_lo_f32x, r_f32x);
|
|
186
|
+
|
|
187
|
+
// Polynomial approximation for exp(r): degree 4
|
|
188
|
+
// exp(r) ≈ 1 + r + r²/2 + r³/6 + r⁴/24
|
|
189
|
+
svfloat32_t p_f32x = svdup_f32(4.1666666667e-2f); // 1/24
|
|
190
|
+
p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.6666666667e-1f)); // 1/6
|
|
191
|
+
p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
|
|
192
|
+
p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
|
|
193
|
+
p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
|
|
194
|
+
|
|
195
|
+
// Reconstruct: exp(x) = 2ⁿ × exp(r)
|
|
196
|
+
// 2ⁿ via IEEE 754 exponent manipulation
|
|
197
|
+
svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_f32x, n_f32x);
|
|
198
|
+
n_i32x = svadd_s32_m(predicate_f32x, n_i32x, svdup_s32(127));
|
|
199
|
+
n_i32x = svlsl_n_s32_m(predicate_f32x, n_i32x, 23);
|
|
200
|
+
svfloat32_t pow2n_f32x = svreinterpret_f32_s32(n_i32x);
|
|
201
|
+
|
|
202
|
+
return svmul_f32_m(predicate_f32x, p_f32x, pow2n_f32x);
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
/**
|
|
206
|
+
* @brief Degree-3 fast exp approximation. Max relative error ~0.5%.
|
|
207
|
+
* Saves 1 FMA per call vs degree-4 nk_exp_f32_sve_.
|
|
208
|
+
*/
|
|
209
|
+
NK_INTERNAL svfloat32_t nk_exp_fast_f32_sve_(svbool_t predicate_f32x, svfloat32_t x_f32x) __arm_streaming {
|
|
210
|
+
svfloat32_t log2e_f32x = svdup_f32(1.4426950408889634f);
|
|
211
|
+
svfloat32_t ln2_hi_f32x = svdup_f32(0.693145751953125f);
|
|
212
|
+
svfloat32_t ln2_lo_f32x = svdup_f32(1.42860682030941723212e-6f);
|
|
213
|
+
|
|
214
|
+
svfloat32_t max_x_f32x = svdup_f32(88.3762626647949f);
|
|
215
|
+
svfloat32_t min_x_f32x = svdup_f32(-87.3365447504021f);
|
|
216
|
+
x_f32x = svmax_f32_m(predicate_f32x, svmin_f32_m(predicate_f32x, x_f32x, max_x_f32x), min_x_f32x);
|
|
217
|
+
|
|
218
|
+
svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_f32x, svmul_f32_m(predicate_f32x, x_f32x, log2e_f32x));
|
|
219
|
+
svfloat32_t r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_hi_f32x, x_f32x);
|
|
220
|
+
r_f32x = svmsb_f32_m(predicate_f32x, n_f32x, ln2_lo_f32x, r_f32x);
|
|
221
|
+
|
|
222
|
+
// Degree-3: exp(r) ~ 1 + r + r^2/2 + r^3/6 (drop 1/24 term)
|
|
223
|
+
svfloat32_t p_f32x = svdup_f32(1.6666666667e-1f); // 1/6
|
|
224
|
+
p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
|
|
225
|
+
p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
|
|
226
|
+
p_f32x = svmad_f32_m(predicate_f32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
|
|
227
|
+
|
|
228
|
+
svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_f32x, n_f32x);
|
|
229
|
+
n_i32x = svadd_s32_m(predicate_f32x, n_i32x, svdup_s32(127));
|
|
230
|
+
n_i32x = svlsl_n_s32_m(predicate_f32x, n_i32x, 23);
|
|
231
|
+
svfloat32_t pow2n_f32x = svreinterpret_f32_s32(n_i32x);
|
|
232
|
+
|
|
233
|
+
return svmul_f32_m(predicate_f32x, p_f32x, pow2n_f32x);
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
NK_PUBLIC nk_size_t nk_attention_packed_kv_size_bf16_sme(nk_size_t num_kv_heads, nk_size_t head_dim,
|
|
237
|
+
nk_size_t max_seq_len) {
|
|
238
|
+
nk_size_t head_dim_padded = (head_dim + 31) / 32 * 32;
|
|
239
|
+
nk_size_t kv_blocks = (max_seq_len + 15) / 16;
|
|
240
|
+
nk_size_t seq_padded = kv_blocks * 16;
|
|
241
|
+
// K and V both use BFMOPA-interleaved format: [num_kv_heads, kv_blocks, depth_steps, 32]
|
|
242
|
+
nk_size_t k_size = num_kv_heads * seq_padded * head_dim_padded * sizeof(nk_bf16_t);
|
|
243
|
+
nk_size_t v_size = k_size;
|
|
244
|
+
return sizeof(nk_attention_sme_packed_header_t) + k_size + v_size;
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
NK_PUBLIC nk_size_t nk_attention_packed_kv_size_f16_sme(nk_size_t num_kv_heads, nk_size_t head_dim,
|
|
248
|
+
nk_size_t max_seq_len) {
|
|
249
|
+
return nk_attention_packed_kv_size_bf16_sme(num_kv_heads, head_dim, max_seq_len);
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
__arm_locally_streaming static void nk_attention_pack_kv_bf16_sme_streaming_(nk_bf16_t const *k, nk_bf16_t const *v,
|
|
253
|
+
nk_size_t num_kv_heads, nk_size_t head_dim,
|
|
254
|
+
nk_size_t seq_len, nk_size_t k_stride,
|
|
255
|
+
nk_size_t v_stride, void *kv_packed) {
|
|
256
|
+
|
|
257
|
+
nk_attention_sme_packed_header_t *header = (nk_attention_sme_packed_header_t *)kv_packed;
|
|
258
|
+
nk_size_t head_dim_padded = (head_dim + 31) / 32 * 32;
|
|
259
|
+
nk_size_t dim_tile_count = (head_dim_padded + 15) / 16;
|
|
260
|
+
nk_size_t kv_block_count = (seq_len + 15) / 16;
|
|
261
|
+
|
|
262
|
+
nk_size_t k_depth_step_count = head_dim_padded / 2;
|
|
263
|
+
nk_size_t head_elems = kv_block_count * 16 * head_dim_padded;
|
|
264
|
+
|
|
265
|
+
header->num_kv_heads = (nk_u32_t)num_kv_heads;
|
|
266
|
+
header->head_dim = (nk_u32_t)head_dim;
|
|
267
|
+
header->head_dim_padded = (nk_u32_t)head_dim_padded;
|
|
268
|
+
header->seq_len = (nk_u32_t)seq_len;
|
|
269
|
+
header->k_offset = sizeof(nk_attention_sme_packed_header_t);
|
|
270
|
+
header->reserved[0] = (nk_u32_t)dim_tile_count; // v_dim_tile_count
|
|
271
|
+
header->v_offset = header->k_offset + (nk_u32_t)(num_kv_heads * head_elems * sizeof(nk_bf16_t));
|
|
272
|
+
|
|
273
|
+
nk_bf16_t *k_packed = (nk_bf16_t *)((char *)kv_packed + header->k_offset);
|
|
274
|
+
nk_bf16_t *v_packed = (nk_bf16_t *)((char *)kv_packed + header->v_offset);
|
|
275
|
+
|
|
276
|
+
for (nk_size_t h = 0; h < num_kv_heads; h++) {
|
|
277
|
+
nk_bf16_t const *k_head = k + h * k_stride;
|
|
278
|
+
nk_bf16_t const *v_head = v + h * v_stride;
|
|
279
|
+
|
|
280
|
+
// K packing: BFMOPA-interleaved format
|
|
281
|
+
// K_packed[kv_block][depth_step][32] where
|
|
282
|
+
// packed[2*ki + sub] = K[kv_block*16 + ki][2*depth_step + sub]
|
|
283
|
+
nk_bf16_t *k_out = k_packed + h * head_elems;
|
|
284
|
+
for (nk_size_t kv_block = 0; kv_block < kv_block_count; kv_block++) {
|
|
285
|
+
for (nk_size_t depth_step = 0; depth_step < k_depth_step_count; depth_step++) {
|
|
286
|
+
nk_bf16_t *vec_out = k_out + (kv_block * k_depth_step_count + depth_step) * 32;
|
|
287
|
+
for (nk_size_t ki = 0; ki < 16; ki++) {
|
|
288
|
+
for (nk_size_t sub = 0; sub < 2; sub++) {
|
|
289
|
+
nk_size_t row = kv_block * 16 + ki;
|
|
290
|
+
nk_size_t col = 2 * depth_step + sub;
|
|
291
|
+
nk_bf16_t zero = {0};
|
|
292
|
+
vec_out[2 * ki + sub] = (row < seq_len && col < head_dim) ? k_head[row * head_dim + col] : zero;
|
|
293
|
+
}
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
// V packing: BFMOPA-interleaved format
|
|
299
|
+
nk_bf16_t *v_out = v_packed + h * head_elems;
|
|
300
|
+
for (nk_size_t kv_block = 0; kv_block < kv_block_count; kv_block++) {
|
|
301
|
+
for (nk_size_t dim_tile = 0; dim_tile < dim_tile_count; dim_tile++) {
|
|
302
|
+
for (nk_size_t depth_step = 0; depth_step < 8; depth_step++) {
|
|
303
|
+
nk_bf16_t *vec_out = v_out + (kv_block * dim_tile_count * 8 + dim_tile * 8 + depth_step) * 32;
|
|
304
|
+
for (nk_size_t dj = 0; dj < 16; dj++) {
|
|
305
|
+
for (nk_size_t sub = 0; sub < 2; sub++) {
|
|
306
|
+
nk_size_t ki = kv_block * 16 + 2 * depth_step + sub;
|
|
307
|
+
nk_size_t d = dim_tile * 16 + dj;
|
|
308
|
+
nk_bf16_t zero = {0};
|
|
309
|
+
vec_out[2 * dj + sub] = (ki < seq_len && d < head_dim) ? v_head[ki * head_dim + d] : zero;
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
NK_PUBLIC void nk_attention_pack_kv_bf16_sme(nk_bf16_t const *k, nk_bf16_t const *v, nk_size_t num_kv_heads,
|
|
319
|
+
nk_size_t head_dim, nk_size_t seq_len, nk_size_t k_stride,
|
|
320
|
+
nk_size_t v_stride, void *kv_packed) {
|
|
321
|
+
nk_attention_pack_kv_bf16_sme_streaming_(k, v, num_kv_heads, head_dim, seq_len, k_stride, v_stride, kv_packed);
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
__arm_locally_streaming static void nk_attention_pack_kv_f16_sme_streaming_(nk_f16_t const *k, nk_f16_t const *v,
|
|
325
|
+
nk_size_t num_kv_heads, nk_size_t head_dim,
|
|
326
|
+
nk_size_t seq_len, nk_size_t k_stride,
|
|
327
|
+
nk_size_t v_stride, void *kv_packed) {
|
|
328
|
+
|
|
329
|
+
nk_attention_sme_packed_header_t *header = (nk_attention_sme_packed_header_t *)kv_packed;
|
|
330
|
+
nk_size_t head_dim_padded = (head_dim + 31) / 32 * 32;
|
|
331
|
+
nk_size_t dim_tile_count = (head_dim_padded + 15) / 16;
|
|
332
|
+
nk_size_t kv_block_count = (seq_len + 15) / 16;
|
|
333
|
+
|
|
334
|
+
nk_size_t k_depth_step_count = head_dim_padded / 2;
|
|
335
|
+
nk_size_t head_elems = kv_block_count * 16 * head_dim_padded;
|
|
336
|
+
|
|
337
|
+
header->num_kv_heads = (nk_u32_t)num_kv_heads;
|
|
338
|
+
header->head_dim = (nk_u32_t)head_dim;
|
|
339
|
+
header->head_dim_padded = (nk_u32_t)head_dim_padded;
|
|
340
|
+
header->seq_len = (nk_u32_t)seq_len;
|
|
341
|
+
header->k_offset = sizeof(nk_attention_sme_packed_header_t);
|
|
342
|
+
header->reserved[0] = (nk_u32_t)dim_tile_count; // v_dim_tile_count
|
|
343
|
+
header->v_offset = header->k_offset + (nk_u32_t)(num_kv_heads * head_elems * sizeof(nk_f16_t));
|
|
344
|
+
|
|
345
|
+
nk_f16_t *k_packed = (nk_f16_t *)((char *)kv_packed + header->k_offset);
|
|
346
|
+
nk_f16_t *v_packed = (nk_f16_t *)((char *)kv_packed + header->v_offset);
|
|
347
|
+
|
|
348
|
+
for (nk_size_t h = 0; h < num_kv_heads; h++) {
|
|
349
|
+
nk_f16_t const *k_head = k + h * k_stride;
|
|
350
|
+
nk_f16_t const *v_head = v + h * v_stride;
|
|
351
|
+
|
|
352
|
+
// K packing: FMOPA-interleaved format
|
|
353
|
+
nk_f16_t *k_out = k_packed + h * head_elems;
|
|
354
|
+
for (nk_size_t kv_block = 0; kv_block < kv_block_count; kv_block++) {
|
|
355
|
+
for (nk_size_t depth_step = 0; depth_step < k_depth_step_count; depth_step++) {
|
|
356
|
+
nk_f16_t *vec_out = k_out + (kv_block * k_depth_step_count + depth_step) * 32;
|
|
357
|
+
for (nk_size_t ki = 0; ki < 16; ki++) {
|
|
358
|
+
for (nk_size_t sub = 0; sub < 2; sub++) {
|
|
359
|
+
nk_size_t row = kv_block * 16 + ki;
|
|
360
|
+
nk_size_t col = 2 * depth_step + sub;
|
|
361
|
+
nk_f16_t zero = {0};
|
|
362
|
+
vec_out[2 * ki + sub] = (row < seq_len && col < head_dim) ? k_head[row * head_dim + col] : zero;
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
// V packing: FMOPA-interleaved format
|
|
369
|
+
nk_f16_t *v_out = v_packed + h * head_elems;
|
|
370
|
+
for (nk_size_t kv_block = 0; kv_block < kv_block_count; kv_block++) {
|
|
371
|
+
for (nk_size_t dim_tile = 0; dim_tile < dim_tile_count; dim_tile++) {
|
|
372
|
+
for (nk_size_t depth_step = 0; depth_step < 8; depth_step++) {
|
|
373
|
+
nk_f16_t *vec_out = v_out + (kv_block * dim_tile_count * 8 + dim_tile * 8 + depth_step) * 32;
|
|
374
|
+
for (nk_size_t dj = 0; dj < 16; dj++) {
|
|
375
|
+
for (nk_size_t sub = 0; sub < 2; sub++) {
|
|
376
|
+
nk_size_t ki = kv_block * 16 + 2 * depth_step + sub;
|
|
377
|
+
nk_size_t d = dim_tile * 16 + dj;
|
|
378
|
+
nk_f16_t zero = {0};
|
|
379
|
+
vec_out[2 * dj + sub] = (ki < seq_len && d < head_dim) ? v_head[ki * head_dim + d] : zero;
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
NK_PUBLIC void nk_attention_pack_kv_f16_sme(nk_f16_t const *k, nk_f16_t const *v, nk_size_t num_kv_heads,
|
|
389
|
+
nk_size_t head_dim, nk_size_t seq_len, nk_size_t k_stride,
|
|
390
|
+
nk_size_t v_stride, void *kv_packed) {
|
|
391
|
+
nk_attention_pack_kv_f16_sme_streaming_(k, v, num_kv_heads, head_dim, seq_len, k_stride, v_stride, kv_packed);
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
/**
|
|
395
|
+
* @brief Optimized bf16 attention kernel with BFMOPA P×V.
|
|
396
|
+
*
|
|
397
|
+
* Key design choices:
|
|
398
|
+
* - P×V uses BFMOPA with pre-packed V (4-tile accumulation) instead of element-wise SVE
|
|
399
|
+
* - Scores read via column-wise vertical ZA reads for vectorized max/exp
|
|
400
|
+
* - Weights stored directly as bf16 (no f32 round-trip)
|
|
401
|
+
* - Uses degree-3 fast exp for softmax
|
|
402
|
+
* - Correction skip when running max is unchanged
|
|
403
|
+
* - Decode path (valid_query_count==1) remains element-wise SVE (BFMOPA overhead too high)
|
|
404
|
+
*/
|
|
405
|
+
__arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_streaming_(
|
|
406
|
+
nk_bf16_t const *q, // [query_len, head_dim]
|
|
407
|
+
nk_bf16_t const *k, // [kv_len, head_dim_padded] BFMOPA-interleaved
|
|
408
|
+
nk_bf16_t const *v_packed, // BFMOPA-interleaved V for this KV head
|
|
409
|
+
nk_bf16_t *output, // [query_len, head_dim]
|
|
410
|
+
nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim, nk_size_t head_dim_padded, nk_size_t dim_tile_count,
|
|
411
|
+
nk_f32_t scale) {
|
|
412
|
+
|
|
413
|
+
svbool_t const predicate_all_f32x = svptrue_b32();
|
|
414
|
+
svbool_t const predicate_all_f16x = svptrue_b16();
|
|
415
|
+
nk_size_t const valid_query_count = (query_len < 16) ? query_len : 16;
|
|
416
|
+
|
|
417
|
+
svfloat32_t row_max_f32x = svdup_f32(NK_F32_MIN);
|
|
418
|
+
svfloat32_t row_sum_f32x = svdup_f32(0.0f);
|
|
419
|
+
|
|
420
|
+
NK_ALIGN64 nk_f32_t output_accumulator[16 * 256];
|
|
421
|
+
svfloat32_t zero_f32x = svdup_f32(0.0f);
|
|
422
|
+
for (nk_size_t i = 0; i < 16 * head_dim_padded; i += svcntw()) {
|
|
423
|
+
svst1_f32(predicate_all_f32x, output_accumulator + i, zero_f32x);
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
nk_size_t kv_block_index = 0;
|
|
427
|
+
nk_size_t kv_start = 0;
|
|
428
|
+
svbool_t const batch_predicate_f32x = svwhilelt_b32(0u, 16u);
|
|
429
|
+
|
|
430
|
+
nk_size_t const k_depth_step_count = head_dim_padded / 2;
|
|
431
|
+
|
|
432
|
+
// Pre-transpose Q once: queries_transposed[step][16 f32 words]
|
|
433
|
+
NK_ALIGN64 nk_f32_t queries_transposed[128 * 16]; // max head_dim_padded/2 * 16 = 128 * 16
|
|
434
|
+
for (nk_size_t batch = 0; batch < head_dim_padded / 32; batch++) {
|
|
435
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
436
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
437
|
+
svld1_hor_za32(0, query_index, batch_predicate_f32x,
|
|
438
|
+
(nk_f32_t const *)(q + query_index * head_dim + batch * 32));
|
|
439
|
+
for (nk_size_t step = 0; step < 16; step++)
|
|
440
|
+
svst1_f32(predicate_all_f32x, queries_transposed + (batch * 16 + step) * 16,
|
|
441
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, step));
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
// Bc=32 main loop (prefill only, skipped for decode)
|
|
445
|
+
if (valid_query_count > 1) {
|
|
446
|
+
for (; kv_start + 32 <= kv_len; kv_start += 32, kv_block_index += 2) {
|
|
447
|
+
// Q×K^T: pure memory→BFMOPA, no ZA staging for Q or K
|
|
448
|
+
svzero_mask_za(nk_sme_zero_za32_tile_2_);
|
|
449
|
+
svzero_mask_za(nk_sme_zero_za32_tile_3_);
|
|
450
|
+
nk_bf16_t const *keys_block_lower = k + kv_block_index * k_depth_step_count * 32;
|
|
451
|
+
nk_bf16_t const *keys_block_upper = k + (kv_block_index + 1) * k_depth_step_count * 32;
|
|
452
|
+
for (nk_size_t step = 0; step < k_depth_step_count; step++) {
|
|
453
|
+
svbfloat16_t zn = svreinterpret_bf16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
|
|
454
|
+
svbfloat16_t zm0 = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(keys_block_lower + step * 32));
|
|
455
|
+
svbfloat16_t zm1 = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(keys_block_upper + step * 32));
|
|
456
|
+
svmopa_za32_bf16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm0);
|
|
457
|
+
svmopa_za32_bf16_m(3, predicate_all_f32x, predicate_all_f32x, zn, zm1);
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
// Pass 1: Column-wise max (read ZA2/ZA3 columns vertically)
|
|
461
|
+
svfloat32_t scale_f32x = svdup_f32(scale);
|
|
462
|
+
svfloat32_t block_max_f32x = svdup_f32(NK_F32_MIN);
|
|
463
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
464
|
+
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
465
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
|
|
466
|
+
scale_f32x);
|
|
467
|
+
block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
|
|
468
|
+
}
|
|
469
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
470
|
+
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
471
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
|
|
472
|
+
scale_f32x);
|
|
473
|
+
block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
// Softmax correction (fully vectorized)
|
|
477
|
+
svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, row_max_f32x, block_max_f32x);
|
|
478
|
+
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
|
|
479
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, row_max_f32x, new_max_f32x));
|
|
480
|
+
svbool_t max_changed = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
|
|
481
|
+
nk_u32_t max_was_updated = svptest_any(predicate_all_f32x, max_changed) ? 1 : 0;
|
|
482
|
+
if (max_was_updated) row_sum_f32x = svmul_f32_x(predicate_all_f32x, row_sum_f32x, correction_f32x);
|
|
483
|
+
NK_ALIGN64 nk_f32_t corrections[16];
|
|
484
|
+
svst1_f32(predicate_all_f32x, corrections, correction_f32x);
|
|
485
|
+
|
|
486
|
+
// Pass 2: Column-wise exp + fused P write + sum
|
|
487
|
+
svfloat32_t sum_delta_f32x = svdup_f32(0.0f);
|
|
488
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
489
|
+
// ZA2 columns in pairs → ZA0 columns 0-7
|
|
490
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
491
|
+
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
492
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
|
|
493
|
+
scale_f32x);
|
|
494
|
+
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
495
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
|
|
496
|
+
scale_f32x);
|
|
497
|
+
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
498
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
|
|
499
|
+
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
500
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
|
|
501
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
|
|
502
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
|
|
503
|
+
svbfloat16_t weight_pair_bf16 = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_f32x, weight_even_f32x),
|
|
504
|
+
nk_f32_to_bf16_sve_(predicate_all_f32x, weight_odd_f32x));
|
|
505
|
+
svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x,
|
|
506
|
+
svreinterpret_f32_bf16(weight_pair_bf16));
|
|
507
|
+
}
|
|
508
|
+
// ZA3 columns in pairs → ZA0 columns 8-15
|
|
509
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
510
|
+
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
511
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
|
|
512
|
+
scale_f32x);
|
|
513
|
+
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
514
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index + 1),
|
|
515
|
+
scale_f32x);
|
|
516
|
+
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
517
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
|
|
518
|
+
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
519
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
|
|
520
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
|
|
521
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
|
|
522
|
+
svbfloat16_t weight_pair_bf16 = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_f32x, weight_even_f32x),
|
|
523
|
+
nk_f32_to_bf16_sve_(predicate_all_f32x, weight_odd_f32x));
|
|
524
|
+
svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_f32x,
|
|
525
|
+
svreinterpret_f32_bf16(weight_pair_bf16));
|
|
526
|
+
}
|
|
527
|
+
row_sum_f32x = svadd_f32_x(predicate_all_f32x, row_sum_f32x, sum_delta_f32x);
|
|
528
|
+
row_max_f32x = new_max_f32x;
|
|
529
|
+
|
|
530
|
+
// Extract P columns from ZA0
|
|
531
|
+
svbfloat16_t probability_column_0_f32x = svreinterpret_bf16_f32(
|
|
532
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
|
|
533
|
+
svbfloat16_t probability_column_1_f32x = svreinterpret_bf16_f32(
|
|
534
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
|
|
535
|
+
svbfloat16_t probability_column_2_f32x = svreinterpret_bf16_f32(
|
|
536
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
|
|
537
|
+
svbfloat16_t probability_column_3_f32x = svreinterpret_bf16_f32(
|
|
538
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
|
|
539
|
+
svbfloat16_t probability_column_4_f32x = svreinterpret_bf16_f32(
|
|
540
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
|
|
541
|
+
svbfloat16_t probability_column_5_f32x = svreinterpret_bf16_f32(
|
|
542
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
|
|
543
|
+
svbfloat16_t probability_column_6_f32x = svreinterpret_bf16_f32(
|
|
544
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
|
|
545
|
+
svbfloat16_t probability_column_7_f32x = svreinterpret_bf16_f32(
|
|
546
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
|
|
547
|
+
svbfloat16_t probability_column_8_f32x = svreinterpret_bf16_f32(
|
|
548
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 8));
|
|
549
|
+
svbfloat16_t probability_column_9_f32x = svreinterpret_bf16_f32(
|
|
550
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 9));
|
|
551
|
+
svbfloat16_t probability_column_10_f32x = svreinterpret_bf16_f32(
|
|
552
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 10));
|
|
553
|
+
svbfloat16_t probability_column_11_f32x = svreinterpret_bf16_f32(
|
|
554
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 11));
|
|
555
|
+
svbfloat16_t probability_column_12_f32x = svreinterpret_bf16_f32(
|
|
556
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 12));
|
|
557
|
+
svbfloat16_t probability_column_13_f32x = svreinterpret_bf16_f32(
|
|
558
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 13));
|
|
559
|
+
svbfloat16_t probability_column_14_f32x = svreinterpret_bf16_f32(
|
|
560
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 14));
|
|
561
|
+
svbfloat16_t probability_column_15_f32x = svreinterpret_bf16_f32(
|
|
562
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 15));
|
|
563
|
+
|
|
564
|
+
// Pre-apply correction once before P×V
|
|
565
|
+
svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
|
|
566
|
+
nk_bf16_t const *values_block_lower = v_packed + kv_block_index * dim_tile_count * 8 * 32;
|
|
567
|
+
nk_bf16_t const *values_block_upper = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
|
|
568
|
+
|
|
569
|
+
if (max_was_updated) {
|
|
570
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
571
|
+
svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
|
|
572
|
+
for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
|
|
573
|
+
svst1_f32(
|
|
574
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
|
|
575
|
+
svmul_f32_x(predicate_all_f32x,
|
|
576
|
+
svld1_f32(predicate_all_f32x,
|
|
577
|
+
output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
578
|
+
correction_scalar_f32x));
|
|
579
|
+
}
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
// P×V: zero → BFMOPA → read → add (no ZA writes for output_accumulator)
|
|
583
|
+
nk_size_t dim_tile = 0;
|
|
584
|
+
for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
|
|
585
|
+
svzero_za();
|
|
586
|
+
// Block0: 8 depth steps (KV positions 0-15)
|
|
587
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
588
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
589
|
+
((dim_tile + 0) * 8 + 0) * 32)));
|
|
590
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
591
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
592
|
+
((dim_tile + 1) * 8 + 0) * 32)));
|
|
593
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
594
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
595
|
+
((dim_tile + 2) * 8 + 0) * 32)));
|
|
596
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
597
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
598
|
+
((dim_tile + 3) * 8 + 0) * 32)));
|
|
599
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
600
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
601
|
+
((dim_tile + 0) * 8 + 1) * 32)));
|
|
602
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
603
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
604
|
+
((dim_tile + 1) * 8 + 1) * 32)));
|
|
605
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
606
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
607
|
+
((dim_tile + 2) * 8 + 1) * 32)));
|
|
608
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
609
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
610
|
+
((dim_tile + 3) * 8 + 1) * 32)));
|
|
611
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
612
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
613
|
+
((dim_tile + 0) * 8 + 2) * 32)));
|
|
614
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
615
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
616
|
+
((dim_tile + 1) * 8 + 2) * 32)));
|
|
617
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
618
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
619
|
+
((dim_tile + 2) * 8 + 2) * 32)));
|
|
620
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
621
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
622
|
+
((dim_tile + 3) * 8 + 2) * 32)));
|
|
623
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
624
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
625
|
+
((dim_tile + 0) * 8 + 3) * 32)));
|
|
626
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
627
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
628
|
+
((dim_tile + 1) * 8 + 3) * 32)));
|
|
629
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
630
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
631
|
+
((dim_tile + 2) * 8 + 3) * 32)));
|
|
632
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
633
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
634
|
+
((dim_tile + 3) * 8 + 3) * 32)));
|
|
635
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
636
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
637
|
+
((dim_tile + 0) * 8 + 4) * 32)));
|
|
638
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
639
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
640
|
+
((dim_tile + 1) * 8 + 4) * 32)));
|
|
641
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
642
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
643
|
+
((dim_tile + 2) * 8 + 4) * 32)));
|
|
644
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
645
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
646
|
+
((dim_tile + 3) * 8 + 4) * 32)));
|
|
647
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
648
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
649
|
+
((dim_tile + 0) * 8 + 5) * 32)));
|
|
650
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
651
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
652
|
+
((dim_tile + 1) * 8 + 5) * 32)));
|
|
653
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
654
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
655
|
+
((dim_tile + 2) * 8 + 5) * 32)));
|
|
656
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
657
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
658
|
+
((dim_tile + 3) * 8 + 5) * 32)));
|
|
659
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
660
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
661
|
+
((dim_tile + 0) * 8 + 6) * 32)));
|
|
662
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
663
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
664
|
+
((dim_tile + 1) * 8 + 6) * 32)));
|
|
665
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
666
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
667
|
+
((dim_tile + 2) * 8 + 6) * 32)));
|
|
668
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
669
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
670
|
+
((dim_tile + 3) * 8 + 6) * 32)));
|
|
671
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
672
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
673
|
+
((dim_tile + 0) * 8 + 7) * 32)));
|
|
674
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
675
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
676
|
+
((dim_tile + 1) * 8 + 7) * 32)));
|
|
677
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
678
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
679
|
+
((dim_tile + 2) * 8 + 7) * 32)));
|
|
680
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
681
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower +
|
|
682
|
+
((dim_tile + 3) * 8 + 7) * 32)));
|
|
683
|
+
// Block1: 8 depth steps (KV positions 16-31)
|
|
684
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
685
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
686
|
+
((dim_tile + 0) * 8 + 0) * 32)));
|
|
687
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
688
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
689
|
+
((dim_tile + 1) * 8 + 0) * 32)));
|
|
690
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
691
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
692
|
+
((dim_tile + 2) * 8 + 0) * 32)));
|
|
693
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
694
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
695
|
+
((dim_tile + 3) * 8 + 0) * 32)));
|
|
696
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
697
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
698
|
+
((dim_tile + 0) * 8 + 1) * 32)));
|
|
699
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
700
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
701
|
+
((dim_tile + 1) * 8 + 1) * 32)));
|
|
702
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
703
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
704
|
+
((dim_tile + 2) * 8 + 1) * 32)));
|
|
705
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
706
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
707
|
+
((dim_tile + 3) * 8 + 1) * 32)));
|
|
708
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
709
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
710
|
+
((dim_tile + 0) * 8 + 2) * 32)));
|
|
711
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
712
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
713
|
+
((dim_tile + 1) * 8 + 2) * 32)));
|
|
714
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
715
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
716
|
+
((dim_tile + 2) * 8 + 2) * 32)));
|
|
717
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
718
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
719
|
+
((dim_tile + 3) * 8 + 2) * 32)));
|
|
720
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
721
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
722
|
+
((dim_tile + 0) * 8 + 3) * 32)));
|
|
723
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
724
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
725
|
+
((dim_tile + 1) * 8 + 3) * 32)));
|
|
726
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
727
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
728
|
+
((dim_tile + 2) * 8 + 3) * 32)));
|
|
729
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
730
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
731
|
+
((dim_tile + 3) * 8 + 3) * 32)));
|
|
732
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
733
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
734
|
+
((dim_tile + 0) * 8 + 4) * 32)));
|
|
735
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
736
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
737
|
+
((dim_tile + 1) * 8 + 4) * 32)));
|
|
738
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
739
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
740
|
+
((dim_tile + 2) * 8 + 4) * 32)));
|
|
741
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
742
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
743
|
+
((dim_tile + 3) * 8 + 4) * 32)));
|
|
744
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
745
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
746
|
+
((dim_tile + 0) * 8 + 5) * 32)));
|
|
747
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
748
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
749
|
+
((dim_tile + 1) * 8 + 5) * 32)));
|
|
750
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
751
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
752
|
+
((dim_tile + 2) * 8 + 5) * 32)));
|
|
753
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
754
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
755
|
+
((dim_tile + 3) * 8 + 5) * 32)));
|
|
756
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
757
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
758
|
+
((dim_tile + 0) * 8 + 6) * 32)));
|
|
759
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
760
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
761
|
+
((dim_tile + 1) * 8 + 6) * 32)));
|
|
762
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
763
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
764
|
+
((dim_tile + 2) * 8 + 6) * 32)));
|
|
765
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
766
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
767
|
+
((dim_tile + 3) * 8 + 6) * 32)));
|
|
768
|
+
svmopa_za32_bf16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
769
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
770
|
+
((dim_tile + 0) * 8 + 7) * 32)));
|
|
771
|
+
svmopa_za32_bf16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
772
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
773
|
+
((dim_tile + 1) * 8 + 7) * 32)));
|
|
774
|
+
svmopa_za32_bf16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
775
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
776
|
+
((dim_tile + 2) * 8 + 7) * 32)));
|
|
777
|
+
svmopa_za32_bf16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
778
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper +
|
|
779
|
+
((dim_tile + 3) * 8 + 7) * 32)));
|
|
780
|
+
// Read BFMOPA result and ADD to output_accumulator
|
|
781
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
782
|
+
svst1_f32(
|
|
783
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
|
|
784
|
+
svadd_f32_x(predicate_all_f32x,
|
|
785
|
+
svld1_f32(predicate_all_f32x,
|
|
786
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
|
|
787
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
|
|
788
|
+
svst1_f32(
|
|
789
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
|
|
790
|
+
svadd_f32_x(predicate_all_f32x,
|
|
791
|
+
svld1_f32(predicate_all_f32x,
|
|
792
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
|
|
793
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
|
|
794
|
+
svst1_f32(
|
|
795
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
|
|
796
|
+
svadd_f32_x(predicate_all_f32x,
|
|
797
|
+
svld1_f32(predicate_all_f32x,
|
|
798
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
|
|
799
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
|
|
800
|
+
svst1_f32(
|
|
801
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
|
|
802
|
+
svadd_f32_x(predicate_all_f32x,
|
|
803
|
+
svld1_f32(predicate_all_f32x,
|
|
804
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
|
|
805
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
|
|
806
|
+
}
|
|
807
|
+
}
|
|
808
|
+
// Remainder: 1 dim_tile at a time using ZA0
|
|
809
|
+
for (; dim_tile < dim_tile_count; dim_tile++) {
|
|
810
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
811
|
+
svmopa_za32_bf16_m(
|
|
812
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
813
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 0) * 32)));
|
|
814
|
+
svmopa_za32_bf16_m(
|
|
815
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
816
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 1) * 32)));
|
|
817
|
+
svmopa_za32_bf16_m(
|
|
818
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
819
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 2) * 32)));
|
|
820
|
+
svmopa_za32_bf16_m(
|
|
821
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
822
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 3) * 32)));
|
|
823
|
+
svmopa_za32_bf16_m(
|
|
824
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
825
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 4) * 32)));
|
|
826
|
+
svmopa_za32_bf16_m(
|
|
827
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
828
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 5) * 32)));
|
|
829
|
+
svmopa_za32_bf16_m(
|
|
830
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
831
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 6) * 32)));
|
|
832
|
+
svmopa_za32_bf16_m(
|
|
833
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
834
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_lower + (dim_tile * 8 + 7) * 32)));
|
|
835
|
+
svmopa_za32_bf16_m(
|
|
836
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
837
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 0) * 32)));
|
|
838
|
+
svmopa_za32_bf16_m(
|
|
839
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
840
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 1) * 32)));
|
|
841
|
+
svmopa_za32_bf16_m(
|
|
842
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
843
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 2) * 32)));
|
|
844
|
+
svmopa_za32_bf16_m(
|
|
845
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
846
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 3) * 32)));
|
|
847
|
+
svmopa_za32_bf16_m(
|
|
848
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
849
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 4) * 32)));
|
|
850
|
+
svmopa_za32_bf16_m(
|
|
851
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
852
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 5) * 32)));
|
|
853
|
+
svmopa_za32_bf16_m(
|
|
854
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
855
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 6) * 32)));
|
|
856
|
+
svmopa_za32_bf16_m(
|
|
857
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
858
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(values_block_upper + (dim_tile * 8 + 7) * 32)));
|
|
859
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
860
|
+
svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
|
|
861
|
+
svadd_f32_x(predicate_all_f32x,
|
|
862
|
+
svld1_f32(predicate_all_f32x,
|
|
863
|
+
output_accumulator + query_index * head_dim_padded + dim_tile * 16),
|
|
864
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
|
|
865
|
+
}
|
|
866
|
+
}
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
// Bc=16 tail loop (handles remaining KV positions and decode path)
|
|
870
|
+
for (; kv_start < kv_len; kv_start += 16, kv_block_index++) {
|
|
871
|
+
nk_size_t const valid_kv = ((kv_start + 16) <= kv_len) ? 16 : (kv_len - kv_start);
|
|
872
|
+
|
|
873
|
+
// Q×K^T: pure memory→BFMOPA, no ZA staging
|
|
874
|
+
svzero_mask_za(nk_sme_zero_za32_tile_2_);
|
|
875
|
+
nk_bf16_t const *k_block = k + kv_block_index * k_depth_step_count * 32;
|
|
876
|
+
for (nk_size_t step = 0; step < k_depth_step_count; step++) {
|
|
877
|
+
svbfloat16_t zn = svreinterpret_bf16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
|
|
878
|
+
svbfloat16_t zm = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(k_block + step * 32));
|
|
879
|
+
svmopa_za32_bf16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm);
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
// Pass 1: Column-wise max (read ZA2 columns vertically)
|
|
883
|
+
svfloat32_t scale_16_f32x = svdup_f32(scale);
|
|
884
|
+
svfloat32_t block_max_16_f32x = svdup_f32(NK_F32_MIN);
|
|
885
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
886
|
+
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
887
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
|
|
888
|
+
scale_16_f32x);
|
|
889
|
+
block_max_16_f32x = svmax_f32_x(predicate_all_f32x, block_max_16_f32x, score_column_f32x);
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
// Softmax correction (fully vectorized)
|
|
893
|
+
svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, row_max_f32x, block_max_16_f32x);
|
|
894
|
+
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_f32x,
|
|
895
|
+
svsub_f32_x(predicate_all_f32x, row_max_f32x, new_max_f32x));
|
|
896
|
+
svbool_t max_changed_16 = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
|
|
897
|
+
nk_u32_t max_was_updated_16 = svptest_any(predicate_all_f32x, max_changed_16) ? 1 : 0;
|
|
898
|
+
if (max_was_updated_16) row_sum_f32x = svmul_f32_x(predicate_all_f32x, row_sum_f32x, correction_f32x);
|
|
899
|
+
NK_ALIGN64 nk_f32_t corrections[16];
|
|
900
|
+
svst1_f32(predicate_all_f32x, corrections, correction_f32x);
|
|
901
|
+
|
|
902
|
+
// Pass 2: Column-wise exp + fused P write + sum (ZA2 → ZA0 columns 0-7)
|
|
903
|
+
svfloat32_t sum_delta_16_f32x = svdup_f32(0.0f);
|
|
904
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
905
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
906
|
+
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
907
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
|
|
908
|
+
scale_16_f32x);
|
|
909
|
+
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
910
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
|
|
911
|
+
scale_16_f32x);
|
|
912
|
+
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
913
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
|
|
914
|
+
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
915
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
|
|
916
|
+
sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_even_f32x);
|
|
917
|
+
sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_odd_f32x);
|
|
918
|
+
svbfloat16_t weight_pair_bf16 = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_f32x, weight_even_f32x),
|
|
919
|
+
nk_f32_to_bf16_sve_(predicate_all_f32x, weight_odd_f32x));
|
|
920
|
+
svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x, svreinterpret_f32_bf16(weight_pair_bf16));
|
|
921
|
+
}
|
|
922
|
+
row_sum_f32x = svadd_f32_x(predicate_all_f32x, row_sum_f32x, sum_delta_16_f32x);
|
|
923
|
+
row_max_f32x = new_max_f32x;
|
|
924
|
+
|
|
925
|
+
if (valid_query_count == 1) {
|
|
926
|
+
// Decode path: extract f32 weights from ZA0 row 0 using SVE
|
|
927
|
+
svbfloat16_t row0_bf16 = svreinterpret_bf16_f32(
|
|
928
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
|
|
929
|
+
svbfloat16_t weights_even_bf16 = svuzp1_bf16(row0_bf16, row0_bf16);
|
|
930
|
+
svbfloat16_t weights_odd_bf16 = svuzp2_bf16(row0_bf16, row0_bf16);
|
|
931
|
+
NK_ALIGN64 nk_f32_t decode_weights[16];
|
|
932
|
+
svst1_f32(svwhilelt_b32(0u, 8u), decode_weights,
|
|
933
|
+
nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_even_bf16));
|
|
934
|
+
svst1_f32(svwhilelt_b32(0u, 8u), decode_weights + 8,
|
|
935
|
+
nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_odd_bf16));
|
|
936
|
+
NK_ALIGN64 nk_f32_t decode_weights_ordered[16];
|
|
937
|
+
for (nk_size_t i = 0; i < 8; i++) {
|
|
938
|
+
decode_weights_ordered[2 * i] = decode_weights[i];
|
|
939
|
+
decode_weights_ordered[2 * i + 1] = decode_weights[8 + i];
|
|
940
|
+
}
|
|
941
|
+
svfloat32_t corr_f32x = svdup_f32(corrections[0]);
|
|
942
|
+
for (nk_size_t d = 0; d < head_dim; d += svcntw()) {
|
|
943
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(d, head_dim);
|
|
944
|
+
svfloat32_t acc_f32x = svmul_f32_x(predicate_f32x, svld1_f32(predicate_f32x, output_accumulator + d),
|
|
945
|
+
corr_f32x);
|
|
946
|
+
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
947
|
+
nk_size_t dim_tile = d / 16, depth_s = ki / 2, sub = ki % 2;
|
|
948
|
+
nk_bf16_t const *v_vec = v_packed +
|
|
949
|
+
(kv_block_index * dim_tile_count * 8 + dim_tile * 8 + depth_s) * 32;
|
|
950
|
+
svbfloat16_t packed_bf16x = svld1_bf16(predicate_all_f16x, (bfloat16_t const *)v_vec);
|
|
951
|
+
svbfloat16_t v_selected = (sub == 0) ? svuzp1_bf16(packed_bf16x, packed_bf16x)
|
|
952
|
+
: svuzp2_bf16(packed_bf16x, packed_bf16x);
|
|
953
|
+
acc_f32x = svmla_f32_x(predicate_f32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
|
|
954
|
+
nk_bf16_to_f32_sve_(predicate_f32x, v_selected));
|
|
955
|
+
}
|
|
956
|
+
svst1_f32(predicate_f32x, output_accumulator + d, acc_f32x);
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
else {
|
|
960
|
+
// Prefill Bc=16: extract P columns, pre-apply correction, add-after P×V
|
|
961
|
+
svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
|
|
962
|
+
|
|
963
|
+
svbfloat16_t probability_column_0_f32x = svreinterpret_bf16_f32(
|
|
964
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
|
|
965
|
+
svbfloat16_t probability_column_1_f32x = svreinterpret_bf16_f32(
|
|
966
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
|
|
967
|
+
svbfloat16_t probability_column_2_f32x = svreinterpret_bf16_f32(
|
|
968
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
|
|
969
|
+
svbfloat16_t probability_column_3_f32x = svreinterpret_bf16_f32(
|
|
970
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
|
|
971
|
+
svbfloat16_t probability_column_4_f32x = svreinterpret_bf16_f32(
|
|
972
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
|
|
973
|
+
svbfloat16_t probability_column_5_f32x = svreinterpret_bf16_f32(
|
|
974
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
|
|
975
|
+
svbfloat16_t probability_column_6_f32x = svreinterpret_bf16_f32(
|
|
976
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
|
|
977
|
+
svbfloat16_t probability_column_7_f32x = svreinterpret_bf16_f32(
|
|
978
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
|
|
979
|
+
|
|
980
|
+
nk_bf16_t const *v_block = v_packed + kv_block_index * dim_tile_count * 8 * 32;
|
|
981
|
+
|
|
982
|
+
// Pre-apply correction
|
|
983
|
+
if (max_was_updated_16) {
|
|
984
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
985
|
+
svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
|
|
986
|
+
for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
|
|
987
|
+
svst1_f32(
|
|
988
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
|
|
989
|
+
svmul_f32_x(predicate_all_f32x,
|
|
990
|
+
svld1_f32(predicate_all_f32x,
|
|
991
|
+
output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
992
|
+
correction_scalar_f32x));
|
|
993
|
+
}
|
|
994
|
+
}
|
|
995
|
+
|
|
996
|
+
// P×V: zero → BFMOPA → read → add
|
|
997
|
+
nk_size_t dim_tile = 0;
|
|
998
|
+
for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
|
|
999
|
+
svzero_za();
|
|
1000
|
+
svmopa_za32_bf16_m(
|
|
1001
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1002
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
1003
|
+
svmopa_za32_bf16_m(
|
|
1004
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1005
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
1006
|
+
svmopa_za32_bf16_m(
|
|
1007
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1008
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
1009
|
+
svmopa_za32_bf16_m(
|
|
1010
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1011
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
1012
|
+
svmopa_za32_bf16_m(
|
|
1013
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1014
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
1015
|
+
svmopa_za32_bf16_m(
|
|
1016
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1017
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
1018
|
+
svmopa_za32_bf16_m(
|
|
1019
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1020
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
1021
|
+
svmopa_za32_bf16_m(
|
|
1022
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1023
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
1024
|
+
svmopa_za32_bf16_m(
|
|
1025
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1026
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
1027
|
+
svmopa_za32_bf16_m(
|
|
1028
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1029
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
1030
|
+
svmopa_za32_bf16_m(
|
|
1031
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1032
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
1033
|
+
svmopa_za32_bf16_m(
|
|
1034
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1035
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
1036
|
+
svmopa_za32_bf16_m(
|
|
1037
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1038
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
1039
|
+
svmopa_za32_bf16_m(
|
|
1040
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1041
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
1042
|
+
svmopa_za32_bf16_m(
|
|
1043
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1044
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
1045
|
+
svmopa_za32_bf16_m(
|
|
1046
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1047
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
1048
|
+
svmopa_za32_bf16_m(
|
|
1049
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1050
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
1051
|
+
svmopa_za32_bf16_m(
|
|
1052
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1053
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
1054
|
+
svmopa_za32_bf16_m(
|
|
1055
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1056
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
1057
|
+
svmopa_za32_bf16_m(
|
|
1058
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1059
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
1060
|
+
svmopa_za32_bf16_m(
|
|
1061
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1062
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
1063
|
+
svmopa_za32_bf16_m(
|
|
1064
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1065
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
1066
|
+
svmopa_za32_bf16_m(
|
|
1067
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1068
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
1069
|
+
svmopa_za32_bf16_m(
|
|
1070
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1071
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
1072
|
+
svmopa_za32_bf16_m(
|
|
1073
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1074
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
1075
|
+
svmopa_za32_bf16_m(
|
|
1076
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1077
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
1078
|
+
svmopa_za32_bf16_m(
|
|
1079
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1080
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
1081
|
+
svmopa_za32_bf16_m(
|
|
1082
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1083
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
1084
|
+
svmopa_za32_bf16_m(
|
|
1085
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1086
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
1087
|
+
svmopa_za32_bf16_m(
|
|
1088
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1089
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
1090
|
+
svmopa_za32_bf16_m(
|
|
1091
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1092
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
1093
|
+
svmopa_za32_bf16_m(
|
|
1094
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1095
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
1096
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1097
|
+
svst1_f32(
|
|
1098
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
|
|
1099
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1100
|
+
svld1_f32(predicate_all_f32x,
|
|
1101
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
|
|
1102
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
|
|
1103
|
+
svst1_f32(
|
|
1104
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
|
|
1105
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1106
|
+
svld1_f32(predicate_all_f32x,
|
|
1107
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
|
|
1108
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
|
|
1109
|
+
svst1_f32(
|
|
1110
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
|
|
1111
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1112
|
+
svld1_f32(predicate_all_f32x,
|
|
1113
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
|
|
1114
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
|
|
1115
|
+
svst1_f32(
|
|
1116
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
|
|
1117
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1118
|
+
svld1_f32(predicate_all_f32x,
|
|
1119
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
|
|
1120
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
|
|
1121
|
+
}
|
|
1122
|
+
}
|
|
1123
|
+
for (; dim_tile < dim_tile_count; dim_tile++) {
|
|
1124
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1125
|
+
svmopa_za32_bf16_m(
|
|
1126
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1127
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
|
|
1128
|
+
svmopa_za32_bf16_m(
|
|
1129
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1130
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
|
|
1131
|
+
svmopa_za32_bf16_m(
|
|
1132
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1133
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
|
|
1134
|
+
svmopa_za32_bf16_m(
|
|
1135
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1136
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
|
|
1137
|
+
svmopa_za32_bf16_m(
|
|
1138
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1139
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
|
|
1140
|
+
svmopa_za32_bf16_m(
|
|
1141
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1142
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
|
|
1143
|
+
svmopa_za32_bf16_m(
|
|
1144
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1145
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
|
|
1146
|
+
svmopa_za32_bf16_m(
|
|
1147
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1148
|
+
svld1_bf16(predicate_all_f16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
|
|
1149
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
1150
|
+
svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
|
|
1151
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1152
|
+
svld1_f32(predicate_all_f32x,
|
|
1153
|
+
output_accumulator + query_index * head_dim_padded + dim_tile * 16),
|
|
1154
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
|
|
1155
|
+
}
|
|
1156
|
+
}
|
|
1157
|
+
}
|
|
1158
|
+
|
|
1159
|
+
// Final normalization
|
|
1160
|
+
NK_ALIGN64 nk_f32_t final_sums[16];
|
|
1161
|
+
svst1_f32(predicate_all_f32x, final_sums, row_sum_f32x);
|
|
1162
|
+
|
|
1163
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1164
|
+
nk_f32_t inv_sum = (final_sums[query_index] > 0.0f) ? (1.0f / final_sums[query_index]) : 0.0f;
|
|
1165
|
+
svfloat32_t inv_sum_f32x = svdup_f32(inv_sum);
|
|
1166
|
+
|
|
1167
|
+
for (nk_size_t dim_offset = 0; dim_offset < head_dim; dim_offset += svcntw()) {
|
|
1168
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(dim_offset, head_dim);
|
|
1169
|
+
svfloat32_t output_f32x = svmul_f32_x(
|
|
1170
|
+
predicate_f32x,
|
|
1171
|
+
svld1_f32(predicate_f32x, output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
1172
|
+
inv_sum_f32x);
|
|
1173
|
+
svbfloat16_t output_bf16x = nk_f32_to_bf16_sve_(predicate_f32x, output_f32x);
|
|
1174
|
+
nk_size_t store_count = (head_dim - dim_offset) < (nk_size_t)svcntw() ? (head_dim - dim_offset)
|
|
1175
|
+
: (nk_size_t)svcntw();
|
|
1176
|
+
svbool_t store_predicate_f16x = svwhilelt_b16_u64(0u, store_count);
|
|
1177
|
+
svst1_bf16(store_predicate_f16x, (bfloat16_t *)(output + query_index * head_dim + dim_offset),
|
|
1178
|
+
output_bf16x);
|
|
1179
|
+
}
|
|
1180
|
+
}
|
|
1181
|
+
}
|
|
1182
|
+
|
|
1183
|
+
NK_PUBLIC void nk_attention_bf16_sme(nk_bf16_t const *q, void const *kv_packed, nk_bf16_t *output, nk_size_t num_heads,
|
|
1184
|
+
nk_size_t num_kv_heads, nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim,
|
|
1185
|
+
nk_f32_t scale) {
|
|
1186
|
+
|
|
1187
|
+
nk_attention_sme_packed_header_t const *header = (nk_attention_sme_packed_header_t const *)kv_packed;
|
|
1188
|
+
nk_size_t head_dim_padded = header->head_dim_padded;
|
|
1189
|
+
nk_size_t dim_tile_count = header->reserved[0]; // v_dim_tile_count
|
|
1190
|
+
nk_size_t kv_blocks = (kv_len + 15) / 16;
|
|
1191
|
+
nk_size_t kv_head_stride = kv_blocks * 16 * head_dim_padded;
|
|
1192
|
+
|
|
1193
|
+
nk_bf16_t const *k_packed = (nk_bf16_t const *)((char const *)kv_packed + header->k_offset);
|
|
1194
|
+
nk_bf16_t const *v_packed = (nk_bf16_t const *)((char const *)kv_packed + header->v_offset);
|
|
1195
|
+
|
|
1196
|
+
nk_size_t group_size = (num_kv_heads > 0) ? num_heads / num_kv_heads : 1;
|
|
1197
|
+
|
|
1198
|
+
for (nk_size_t q_head = 0; q_head < num_heads; q_head++) {
|
|
1199
|
+
nk_size_t kv_head = q_head / group_size;
|
|
1200
|
+
|
|
1201
|
+
nk_bf16_t const *q_ptr = q + q_head * query_len * head_dim;
|
|
1202
|
+
nk_bf16_t const *k_ptr = k_packed + kv_head * kv_head_stride;
|
|
1203
|
+
nk_bf16_t const *v_ptr = v_packed + kv_head * kv_head_stride;
|
|
1204
|
+
nk_bf16_t *out_ptr = output + q_head * query_len * head_dim;
|
|
1205
|
+
|
|
1206
|
+
for (nk_size_t q_start = 0; q_start < query_len; q_start += 16) {
|
|
1207
|
+
nk_size_t q_block_len = (q_start + 16 < query_len) ? 16 : (query_len - q_start);
|
|
1208
|
+
|
|
1209
|
+
nk_attention_bf16_sme_streaming_(q_ptr + q_start * head_dim, k_ptr, v_ptr, out_ptr + q_start * head_dim,
|
|
1210
|
+
q_block_len, kv_len, head_dim, head_dim_padded, dim_tile_count, scale);
|
|
1211
|
+
}
|
|
1212
|
+
}
|
|
1213
|
+
}
|
|
1214
|
+
|
|
1215
|
+
__arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streaming_(
|
|
1216
|
+
nk_f16_t const *q, // [query_len, head_dim]
|
|
1217
|
+
nk_f16_t const *k, // [kv_len, head_dim_padded] FMOPA-interleaved
|
|
1218
|
+
nk_f16_t const *v_packed, // FMOPA-interleaved V for this KV head
|
|
1219
|
+
nk_f16_t *output, // [query_len, head_dim]
|
|
1220
|
+
nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim, nk_size_t head_dim_padded, nk_size_t dim_tile_count,
|
|
1221
|
+
nk_f32_t scale) {
|
|
1222
|
+
|
|
1223
|
+
svbool_t const predicate_all_f32x = svptrue_b32();
|
|
1224
|
+
svbool_t const predicate_all_f16x = svptrue_b16();
|
|
1225
|
+
nk_size_t const valid_query_count = (query_len < 16) ? query_len : 16;
|
|
1226
|
+
|
|
1227
|
+
NK_ALIGN64 nk_f32_t row_max[16];
|
|
1228
|
+
NK_ALIGN64 nk_f32_t row_sum[16];
|
|
1229
|
+
NK_ALIGN64 nk_f32_t output_accumulator[16 * 256];
|
|
1230
|
+
|
|
1231
|
+
svst1_f32(predicate_all_f32x, row_max, svdup_f32(NK_F32_MIN));
|
|
1232
|
+
svst1_f32(predicate_all_f32x, row_sum, svdup_f32(0.0f));
|
|
1233
|
+
svfloat32_t zero_f32x = svdup_f32(0.0f);
|
|
1234
|
+
for (nk_size_t i = 0; i < 16 * head_dim_padded; i += svcntw()) {
|
|
1235
|
+
svst1_f32(predicate_all_f32x, output_accumulator + i, zero_f32x);
|
|
1236
|
+
}
|
|
1237
|
+
|
|
1238
|
+
nk_size_t kv_block_index = 0;
|
|
1239
|
+
nk_size_t kv_start = 0;
|
|
1240
|
+
svbool_t const batch_predicate_f32x = svwhilelt_b32(0u, 16u);
|
|
1241
|
+
|
|
1242
|
+
nk_size_t const k_depth_step_count = head_dim_padded / 2;
|
|
1243
|
+
|
|
1244
|
+
// Pre-transpose Q once: queries_transposed[step][16 f32 words]
|
|
1245
|
+
// queries_transposed[step] reinterpret-as-f16 = {Q[0][2s], Q[0][2s+1], Q[1][2s], Q[1][2s+1], ...}
|
|
1246
|
+
// This is the same interleaving ZA0 vertical reads would produce.
|
|
1247
|
+
NK_ALIGN64 nk_f32_t queries_transposed[128 * 16]; // max head_dim_padded/2 * 16 = 128 * 16
|
|
1248
|
+
for (nk_size_t batch = 0; batch < head_dim_padded / 32; batch++) {
|
|
1249
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1250
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
1251
|
+
svld1_hor_za32(0, query_index, batch_predicate_f32x,
|
|
1252
|
+
(nk_f32_t const *)(q + query_index * head_dim + batch * 32));
|
|
1253
|
+
for (nk_size_t step = 0; step < 16; step++)
|
|
1254
|
+
svst1_f32(predicate_all_f32x, queries_transposed + (batch * 16 + step) * 16,
|
|
1255
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, step));
|
|
1256
|
+
}
|
|
1257
|
+
|
|
1258
|
+
// === Bc=32 main loop (prefill only, skipped for decode) ===
|
|
1259
|
+
if (valid_query_count > 1) {
|
|
1260
|
+
for (; kv_start + 32 <= kv_len; kv_start += 32, kv_block_index += 2) {
|
|
1261
|
+
// Q×K^T: pure memory→FMOPA, no ZA staging for Q or K
|
|
1262
|
+
svzero_mask_za(nk_sme_zero_za32_tile_2_);
|
|
1263
|
+
svzero_mask_za(nk_sme_zero_za32_tile_3_);
|
|
1264
|
+
nk_f16_t const *keys_block_lower = k + kv_block_index * k_depth_step_count * 32;
|
|
1265
|
+
nk_f16_t const *keys_block_upper = k + (kv_block_index + 1) * k_depth_step_count * 32;
|
|
1266
|
+
for (nk_size_t step = 0; step < k_depth_step_count; step++) {
|
|
1267
|
+
svfloat16_t zn = svreinterpret_f16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
|
|
1268
|
+
svfloat16_t zm0 = svld1_f16(predicate_all_f16x, (float16_t const *)(keys_block_lower + step * 32));
|
|
1269
|
+
svfloat16_t zm1 = svld1_f16(predicate_all_f16x, (float16_t const *)(keys_block_upper + step * 32));
|
|
1270
|
+
svmopa_za32_f16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm0);
|
|
1271
|
+
svmopa_za32_f16_m(3, predicate_all_f32x, predicate_all_f32x, zn, zm1);
|
|
1272
|
+
}
|
|
1273
|
+
// ZA2 = scores[query_index][0:15], ZA3 = scores[query_index][16:31]
|
|
1274
|
+
|
|
1275
|
+
// Pass 1: Column-wise max (read ZA2/ZA3 columns vertically)
|
|
1276
|
+
svfloat32_t scale_f32x = svdup_f32(scale);
|
|
1277
|
+
svfloat32_t block_max_f32x = svdup_f32(NK_F32_MIN);
|
|
1278
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
1279
|
+
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
1280
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
|
|
1281
|
+
scale_f32x);
|
|
1282
|
+
block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
|
|
1283
|
+
}
|
|
1284
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
1285
|
+
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
1286
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
|
|
1287
|
+
scale_f32x);
|
|
1288
|
+
block_max_f32x = svmax_f32_x(predicate_all_f32x, block_max_f32x, score_column_f32x);
|
|
1289
|
+
}
|
|
1290
|
+
|
|
1291
|
+
// Softmax correction (vectorized via array load/store)
|
|
1292
|
+
svfloat32_t old_max_f32x = svld1_f32(predicate_all_f32x, row_max);
|
|
1293
|
+
svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, old_max_f32x, block_max_f32x);
|
|
1294
|
+
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
|
|
1295
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, old_max_f32x, new_max_f32x));
|
|
1296
|
+
svbool_t max_changed = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
|
|
1297
|
+
nk_u32_t max_was_updated = svptest_any(predicate_all_f32x, max_changed) ? 1 : 0;
|
|
1298
|
+
svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_f32x, row_sum);
|
|
1299
|
+
if (max_was_updated)
|
|
1300
|
+
row_sum_corrected_f32x = svmul_f32_x(predicate_all_f32x, row_sum_corrected_f32x, correction_f32x);
|
|
1301
|
+
NK_ALIGN64 nk_f32_t corrections[16];
|
|
1302
|
+
svst1_f32(predicate_all_f32x, corrections, correction_f32x);
|
|
1303
|
+
|
|
1304
|
+
// Pass 2: Column-wise exp + fused P write + sum
|
|
1305
|
+
svfloat32_t sum_delta_f32x = svdup_f32(0.0f);
|
|
1306
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1307
|
+
// ZA2 columns in pairs -> ZA0 columns 0-7
|
|
1308
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
1309
|
+
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
1310
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
|
|
1311
|
+
scale_f32x);
|
|
1312
|
+
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
1313
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
|
|
1314
|
+
scale_f32x);
|
|
1315
|
+
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
1316
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
|
|
1317
|
+
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
1318
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
|
|
1319
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
|
|
1320
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
|
|
1321
|
+
svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_f32x, weight_even_f32x),
|
|
1322
|
+
svcvt_f16_f32_x(predicate_all_f32x, weight_odd_f32x));
|
|
1323
|
+
svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x,
|
|
1324
|
+
svreinterpret_f32_f16(weight_pair_f16x));
|
|
1325
|
+
}
|
|
1326
|
+
// ZA3 columns in pairs -> ZA0 columns 8-15
|
|
1327
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
1328
|
+
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
1329
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index),
|
|
1330
|
+
scale_f32x);
|
|
1331
|
+
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
1332
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, column_index + 1),
|
|
1333
|
+
scale_f32x);
|
|
1334
|
+
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
1335
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
|
|
1336
|
+
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
1337
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
|
|
1338
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_even_f32x);
|
|
1339
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_f32x, weight_odd_f32x);
|
|
1340
|
+
svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_f32x, weight_even_f32x),
|
|
1341
|
+
svcvt_f16_f32_x(predicate_all_f32x, weight_odd_f32x));
|
|
1342
|
+
svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_f32x,
|
|
1343
|
+
svreinterpret_f32_f16(weight_pair_f16x));
|
|
1344
|
+
}
|
|
1345
|
+
row_sum_corrected_f32x = svadd_f32_x(predicate_all_f32x, row_sum_corrected_f32x, sum_delta_f32x);
|
|
1346
|
+
svst1_f32(predicate_all_f32x, row_sum, row_sum_corrected_f32x);
|
|
1347
|
+
svst1_f32(predicate_all_f32x, row_max, new_max_f32x);
|
|
1348
|
+
|
|
1349
|
+
// Extract P columns from ZA0
|
|
1350
|
+
svfloat16_t probability_column_0_f32x = svreinterpret_f16_f32(
|
|
1351
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
|
|
1352
|
+
svfloat16_t probability_column_1_f32x = svreinterpret_f16_f32(
|
|
1353
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
|
|
1354
|
+
svfloat16_t probability_column_2_f32x = svreinterpret_f16_f32(
|
|
1355
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
|
|
1356
|
+
svfloat16_t probability_column_3_f32x = svreinterpret_f16_f32(
|
|
1357
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
|
|
1358
|
+
svfloat16_t probability_column_4_f32x = svreinterpret_f16_f32(
|
|
1359
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
|
|
1360
|
+
svfloat16_t probability_column_5_f32x = svreinterpret_f16_f32(
|
|
1361
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
|
|
1362
|
+
svfloat16_t probability_column_6_f32x = svreinterpret_f16_f32(
|
|
1363
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
|
|
1364
|
+
svfloat16_t probability_column_7_f32x = svreinterpret_f16_f32(
|
|
1365
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
|
|
1366
|
+
svfloat16_t probability_column_8_f32x = svreinterpret_f16_f32(
|
|
1367
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 8));
|
|
1368
|
+
svfloat16_t probability_column_9_f32x = svreinterpret_f16_f32(
|
|
1369
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 9));
|
|
1370
|
+
svfloat16_t probability_column_10_f32x = svreinterpret_f16_f32(
|
|
1371
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 10));
|
|
1372
|
+
svfloat16_t probability_column_11_f32x = svreinterpret_f16_f32(
|
|
1373
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 11));
|
|
1374
|
+
svfloat16_t probability_column_12_f32x = svreinterpret_f16_f32(
|
|
1375
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 12));
|
|
1376
|
+
svfloat16_t probability_column_13_f32x = svreinterpret_f16_f32(
|
|
1377
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 13));
|
|
1378
|
+
svfloat16_t probability_column_14_f32x = svreinterpret_f16_f32(
|
|
1379
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 14));
|
|
1380
|
+
svfloat16_t probability_column_15_f32x = svreinterpret_f16_f32(
|
|
1381
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 15));
|
|
1382
|
+
|
|
1383
|
+
// Pre-apply correction once before P×V
|
|
1384
|
+
svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
|
|
1385
|
+
nk_f16_t const *values_block_lower = v_packed + kv_block_index * dim_tile_count * 8 * 32;
|
|
1386
|
+
nk_f16_t const *values_block_upper = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
|
|
1387
|
+
|
|
1388
|
+
if (max_was_updated) {
|
|
1389
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1390
|
+
svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
|
|
1391
|
+
for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
|
|
1392
|
+
svst1_f32(
|
|
1393
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
|
|
1394
|
+
svmul_f32_x(predicate_all_f32x,
|
|
1395
|
+
svld1_f32(predicate_all_f32x,
|
|
1396
|
+
output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
1397
|
+
correction_scalar_f32x));
|
|
1398
|
+
}
|
|
1399
|
+
}
|
|
1400
|
+
|
|
1401
|
+
// P×V: zero -> FMOPA -> read -> add (no ZA writes for output_accumulator)
|
|
1402
|
+
nk_size_t dim_tile = 0;
|
|
1403
|
+
for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
|
|
1404
|
+
svzero_za();
|
|
1405
|
+
// Block0: 8 depth steps (KV positions 0-15)
|
|
1406
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1407
|
+
svld1_f16(predicate_all_f16x,
|
|
1408
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
1409
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1410
|
+
svld1_f16(predicate_all_f16x,
|
|
1411
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
1412
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1413
|
+
svld1_f16(predicate_all_f16x,
|
|
1414
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
1415
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1416
|
+
svld1_f16(predicate_all_f16x,
|
|
1417
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
1418
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1419
|
+
svld1_f16(predicate_all_f16x,
|
|
1420
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
1421
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1422
|
+
svld1_f16(predicate_all_f16x,
|
|
1423
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
1424
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1425
|
+
svld1_f16(predicate_all_f16x,
|
|
1426
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
1427
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1428
|
+
svld1_f16(predicate_all_f16x,
|
|
1429
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
1430
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1431
|
+
svld1_f16(predicate_all_f16x,
|
|
1432
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
1433
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1434
|
+
svld1_f16(predicate_all_f16x,
|
|
1435
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
1436
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1437
|
+
svld1_f16(predicate_all_f16x,
|
|
1438
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
1439
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1440
|
+
svld1_f16(predicate_all_f16x,
|
|
1441
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
1442
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1443
|
+
svld1_f16(predicate_all_f16x,
|
|
1444
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
1445
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1446
|
+
svld1_f16(predicate_all_f16x,
|
|
1447
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
1448
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1449
|
+
svld1_f16(predicate_all_f16x,
|
|
1450
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
1451
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1452
|
+
svld1_f16(predicate_all_f16x,
|
|
1453
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
1454
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1455
|
+
svld1_f16(predicate_all_f16x,
|
|
1456
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
1457
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1458
|
+
svld1_f16(predicate_all_f16x,
|
|
1459
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
1460
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1461
|
+
svld1_f16(predicate_all_f16x,
|
|
1462
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
1463
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1464
|
+
svld1_f16(predicate_all_f16x,
|
|
1465
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
1466
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1467
|
+
svld1_f16(predicate_all_f16x,
|
|
1468
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
1469
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1470
|
+
svld1_f16(predicate_all_f16x,
|
|
1471
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
1472
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1473
|
+
svld1_f16(predicate_all_f16x,
|
|
1474
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
1475
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1476
|
+
svld1_f16(predicate_all_f16x,
|
|
1477
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
1478
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1479
|
+
svld1_f16(predicate_all_f16x,
|
|
1480
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
1481
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1482
|
+
svld1_f16(predicate_all_f16x,
|
|
1483
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
1484
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1485
|
+
svld1_f16(predicate_all_f16x,
|
|
1486
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
1487
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1488
|
+
svld1_f16(predicate_all_f16x,
|
|
1489
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
1490
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1491
|
+
svld1_f16(predicate_all_f16x,
|
|
1492
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
1493
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1494
|
+
svld1_f16(predicate_all_f16x,
|
|
1495
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
1496
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1497
|
+
svld1_f16(predicate_all_f16x,
|
|
1498
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
1499
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1500
|
+
svld1_f16(predicate_all_f16x,
|
|
1501
|
+
(float16_t const *)(values_block_lower + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
1502
|
+
// Block1: 8 depth steps (KV positions 16-31)
|
|
1503
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
1504
|
+
svld1_f16(predicate_all_f16x,
|
|
1505
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
1506
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
1507
|
+
svld1_f16(predicate_all_f16x,
|
|
1508
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
1509
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
1510
|
+
svld1_f16(predicate_all_f16x,
|
|
1511
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
1512
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
1513
|
+
svld1_f16(predicate_all_f16x,
|
|
1514
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
1515
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
1516
|
+
svld1_f16(predicate_all_f16x,
|
|
1517
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
1518
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
1519
|
+
svld1_f16(predicate_all_f16x,
|
|
1520
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
1521
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
1522
|
+
svld1_f16(predicate_all_f16x,
|
|
1523
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
1524
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
1525
|
+
svld1_f16(predicate_all_f16x,
|
|
1526
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
1527
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
1528
|
+
svld1_f16(predicate_all_f16x,
|
|
1529
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
1530
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
1531
|
+
svld1_f16(predicate_all_f16x,
|
|
1532
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
1533
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
1534
|
+
svld1_f16(predicate_all_f16x,
|
|
1535
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
1536
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
1537
|
+
svld1_f16(predicate_all_f16x,
|
|
1538
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
1539
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
1540
|
+
svld1_f16(predicate_all_f16x,
|
|
1541
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
1542
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
1543
|
+
svld1_f16(predicate_all_f16x,
|
|
1544
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
1545
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
1546
|
+
svld1_f16(predicate_all_f16x,
|
|
1547
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
1548
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
1549
|
+
svld1_f16(predicate_all_f16x,
|
|
1550
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
1551
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
1552
|
+
svld1_f16(predicate_all_f16x,
|
|
1553
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
1554
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
1555
|
+
svld1_f16(predicate_all_f16x,
|
|
1556
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
1557
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
1558
|
+
svld1_f16(predicate_all_f16x,
|
|
1559
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
1560
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
1561
|
+
svld1_f16(predicate_all_f16x,
|
|
1562
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
1563
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
1564
|
+
svld1_f16(predicate_all_f16x,
|
|
1565
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
1566
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
1567
|
+
svld1_f16(predicate_all_f16x,
|
|
1568
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
1569
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
1570
|
+
svld1_f16(predicate_all_f16x,
|
|
1571
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
1572
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
1573
|
+
svld1_f16(predicate_all_f16x,
|
|
1574
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
1575
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
1576
|
+
svld1_f16(predicate_all_f16x,
|
|
1577
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
1578
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
1579
|
+
svld1_f16(predicate_all_f16x,
|
|
1580
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
1581
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
1582
|
+
svld1_f16(predicate_all_f16x,
|
|
1583
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
1584
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
1585
|
+
svld1_f16(predicate_all_f16x,
|
|
1586
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
1587
|
+
svmopa_za32_f16_m(0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
1588
|
+
svld1_f16(predicate_all_f16x,
|
|
1589
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
1590
|
+
svmopa_za32_f16_m(1, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
1591
|
+
svld1_f16(predicate_all_f16x,
|
|
1592
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
1593
|
+
svmopa_za32_f16_m(2, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
1594
|
+
svld1_f16(predicate_all_f16x,
|
|
1595
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
1596
|
+
svmopa_za32_f16_m(3, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
1597
|
+
svld1_f16(predicate_all_f16x,
|
|
1598
|
+
(float16_t const *)(values_block_upper + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
1599
|
+
// Read FMOPA result and ADD to output_accumulator
|
|
1600
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1601
|
+
svst1_f32(
|
|
1602
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
|
|
1603
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1604
|
+
svld1_f32(predicate_all_f32x,
|
|
1605
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
|
|
1606
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
|
|
1607
|
+
svst1_f32(
|
|
1608
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
|
|
1609
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1610
|
+
svld1_f32(predicate_all_f32x,
|
|
1611
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
|
|
1612
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
|
|
1613
|
+
svst1_f32(
|
|
1614
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
|
|
1615
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1616
|
+
svld1_f32(predicate_all_f32x,
|
|
1617
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
|
|
1618
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
|
|
1619
|
+
svst1_f32(
|
|
1620
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
|
|
1621
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1622
|
+
svld1_f32(predicate_all_f32x,
|
|
1623
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
|
|
1624
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
|
|
1625
|
+
}
|
|
1626
|
+
}
|
|
1627
|
+
// Remainder: 1 dim_tile at a time using ZA0
|
|
1628
|
+
for (; dim_tile < dim_tile_count; dim_tile++) {
|
|
1629
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1630
|
+
svmopa_za32_f16_m(
|
|
1631
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1632
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 0) * 32)));
|
|
1633
|
+
svmopa_za32_f16_m(
|
|
1634
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1635
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 1) * 32)));
|
|
1636
|
+
svmopa_za32_f16_m(
|
|
1637
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1638
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 2) * 32)));
|
|
1639
|
+
svmopa_za32_f16_m(
|
|
1640
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1641
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 3) * 32)));
|
|
1642
|
+
svmopa_za32_f16_m(
|
|
1643
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1644
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 4) * 32)));
|
|
1645
|
+
svmopa_za32_f16_m(
|
|
1646
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1647
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 5) * 32)));
|
|
1648
|
+
svmopa_za32_f16_m(
|
|
1649
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1650
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 6) * 32)));
|
|
1651
|
+
svmopa_za32_f16_m(
|
|
1652
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1653
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_lower + (dim_tile * 8 + 7) * 32)));
|
|
1654
|
+
svmopa_za32_f16_m(
|
|
1655
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_8_f32x,
|
|
1656
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 0) * 32)));
|
|
1657
|
+
svmopa_za32_f16_m(
|
|
1658
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_9_f32x,
|
|
1659
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 1) * 32)));
|
|
1660
|
+
svmopa_za32_f16_m(
|
|
1661
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_10_f32x,
|
|
1662
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 2) * 32)));
|
|
1663
|
+
svmopa_za32_f16_m(
|
|
1664
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_11_f32x,
|
|
1665
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 3) * 32)));
|
|
1666
|
+
svmopa_za32_f16_m(
|
|
1667
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_12_f32x,
|
|
1668
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 4) * 32)));
|
|
1669
|
+
svmopa_za32_f16_m(
|
|
1670
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_13_f32x,
|
|
1671
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 5) * 32)));
|
|
1672
|
+
svmopa_za32_f16_m(
|
|
1673
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_14_f32x,
|
|
1674
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 6) * 32)));
|
|
1675
|
+
svmopa_za32_f16_m(
|
|
1676
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_15_f32x,
|
|
1677
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(values_block_upper + (dim_tile * 8 + 7) * 32)));
|
|
1678
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
1679
|
+
svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
|
|
1680
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1681
|
+
svld1_f32(predicate_all_f32x,
|
|
1682
|
+
output_accumulator + query_index * head_dim_padded + dim_tile * 16),
|
|
1683
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
|
|
1684
|
+
}
|
|
1685
|
+
}
|
|
1686
|
+
}
|
|
1687
|
+
|
|
1688
|
+
// === Bc=16 tail loop (handles remaining KV positions and decode path) ===
|
|
1689
|
+
for (; kv_start < kv_len; kv_start += 16, kv_block_index++) {
|
|
1690
|
+
nk_size_t const valid_kv = ((kv_start + 16) <= kv_len) ? 16 : (kv_len - kv_start);
|
|
1691
|
+
|
|
1692
|
+
// Q×K^T: pure memory→FMOPA, no ZA staging
|
|
1693
|
+
svzero_mask_za(nk_sme_zero_za32_tile_2_);
|
|
1694
|
+
nk_f16_t const *k_block = k + kv_block_index * k_depth_step_count * 32;
|
|
1695
|
+
for (nk_size_t step = 0; step < k_depth_step_count; step++) {
|
|
1696
|
+
svfloat16_t zn = svreinterpret_f16_f32(svld1_f32(predicate_all_f32x, queries_transposed + step * 16));
|
|
1697
|
+
svfloat16_t zm = svld1_f16(predicate_all_f16x, (float16_t const *)(k_block + step * 32));
|
|
1698
|
+
svmopa_za32_f16_m(2, predicate_all_f32x, predicate_all_f32x, zn, zm);
|
|
1699
|
+
}
|
|
1700
|
+
|
|
1701
|
+
// Pass 1: Column-wise max (read ZA2 columns vertically)
|
|
1702
|
+
svfloat32_t scale_16_f32x = svdup_f32(scale);
|
|
1703
|
+
svfloat32_t block_max_16_f32x = svdup_f32(NK_F32_MIN);
|
|
1704
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
1705
|
+
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
1706
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
|
|
1707
|
+
scale_16_f32x);
|
|
1708
|
+
block_max_16_f32x = svmax_f32_x(predicate_all_f32x, block_max_16_f32x, score_column_f32x);
|
|
1709
|
+
}
|
|
1710
|
+
|
|
1711
|
+
svfloat32_t old_max_f32x = svld1_f32(predicate_all_f32x, row_max);
|
|
1712
|
+
svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_f32x, old_max_f32x, block_max_16_f32x);
|
|
1713
|
+
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_f32x,
|
|
1714
|
+
svsub_f32_x(predicate_all_f32x, old_max_f32x, new_max_f32x));
|
|
1715
|
+
svbool_t max_changed_16 = svcmplt_f32(predicate_all_f32x, correction_f32x, svdup_f32(1.0f));
|
|
1716
|
+
nk_u32_t max_was_updated_16 = svptest_any(predicate_all_f32x, max_changed_16) ? 1 : 0;
|
|
1717
|
+
svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_f32x, row_sum);
|
|
1718
|
+
if (max_was_updated_16)
|
|
1719
|
+
row_sum_corrected_f32x = svmul_f32_x(predicate_all_f32x, row_sum_corrected_f32x, correction_f32x);
|
|
1720
|
+
NK_ALIGN64 nk_f32_t corrections[16];
|
|
1721
|
+
svst1_f32(predicate_all_f32x, corrections, correction_f32x);
|
|
1722
|
+
|
|
1723
|
+
// Pass 2: Column-wise exp + fused P write + sum (ZA2 → ZA0 columns 0-7)
|
|
1724
|
+
svfloat32_t sum_delta_16_f32x = svdup_f32(0.0f);
|
|
1725
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1726
|
+
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
1727
|
+
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
1728
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index),
|
|
1729
|
+
scale_16_f32x);
|
|
1730
|
+
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
1731
|
+
predicate_all_f32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, column_index + 1),
|
|
1732
|
+
scale_16_f32x);
|
|
1733
|
+
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
1734
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_even_f32x, new_max_f32x));
|
|
1735
|
+
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
1736
|
+
predicate_all_f32x, svsub_f32_x(predicate_all_f32x, score_odd_f32x, new_max_f32x));
|
|
1737
|
+
sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_even_f32x);
|
|
1738
|
+
sum_delta_16_f32x = svadd_f32_x(predicate_all_f32x, sum_delta_16_f32x, weight_odd_f32x);
|
|
1739
|
+
svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_f32x, weight_even_f32x),
|
|
1740
|
+
svcvt_f16_f32_x(predicate_all_f32x, weight_odd_f32x));
|
|
1741
|
+
svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_f32x, svreinterpret_f32_f16(weight_pair_f16x));
|
|
1742
|
+
}
|
|
1743
|
+
row_sum_corrected_f32x = svadd_f32_x(predicate_all_f32x, row_sum_corrected_f32x, sum_delta_16_f32x);
|
|
1744
|
+
svst1_f32(predicate_all_f32x, row_sum, row_sum_corrected_f32x);
|
|
1745
|
+
svst1_f32(predicate_all_f32x, row_max, new_max_f32x);
|
|
1746
|
+
|
|
1747
|
+
if (valid_query_count == 1) {
|
|
1748
|
+
// Decode path: extract f32 weights from ZA0 row 0 using SVE
|
|
1749
|
+
svfloat16_t row0_f16 = svreinterpret_f16_f32(svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
|
|
1750
|
+
svfloat16_t weights_even_f16 = svuzp1_f16(row0_f16, row0_f16);
|
|
1751
|
+
svfloat16_t weights_odd_f16 = svuzp2_f16(row0_f16, row0_f16);
|
|
1752
|
+
NK_ALIGN64 nk_f32_t decode_weights[16];
|
|
1753
|
+
svst1_f32(svwhilelt_b32(0u, 8u), decode_weights, svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_even_f16));
|
|
1754
|
+
svst1_f32(svwhilelt_b32(0u, 8u), decode_weights + 8,
|
|
1755
|
+
svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_odd_f16));
|
|
1756
|
+
NK_ALIGN64 nk_f32_t decode_weights_ordered[16];
|
|
1757
|
+
for (nk_size_t i = 0; i < 8; i++) {
|
|
1758
|
+
decode_weights_ordered[2 * i] = decode_weights[i];
|
|
1759
|
+
decode_weights_ordered[2 * i + 1] = decode_weights[8 + i];
|
|
1760
|
+
}
|
|
1761
|
+
svfloat32_t corr_f32x = svdup_f32(corrections[0]);
|
|
1762
|
+
for (nk_size_t d = 0; d < head_dim; d += svcntw()) {
|
|
1763
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(d, head_dim);
|
|
1764
|
+
svfloat32_t acc_f32x = svmul_f32_x(predicate_f32x, svld1_f32(predicate_f32x, output_accumulator + d),
|
|
1765
|
+
corr_f32x);
|
|
1766
|
+
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
1767
|
+
nk_size_t dim_tile = d / 16, depth_s = ki / 2, sub = ki % 2;
|
|
1768
|
+
nk_f16_t const *v_vec = v_packed +
|
|
1769
|
+
(kv_block_index * dim_tile_count * 8 + dim_tile * 8 + depth_s) * 32;
|
|
1770
|
+
svfloat16_t packed_f16x = svld1_f16(predicate_all_f16x, (float16_t const *)v_vec);
|
|
1771
|
+
svfloat16_t v_selected = (sub == 0) ? svuzp1_f16(packed_f16x, packed_f16x)
|
|
1772
|
+
: svuzp2_f16(packed_f16x, packed_f16x);
|
|
1773
|
+
acc_f32x = svmla_f32_x(predicate_f32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
|
|
1774
|
+
svcvt_f32_f16_x(predicate_f32x, v_selected));
|
|
1775
|
+
}
|
|
1776
|
+
svst1_f32(predicate_f32x, output_accumulator + d, acc_f32x);
|
|
1777
|
+
}
|
|
1778
|
+
}
|
|
1779
|
+
else {
|
|
1780
|
+
// Prefill Bc=16: extract P columns, pre-apply correction, add-after P×V
|
|
1781
|
+
svbool_t query_predicate_f16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
|
|
1782
|
+
|
|
1783
|
+
svfloat16_t probability_column_0_f32x = svreinterpret_f16_f32(
|
|
1784
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 0));
|
|
1785
|
+
svfloat16_t probability_column_1_f32x = svreinterpret_f16_f32(
|
|
1786
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 1));
|
|
1787
|
+
svfloat16_t probability_column_2_f32x = svreinterpret_f16_f32(
|
|
1788
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 2));
|
|
1789
|
+
svfloat16_t probability_column_3_f32x = svreinterpret_f16_f32(
|
|
1790
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 3));
|
|
1791
|
+
svfloat16_t probability_column_4_f32x = svreinterpret_f16_f32(
|
|
1792
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 4));
|
|
1793
|
+
svfloat16_t probability_column_5_f32x = svreinterpret_f16_f32(
|
|
1794
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 5));
|
|
1795
|
+
svfloat16_t probability_column_6_f32x = svreinterpret_f16_f32(
|
|
1796
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 6));
|
|
1797
|
+
svfloat16_t probability_column_7_f32x = svreinterpret_f16_f32(
|
|
1798
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, 7));
|
|
1799
|
+
|
|
1800
|
+
nk_f16_t const *v_block = v_packed + kv_block_index * dim_tile_count * 8 * 32;
|
|
1801
|
+
|
|
1802
|
+
if (max_was_updated_16) {
|
|
1803
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1804
|
+
svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
|
|
1805
|
+
for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
|
|
1806
|
+
svst1_f32(
|
|
1807
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_offset,
|
|
1808
|
+
svmul_f32_x(predicate_all_f32x,
|
|
1809
|
+
svld1_f32(predicate_all_f32x,
|
|
1810
|
+
output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
1811
|
+
correction_scalar_f32x));
|
|
1812
|
+
}
|
|
1813
|
+
}
|
|
1814
|
+
|
|
1815
|
+
nk_size_t dim_tile = 0;
|
|
1816
|
+
for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
|
|
1817
|
+
svzero_za();
|
|
1818
|
+
svmopa_za32_f16_m(
|
|
1819
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1820
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
1821
|
+
svmopa_za32_f16_m(
|
|
1822
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1823
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
1824
|
+
svmopa_za32_f16_m(
|
|
1825
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1826
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
1827
|
+
svmopa_za32_f16_m(
|
|
1828
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1829
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
1830
|
+
svmopa_za32_f16_m(
|
|
1831
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1832
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
1833
|
+
svmopa_za32_f16_m(
|
|
1834
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1835
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
1836
|
+
svmopa_za32_f16_m(
|
|
1837
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1838
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
1839
|
+
svmopa_za32_f16_m(
|
|
1840
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1841
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
1842
|
+
svmopa_za32_f16_m(
|
|
1843
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1844
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
1845
|
+
svmopa_za32_f16_m(
|
|
1846
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1847
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
1848
|
+
svmopa_za32_f16_m(
|
|
1849
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1850
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
1851
|
+
svmopa_za32_f16_m(
|
|
1852
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1853
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
1854
|
+
svmopa_za32_f16_m(
|
|
1855
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1856
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
1857
|
+
svmopa_za32_f16_m(
|
|
1858
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1859
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
1860
|
+
svmopa_za32_f16_m(
|
|
1861
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1862
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
1863
|
+
svmopa_za32_f16_m(
|
|
1864
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1865
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
1866
|
+
svmopa_za32_f16_m(
|
|
1867
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1868
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
1869
|
+
svmopa_za32_f16_m(
|
|
1870
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1871
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
1872
|
+
svmopa_za32_f16_m(
|
|
1873
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1874
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
1875
|
+
svmopa_za32_f16_m(
|
|
1876
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1877
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
1878
|
+
svmopa_za32_f16_m(
|
|
1879
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1880
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
1881
|
+
svmopa_za32_f16_m(
|
|
1882
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1883
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
1884
|
+
svmopa_za32_f16_m(
|
|
1885
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1886
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
1887
|
+
svmopa_za32_f16_m(
|
|
1888
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1889
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
1890
|
+
svmopa_za32_f16_m(
|
|
1891
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1892
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
1893
|
+
svmopa_za32_f16_m(
|
|
1894
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1895
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
1896
|
+
svmopa_za32_f16_m(
|
|
1897
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1898
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
1899
|
+
svmopa_za32_f16_m(
|
|
1900
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1901
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
1902
|
+
svmopa_za32_f16_m(
|
|
1903
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1904
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
1905
|
+
svmopa_za32_f16_m(
|
|
1906
|
+
1, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1907
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
1908
|
+
svmopa_za32_f16_m(
|
|
1909
|
+
2, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1910
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
1911
|
+
svmopa_za32_f16_m(
|
|
1912
|
+
3, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1913
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
1914
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1915
|
+
svst1_f32(
|
|
1916
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
|
|
1917
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1918
|
+
svld1_f32(predicate_all_f32x,
|
|
1919
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
|
|
1920
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
|
|
1921
|
+
svst1_f32(
|
|
1922
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
|
|
1923
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1924
|
+
svld1_f32(predicate_all_f32x,
|
|
1925
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
|
|
1926
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 1, query_index)));
|
|
1927
|
+
svst1_f32(
|
|
1928
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
|
|
1929
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1930
|
+
svld1_f32(predicate_all_f32x,
|
|
1931
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
|
|
1932
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 2, query_index)));
|
|
1933
|
+
svst1_f32(
|
|
1934
|
+
predicate_all_f32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
|
|
1935
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1936
|
+
svld1_f32(predicate_all_f32x,
|
|
1937
|
+
output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
|
|
1938
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 3, query_index)));
|
|
1939
|
+
}
|
|
1940
|
+
}
|
|
1941
|
+
for (; dim_tile < dim_tile_count; dim_tile++) {
|
|
1942
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1943
|
+
svmopa_za32_f16_m(
|
|
1944
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_0_f32x,
|
|
1945
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
|
|
1946
|
+
svmopa_za32_f16_m(
|
|
1947
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_1_f32x,
|
|
1948
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
|
|
1949
|
+
svmopa_za32_f16_m(
|
|
1950
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_2_f32x,
|
|
1951
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
|
|
1952
|
+
svmopa_za32_f16_m(
|
|
1953
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_3_f32x,
|
|
1954
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
|
|
1955
|
+
svmopa_za32_f16_m(
|
|
1956
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_4_f32x,
|
|
1957
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
|
|
1958
|
+
svmopa_za32_f16_m(
|
|
1959
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_5_f32x,
|
|
1960
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
|
|
1961
|
+
svmopa_za32_f16_m(
|
|
1962
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_6_f32x,
|
|
1963
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
|
|
1964
|
+
svmopa_za32_f16_m(
|
|
1965
|
+
0, query_predicate_f16x, predicate_all_f16x, probability_column_7_f32x,
|
|
1966
|
+
svld1_f16(predicate_all_f16x, (float16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
|
|
1967
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
1968
|
+
svst1_f32(predicate_all_f32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
|
|
1969
|
+
svadd_f32_x(predicate_all_f32x,
|
|
1970
|
+
svld1_f32(predicate_all_f32x,
|
|
1971
|
+
output_accumulator + query_index * head_dim_padded + dim_tile * 16),
|
|
1972
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_f32x, 0, query_index)));
|
|
1973
|
+
}
|
|
1974
|
+
}
|
|
1975
|
+
}
|
|
1976
|
+
|
|
1977
|
+
// Final normalization
|
|
1978
|
+
svfloat32_t final_sum_f32x = svld1_f32(predicate_all_f32x, row_sum);
|
|
1979
|
+
svfloat32_t ones_f32x = svdup_f32(1.0f);
|
|
1980
|
+
svfloat32_t zeros_f32x = svdup_f32(0.0f);
|
|
1981
|
+
svbool_t sum_positive = svcmpgt_f32(predicate_all_f32x, final_sum_f32x, zeros_f32x);
|
|
1982
|
+
svfloat32_t inv_sum_f32x = svsel_f32(sum_positive, svdiv_f32_x(predicate_all_f32x, ones_f32x, final_sum_f32x),
|
|
1983
|
+
zeros_f32x);
|
|
1984
|
+
|
|
1985
|
+
NK_ALIGN64 nk_f32_t inv_sums[16];
|
|
1986
|
+
svst1_f32(predicate_all_f32x, inv_sums, inv_sum_f32x);
|
|
1987
|
+
|
|
1988
|
+
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1989
|
+
svfloat32_t inv_sum_f32x = svdup_f32(inv_sums[query_index]);
|
|
1990
|
+
for (nk_size_t dim_offset = 0; dim_offset < head_dim; dim_offset += svcntw()) {
|
|
1991
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(dim_offset, head_dim);
|
|
1992
|
+
svfloat32_t output_f32x = svmul_f32_x(
|
|
1993
|
+
predicate_f32x,
|
|
1994
|
+
svld1_f32(predicate_f32x, output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
1995
|
+
inv_sum_f32x);
|
|
1996
|
+
svfloat16_t output_f16x = svcvt_f16_f32_x(predicate_f32x, output_f32x);
|
|
1997
|
+
nk_size_t store_count = (head_dim - dim_offset) < (nk_size_t)svcntw() ? (head_dim - dim_offset)
|
|
1998
|
+
: (nk_size_t)svcntw();
|
|
1999
|
+
svbool_t predicate_f16x = svwhilelt_b16_u64(0u, store_count);
|
|
2000
|
+
svst1_f16(predicate_f16x, (float16_t *)(output + query_index * head_dim + dim_offset), output_f16x);
|
|
2001
|
+
}
|
|
2002
|
+
}
|
|
2003
|
+
}
|
|
2004
|
+
|
|
2005
|
+
NK_PUBLIC void nk_attention_f16_sme(nk_f16_t const *q, void const *kv_packed, nk_f16_t *output, nk_size_t num_heads,
|
|
2006
|
+
nk_size_t num_kv_heads, nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim,
|
|
2007
|
+
nk_f32_t scale) {
|
|
2008
|
+
|
|
2009
|
+
nk_attention_sme_packed_header_t const *header = (nk_attention_sme_packed_header_t const *)kv_packed;
|
|
2010
|
+
nk_size_t head_dim_padded = header->head_dim_padded;
|
|
2011
|
+
nk_size_t dim_tile_count = header->reserved[0];
|
|
2012
|
+
nk_size_t kv_blocks = (kv_len + 15) / 16;
|
|
2013
|
+
// K and V both use interleaved format: kv_blocks * 16 * head_dim_padded elements per head
|
|
2014
|
+
nk_size_t kv_head_stride = kv_blocks * 16 * head_dim_padded;
|
|
2015
|
+
|
|
2016
|
+
nk_f16_t const *k_packed = (nk_f16_t const *)((char const *)kv_packed + header->k_offset);
|
|
2017
|
+
nk_f16_t const *v_packed = (nk_f16_t const *)((char const *)kv_packed + header->v_offset);
|
|
2018
|
+
|
|
2019
|
+
nk_size_t group_size = (num_kv_heads > 0) ? num_heads / num_kv_heads : 1;
|
|
2020
|
+
|
|
2021
|
+
for (nk_size_t q_head = 0; q_head < num_heads; q_head++) {
|
|
2022
|
+
nk_size_t kv_head = q_head / group_size;
|
|
2023
|
+
|
|
2024
|
+
nk_f16_t const *q_ptr = q + q_head * query_len * head_dim;
|
|
2025
|
+
nk_f16_t const *k_ptr = k_packed + kv_head * kv_head_stride;
|
|
2026
|
+
nk_f16_t const *v_ptr = v_packed + kv_head * kv_head_stride;
|
|
2027
|
+
nk_f16_t *out_ptr = output + q_head * query_len * head_dim;
|
|
2028
|
+
|
|
2029
|
+
for (nk_size_t q_start = 0; q_start < query_len; q_start += 16) {
|
|
2030
|
+
nk_size_t q_block_len = (q_start + 16 < query_len) ? 16 : (query_len - q_start);
|
|
2031
|
+
|
|
2032
|
+
nk_attention_f16_sme_streaming_(q_ptr + q_start * head_dim, k_ptr, v_ptr, out_ptr + q_start * head_dim,
|
|
2033
|
+
q_block_len, kv_len, head_dim, head_dim_padded, dim_tile_count, scale);
|
|
2034
|
+
}
|
|
2035
|
+
}
|
|
2036
|
+
}
|
|
2037
|
+
|
|
2038
|
+
NK_PUBLIC void nk_attention_causal_bf16_sme(nk_bf16_t const *q, void const *kv_packed, nk_bf16_t *output,
|
|
2039
|
+
nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
|
|
2040
|
+
nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
|
|
2041
|
+
// TODO: Implement proper causal masking with block skipping
|
|
2042
|
+
// For now, delegate to full attention (correct for decode where query_len=1)
|
|
2043
|
+
nk_attention_bf16_sme(q, kv_packed, output, num_heads, num_kv_heads, query_len, kv_len, head_dim, scale);
|
|
2044
|
+
}
|
|
2045
|
+
|
|
2046
|
+
NK_PUBLIC void nk_attention_causal_f16_sme(nk_f16_t const *q, void const *kv_packed, nk_f16_t *output,
|
|
2047
|
+
nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
|
|
2048
|
+
nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
|
|
2049
|
+
// TODO: Implement proper causal masking with block skipping
|
|
2050
|
+
// For now, delegate to full attention (correct for decode where query_len=1)
|
|
2051
|
+
nk_attention_f16_sme(q, kv_packed, output, num_heads, num_kv_heads, query_len, kv_len, head_dim, scale);
|
|
2052
|
+
}
|
|
2053
|
+
|
|
2054
|
+
#if defined(__clang__)
|
|
2055
|
+
#pragma clang attribute pop
|
|
2056
|
+
#elif defined(__GNUC__)
|
|
2057
|
+
#pragma GCC pop_options
|
|
2058
|
+
#endif
|
|
2059
|
+
|
|
2060
|
+
#if defined(__cplusplus)
|
|
2061
|
+
} // extern "C"
|
|
2062
|
+
#endif
|
|
2063
|
+
|
|
2064
|
+
#endif // NK_TARGET_SME
|
|
2065
|
+
#endif // NK_TARGET_ARM_
|
|
2066
|
+
#endif // NK_ATTENTION_SME_H
|