numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -101,7 +101,7 @@ extern "C" {
|
|
|
101
101
|
#endif
|
|
102
102
|
|
|
103
103
|
#if defined(__clang__)
|
|
104
|
-
#pragma clang attribute push(__attribute__((target("sme
|
|
104
|
+
#pragma clang attribute push(__attribute__((target("sme"))), apply_to = function)
|
|
105
105
|
#elif defined(__GNUC__)
|
|
106
106
|
#pragma GCC push_options
|
|
107
107
|
#pragma GCC target("+sme")
|
|
@@ -116,10 +116,10 @@ extern "C" {
|
|
|
116
116
|
* 3. Shift left by 16 to place in f32 exponent+mantissa position
|
|
117
117
|
* 4. Reinterpret as f32
|
|
118
118
|
*/
|
|
119
|
-
NK_INTERNAL svfloat32_t nk_bf16_to_f32_sve_(svbool_t
|
|
119
|
+
NK_INTERNAL svfloat32_t nk_bf16_to_f32_sve_(svbool_t predicate_b32x, svbfloat16_t x_bf16x) __arm_streaming {
|
|
120
120
|
svuint16_t x_u16x = svreinterpret_u16_bf16(x_bf16x);
|
|
121
121
|
svuint32_t x_u32x = svunpklo_u32(x_u16x);
|
|
122
|
-
x_u32x = svlsl_n_u32_x(
|
|
122
|
+
x_u32x = svlsl_n_u32_x(predicate_b32x, x_u32x, 16);
|
|
123
123
|
return svreinterpret_f32_u32(x_u32x);
|
|
124
124
|
}
|
|
125
125
|
|
|
@@ -131,10 +131,10 @@ NK_INTERNAL svfloat32_t nk_bf16_to_f32_sve_(svbool_t predicate_f32x, svbfloat16_
|
|
|
131
131
|
* 3. Shift right by 16
|
|
132
132
|
* 4. Narrow to u16 and reinterpret as bf16
|
|
133
133
|
*/
|
|
134
|
-
NK_INTERNAL svbfloat16_t nk_f32_to_bf16_sve_(svbool_t
|
|
134
|
+
NK_INTERNAL svbfloat16_t nk_f32_to_bf16_sve_(svbool_t predicate_b32x, svfloat32_t x_f32x) __arm_streaming {
|
|
135
135
|
svuint32_t x_u32x = svreinterpret_u32_f32(x_f32x);
|
|
136
|
-
x_u32x = svadd_n_u32_x(
|
|
137
|
-
x_u32x = svlsr_n_u32_x(
|
|
136
|
+
x_u32x = svadd_n_u32_x(predicate_b32x, x_u32x, 0x8000); // Round to nearest
|
|
137
|
+
x_u32x = svlsr_n_u32_x(predicate_b32x, x_u32x, 16);
|
|
138
138
|
svuint16_t x_u16x = svuzp1_u16(svreinterpret_u16_u32(x_u32x), svreinterpret_u16_u32(x_u32x));
|
|
139
139
|
return svreinterpret_bf16_u16(x_u16x);
|
|
140
140
|
}
|
|
@@ -166,71 +166,71 @@ typedef struct {
|
|
|
166
166
|
* @param x Input vector
|
|
167
167
|
* @return exp(x) approximation
|
|
168
168
|
*/
|
|
169
|
-
NK_INTERNAL svfloat32_t nk_exp_f32_sve_(svbool_t
|
|
169
|
+
NK_INTERNAL svfloat32_t nk_exp_f32_sve_(svbool_t predicate_b32x, svfloat32_t x_f32x) __arm_streaming {
|
|
170
170
|
// Constants for Cody-Waite range reduction
|
|
171
171
|
svfloat32_t log2e_f32x = svdup_f32(1.4426950408889634f);
|
|
172
|
-
svfloat32_t
|
|
173
|
-
svfloat32_t
|
|
172
|
+
svfloat32_t ln2_high_f32x = svdup_f32(0.693145751953125f);
|
|
173
|
+
svfloat32_t ln2_low_f32x = svdup_f32(1.42860682030941723212e-6f);
|
|
174
174
|
|
|
175
175
|
// Clamp to avoid overflow/underflow
|
|
176
176
|
svfloat32_t max_x_f32x = svdup_f32(88.3762626647949f);
|
|
177
177
|
svfloat32_t min_x_f32x = svdup_f32(-87.3365447504021f);
|
|
178
|
-
x_f32x = svmax_f32_m(
|
|
178
|
+
x_f32x = svmax_f32_m(predicate_b32x, svmin_f32_m(predicate_b32x, x_f32x, max_x_f32x), min_x_f32x);
|
|
179
179
|
|
|
180
180
|
// n = round(x / ln(2))
|
|
181
|
-
svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(),
|
|
181
|
+
svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_b32x, svmul_f32_m(predicate_b32x, x_f32x, log2e_f32x));
|
|
182
182
|
|
|
183
183
|
// r = x - n × ln(2) using Cody-Waite for precision
|
|
184
|
-
svfloat32_t r_f32x = svmsb_f32_m(
|
|
185
|
-
r_f32x = svmsb_f32_m(
|
|
184
|
+
svfloat32_t r_f32x = svmsb_f32_m(predicate_b32x, n_f32x, ln2_high_f32x, x_f32x);
|
|
185
|
+
r_f32x = svmsb_f32_m(predicate_b32x, n_f32x, ln2_low_f32x, r_f32x);
|
|
186
186
|
|
|
187
187
|
// Polynomial approximation for exp(r): degree 4
|
|
188
188
|
// exp(r) ≈ 1 + r + r²/2 + r³/6 + r⁴/24
|
|
189
189
|
svfloat32_t p_f32x = svdup_f32(4.1666666667e-2f); // 1/24
|
|
190
|
-
p_f32x = svmad_f32_m(
|
|
191
|
-
p_f32x = svmad_f32_m(
|
|
192
|
-
p_f32x = svmad_f32_m(
|
|
193
|
-
p_f32x = svmad_f32_m(
|
|
190
|
+
p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.6666666667e-1f)); // 1/6
|
|
191
|
+
p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
|
|
192
|
+
p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
|
|
193
|
+
p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
|
|
194
194
|
|
|
195
195
|
// Reconstruct: exp(x) = 2ⁿ × exp(r)
|
|
196
196
|
// 2ⁿ via IEEE 754 exponent manipulation
|
|
197
|
-
svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(),
|
|
198
|
-
n_i32x = svadd_s32_m(
|
|
199
|
-
n_i32x = svlsl_n_s32_m(
|
|
197
|
+
svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_b32x, n_f32x);
|
|
198
|
+
n_i32x = svadd_s32_m(predicate_b32x, n_i32x, svdup_s32(127));
|
|
199
|
+
n_i32x = svlsl_n_s32_m(predicate_b32x, n_i32x, 23);
|
|
200
200
|
svfloat32_t pow2n_f32x = svreinterpret_f32_s32(n_i32x);
|
|
201
201
|
|
|
202
|
-
return svmul_f32_m(
|
|
202
|
+
return svmul_f32_m(predicate_b32x, p_f32x, pow2n_f32x);
|
|
203
203
|
}
|
|
204
204
|
|
|
205
205
|
/**
|
|
206
206
|
* @brief Degree-3 fast exp approximation. Max relative error ~0.5%.
|
|
207
207
|
* Saves 1 FMA per call vs degree-4 nk_exp_f32_sve_.
|
|
208
208
|
*/
|
|
209
|
-
NK_INTERNAL svfloat32_t nk_exp_fast_f32_sve_(svbool_t
|
|
209
|
+
NK_INTERNAL svfloat32_t nk_exp_fast_f32_sve_(svbool_t predicate_b32x, svfloat32_t x_f32x) __arm_streaming {
|
|
210
210
|
svfloat32_t log2e_f32x = svdup_f32(1.4426950408889634f);
|
|
211
|
-
svfloat32_t
|
|
212
|
-
svfloat32_t
|
|
211
|
+
svfloat32_t ln2_high_f32x = svdup_f32(0.693145751953125f);
|
|
212
|
+
svfloat32_t ln2_low_f32x = svdup_f32(1.42860682030941723212e-6f);
|
|
213
213
|
|
|
214
214
|
svfloat32_t max_x_f32x = svdup_f32(88.3762626647949f);
|
|
215
215
|
svfloat32_t min_x_f32x = svdup_f32(-87.3365447504021f);
|
|
216
|
-
x_f32x = svmax_f32_m(
|
|
216
|
+
x_f32x = svmax_f32_m(predicate_b32x, svmin_f32_m(predicate_b32x, x_f32x, max_x_f32x), min_x_f32x);
|
|
217
217
|
|
|
218
|
-
svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(),
|
|
219
|
-
svfloat32_t r_f32x = svmsb_f32_m(
|
|
220
|
-
r_f32x = svmsb_f32_m(
|
|
218
|
+
svfloat32_t n_f32x = svrintn_f32_m(svundef_f32(), predicate_b32x, svmul_f32_m(predicate_b32x, x_f32x, log2e_f32x));
|
|
219
|
+
svfloat32_t r_f32x = svmsb_f32_m(predicate_b32x, n_f32x, ln2_high_f32x, x_f32x);
|
|
220
|
+
r_f32x = svmsb_f32_m(predicate_b32x, n_f32x, ln2_low_f32x, r_f32x);
|
|
221
221
|
|
|
222
222
|
// Degree-3: exp(r) ~ 1 + r + r^2/2 + r^3/6 (drop 1/24 term)
|
|
223
223
|
svfloat32_t p_f32x = svdup_f32(1.6666666667e-1f); // 1/6
|
|
224
|
-
p_f32x = svmad_f32_m(
|
|
225
|
-
p_f32x = svmad_f32_m(
|
|
226
|
-
p_f32x = svmad_f32_m(
|
|
224
|
+
p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(5.0000000000e-1f)); // 1/2
|
|
225
|
+
p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
|
|
226
|
+
p_f32x = svmad_f32_m(predicate_b32x, p_f32x, r_f32x, svdup_f32(1.0f)); // 1
|
|
227
227
|
|
|
228
|
-
svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(),
|
|
229
|
-
n_i32x = svadd_s32_m(
|
|
230
|
-
n_i32x = svlsl_n_s32_m(
|
|
228
|
+
svint32_t n_i32x = svcvt_s32_f32_m(svundef_s32(), predicate_b32x, n_f32x);
|
|
229
|
+
n_i32x = svadd_s32_m(predicate_b32x, n_i32x, svdup_s32(127));
|
|
230
|
+
n_i32x = svlsl_n_s32_m(predicate_b32x, n_i32x, 23);
|
|
231
231
|
svfloat32_t pow2n_f32x = svreinterpret_f32_s32(n_i32x);
|
|
232
232
|
|
|
233
|
-
return svmul_f32_m(
|
|
233
|
+
return svmul_f32_m(predicate_b32x, p_f32x, pow2n_f32x);
|
|
234
234
|
}
|
|
235
235
|
|
|
236
236
|
NK_PUBLIC nk_size_t nk_attention_packed_kv_size_bf16_sme(nk_size_t num_kv_heads, nk_size_t head_dim,
|
|
@@ -410,8 +410,8 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
410
410
|
nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim, nk_size_t head_dim_padded, nk_size_t dim_tile_count,
|
|
411
411
|
nk_f32_t scale) {
|
|
412
412
|
|
|
413
|
-
svbool_t const
|
|
414
|
-
svbool_t const
|
|
413
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
414
|
+
svbool_t const predicate_all_b16x = svptrue_b16();
|
|
415
415
|
nk_size_t const valid_query_count = (query_len < 16) ? query_len : 16;
|
|
416
416
|
|
|
417
417
|
svfloat32_t row_max_f32x = svdup_f32(NK_F32_MIN);
|
|
@@ -420,12 +420,12 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
420
420
|
NK_ALIGN64 nk_f32_t output_accumulator[16 * 256];
|
|
421
421
|
svfloat32_t zero_f32x = svdup_f32(0.0f);
|
|
422
422
|
for (nk_size_t i = 0; i < 16 * head_dim_padded; i += svcntw()) {
|
|
423
|
-
svst1_f32(
|
|
423
|
+
svst1_f32(predicate_all_b32x, output_accumulator + i, zero_f32x);
|
|
424
424
|
}
|
|
425
425
|
|
|
426
426
|
nk_size_t kv_block_index = 0;
|
|
427
427
|
nk_size_t kv_start = 0;
|
|
428
|
-
svbool_t const
|
|
428
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32(0u, 16u);
|
|
429
429
|
|
|
430
430
|
nk_size_t const k_depth_step_count = head_dim_padded / 2;
|
|
431
431
|
|
|
@@ -434,11 +434,11 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
434
434
|
for (nk_size_t batch = 0; batch < head_dim_padded / 32; batch++) {
|
|
435
435
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
436
436
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
437
|
-
svld1_hor_za32(0, query_index,
|
|
437
|
+
svld1_hor_za32(0, query_index, batch_predicate_b32x,
|
|
438
438
|
(nk_f32_t const *)(q + query_index * head_dim + batch * 32));
|
|
439
439
|
for (nk_size_t step = 0; step < 16; step++)
|
|
440
|
-
svst1_f32(
|
|
441
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
440
|
+
svst1_f32(predicate_all_b32x, queries_transposed + (batch * 16 + step) * 16,
|
|
441
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, step));
|
|
442
442
|
}
|
|
443
443
|
|
|
444
444
|
// Bc=32 main loop (prefill only, skipped for decode)
|
|
@@ -447,14 +447,17 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
447
447
|
// Q×K^T: pure memory→BFMOPA, no ZA staging for Q or K
|
|
448
448
|
svzero_mask_za(nk_sme_zero_za32_tile_2_);
|
|
449
449
|
svzero_mask_za(nk_sme_zero_za32_tile_3_);
|
|
450
|
-
nk_bf16_t const *
|
|
451
|
-
nk_bf16_t const *
|
|
450
|
+
nk_bf16_t const *keys_block_low = k + kv_block_index * k_depth_step_count * 32;
|
|
451
|
+
nk_bf16_t const *keys_block_high = k + (kv_block_index + 1) * k_depth_step_count * 32;
|
|
452
452
|
for (nk_size_t step = 0; step < k_depth_step_count; step++) {
|
|
453
|
-
svbfloat16_t
|
|
454
|
-
|
|
455
|
-
svbfloat16_t
|
|
456
|
-
|
|
457
|
-
|
|
453
|
+
svbfloat16_t zn_bf16x = svreinterpret_bf16_f32(
|
|
454
|
+
svld1_f32(predicate_all_b32x, queries_transposed + step * 16));
|
|
455
|
+
svbfloat16_t zm0_bf16x = svld1_bf16(predicate_all_b16x,
|
|
456
|
+
(bfloat16_t const *)(keys_block_low + step * 32));
|
|
457
|
+
svbfloat16_t zm1_bf16x = svld1_bf16(predicate_all_b16x,
|
|
458
|
+
(bfloat16_t const *)(keys_block_high + step * 32));
|
|
459
|
+
svmopa_za32_bf16_m(2, predicate_all_b32x, predicate_all_b32x, zn_bf16x, zm0_bf16x);
|
|
460
|
+
svmopa_za32_bf16_m(3, predicate_all_b32x, predicate_all_b32x, zn_bf16x, zm1_bf16x);
|
|
458
461
|
}
|
|
459
462
|
|
|
460
463
|
// Pass 1: Column-wise max (read ZA2/ZA3 columns vertically)
|
|
@@ -462,26 +465,26 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
462
465
|
svfloat32_t block_max_f32x = svdup_f32(NK_F32_MIN);
|
|
463
466
|
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
464
467
|
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
465
|
-
|
|
468
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
|
|
466
469
|
scale_f32x);
|
|
467
|
-
block_max_f32x = svmax_f32_x(
|
|
470
|
+
block_max_f32x = svmax_f32_x(predicate_all_b32x, block_max_f32x, score_column_f32x);
|
|
468
471
|
}
|
|
469
472
|
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
470
473
|
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
471
|
-
|
|
474
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index),
|
|
472
475
|
scale_f32x);
|
|
473
|
-
block_max_f32x = svmax_f32_x(
|
|
476
|
+
block_max_f32x = svmax_f32_x(predicate_all_b32x, block_max_f32x, score_column_f32x);
|
|
474
477
|
}
|
|
475
478
|
|
|
476
479
|
// Softmax correction (fully vectorized)
|
|
477
|
-
svfloat32_t new_max_f32x = svmax_f32_x(
|
|
480
|
+
svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_b32x, row_max_f32x, block_max_f32x);
|
|
478
481
|
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
|
|
479
|
-
|
|
480
|
-
svbool_t
|
|
481
|
-
nk_u32_t max_was_updated = svptest_any(
|
|
482
|
-
if (max_was_updated) row_sum_f32x = svmul_f32_x(
|
|
482
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, row_max_f32x, new_max_f32x));
|
|
483
|
+
svbool_t max_changed_b32x = svcmplt_f32(predicate_all_b32x, correction_f32x, svdup_f32(1.0f));
|
|
484
|
+
nk_u32_t max_was_updated = svptest_any(predicate_all_b32x, max_changed_b32x) ? 1 : 0;
|
|
485
|
+
if (max_was_updated) row_sum_f32x = svmul_f32_x(predicate_all_b32x, row_sum_f32x, correction_f32x);
|
|
483
486
|
NK_ALIGN64 nk_f32_t corrections[16];
|
|
484
|
-
svst1_f32(
|
|
487
|
+
svst1_f32(predicate_all_b32x, corrections, correction_f32x);
|
|
485
488
|
|
|
486
489
|
// Pass 2: Column-wise exp + fused P write + sum
|
|
487
490
|
svfloat32_t sum_delta_f32x = svdup_f32(0.0f);
|
|
@@ -489,91 +492,91 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
489
492
|
// ZA2 columns in pairs → ZA0 columns 0-7
|
|
490
493
|
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
491
494
|
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
492
|
-
|
|
495
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
|
|
493
496
|
scale_f32x);
|
|
494
497
|
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
495
|
-
|
|
498
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index + 1),
|
|
496
499
|
scale_f32x);
|
|
497
500
|
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
498
|
-
|
|
501
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
|
|
499
502
|
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
500
|
-
|
|
501
|
-
sum_delta_f32x = svadd_f32_x(
|
|
502
|
-
sum_delta_f32x = svadd_f32_x(
|
|
503
|
-
svbfloat16_t
|
|
504
|
-
|
|
505
|
-
svwrite_ver_za32_f32_m(0, column_index / 2,
|
|
506
|
-
svreinterpret_f32_bf16(
|
|
503
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
|
|
504
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_even_f32x);
|
|
505
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_odd_f32x);
|
|
506
|
+
svbfloat16_t weight_pair_bf16x = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_b32x, weight_even_f32x),
|
|
507
|
+
nk_f32_to_bf16_sve_(predicate_all_b32x, weight_odd_f32x));
|
|
508
|
+
svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_b32x,
|
|
509
|
+
svreinterpret_f32_bf16(weight_pair_bf16x));
|
|
507
510
|
}
|
|
508
511
|
// ZA3 columns in pairs → ZA0 columns 8-15
|
|
509
512
|
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
510
513
|
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
511
|
-
|
|
514
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index),
|
|
512
515
|
scale_f32x);
|
|
513
516
|
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
514
|
-
|
|
517
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index + 1),
|
|
515
518
|
scale_f32x);
|
|
516
519
|
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
517
|
-
|
|
520
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
|
|
518
521
|
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
519
|
-
|
|
520
|
-
sum_delta_f32x = svadd_f32_x(
|
|
521
|
-
sum_delta_f32x = svadd_f32_x(
|
|
522
|
-
svbfloat16_t
|
|
523
|
-
|
|
524
|
-
svwrite_ver_za32_f32_m(0, 8 + column_index / 2,
|
|
525
|
-
svreinterpret_f32_bf16(
|
|
522
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
|
|
523
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_even_f32x);
|
|
524
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_odd_f32x);
|
|
525
|
+
svbfloat16_t weight_pair_bf16x = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_b32x, weight_even_f32x),
|
|
526
|
+
nk_f32_to_bf16_sve_(predicate_all_b32x, weight_odd_f32x));
|
|
527
|
+
svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_b32x,
|
|
528
|
+
svreinterpret_f32_bf16(weight_pair_bf16x));
|
|
526
529
|
}
|
|
527
|
-
row_sum_f32x = svadd_f32_x(
|
|
530
|
+
row_sum_f32x = svadd_f32_x(predicate_all_b32x, row_sum_f32x, sum_delta_f32x);
|
|
528
531
|
row_max_f32x = new_max_f32x;
|
|
529
532
|
|
|
530
533
|
// Extract P columns from ZA0
|
|
531
534
|
svbfloat16_t probability_column_0_f32x = svreinterpret_bf16_f32(
|
|
532
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
535
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
|
|
533
536
|
svbfloat16_t probability_column_1_f32x = svreinterpret_bf16_f32(
|
|
534
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
537
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 1));
|
|
535
538
|
svbfloat16_t probability_column_2_f32x = svreinterpret_bf16_f32(
|
|
536
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
539
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 2));
|
|
537
540
|
svbfloat16_t probability_column_3_f32x = svreinterpret_bf16_f32(
|
|
538
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
541
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 3));
|
|
539
542
|
svbfloat16_t probability_column_4_f32x = svreinterpret_bf16_f32(
|
|
540
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
543
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 4));
|
|
541
544
|
svbfloat16_t probability_column_5_f32x = svreinterpret_bf16_f32(
|
|
542
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
545
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 5));
|
|
543
546
|
svbfloat16_t probability_column_6_f32x = svreinterpret_bf16_f32(
|
|
544
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
547
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 6));
|
|
545
548
|
svbfloat16_t probability_column_7_f32x = svreinterpret_bf16_f32(
|
|
546
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
549
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 7));
|
|
547
550
|
svbfloat16_t probability_column_8_f32x = svreinterpret_bf16_f32(
|
|
548
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
551
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 8));
|
|
549
552
|
svbfloat16_t probability_column_9_f32x = svreinterpret_bf16_f32(
|
|
550
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
553
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 9));
|
|
551
554
|
svbfloat16_t probability_column_10_f32x = svreinterpret_bf16_f32(
|
|
552
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
555
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 10));
|
|
553
556
|
svbfloat16_t probability_column_11_f32x = svreinterpret_bf16_f32(
|
|
554
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
557
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 11));
|
|
555
558
|
svbfloat16_t probability_column_12_f32x = svreinterpret_bf16_f32(
|
|
556
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
559
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 12));
|
|
557
560
|
svbfloat16_t probability_column_13_f32x = svreinterpret_bf16_f32(
|
|
558
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
561
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 13));
|
|
559
562
|
svbfloat16_t probability_column_14_f32x = svreinterpret_bf16_f32(
|
|
560
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
563
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 14));
|
|
561
564
|
svbfloat16_t probability_column_15_f32x = svreinterpret_bf16_f32(
|
|
562
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
565
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 15));
|
|
563
566
|
|
|
564
567
|
// Pre-apply correction once before P×V
|
|
565
|
-
svbool_t
|
|
566
|
-
nk_bf16_t const *
|
|
567
|
-
nk_bf16_t const *
|
|
568
|
+
svbool_t query_predicate_b16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
|
|
569
|
+
nk_bf16_t const *values_block_low = v_packed + kv_block_index * dim_tile_count * 8 * 32;
|
|
570
|
+
nk_bf16_t const *values_block_high = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
|
|
568
571
|
|
|
569
572
|
if (max_was_updated) {
|
|
570
573
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
571
574
|
svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
|
|
572
575
|
for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
|
|
573
576
|
svst1_f32(
|
|
574
|
-
|
|
575
|
-
svmul_f32_x(
|
|
576
|
-
svld1_f32(
|
|
577
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_offset,
|
|
578
|
+
svmul_f32_x(predicate_all_b32x,
|
|
579
|
+
svld1_f32(predicate_all_b32x,
|
|
577
580
|
output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
578
581
|
correction_scalar_f32x));
|
|
579
582
|
}
|
|
@@ -584,284 +587,284 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
584
587
|
for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
|
|
585
588
|
svzero_za();
|
|
586
589
|
// Block0: 8 depth steps (KV positions 0-15)
|
|
587
|
-
svmopa_za32_bf16_m(0,
|
|
588
|
-
svld1_bf16(
|
|
589
|
-
|
|
590
|
-
svmopa_za32_bf16_m(1,
|
|
591
|
-
svld1_bf16(
|
|
592
|
-
|
|
593
|
-
svmopa_za32_bf16_m(2,
|
|
594
|
-
svld1_bf16(
|
|
595
|
-
|
|
596
|
-
svmopa_za32_bf16_m(3,
|
|
597
|
-
svld1_bf16(
|
|
598
|
-
|
|
599
|
-
svmopa_za32_bf16_m(0,
|
|
600
|
-
svld1_bf16(
|
|
601
|
-
|
|
602
|
-
svmopa_za32_bf16_m(1,
|
|
603
|
-
svld1_bf16(
|
|
604
|
-
|
|
605
|
-
svmopa_za32_bf16_m(2,
|
|
606
|
-
svld1_bf16(
|
|
607
|
-
|
|
608
|
-
svmopa_za32_bf16_m(3,
|
|
609
|
-
svld1_bf16(
|
|
610
|
-
|
|
611
|
-
svmopa_za32_bf16_m(0,
|
|
612
|
-
svld1_bf16(
|
|
613
|
-
|
|
614
|
-
svmopa_za32_bf16_m(1,
|
|
615
|
-
svld1_bf16(
|
|
616
|
-
|
|
617
|
-
svmopa_za32_bf16_m(2,
|
|
618
|
-
svld1_bf16(
|
|
619
|
-
|
|
620
|
-
svmopa_za32_bf16_m(3,
|
|
621
|
-
svld1_bf16(
|
|
622
|
-
|
|
623
|
-
svmopa_za32_bf16_m(0,
|
|
624
|
-
svld1_bf16(
|
|
625
|
-
|
|
626
|
-
svmopa_za32_bf16_m(1,
|
|
627
|
-
svld1_bf16(
|
|
628
|
-
|
|
629
|
-
svmopa_za32_bf16_m(2,
|
|
630
|
-
svld1_bf16(
|
|
631
|
-
|
|
632
|
-
svmopa_za32_bf16_m(3,
|
|
633
|
-
svld1_bf16(
|
|
634
|
-
|
|
635
|
-
svmopa_za32_bf16_m(0,
|
|
636
|
-
svld1_bf16(
|
|
637
|
-
|
|
638
|
-
svmopa_za32_bf16_m(1,
|
|
639
|
-
svld1_bf16(
|
|
640
|
-
|
|
641
|
-
svmopa_za32_bf16_m(2,
|
|
642
|
-
svld1_bf16(
|
|
643
|
-
|
|
644
|
-
svmopa_za32_bf16_m(3,
|
|
645
|
-
svld1_bf16(
|
|
646
|
-
|
|
647
|
-
svmopa_za32_bf16_m(0,
|
|
648
|
-
svld1_bf16(
|
|
649
|
-
|
|
650
|
-
svmopa_za32_bf16_m(1,
|
|
651
|
-
svld1_bf16(
|
|
652
|
-
|
|
653
|
-
svmopa_za32_bf16_m(2,
|
|
654
|
-
svld1_bf16(
|
|
655
|
-
|
|
656
|
-
svmopa_za32_bf16_m(3,
|
|
657
|
-
svld1_bf16(
|
|
658
|
-
|
|
659
|
-
svmopa_za32_bf16_m(0,
|
|
660
|
-
svld1_bf16(
|
|
661
|
-
|
|
662
|
-
svmopa_za32_bf16_m(1,
|
|
663
|
-
svld1_bf16(
|
|
664
|
-
|
|
665
|
-
svmopa_za32_bf16_m(2,
|
|
666
|
-
svld1_bf16(
|
|
667
|
-
|
|
668
|
-
svmopa_za32_bf16_m(3,
|
|
669
|
-
svld1_bf16(
|
|
670
|
-
|
|
671
|
-
svmopa_za32_bf16_m(0,
|
|
672
|
-
svld1_bf16(
|
|
673
|
-
|
|
674
|
-
svmopa_za32_bf16_m(1,
|
|
675
|
-
svld1_bf16(
|
|
676
|
-
|
|
677
|
-
svmopa_za32_bf16_m(2,
|
|
678
|
-
svld1_bf16(
|
|
679
|
-
|
|
680
|
-
svmopa_za32_bf16_m(3,
|
|
681
|
-
svld1_bf16(
|
|
682
|
-
|
|
590
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
591
|
+
svld1_bf16(predicate_all_b16x,
|
|
592
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
593
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
594
|
+
svld1_bf16(predicate_all_b16x,
|
|
595
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
596
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
597
|
+
svld1_bf16(predicate_all_b16x,
|
|
598
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
599
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
600
|
+
svld1_bf16(predicate_all_b16x,
|
|
601
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
602
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
603
|
+
svld1_bf16(predicate_all_b16x,
|
|
604
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
605
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
606
|
+
svld1_bf16(predicate_all_b16x,
|
|
607
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
608
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
609
|
+
svld1_bf16(predicate_all_b16x,
|
|
610
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
611
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
612
|
+
svld1_bf16(predicate_all_b16x,
|
|
613
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
614
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
615
|
+
svld1_bf16(predicate_all_b16x,
|
|
616
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
617
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
618
|
+
svld1_bf16(predicate_all_b16x,
|
|
619
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
620
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
621
|
+
svld1_bf16(predicate_all_b16x,
|
|
622
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
623
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
624
|
+
svld1_bf16(predicate_all_b16x,
|
|
625
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
626
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
627
|
+
svld1_bf16(predicate_all_b16x,
|
|
628
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
629
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
630
|
+
svld1_bf16(predicate_all_b16x,
|
|
631
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
632
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
633
|
+
svld1_bf16(predicate_all_b16x,
|
|
634
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
635
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
636
|
+
svld1_bf16(predicate_all_b16x,
|
|
637
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
638
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
639
|
+
svld1_bf16(predicate_all_b16x,
|
|
640
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
641
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
642
|
+
svld1_bf16(predicate_all_b16x,
|
|
643
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
644
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
645
|
+
svld1_bf16(predicate_all_b16x,
|
|
646
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
647
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
648
|
+
svld1_bf16(predicate_all_b16x,
|
|
649
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
650
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
651
|
+
svld1_bf16(predicate_all_b16x,
|
|
652
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
653
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
654
|
+
svld1_bf16(predicate_all_b16x,
|
|
655
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
656
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
657
|
+
svld1_bf16(predicate_all_b16x,
|
|
658
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
659
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
660
|
+
svld1_bf16(predicate_all_b16x,
|
|
661
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
662
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
663
|
+
svld1_bf16(predicate_all_b16x,
|
|
664
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
665
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
666
|
+
svld1_bf16(predicate_all_b16x,
|
|
667
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
668
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
669
|
+
svld1_bf16(predicate_all_b16x,
|
|
670
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
671
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
672
|
+
svld1_bf16(predicate_all_b16x,
|
|
673
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
674
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
675
|
+
svld1_bf16(predicate_all_b16x,
|
|
676
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
677
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
678
|
+
svld1_bf16(predicate_all_b16x,
|
|
679
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
680
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
681
|
+
svld1_bf16(predicate_all_b16x,
|
|
682
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
683
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
684
|
+
svld1_bf16(predicate_all_b16x,
|
|
685
|
+
(bfloat16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
683
686
|
// Block1: 8 depth steps (KV positions 16-31)
|
|
684
|
-
svmopa_za32_bf16_m(0,
|
|
685
|
-
svld1_bf16(
|
|
686
|
-
|
|
687
|
-
svmopa_za32_bf16_m(1,
|
|
688
|
-
svld1_bf16(
|
|
689
|
-
|
|
690
|
-
svmopa_za32_bf16_m(2,
|
|
691
|
-
svld1_bf16(
|
|
692
|
-
|
|
693
|
-
svmopa_za32_bf16_m(3,
|
|
694
|
-
svld1_bf16(
|
|
695
|
-
|
|
696
|
-
svmopa_za32_bf16_m(0,
|
|
697
|
-
svld1_bf16(
|
|
698
|
-
|
|
699
|
-
svmopa_za32_bf16_m(1,
|
|
700
|
-
svld1_bf16(
|
|
701
|
-
|
|
702
|
-
svmopa_za32_bf16_m(2,
|
|
703
|
-
svld1_bf16(
|
|
704
|
-
|
|
705
|
-
svmopa_za32_bf16_m(3,
|
|
706
|
-
svld1_bf16(
|
|
707
|
-
|
|
708
|
-
svmopa_za32_bf16_m(0,
|
|
709
|
-
svld1_bf16(
|
|
710
|
-
|
|
711
|
-
svmopa_za32_bf16_m(1,
|
|
712
|
-
svld1_bf16(
|
|
713
|
-
|
|
714
|
-
svmopa_za32_bf16_m(2,
|
|
715
|
-
svld1_bf16(
|
|
716
|
-
|
|
717
|
-
svmopa_za32_bf16_m(3,
|
|
718
|
-
svld1_bf16(
|
|
719
|
-
|
|
720
|
-
svmopa_za32_bf16_m(0,
|
|
721
|
-
svld1_bf16(
|
|
722
|
-
|
|
723
|
-
svmopa_za32_bf16_m(1,
|
|
724
|
-
svld1_bf16(
|
|
725
|
-
|
|
726
|
-
svmopa_za32_bf16_m(2,
|
|
727
|
-
svld1_bf16(
|
|
728
|
-
|
|
729
|
-
svmopa_za32_bf16_m(3,
|
|
730
|
-
svld1_bf16(
|
|
731
|
-
|
|
732
|
-
svmopa_za32_bf16_m(0,
|
|
733
|
-
svld1_bf16(
|
|
734
|
-
|
|
735
|
-
svmopa_za32_bf16_m(1,
|
|
736
|
-
svld1_bf16(
|
|
737
|
-
|
|
738
|
-
svmopa_za32_bf16_m(2,
|
|
739
|
-
svld1_bf16(
|
|
740
|
-
|
|
741
|
-
svmopa_za32_bf16_m(3,
|
|
742
|
-
svld1_bf16(
|
|
743
|
-
|
|
744
|
-
svmopa_za32_bf16_m(0,
|
|
745
|
-
svld1_bf16(
|
|
746
|
-
|
|
747
|
-
svmopa_za32_bf16_m(1,
|
|
748
|
-
svld1_bf16(
|
|
749
|
-
|
|
750
|
-
svmopa_za32_bf16_m(2,
|
|
751
|
-
svld1_bf16(
|
|
752
|
-
|
|
753
|
-
svmopa_za32_bf16_m(3,
|
|
754
|
-
svld1_bf16(
|
|
755
|
-
|
|
756
|
-
svmopa_za32_bf16_m(0,
|
|
757
|
-
svld1_bf16(
|
|
758
|
-
|
|
759
|
-
svmopa_za32_bf16_m(1,
|
|
760
|
-
svld1_bf16(
|
|
761
|
-
|
|
762
|
-
svmopa_za32_bf16_m(2,
|
|
763
|
-
svld1_bf16(
|
|
764
|
-
|
|
765
|
-
svmopa_za32_bf16_m(3,
|
|
766
|
-
svld1_bf16(
|
|
767
|
-
|
|
768
|
-
svmopa_za32_bf16_m(0,
|
|
769
|
-
svld1_bf16(
|
|
770
|
-
|
|
771
|
-
svmopa_za32_bf16_m(1,
|
|
772
|
-
svld1_bf16(
|
|
773
|
-
|
|
774
|
-
svmopa_za32_bf16_m(2,
|
|
775
|
-
svld1_bf16(
|
|
776
|
-
|
|
777
|
-
svmopa_za32_bf16_m(3,
|
|
778
|
-
svld1_bf16(
|
|
779
|
-
|
|
687
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
688
|
+
svld1_bf16(predicate_all_b16x,
|
|
689
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
690
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
691
|
+
svld1_bf16(predicate_all_b16x,
|
|
692
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
693
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
694
|
+
svld1_bf16(predicate_all_b16x,
|
|
695
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
696
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
697
|
+
svld1_bf16(predicate_all_b16x,
|
|
698
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
699
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
700
|
+
svld1_bf16(predicate_all_b16x,
|
|
701
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
702
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
703
|
+
svld1_bf16(predicate_all_b16x,
|
|
704
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
705
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
706
|
+
svld1_bf16(predicate_all_b16x,
|
|
707
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
708
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
709
|
+
svld1_bf16(predicate_all_b16x,
|
|
710
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
711
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
712
|
+
svld1_bf16(predicate_all_b16x,
|
|
713
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
714
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
715
|
+
svld1_bf16(predicate_all_b16x,
|
|
716
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
717
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
718
|
+
svld1_bf16(predicate_all_b16x,
|
|
719
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
720
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
721
|
+
svld1_bf16(predicate_all_b16x,
|
|
722
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
723
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
724
|
+
svld1_bf16(predicate_all_b16x,
|
|
725
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
726
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
727
|
+
svld1_bf16(predicate_all_b16x,
|
|
728
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
729
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
730
|
+
svld1_bf16(predicate_all_b16x,
|
|
731
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
732
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
733
|
+
svld1_bf16(predicate_all_b16x,
|
|
734
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
735
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
736
|
+
svld1_bf16(predicate_all_b16x,
|
|
737
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
738
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
739
|
+
svld1_bf16(predicate_all_b16x,
|
|
740
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
741
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
742
|
+
svld1_bf16(predicate_all_b16x,
|
|
743
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
744
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
745
|
+
svld1_bf16(predicate_all_b16x,
|
|
746
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
747
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
748
|
+
svld1_bf16(predicate_all_b16x,
|
|
749
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
750
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
751
|
+
svld1_bf16(predicate_all_b16x,
|
|
752
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
753
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
754
|
+
svld1_bf16(predicate_all_b16x,
|
|
755
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
756
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
757
|
+
svld1_bf16(predicate_all_b16x,
|
|
758
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
759
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
760
|
+
svld1_bf16(predicate_all_b16x,
|
|
761
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
762
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
763
|
+
svld1_bf16(predicate_all_b16x,
|
|
764
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
765
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
766
|
+
svld1_bf16(predicate_all_b16x,
|
|
767
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
768
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
769
|
+
svld1_bf16(predicate_all_b16x,
|
|
770
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
771
|
+
svmopa_za32_bf16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
772
|
+
svld1_bf16(predicate_all_b16x,
|
|
773
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
774
|
+
svmopa_za32_bf16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
775
|
+
svld1_bf16(predicate_all_b16x,
|
|
776
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
777
|
+
svmopa_za32_bf16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
778
|
+
svld1_bf16(predicate_all_b16x,
|
|
779
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
780
|
+
svmopa_za32_bf16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
781
|
+
svld1_bf16(predicate_all_b16x,
|
|
782
|
+
(bfloat16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
780
783
|
// Read BFMOPA result and ADD to output_accumulator
|
|
781
784
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
782
785
|
svst1_f32(
|
|
783
|
-
|
|
784
|
-
svadd_f32_x(
|
|
785
|
-
svld1_f32(
|
|
786
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
|
|
787
|
+
svadd_f32_x(predicate_all_b32x,
|
|
788
|
+
svld1_f32(predicate_all_b32x,
|
|
786
789
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
|
|
787
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
790
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
|
|
788
791
|
svst1_f32(
|
|
789
|
-
|
|
790
|
-
svadd_f32_x(
|
|
791
|
-
svld1_f32(
|
|
792
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
|
|
793
|
+
svadd_f32_x(predicate_all_b32x,
|
|
794
|
+
svld1_f32(predicate_all_b32x,
|
|
792
795
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
|
|
793
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
796
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 1, query_index)));
|
|
794
797
|
svst1_f32(
|
|
795
|
-
|
|
796
|
-
svadd_f32_x(
|
|
797
|
-
svld1_f32(
|
|
798
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
|
|
799
|
+
svadd_f32_x(predicate_all_b32x,
|
|
800
|
+
svld1_f32(predicate_all_b32x,
|
|
798
801
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
|
|
799
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
802
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, query_index)));
|
|
800
803
|
svst1_f32(
|
|
801
|
-
|
|
802
|
-
svadd_f32_x(
|
|
803
|
-
svld1_f32(
|
|
804
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
|
|
805
|
+
svadd_f32_x(predicate_all_b32x,
|
|
806
|
+
svld1_f32(predicate_all_b32x,
|
|
804
807
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
|
|
805
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
808
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, query_index)));
|
|
806
809
|
}
|
|
807
810
|
}
|
|
808
811
|
// Remainder: 1 dim_tile at a time using ZA0
|
|
809
812
|
for (; dim_tile < dim_tile_count; dim_tile++) {
|
|
810
813
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
811
814
|
svmopa_za32_bf16_m(
|
|
812
|
-
0,
|
|
813
|
-
svld1_bf16(
|
|
815
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
816
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 0) * 32)));
|
|
814
817
|
svmopa_za32_bf16_m(
|
|
815
|
-
0,
|
|
816
|
-
svld1_bf16(
|
|
818
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
819
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 1) * 32)));
|
|
817
820
|
svmopa_za32_bf16_m(
|
|
818
|
-
0,
|
|
819
|
-
svld1_bf16(
|
|
821
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
822
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 2) * 32)));
|
|
820
823
|
svmopa_za32_bf16_m(
|
|
821
|
-
0,
|
|
822
|
-
svld1_bf16(
|
|
824
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
825
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 3) * 32)));
|
|
823
826
|
svmopa_za32_bf16_m(
|
|
824
|
-
0,
|
|
825
|
-
svld1_bf16(
|
|
827
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
828
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 4) * 32)));
|
|
826
829
|
svmopa_za32_bf16_m(
|
|
827
|
-
0,
|
|
828
|
-
svld1_bf16(
|
|
830
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
831
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 5) * 32)));
|
|
829
832
|
svmopa_za32_bf16_m(
|
|
830
|
-
0,
|
|
831
|
-
svld1_bf16(
|
|
833
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
834
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 6) * 32)));
|
|
832
835
|
svmopa_za32_bf16_m(
|
|
833
|
-
0,
|
|
834
|
-
svld1_bf16(
|
|
836
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
837
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_low + (dim_tile * 8 + 7) * 32)));
|
|
835
838
|
svmopa_za32_bf16_m(
|
|
836
|
-
0,
|
|
837
|
-
svld1_bf16(
|
|
839
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
840
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 0) * 32)));
|
|
838
841
|
svmopa_za32_bf16_m(
|
|
839
|
-
0,
|
|
840
|
-
svld1_bf16(
|
|
842
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
843
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 1) * 32)));
|
|
841
844
|
svmopa_za32_bf16_m(
|
|
842
|
-
0,
|
|
843
|
-
svld1_bf16(
|
|
845
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
846
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 2) * 32)));
|
|
844
847
|
svmopa_za32_bf16_m(
|
|
845
|
-
0,
|
|
846
|
-
svld1_bf16(
|
|
848
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
849
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 3) * 32)));
|
|
847
850
|
svmopa_za32_bf16_m(
|
|
848
|
-
0,
|
|
849
|
-
svld1_bf16(
|
|
851
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
852
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 4) * 32)));
|
|
850
853
|
svmopa_za32_bf16_m(
|
|
851
|
-
0,
|
|
852
|
-
svld1_bf16(
|
|
854
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
855
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 5) * 32)));
|
|
853
856
|
svmopa_za32_bf16_m(
|
|
854
|
-
0,
|
|
855
|
-
svld1_bf16(
|
|
857
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
858
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 6) * 32)));
|
|
856
859
|
svmopa_za32_bf16_m(
|
|
857
|
-
0,
|
|
858
|
-
svld1_bf16(
|
|
860
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
861
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(values_block_high + (dim_tile * 8 + 7) * 32)));
|
|
859
862
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
860
|
-
svst1_f32(
|
|
861
|
-
svadd_f32_x(
|
|
862
|
-
svld1_f32(
|
|
863
|
+
svst1_f32(predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
|
|
864
|
+
svadd_f32_x(predicate_all_b32x,
|
|
865
|
+
svld1_f32(predicate_all_b32x,
|
|
863
866
|
output_accumulator + query_index * head_dim_padded + dim_tile * 16),
|
|
864
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
867
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
|
|
865
868
|
}
|
|
866
869
|
}
|
|
867
870
|
}
|
|
@@ -874,9 +877,10 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
874
877
|
svzero_mask_za(nk_sme_zero_za32_tile_2_);
|
|
875
878
|
nk_bf16_t const *k_block = k + kv_block_index * k_depth_step_count * 32;
|
|
876
879
|
for (nk_size_t step = 0; step < k_depth_step_count; step++) {
|
|
877
|
-
svbfloat16_t
|
|
878
|
-
|
|
879
|
-
|
|
880
|
+
svbfloat16_t zn_bf16x = svreinterpret_bf16_f32(
|
|
881
|
+
svld1_f32(predicate_all_b32x, queries_transposed + step * 16));
|
|
882
|
+
svbfloat16_t zm_bf16x = svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(k_block + step * 32));
|
|
883
|
+
svmopa_za32_bf16_m(2, predicate_all_b32x, predicate_all_b32x, zn_bf16x, zm_bf16x);
|
|
880
884
|
}
|
|
881
885
|
|
|
882
886
|
// Pass 1: Column-wise max (read ZA2 columns vertically)
|
|
@@ -884,55 +888,55 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
884
888
|
svfloat32_t block_max_16_f32x = svdup_f32(NK_F32_MIN);
|
|
885
889
|
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
886
890
|
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
887
|
-
|
|
891
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
|
|
888
892
|
scale_16_f32x);
|
|
889
|
-
block_max_16_f32x = svmax_f32_x(
|
|
893
|
+
block_max_16_f32x = svmax_f32_x(predicate_all_b32x, block_max_16_f32x, score_column_f32x);
|
|
890
894
|
}
|
|
891
895
|
|
|
892
896
|
// Softmax correction (fully vectorized)
|
|
893
|
-
svfloat32_t new_max_f32x = svmax_f32_x(
|
|
894
|
-
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
|
|
895
|
-
svsub_f32_x(
|
|
896
|
-
svbool_t
|
|
897
|
-
nk_u32_t max_was_updated_16 = svptest_any(
|
|
898
|
-
if (max_was_updated_16) row_sum_f32x = svmul_f32_x(
|
|
897
|
+
svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_b32x, row_max_f32x, block_max_16_f32x);
|
|
898
|
+
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_b32x,
|
|
899
|
+
svsub_f32_x(predicate_all_b32x, row_max_f32x, new_max_f32x));
|
|
900
|
+
svbool_t max_changed_16_b32x = svcmplt_f32(predicate_all_b32x, correction_f32x, svdup_f32(1.0f));
|
|
901
|
+
nk_u32_t max_was_updated_16 = svptest_any(predicate_all_b32x, max_changed_16_b32x) ? 1 : 0;
|
|
902
|
+
if (max_was_updated_16) row_sum_f32x = svmul_f32_x(predicate_all_b32x, row_sum_f32x, correction_f32x);
|
|
899
903
|
NK_ALIGN64 nk_f32_t corrections[16];
|
|
900
|
-
svst1_f32(
|
|
904
|
+
svst1_f32(predicate_all_b32x, corrections, correction_f32x);
|
|
901
905
|
|
|
902
906
|
// Pass 2: Column-wise exp + fused P write + sum (ZA2 → ZA0 columns 0-7)
|
|
903
907
|
svfloat32_t sum_delta_16_f32x = svdup_f32(0.0f);
|
|
904
908
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
905
909
|
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
906
910
|
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
907
|
-
|
|
911
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
|
|
908
912
|
scale_16_f32x);
|
|
909
913
|
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
910
|
-
|
|
914
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index + 1),
|
|
911
915
|
scale_16_f32x);
|
|
912
916
|
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
913
|
-
|
|
917
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
|
|
914
918
|
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
915
|
-
|
|
916
|
-
sum_delta_16_f32x = svadd_f32_x(
|
|
917
|
-
sum_delta_16_f32x = svadd_f32_x(
|
|
918
|
-
svbfloat16_t
|
|
919
|
-
|
|
920
|
-
svwrite_ver_za32_f32_m(0, column_index / 2,
|
|
919
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
|
|
920
|
+
sum_delta_16_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_16_f32x, weight_even_f32x);
|
|
921
|
+
sum_delta_16_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_16_f32x, weight_odd_f32x);
|
|
922
|
+
svbfloat16_t weight_pair_bf16x = svzip1_bf16(nk_f32_to_bf16_sve_(predicate_all_b32x, weight_even_f32x),
|
|
923
|
+
nk_f32_to_bf16_sve_(predicate_all_b32x, weight_odd_f32x));
|
|
924
|
+
svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_b32x, svreinterpret_f32_bf16(weight_pair_bf16x));
|
|
921
925
|
}
|
|
922
|
-
row_sum_f32x = svadd_f32_x(
|
|
926
|
+
row_sum_f32x = svadd_f32_x(predicate_all_b32x, row_sum_f32x, sum_delta_16_f32x);
|
|
923
927
|
row_max_f32x = new_max_f32x;
|
|
924
928
|
|
|
925
929
|
if (valid_query_count == 1) {
|
|
926
930
|
// Decode path: extract f32 weights from ZA0 row 0 using SVE
|
|
927
|
-
svbfloat16_t
|
|
928
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
929
|
-
svbfloat16_t
|
|
930
|
-
svbfloat16_t
|
|
931
|
+
svbfloat16_t row0_bf16x = svreinterpret_bf16_f32(
|
|
932
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
|
|
933
|
+
svbfloat16_t weights_even_bf16x = svuzp1_bf16(row0_bf16x, row0_bf16x);
|
|
934
|
+
svbfloat16_t weights_odd_bf16x = svuzp2_bf16(row0_bf16x, row0_bf16x);
|
|
931
935
|
NK_ALIGN64 nk_f32_t decode_weights[16];
|
|
932
936
|
svst1_f32(svwhilelt_b32(0u, 8u), decode_weights,
|
|
933
|
-
nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u),
|
|
937
|
+
nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_even_bf16x));
|
|
934
938
|
svst1_f32(svwhilelt_b32(0u, 8u), decode_weights + 8,
|
|
935
|
-
nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u),
|
|
939
|
+
nk_bf16_to_f32_sve_(svwhilelt_b32(0u, 8u), weights_odd_bf16x));
|
|
936
940
|
NK_ALIGN64 nk_f32_t decode_weights_ordered[16];
|
|
937
941
|
for (nk_size_t i = 0; i < 8; i++) {
|
|
938
942
|
decode_weights_ordered[2 * i] = decode_weights[i];
|
|
@@ -940,42 +944,42 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
940
944
|
}
|
|
941
945
|
svfloat32_t corr_f32x = svdup_f32(corrections[0]);
|
|
942
946
|
for (nk_size_t d = 0; d < head_dim; d += svcntw()) {
|
|
943
|
-
svbool_t
|
|
944
|
-
svfloat32_t acc_f32x = svmul_f32_x(
|
|
947
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(d, head_dim);
|
|
948
|
+
svfloat32_t acc_f32x = svmul_f32_x(predicate_b32x, svld1_f32(predicate_b32x, output_accumulator + d),
|
|
945
949
|
corr_f32x);
|
|
946
950
|
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
947
951
|
nk_size_t dim_tile = d / 16, depth_s = ki / 2, sub = ki % 2;
|
|
948
952
|
nk_bf16_t const *v_vec = v_packed +
|
|
949
953
|
(kv_block_index * dim_tile_count * 8 + dim_tile * 8 + depth_s) * 32;
|
|
950
|
-
svbfloat16_t packed_bf16x = svld1_bf16(
|
|
951
|
-
svbfloat16_t
|
|
952
|
-
|
|
953
|
-
acc_f32x = svmla_f32_x(
|
|
954
|
-
nk_bf16_to_f32_sve_(
|
|
954
|
+
svbfloat16_t packed_bf16x = svld1_bf16(predicate_all_b16x, (bfloat16_t const *)v_vec);
|
|
955
|
+
svbfloat16_t v_selected_bf16x = (sub == 0) ? svuzp1_bf16(packed_bf16x, packed_bf16x)
|
|
956
|
+
: svuzp2_bf16(packed_bf16x, packed_bf16x);
|
|
957
|
+
acc_f32x = svmla_f32_x(predicate_b32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
|
|
958
|
+
nk_bf16_to_f32_sve_(predicate_b32x, v_selected_bf16x));
|
|
955
959
|
}
|
|
956
|
-
svst1_f32(
|
|
960
|
+
svst1_f32(predicate_b32x, output_accumulator + d, acc_f32x);
|
|
957
961
|
}
|
|
958
962
|
}
|
|
959
963
|
else {
|
|
960
964
|
// Prefill Bc=16: extract P columns, pre-apply correction, add-after P×V
|
|
961
|
-
svbool_t
|
|
965
|
+
svbool_t query_predicate_b16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
|
|
962
966
|
|
|
963
967
|
svbfloat16_t probability_column_0_f32x = svreinterpret_bf16_f32(
|
|
964
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
968
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
|
|
965
969
|
svbfloat16_t probability_column_1_f32x = svreinterpret_bf16_f32(
|
|
966
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
970
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 1));
|
|
967
971
|
svbfloat16_t probability_column_2_f32x = svreinterpret_bf16_f32(
|
|
968
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
972
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 2));
|
|
969
973
|
svbfloat16_t probability_column_3_f32x = svreinterpret_bf16_f32(
|
|
970
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
974
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 3));
|
|
971
975
|
svbfloat16_t probability_column_4_f32x = svreinterpret_bf16_f32(
|
|
972
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
976
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 4));
|
|
973
977
|
svbfloat16_t probability_column_5_f32x = svreinterpret_bf16_f32(
|
|
974
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
978
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 5));
|
|
975
979
|
svbfloat16_t probability_column_6_f32x = svreinterpret_bf16_f32(
|
|
976
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
980
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 6));
|
|
977
981
|
svbfloat16_t probability_column_7_f32x = svreinterpret_bf16_f32(
|
|
978
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
982
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 7));
|
|
979
983
|
|
|
980
984
|
nk_bf16_t const *v_block = v_packed + kv_block_index * dim_tile_count * 8 * 32;
|
|
981
985
|
|
|
@@ -985,9 +989,9 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
985
989
|
svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
|
|
986
990
|
for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
|
|
987
991
|
svst1_f32(
|
|
988
|
-
|
|
989
|
-
svmul_f32_x(
|
|
990
|
-
svld1_f32(
|
|
992
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_offset,
|
|
993
|
+
svmul_f32_x(predicate_all_b32x,
|
|
994
|
+
svld1_f32(predicate_all_b32x,
|
|
991
995
|
output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
992
996
|
correction_scalar_f32x));
|
|
993
997
|
}
|
|
@@ -998,183 +1002,183 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_bf16_sme_stream
|
|
|
998
1002
|
for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
|
|
999
1003
|
svzero_za();
|
|
1000
1004
|
svmopa_za32_bf16_m(
|
|
1001
|
-
0,
|
|
1002
|
-
svld1_bf16(
|
|
1005
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1006
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
1003
1007
|
svmopa_za32_bf16_m(
|
|
1004
|
-
1,
|
|
1005
|
-
svld1_bf16(
|
|
1008
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1009
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
1006
1010
|
svmopa_za32_bf16_m(
|
|
1007
|
-
2,
|
|
1008
|
-
svld1_bf16(
|
|
1011
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1012
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
1009
1013
|
svmopa_za32_bf16_m(
|
|
1010
|
-
3,
|
|
1011
|
-
svld1_bf16(
|
|
1014
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1015
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
1012
1016
|
svmopa_za32_bf16_m(
|
|
1013
|
-
0,
|
|
1014
|
-
svld1_bf16(
|
|
1017
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1018
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
1015
1019
|
svmopa_za32_bf16_m(
|
|
1016
|
-
1,
|
|
1017
|
-
svld1_bf16(
|
|
1020
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1021
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
1018
1022
|
svmopa_za32_bf16_m(
|
|
1019
|
-
2,
|
|
1020
|
-
svld1_bf16(
|
|
1023
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1024
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
1021
1025
|
svmopa_za32_bf16_m(
|
|
1022
|
-
3,
|
|
1023
|
-
svld1_bf16(
|
|
1026
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1027
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
1024
1028
|
svmopa_za32_bf16_m(
|
|
1025
|
-
0,
|
|
1026
|
-
svld1_bf16(
|
|
1029
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1030
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
1027
1031
|
svmopa_za32_bf16_m(
|
|
1028
|
-
1,
|
|
1029
|
-
svld1_bf16(
|
|
1032
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1033
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
1030
1034
|
svmopa_za32_bf16_m(
|
|
1031
|
-
2,
|
|
1032
|
-
svld1_bf16(
|
|
1035
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1036
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
1033
1037
|
svmopa_za32_bf16_m(
|
|
1034
|
-
3,
|
|
1035
|
-
svld1_bf16(
|
|
1038
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1039
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
1036
1040
|
svmopa_za32_bf16_m(
|
|
1037
|
-
0,
|
|
1038
|
-
svld1_bf16(
|
|
1041
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1042
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
1039
1043
|
svmopa_za32_bf16_m(
|
|
1040
|
-
1,
|
|
1041
|
-
svld1_bf16(
|
|
1044
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1045
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
1042
1046
|
svmopa_za32_bf16_m(
|
|
1043
|
-
2,
|
|
1044
|
-
svld1_bf16(
|
|
1047
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1048
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
1045
1049
|
svmopa_za32_bf16_m(
|
|
1046
|
-
3,
|
|
1047
|
-
svld1_bf16(
|
|
1050
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1051
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
1048
1052
|
svmopa_za32_bf16_m(
|
|
1049
|
-
0,
|
|
1050
|
-
svld1_bf16(
|
|
1053
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1054
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
1051
1055
|
svmopa_za32_bf16_m(
|
|
1052
|
-
1,
|
|
1053
|
-
svld1_bf16(
|
|
1056
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1057
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
1054
1058
|
svmopa_za32_bf16_m(
|
|
1055
|
-
2,
|
|
1056
|
-
svld1_bf16(
|
|
1059
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1060
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
1057
1061
|
svmopa_za32_bf16_m(
|
|
1058
|
-
3,
|
|
1059
|
-
svld1_bf16(
|
|
1062
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1063
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
1060
1064
|
svmopa_za32_bf16_m(
|
|
1061
|
-
0,
|
|
1062
|
-
svld1_bf16(
|
|
1065
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1066
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
1063
1067
|
svmopa_za32_bf16_m(
|
|
1064
|
-
1,
|
|
1065
|
-
svld1_bf16(
|
|
1068
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1069
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
1066
1070
|
svmopa_za32_bf16_m(
|
|
1067
|
-
2,
|
|
1068
|
-
svld1_bf16(
|
|
1071
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1072
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
1069
1073
|
svmopa_za32_bf16_m(
|
|
1070
|
-
3,
|
|
1071
|
-
svld1_bf16(
|
|
1074
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1075
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
1072
1076
|
svmopa_za32_bf16_m(
|
|
1073
|
-
0,
|
|
1074
|
-
svld1_bf16(
|
|
1077
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1078
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
1075
1079
|
svmopa_za32_bf16_m(
|
|
1076
|
-
1,
|
|
1077
|
-
svld1_bf16(
|
|
1080
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1081
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
1078
1082
|
svmopa_za32_bf16_m(
|
|
1079
|
-
2,
|
|
1080
|
-
svld1_bf16(
|
|
1083
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1084
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
1081
1085
|
svmopa_za32_bf16_m(
|
|
1082
|
-
3,
|
|
1083
|
-
svld1_bf16(
|
|
1086
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1087
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
1084
1088
|
svmopa_za32_bf16_m(
|
|
1085
|
-
0,
|
|
1086
|
-
svld1_bf16(
|
|
1089
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1090
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
1087
1091
|
svmopa_za32_bf16_m(
|
|
1088
|
-
1,
|
|
1089
|
-
svld1_bf16(
|
|
1092
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1093
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
1090
1094
|
svmopa_za32_bf16_m(
|
|
1091
|
-
2,
|
|
1092
|
-
svld1_bf16(
|
|
1095
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1096
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
1093
1097
|
svmopa_za32_bf16_m(
|
|
1094
|
-
3,
|
|
1095
|
-
svld1_bf16(
|
|
1098
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1099
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
1096
1100
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1097
1101
|
svst1_f32(
|
|
1098
|
-
|
|
1099
|
-
svadd_f32_x(
|
|
1100
|
-
svld1_f32(
|
|
1102
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
|
|
1103
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1104
|
+
svld1_f32(predicate_all_b32x,
|
|
1101
1105
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
|
|
1102
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1106
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
|
|
1103
1107
|
svst1_f32(
|
|
1104
|
-
|
|
1105
|
-
svadd_f32_x(
|
|
1106
|
-
svld1_f32(
|
|
1108
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
|
|
1109
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1110
|
+
svld1_f32(predicate_all_b32x,
|
|
1107
1111
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
|
|
1108
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1112
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 1, query_index)));
|
|
1109
1113
|
svst1_f32(
|
|
1110
|
-
|
|
1111
|
-
svadd_f32_x(
|
|
1112
|
-
svld1_f32(
|
|
1114
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
|
|
1115
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1116
|
+
svld1_f32(predicate_all_b32x,
|
|
1113
1117
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
|
|
1114
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1118
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, query_index)));
|
|
1115
1119
|
svst1_f32(
|
|
1116
|
-
|
|
1117
|
-
svadd_f32_x(
|
|
1118
|
-
svld1_f32(
|
|
1120
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
|
|
1121
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1122
|
+
svld1_f32(predicate_all_b32x,
|
|
1119
1123
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
|
|
1120
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1124
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, query_index)));
|
|
1121
1125
|
}
|
|
1122
1126
|
}
|
|
1123
1127
|
for (; dim_tile < dim_tile_count; dim_tile++) {
|
|
1124
1128
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1125
1129
|
svmopa_za32_bf16_m(
|
|
1126
|
-
0,
|
|
1127
|
-
svld1_bf16(
|
|
1130
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1131
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
|
|
1128
1132
|
svmopa_za32_bf16_m(
|
|
1129
|
-
0,
|
|
1130
|
-
svld1_bf16(
|
|
1133
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1134
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
|
|
1131
1135
|
svmopa_za32_bf16_m(
|
|
1132
|
-
0,
|
|
1133
|
-
svld1_bf16(
|
|
1136
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1137
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
|
|
1134
1138
|
svmopa_za32_bf16_m(
|
|
1135
|
-
0,
|
|
1136
|
-
svld1_bf16(
|
|
1139
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1140
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
|
|
1137
1141
|
svmopa_za32_bf16_m(
|
|
1138
|
-
0,
|
|
1139
|
-
svld1_bf16(
|
|
1142
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1143
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
|
|
1140
1144
|
svmopa_za32_bf16_m(
|
|
1141
|
-
0,
|
|
1142
|
-
svld1_bf16(
|
|
1145
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1146
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
|
|
1143
1147
|
svmopa_za32_bf16_m(
|
|
1144
|
-
0,
|
|
1145
|
-
svld1_bf16(
|
|
1148
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1149
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
|
|
1146
1150
|
svmopa_za32_bf16_m(
|
|
1147
|
-
0,
|
|
1148
|
-
svld1_bf16(
|
|
1151
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1152
|
+
svld1_bf16(predicate_all_b16x, (bfloat16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
|
|
1149
1153
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
1150
|
-
svst1_f32(
|
|
1151
|
-
svadd_f32_x(
|
|
1152
|
-
svld1_f32(
|
|
1154
|
+
svst1_f32(predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
|
|
1155
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1156
|
+
svld1_f32(predicate_all_b32x,
|
|
1153
1157
|
output_accumulator + query_index * head_dim_padded + dim_tile * 16),
|
|
1154
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1158
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
|
|
1155
1159
|
}
|
|
1156
1160
|
}
|
|
1157
1161
|
}
|
|
1158
1162
|
|
|
1159
1163
|
// Final normalization
|
|
1160
1164
|
NK_ALIGN64 nk_f32_t final_sums[16];
|
|
1161
|
-
svst1_f32(
|
|
1165
|
+
svst1_f32(predicate_all_b32x, final_sums, row_sum_f32x);
|
|
1162
1166
|
|
|
1163
1167
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1164
1168
|
nk_f32_t inv_sum = (final_sums[query_index] > 0.0f) ? (1.0f / final_sums[query_index]) : 0.0f;
|
|
1165
1169
|
svfloat32_t inv_sum_f32x = svdup_f32(inv_sum);
|
|
1166
1170
|
|
|
1167
1171
|
for (nk_size_t dim_offset = 0; dim_offset < head_dim; dim_offset += svcntw()) {
|
|
1168
|
-
svbool_t
|
|
1172
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(dim_offset, head_dim);
|
|
1169
1173
|
svfloat32_t output_f32x = svmul_f32_x(
|
|
1170
|
-
|
|
1171
|
-
svld1_f32(
|
|
1174
|
+
predicate_b32x,
|
|
1175
|
+
svld1_f32(predicate_b32x, output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
1172
1176
|
inv_sum_f32x);
|
|
1173
|
-
svbfloat16_t output_bf16x = nk_f32_to_bf16_sve_(
|
|
1177
|
+
svbfloat16_t output_bf16x = nk_f32_to_bf16_sve_(predicate_b32x, output_f32x);
|
|
1174
1178
|
nk_size_t store_count = (head_dim - dim_offset) < (nk_size_t)svcntw() ? (head_dim - dim_offset)
|
|
1175
1179
|
: (nk_size_t)svcntw();
|
|
1176
|
-
svbool_t
|
|
1177
|
-
svst1_bf16(
|
|
1180
|
+
svbool_t store_predicate_b16x = svwhilelt_b16_u64(0u, store_count);
|
|
1181
|
+
svst1_bf16(store_predicate_b16x, (bfloat16_t *)(output + query_index * head_dim + dim_offset),
|
|
1178
1182
|
output_bf16x);
|
|
1179
1183
|
}
|
|
1180
1184
|
}
|
|
@@ -1220,24 +1224,24 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1220
1224
|
nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim, nk_size_t head_dim_padded, nk_size_t dim_tile_count,
|
|
1221
1225
|
nk_f32_t scale) {
|
|
1222
1226
|
|
|
1223
|
-
svbool_t const
|
|
1224
|
-
svbool_t const
|
|
1227
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
1228
|
+
svbool_t const predicate_all_b16x = svptrue_b16();
|
|
1225
1229
|
nk_size_t const valid_query_count = (query_len < 16) ? query_len : 16;
|
|
1226
1230
|
|
|
1227
1231
|
NK_ALIGN64 nk_f32_t row_max[16];
|
|
1228
1232
|
NK_ALIGN64 nk_f32_t row_sum[16];
|
|
1229
1233
|
NK_ALIGN64 nk_f32_t output_accumulator[16 * 256];
|
|
1230
1234
|
|
|
1231
|
-
svst1_f32(
|
|
1232
|
-
svst1_f32(
|
|
1235
|
+
svst1_f32(predicate_all_b32x, row_max, svdup_f32(NK_F32_MIN));
|
|
1236
|
+
svst1_f32(predicate_all_b32x, row_sum, svdup_f32(0.0f));
|
|
1233
1237
|
svfloat32_t zero_f32x = svdup_f32(0.0f);
|
|
1234
1238
|
for (nk_size_t i = 0; i < 16 * head_dim_padded; i += svcntw()) {
|
|
1235
|
-
svst1_f32(
|
|
1239
|
+
svst1_f32(predicate_all_b32x, output_accumulator + i, zero_f32x);
|
|
1236
1240
|
}
|
|
1237
1241
|
|
|
1238
1242
|
nk_size_t kv_block_index = 0;
|
|
1239
1243
|
nk_size_t kv_start = 0;
|
|
1240
|
-
svbool_t const
|
|
1244
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32(0u, 16u);
|
|
1241
1245
|
|
|
1242
1246
|
nk_size_t const k_depth_step_count = head_dim_padded / 2;
|
|
1243
1247
|
|
|
@@ -1248,11 +1252,11 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1248
1252
|
for (nk_size_t batch = 0; batch < head_dim_padded / 32; batch++) {
|
|
1249
1253
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1250
1254
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
1251
|
-
svld1_hor_za32(0, query_index,
|
|
1255
|
+
svld1_hor_za32(0, query_index, batch_predicate_b32x,
|
|
1252
1256
|
(nk_f32_t const *)(q + query_index * head_dim + batch * 32));
|
|
1253
1257
|
for (nk_size_t step = 0; step < 16; step++)
|
|
1254
|
-
svst1_f32(
|
|
1255
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1258
|
+
svst1_f32(predicate_all_b32x, queries_transposed + (batch * 16 + step) * 16,
|
|
1259
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, step));
|
|
1256
1260
|
}
|
|
1257
1261
|
|
|
1258
1262
|
// === Bc=32 main loop (prefill only, skipped for decode) ===
|
|
@@ -1261,14 +1265,15 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1261
1265
|
// Q×K^T: pure memory→FMOPA, no ZA staging for Q or K
|
|
1262
1266
|
svzero_mask_za(nk_sme_zero_za32_tile_2_);
|
|
1263
1267
|
svzero_mask_za(nk_sme_zero_za32_tile_3_);
|
|
1264
|
-
nk_f16_t const *
|
|
1265
|
-
nk_f16_t const *
|
|
1268
|
+
nk_f16_t const *keys_block_low = k + kv_block_index * k_depth_step_count * 32;
|
|
1269
|
+
nk_f16_t const *keys_block_high = k + (kv_block_index + 1) * k_depth_step_count * 32;
|
|
1266
1270
|
for (nk_size_t step = 0; step < k_depth_step_count; step++) {
|
|
1267
|
-
svfloat16_t
|
|
1268
|
-
|
|
1269
|
-
svfloat16_t
|
|
1270
|
-
|
|
1271
|
-
svmopa_za32_f16_m(
|
|
1271
|
+
svfloat16_t zn_f16x = svreinterpret_f16_f32(
|
|
1272
|
+
svld1_f32(predicate_all_b32x, queries_transposed + step * 16));
|
|
1273
|
+
svfloat16_t zm0_f16x = svld1_f16(predicate_all_b16x, (float16_t const *)(keys_block_low + step * 32));
|
|
1274
|
+
svfloat16_t zm1_f16x = svld1_f16(predicate_all_b16x, (float16_t const *)(keys_block_high + step * 32));
|
|
1275
|
+
svmopa_za32_f16_m(2, predicate_all_b32x, predicate_all_b32x, zn_f16x, zm0_f16x);
|
|
1276
|
+
svmopa_za32_f16_m(3, predicate_all_b32x, predicate_all_b32x, zn_f16x, zm1_f16x);
|
|
1272
1277
|
}
|
|
1273
1278
|
// ZA2 = scores[query_index][0:15], ZA3 = scores[query_index][16:31]
|
|
1274
1279
|
|
|
@@ -1277,29 +1282,29 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1277
1282
|
svfloat32_t block_max_f32x = svdup_f32(NK_F32_MIN);
|
|
1278
1283
|
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
1279
1284
|
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
1280
|
-
|
|
1285
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
|
|
1281
1286
|
scale_f32x);
|
|
1282
|
-
block_max_f32x = svmax_f32_x(
|
|
1287
|
+
block_max_f32x = svmax_f32_x(predicate_all_b32x, block_max_f32x, score_column_f32x);
|
|
1283
1288
|
}
|
|
1284
1289
|
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
1285
1290
|
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
1286
|
-
|
|
1291
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index),
|
|
1287
1292
|
scale_f32x);
|
|
1288
|
-
block_max_f32x = svmax_f32_x(
|
|
1293
|
+
block_max_f32x = svmax_f32_x(predicate_all_b32x, block_max_f32x, score_column_f32x);
|
|
1289
1294
|
}
|
|
1290
1295
|
|
|
1291
1296
|
// Softmax correction (vectorized via array load/store)
|
|
1292
|
-
svfloat32_t old_max_f32x = svld1_f32(
|
|
1293
|
-
svfloat32_t new_max_f32x = svmax_f32_x(
|
|
1297
|
+
svfloat32_t old_max_f32x = svld1_f32(predicate_all_b32x, row_max);
|
|
1298
|
+
svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_b32x, old_max_f32x, block_max_f32x);
|
|
1294
1299
|
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
|
|
1295
|
-
|
|
1296
|
-
svbool_t
|
|
1297
|
-
nk_u32_t max_was_updated = svptest_any(
|
|
1298
|
-
svfloat32_t row_sum_corrected_f32x = svld1_f32(
|
|
1300
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, old_max_f32x, new_max_f32x));
|
|
1301
|
+
svbool_t max_changed_b32x = svcmplt_f32(predicate_all_b32x, correction_f32x, svdup_f32(1.0f));
|
|
1302
|
+
nk_u32_t max_was_updated = svptest_any(predicate_all_b32x, max_changed_b32x) ? 1 : 0;
|
|
1303
|
+
svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_b32x, row_sum);
|
|
1299
1304
|
if (max_was_updated)
|
|
1300
|
-
row_sum_corrected_f32x = svmul_f32_x(
|
|
1305
|
+
row_sum_corrected_f32x = svmul_f32_x(predicate_all_b32x, row_sum_corrected_f32x, correction_f32x);
|
|
1301
1306
|
NK_ALIGN64 nk_f32_t corrections[16];
|
|
1302
|
-
svst1_f32(
|
|
1307
|
+
svst1_f32(predicate_all_b32x, corrections, correction_f32x);
|
|
1303
1308
|
|
|
1304
1309
|
// Pass 2: Column-wise exp + fused P write + sum
|
|
1305
1310
|
svfloat32_t sum_delta_f32x = svdup_f32(0.0f);
|
|
@@ -1307,92 +1312,92 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1307
1312
|
// ZA2 columns in pairs -> ZA0 columns 0-7
|
|
1308
1313
|
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
1309
1314
|
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
1310
|
-
|
|
1315
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
|
|
1311
1316
|
scale_f32x);
|
|
1312
1317
|
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
1313
|
-
|
|
1318
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index + 1),
|
|
1314
1319
|
scale_f32x);
|
|
1315
1320
|
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
1316
|
-
|
|
1321
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
|
|
1317
1322
|
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
1318
|
-
|
|
1319
|
-
sum_delta_f32x = svadd_f32_x(
|
|
1320
|
-
sum_delta_f32x = svadd_f32_x(
|
|
1321
|
-
svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(
|
|
1322
|
-
svcvt_f16_f32_x(
|
|
1323
|
-
svwrite_ver_za32_f32_m(0, column_index / 2,
|
|
1323
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
|
|
1324
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_even_f32x);
|
|
1325
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_odd_f32x);
|
|
1326
|
+
svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_b32x, weight_even_f32x),
|
|
1327
|
+
svcvt_f16_f32_x(predicate_all_b32x, weight_odd_f32x));
|
|
1328
|
+
svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_b32x,
|
|
1324
1329
|
svreinterpret_f32_f16(weight_pair_f16x));
|
|
1325
1330
|
}
|
|
1326
1331
|
// ZA3 columns in pairs -> ZA0 columns 8-15
|
|
1327
1332
|
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
1328
1333
|
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
1329
|
-
|
|
1334
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index),
|
|
1330
1335
|
scale_f32x);
|
|
1331
1336
|
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
1332
|
-
|
|
1337
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, column_index + 1),
|
|
1333
1338
|
scale_f32x);
|
|
1334
1339
|
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
1335
|
-
|
|
1340
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
|
|
1336
1341
|
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
1337
|
-
|
|
1338
|
-
sum_delta_f32x = svadd_f32_x(
|
|
1339
|
-
sum_delta_f32x = svadd_f32_x(
|
|
1340
|
-
svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(
|
|
1341
|
-
svcvt_f16_f32_x(
|
|
1342
|
-
svwrite_ver_za32_f32_m(0, 8 + column_index / 2,
|
|
1342
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
|
|
1343
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_even_f32x);
|
|
1344
|
+
sum_delta_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_f32x, weight_odd_f32x);
|
|
1345
|
+
svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_b32x, weight_even_f32x),
|
|
1346
|
+
svcvt_f16_f32_x(predicate_all_b32x, weight_odd_f32x));
|
|
1347
|
+
svwrite_ver_za32_f32_m(0, 8 + column_index / 2, predicate_all_b32x,
|
|
1343
1348
|
svreinterpret_f32_f16(weight_pair_f16x));
|
|
1344
1349
|
}
|
|
1345
|
-
row_sum_corrected_f32x = svadd_f32_x(
|
|
1346
|
-
svst1_f32(
|
|
1347
|
-
svst1_f32(
|
|
1350
|
+
row_sum_corrected_f32x = svadd_f32_x(predicate_all_b32x, row_sum_corrected_f32x, sum_delta_f32x);
|
|
1351
|
+
svst1_f32(predicate_all_b32x, row_sum, row_sum_corrected_f32x);
|
|
1352
|
+
svst1_f32(predicate_all_b32x, row_max, new_max_f32x);
|
|
1348
1353
|
|
|
1349
1354
|
// Extract P columns from ZA0
|
|
1350
1355
|
svfloat16_t probability_column_0_f32x = svreinterpret_f16_f32(
|
|
1351
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1356
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
|
|
1352
1357
|
svfloat16_t probability_column_1_f32x = svreinterpret_f16_f32(
|
|
1353
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1358
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 1));
|
|
1354
1359
|
svfloat16_t probability_column_2_f32x = svreinterpret_f16_f32(
|
|
1355
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1360
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 2));
|
|
1356
1361
|
svfloat16_t probability_column_3_f32x = svreinterpret_f16_f32(
|
|
1357
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1362
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 3));
|
|
1358
1363
|
svfloat16_t probability_column_4_f32x = svreinterpret_f16_f32(
|
|
1359
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1364
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 4));
|
|
1360
1365
|
svfloat16_t probability_column_5_f32x = svreinterpret_f16_f32(
|
|
1361
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1366
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 5));
|
|
1362
1367
|
svfloat16_t probability_column_6_f32x = svreinterpret_f16_f32(
|
|
1363
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1368
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 6));
|
|
1364
1369
|
svfloat16_t probability_column_7_f32x = svreinterpret_f16_f32(
|
|
1365
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1370
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 7));
|
|
1366
1371
|
svfloat16_t probability_column_8_f32x = svreinterpret_f16_f32(
|
|
1367
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1372
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 8));
|
|
1368
1373
|
svfloat16_t probability_column_9_f32x = svreinterpret_f16_f32(
|
|
1369
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1374
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 9));
|
|
1370
1375
|
svfloat16_t probability_column_10_f32x = svreinterpret_f16_f32(
|
|
1371
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1376
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 10));
|
|
1372
1377
|
svfloat16_t probability_column_11_f32x = svreinterpret_f16_f32(
|
|
1373
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1378
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 11));
|
|
1374
1379
|
svfloat16_t probability_column_12_f32x = svreinterpret_f16_f32(
|
|
1375
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1380
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 12));
|
|
1376
1381
|
svfloat16_t probability_column_13_f32x = svreinterpret_f16_f32(
|
|
1377
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1382
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 13));
|
|
1378
1383
|
svfloat16_t probability_column_14_f32x = svreinterpret_f16_f32(
|
|
1379
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1384
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 14));
|
|
1380
1385
|
svfloat16_t probability_column_15_f32x = svreinterpret_f16_f32(
|
|
1381
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1386
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 15));
|
|
1382
1387
|
|
|
1383
1388
|
// Pre-apply correction once before P×V
|
|
1384
|
-
svbool_t
|
|
1385
|
-
nk_f16_t const *
|
|
1386
|
-
nk_f16_t const *
|
|
1389
|
+
svbool_t query_predicate_b16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
|
|
1390
|
+
nk_f16_t const *values_block_low = v_packed + kv_block_index * dim_tile_count * 8 * 32;
|
|
1391
|
+
nk_f16_t const *values_block_high = v_packed + (kv_block_index + 1) * dim_tile_count * 8 * 32;
|
|
1387
1392
|
|
|
1388
1393
|
if (max_was_updated) {
|
|
1389
1394
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1390
1395
|
svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
|
|
1391
1396
|
for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
|
|
1392
1397
|
svst1_f32(
|
|
1393
|
-
|
|
1394
|
-
svmul_f32_x(
|
|
1395
|
-
svld1_f32(
|
|
1398
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_offset,
|
|
1399
|
+
svmul_f32_x(predicate_all_b32x,
|
|
1400
|
+
svld1_f32(predicate_all_b32x,
|
|
1396
1401
|
output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
1397
1402
|
correction_scalar_f32x));
|
|
1398
1403
|
}
|
|
@@ -1403,284 +1408,284 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1403
1408
|
for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
|
|
1404
1409
|
svzero_za();
|
|
1405
1410
|
// Block0: 8 depth steps (KV positions 0-15)
|
|
1406
|
-
svmopa_za32_f16_m(0,
|
|
1407
|
-
svld1_f16(
|
|
1408
|
-
(float16_t const *)(
|
|
1409
|
-
svmopa_za32_f16_m(1,
|
|
1410
|
-
svld1_f16(
|
|
1411
|
-
(float16_t const *)(
|
|
1412
|
-
svmopa_za32_f16_m(2,
|
|
1413
|
-
svld1_f16(
|
|
1414
|
-
(float16_t const *)(
|
|
1415
|
-
svmopa_za32_f16_m(3,
|
|
1416
|
-
svld1_f16(
|
|
1417
|
-
(float16_t const *)(
|
|
1418
|
-
svmopa_za32_f16_m(0,
|
|
1419
|
-
svld1_f16(
|
|
1420
|
-
(float16_t const *)(
|
|
1421
|
-
svmopa_za32_f16_m(1,
|
|
1422
|
-
svld1_f16(
|
|
1423
|
-
(float16_t const *)(
|
|
1424
|
-
svmopa_za32_f16_m(2,
|
|
1425
|
-
svld1_f16(
|
|
1426
|
-
(float16_t const *)(
|
|
1427
|
-
svmopa_za32_f16_m(3,
|
|
1428
|
-
svld1_f16(
|
|
1429
|
-
(float16_t const *)(
|
|
1430
|
-
svmopa_za32_f16_m(0,
|
|
1431
|
-
svld1_f16(
|
|
1432
|
-
(float16_t const *)(
|
|
1433
|
-
svmopa_za32_f16_m(1,
|
|
1434
|
-
svld1_f16(
|
|
1435
|
-
(float16_t const *)(
|
|
1436
|
-
svmopa_za32_f16_m(2,
|
|
1437
|
-
svld1_f16(
|
|
1438
|
-
(float16_t const *)(
|
|
1439
|
-
svmopa_za32_f16_m(3,
|
|
1440
|
-
svld1_f16(
|
|
1441
|
-
(float16_t const *)(
|
|
1442
|
-
svmopa_za32_f16_m(0,
|
|
1443
|
-
svld1_f16(
|
|
1444
|
-
(float16_t const *)(
|
|
1445
|
-
svmopa_za32_f16_m(1,
|
|
1446
|
-
svld1_f16(
|
|
1447
|
-
(float16_t const *)(
|
|
1448
|
-
svmopa_za32_f16_m(2,
|
|
1449
|
-
svld1_f16(
|
|
1450
|
-
(float16_t const *)(
|
|
1451
|
-
svmopa_za32_f16_m(3,
|
|
1452
|
-
svld1_f16(
|
|
1453
|
-
(float16_t const *)(
|
|
1454
|
-
svmopa_za32_f16_m(0,
|
|
1455
|
-
svld1_f16(
|
|
1456
|
-
(float16_t const *)(
|
|
1457
|
-
svmopa_za32_f16_m(1,
|
|
1458
|
-
svld1_f16(
|
|
1459
|
-
(float16_t const *)(
|
|
1460
|
-
svmopa_za32_f16_m(2,
|
|
1461
|
-
svld1_f16(
|
|
1462
|
-
(float16_t const *)(
|
|
1463
|
-
svmopa_za32_f16_m(3,
|
|
1464
|
-
svld1_f16(
|
|
1465
|
-
(float16_t const *)(
|
|
1466
|
-
svmopa_za32_f16_m(0,
|
|
1467
|
-
svld1_f16(
|
|
1468
|
-
(float16_t const *)(
|
|
1469
|
-
svmopa_za32_f16_m(1,
|
|
1470
|
-
svld1_f16(
|
|
1471
|
-
(float16_t const *)(
|
|
1472
|
-
svmopa_za32_f16_m(2,
|
|
1473
|
-
svld1_f16(
|
|
1474
|
-
(float16_t const *)(
|
|
1475
|
-
svmopa_za32_f16_m(3,
|
|
1476
|
-
svld1_f16(
|
|
1477
|
-
(float16_t const *)(
|
|
1478
|
-
svmopa_za32_f16_m(0,
|
|
1479
|
-
svld1_f16(
|
|
1480
|
-
(float16_t const *)(
|
|
1481
|
-
svmopa_za32_f16_m(1,
|
|
1482
|
-
svld1_f16(
|
|
1483
|
-
(float16_t const *)(
|
|
1484
|
-
svmopa_za32_f16_m(2,
|
|
1485
|
-
svld1_f16(
|
|
1486
|
-
(float16_t const *)(
|
|
1487
|
-
svmopa_za32_f16_m(3,
|
|
1488
|
-
svld1_f16(
|
|
1489
|
-
(float16_t const *)(
|
|
1490
|
-
svmopa_za32_f16_m(0,
|
|
1491
|
-
svld1_f16(
|
|
1492
|
-
(float16_t const *)(
|
|
1493
|
-
svmopa_za32_f16_m(1,
|
|
1494
|
-
svld1_f16(
|
|
1495
|
-
(float16_t const *)(
|
|
1496
|
-
svmopa_za32_f16_m(2,
|
|
1497
|
-
svld1_f16(
|
|
1498
|
-
(float16_t const *)(
|
|
1499
|
-
svmopa_za32_f16_m(3,
|
|
1500
|
-
svld1_f16(
|
|
1501
|
-
(float16_t const *)(
|
|
1411
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1412
|
+
svld1_f16(predicate_all_b16x,
|
|
1413
|
+
(float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
1414
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1415
|
+
svld1_f16(predicate_all_b16x,
|
|
1416
|
+
(float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
1417
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1418
|
+
svld1_f16(predicate_all_b16x,
|
|
1419
|
+
(float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
1420
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1421
|
+
svld1_f16(predicate_all_b16x,
|
|
1422
|
+
(float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
1423
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1424
|
+
svld1_f16(predicate_all_b16x,
|
|
1425
|
+
(float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
1426
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1427
|
+
svld1_f16(predicate_all_b16x,
|
|
1428
|
+
(float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
1429
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1430
|
+
svld1_f16(predicate_all_b16x,
|
|
1431
|
+
(float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
1432
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1433
|
+
svld1_f16(predicate_all_b16x,
|
|
1434
|
+
(float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
1435
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1436
|
+
svld1_f16(predicate_all_b16x,
|
|
1437
|
+
(float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
1438
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1439
|
+
svld1_f16(predicate_all_b16x,
|
|
1440
|
+
(float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
1441
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1442
|
+
svld1_f16(predicate_all_b16x,
|
|
1443
|
+
(float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
1444
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1445
|
+
svld1_f16(predicate_all_b16x,
|
|
1446
|
+
(float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
1447
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1448
|
+
svld1_f16(predicate_all_b16x,
|
|
1449
|
+
(float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
1450
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1451
|
+
svld1_f16(predicate_all_b16x,
|
|
1452
|
+
(float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
1453
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1454
|
+
svld1_f16(predicate_all_b16x,
|
|
1455
|
+
(float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
1456
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1457
|
+
svld1_f16(predicate_all_b16x,
|
|
1458
|
+
(float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
1459
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1460
|
+
svld1_f16(predicate_all_b16x,
|
|
1461
|
+
(float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
1462
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1463
|
+
svld1_f16(predicate_all_b16x,
|
|
1464
|
+
(float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
1465
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1466
|
+
svld1_f16(predicate_all_b16x,
|
|
1467
|
+
(float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
1468
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1469
|
+
svld1_f16(predicate_all_b16x,
|
|
1470
|
+
(float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
1471
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1472
|
+
svld1_f16(predicate_all_b16x,
|
|
1473
|
+
(float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
1474
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1475
|
+
svld1_f16(predicate_all_b16x,
|
|
1476
|
+
(float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
1477
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1478
|
+
svld1_f16(predicate_all_b16x,
|
|
1479
|
+
(float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
1480
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1481
|
+
svld1_f16(predicate_all_b16x,
|
|
1482
|
+
(float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
1483
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1484
|
+
svld1_f16(predicate_all_b16x,
|
|
1485
|
+
(float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
1486
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1487
|
+
svld1_f16(predicate_all_b16x,
|
|
1488
|
+
(float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
1489
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1490
|
+
svld1_f16(predicate_all_b16x,
|
|
1491
|
+
(float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
1492
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1493
|
+
svld1_f16(predicate_all_b16x,
|
|
1494
|
+
(float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
1495
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1496
|
+
svld1_f16(predicate_all_b16x,
|
|
1497
|
+
(float16_t const *)(values_block_low + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
1498
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1499
|
+
svld1_f16(predicate_all_b16x,
|
|
1500
|
+
(float16_t const *)(values_block_low + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
1501
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1502
|
+
svld1_f16(predicate_all_b16x,
|
|
1503
|
+
(float16_t const *)(values_block_low + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
1504
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1505
|
+
svld1_f16(predicate_all_b16x,
|
|
1506
|
+
(float16_t const *)(values_block_low + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
1502
1507
|
// Block1: 8 depth steps (KV positions 16-31)
|
|
1503
|
-
svmopa_za32_f16_m(0,
|
|
1504
|
-
svld1_f16(
|
|
1505
|
-
(float16_t const *)(
|
|
1506
|
-
svmopa_za32_f16_m(1,
|
|
1507
|
-
svld1_f16(
|
|
1508
|
-
(float16_t const *)(
|
|
1509
|
-
svmopa_za32_f16_m(2,
|
|
1510
|
-
svld1_f16(
|
|
1511
|
-
(float16_t const *)(
|
|
1512
|
-
svmopa_za32_f16_m(3,
|
|
1513
|
-
svld1_f16(
|
|
1514
|
-
(float16_t const *)(
|
|
1515
|
-
svmopa_za32_f16_m(0,
|
|
1516
|
-
svld1_f16(
|
|
1517
|
-
(float16_t const *)(
|
|
1518
|
-
svmopa_za32_f16_m(1,
|
|
1519
|
-
svld1_f16(
|
|
1520
|
-
(float16_t const *)(
|
|
1521
|
-
svmopa_za32_f16_m(2,
|
|
1522
|
-
svld1_f16(
|
|
1523
|
-
(float16_t const *)(
|
|
1524
|
-
svmopa_za32_f16_m(3,
|
|
1525
|
-
svld1_f16(
|
|
1526
|
-
(float16_t const *)(
|
|
1527
|
-
svmopa_za32_f16_m(0,
|
|
1528
|
-
svld1_f16(
|
|
1529
|
-
(float16_t const *)(
|
|
1530
|
-
svmopa_za32_f16_m(1,
|
|
1531
|
-
svld1_f16(
|
|
1532
|
-
(float16_t const *)(
|
|
1533
|
-
svmopa_za32_f16_m(2,
|
|
1534
|
-
svld1_f16(
|
|
1535
|
-
(float16_t const *)(
|
|
1536
|
-
svmopa_za32_f16_m(3,
|
|
1537
|
-
svld1_f16(
|
|
1538
|
-
(float16_t const *)(
|
|
1539
|
-
svmopa_za32_f16_m(0,
|
|
1540
|
-
svld1_f16(
|
|
1541
|
-
(float16_t const *)(
|
|
1542
|
-
svmopa_za32_f16_m(1,
|
|
1543
|
-
svld1_f16(
|
|
1544
|
-
(float16_t const *)(
|
|
1545
|
-
svmopa_za32_f16_m(2,
|
|
1546
|
-
svld1_f16(
|
|
1547
|
-
(float16_t const *)(
|
|
1548
|
-
svmopa_za32_f16_m(3,
|
|
1549
|
-
svld1_f16(
|
|
1550
|
-
(float16_t const *)(
|
|
1551
|
-
svmopa_za32_f16_m(0,
|
|
1552
|
-
svld1_f16(
|
|
1553
|
-
(float16_t const *)(
|
|
1554
|
-
svmopa_za32_f16_m(1,
|
|
1555
|
-
svld1_f16(
|
|
1556
|
-
(float16_t const *)(
|
|
1557
|
-
svmopa_za32_f16_m(2,
|
|
1558
|
-
svld1_f16(
|
|
1559
|
-
(float16_t const *)(
|
|
1560
|
-
svmopa_za32_f16_m(3,
|
|
1561
|
-
svld1_f16(
|
|
1562
|
-
(float16_t const *)(
|
|
1563
|
-
svmopa_za32_f16_m(0,
|
|
1564
|
-
svld1_f16(
|
|
1565
|
-
(float16_t const *)(
|
|
1566
|
-
svmopa_za32_f16_m(1,
|
|
1567
|
-
svld1_f16(
|
|
1568
|
-
(float16_t const *)(
|
|
1569
|
-
svmopa_za32_f16_m(2,
|
|
1570
|
-
svld1_f16(
|
|
1571
|
-
(float16_t const *)(
|
|
1572
|
-
svmopa_za32_f16_m(3,
|
|
1573
|
-
svld1_f16(
|
|
1574
|
-
(float16_t const *)(
|
|
1575
|
-
svmopa_za32_f16_m(0,
|
|
1576
|
-
svld1_f16(
|
|
1577
|
-
(float16_t const *)(
|
|
1578
|
-
svmopa_za32_f16_m(1,
|
|
1579
|
-
svld1_f16(
|
|
1580
|
-
(float16_t const *)(
|
|
1581
|
-
svmopa_za32_f16_m(2,
|
|
1582
|
-
svld1_f16(
|
|
1583
|
-
(float16_t const *)(
|
|
1584
|
-
svmopa_za32_f16_m(3,
|
|
1585
|
-
svld1_f16(
|
|
1586
|
-
(float16_t const *)(
|
|
1587
|
-
svmopa_za32_f16_m(0,
|
|
1588
|
-
svld1_f16(
|
|
1589
|
-
(float16_t const *)(
|
|
1590
|
-
svmopa_za32_f16_m(1,
|
|
1591
|
-
svld1_f16(
|
|
1592
|
-
(float16_t const *)(
|
|
1593
|
-
svmopa_za32_f16_m(2,
|
|
1594
|
-
svld1_f16(
|
|
1595
|
-
(float16_t const *)(
|
|
1596
|
-
svmopa_za32_f16_m(3,
|
|
1597
|
-
svld1_f16(
|
|
1598
|
-
(float16_t const *)(
|
|
1508
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
1509
|
+
svld1_f16(predicate_all_b16x,
|
|
1510
|
+
(float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
1511
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
1512
|
+
svld1_f16(predicate_all_b16x,
|
|
1513
|
+
(float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
1514
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
1515
|
+
svld1_f16(predicate_all_b16x,
|
|
1516
|
+
(float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
1517
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
1518
|
+
svld1_f16(predicate_all_b16x,
|
|
1519
|
+
(float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
1520
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
1521
|
+
svld1_f16(predicate_all_b16x,
|
|
1522
|
+
(float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
1523
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
1524
|
+
svld1_f16(predicate_all_b16x,
|
|
1525
|
+
(float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
1526
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
1527
|
+
svld1_f16(predicate_all_b16x,
|
|
1528
|
+
(float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
1529
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
1530
|
+
svld1_f16(predicate_all_b16x,
|
|
1531
|
+
(float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
1532
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
1533
|
+
svld1_f16(predicate_all_b16x,
|
|
1534
|
+
(float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
1535
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
1536
|
+
svld1_f16(predicate_all_b16x,
|
|
1537
|
+
(float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
1538
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
1539
|
+
svld1_f16(predicate_all_b16x,
|
|
1540
|
+
(float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
1541
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
1542
|
+
svld1_f16(predicate_all_b16x,
|
|
1543
|
+
(float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
1544
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
1545
|
+
svld1_f16(predicate_all_b16x,
|
|
1546
|
+
(float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
1547
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
1548
|
+
svld1_f16(predicate_all_b16x,
|
|
1549
|
+
(float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
1550
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
1551
|
+
svld1_f16(predicate_all_b16x,
|
|
1552
|
+
(float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
1553
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
1554
|
+
svld1_f16(predicate_all_b16x,
|
|
1555
|
+
(float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
1556
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
1557
|
+
svld1_f16(predicate_all_b16x,
|
|
1558
|
+
(float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
1559
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
1560
|
+
svld1_f16(predicate_all_b16x,
|
|
1561
|
+
(float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
1562
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
1563
|
+
svld1_f16(predicate_all_b16x,
|
|
1564
|
+
(float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
1565
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
1566
|
+
svld1_f16(predicate_all_b16x,
|
|
1567
|
+
(float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
1568
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
1569
|
+
svld1_f16(predicate_all_b16x,
|
|
1570
|
+
(float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
1571
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
1572
|
+
svld1_f16(predicate_all_b16x,
|
|
1573
|
+
(float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
1574
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
1575
|
+
svld1_f16(predicate_all_b16x,
|
|
1576
|
+
(float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
1577
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
1578
|
+
svld1_f16(predicate_all_b16x,
|
|
1579
|
+
(float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
1580
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
1581
|
+
svld1_f16(predicate_all_b16x,
|
|
1582
|
+
(float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
1583
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
1584
|
+
svld1_f16(predicate_all_b16x,
|
|
1585
|
+
(float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
1586
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
1587
|
+
svld1_f16(predicate_all_b16x,
|
|
1588
|
+
(float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
1589
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
1590
|
+
svld1_f16(predicate_all_b16x,
|
|
1591
|
+
(float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
1592
|
+
svmopa_za32_f16_m(0, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
1593
|
+
svld1_f16(predicate_all_b16x,
|
|
1594
|
+
(float16_t const *)(values_block_high + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
1595
|
+
svmopa_za32_f16_m(1, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
1596
|
+
svld1_f16(predicate_all_b16x,
|
|
1597
|
+
(float16_t const *)(values_block_high + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
1598
|
+
svmopa_za32_f16_m(2, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
1599
|
+
svld1_f16(predicate_all_b16x,
|
|
1600
|
+
(float16_t const *)(values_block_high + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
1601
|
+
svmopa_za32_f16_m(3, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
1602
|
+
svld1_f16(predicate_all_b16x,
|
|
1603
|
+
(float16_t const *)(values_block_high + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
1599
1604
|
// Read FMOPA result and ADD to output_accumulator
|
|
1600
1605
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1601
1606
|
svst1_f32(
|
|
1602
|
-
|
|
1603
|
-
svadd_f32_x(
|
|
1604
|
-
svld1_f32(
|
|
1607
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
|
|
1608
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1609
|
+
svld1_f32(predicate_all_b32x,
|
|
1605
1610
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
|
|
1606
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1611
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
|
|
1607
1612
|
svst1_f32(
|
|
1608
|
-
|
|
1609
|
-
svadd_f32_x(
|
|
1610
|
-
svld1_f32(
|
|
1613
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
|
|
1614
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1615
|
+
svld1_f32(predicate_all_b32x,
|
|
1611
1616
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
|
|
1612
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1617
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 1, query_index)));
|
|
1613
1618
|
svst1_f32(
|
|
1614
|
-
|
|
1615
|
-
svadd_f32_x(
|
|
1616
|
-
svld1_f32(
|
|
1619
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
|
|
1620
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1621
|
+
svld1_f32(predicate_all_b32x,
|
|
1617
1622
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
|
|
1618
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1623
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, query_index)));
|
|
1619
1624
|
svst1_f32(
|
|
1620
|
-
|
|
1621
|
-
svadd_f32_x(
|
|
1622
|
-
svld1_f32(
|
|
1625
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
|
|
1626
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1627
|
+
svld1_f32(predicate_all_b32x,
|
|
1623
1628
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
|
|
1624
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1629
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, query_index)));
|
|
1625
1630
|
}
|
|
1626
1631
|
}
|
|
1627
1632
|
// Remainder: 1 dim_tile at a time using ZA0
|
|
1628
1633
|
for (; dim_tile < dim_tile_count; dim_tile++) {
|
|
1629
1634
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1630
1635
|
svmopa_za32_f16_m(
|
|
1631
|
-
0,
|
|
1632
|
-
svld1_f16(
|
|
1636
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1637
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 0) * 32)));
|
|
1633
1638
|
svmopa_za32_f16_m(
|
|
1634
|
-
0,
|
|
1635
|
-
svld1_f16(
|
|
1639
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1640
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 1) * 32)));
|
|
1636
1641
|
svmopa_za32_f16_m(
|
|
1637
|
-
0,
|
|
1638
|
-
svld1_f16(
|
|
1642
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1643
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 2) * 32)));
|
|
1639
1644
|
svmopa_za32_f16_m(
|
|
1640
|
-
0,
|
|
1641
|
-
svld1_f16(
|
|
1645
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1646
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 3) * 32)));
|
|
1642
1647
|
svmopa_za32_f16_m(
|
|
1643
|
-
0,
|
|
1644
|
-
svld1_f16(
|
|
1648
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1649
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 4) * 32)));
|
|
1645
1650
|
svmopa_za32_f16_m(
|
|
1646
|
-
0,
|
|
1647
|
-
svld1_f16(
|
|
1651
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1652
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 5) * 32)));
|
|
1648
1653
|
svmopa_za32_f16_m(
|
|
1649
|
-
0,
|
|
1650
|
-
svld1_f16(
|
|
1654
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1655
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 6) * 32)));
|
|
1651
1656
|
svmopa_za32_f16_m(
|
|
1652
|
-
0,
|
|
1653
|
-
svld1_f16(
|
|
1657
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1658
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_low + (dim_tile * 8 + 7) * 32)));
|
|
1654
1659
|
svmopa_za32_f16_m(
|
|
1655
|
-
0,
|
|
1656
|
-
svld1_f16(
|
|
1660
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_8_f32x,
|
|
1661
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 0) * 32)));
|
|
1657
1662
|
svmopa_za32_f16_m(
|
|
1658
|
-
0,
|
|
1659
|
-
svld1_f16(
|
|
1663
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_9_f32x,
|
|
1664
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 1) * 32)));
|
|
1660
1665
|
svmopa_za32_f16_m(
|
|
1661
|
-
0,
|
|
1662
|
-
svld1_f16(
|
|
1666
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_10_f32x,
|
|
1667
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 2) * 32)));
|
|
1663
1668
|
svmopa_za32_f16_m(
|
|
1664
|
-
0,
|
|
1665
|
-
svld1_f16(
|
|
1669
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_11_f32x,
|
|
1670
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 3) * 32)));
|
|
1666
1671
|
svmopa_za32_f16_m(
|
|
1667
|
-
0,
|
|
1668
|
-
svld1_f16(
|
|
1672
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_12_f32x,
|
|
1673
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 4) * 32)));
|
|
1669
1674
|
svmopa_za32_f16_m(
|
|
1670
|
-
0,
|
|
1671
|
-
svld1_f16(
|
|
1675
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_13_f32x,
|
|
1676
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 5) * 32)));
|
|
1672
1677
|
svmopa_za32_f16_m(
|
|
1673
|
-
0,
|
|
1674
|
-
svld1_f16(
|
|
1678
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_14_f32x,
|
|
1679
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 6) * 32)));
|
|
1675
1680
|
svmopa_za32_f16_m(
|
|
1676
|
-
0,
|
|
1677
|
-
svld1_f16(
|
|
1681
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_15_f32x,
|
|
1682
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(values_block_high + (dim_tile * 8 + 7) * 32)));
|
|
1678
1683
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
1679
|
-
svst1_f32(
|
|
1680
|
-
svadd_f32_x(
|
|
1681
|
-
svld1_f32(
|
|
1684
|
+
svst1_f32(predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
|
|
1685
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1686
|
+
svld1_f32(predicate_all_b32x,
|
|
1682
1687
|
output_accumulator + query_index * head_dim_padded + dim_tile * 16),
|
|
1683
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1688
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
|
|
1684
1689
|
}
|
|
1685
1690
|
}
|
|
1686
1691
|
}
|
|
@@ -1693,9 +1698,9 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1693
1698
|
svzero_mask_za(nk_sme_zero_za32_tile_2_);
|
|
1694
1699
|
nk_f16_t const *k_block = k + kv_block_index * k_depth_step_count * 32;
|
|
1695
1700
|
for (nk_size_t step = 0; step < k_depth_step_count; step++) {
|
|
1696
|
-
svfloat16_t
|
|
1697
|
-
svfloat16_t
|
|
1698
|
-
svmopa_za32_f16_m(2,
|
|
1701
|
+
svfloat16_t zn_f16x = svreinterpret_f16_f32(svld1_f32(predicate_all_b32x, queries_transposed + step * 16));
|
|
1702
|
+
svfloat16_t zm_f16x = svld1_f16(predicate_all_b16x, (float16_t const *)(k_block + step * 32));
|
|
1703
|
+
svmopa_za32_f16_m(2, predicate_all_b32x, predicate_all_b32x, zn_f16x, zm_f16x);
|
|
1699
1704
|
}
|
|
1700
1705
|
|
|
1701
1706
|
// Pass 1: Column-wise max (read ZA2 columns vertically)
|
|
@@ -1703,56 +1708,57 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1703
1708
|
svfloat32_t block_max_16_f32x = svdup_f32(NK_F32_MIN);
|
|
1704
1709
|
for (nk_size_t column_index = 0; column_index < 16; column_index++) {
|
|
1705
1710
|
svfloat32_t score_column_f32x = svmul_f32_x(
|
|
1706
|
-
|
|
1711
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
|
|
1707
1712
|
scale_16_f32x);
|
|
1708
|
-
block_max_16_f32x = svmax_f32_x(
|
|
1713
|
+
block_max_16_f32x = svmax_f32_x(predicate_all_b32x, block_max_16_f32x, score_column_f32x);
|
|
1709
1714
|
}
|
|
1710
1715
|
|
|
1711
|
-
svfloat32_t old_max_f32x = svld1_f32(
|
|
1712
|
-
svfloat32_t new_max_f32x = svmax_f32_x(
|
|
1713
|
-
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(
|
|
1714
|
-
svsub_f32_x(
|
|
1715
|
-
svbool_t
|
|
1716
|
-
nk_u32_t max_was_updated_16 = svptest_any(
|
|
1717
|
-
svfloat32_t row_sum_corrected_f32x = svld1_f32(
|
|
1716
|
+
svfloat32_t old_max_f32x = svld1_f32(predicate_all_b32x, row_max);
|
|
1717
|
+
svfloat32_t new_max_f32x = svmax_f32_x(predicate_all_b32x, old_max_f32x, block_max_16_f32x);
|
|
1718
|
+
svfloat32_t correction_f32x = nk_exp_fast_f32_sve_(predicate_all_b32x,
|
|
1719
|
+
svsub_f32_x(predicate_all_b32x, old_max_f32x, new_max_f32x));
|
|
1720
|
+
svbool_t max_changed_16_b32x = svcmplt_f32(predicate_all_b32x, correction_f32x, svdup_f32(1.0f));
|
|
1721
|
+
nk_u32_t max_was_updated_16 = svptest_any(predicate_all_b32x, max_changed_16_b32x) ? 1 : 0;
|
|
1722
|
+
svfloat32_t row_sum_corrected_f32x = svld1_f32(predicate_all_b32x, row_sum);
|
|
1718
1723
|
if (max_was_updated_16)
|
|
1719
|
-
row_sum_corrected_f32x = svmul_f32_x(
|
|
1724
|
+
row_sum_corrected_f32x = svmul_f32_x(predicate_all_b32x, row_sum_corrected_f32x, correction_f32x);
|
|
1720
1725
|
NK_ALIGN64 nk_f32_t corrections[16];
|
|
1721
|
-
svst1_f32(
|
|
1726
|
+
svst1_f32(predicate_all_b32x, corrections, correction_f32x);
|
|
1722
1727
|
|
|
1723
1728
|
// Pass 2: Column-wise exp + fused P write + sum (ZA2 → ZA0 columns 0-7)
|
|
1724
1729
|
svfloat32_t sum_delta_16_f32x = svdup_f32(0.0f);
|
|
1725
1730
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1726
1731
|
for (nk_size_t column_index = 0; column_index < 16; column_index += 2) {
|
|
1727
1732
|
svfloat32_t score_even_f32x = svmul_f32_x(
|
|
1728
|
-
|
|
1733
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index),
|
|
1729
1734
|
scale_16_f32x);
|
|
1730
1735
|
svfloat32_t score_odd_f32x = svmul_f32_x(
|
|
1731
|
-
|
|
1736
|
+
predicate_all_b32x, svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, column_index + 1),
|
|
1732
1737
|
scale_16_f32x);
|
|
1733
1738
|
svfloat32_t weight_even_f32x = nk_exp_fast_f32_sve_(
|
|
1734
|
-
|
|
1739
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_even_f32x, new_max_f32x));
|
|
1735
1740
|
svfloat32_t weight_odd_f32x = nk_exp_fast_f32_sve_(
|
|
1736
|
-
|
|
1737
|
-
sum_delta_16_f32x = svadd_f32_x(
|
|
1738
|
-
sum_delta_16_f32x = svadd_f32_x(
|
|
1739
|
-
svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(
|
|
1740
|
-
svcvt_f16_f32_x(
|
|
1741
|
-
svwrite_ver_za32_f32_m(0, column_index / 2,
|
|
1741
|
+
predicate_all_b32x, svsub_f32_x(predicate_all_b32x, score_odd_f32x, new_max_f32x));
|
|
1742
|
+
sum_delta_16_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_16_f32x, weight_even_f32x);
|
|
1743
|
+
sum_delta_16_f32x = svadd_f32_x(predicate_all_b32x, sum_delta_16_f32x, weight_odd_f32x);
|
|
1744
|
+
svfloat16_t weight_pair_f16x = svzip1_f16(svcvt_f16_f32_x(predicate_all_b32x, weight_even_f32x),
|
|
1745
|
+
svcvt_f16_f32_x(predicate_all_b32x, weight_odd_f32x));
|
|
1746
|
+
svwrite_ver_za32_f32_m(0, column_index / 2, predicate_all_b32x, svreinterpret_f32_f16(weight_pair_f16x));
|
|
1742
1747
|
}
|
|
1743
|
-
row_sum_corrected_f32x = svadd_f32_x(
|
|
1744
|
-
svst1_f32(
|
|
1745
|
-
svst1_f32(
|
|
1748
|
+
row_sum_corrected_f32x = svadd_f32_x(predicate_all_b32x, row_sum_corrected_f32x, sum_delta_16_f32x);
|
|
1749
|
+
svst1_f32(predicate_all_b32x, row_sum, row_sum_corrected_f32x);
|
|
1750
|
+
svst1_f32(predicate_all_b32x, row_max, new_max_f32x);
|
|
1746
1751
|
|
|
1747
1752
|
if (valid_query_count == 1) {
|
|
1748
1753
|
// Decode path: extract f32 weights from ZA0 row 0 using SVE
|
|
1749
|
-
svfloat16_t
|
|
1750
|
-
|
|
1751
|
-
svfloat16_t
|
|
1754
|
+
svfloat16_t row0_f16x = svreinterpret_f16_f32(
|
|
1755
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
|
|
1756
|
+
svfloat16_t weights_even_f16x = svuzp1_f16(row0_f16x, row0_f16x);
|
|
1757
|
+
svfloat16_t weights_odd_f16x = svuzp2_f16(row0_f16x, row0_f16x);
|
|
1752
1758
|
NK_ALIGN64 nk_f32_t decode_weights[16];
|
|
1753
|
-
svst1_f32(svwhilelt_b32(0u, 8u), decode_weights, svcvt_f32_f16_x(svwhilelt_b32(0u, 8u),
|
|
1759
|
+
svst1_f32(svwhilelt_b32(0u, 8u), decode_weights, svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_even_f16x));
|
|
1754
1760
|
svst1_f32(svwhilelt_b32(0u, 8u), decode_weights + 8,
|
|
1755
|
-
svcvt_f32_f16_x(svwhilelt_b32(0u, 8u),
|
|
1761
|
+
svcvt_f32_f16_x(svwhilelt_b32(0u, 8u), weights_odd_f16x));
|
|
1756
1762
|
NK_ALIGN64 nk_f32_t decode_weights_ordered[16];
|
|
1757
1763
|
for (nk_size_t i = 0; i < 8; i++) {
|
|
1758
1764
|
decode_weights_ordered[2 * i] = decode_weights[i];
|
|
@@ -1760,42 +1766,42 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1760
1766
|
}
|
|
1761
1767
|
svfloat32_t corr_f32x = svdup_f32(corrections[0]);
|
|
1762
1768
|
for (nk_size_t d = 0; d < head_dim; d += svcntw()) {
|
|
1763
|
-
svbool_t
|
|
1764
|
-
svfloat32_t acc_f32x = svmul_f32_x(
|
|
1769
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(d, head_dim);
|
|
1770
|
+
svfloat32_t acc_f32x = svmul_f32_x(predicate_b32x, svld1_f32(predicate_b32x, output_accumulator + d),
|
|
1765
1771
|
corr_f32x);
|
|
1766
1772
|
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
1767
1773
|
nk_size_t dim_tile = d / 16, depth_s = ki / 2, sub = ki % 2;
|
|
1768
1774
|
nk_f16_t const *v_vec = v_packed +
|
|
1769
1775
|
(kv_block_index * dim_tile_count * 8 + dim_tile * 8 + depth_s) * 32;
|
|
1770
|
-
svfloat16_t packed_f16x = svld1_f16(
|
|
1771
|
-
svfloat16_t
|
|
1772
|
-
|
|
1773
|
-
acc_f32x = svmla_f32_x(
|
|
1774
|
-
svcvt_f32_f16_x(
|
|
1776
|
+
svfloat16_t packed_f16x = svld1_f16(predicate_all_b16x, (float16_t const *)v_vec);
|
|
1777
|
+
svfloat16_t v_selected_f16x = (sub == 0) ? svuzp1_f16(packed_f16x, packed_f16x)
|
|
1778
|
+
: svuzp2_f16(packed_f16x, packed_f16x);
|
|
1779
|
+
acc_f32x = svmla_f32_x(predicate_b32x, acc_f32x, svdup_f32(decode_weights_ordered[ki]),
|
|
1780
|
+
svcvt_f32_f16_x(predicate_b32x, v_selected_f16x));
|
|
1775
1781
|
}
|
|
1776
|
-
svst1_f32(
|
|
1782
|
+
svst1_f32(predicate_b32x, output_accumulator + d, acc_f32x);
|
|
1777
1783
|
}
|
|
1778
1784
|
}
|
|
1779
1785
|
else {
|
|
1780
1786
|
// Prefill Bc=16: extract P columns, pre-apply correction, add-after P×V
|
|
1781
|
-
svbool_t
|
|
1787
|
+
svbool_t query_predicate_b16x = svwhilelt_b16_u64(0u, valid_query_count * 2);
|
|
1782
1788
|
|
|
1783
1789
|
svfloat16_t probability_column_0_f32x = svreinterpret_f16_f32(
|
|
1784
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1790
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 0));
|
|
1785
1791
|
svfloat16_t probability_column_1_f32x = svreinterpret_f16_f32(
|
|
1786
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1792
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 1));
|
|
1787
1793
|
svfloat16_t probability_column_2_f32x = svreinterpret_f16_f32(
|
|
1788
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1794
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 2));
|
|
1789
1795
|
svfloat16_t probability_column_3_f32x = svreinterpret_f16_f32(
|
|
1790
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1796
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 3));
|
|
1791
1797
|
svfloat16_t probability_column_4_f32x = svreinterpret_f16_f32(
|
|
1792
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1798
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 4));
|
|
1793
1799
|
svfloat16_t probability_column_5_f32x = svreinterpret_f16_f32(
|
|
1794
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1800
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 5));
|
|
1795
1801
|
svfloat16_t probability_column_6_f32x = svreinterpret_f16_f32(
|
|
1796
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1802
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 6));
|
|
1797
1803
|
svfloat16_t probability_column_7_f32x = svreinterpret_f16_f32(
|
|
1798
|
-
svread_ver_za32_f32_m(svdup_f32(0),
|
|
1804
|
+
svread_ver_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, 7));
|
|
1799
1805
|
|
|
1800
1806
|
nk_f16_t const *v_block = v_packed + kv_block_index * dim_tile_count * 8 * 32;
|
|
1801
1807
|
|
|
@@ -1804,9 +1810,9 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1804
1810
|
svfloat32_t correction_scalar_f32x = svdup_f32(corrections[query_index]);
|
|
1805
1811
|
for (nk_size_t dim_offset = 0; dim_offset < head_dim_padded; dim_offset += 16)
|
|
1806
1812
|
svst1_f32(
|
|
1807
|
-
|
|
1808
|
-
svmul_f32_x(
|
|
1809
|
-
svld1_f32(
|
|
1813
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_offset,
|
|
1814
|
+
svmul_f32_x(predicate_all_b32x,
|
|
1815
|
+
svld1_f32(predicate_all_b32x,
|
|
1810
1816
|
output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
1811
1817
|
correction_scalar_f32x));
|
|
1812
1818
|
}
|
|
@@ -1816,188 +1822,188 @@ __arm_locally_streaming __arm_new("za") static void nk_attention_f16_sme_streami
|
|
|
1816
1822
|
for (; dim_tile + 4 <= dim_tile_count; dim_tile += 4) {
|
|
1817
1823
|
svzero_za();
|
|
1818
1824
|
svmopa_za32_f16_m(
|
|
1819
|
-
0,
|
|
1820
|
-
svld1_f16(
|
|
1825
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1826
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 0) * 32)));
|
|
1821
1827
|
svmopa_za32_f16_m(
|
|
1822
|
-
1,
|
|
1823
|
-
svld1_f16(
|
|
1828
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1829
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 0) * 32)));
|
|
1824
1830
|
svmopa_za32_f16_m(
|
|
1825
|
-
2,
|
|
1826
|
-
svld1_f16(
|
|
1831
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1832
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 0) * 32)));
|
|
1827
1833
|
svmopa_za32_f16_m(
|
|
1828
|
-
3,
|
|
1829
|
-
svld1_f16(
|
|
1834
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1835
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 0) * 32)));
|
|
1830
1836
|
svmopa_za32_f16_m(
|
|
1831
|
-
0,
|
|
1832
|
-
svld1_f16(
|
|
1837
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1838
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 1) * 32)));
|
|
1833
1839
|
svmopa_za32_f16_m(
|
|
1834
|
-
1,
|
|
1835
|
-
svld1_f16(
|
|
1840
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1841
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 1) * 32)));
|
|
1836
1842
|
svmopa_za32_f16_m(
|
|
1837
|
-
2,
|
|
1838
|
-
svld1_f16(
|
|
1843
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1844
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 1) * 32)));
|
|
1839
1845
|
svmopa_za32_f16_m(
|
|
1840
|
-
3,
|
|
1841
|
-
svld1_f16(
|
|
1846
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1847
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 1) * 32)));
|
|
1842
1848
|
svmopa_za32_f16_m(
|
|
1843
|
-
0,
|
|
1844
|
-
svld1_f16(
|
|
1849
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1850
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 2) * 32)));
|
|
1845
1851
|
svmopa_za32_f16_m(
|
|
1846
|
-
1,
|
|
1847
|
-
svld1_f16(
|
|
1852
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1853
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 2) * 32)));
|
|
1848
1854
|
svmopa_za32_f16_m(
|
|
1849
|
-
2,
|
|
1850
|
-
svld1_f16(
|
|
1855
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1856
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 2) * 32)));
|
|
1851
1857
|
svmopa_za32_f16_m(
|
|
1852
|
-
3,
|
|
1853
|
-
svld1_f16(
|
|
1858
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1859
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 2) * 32)));
|
|
1854
1860
|
svmopa_za32_f16_m(
|
|
1855
|
-
0,
|
|
1856
|
-
svld1_f16(
|
|
1861
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1862
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 3) * 32)));
|
|
1857
1863
|
svmopa_za32_f16_m(
|
|
1858
|
-
1,
|
|
1859
|
-
svld1_f16(
|
|
1864
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1865
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 3) * 32)));
|
|
1860
1866
|
svmopa_za32_f16_m(
|
|
1861
|
-
2,
|
|
1862
|
-
svld1_f16(
|
|
1867
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1868
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 3) * 32)));
|
|
1863
1869
|
svmopa_za32_f16_m(
|
|
1864
|
-
3,
|
|
1865
|
-
svld1_f16(
|
|
1870
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1871
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 3) * 32)));
|
|
1866
1872
|
svmopa_za32_f16_m(
|
|
1867
|
-
0,
|
|
1868
|
-
svld1_f16(
|
|
1873
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1874
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 4) * 32)));
|
|
1869
1875
|
svmopa_za32_f16_m(
|
|
1870
|
-
1,
|
|
1871
|
-
svld1_f16(
|
|
1876
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1877
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 4) * 32)));
|
|
1872
1878
|
svmopa_za32_f16_m(
|
|
1873
|
-
2,
|
|
1874
|
-
svld1_f16(
|
|
1879
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1880
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 4) * 32)));
|
|
1875
1881
|
svmopa_za32_f16_m(
|
|
1876
|
-
3,
|
|
1877
|
-
svld1_f16(
|
|
1882
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1883
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 4) * 32)));
|
|
1878
1884
|
svmopa_za32_f16_m(
|
|
1879
|
-
0,
|
|
1880
|
-
svld1_f16(
|
|
1885
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1886
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 5) * 32)));
|
|
1881
1887
|
svmopa_za32_f16_m(
|
|
1882
|
-
1,
|
|
1883
|
-
svld1_f16(
|
|
1888
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1889
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 5) * 32)));
|
|
1884
1890
|
svmopa_za32_f16_m(
|
|
1885
|
-
2,
|
|
1886
|
-
svld1_f16(
|
|
1891
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1892
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 5) * 32)));
|
|
1887
1893
|
svmopa_za32_f16_m(
|
|
1888
|
-
3,
|
|
1889
|
-
svld1_f16(
|
|
1894
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1895
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 5) * 32)));
|
|
1890
1896
|
svmopa_za32_f16_m(
|
|
1891
|
-
0,
|
|
1892
|
-
svld1_f16(
|
|
1897
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1898
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 6) * 32)));
|
|
1893
1899
|
svmopa_za32_f16_m(
|
|
1894
|
-
1,
|
|
1895
|
-
svld1_f16(
|
|
1900
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1901
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 6) * 32)));
|
|
1896
1902
|
svmopa_za32_f16_m(
|
|
1897
|
-
2,
|
|
1898
|
-
svld1_f16(
|
|
1903
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1904
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 6) * 32)));
|
|
1899
1905
|
svmopa_za32_f16_m(
|
|
1900
|
-
3,
|
|
1901
|
-
svld1_f16(
|
|
1906
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1907
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 6) * 32)));
|
|
1902
1908
|
svmopa_za32_f16_m(
|
|
1903
|
-
0,
|
|
1904
|
-
svld1_f16(
|
|
1909
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1910
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 0) * 8 + 7) * 32)));
|
|
1905
1911
|
svmopa_za32_f16_m(
|
|
1906
|
-
1,
|
|
1907
|
-
svld1_f16(
|
|
1912
|
+
1, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1913
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 1) * 8 + 7) * 32)));
|
|
1908
1914
|
svmopa_za32_f16_m(
|
|
1909
|
-
2,
|
|
1910
|
-
svld1_f16(
|
|
1915
|
+
2, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1916
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 2) * 8 + 7) * 32)));
|
|
1911
1917
|
svmopa_za32_f16_m(
|
|
1912
|
-
3,
|
|
1913
|
-
svld1_f16(
|
|
1918
|
+
3, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1919
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + ((dim_tile + 3) * 8 + 7) * 32)));
|
|
1914
1920
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1915
1921
|
svst1_f32(
|
|
1916
|
-
|
|
1917
|
-
svadd_f32_x(
|
|
1918
|
-
svld1_f32(
|
|
1922
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16,
|
|
1923
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1924
|
+
svld1_f32(predicate_all_b32x,
|
|
1919
1925
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 0) * 16),
|
|
1920
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1926
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
|
|
1921
1927
|
svst1_f32(
|
|
1922
|
-
|
|
1923
|
-
svadd_f32_x(
|
|
1924
|
-
svld1_f32(
|
|
1928
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16,
|
|
1929
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1930
|
+
svld1_f32(predicate_all_b32x,
|
|
1925
1931
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 1) * 16),
|
|
1926
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1932
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 1, query_index)));
|
|
1927
1933
|
svst1_f32(
|
|
1928
|
-
|
|
1929
|
-
svadd_f32_x(
|
|
1930
|
-
svld1_f32(
|
|
1934
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16,
|
|
1935
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1936
|
+
svld1_f32(predicate_all_b32x,
|
|
1931
1937
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 2) * 16),
|
|
1932
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1938
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 2, query_index)));
|
|
1933
1939
|
svst1_f32(
|
|
1934
|
-
|
|
1935
|
-
svadd_f32_x(
|
|
1936
|
-
svld1_f32(
|
|
1940
|
+
predicate_all_b32x, output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16,
|
|
1941
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1942
|
+
svld1_f32(predicate_all_b32x,
|
|
1937
1943
|
output_accumulator + query_index * head_dim_padded + (dim_tile + 3) * 16),
|
|
1938
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1944
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 3, query_index)));
|
|
1939
1945
|
}
|
|
1940
1946
|
}
|
|
1941
1947
|
for (; dim_tile < dim_tile_count; dim_tile++) {
|
|
1942
1948
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1943
1949
|
svmopa_za32_f16_m(
|
|
1944
|
-
0,
|
|
1945
|
-
svld1_f16(
|
|
1950
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_0_f32x,
|
|
1951
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 0) * 32)));
|
|
1946
1952
|
svmopa_za32_f16_m(
|
|
1947
|
-
0,
|
|
1948
|
-
svld1_f16(
|
|
1953
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_1_f32x,
|
|
1954
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 1) * 32)));
|
|
1949
1955
|
svmopa_za32_f16_m(
|
|
1950
|
-
0,
|
|
1951
|
-
svld1_f16(
|
|
1956
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_2_f32x,
|
|
1957
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 2) * 32)));
|
|
1952
1958
|
svmopa_za32_f16_m(
|
|
1953
|
-
0,
|
|
1954
|
-
svld1_f16(
|
|
1959
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_3_f32x,
|
|
1960
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 3) * 32)));
|
|
1955
1961
|
svmopa_za32_f16_m(
|
|
1956
|
-
0,
|
|
1957
|
-
svld1_f16(
|
|
1962
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_4_f32x,
|
|
1963
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 4) * 32)));
|
|
1958
1964
|
svmopa_za32_f16_m(
|
|
1959
|
-
0,
|
|
1960
|
-
svld1_f16(
|
|
1965
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_5_f32x,
|
|
1966
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 5) * 32)));
|
|
1961
1967
|
svmopa_za32_f16_m(
|
|
1962
|
-
0,
|
|
1963
|
-
svld1_f16(
|
|
1968
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_6_f32x,
|
|
1969
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 6) * 32)));
|
|
1964
1970
|
svmopa_za32_f16_m(
|
|
1965
|
-
0,
|
|
1966
|
-
svld1_f16(
|
|
1971
|
+
0, query_predicate_b16x, predicate_all_b16x, probability_column_7_f32x,
|
|
1972
|
+
svld1_f16(predicate_all_b16x, (float16_t const *)(v_block + (dim_tile * 8 + 7) * 32)));
|
|
1967
1973
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++)
|
|
1968
|
-
svst1_f32(
|
|
1969
|
-
svadd_f32_x(
|
|
1970
|
-
svld1_f32(
|
|
1974
|
+
svst1_f32(predicate_all_b32x, output_accumulator + query_index * head_dim_padded + dim_tile * 16,
|
|
1975
|
+
svadd_f32_x(predicate_all_b32x,
|
|
1976
|
+
svld1_f32(predicate_all_b32x,
|
|
1971
1977
|
output_accumulator + query_index * head_dim_padded + dim_tile * 16),
|
|
1972
|
-
svread_hor_za32_f32_m(svdup_f32(0),
|
|
1978
|
+
svread_hor_za32_f32_m(svdup_f32(0), predicate_all_b32x, 0, query_index)));
|
|
1973
1979
|
}
|
|
1974
1980
|
}
|
|
1975
1981
|
}
|
|
1976
1982
|
|
|
1977
1983
|
// Final normalization
|
|
1978
|
-
svfloat32_t final_sum_f32x = svld1_f32(
|
|
1984
|
+
svfloat32_t final_sum_f32x = svld1_f32(predicate_all_b32x, row_sum);
|
|
1979
1985
|
svfloat32_t ones_f32x = svdup_f32(1.0f);
|
|
1980
1986
|
svfloat32_t zeros_f32x = svdup_f32(0.0f);
|
|
1981
|
-
svbool_t
|
|
1982
|
-
svfloat32_t inv_sum_f32x = svsel_f32(
|
|
1987
|
+
svbool_t sum_positive_b32x = svcmpgt_f32(predicate_all_b32x, final_sum_f32x, zeros_f32x);
|
|
1988
|
+
svfloat32_t inv_sum_f32x = svsel_f32(sum_positive_b32x, svdiv_f32_x(predicate_all_b32x, ones_f32x, final_sum_f32x),
|
|
1983
1989
|
zeros_f32x);
|
|
1984
1990
|
|
|
1985
1991
|
NK_ALIGN64 nk_f32_t inv_sums[16];
|
|
1986
|
-
svst1_f32(
|
|
1992
|
+
svst1_f32(predicate_all_b32x, inv_sums, inv_sum_f32x);
|
|
1987
1993
|
|
|
1988
1994
|
for (nk_size_t query_index = 0; query_index < valid_query_count; query_index++) {
|
|
1989
1995
|
svfloat32_t inv_sum_f32x = svdup_f32(inv_sums[query_index]);
|
|
1990
1996
|
for (nk_size_t dim_offset = 0; dim_offset < head_dim; dim_offset += svcntw()) {
|
|
1991
|
-
svbool_t
|
|
1997
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(dim_offset, head_dim);
|
|
1992
1998
|
svfloat32_t output_f32x = svmul_f32_x(
|
|
1993
|
-
|
|
1994
|
-
svld1_f32(
|
|
1999
|
+
predicate_b32x,
|
|
2000
|
+
svld1_f32(predicate_b32x, output_accumulator + query_index * head_dim_padded + dim_offset),
|
|
1995
2001
|
inv_sum_f32x);
|
|
1996
|
-
svfloat16_t output_f16x = svcvt_f16_f32_x(
|
|
2002
|
+
svfloat16_t output_f16x = svcvt_f16_f32_x(predicate_b32x, output_f32x);
|
|
1997
2003
|
nk_size_t store_count = (head_dim - dim_offset) < (nk_size_t)svcntw() ? (head_dim - dim_offset)
|
|
1998
2004
|
: (nk_size_t)svcntw();
|
|
1999
|
-
svbool_t
|
|
2000
|
-
svst1_f16(
|
|
2005
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(0u, store_count);
|
|
2006
|
+
svst1_f16(predicate_b16x, (float16_t *)(output + query_index * head_dim + dim_offset), output_f16x);
|
|
2001
2007
|
}
|
|
2002
2008
|
}
|
|
2003
2009
|
}
|