numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -120,43 +120,43 @@ typedef struct {
|
|
|
120
120
|
* @param x Input vector (16 floats)
|
|
121
121
|
* @return exp(x) for each element
|
|
122
122
|
*/
|
|
123
|
-
NK_INTERNAL __m512 nk_exp_ps_avx512_(__m512
|
|
123
|
+
NK_INTERNAL __m512 nk_exp_ps_avx512_(__m512 x_f32x16) {
|
|
124
124
|
// Constants for Cody-Waite range reduction
|
|
125
|
-
const __m512
|
|
126
|
-
const __m512
|
|
127
|
-
const __m512
|
|
125
|
+
const __m512 log2e_f32x16 = _mm512_set1_ps(1.4426950408889634f);
|
|
126
|
+
const __m512 ln2_high_f32x16 = _mm512_set1_ps(0.693145751953125f);
|
|
127
|
+
const __m512 ln2_low_f32x16 = _mm512_set1_ps(1.42860682030941723212e-6f);
|
|
128
128
|
|
|
129
129
|
// Clamp to avoid overflow/underflow
|
|
130
|
-
const __m512
|
|
131
|
-
const __m512
|
|
132
|
-
|
|
130
|
+
const __m512 max_x_f32x16 = _mm512_set1_ps(88.3762626647949f);
|
|
131
|
+
const __m512 min_x_f32x16 = _mm512_set1_ps(-87.3365447504021f);
|
|
132
|
+
x_f32x16 = _mm512_max_ps(_mm512_min_ps(x_f32x16, max_x_f32x16), min_x_f32x16);
|
|
133
133
|
|
|
134
|
-
//
|
|
135
|
-
__m512
|
|
134
|
+
// n_f32x16 = round(x / ln(2))
|
|
135
|
+
__m512 n_f32x16 = _mm512_roundscale_ps(_mm512_mul_ps(x_f32x16, log2e_f32x16), _MM_FROUND_TO_NEAREST_INT);
|
|
136
136
|
|
|
137
|
-
//
|
|
138
|
-
__m512
|
|
139
|
-
|
|
137
|
+
// r_f32x16 = x - n_f32x16 × ln(2) using Cody-Waite for precision
|
|
138
|
+
__m512 r_f32x16 = _mm512_fnmadd_ps(n_f32x16, ln2_high_f32x16, x_f32x16);
|
|
139
|
+
r_f32x16 = _mm512_fnmadd_ps(n_f32x16, ln2_low_f32x16, r_f32x16);
|
|
140
140
|
|
|
141
|
-
// Polynomial approximation for exp(
|
|
141
|
+
// Polynomial approximation for exp(r_f32x16): Remez minimax degree 6
|
|
142
142
|
// Coefficients optimized for [-ln(2)/2, ln(2)/2]
|
|
143
|
-
__m512
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
// Reconstruct: exp(x) = 2ⁿ × exp(
|
|
143
|
+
__m512 p_f32x16 = _mm512_set1_ps(1.9875691500e-4f);
|
|
144
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.3981999507e-3f));
|
|
145
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(8.3334519073e-3f));
|
|
146
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(4.1665858030e-2f));
|
|
147
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.6666665459e-1f));
|
|
148
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(5.0000001201e-1f));
|
|
149
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.0f));
|
|
150
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.0f));
|
|
151
|
+
|
|
152
|
+
// Reconstruct: exp(x) = 2ⁿ × exp(r_f32x16)
|
|
153
153
|
// 2ⁿ via IEEE 754 exponent manipulation
|
|
154
|
-
__m512i
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
__m512
|
|
154
|
+
__m512i ni_i32x16 = _mm512_cvtps_epi32(n_f32x16);
|
|
155
|
+
ni_i32x16 = _mm512_add_epi32(ni_i32x16, _mm512_set1_epi32(127));
|
|
156
|
+
ni_i32x16 = _mm512_slli_epi32(ni_i32x16, 23);
|
|
157
|
+
__m512 pow2n_f32x16 = _mm512_castsi512_ps(ni_i32x16);
|
|
158
158
|
|
|
159
|
-
return _mm512_mul_ps(
|
|
159
|
+
return _mm512_mul_ps(p_f32x16, pow2n_f32x16);
|
|
160
160
|
}
|
|
161
161
|
|
|
162
162
|
/**
|
|
@@ -172,41 +172,41 @@ NK_INTERNAL __m512 nk_exp_ps_avx512_(__m512 x) {
|
|
|
172
172
|
* @param x Input vector (16 floats)
|
|
173
173
|
* @return exp(x) approximation
|
|
174
174
|
*/
|
|
175
|
-
NK_INTERNAL __m512 nk_exp_ps_fast_avx512_(__m512
|
|
175
|
+
NK_INTERNAL __m512 nk_exp_ps_fast_avx512_(__m512 x_f32x16) {
|
|
176
176
|
// Constants for Cody-Waite range reduction
|
|
177
|
-
const __m512
|
|
178
|
-
const __m512
|
|
179
|
-
const __m512
|
|
177
|
+
const __m512 log2e_f32x16 = _mm512_set1_ps(1.4426950408889634f);
|
|
178
|
+
const __m512 ln2_high_f32x16 = _mm512_set1_ps(0.693145751953125f);
|
|
179
|
+
const __m512 ln2_low_f32x16 = _mm512_set1_ps(1.42860682030941723212e-6f);
|
|
180
180
|
|
|
181
181
|
// Clamp to avoid overflow/underflow (same as accurate version)
|
|
182
|
-
const __m512
|
|
183
|
-
const __m512
|
|
184
|
-
|
|
182
|
+
const __m512 max_x_f32x16 = _mm512_set1_ps(88.3762626647949f);
|
|
183
|
+
const __m512 min_x_f32x16 = _mm512_set1_ps(-87.3365447504021f);
|
|
184
|
+
x_f32x16 = _mm512_max_ps(_mm512_min_ps(x_f32x16, max_x_f32x16), min_x_f32x16);
|
|
185
185
|
|
|
186
|
-
//
|
|
187
|
-
__m512
|
|
186
|
+
// n_f32x16 = round(x / ln(2))
|
|
187
|
+
__m512 n_f32x16 = _mm512_roundscale_ps(_mm512_mul_ps(x_f32x16, log2e_f32x16), _MM_FROUND_TO_NEAREST_INT);
|
|
188
188
|
|
|
189
|
-
//
|
|
190
|
-
__m512
|
|
191
|
-
|
|
189
|
+
// r_f32x16 = x - n_f32x16 × ln(2) using Cody-Waite for precision
|
|
190
|
+
__m512 r_f32x16 = _mm512_fnmadd_ps(n_f32x16, ln2_high_f32x16, x_f32x16);
|
|
191
|
+
r_f32x16 = _mm512_fnmadd_ps(n_f32x16, ln2_low_f32x16, r_f32x16);
|
|
192
192
|
|
|
193
|
-
// Polynomial approximation for exp(
|
|
193
|
+
// Polynomial approximation for exp(r_f32x16): degree 4
|
|
194
194
|
// Optimized coefficients for [-ln(2)/2, ln(2)/2]
|
|
195
|
-
// exp(
|
|
196
|
-
// Using Horner form: ((c₄ ×
|
|
197
|
-
__m512
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
// Reconstruct: exp(x) = 2ⁿ × exp(
|
|
204
|
-
__m512i
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
__m512
|
|
208
|
-
|
|
209
|
-
return _mm512_mul_ps(
|
|
195
|
+
// exp(r_f32x16) ≈ 1 + r_f32x16 + r²/2 + r³/6 + r⁴/24
|
|
196
|
+
// Using Horner form: ((c₄ × r_f32x16 + c₃) × r_f32x16 + c₂) × r_f32x16 + c₁) × r_f32x16 + c₀
|
|
197
|
+
__m512 p_f32x16 = _mm512_set1_ps(4.1666666667e-2f); // 1/24
|
|
198
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.6666666667e-1f)); // 1/6
|
|
199
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(5.0000000000e-1f)); // 1/2
|
|
200
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.0f)); // 1
|
|
201
|
+
p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.0f)); // 1
|
|
202
|
+
|
|
203
|
+
// Reconstruct: exp(x) = 2ⁿ × exp(r_f32x16)
|
|
204
|
+
__m512i ni_i32x16 = _mm512_cvtps_epi32(n_f32x16);
|
|
205
|
+
ni_i32x16 = _mm512_add_epi32(ni_i32x16, _mm512_set1_epi32(127));
|
|
206
|
+
ni_i32x16 = _mm512_slli_epi32(ni_i32x16, 23);
|
|
207
|
+
__m512 pow2n_f32x16 = _mm512_castsi512_ps(ni_i32x16);
|
|
208
|
+
|
|
209
|
+
return _mm512_mul_ps(p_f32x16, pow2n_f32x16);
|
|
210
210
|
}
|
|
211
211
|
|
|
212
212
|
/**
|
|
@@ -228,8 +228,8 @@ NK_INTERNAL __m512 nk_exp_ps_fast_avx512_(__m512 x) {
|
|
|
228
228
|
* Tracks per-row running maximum and sum for 16 rows.
|
|
229
229
|
*/
|
|
230
230
|
typedef struct {
|
|
231
|
-
__m512
|
|
232
|
-
__m512
|
|
231
|
+
__m512 row_max_f32x16; ///< Running max per row (16 values)
|
|
232
|
+
__m512 row_sum_f32x16; ///< Running sum of exp(x - max) per row
|
|
233
233
|
} nk_attention_softmax_row_state_t;
|
|
234
234
|
|
|
235
235
|
/**
|
|
@@ -246,80 +246,80 @@ NK_INTERNAL void nk_attention_softmax_update_bc32_(nk_attention_softmax_row_stat
|
|
|
246
246
|
nk_f32_t scale,
|
|
247
247
|
nk_f32_t *weights_out) { // [16, 32] output weights
|
|
248
248
|
|
|
249
|
-
__m512
|
|
249
|
+
__m512 scale_v_f32x16 = _mm512_set1_ps(scale);
|
|
250
250
|
|
|
251
251
|
// Load and scale all scores, compute per-row max
|
|
252
252
|
// Store in temporary arrays to avoid register pressure
|
|
253
|
-
__m512
|
|
253
|
+
__m512 s_scaled_f32x16[16][2];
|
|
254
254
|
NK_ALIGN64 float row_maxes[16];
|
|
255
255
|
|
|
256
256
|
// Process 4 rows at a time for ILP
|
|
257
257
|
for (int i = 0; i < 16; i += 4) {
|
|
258
258
|
// Row i
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
__m512
|
|
259
|
+
s_scaled_f32x16[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v_f32x16);
|
|
260
|
+
s_scaled_f32x16[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v_f32x16);
|
|
261
|
+
__m512 m0_f32x16 = _mm512_max_ps(s_scaled_f32x16[i][0], s_scaled_f32x16[i][1]);
|
|
262
262
|
|
|
263
263
|
// Row i+1
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
__m512
|
|
264
|
+
s_scaled_f32x16[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v_f32x16);
|
|
265
|
+
s_scaled_f32x16[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v_f32x16);
|
|
266
|
+
__m512 m1_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 1][0], s_scaled_f32x16[i + 1][1]);
|
|
267
267
|
|
|
268
268
|
// Row i+2
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
__m512
|
|
269
|
+
s_scaled_f32x16[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v_f32x16);
|
|
270
|
+
s_scaled_f32x16[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v_f32x16);
|
|
271
|
+
__m512 m2_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 2][0], s_scaled_f32x16[i + 2][1]);
|
|
272
272
|
|
|
273
273
|
// Row i+3
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
__m512
|
|
274
|
+
s_scaled_f32x16[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v_f32x16);
|
|
275
|
+
s_scaled_f32x16[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v_f32x16);
|
|
276
|
+
__m512 m3_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 3][0], s_scaled_f32x16[i + 3][1]);
|
|
277
277
|
|
|
278
278
|
// Reduce to scalar max
|
|
279
|
-
row_maxes[i] = _mm512_reduce_max_ps(
|
|
280
|
-
row_maxes[i + 1] = _mm512_reduce_max_ps(
|
|
281
|
-
row_maxes[i + 2] = _mm512_reduce_max_ps(
|
|
282
|
-
row_maxes[i + 3] = _mm512_reduce_max_ps(
|
|
279
|
+
row_maxes[i] = _mm512_reduce_max_ps(m0_f32x16);
|
|
280
|
+
row_maxes[i + 1] = _mm512_reduce_max_ps(m1_f32x16);
|
|
281
|
+
row_maxes[i + 2] = _mm512_reduce_max_ps(m2_f32x16);
|
|
282
|
+
row_maxes[i + 3] = _mm512_reduce_max_ps(m3_f32x16);
|
|
283
283
|
}
|
|
284
284
|
|
|
285
|
-
__m512
|
|
286
|
-
__m512
|
|
287
|
-
__m512
|
|
285
|
+
__m512 row_max_new_f32x16 = _mm512_load_ps(row_maxes);
|
|
286
|
+
__m512 old_max_f32x16 = state->row_max_f32x16;
|
|
287
|
+
__m512 new_max_f32x16 = _mm512_max_ps(old_max_f32x16, row_max_new_f32x16);
|
|
288
288
|
|
|
289
289
|
// Rescale old sum
|
|
290
|
-
__m512
|
|
291
|
-
__m512
|
|
290
|
+
__m512 correction_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(old_max_f32x16, new_max_f32x16));
|
|
291
|
+
__m512 new_sum_f32x16 = _mm512_mul_ps(state->row_sum_f32x16, correction_f32x16);
|
|
292
292
|
|
|
293
|
-
// Compute P = exp(S -
|
|
294
|
-
NK_ALIGN64
|
|
295
|
-
NK_ALIGN64
|
|
296
|
-
_mm512_store_ps(new_max_arr,
|
|
293
|
+
// Compute P = exp(S - new_max_f32x16) and accumulate sums
|
|
294
|
+
NK_ALIGN64 nk_f32_t new_max_arr[16];
|
|
295
|
+
NK_ALIGN64 nk_f32_t row_sums[16];
|
|
296
|
+
_mm512_store_ps(new_max_arr, new_max_f32x16);
|
|
297
297
|
|
|
298
298
|
// Process rows
|
|
299
299
|
for (int i = 0; i < 16; i += 2) {
|
|
300
|
-
__m512
|
|
301
|
-
__m512
|
|
300
|
+
__m512 max_i_f32x16 = _mm512_set1_ps(new_max_arr[i]);
|
|
301
|
+
__m512 max_i1_f32x16 = _mm512_set1_ps(new_max_arr[i + 1]);
|
|
302
302
|
|
|
303
303
|
// Row i
|
|
304
|
-
__m512
|
|
305
|
-
__m512
|
|
306
|
-
_mm512_store_ps(weights_out + i * 32 + 0,
|
|
307
|
-
_mm512_store_ps(weights_out + i * 32 + 16,
|
|
308
|
-
row_sums[i] = _mm512_reduce_add_ps(
|
|
304
|
+
__m512 p0_i_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled_f32x16[i][0], max_i_f32x16));
|
|
305
|
+
__m512 p1_i_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled_f32x16[i][1], max_i_f32x16));
|
|
306
|
+
_mm512_store_ps(weights_out + i * 32 + 0, p0_i_f32x16);
|
|
307
|
+
_mm512_store_ps(weights_out + i * 32 + 16, p1_i_f32x16);
|
|
308
|
+
row_sums[i] = _mm512_reduce_add_ps(p0_i_f32x16) + _mm512_reduce_add_ps(p1_i_f32x16);
|
|
309
309
|
|
|
310
310
|
// Row i+1
|
|
311
|
-
__m512
|
|
312
|
-
__m512
|
|
313
|
-
_mm512_store_ps(weights_out + (i + 1) * 32 + 0,
|
|
314
|
-
_mm512_store_ps(weights_out + (i + 1) * 32 + 16,
|
|
315
|
-
row_sums[i + 1] = _mm512_reduce_add_ps(
|
|
311
|
+
__m512 p0_i1_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled_f32x16[i + 1][0], max_i1_f32x16));
|
|
312
|
+
__m512 p1_i1_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled_f32x16[i + 1][1], max_i1_f32x16));
|
|
313
|
+
_mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1_f32x16);
|
|
314
|
+
_mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1_f32x16);
|
|
315
|
+
row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1_f32x16) + _mm512_reduce_add_ps(p1_i1_f32x16);
|
|
316
316
|
}
|
|
317
317
|
|
|
318
318
|
// Add row sums to running sum vectorially
|
|
319
|
-
|
|
319
|
+
new_sum_f32x16 = _mm512_add_ps(new_sum_f32x16, _mm512_load_ps(row_sums));
|
|
320
320
|
|
|
321
|
-
state->
|
|
322
|
-
state->
|
|
321
|
+
state->row_max_f32x16 = new_max_f32x16;
|
|
322
|
+
state->row_sum_f32x16 = new_sum_f32x16;
|
|
323
323
|
}
|
|
324
324
|
|
|
325
325
|
/**
|
|
@@ -335,81 +335,81 @@ NK_INTERNAL void nk_attention_softmax_update_bc32_fast_(nk_attention_softmax_row
|
|
|
335
335
|
nk_f32_t scale,
|
|
336
336
|
nk_f32_t *weights_out) { // [16, 32] output weights
|
|
337
337
|
|
|
338
|
-
__m512
|
|
338
|
+
__m512 scale_v_f32x16 = _mm512_set1_ps(scale);
|
|
339
339
|
|
|
340
340
|
// Load and scale all scores, compute per-row max
|
|
341
|
-
__m512
|
|
341
|
+
__m512 s_scaled_f32x16[16][2];
|
|
342
342
|
NK_ALIGN64 float row_maxes[16];
|
|
343
343
|
|
|
344
344
|
// Process 4 rows at a time for ILP
|
|
345
345
|
for (int i = 0; i < 16; i += 4) {
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
__m512
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
__m512
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
__m512
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
__m512
|
|
361
|
-
|
|
362
|
-
row_maxes[i] = _mm512_reduce_max_ps(
|
|
363
|
-
row_maxes[i + 1] = _mm512_reduce_max_ps(
|
|
364
|
-
row_maxes[i + 2] = _mm512_reduce_max_ps(
|
|
365
|
-
row_maxes[i + 3] = _mm512_reduce_max_ps(
|
|
346
|
+
s_scaled_f32x16[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v_f32x16);
|
|
347
|
+
s_scaled_f32x16[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v_f32x16);
|
|
348
|
+
__m512 m0_f32x16 = _mm512_max_ps(s_scaled_f32x16[i][0], s_scaled_f32x16[i][1]);
|
|
349
|
+
|
|
350
|
+
s_scaled_f32x16[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v_f32x16);
|
|
351
|
+
s_scaled_f32x16[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v_f32x16);
|
|
352
|
+
__m512 m1_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 1][0], s_scaled_f32x16[i + 1][1]);
|
|
353
|
+
|
|
354
|
+
s_scaled_f32x16[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v_f32x16);
|
|
355
|
+
s_scaled_f32x16[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v_f32x16);
|
|
356
|
+
__m512 m2_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 2][0], s_scaled_f32x16[i + 2][1]);
|
|
357
|
+
|
|
358
|
+
s_scaled_f32x16[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v_f32x16);
|
|
359
|
+
s_scaled_f32x16[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v_f32x16);
|
|
360
|
+
__m512 m3_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 3][0], s_scaled_f32x16[i + 3][1]);
|
|
361
|
+
|
|
362
|
+
row_maxes[i] = _mm512_reduce_max_ps(m0_f32x16);
|
|
363
|
+
row_maxes[i + 1] = _mm512_reduce_max_ps(m1_f32x16);
|
|
364
|
+
row_maxes[i + 2] = _mm512_reduce_max_ps(m2_f32x16);
|
|
365
|
+
row_maxes[i + 3] = _mm512_reduce_max_ps(m3_f32x16);
|
|
366
366
|
}
|
|
367
367
|
|
|
368
|
-
__m512
|
|
369
|
-
__m512
|
|
370
|
-
__m512
|
|
368
|
+
__m512 row_max_new_f32x16 = _mm512_load_ps(row_maxes);
|
|
369
|
+
__m512 old_max_f32x16 = state->row_max_f32x16;
|
|
370
|
+
__m512 new_max_f32x16 = _mm512_max_ps(old_max_f32x16, row_max_new_f32x16);
|
|
371
371
|
|
|
372
372
|
// Rescale old sum using fast exp
|
|
373
|
-
__m512
|
|
374
|
-
__m512
|
|
373
|
+
__m512 correction_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(old_max_f32x16, new_max_f32x16));
|
|
374
|
+
__m512 new_sum_f32x16 = _mm512_mul_ps(state->row_sum_f32x16, correction_f32x16);
|
|
375
375
|
|
|
376
|
-
// Compute P = exp(S -
|
|
377
|
-
NK_ALIGN64
|
|
378
|
-
NK_ALIGN64
|
|
379
|
-
_mm512_store_ps(new_max_arr,
|
|
376
|
+
// Compute P = exp(S - new_max_f32x16) using fast exp
|
|
377
|
+
NK_ALIGN64 nk_f32_t new_max_arr[16];
|
|
378
|
+
NK_ALIGN64 nk_f32_t row_sums[16];
|
|
379
|
+
_mm512_store_ps(new_max_arr, new_max_f32x16);
|
|
380
380
|
|
|
381
381
|
// Process rows with fast exp
|
|
382
382
|
for (int i = 0; i < 16; i += 2) {
|
|
383
|
-
__m512
|
|
384
|
-
__m512
|
|
383
|
+
__m512 max_i_f32x16 = _mm512_set1_ps(new_max_arr[i]);
|
|
384
|
+
__m512 max_i1_f32x16 = _mm512_set1_ps(new_max_arr[i + 1]);
|
|
385
385
|
|
|
386
386
|
// Row i
|
|
387
|
-
__m512
|
|
388
|
-
__m512
|
|
389
|
-
_mm512_store_ps(weights_out + i * 32 + 0,
|
|
390
|
-
_mm512_store_ps(weights_out + i * 32 + 16,
|
|
391
|
-
row_sums[i] = _mm512_reduce_add_ps(
|
|
387
|
+
__m512 p0_i_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled_f32x16[i][0], max_i_f32x16));
|
|
388
|
+
__m512 p1_i_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled_f32x16[i][1], max_i_f32x16));
|
|
389
|
+
_mm512_store_ps(weights_out + i * 32 + 0, p0_i_f32x16);
|
|
390
|
+
_mm512_store_ps(weights_out + i * 32 + 16, p1_i_f32x16);
|
|
391
|
+
row_sums[i] = _mm512_reduce_add_ps(p0_i_f32x16) + _mm512_reduce_add_ps(p1_i_f32x16);
|
|
392
392
|
|
|
393
393
|
// Row i+1
|
|
394
|
-
__m512
|
|
395
|
-
__m512
|
|
396
|
-
_mm512_store_ps(weights_out + (i + 1) * 32 + 0,
|
|
397
|
-
_mm512_store_ps(weights_out + (i + 1) * 32 + 16,
|
|
398
|
-
row_sums[i + 1] = _mm512_reduce_add_ps(
|
|
394
|
+
__m512 p0_i1_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled_f32x16[i + 1][0], max_i1_f32x16));
|
|
395
|
+
__m512 p1_i1_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled_f32x16[i + 1][1], max_i1_f32x16));
|
|
396
|
+
_mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1_f32x16);
|
|
397
|
+
_mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1_f32x16);
|
|
398
|
+
row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1_f32x16) + _mm512_reduce_add_ps(p1_i1_f32x16);
|
|
399
399
|
}
|
|
400
400
|
|
|
401
|
-
|
|
401
|
+
new_sum_f32x16 = _mm512_add_ps(new_sum_f32x16, _mm512_load_ps(row_sums));
|
|
402
402
|
|
|
403
|
-
state->
|
|
404
|
-
state->
|
|
403
|
+
state->row_max_f32x16 = new_max_f32x16;
|
|
404
|
+
state->row_sum_f32x16 = new_sum_f32x16;
|
|
405
405
|
}
|
|
406
406
|
|
|
407
407
|
/**
|
|
408
408
|
* @brief Initialize online softmax state.
|
|
409
409
|
*/
|
|
410
410
|
NK_INTERNAL void nk_attention_softmax_init_(nk_attention_softmax_row_state_t *state) {
|
|
411
|
-
state->
|
|
412
|
-
state->
|
|
411
|
+
state->row_max_f32x16 = _mm512_set1_ps(NK_F32_MIN);
|
|
412
|
+
state->row_sum_f32x16 = _mm512_setzero_ps();
|
|
413
413
|
}
|
|
414
414
|
|
|
415
415
|
/**
|
|
@@ -430,43 +430,43 @@ NK_INTERNAL void nk_attention_softmax_init_(nk_attention_softmax_row_state_t *st
|
|
|
430
430
|
NK_INTERNAL void nk_attention_softmax_update_(nk_attention_softmax_row_state_t *state, nk_f32_t const *scores,
|
|
431
431
|
nk_f32_t scale, nk_f32_t *weights_out) {
|
|
432
432
|
|
|
433
|
-
__m512
|
|
433
|
+
__m512 scale_v_f32x16 = _mm512_set1_ps(scale);
|
|
434
434
|
|
|
435
435
|
// Load scores into 16 ZMM registers (one per row)
|
|
436
|
-
__m512
|
|
437
|
-
for (int i = 0; i < 16; i++) {
|
|
436
|
+
__m512 s_f32x16[16];
|
|
437
|
+
for (int i = 0; i < 16; i++) { s_f32x16[i] = _mm512_mul_ps(_mm512_load_ps(scores + i * 16), scale_v_f32x16); }
|
|
438
438
|
|
|
439
439
|
// Per-row max (each row has 16 elements, we need max across those 16)
|
|
440
440
|
// _mm512_reduce_max_ps returns a float scalar
|
|
441
441
|
NK_ALIGN64 float row_maxes[16];
|
|
442
|
-
for (int i = 0; i < 16; i++) { row_maxes[i] = _mm512_reduce_max_ps(
|
|
443
|
-
__m512
|
|
442
|
+
for (int i = 0; i < 16; i++) { row_maxes[i] = _mm512_reduce_max_ps(s_f32x16[i]); }
|
|
443
|
+
__m512 row_max_new_f32x16 = _mm512_load_ps(row_maxes);
|
|
444
444
|
|
|
445
445
|
// Update running max
|
|
446
|
-
__m512
|
|
447
|
-
__m512
|
|
446
|
+
__m512 old_max_f32x16 = state->row_max_f32x16;
|
|
447
|
+
__m512 new_max_f32x16 = _mm512_max_ps(old_max_f32x16, row_max_new_f32x16);
|
|
448
448
|
|
|
449
449
|
// Rescale old sum: l = l × exp(oldₘₐₓ - newₘₐₓ)
|
|
450
|
-
__m512
|
|
451
|
-
__m512
|
|
450
|
+
__m512 correction_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(old_max_f32x16, new_max_f32x16));
|
|
451
|
+
__m512 old_sum_rescaled_f32x16 = _mm512_mul_ps(state->row_sum_f32x16, correction_f32x16);
|
|
452
452
|
|
|
453
453
|
// Compute P = exp(S - newₘₐₓ) for each row, accumulate sum
|
|
454
|
-
__m512
|
|
455
|
-
|
|
456
|
-
_mm512_store_ps(new_max_arr,
|
|
454
|
+
__m512 new_sum_f32x16 = old_sum_rescaled_f32x16;
|
|
455
|
+
nk_f32_t new_max_arr[16];
|
|
456
|
+
_mm512_store_ps(new_max_arr, new_max_f32x16);
|
|
457
457
|
|
|
458
458
|
for (int i = 0; i < 16; i++) {
|
|
459
|
-
__m512
|
|
460
|
-
__m512
|
|
461
|
-
_mm512_store_ps(weights_out + i * 16,
|
|
459
|
+
__m512 max_broadcast_f32x16 = _mm512_set1_ps(new_max_arr[i]);
|
|
460
|
+
__m512 p_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_f32x16[i], max_broadcast_f32x16));
|
|
461
|
+
_mm512_store_ps(weights_out + i * 16, p_f32x16);
|
|
462
462
|
|
|
463
463
|
// Add row sum to running sum (at position i)
|
|
464
|
-
|
|
465
|
-
|
|
464
|
+
nk_f32_t row_sum = _mm512_reduce_add_ps(p_f32x16);
|
|
465
|
+
new_sum_f32x16 = _mm512_mask_add_ps(new_sum_f32x16, 1u << i, new_sum_f32x16, _mm512_set1_ps(row_sum));
|
|
466
466
|
}
|
|
467
467
|
|
|
468
|
-
state->
|
|
469
|
-
state->
|
|
468
|
+
state->row_max_f32x16 = new_max_f32x16;
|
|
469
|
+
state->row_sum_f32x16 = new_sum_f32x16;
|
|
470
470
|
}
|
|
471
471
|
|
|
472
472
|
/**
|
|
@@ -480,18 +480,19 @@ NK_INTERNAL void nk_attention_softmax_update_(nk_attention_softmax_row_state_t *
|
|
|
480
480
|
* @param old_max Previous running max per row (16 values)
|
|
481
481
|
* @param new_max New running max per row (16 values)
|
|
482
482
|
*/
|
|
483
|
-
NK_INTERNAL void nk_attention_rescale_output_(nk_f32_t *output, nk_size_t head_dim, __m512
|
|
483
|
+
NK_INTERNAL void nk_attention_rescale_output_(nk_f32_t *output, nk_size_t head_dim, __m512 old_max_f32x16,
|
|
484
|
+
__m512 new_max_f32x16) {
|
|
484
485
|
|
|
485
|
-
__m512
|
|
486
|
-
|
|
487
|
-
_mm512_store_ps(corr_arr,
|
|
486
|
+
__m512 correction_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(old_max_f32x16, new_max_f32x16));
|
|
487
|
+
nk_f32_t corr_arr[16];
|
|
488
|
+
_mm512_store_ps(corr_arr, correction_f32x16);
|
|
488
489
|
|
|
489
490
|
for (nk_size_t row = 0; row < 16; row++) {
|
|
490
|
-
__m512
|
|
491
|
+
__m512 corr_v_f32x16 = _mm512_set1_ps(corr_arr[row]);
|
|
491
492
|
for (nk_size_t col = 0; col < head_dim; col += 16) {
|
|
492
|
-
__m512
|
|
493
|
-
|
|
494
|
-
_mm512_store_ps(output + row * head_dim + col,
|
|
493
|
+
__m512 o_f32x16 = _mm512_load_ps(output + row * head_dim + col);
|
|
494
|
+
o_f32x16 = _mm512_mul_ps(o_f32x16, corr_v_f32x16);
|
|
495
|
+
_mm512_store_ps(output + row * head_dim + col, o_f32x16);
|
|
495
496
|
}
|
|
496
497
|
}
|
|
497
498
|
}
|
|
@@ -790,22 +791,22 @@ NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_
|
|
|
790
791
|
// Phase 1: Compute S = Q × Kᵀ using AVX-512 FMA
|
|
791
792
|
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
792
793
|
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
793
|
-
__m512
|
|
794
|
+
__m512 sum_v_f32x16 = _mm512_setzero_ps();
|
|
794
795
|
nk_size_t d = 0;
|
|
795
796
|
// Vectorized loop over head_dim
|
|
796
797
|
for (; d + 16 <= head_dim; d += 16) {
|
|
797
|
-
__m512
|
|
798
|
+
__m512 q_v_f32x16 = _mm512_loadu_ps(&q_block[qi * head_dim + d]);
|
|
798
799
|
// Kᵀ is stored as [head_dim, kv], gather is slow, use scalar for now
|
|
799
|
-
__m512
|
|
800
|
+
__m512 k_v_f32x16 = _mm512_set_ps(
|
|
800
801
|
k_block[(d + 15) * 16 + ki], k_block[(d + 14) * 16 + ki], k_block[(d + 13) * 16 + ki],
|
|
801
802
|
k_block[(d + 12) * 16 + ki], k_block[(d + 11) * 16 + ki], k_block[(d + 10) * 16 + ki],
|
|
802
803
|
k_block[(d + 9) * 16 + ki], k_block[(d + 8) * 16 + ki], k_block[(d + 7) * 16 + ki],
|
|
803
804
|
k_block[(d + 6) * 16 + ki], k_block[(d + 5) * 16 + ki], k_block[(d + 4) * 16 + ki],
|
|
804
805
|
k_block[(d + 3) * 16 + ki], k_block[(d + 2) * 16 + ki], k_block[(d + 1) * 16 + ki],
|
|
805
806
|
k_block[(d + 0) * 16 + ki]);
|
|
806
|
-
|
|
807
|
+
sum_v_f32x16 = _mm512_fmadd_ps(q_v_f32x16, k_v_f32x16, sum_v_f32x16);
|
|
807
808
|
}
|
|
808
|
-
nk_f32_t sum = _mm512_reduce_add_ps(
|
|
809
|
+
nk_f32_t sum = _mm512_reduce_add_ps(sum_v_f32x16);
|
|
809
810
|
// Scalar tail
|
|
810
811
|
for (; d < head_dim; d++) { sum += q_block[qi * head_dim + d] * k_block[d * 16 + ki]; }
|
|
811
812
|
scores[qi * 16 + ki] = sum;
|
|
@@ -819,11 +820,11 @@ NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_
|
|
|
819
820
|
}
|
|
820
821
|
|
|
821
822
|
// Phase 2: Online softmax update
|
|
822
|
-
__m512
|
|
823
|
+
__m512 old_max_f32x16 = softmax_state.row_max_f32x16;
|
|
823
824
|
nk_attention_softmax_update_(&softmax_state, scores, scale, weights);
|
|
824
825
|
|
|
825
826
|
// Rescale output accumulator if max changed
|
|
826
|
-
nk_attention_rescale_output_(o_acc, head_dim_padded,
|
|
827
|
+
nk_attention_rescale_output_(o_acc, head_dim_padded, old_max_f32x16, softmax_state.row_max_f32x16);
|
|
827
828
|
|
|
828
829
|
// Extract V block: V[valid_kv, head_dim] using bulk extraction
|
|
829
830
|
nk_attention_extract_v_block_(v_packed, v_block, kv_h, kvb, valid_kv, head_dim, kv_len);
|
|
@@ -833,13 +834,13 @@ NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_
|
|
|
833
834
|
nk_size_t d = 0;
|
|
834
835
|
// Vectorized loop over head_dim
|
|
835
836
|
for (; d + 16 <= head_dim; d += 16) {
|
|
836
|
-
__m512
|
|
837
|
+
__m512 acc_v_f32x16 = _mm512_loadu_ps(&o_acc[qi * head_dim_padded + d]);
|
|
837
838
|
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
838
|
-
__m512
|
|
839
|
-
__m512
|
|
840
|
-
|
|
839
|
+
__m512 p_v_f32x16 = _mm512_set1_ps(weights[qi * 16 + ki]);
|
|
840
|
+
__m512 v_v_f32x16 = _mm512_loadu_ps(&v_block[ki * head_dim + d]);
|
|
841
|
+
acc_v_f32x16 = _mm512_fmadd_ps(p_v_f32x16, v_v_f32x16, acc_v_f32x16);
|
|
841
842
|
}
|
|
842
|
-
_mm512_storeu_ps(&o_acc[qi * head_dim_padded + d],
|
|
843
|
+
_mm512_storeu_ps(&o_acc[qi * head_dim_padded + d], acc_v_f32x16);
|
|
843
844
|
}
|
|
844
845
|
// Scalar tail
|
|
845
846
|
for (; d < head_dim; d++) {
|
|
@@ -853,8 +854,8 @@ NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_
|
|
|
853
854
|
}
|
|
854
855
|
|
|
855
856
|
// Finalize: normalize O by row sums
|
|
856
|
-
|
|
857
|
-
_mm512_store_ps(row_sums, softmax_state.
|
|
857
|
+
nk_f32_t row_sums[16];
|
|
858
|
+
_mm512_store_ps(row_sums, softmax_state.row_sum_f32x16);
|
|
858
859
|
|
|
859
860
|
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
860
861
|
nk_f32_t inv_sum = 1.0f / row_sums[qi];
|
|
@@ -918,12 +919,12 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
|
|
|
918
919
|
nk_attention_softmax_init_(&softmax_state);
|
|
919
920
|
|
|
920
921
|
// Zero output accumulator using SIMD
|
|
921
|
-
__m512
|
|
922
|
+
__m512 zero_f32x16 = _mm512_setzero_ps();
|
|
922
923
|
for (nk_size_t i = 0; i < 16 * head_dim_padded; i += 64) {
|
|
923
|
-
_mm512_store_ps(&o_acc[i],
|
|
924
|
-
_mm512_store_ps(&o_acc[i + 16],
|
|
925
|
-
_mm512_store_ps(&o_acc[i + 32],
|
|
926
|
-
_mm512_store_ps(&o_acc[i + 48],
|
|
924
|
+
_mm512_store_ps(&o_acc[i], zero_f32x16);
|
|
925
|
+
_mm512_store_ps(&o_acc[i + 16], zero_f32x16);
|
|
926
|
+
_mm512_store_ps(&o_acc[i + 32], zero_f32x16);
|
|
927
|
+
_mm512_store_ps(&o_acc[i + 48], zero_f32x16);
|
|
927
928
|
}
|
|
928
929
|
|
|
929
930
|
// Process KV blocks in chunks of 32
|
|
@@ -949,10 +950,10 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
|
|
|
949
950
|
for (nk_size_t row = 0; row < valid_q; row++) {
|
|
950
951
|
nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
|
|
951
952
|
// Load 32 BF16 values (64 bytes) using two 256-bit loads
|
|
952
|
-
__m256i
|
|
953
|
-
__m256i
|
|
954
|
-
_mm256_store_si256((__m256i *)&q_tile[row][0],
|
|
955
|
-
_mm256_store_si256((__m256i *)&q_tile[row][16],
|
|
953
|
+
__m256i q0_bf16x16 = _mm256_loadu_si256((__m256i const *)q_row);
|
|
954
|
+
__m256i q1_bf16x16 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
|
|
955
|
+
_mm256_store_si256((__m256i *)&q_tile[row][0], q0_bf16x16);
|
|
956
|
+
_mm256_store_si256((__m256i *)&q_tile[row][16], q1_bf16x16);
|
|
956
957
|
}
|
|
957
958
|
}
|
|
958
959
|
else {
|
|
@@ -990,23 +991,23 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
|
|
|
990
991
|
// Use SIMD for fast extraction
|
|
991
992
|
_tile_stored(0, s_tile, 64);
|
|
992
993
|
|
|
993
|
-
__m512
|
|
994
|
+
__m512 neg_inf_f32x16 = _mm512_set1_ps(NK_F32_MIN);
|
|
994
995
|
|
|
995
996
|
if (valid_q == 16 && valid_kv >= 16) {
|
|
996
997
|
// Fast path: full first half, just copy
|
|
997
998
|
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
998
|
-
__m512
|
|
999
|
-
_mm512_store_ps(&scores[qi * 32],
|
|
999
|
+
__m512 s0_f32x16 = _mm512_load_ps(&s_tile[qi][0]);
|
|
1000
|
+
_mm512_store_ps(&scores[qi * 32], s0_f32x16);
|
|
1000
1001
|
}
|
|
1001
1002
|
}
|
|
1002
1003
|
else {
|
|
1003
1004
|
// Partial - need masking
|
|
1004
1005
|
__mmask16 kv_mask = (1u << valid_kv) - 1;
|
|
1005
1006
|
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1006
|
-
__m512
|
|
1007
|
-
if (qi < valid_q) {
|
|
1008
|
-
else {
|
|
1009
|
-
_mm512_store_ps(&scores[qi * 32],
|
|
1007
|
+
__m512 s0_f32x16 = _mm512_load_ps(&s_tile[qi][0]);
|
|
1008
|
+
if (qi < valid_q) { s0_f32x16 = _mm512_mask_blend_ps(kv_mask, neg_inf_f32x16, s0_f32x16); }
|
|
1009
|
+
else { s0_f32x16 = neg_inf_f32x16; }
|
|
1010
|
+
_mm512_store_ps(&scores[qi * 32], s0_f32x16);
|
|
1010
1011
|
}
|
|
1011
1012
|
}
|
|
1012
1013
|
|
|
@@ -1018,36 +1019,36 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
|
|
|
1018
1019
|
if (valid_q == 16 && valid_kv2 >= 16) {
|
|
1019
1020
|
// Fast path
|
|
1020
1021
|
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1021
|
-
__m512
|
|
1022
|
-
_mm512_store_ps(&scores[qi * 32 + 16],
|
|
1022
|
+
__m512 s1_f32x16 = _mm512_load_ps(&s_tile[qi][0]);
|
|
1023
|
+
_mm512_store_ps(&scores[qi * 32 + 16], s1_f32x16);
|
|
1023
1024
|
}
|
|
1024
1025
|
}
|
|
1025
1026
|
else {
|
|
1026
1027
|
__mmask16 kv_mask2 = (valid_kv2 >= 16) ? 0xFFFF : ((1u << valid_kv2) - 1);
|
|
1027
1028
|
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1028
|
-
__m512
|
|
1029
|
-
if (qi < valid_q) {
|
|
1030
|
-
else {
|
|
1031
|
-
_mm512_store_ps(&scores[qi * 32 + 16],
|
|
1029
|
+
__m512 s1_f32x16 = _mm512_load_ps(&s_tile[qi][0]);
|
|
1030
|
+
if (qi < valid_q) { s1_f32x16 = _mm512_mask_blend_ps(kv_mask2, neg_inf_f32x16, s1_f32x16); }
|
|
1031
|
+
else { s1_f32x16 = neg_inf_f32x16; }
|
|
1032
|
+
_mm512_store_ps(&scores[qi * 32 + 16], s1_f32x16);
|
|
1032
1033
|
}
|
|
1033
1034
|
}
|
|
1034
1035
|
}
|
|
1035
1036
|
else {
|
|
1036
1037
|
// Mask out second half entirely
|
|
1037
|
-
for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi * 32 + 16],
|
|
1038
|
+
for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi * 32 + 16], neg_inf_f32x16); }
|
|
1038
1039
|
}
|
|
1039
1040
|
|
|
1040
1041
|
// Phase 2: online softmax (fast degree-4 exp)
|
|
1041
|
-
__m512
|
|
1042
|
+
__m512 old_max_f32x16 = softmax_state.row_max_f32x16;
|
|
1042
1043
|
nk_attention_softmax_update_bc32_fast_(&softmax_state, scores, scale, weights);
|
|
1043
|
-
nk_attention_rescale_output_(o_acc, head_dim_padded,
|
|
1044
|
+
nk_attention_rescale_output_(o_acc, head_dim_padded, old_max_f32x16, softmax_state.row_max_f32x16);
|
|
1044
1045
|
|
|
1045
1046
|
// Phase 3: O += P × V using AMX
|
|
1046
1047
|
// Convert P[16, 32] from F32 to BF16 and pack as A-tile
|
|
1047
1048
|
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1048
1049
|
for (nk_size_t ki = 0; ki < 32; ki += 16) {
|
|
1049
|
-
__m512
|
|
1050
|
-
__m256bh p_bf16 = _mm512_cvtneps_pbh(
|
|
1050
|
+
__m512 p_f32_f32x16 = _mm512_loadu_ps(&weights[qi * 32 + ki]);
|
|
1051
|
+
__m256bh p_bf16 = _mm512_cvtneps_pbh(p_f32_f32x16);
|
|
1051
1052
|
// Store BF16 vector - cast through union or memory
|
|
1052
1053
|
*(__m256bh *)&p_tile[qi][ki] = p_bf16;
|
|
1053
1054
|
}
|
|
@@ -1079,29 +1080,29 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
|
|
|
1079
1080
|
_tile_stored(5, o_tile, 64);
|
|
1080
1081
|
|
|
1081
1082
|
// Add to output accumulator - unrolled for all 16 rows
|
|
1082
|
-
// Even if valid_q < 16, we accumulate all (padded rows have
|
|
1083
|
+
// Even if valid_q < 16, we accumulate all (padded rows have zero_f32x16 weights)
|
|
1083
1084
|
for (nk_size_t qi = 0; qi < 16; qi += 4) {
|
|
1084
|
-
__m512
|
|
1085
|
-
__m512
|
|
1086
|
-
__m512
|
|
1087
|
-
__m512
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
_mm512_store_ps(&o_acc[(qi + 0) * head_dim_padded + head_start],
|
|
1095
|
-
_mm512_store_ps(&o_acc[(qi + 1) * head_dim_padded + head_start],
|
|
1096
|
-
_mm512_store_ps(&o_acc[(qi + 2) * head_dim_padded + head_start],
|
|
1097
|
-
_mm512_store_ps(&o_acc[(qi + 3) * head_dim_padded + head_start],
|
|
1085
|
+
__m512 acc0_f32x16 = _mm512_load_ps(&o_acc[(qi + 0) * head_dim_padded + head_start]);
|
|
1086
|
+
__m512 acc1_f32x16 = _mm512_load_ps(&o_acc[(qi + 1) * head_dim_padded + head_start]);
|
|
1087
|
+
__m512 acc2_f32x16 = _mm512_load_ps(&o_acc[(qi + 2) * head_dim_padded + head_start]);
|
|
1088
|
+
__m512 acc3_f32x16 = _mm512_load_ps(&o_acc[(qi + 3) * head_dim_padded + head_start]);
|
|
1089
|
+
|
|
1090
|
+
acc0_f32x16 = _mm512_add_ps(acc0_f32x16, _mm512_load_ps(&o_tile[qi + 0][0]));
|
|
1091
|
+
acc1_f32x16 = _mm512_add_ps(acc1_f32x16, _mm512_load_ps(&o_tile[qi + 1][0]));
|
|
1092
|
+
acc2_f32x16 = _mm512_add_ps(acc2_f32x16, _mm512_load_ps(&o_tile[qi + 2][0]));
|
|
1093
|
+
acc3_f32x16 = _mm512_add_ps(acc3_f32x16, _mm512_load_ps(&o_tile[qi + 3][0]));
|
|
1094
|
+
|
|
1095
|
+
_mm512_store_ps(&o_acc[(qi + 0) * head_dim_padded + head_start], acc0_f32x16);
|
|
1096
|
+
_mm512_store_ps(&o_acc[(qi + 1) * head_dim_padded + head_start], acc1_f32x16);
|
|
1097
|
+
_mm512_store_ps(&o_acc[(qi + 2) * head_dim_padded + head_start], acc2_f32x16);
|
|
1098
|
+
_mm512_store_ps(&o_acc[(qi + 3) * head_dim_padded + head_start], acc3_f32x16);
|
|
1098
1099
|
}
|
|
1099
1100
|
}
|
|
1100
1101
|
}
|
|
1101
1102
|
|
|
1102
1103
|
// Finalize: normalize O by row sums
|
|
1103
|
-
|
|
1104
|
-
_mm512_store_ps(row_sums, softmax_state.
|
|
1104
|
+
nk_f32_t row_sums[16];
|
|
1105
|
+
_mm512_store_ps(row_sums, softmax_state.row_sum_f32x16);
|
|
1105
1106
|
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
1106
1107
|
nk_f32_t inv_sum = 1.0f / row_sums[qi];
|
|
1107
1108
|
for (nk_size_t d = 0; d < head_dim; d++) {
|
|
@@ -1149,7 +1150,7 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
|
|
|
1149
1150
|
NK_ALIGN64 nk_f32_t o_tile[16][16]; // Output tile buffer
|
|
1150
1151
|
NK_ALIGN64 nk_f32_t o_acc[16][256]; // Output accumulator (max d=256)
|
|
1151
1152
|
|
|
1152
|
-
__m512
|
|
1153
|
+
__m512 neg_inf_f32x16 = _mm512_set1_ps(NK_F32_MIN);
|
|
1153
1154
|
|
|
1154
1155
|
for (nk_size_t h = 0; h < num_heads; h++) {
|
|
1155
1156
|
nk_size_t kv_h = h / gqa_ratio;
|
|
@@ -1169,10 +1170,10 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
|
|
|
1169
1170
|
// Full tile - fast SIMD copy
|
|
1170
1171
|
for (nk_size_t row = 0; row < valid_q; row++) {
|
|
1171
1172
|
nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
|
|
1172
|
-
__m256i
|
|
1173
|
-
__m256i
|
|
1174
|
-
_mm256_store_si256((__m256i *)&q_tiles[dt][row][0],
|
|
1175
|
-
_mm256_store_si256((__m256i *)&q_tiles[dt][row][16],
|
|
1173
|
+
__m256i q0_bf16x16 = _mm256_loadu_si256((__m256i const *)q_row);
|
|
1174
|
+
__m256i q1_bf16x16 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
|
|
1175
|
+
_mm256_store_si256((__m256i *)&q_tiles[dt][row][0], q0_bf16x16);
|
|
1176
|
+
_mm256_store_si256((__m256i *)&q_tiles[dt][row][16], q1_bf16x16);
|
|
1176
1177
|
}
|
|
1177
1178
|
// Zero remaining rows
|
|
1178
1179
|
for (nk_size_t row = valid_q; row < 16; row++) {
|
|
@@ -1198,12 +1199,12 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
|
|
|
1198
1199
|
nk_attention_softmax_row_state_t softmax_state;
|
|
1199
1200
|
nk_attention_softmax_init_(&softmax_state);
|
|
1200
1201
|
|
|
1201
|
-
__m512
|
|
1202
|
+
__m512 zero_f32x16 = _mm512_setzero_ps();
|
|
1202
1203
|
for (nk_size_t i = 0; i < 16 * head_dim_padded; i += 64) {
|
|
1203
|
-
_mm512_store_ps(&o_acc[0][i],
|
|
1204
|
-
_mm512_store_ps(&o_acc[0][i + 16],
|
|
1205
|
-
_mm512_store_ps(&o_acc[0][i + 32],
|
|
1206
|
-
_mm512_store_ps(&o_acc[0][i + 48],
|
|
1204
|
+
_mm512_store_ps(&o_acc[0][i], zero_f32x16);
|
|
1205
|
+
_mm512_store_ps(&o_acc[0][i + 16], zero_f32x16);
|
|
1206
|
+
_mm512_store_ps(&o_acc[0][i + 32], zero_f32x16);
|
|
1207
|
+
_mm512_store_ps(&o_acc[0][i + 48], zero_f32x16);
|
|
1207
1208
|
}
|
|
1208
1209
|
|
|
1209
1210
|
// Process KV blocks
|
|
@@ -1239,7 +1240,7 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
|
|
|
1239
1240
|
if (kvb + 16 < kv_len) { _tile_stored(3, &scores[0][16], 128); }
|
|
1240
1241
|
else {
|
|
1241
1242
|
// Mask out second half
|
|
1242
|
-
for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi][16],
|
|
1243
|
+
for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi][16], neg_inf_f32x16); }
|
|
1243
1244
|
}
|
|
1244
1245
|
|
|
1245
1246
|
// Apply masking for invalid positions (only on boundaries)
|
|
@@ -1250,30 +1251,31 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
|
|
|
1250
1251
|
|
|
1251
1252
|
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1252
1253
|
if (qi >= valid_q) {
|
|
1253
|
-
_mm512_store_ps(&scores[qi][0],
|
|
1254
|
-
_mm512_store_ps(&scores[qi][16],
|
|
1254
|
+
_mm512_store_ps(&scores[qi][0], neg_inf_f32x16);
|
|
1255
|
+
_mm512_store_ps(&scores[qi][16], neg_inf_f32x16);
|
|
1255
1256
|
}
|
|
1256
1257
|
else {
|
|
1257
|
-
__m512
|
|
1258
|
-
__m512
|
|
1259
|
-
_mm512_store_ps(&scores[qi][0], _mm512_mask_blend_ps(kv_mask0,
|
|
1260
|
-
_mm512_store_ps(&scores[qi][16], _mm512_mask_blend_ps(kv_mask1,
|
|
1258
|
+
__m512 s0_f32x16 = _mm512_load_ps(&scores[qi][0]);
|
|
1259
|
+
__m512 s1_f32x16 = _mm512_load_ps(&scores[qi][16]);
|
|
1260
|
+
_mm512_store_ps(&scores[qi][0], _mm512_mask_blend_ps(kv_mask0, neg_inf_f32x16, s0_f32x16));
|
|
1261
|
+
_mm512_store_ps(&scores[qi][16], _mm512_mask_blend_ps(kv_mask1, neg_inf_f32x16, s1_f32x16));
|
|
1261
1262
|
}
|
|
1262
1263
|
}
|
|
1263
1264
|
}
|
|
1264
1265
|
|
|
1265
1266
|
// Phase 2: online softmax (fast degree-4 exp)
|
|
1266
|
-
__m512
|
|
1267
|
+
__m512 old_max_f32x16 = softmax_state.row_max_f32x16;
|
|
1267
1268
|
nk_attention_softmax_update_bc32_fast_(&softmax_state, &scores[0][0], scale, &weights[0][0]);
|
|
1268
|
-
nk_attention_rescale_output_(&o_acc[0][0], head_dim_padded,
|
|
1269
|
+
nk_attention_rescale_output_(&o_acc[0][0], head_dim_padded, old_max_f32x16,
|
|
1270
|
+
softmax_state.row_max_f32x16);
|
|
1269
1271
|
|
|
1270
1272
|
// Phase 3: O += P × V with hoisted P tile load
|
|
1271
1273
|
// Convert F32 weights to BF16 P tile (once per KV block)
|
|
1272
1274
|
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1273
|
-
__m512
|
|
1274
|
-
__m512
|
|
1275
|
-
__m256bh pb0 = _mm512_cvtneps_pbh(
|
|
1276
|
-
__m256bh pb1 = _mm512_cvtneps_pbh(
|
|
1275
|
+
__m512 p0_f32x16 = _mm512_load_ps(&weights[qi][0]);
|
|
1276
|
+
__m512 p1_f32x16 = _mm512_load_ps(&weights[qi][16]);
|
|
1277
|
+
__m256bh pb0 = _mm512_cvtneps_pbh(p0_f32x16);
|
|
1278
|
+
__m256bh pb1 = _mm512_cvtneps_pbh(p1_f32x16);
|
|
1277
1279
|
*(__m256bh *)&p_tile[qi][0] = pb0;
|
|
1278
1280
|
*(__m256bh *)&p_tile[qi][16] = pb1;
|
|
1279
1281
|
}
|
|
@@ -1299,33 +1301,33 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
|
|
|
1299
1301
|
|
|
1300
1302
|
// Accumulate into output (unrolled)
|
|
1301
1303
|
for (nk_size_t qi = 0; qi < 16; qi += 4) {
|
|
1302
|
-
__m512
|
|
1303
|
-
__m512
|
|
1304
|
-
__m512
|
|
1305
|
-
__m512
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
_mm512_store_ps(&o_acc[qi + 0][head_start],
|
|
1313
|
-
_mm512_store_ps(&o_acc[qi + 1][head_start],
|
|
1314
|
-
_mm512_store_ps(&o_acc[qi + 2][head_start],
|
|
1315
|
-
_mm512_store_ps(&o_acc[qi + 3][head_start],
|
|
1304
|
+
__m512 acc0_f32x16 = _mm512_load_ps(&o_acc[qi + 0][head_start]);
|
|
1305
|
+
__m512 acc1_f32x16 = _mm512_load_ps(&o_acc[qi + 1][head_start]);
|
|
1306
|
+
__m512 acc2_f32x16 = _mm512_load_ps(&o_acc[qi + 2][head_start]);
|
|
1307
|
+
__m512 acc3_f32x16 = _mm512_load_ps(&o_acc[qi + 3][head_start]);
|
|
1308
|
+
|
|
1309
|
+
acc0_f32x16 = _mm512_add_ps(acc0_f32x16, _mm512_load_ps(&o_tile[qi + 0][0]));
|
|
1310
|
+
acc1_f32x16 = _mm512_add_ps(acc1_f32x16, _mm512_load_ps(&o_tile[qi + 1][0]));
|
|
1311
|
+
acc2_f32x16 = _mm512_add_ps(acc2_f32x16, _mm512_load_ps(&o_tile[qi + 2][0]));
|
|
1312
|
+
acc3_f32x16 = _mm512_add_ps(acc3_f32x16, _mm512_load_ps(&o_tile[qi + 3][0]));
|
|
1313
|
+
|
|
1314
|
+
_mm512_store_ps(&o_acc[qi + 0][head_start], acc0_f32x16);
|
|
1315
|
+
_mm512_store_ps(&o_acc[qi + 1][head_start], acc1_f32x16);
|
|
1316
|
+
_mm512_store_ps(&o_acc[qi + 2][head_start], acc2_f32x16);
|
|
1317
|
+
_mm512_store_ps(&o_acc[qi + 3][head_start], acc3_f32x16);
|
|
1316
1318
|
}
|
|
1317
1319
|
}
|
|
1318
1320
|
}
|
|
1319
1321
|
|
|
1320
1322
|
// Finalize: normalize O by row sums
|
|
1321
|
-
|
|
1322
|
-
_mm512_store_ps(row_sums, softmax_state.
|
|
1323
|
+
nk_f32_t row_sums[16];
|
|
1324
|
+
_mm512_store_ps(row_sums, softmax_state.row_sum_f32x16);
|
|
1323
1325
|
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
1324
|
-
__m512
|
|
1326
|
+
__m512 inv_sum_f32x16 = _mm512_set1_ps(1.0f / row_sums[qi]);
|
|
1325
1327
|
for (nk_size_t d = 0; d < head_dim; d += 16) {
|
|
1326
|
-
__m512
|
|
1327
|
-
|
|
1328
|
-
_mm512_storeu_ps(&o_head[(qb + qi) * head_dim + d],
|
|
1328
|
+
__m512 o_f32x16 = _mm512_load_ps(&o_acc[qi][d]);
|
|
1329
|
+
o_f32x16 = _mm512_mul_ps(o_f32x16, inv_sum_f32x16);
|
|
1330
|
+
_mm512_storeu_ps(&o_head[(qb + qi) * head_dim + d], o_f32x16);
|
|
1329
1331
|
}
|
|
1330
1332
|
}
|
|
1331
1333
|
}
|