numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1361 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief FlashAttention-style kernels for Intel Sapphire Rapids AMX.
|
|
3
|
+
* @file include/numkong/attention/sapphireamx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 5, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/attention.h
|
|
8
|
+
*
|
|
9
|
+
* This file implements FlashAttention-2 style scaled dot-product attention (SDPA) optimized
|
|
10
|
+
* for Intel AMX instructions on Sapphire Rapids CPUs. The kernel computes:
|
|
11
|
+
*
|
|
12
|
+
* O = softmax(Q × Kᵀ / √d) × V
|
|
13
|
+
*
|
|
14
|
+
* Key features:
|
|
15
|
+
* - Online softmax: Mathematically exact, processes KV blocks incrementally
|
|
16
|
+
* - Pre-packed KV cache: Amortizes packing cost for repeated inference
|
|
17
|
+
* - GQA/MQA support: Different num_heads and num_kv_heads for grouped-query attention
|
|
18
|
+
* - Causal masking: Optional masking for autoregressive generation
|
|
19
|
+
*
|
|
20
|
+
* Target models (2025):
|
|
21
|
+
* - Kimi K2: head_dim=112, 64 heads, MHA, 128K context
|
|
22
|
+
* - LLaMA 3.1 405B: head_dim=128, 128 heads, 16 KV heads (GQA 8:1), 128K context
|
|
23
|
+
* - Qwen 2.5 72B: head_dim=128, 64 heads, 8 KV heads (GQA 8:1), 32K context
|
|
24
|
+
*
|
|
25
|
+
* Performance comparison with H100 FlashAttention-2:
|
|
26
|
+
* - H100 SXM5: ~335 TFLOPS (35% of 989 TFLOPS peak), 80GB HBM3
|
|
27
|
+
* - 100-core SPR: ~40 TFLOPS with FlashAttention (13% of 300 TFLOPS peak)
|
|
28
|
+
* - CPU advantage: 512GB-2TB DDR5 vs 80GB HBM → supports 10-25⨯ longer contexts
|
|
29
|
+
*
|
|
30
|
+
* Expected performance per core:
|
|
31
|
+
* - Decode (query_len=1, kv_len=4K): 350-450 GOPS (softmax bound)
|
|
32
|
+
* - Prefill (query_len=64, kv_len=4K): 450-550 GOPS (better AMX utilization)
|
|
33
|
+
* - Long context (kv_len=64K+): 250-350 GOPS (memory bandwidth bound)
|
|
34
|
+
*
|
|
35
|
+
* Block sizes:
|
|
36
|
+
* - Bᵣ = 16 (query block rows, matches AMX tile height)
|
|
37
|
+
* - Bᶜ = 16 (KV block columns, fits 16×16 scores in 16 ZMM registers)
|
|
38
|
+
*
|
|
39
|
+
* Algorithm (FlashAttention-2 style):
|
|
40
|
+
* For each query block:
|
|
41
|
+
* Initialize O = 0, rowsum = 0, rowmax = -∞
|
|
42
|
+
* For each KV block:
|
|
43
|
+
* S = Q × Kᵀ using AMX TDPBF16PS
|
|
44
|
+
* Apply online softmax: rescale old values, accumulate new
|
|
45
|
+
* O = rescale(O) + P × V using AMX
|
|
46
|
+
* Finalize: normalize O by row sums
|
|
47
|
+
*
|
|
48
|
+
* @section sapphireamx_attention_instructions Relevant Instructions
|
|
49
|
+
*
|
|
50
|
+
* Intrinsic Instruction Sapphire
|
|
51
|
+
* _tile_dpbf16ps TDPBF16PS (TMM, TMM, TMM) ~16cy (16x16x32 BF16)
|
|
52
|
+
* _tile_dpbssd TDPBSSD (TMM, TMM, TMM) ~16cy (16x16x64 INT8)
|
|
53
|
+
* _tile_loadd TILELOADD (TMM, MEM) ~10cy @ p23
|
|
54
|
+
* _tile_stored TILESTORED (MEM, TMM) ~10cy @ p4
|
|
55
|
+
* _tile_zero TILEZERO (TMM) ~1cy
|
|
56
|
+
* _mm512_fmadd_ps VFMADD (ZMM, ZMM, ZMM) 4cy @ p05
|
|
57
|
+
* _mm512_mul_ps VMULPS (ZMM, ZMM, ZMM) 4cy @ p05
|
|
58
|
+
* _mm512_max_ps VMAXPS (ZMM, ZMM, ZMM) 4cy @ p05
|
|
59
|
+
* _mm512_reduce_max_ps (pseudo: VHADDPS chain) ~8cy
|
|
60
|
+
* _mm512_reduce_add_ps (pseudo: VHADDPS chain) ~8cy
|
|
61
|
+
*/
|
|
62
|
+
#ifndef NK_ATTENTION_SAPPHIREAMX_H
|
|
63
|
+
#define NK_ATTENTION_SAPPHIREAMX_H
|
|
64
|
+
|
|
65
|
+
#if NK_TARGET_X86_
|
|
66
|
+
#if NK_TARGET_SAPPHIREAMX
|
|
67
|
+
|
|
68
|
+
#include "numkong/types.h"
|
|
69
|
+
#include "numkong/dots/sapphireamx.h"
|
|
70
|
+
|
|
71
|
+
#if defined(__cplusplus)
|
|
72
|
+
extern "C" {
|
|
73
|
+
#endif
|
|
74
|
+
|
|
75
|
+
#if defined(__clang__)
|
|
76
|
+
#pragma clang attribute push( \
|
|
77
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512bf16,f16c,fma,bmi,bmi2"))), \
|
|
78
|
+
apply_to = function)
|
|
79
|
+
#elif defined(__GNUC__)
|
|
80
|
+
#pragma GCC push_options
|
|
81
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512bf16", "f16c", "fma", \
|
|
82
|
+
"bmi", "bmi2")
|
|
83
|
+
#endif
|
|
84
|
+
|
|
85
|
+
/**
|
|
86
|
+
* @brief Packed KV cache header for attention (64-byte aligned).
|
|
87
|
+
*
|
|
88
|
+
* Layout in memory:
|
|
89
|
+
* [header: 64 bytes][K tiles: variable][V tiles: variable]
|
|
90
|
+
*
|
|
91
|
+
* K and V are packed in AMX tile format for efficient loading.
|
|
92
|
+
*/
|
|
93
|
+
typedef struct {
|
|
94
|
+
nk_u32_t num_kv_heads; ///< Number of K/V heads (for GQA, may differ from Q heads)
|
|
95
|
+
nk_u32_t head_dim; ///< Original head dimension (64, 112, 128)
|
|
96
|
+
nk_u32_t head_dim_padded; ///< Padded to multiple of 32 for AMX tiles
|
|
97
|
+
nk_u32_t seq_len; ///< Current sequence length
|
|
98
|
+
nk_u32_t max_seq_len; ///< Maximum sequence length (for pre-allocation)
|
|
99
|
+
nk_u32_t k_offset; ///< Byte offset to K tiles from header start
|
|
100
|
+
nk_u32_t v_offset; ///< Byte offset to V tiles from header start
|
|
101
|
+
nk_u32_t reserved[9]; ///< Pad to 64 bytes
|
|
102
|
+
} nk_attention_kv_packed_header_t;
|
|
103
|
+
|
|
104
|
+
/**
|
|
105
|
+
* @brief Fast exp approximation for AVX-512.
|
|
106
|
+
*
|
|
107
|
+
* Uses Cody-Waite range reduction + Remez minimax polynomial.
|
|
108
|
+
* Accuracy: max error < 1 ULP for x ∈ [-87.3, 88.7] (float range).
|
|
109
|
+
* Performance: ~15-20 cycles for 16 floats.
|
|
110
|
+
*/
|
|
111
|
+
|
|
112
|
+
/**
|
|
113
|
+
* @brief Fast vectorized exp(x) approximation using AVX-512.
|
|
114
|
+
*
|
|
115
|
+
* Algorithm:
|
|
116
|
+
* 1. Range reduction: x = n × ln(2) + r, where |r| < ln(2)/2
|
|
117
|
+
* 2. Polynomial approximation: exp(r) ≈ 1 + r + r²/2 + ... (degree 6)
|
|
118
|
+
* 3. Reconstruction: exp(x) = 2ⁿ × exp(r)
|
|
119
|
+
*
|
|
120
|
+
* @param x Input vector (16 floats)
|
|
121
|
+
* @return exp(x) for each element
|
|
122
|
+
*/
|
|
123
|
+
NK_INTERNAL __m512 nk_exp_ps_avx512_(__m512 x) {
|
|
124
|
+
// Constants for Cody-Waite range reduction
|
|
125
|
+
const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
|
|
126
|
+
const __m512 ln2_hi = _mm512_set1_ps(0.693145751953125f);
|
|
127
|
+
const __m512 ln2_lo = _mm512_set1_ps(1.42860682030941723212e-6f);
|
|
128
|
+
|
|
129
|
+
// Clamp to avoid overflow/underflow
|
|
130
|
+
const __m512 max_x = _mm512_set1_ps(88.3762626647949f);
|
|
131
|
+
const __m512 min_x = _mm512_set1_ps(-87.3365447504021f);
|
|
132
|
+
x = _mm512_max_ps(_mm512_min_ps(x, max_x), min_x);
|
|
133
|
+
|
|
134
|
+
// n = round(x / ln(2))
|
|
135
|
+
__m512 n = _mm512_roundscale_ps(_mm512_mul_ps(x, log2e), _MM_FROUND_TO_NEAREST_INT);
|
|
136
|
+
|
|
137
|
+
// r = x - n × ln(2) using Cody-Waite for precision
|
|
138
|
+
__m512 r = _mm512_fnmadd_ps(n, ln2_hi, x);
|
|
139
|
+
r = _mm512_fnmadd_ps(n, ln2_lo, r);
|
|
140
|
+
|
|
141
|
+
// Polynomial approximation for exp(r): Remez minimax degree 6
|
|
142
|
+
// Coefficients optimized for [-ln(2)/2, ln(2)/2]
|
|
143
|
+
__m512 p = _mm512_set1_ps(1.9875691500e-4f);
|
|
144
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.3981999507e-3f));
|
|
145
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(8.3334519073e-3f));
|
|
146
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(4.1665858030e-2f));
|
|
147
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.6666665459e-1f));
|
|
148
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(5.0000001201e-1f));
|
|
149
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
|
|
150
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
|
|
151
|
+
|
|
152
|
+
// Reconstruct: exp(x) = 2ⁿ × exp(r)
|
|
153
|
+
// 2ⁿ via IEEE 754 exponent manipulation
|
|
154
|
+
__m512i ni = _mm512_cvtps_epi32(n);
|
|
155
|
+
ni = _mm512_add_epi32(ni, _mm512_set1_epi32(127));
|
|
156
|
+
ni = _mm512_slli_epi32(ni, 23);
|
|
157
|
+
__m512 pow2n = _mm512_castsi512_ps(ni);
|
|
158
|
+
|
|
159
|
+
return _mm512_mul_ps(p, pow2n);
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
/**
|
|
163
|
+
* @brief Faster exp(x) approximation using degree-4 polynomial.
|
|
164
|
+
*
|
|
165
|
+
* Trades accuracy for speed: ~0.1% relative error (vs <0.001% for degree-6).
|
|
166
|
+
* This is acceptable for softmax where:
|
|
167
|
+
* - Probabilities sum to 1 (normalization absorbs errors)
|
|
168
|
+
* - Relative ranking matters more than absolute values
|
|
169
|
+
*
|
|
170
|
+
* Performance: ~12-15 cycles for 16 floats (vs ~18-22 for degree-6)
|
|
171
|
+
*
|
|
172
|
+
* @param x Input vector (16 floats)
|
|
173
|
+
* @return exp(x) approximation
|
|
174
|
+
*/
|
|
175
|
+
NK_INTERNAL __m512 nk_exp_ps_fast_avx512_(__m512 x) {
|
|
176
|
+
// Constants for Cody-Waite range reduction
|
|
177
|
+
const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
|
|
178
|
+
const __m512 ln2_hi = _mm512_set1_ps(0.693145751953125f);
|
|
179
|
+
const __m512 ln2_lo = _mm512_set1_ps(1.42860682030941723212e-6f);
|
|
180
|
+
|
|
181
|
+
// Clamp to avoid overflow/underflow (same as accurate version)
|
|
182
|
+
const __m512 max_x = _mm512_set1_ps(88.3762626647949f);
|
|
183
|
+
const __m512 min_x = _mm512_set1_ps(-87.3365447504021f);
|
|
184
|
+
x = _mm512_max_ps(_mm512_min_ps(x, max_x), min_x);
|
|
185
|
+
|
|
186
|
+
// n = round(x / ln(2))
|
|
187
|
+
__m512 n = _mm512_roundscale_ps(_mm512_mul_ps(x, log2e), _MM_FROUND_TO_NEAREST_INT);
|
|
188
|
+
|
|
189
|
+
// r = x - n × ln(2) using Cody-Waite for precision
|
|
190
|
+
__m512 r = _mm512_fnmadd_ps(n, ln2_hi, x);
|
|
191
|
+
r = _mm512_fnmadd_ps(n, ln2_lo, r);
|
|
192
|
+
|
|
193
|
+
// Polynomial approximation for exp(r): degree 4
|
|
194
|
+
// Optimized coefficients for [-ln(2)/2, ln(2)/2]
|
|
195
|
+
// exp(r) ≈ 1 + r + r²/2 + r³/6 + r⁴/24
|
|
196
|
+
// Using Horner form: ((c₄ × r + c₃) × r + c₂) × r + c₁) × r + c₀
|
|
197
|
+
__m512 p = _mm512_set1_ps(4.1666666667e-2f); // 1/24
|
|
198
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.6666666667e-1f)); // 1/6
|
|
199
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(5.0000000000e-1f)); // 1/2
|
|
200
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); // 1
|
|
201
|
+
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); // 1
|
|
202
|
+
|
|
203
|
+
// Reconstruct: exp(x) = 2ⁿ × exp(r)
|
|
204
|
+
__m512i ni = _mm512_cvtps_epi32(n);
|
|
205
|
+
ni = _mm512_add_epi32(ni, _mm512_set1_epi32(127));
|
|
206
|
+
ni = _mm512_slli_epi32(ni, 23);
|
|
207
|
+
__m512 pow2n = _mm512_castsi512_ps(ni);
|
|
208
|
+
|
|
209
|
+
return _mm512_mul_ps(p, pow2n);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
/**
|
|
213
|
+
* @brief Online softmax primitives.
|
|
214
|
+
*
|
|
215
|
+
* These implement the online softmax algorithm from FlashAttention.
|
|
216
|
+
* Key insight: softmax can be computed incrementally by tracking:
|
|
217
|
+
* - m: running maximum (for numerical stability)
|
|
218
|
+
* - l: running sum of exp(x - m)
|
|
219
|
+
*
|
|
220
|
+
* When a new block arrives with larger values:
|
|
221
|
+
* - Rescale old sum: l = l × exp(m_old - m_new)
|
|
222
|
+
* - Add new contributions: l += Σ exp(x_new - m_new)
|
|
223
|
+
*/
|
|
224
|
+
|
|
225
|
+
/**
|
|
226
|
+
* @brief State for online softmax computation.
|
|
227
|
+
*
|
|
228
|
+
* Tracks per-row running maximum and sum for 16 rows.
|
|
229
|
+
*/
|
|
230
|
+
typedef struct {
|
|
231
|
+
__m512 row_max; ///< Running max per row (16 values)
|
|
232
|
+
__m512 row_sum; ///< Running sum of exp(x - max) per row
|
|
233
|
+
} nk_attention_softmax_row_state_t;
|
|
234
|
+
|
|
235
|
+
/**
|
|
236
|
+
* @brief Update softmax state with Bᶜ=32 score block (optimized).
|
|
237
|
+
*
|
|
238
|
+
* Computes online softmax for 16×32 score block using AVX-512.
|
|
239
|
+
* Optimizations:
|
|
240
|
+
* - Process 4 rows at a time for better ILP
|
|
241
|
+
* - Keep scaled scores in registers to avoid reloading
|
|
242
|
+
* - Vectorized row sum accumulation
|
|
243
|
+
*/
|
|
244
|
+
NK_INTERNAL void nk_attention_softmax_update_bc32_(nk_attention_softmax_row_state_t *state,
|
|
245
|
+
nk_f32_t const *scores, // [16, 32] score block
|
|
246
|
+
nk_f32_t scale,
|
|
247
|
+
nk_f32_t *weights_out) { // [16, 32] output weights
|
|
248
|
+
|
|
249
|
+
__m512 scale_v = _mm512_set1_ps(scale);
|
|
250
|
+
|
|
251
|
+
// Load and scale all scores, compute per-row max
|
|
252
|
+
// Store in temporary arrays to avoid register pressure
|
|
253
|
+
__m512 s_scaled[16][2];
|
|
254
|
+
NK_ALIGN64 float row_maxes[16];
|
|
255
|
+
|
|
256
|
+
// Process 4 rows at a time for ILP
|
|
257
|
+
for (int i = 0; i < 16; i += 4) {
|
|
258
|
+
// Row i
|
|
259
|
+
s_scaled[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v);
|
|
260
|
+
s_scaled[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v);
|
|
261
|
+
__m512 m0 = _mm512_max_ps(s_scaled[i][0], s_scaled[i][1]);
|
|
262
|
+
|
|
263
|
+
// Row i+1
|
|
264
|
+
s_scaled[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v);
|
|
265
|
+
s_scaled[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v);
|
|
266
|
+
__m512 m1 = _mm512_max_ps(s_scaled[i + 1][0], s_scaled[i + 1][1]);
|
|
267
|
+
|
|
268
|
+
// Row i+2
|
|
269
|
+
s_scaled[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v);
|
|
270
|
+
s_scaled[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v);
|
|
271
|
+
__m512 m2 = _mm512_max_ps(s_scaled[i + 2][0], s_scaled[i + 2][1]);
|
|
272
|
+
|
|
273
|
+
// Row i+3
|
|
274
|
+
s_scaled[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v);
|
|
275
|
+
s_scaled[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v);
|
|
276
|
+
__m512 m3 = _mm512_max_ps(s_scaled[i + 3][0], s_scaled[i + 3][1]);
|
|
277
|
+
|
|
278
|
+
// Reduce to scalar max
|
|
279
|
+
row_maxes[i] = _mm512_reduce_max_ps(m0);
|
|
280
|
+
row_maxes[i + 1] = _mm512_reduce_max_ps(m1);
|
|
281
|
+
row_maxes[i + 2] = _mm512_reduce_max_ps(m2);
|
|
282
|
+
row_maxes[i + 3] = _mm512_reduce_max_ps(m3);
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
__m512 row_max_new = _mm512_load_ps(row_maxes);
|
|
286
|
+
__m512 old_max = state->row_max;
|
|
287
|
+
__m512 new_max = _mm512_max_ps(old_max, row_max_new);
|
|
288
|
+
|
|
289
|
+
// Rescale old sum
|
|
290
|
+
__m512 correction = nk_exp_ps_avx512_(_mm512_sub_ps(old_max, new_max));
|
|
291
|
+
__m512 new_sum = _mm512_mul_ps(state->row_sum, correction);
|
|
292
|
+
|
|
293
|
+
// Compute P = exp(S - new_max) and accumulate sums
|
|
294
|
+
NK_ALIGN64 float new_max_arr[16];
|
|
295
|
+
NK_ALIGN64 float row_sums[16];
|
|
296
|
+
_mm512_store_ps(new_max_arr, new_max);
|
|
297
|
+
|
|
298
|
+
// Process rows
|
|
299
|
+
for (int i = 0; i < 16; i += 2) {
|
|
300
|
+
__m512 max_i = _mm512_set1_ps(new_max_arr[i]);
|
|
301
|
+
__m512 max_i1 = _mm512_set1_ps(new_max_arr[i + 1]);
|
|
302
|
+
|
|
303
|
+
// Row i
|
|
304
|
+
__m512 p0_i = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i][0], max_i));
|
|
305
|
+
__m512 p1_i = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i][1], max_i));
|
|
306
|
+
_mm512_store_ps(weights_out + i * 32 + 0, p0_i);
|
|
307
|
+
_mm512_store_ps(weights_out + i * 32 + 16, p1_i);
|
|
308
|
+
row_sums[i] = _mm512_reduce_add_ps(p0_i) + _mm512_reduce_add_ps(p1_i);
|
|
309
|
+
|
|
310
|
+
// Row i+1
|
|
311
|
+
__m512 p0_i1 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i + 1][0], max_i1));
|
|
312
|
+
__m512 p1_i1 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i + 1][1], max_i1));
|
|
313
|
+
_mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1);
|
|
314
|
+
_mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1);
|
|
315
|
+
row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1) + _mm512_reduce_add_ps(p1_i1);
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
// Add row sums to running sum vectorially
|
|
319
|
+
new_sum = _mm512_add_ps(new_sum, _mm512_load_ps(row_sums));
|
|
320
|
+
|
|
321
|
+
state->row_max = new_max;
|
|
322
|
+
state->row_sum = new_sum;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
/**
|
|
326
|
+
* @brief Fast softmax update using degree-4 exp polynomial.
|
|
327
|
+
*
|
|
328
|
+
* Same algorithm as nk_attention_softmax_update_bc32_ but uses faster exp.
|
|
329
|
+
* Trades ~0.1% accuracy for ~20% performance improvement.
|
|
330
|
+
*
|
|
331
|
+
* Use this for inference where throughput matters more than last-bit accuracy.
|
|
332
|
+
*/
|
|
333
|
+
NK_INTERNAL void nk_attention_softmax_update_bc32_fast_(nk_attention_softmax_row_state_t *state,
|
|
334
|
+
nk_f32_t const *scores, // [16, 32] score block
|
|
335
|
+
nk_f32_t scale,
|
|
336
|
+
nk_f32_t *weights_out) { // [16, 32] output weights
|
|
337
|
+
|
|
338
|
+
__m512 scale_v = _mm512_set1_ps(scale);
|
|
339
|
+
|
|
340
|
+
// Load and scale all scores, compute per-row max
|
|
341
|
+
__m512 s_scaled[16][2];
|
|
342
|
+
NK_ALIGN64 float row_maxes[16];
|
|
343
|
+
|
|
344
|
+
// Process 4 rows at a time for ILP
|
|
345
|
+
for (int i = 0; i < 16; i += 4) {
|
|
346
|
+
s_scaled[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v);
|
|
347
|
+
s_scaled[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v);
|
|
348
|
+
__m512 m0 = _mm512_max_ps(s_scaled[i][0], s_scaled[i][1]);
|
|
349
|
+
|
|
350
|
+
s_scaled[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v);
|
|
351
|
+
s_scaled[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v);
|
|
352
|
+
__m512 m1 = _mm512_max_ps(s_scaled[i + 1][0], s_scaled[i + 1][1]);
|
|
353
|
+
|
|
354
|
+
s_scaled[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v);
|
|
355
|
+
s_scaled[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v);
|
|
356
|
+
__m512 m2 = _mm512_max_ps(s_scaled[i + 2][0], s_scaled[i + 2][1]);
|
|
357
|
+
|
|
358
|
+
s_scaled[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v);
|
|
359
|
+
s_scaled[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v);
|
|
360
|
+
__m512 m3 = _mm512_max_ps(s_scaled[i + 3][0], s_scaled[i + 3][1]);
|
|
361
|
+
|
|
362
|
+
row_maxes[i] = _mm512_reduce_max_ps(m0);
|
|
363
|
+
row_maxes[i + 1] = _mm512_reduce_max_ps(m1);
|
|
364
|
+
row_maxes[i + 2] = _mm512_reduce_max_ps(m2);
|
|
365
|
+
row_maxes[i + 3] = _mm512_reduce_max_ps(m3);
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
__m512 row_max_new = _mm512_load_ps(row_maxes);
|
|
369
|
+
__m512 old_max = state->row_max;
|
|
370
|
+
__m512 new_max = _mm512_max_ps(old_max, row_max_new);
|
|
371
|
+
|
|
372
|
+
// Rescale old sum using fast exp
|
|
373
|
+
__m512 correction = nk_exp_ps_fast_avx512_(_mm512_sub_ps(old_max, new_max));
|
|
374
|
+
__m512 new_sum = _mm512_mul_ps(state->row_sum, correction);
|
|
375
|
+
|
|
376
|
+
// Compute P = exp(S - new_max) using fast exp
|
|
377
|
+
NK_ALIGN64 float new_max_arr[16];
|
|
378
|
+
NK_ALIGN64 float row_sums[16];
|
|
379
|
+
_mm512_store_ps(new_max_arr, new_max);
|
|
380
|
+
|
|
381
|
+
// Process rows with fast exp
|
|
382
|
+
for (int i = 0; i < 16; i += 2) {
|
|
383
|
+
__m512 max_i = _mm512_set1_ps(new_max_arr[i]);
|
|
384
|
+
__m512 max_i1 = _mm512_set1_ps(new_max_arr[i + 1]);
|
|
385
|
+
|
|
386
|
+
// Row i
|
|
387
|
+
__m512 p0_i = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i][0], max_i));
|
|
388
|
+
__m512 p1_i = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i][1], max_i));
|
|
389
|
+
_mm512_store_ps(weights_out + i * 32 + 0, p0_i);
|
|
390
|
+
_mm512_store_ps(weights_out + i * 32 + 16, p1_i);
|
|
391
|
+
row_sums[i] = _mm512_reduce_add_ps(p0_i) + _mm512_reduce_add_ps(p1_i);
|
|
392
|
+
|
|
393
|
+
// Row i+1
|
|
394
|
+
__m512 p0_i1 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i + 1][0], max_i1));
|
|
395
|
+
__m512 p1_i1 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i + 1][1], max_i1));
|
|
396
|
+
_mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1);
|
|
397
|
+
_mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1);
|
|
398
|
+
row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1) + _mm512_reduce_add_ps(p1_i1);
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
new_sum = _mm512_add_ps(new_sum, _mm512_load_ps(row_sums));
|
|
402
|
+
|
|
403
|
+
state->row_max = new_max;
|
|
404
|
+
state->row_sum = new_sum;
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
/**
|
|
408
|
+
* @brief Initialize online softmax state.
|
|
409
|
+
*/
|
|
410
|
+
NK_INTERNAL void nk_attention_softmax_init_(nk_attention_softmax_row_state_t *state) {
|
|
411
|
+
state->row_max = _mm512_set1_ps(NK_F32_MIN);
|
|
412
|
+
state->row_sum = _mm512_setzero_ps();
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
/**
|
|
416
|
+
* @brief Update softmax state with new score block and compute attention weights.
|
|
417
|
+
*
|
|
418
|
+
* For a 16×16 score block S[16][16]:
|
|
419
|
+
* 1. Compute row-wise max of S
|
|
420
|
+
* 2. Update running max: newₘₐₓ = max(oldₘₐₓ, blockₘₐₓ)
|
|
421
|
+
* 3. Rescale old sum: oldₛᵤₘ × = exp(oldₘₐₓ - newₘₐₓ)
|
|
422
|
+
* 4. Compute P = exp(S - newₘₐₓ), store for P × V
|
|
423
|
+
* 5. Update sum: newₛᵤₘ = oldₛᵤₘ + row_sum(P)
|
|
424
|
+
*
|
|
425
|
+
* @param state Running softmax state (updated in place)
|
|
426
|
+
* @param scores 16×16 score block in row-major order (256 floats)
|
|
427
|
+
* @param scale Scaling factor (1/√head_dim)
|
|
428
|
+
* @param weights_out Output: 16×16 attention weights P (pre-softmax normalized)
|
|
429
|
+
*/
|
|
430
|
+
NK_INTERNAL void nk_attention_softmax_update_(nk_attention_softmax_row_state_t *state, nk_f32_t const *scores,
|
|
431
|
+
nk_f32_t scale, nk_f32_t *weights_out) {
|
|
432
|
+
|
|
433
|
+
__m512 scale_v = _mm512_set1_ps(scale);
|
|
434
|
+
|
|
435
|
+
// Load scores into 16 ZMM registers (one per row)
|
|
436
|
+
__m512 s[16];
|
|
437
|
+
for (int i = 0; i < 16; i++) { s[i] = _mm512_mul_ps(_mm512_load_ps(scores + i * 16), scale_v); }
|
|
438
|
+
|
|
439
|
+
// Per-row max (each row has 16 elements, we need max across those 16)
|
|
440
|
+
// _mm512_reduce_max_ps returns a float scalar
|
|
441
|
+
NK_ALIGN64 float row_maxes[16];
|
|
442
|
+
for (int i = 0; i < 16; i++) { row_maxes[i] = _mm512_reduce_max_ps(s[i]); }
|
|
443
|
+
__m512 row_max_new = _mm512_load_ps(row_maxes);
|
|
444
|
+
|
|
445
|
+
// Update running max
|
|
446
|
+
__m512 old_max = state->row_max;
|
|
447
|
+
__m512 new_max = _mm512_max_ps(old_max, row_max_new);
|
|
448
|
+
|
|
449
|
+
// Rescale old sum: l = l × exp(oldₘₐₓ - newₘₐₓ)
|
|
450
|
+
__m512 correction = nk_exp_ps_avx512_(_mm512_sub_ps(old_max, new_max));
|
|
451
|
+
__m512 old_sum_rescaled = _mm512_mul_ps(state->row_sum, correction);
|
|
452
|
+
|
|
453
|
+
// Compute P = exp(S - newₘₐₓ) for each row, accumulate sum
|
|
454
|
+
__m512 new_sum = old_sum_rescaled;
|
|
455
|
+
float new_max_arr[16];
|
|
456
|
+
_mm512_store_ps(new_max_arr, new_max);
|
|
457
|
+
|
|
458
|
+
for (int i = 0; i < 16; i++) {
|
|
459
|
+
__m512 max_broadcast = _mm512_set1_ps(new_max_arr[i]);
|
|
460
|
+
__m512 p = nk_exp_ps_avx512_(_mm512_sub_ps(s[i], max_broadcast));
|
|
461
|
+
_mm512_store_ps(weights_out + i * 16, p);
|
|
462
|
+
|
|
463
|
+
// Add row sum to running sum (at position i)
|
|
464
|
+
float row_sum = _mm512_reduce_add_ps(p);
|
|
465
|
+
new_sum = _mm512_mask_add_ps(new_sum, 1u << i, new_sum, _mm512_set1_ps(row_sum));
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
state->row_max = new_max;
|
|
469
|
+
state->row_sum = new_sum;
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
/**
|
|
473
|
+
* @brief Rescale output accumulator when max changes.
|
|
474
|
+
*
|
|
475
|
+
* When processing a new KV block with larger scores, previous O accumulator
|
|
476
|
+
* needs rescaling: O = O × exp(oldₘₐₓ - newₘₐₓ)
|
|
477
|
+
*
|
|
478
|
+
* @param output Output accumulator [16][head_dim] in F32
|
|
479
|
+
* @param head_dim Head dimension
|
|
480
|
+
* @param old_max Previous running max per row (16 values)
|
|
481
|
+
* @param new_max New running max per row (16 values)
|
|
482
|
+
*/
|
|
483
|
+
NK_INTERNAL void nk_attention_rescale_output_(nk_f32_t *output, nk_size_t head_dim, __m512 old_max, __m512 new_max) {
|
|
484
|
+
|
|
485
|
+
__m512 correction = nk_exp_ps_avx512_(_mm512_sub_ps(old_max, new_max));
|
|
486
|
+
float corr_arr[16];
|
|
487
|
+
_mm512_store_ps(corr_arr, correction);
|
|
488
|
+
|
|
489
|
+
for (nk_size_t row = 0; row < 16; row++) {
|
|
490
|
+
__m512 corr_v = _mm512_set1_ps(corr_arr[row]);
|
|
491
|
+
for (nk_size_t col = 0; col < head_dim; col += 16) {
|
|
492
|
+
__m512 o = _mm512_load_ps(output + row * head_dim + col);
|
|
493
|
+
o = _mm512_mul_ps(o, corr_v);
|
|
494
|
+
_mm512_store_ps(output + row * head_dim + col, o);
|
|
495
|
+
}
|
|
496
|
+
}
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
NK_PUBLIC nk_size_t nk_attention_kv_packed_size_sapphireamx(nk_size_t num_kv_heads, nk_size_t head_dim,
|
|
500
|
+
nk_size_t max_seq_len) {
|
|
501
|
+
|
|
502
|
+
// Pad head_dim to multiple of 32 for AMX tiles
|
|
503
|
+
nk_size_t head_dim_padded = nk_size_round_up_to_multiple_(head_dim, 32);
|
|
504
|
+
|
|
505
|
+
// Each head: seq_len × head_dim_padded BF16 values
|
|
506
|
+
// Packed in AMX tile format: 16-row tiles with pair-interleaving
|
|
507
|
+
nk_size_t tiles_per_head_col = nk_size_divide_round_up_(max_seq_len, 16);
|
|
508
|
+
nk_size_t tiles_per_head_depth = head_dim_padded / 32;
|
|
509
|
+
nk_size_t bytes_per_head = tiles_per_head_col * tiles_per_head_depth * 1024; // 1KB per tile
|
|
510
|
+
|
|
511
|
+
// K and V each have num_kv_heads heads
|
|
512
|
+
nk_size_t k_size = num_kv_heads * bytes_per_head;
|
|
513
|
+
nk_size_t v_size = num_kv_heads * bytes_per_head;
|
|
514
|
+
|
|
515
|
+
// Header + K + V, all 64-byte aligned
|
|
516
|
+
return sizeof(nk_attention_kv_packed_header_t) + k_size + v_size;
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
NK_PUBLIC void nk_attention_pack_k_sapphireamx(nk_bf16_t const *k, void *kv_packed, nk_size_t num_kv_heads,
|
|
520
|
+
nk_size_t seq_len, nk_size_t head_dim) {
|
|
521
|
+
|
|
522
|
+
nk_attention_kv_packed_header_t *header = (nk_attention_kv_packed_header_t *)kv_packed;
|
|
523
|
+
|
|
524
|
+
// Initialize header
|
|
525
|
+
nk_size_t head_dim_padded = nk_size_round_up_to_multiple_(head_dim, 32);
|
|
526
|
+
header->num_kv_heads = (nk_u32_t)num_kv_heads;
|
|
527
|
+
header->head_dim = (nk_u32_t)head_dim;
|
|
528
|
+
header->head_dim_padded = (nk_u32_t)head_dim_padded;
|
|
529
|
+
header->seq_len = (nk_u32_t)seq_len;
|
|
530
|
+
header->k_offset = sizeof(nk_attention_kv_packed_header_t);
|
|
531
|
+
|
|
532
|
+
nk_bf16_t *k_packed = (nk_bf16_t *)((char *)kv_packed + header->k_offset);
|
|
533
|
+
|
|
534
|
+
// For Q × Kᵀ, K acts as B matrix but transposed
|
|
535
|
+
// K[h, s, d] → Kᵀ[h, d, s]
|
|
536
|
+
// Pack Kᵀ into AMX B tile format with pair-interleaving
|
|
537
|
+
|
|
538
|
+
nk_size_t tiles_per_seq = nk_size_divide_round_up_(seq_len, 16);
|
|
539
|
+
nk_size_t tiles_per_depth = head_dim_padded / 32;
|
|
540
|
+
nk_size_t tile_size = 512; // BF16 elements per tile
|
|
541
|
+
|
|
542
|
+
for (nk_size_t h = 0; h < num_kv_heads; h++) {
|
|
543
|
+
nk_bf16_t const *k_head = k + h * seq_len * head_dim;
|
|
544
|
+
nk_bf16_t *k_head_packed = k_packed + h * tiles_per_seq * tiles_per_depth * tile_size;
|
|
545
|
+
|
|
546
|
+
// Pack tiles: iterate over seq_len tiles (columns of Kᵀ) and depth tiles
|
|
547
|
+
for (nk_size_t seq_tile = 0; seq_tile < tiles_per_seq; seq_tile++) {
|
|
548
|
+
nk_size_t seq_start = seq_tile * 16;
|
|
549
|
+
nk_size_t valid_seq = (seq_start + 16 <= seq_len) ? 16 : (seq_len - seq_start);
|
|
550
|
+
|
|
551
|
+
for (nk_size_t depth_tile = 0; depth_tile < tiles_per_depth; depth_tile++) {
|
|
552
|
+
nk_size_t depth_start = depth_tile * 32;
|
|
553
|
+
nk_size_t valid_depth = (depth_start + 32 <= head_dim) ? 32 : (head_dim - depth_start);
|
|
554
|
+
|
|
555
|
+
// Tile index in packed format
|
|
556
|
+
nk_size_t tile_idx = seq_tile * tiles_per_depth + depth_tile;
|
|
557
|
+
nk_bf16_t *tile_ptr = k_head_packed + tile_idx * tile_size;
|
|
558
|
+
|
|
559
|
+
// Pack with pair-interleaving for TDPBF16PS
|
|
560
|
+
// B tile layout: data[depth/2][col][depth%2]
|
|
561
|
+
// For Kᵀ: depth is original head_dim, col is original seq position
|
|
562
|
+
for (nk_size_t d = 0; d < 32; d += 2) {
|
|
563
|
+
for (nk_size_t s = 0; s < 16; s++) {
|
|
564
|
+
nk_size_t dst_idx = (d / 2) * 32 + s * 2;
|
|
565
|
+
|
|
566
|
+
// K[h, seq_start + s, depth_start + d] and K[h, seq_start + s, depth_start + d + 1]
|
|
567
|
+
nk_bf16_t v0 = 0, v1 = 0;
|
|
568
|
+
if (s < valid_seq && d < valid_depth) {
|
|
569
|
+
v0 = k_head[(seq_start + s) * head_dim + depth_start + d];
|
|
570
|
+
}
|
|
571
|
+
if (s < valid_seq && d + 1 < valid_depth) {
|
|
572
|
+
v1 = k_head[(seq_start + s) * head_dim + depth_start + d + 1];
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
tile_ptr[dst_idx] = v0;
|
|
576
|
+
tile_ptr[dst_idx + 1] = v1;
|
|
577
|
+
}
|
|
578
|
+
}
|
|
579
|
+
}
|
|
580
|
+
}
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
// Calculate V offset
|
|
584
|
+
nk_size_t k_size = num_kv_heads * tiles_per_seq * tiles_per_depth * tile_size * sizeof(nk_bf16_t);
|
|
585
|
+
header->v_offset = header->k_offset + (nk_u32_t)k_size;
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
NK_PUBLIC void nk_attention_pack_v_sapphireamx(nk_bf16_t const *v, void *kv_packed, nk_size_t num_kv_heads,
|
|
589
|
+
nk_size_t seq_len, nk_size_t head_dim) {
|
|
590
|
+
|
|
591
|
+
nk_attention_kv_packed_header_t *header = (nk_attention_kv_packed_header_t *)kv_packed;
|
|
592
|
+
nk_size_t head_dim_padded = header->head_dim_padded;
|
|
593
|
+
|
|
594
|
+
nk_bf16_t *v_packed = (nk_bf16_t *)((char *)kv_packed + header->v_offset);
|
|
595
|
+
|
|
596
|
+
// For P @ V, P is [query_len, seq_len], V is [seq_len, head_dim]
|
|
597
|
+
// V acts as B matrix: pack with seq_len as "depth", head_dim as "columns"
|
|
598
|
+
|
|
599
|
+
nk_size_t tiles_per_seq = nk_size_divide_round_up_(seq_len, 32); // seq_len is depth for V
|
|
600
|
+
nk_size_t tiles_per_head = nk_size_divide_round_up_(head_dim_padded, 16); // head_dim is columns
|
|
601
|
+
nk_size_t tile_size = 512;
|
|
602
|
+
|
|
603
|
+
for (nk_size_t h = 0; h < num_kv_heads; h++) {
|
|
604
|
+
nk_bf16_t const *v_head = v + h * seq_len * head_dim;
|
|
605
|
+
nk_bf16_t *v_head_packed = v_packed + h * tiles_per_seq * tiles_per_head * tile_size;
|
|
606
|
+
|
|
607
|
+
for (nk_size_t seq_tile = 0; seq_tile < tiles_per_seq; seq_tile++) {
|
|
608
|
+
nk_size_t seq_start = seq_tile * 32;
|
|
609
|
+
nk_size_t valid_seq = (seq_start + 32 <= seq_len) ? 32 : (seq_len - seq_start);
|
|
610
|
+
|
|
611
|
+
for (nk_size_t head_tile = 0; head_tile < tiles_per_head; head_tile++) {
|
|
612
|
+
nk_size_t head_start = head_tile * 16;
|
|
613
|
+
nk_size_t valid_head = (head_start + 16 <= head_dim) ? 16 : (head_dim - head_start);
|
|
614
|
+
|
|
615
|
+
nk_size_t tile_idx = seq_tile * tiles_per_head + head_tile;
|
|
616
|
+
nk_bf16_t *tile_ptr = v_head_packed + tile_idx * tile_size;
|
|
617
|
+
|
|
618
|
+
// Pack with pair-interleaving
|
|
619
|
+
// B tile: data[depth/2][col][depth%2] where depth=seq, col=head_dim
|
|
620
|
+
for (nk_size_t s = 0; s < 32; s += 2) {
|
|
621
|
+
for (nk_size_t d = 0; d < 16; d++) {
|
|
622
|
+
nk_size_t dst_idx = (s / 2) * 32 + d * 2;
|
|
623
|
+
|
|
624
|
+
nk_bf16_t v0 = 0, v1 = 0;
|
|
625
|
+
if (s < valid_seq && d < valid_head) {
|
|
626
|
+
v0 = v_head[(seq_start + s) * head_dim + head_start + d];
|
|
627
|
+
}
|
|
628
|
+
if (s + 1 < valid_seq && d < valid_head) {
|
|
629
|
+
v1 = v_head[(seq_start + s + 1) * head_dim + head_start + d];
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
tile_ptr[dst_idx] = v0;
|
|
633
|
+
tile_ptr[dst_idx + 1] = v1;
|
|
634
|
+
}
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
}
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
/**
|
|
642
|
+
* @brief Extract K block from packed format: Kᵀ[head_dim, Bᶜ] for a given kv_block.
|
|
643
|
+
*
|
|
644
|
+
* K is packed as Kᵀ for Q × Kᵀ, with pair-interleaving.
|
|
645
|
+
* Output is in row-major F32 format: k_out[d × Bᶜ + kᵢ] = Kᵀ[d, kᵢ]
|
|
646
|
+
*/
|
|
647
|
+
NK_INTERNAL void nk_attention_extract_k_block_(nk_bf16_t const *k_packed, nk_f32_t *k_out, nk_size_t kv_h,
|
|
648
|
+
nk_size_t kv_block_start, nk_size_t valid_kv, nk_size_t head_dim,
|
|
649
|
+
nk_size_t kv_len) {
|
|
650
|
+
|
|
651
|
+
nk_size_t const Bc = 16;
|
|
652
|
+
nk_size_t head_dim_padded = nk_size_round_up_to_multiple_(head_dim, 32);
|
|
653
|
+
nk_size_t tiles_per_seq = nk_size_divide_round_up_(kv_len, 16);
|
|
654
|
+
nk_size_t tiles_per_depth = head_dim_padded / 32;
|
|
655
|
+
nk_size_t tile_size = 512;
|
|
656
|
+
|
|
657
|
+
nk_size_t seq_tile = kv_block_start / 16;
|
|
658
|
+
nk_size_t base_s = kv_block_start % 16;
|
|
659
|
+
|
|
660
|
+
// Get pointer to this head's K data
|
|
661
|
+
nk_bf16_t const *k_head = k_packed + kv_h * tiles_per_seq * tiles_per_depth * tile_size;
|
|
662
|
+
|
|
663
|
+
// Extract each depth tile
|
|
664
|
+
for (nk_size_t depth_tile = 0; depth_tile < tiles_per_depth; depth_tile++) {
|
|
665
|
+
nk_size_t depth_start = depth_tile * 32;
|
|
666
|
+
nk_size_t tile_idx = seq_tile * tiles_per_depth + depth_tile;
|
|
667
|
+
nk_bf16_t const *tile_ptr = k_head + tile_idx * tile_size;
|
|
668
|
+
|
|
669
|
+
// Unpack tile: pair-interleaved layout data[d/2][s][d%2]
|
|
670
|
+
for (nk_size_t d_in_tile = 0; d_in_tile < 32 && depth_start + d_in_tile < head_dim; d_in_tile++) {
|
|
671
|
+
nk_size_t d = depth_start + d_in_tile;
|
|
672
|
+
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
673
|
+
nk_size_t s_in_tile = base_s + ki;
|
|
674
|
+
if (s_in_tile >= 16) continue; // Shouldn't happen if kv_block aligned
|
|
675
|
+
|
|
676
|
+
nk_size_t elem_idx = (d_in_tile / 2) * 32 + s_in_tile * 2 + (d_in_tile % 2);
|
|
677
|
+
nk_bf16_t bf16_val = tile_ptr[elem_idx];
|
|
678
|
+
nk_f32_t f32_val;
|
|
679
|
+
nk_bf16_to_f32_serial(&bf16_val, &f32_val);
|
|
680
|
+
k_out[d * Bc + ki] = f32_val;
|
|
681
|
+
}
|
|
682
|
+
}
|
|
683
|
+
}
|
|
684
|
+
}
|
|
685
|
+
|
|
686
|
+
/**
|
|
687
|
+
* @brief Extract V block from packed format: V[Bᶜ, head_dim] for a given kv_block.
|
|
688
|
+
*
|
|
689
|
+
* V is packed for P × V, with pair-interleaving.
|
|
690
|
+
* Output is in row-major F32 format: v_out[kᵢ × head_dim + d] = V[kᵢ, d]
|
|
691
|
+
*/
|
|
692
|
+
NK_INTERNAL void nk_attention_extract_v_block_(nk_bf16_t const *v_packed, nk_f32_t *v_out, nk_size_t kv_h,
|
|
693
|
+
nk_size_t kv_block_start, nk_size_t valid_kv, nk_size_t head_dim,
|
|
694
|
+
nk_size_t kv_len) {
|
|
695
|
+
|
|
696
|
+
nk_size_t head_dim_padded = nk_size_round_up_to_multiple_(head_dim, 32);
|
|
697
|
+
nk_size_t tiles_per_seq = nk_size_divide_round_up_(kv_len, 32);
|
|
698
|
+
nk_size_t tiles_per_head = nk_size_divide_round_up_(head_dim_padded, 16);
|
|
699
|
+
nk_size_t tile_size = 512;
|
|
700
|
+
|
|
701
|
+
// Get pointer to this head's V data
|
|
702
|
+
nk_bf16_t const *v_head = v_packed + kv_h * tiles_per_seq * tiles_per_head * tile_size;
|
|
703
|
+
|
|
704
|
+
// For each kv position in the block
|
|
705
|
+
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
706
|
+
nk_size_t kv_pos = kv_block_start + ki;
|
|
707
|
+
nk_size_t seq_tile = kv_pos / 32;
|
|
708
|
+
nk_size_t s_in_tile = kv_pos % 32;
|
|
709
|
+
|
|
710
|
+
// Extract each head_dim tile
|
|
711
|
+
for (nk_size_t head_tile = 0; head_tile < tiles_per_head; head_tile++) {
|
|
712
|
+
nk_size_t head_start = head_tile * 16;
|
|
713
|
+
nk_size_t tile_idx = seq_tile * tiles_per_head + head_tile;
|
|
714
|
+
nk_bf16_t const *tile_ptr = v_head + tile_idx * tile_size;
|
|
715
|
+
|
|
716
|
+
// Unpack: pair-interleaved layout data[s/2][d][s%2]
|
|
717
|
+
for (nk_size_t d_in_tile = 0; d_in_tile < 16 && head_start + d_in_tile < head_dim; d_in_tile++) {
|
|
718
|
+
nk_size_t d = head_start + d_in_tile;
|
|
719
|
+
nk_size_t elem_idx = (s_in_tile / 2) * 32 + d_in_tile * 2 + (s_in_tile % 2);
|
|
720
|
+
nk_bf16_t bf16_val = tile_ptr[elem_idx];
|
|
721
|
+
nk_f32_t f32_val;
|
|
722
|
+
nk_bf16_to_f32_serial(&bf16_val, &f32_val);
|
|
723
|
+
v_out[ki * head_dim + d] = f32_val;
|
|
724
|
+
}
|
|
725
|
+
}
|
|
726
|
+
}
|
|
727
|
+
}
|
|
728
|
+
|
|
729
|
+
NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_packed, nk_f32_t *output,
|
|
730
|
+
nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
|
|
731
|
+
nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
|
|
732
|
+
|
|
733
|
+
nk_attention_kv_packed_header_t const *header = (nk_attention_kv_packed_header_t const *)kv_packed;
|
|
734
|
+
nk_size_t head_dim_padded = header->head_dim_padded;
|
|
735
|
+
nk_size_t gqa_ratio = num_heads / num_kv_heads;
|
|
736
|
+
|
|
737
|
+
// Tile sizes
|
|
738
|
+
nk_size_t const Br = 16; // Query block rows
|
|
739
|
+
nk_size_t const Bc = 16; // KV block columns
|
|
740
|
+
|
|
741
|
+
// Configure AMX tiles
|
|
742
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
743
|
+
|
|
744
|
+
// Temporary buffers (aligned to 64 bytes)
|
|
745
|
+
NK_ALIGN64 nk_f32_t scores[16 * 16]; // S = Q × Kᵀ block
|
|
746
|
+
NK_ALIGN64 nk_f32_t weights[16 * 16]; // P = softmax(S)
|
|
747
|
+
NK_ALIGN64 nk_f32_t o_acc[16 * 256]; // Output accumulator (max head_dim=256)
|
|
748
|
+
|
|
749
|
+
// Packed data pointers
|
|
750
|
+
nk_bf16_t const *k_packed = (nk_bf16_t const *)((char const *)kv_packed + header->k_offset);
|
|
751
|
+
nk_bf16_t const *v_packed = (nk_bf16_t const *)((char const *)kv_packed + header->v_offset);
|
|
752
|
+
|
|
753
|
+
// Process each head
|
|
754
|
+
for (nk_size_t h = 0; h < num_heads; h++) {
|
|
755
|
+
nk_size_t kv_h = h / gqa_ratio;
|
|
756
|
+
|
|
757
|
+
nk_bf16_t const *q_head = q + h * query_len * head_dim;
|
|
758
|
+
nk_f32_t *o_head = output + h * query_len * head_dim;
|
|
759
|
+
|
|
760
|
+
// Process query blocks
|
|
761
|
+
for (nk_size_t qb = 0; qb < query_len; qb += Br) {
|
|
762
|
+
nk_size_t valid_q = (qb + Br <= query_len) ? Br : (query_len - qb);
|
|
763
|
+
|
|
764
|
+
// Initialize softmax state and output accumulator
|
|
765
|
+
nk_attention_softmax_row_state_t softmax_state;
|
|
766
|
+
nk_attention_softmax_init_(&softmax_state);
|
|
767
|
+
|
|
768
|
+
for (nk_size_t i = 0; i < valid_q * head_dim_padded; i++) { o_acc[i] = 0.0f; }
|
|
769
|
+
|
|
770
|
+
// Temporary buffers for extracted K and V blocks
|
|
771
|
+
NK_ALIGN64 nk_f32_t k_block[16 * 256]; // Kᵀ block [head_dim, 16]
|
|
772
|
+
NK_ALIGN64 nk_f32_t v_block[16 * 256]; // V block [16, head_dim]
|
|
773
|
+
NK_ALIGN64 nk_f32_t q_block[16 * 256]; // Q block [16, head_dim]
|
|
774
|
+
|
|
775
|
+
// Pre-convert Q block to F32
|
|
776
|
+
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
777
|
+
for (nk_size_t d = 0; d < head_dim; d++) {
|
|
778
|
+
nk_bf16_t q_val = q_head[(qb + qi) * head_dim + d];
|
|
779
|
+
nk_bf16_to_f32_serial(&q_val, &q_block[qi * head_dim + d]);
|
|
780
|
+
}
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
// Process KV blocks
|
|
784
|
+
for (nk_size_t kvb = 0; kvb < kv_len; kvb += Bc) {
|
|
785
|
+
nk_size_t valid_kv = (kvb + Bc <= kv_len) ? Bc : (kv_len - kvb);
|
|
786
|
+
|
|
787
|
+
// Extract K block: Kᵀ[head_dim, valid_kv] using bulk extraction
|
|
788
|
+
nk_attention_extract_k_block_(k_packed, k_block, kv_h, kvb, valid_kv, head_dim, kv_len);
|
|
789
|
+
|
|
790
|
+
// Phase 1: Compute S = Q × Kᵀ using AVX-512 FMA
|
|
791
|
+
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
792
|
+
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
793
|
+
__m512 sum_v = _mm512_setzero_ps();
|
|
794
|
+
nk_size_t d = 0;
|
|
795
|
+
// Vectorized loop over head_dim
|
|
796
|
+
for (; d + 16 <= head_dim; d += 16) {
|
|
797
|
+
__m512 q_v = _mm512_loadu_ps(&q_block[qi * head_dim + d]);
|
|
798
|
+
// Kᵀ is stored as [head_dim, kv], gather is slow, use scalar for now
|
|
799
|
+
__m512 k_v = _mm512_set_ps(
|
|
800
|
+
k_block[(d + 15) * 16 + ki], k_block[(d + 14) * 16 + ki], k_block[(d + 13) * 16 + ki],
|
|
801
|
+
k_block[(d + 12) * 16 + ki], k_block[(d + 11) * 16 + ki], k_block[(d + 10) * 16 + ki],
|
|
802
|
+
k_block[(d + 9) * 16 + ki], k_block[(d + 8) * 16 + ki], k_block[(d + 7) * 16 + ki],
|
|
803
|
+
k_block[(d + 6) * 16 + ki], k_block[(d + 5) * 16 + ki], k_block[(d + 4) * 16 + ki],
|
|
804
|
+
k_block[(d + 3) * 16 + ki], k_block[(d + 2) * 16 + ki], k_block[(d + 1) * 16 + ki],
|
|
805
|
+
k_block[(d + 0) * 16 + ki]);
|
|
806
|
+
sum_v = _mm512_fmadd_ps(q_v, k_v, sum_v);
|
|
807
|
+
}
|
|
808
|
+
nk_f32_t sum = _mm512_reduce_add_ps(sum_v);
|
|
809
|
+
// Scalar tail
|
|
810
|
+
for (; d < head_dim; d++) { sum += q_block[qi * head_dim + d] * k_block[d * 16 + ki]; }
|
|
811
|
+
scores[qi * 16 + ki] = sum;
|
|
812
|
+
}
|
|
813
|
+
// Zero out invalid KV positions
|
|
814
|
+
for (nk_size_t ki = valid_kv; ki < 16; ki++) { scores[qi * 16 + ki] = NK_F32_MIN; }
|
|
815
|
+
}
|
|
816
|
+
// Zero out invalid query rows
|
|
817
|
+
for (nk_size_t qi = valid_q; qi < 16; qi++) {
|
|
818
|
+
for (nk_size_t ki = 0; ki < 16; ki++) { scores[qi * 16 + ki] = NK_F32_MIN; }
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
// Phase 2: Online softmax update
|
|
822
|
+
__m512 old_max = softmax_state.row_max;
|
|
823
|
+
nk_attention_softmax_update_(&softmax_state, scores, scale, weights);
|
|
824
|
+
|
|
825
|
+
// Rescale output accumulator if max changed
|
|
826
|
+
nk_attention_rescale_output_(o_acc, head_dim_padded, old_max, softmax_state.row_max);
|
|
827
|
+
|
|
828
|
+
// Extract V block: V[valid_kv, head_dim] using bulk extraction
|
|
829
|
+
nk_attention_extract_v_block_(v_packed, v_block, kv_h, kvb, valid_kv, head_dim, kv_len);
|
|
830
|
+
|
|
831
|
+
// Phase 3: Compute O += P × V using AVX-512 FMA
|
|
832
|
+
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
833
|
+
nk_size_t d = 0;
|
|
834
|
+
// Vectorized loop over head_dim
|
|
835
|
+
for (; d + 16 <= head_dim; d += 16) {
|
|
836
|
+
__m512 acc_v = _mm512_loadu_ps(&o_acc[qi * head_dim_padded + d]);
|
|
837
|
+
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
838
|
+
__m512 p_v = _mm512_set1_ps(weights[qi * 16 + ki]);
|
|
839
|
+
__m512 v_v = _mm512_loadu_ps(&v_block[ki * head_dim + d]);
|
|
840
|
+
acc_v = _mm512_fmadd_ps(p_v, v_v, acc_v);
|
|
841
|
+
}
|
|
842
|
+
_mm512_storeu_ps(&o_acc[qi * head_dim_padded + d], acc_v);
|
|
843
|
+
}
|
|
844
|
+
// Scalar tail
|
|
845
|
+
for (; d < head_dim; d++) {
|
|
846
|
+
nk_f32_t sum = o_acc[qi * head_dim_padded + d];
|
|
847
|
+
for (nk_size_t ki = 0; ki < valid_kv; ki++) {
|
|
848
|
+
sum += weights[qi * 16 + ki] * v_block[ki * head_dim + d];
|
|
849
|
+
}
|
|
850
|
+
o_acc[qi * head_dim_padded + d] = sum;
|
|
851
|
+
}
|
|
852
|
+
}
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
// Finalize: normalize O by row sums
|
|
856
|
+
float row_sums[16];
|
|
857
|
+
_mm512_store_ps(row_sums, softmax_state.row_sum);
|
|
858
|
+
|
|
859
|
+
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
860
|
+
nk_f32_t inv_sum = 1.0f / row_sums[qi];
|
|
861
|
+
for (nk_size_t d = 0; d < head_dim; d++) {
|
|
862
|
+
o_head[(qb + qi) * head_dim + d] = o_acc[qi * head_dim_padded + d] * inv_sum;
|
|
863
|
+
}
|
|
864
|
+
}
|
|
865
|
+
}
|
|
866
|
+
}
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void const *kv_packed, nk_f32_t *output,
|
|
870
|
+
nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
|
|
871
|
+
nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
|
|
872
|
+
|
|
873
|
+
nk_attention_kv_packed_header_t const *header = (nk_attention_kv_packed_header_t const *)kv_packed;
|
|
874
|
+
nk_size_t head_dim_padded = header->head_dim_padded;
|
|
875
|
+
nk_size_t gqa_ratio = num_heads / num_kv_heads;
|
|
876
|
+
|
|
877
|
+
// Block sizes - Bc=32 matches V tile depth granularity
|
|
878
|
+
nk_size_t const Br = 16;
|
|
879
|
+
nk_size_t const Bc = 32;
|
|
880
|
+
|
|
881
|
+
// Configure AMX tiles
|
|
882
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
883
|
+
|
|
884
|
+
// Buffers
|
|
885
|
+
NK_ALIGN64 nk_f32_t scores[16 * 32]; // S [16, 32]
|
|
886
|
+
NK_ALIGN64 nk_f32_t weights[16 * 32]; // P [16, 32]
|
|
887
|
+
NK_ALIGN64 nk_f32_t o_acc[16 * 256]; // Output accumulator
|
|
888
|
+
NK_ALIGN64 nk_bf16_t q_tile[16][32]; // Q as A-tile format
|
|
889
|
+
NK_ALIGN64 nk_f32_t s_tile[16][16]; // Score tile output (for each half)
|
|
890
|
+
NK_ALIGN64 nk_bf16_t p_tile[16][32]; // P weights as A-tile format
|
|
891
|
+
NK_ALIGN64 nk_f32_t o_tile[16][16]; // Output tile from AMX
|
|
892
|
+
|
|
893
|
+
// K packing layout (16 seq positions per tile)
|
|
894
|
+
nk_size_t k_tiles_per_seq = nk_size_divide_round_up_(kv_len, 16);
|
|
895
|
+
nk_size_t tiles_per_depth = head_dim_padded / 32;
|
|
896
|
+
nk_size_t tile_size = 512; // BF16 elements per tile
|
|
897
|
+
|
|
898
|
+
// V packing layout (32 seq positions per tile)
|
|
899
|
+
nk_size_t v_tiles_per_seq = nk_size_divide_round_up_(kv_len, 32);
|
|
900
|
+
nk_size_t v_tiles_per_head = nk_size_divide_round_up_(head_dim_padded, 16);
|
|
901
|
+
|
|
902
|
+
nk_bf16_t const *k_packed = (nk_bf16_t const *)((char const *)kv_packed + header->k_offset);
|
|
903
|
+
nk_bf16_t const *v_packed = (nk_bf16_t const *)((char const *)kv_packed + header->v_offset);
|
|
904
|
+
|
|
905
|
+
for (nk_size_t h = 0; h < num_heads; h++) {
|
|
906
|
+
nk_size_t kv_h = h / gqa_ratio;
|
|
907
|
+
nk_bf16_t const *q_head = q + h * query_len * head_dim;
|
|
908
|
+
nk_f32_t *o_head = output + h * query_len * head_dim;
|
|
909
|
+
|
|
910
|
+
// Pointer to this KV head's packed data
|
|
911
|
+
nk_bf16_t const *k_head = k_packed + kv_h * k_tiles_per_seq * tiles_per_depth * tile_size;
|
|
912
|
+
nk_bf16_t const *v_head = v_packed + kv_h * v_tiles_per_seq * v_tiles_per_head * tile_size;
|
|
913
|
+
|
|
914
|
+
for (nk_size_t qb = 0; qb < query_len; qb += Br) {
|
|
915
|
+
nk_size_t valid_q = (qb + Br <= query_len) ? Br : (query_len - qb);
|
|
916
|
+
|
|
917
|
+
nk_attention_softmax_row_state_t softmax_state;
|
|
918
|
+
nk_attention_softmax_init_(&softmax_state);
|
|
919
|
+
|
|
920
|
+
// Zero output accumulator using SIMD
|
|
921
|
+
__m512 zero = _mm512_setzero_ps();
|
|
922
|
+
for (nk_size_t i = 0; i < 16 * head_dim_padded; i += 64) {
|
|
923
|
+
_mm512_store_ps(&o_acc[i], zero);
|
|
924
|
+
_mm512_store_ps(&o_acc[i + 16], zero);
|
|
925
|
+
_mm512_store_ps(&o_acc[i + 32], zero);
|
|
926
|
+
_mm512_store_ps(&o_acc[i + 48], zero);
|
|
927
|
+
}
|
|
928
|
+
|
|
929
|
+
// Process KV blocks in chunks of 32
|
|
930
|
+
for (nk_size_t kvb = 0; kvb < kv_len; kvb += Bc) {
|
|
931
|
+
nk_size_t valid_kv = (kvb + Bc <= kv_len) ? Bc : (kv_len - kvb);
|
|
932
|
+
|
|
933
|
+
// Phase 1: S = Q × Kᵀ using AMX
|
|
934
|
+
// Need 2 K tiles per block (each K tile has 16 columns)
|
|
935
|
+
nk_size_t k_tile_idx0 = kvb / 16; // First K tile
|
|
936
|
+
nk_size_t k_tile_idx1 = (kvb + 16) / 16; // Second K tile
|
|
937
|
+
|
|
938
|
+
// Process first half: S[0:16, 0:16]
|
|
939
|
+
_tile_zero(0); // TMM0 = score accumulator for first 16 columns
|
|
940
|
+
_tile_zero(3); // TMM3 = score accumulator for second 16 columns
|
|
941
|
+
|
|
942
|
+
for (nk_size_t dt = 0; dt < tiles_per_depth; dt++) {
|
|
943
|
+
nk_size_t depth_start = dt * 32;
|
|
944
|
+
|
|
945
|
+
// Load Q[qb:qb+16, depth_start:depth_start+32] into A-tile format
|
|
946
|
+
// Use SIMD loads when possible (full 32 elements per row)
|
|
947
|
+
if (depth_start + 32 <= head_dim) {
|
|
948
|
+
// Full tile - use fast SIMD copy
|
|
949
|
+
for (nk_size_t row = 0; row < valid_q; row++) {
|
|
950
|
+
nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
|
|
951
|
+
// Load 32 BF16 values (64 bytes) using two 256-bit loads
|
|
952
|
+
__m256i q0 = _mm256_loadu_si256((__m256i const *)q_row);
|
|
953
|
+
__m256i q1 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
|
|
954
|
+
_mm256_store_si256((__m256i *)&q_tile[row][0], q0);
|
|
955
|
+
_mm256_store_si256((__m256i *)&q_tile[row][16], q1);
|
|
956
|
+
}
|
|
957
|
+
}
|
|
958
|
+
else {
|
|
959
|
+
// Partial tile - element-by-element with padding
|
|
960
|
+
nk_size_t valid_depth = head_dim - depth_start;
|
|
961
|
+
for (nk_size_t row = 0; row < valid_q; row++) {
|
|
962
|
+
nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
|
|
963
|
+
for (nk_size_t col = 0; col < 32; col++) {
|
|
964
|
+
q_tile[row][col] = (col < valid_depth) ? q_row[col] : 0;
|
|
965
|
+
}
|
|
966
|
+
}
|
|
967
|
+
}
|
|
968
|
+
// Zero pad remaining rows
|
|
969
|
+
for (nk_size_t row = valid_q; row < 16; row++) {
|
|
970
|
+
_mm256_store_si256((__m256i *)&q_tile[row][0], _mm256_setzero_si256());
|
|
971
|
+
_mm256_store_si256((__m256i *)&q_tile[row][16], _mm256_setzero_si256());
|
|
972
|
+
}
|
|
973
|
+
|
|
974
|
+
_tile_loadd(1, q_tile, 64); // A: 16×32 BF16
|
|
975
|
+
|
|
976
|
+
// First K tile (columns 0:16)
|
|
977
|
+
nk_bf16_t const *k_tile_ptr0 = k_head + (k_tile_idx0 * tiles_per_depth + dt) * tile_size;
|
|
978
|
+
_tile_loadd(2, k_tile_ptr0, 64); // B: 32×16 BF16
|
|
979
|
+
_tile_dpbf16ps(0, 1, 2); // TMM0 += Q × K0
|
|
980
|
+
|
|
981
|
+
// Second K tile (columns 16:32) if within bounds
|
|
982
|
+
if (kvb + 16 < kv_len) {
|
|
983
|
+
nk_bf16_t const *k_tile_ptr1 = k_head + (k_tile_idx1 * tiles_per_depth + dt) * tile_size;
|
|
984
|
+
_tile_loadd(4, k_tile_ptr1, 64); // B: 32×16 BF16
|
|
985
|
+
_tile_dpbf16ps(3, 1, 4); // TMM3 += Q × K1
|
|
986
|
+
}
|
|
987
|
+
}
|
|
988
|
+
|
|
989
|
+
// Store scores from TMM0 and TMM3
|
|
990
|
+
// Use SIMD for fast extraction
|
|
991
|
+
_tile_stored(0, s_tile, 64);
|
|
992
|
+
|
|
993
|
+
__m512 neg_inf = _mm512_set1_ps(NK_F32_MIN);
|
|
994
|
+
|
|
995
|
+
if (valid_q == 16 && valid_kv >= 16) {
|
|
996
|
+
// Fast path: full first half, just copy
|
|
997
|
+
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
998
|
+
__m512 s0 = _mm512_load_ps(&s_tile[qi][0]);
|
|
999
|
+
_mm512_store_ps(&scores[qi * 32], s0);
|
|
1000
|
+
}
|
|
1001
|
+
}
|
|
1002
|
+
else {
|
|
1003
|
+
// Partial - need masking
|
|
1004
|
+
__mmask16 kv_mask = (1u << valid_kv) - 1;
|
|
1005
|
+
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1006
|
+
__m512 s0 = _mm512_load_ps(&s_tile[qi][0]);
|
|
1007
|
+
if (qi < valid_q) { s0 = _mm512_mask_blend_ps(kv_mask, neg_inf, s0); }
|
|
1008
|
+
else { s0 = neg_inf; }
|
|
1009
|
+
_mm512_store_ps(&scores[qi * 32], s0);
|
|
1010
|
+
}
|
|
1011
|
+
}
|
|
1012
|
+
|
|
1013
|
+
// Second half scores (columns 16:32)
|
|
1014
|
+
if (kvb + 16 < kv_len) {
|
|
1015
|
+
_tile_stored(3, s_tile, 64);
|
|
1016
|
+
nk_size_t valid_kv2 = (valid_kv > 16) ? (valid_kv - 16) : 0;
|
|
1017
|
+
|
|
1018
|
+
if (valid_q == 16 && valid_kv2 >= 16) {
|
|
1019
|
+
// Fast path
|
|
1020
|
+
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1021
|
+
__m512 s1 = _mm512_load_ps(&s_tile[qi][0]);
|
|
1022
|
+
_mm512_store_ps(&scores[qi * 32 + 16], s1);
|
|
1023
|
+
}
|
|
1024
|
+
}
|
|
1025
|
+
else {
|
|
1026
|
+
__mmask16 kv_mask2 = (valid_kv2 >= 16) ? 0xFFFF : ((1u << valid_kv2) - 1);
|
|
1027
|
+
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1028
|
+
__m512 s1 = _mm512_load_ps(&s_tile[qi][0]);
|
|
1029
|
+
if (qi < valid_q) { s1 = _mm512_mask_blend_ps(kv_mask2, neg_inf, s1); }
|
|
1030
|
+
else { s1 = neg_inf; }
|
|
1031
|
+
_mm512_store_ps(&scores[qi * 32 + 16], s1);
|
|
1032
|
+
}
|
|
1033
|
+
}
|
|
1034
|
+
}
|
|
1035
|
+
else {
|
|
1036
|
+
// Mask out second half entirely
|
|
1037
|
+
for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi * 32 + 16], neg_inf); }
|
|
1038
|
+
}
|
|
1039
|
+
|
|
1040
|
+
// Phase 2: online softmax (fast degree-4 exp)
|
|
1041
|
+
__m512 old_max = softmax_state.row_max;
|
|
1042
|
+
nk_attention_softmax_update_bc32_fast_(&softmax_state, scores, scale, weights);
|
|
1043
|
+
nk_attention_rescale_output_(o_acc, head_dim_padded, old_max, softmax_state.row_max);
|
|
1044
|
+
|
|
1045
|
+
// Phase 3: O += P × V using AMX
|
|
1046
|
+
// Convert P[16, 32] from F32 to BF16 and pack as A-tile
|
|
1047
|
+
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1048
|
+
for (nk_size_t ki = 0; ki < 32; ki += 16) {
|
|
1049
|
+
__m512 p_f32 = _mm512_loadu_ps(&weights[qi * 32 + ki]);
|
|
1050
|
+
__m256bh p_bf16 = _mm512_cvtneps_pbh(p_f32);
|
|
1051
|
+
// Store BF16 vector - cast through union or memory
|
|
1052
|
+
*(__m256bh *)&p_tile[qi][ki] = p_bf16;
|
|
1053
|
+
}
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
// V tile index for this block
|
|
1057
|
+
nk_size_t v_seq_tile = kvb / 32;
|
|
1058
|
+
|
|
1059
|
+
// For each head_dim chunk of 16
|
|
1060
|
+
for (nk_size_t ht = 0; ht < v_tiles_per_head; ht++) {
|
|
1061
|
+
nk_size_t head_start = ht * 16;
|
|
1062
|
+
|
|
1063
|
+
// V tile is already packed: V[32, 16] in B-tile format
|
|
1064
|
+
nk_bf16_t const *v_tile_ptr = v_head + (v_seq_tile * v_tiles_per_head + ht) * tile_size;
|
|
1065
|
+
|
|
1066
|
+
// Zero output tile
|
|
1067
|
+
_tile_zero(5);
|
|
1068
|
+
|
|
1069
|
+
// Load P into TMM6 (A-tile: 16×32)
|
|
1070
|
+
_tile_loadd(6, p_tile, 64);
|
|
1071
|
+
|
|
1072
|
+
// Load V into TMM7 (B-tile: 32×16)
|
|
1073
|
+
_tile_loadd(7, v_tile_ptr, 64);
|
|
1074
|
+
|
|
1075
|
+
// O_tile = P × V
|
|
1076
|
+
_tile_dpbf16ps(5, 6, 7);
|
|
1077
|
+
|
|
1078
|
+
// Store and accumulate
|
|
1079
|
+
_tile_stored(5, o_tile, 64);
|
|
1080
|
+
|
|
1081
|
+
// Add to output accumulator - unrolled for all 16 rows
|
|
1082
|
+
// Even if valid_q < 16, we accumulate all (padded rows have zero weights)
|
|
1083
|
+
for (nk_size_t qi = 0; qi < 16; qi += 4) {
|
|
1084
|
+
__m512 acc0 = _mm512_load_ps(&o_acc[(qi + 0) * head_dim_padded + head_start]);
|
|
1085
|
+
__m512 acc1 = _mm512_load_ps(&o_acc[(qi + 1) * head_dim_padded + head_start]);
|
|
1086
|
+
__m512 acc2 = _mm512_load_ps(&o_acc[(qi + 2) * head_dim_padded + head_start]);
|
|
1087
|
+
__m512 acc3 = _mm512_load_ps(&o_acc[(qi + 3) * head_dim_padded + head_start]);
|
|
1088
|
+
|
|
1089
|
+
acc0 = _mm512_add_ps(acc0, _mm512_load_ps(&o_tile[qi + 0][0]));
|
|
1090
|
+
acc1 = _mm512_add_ps(acc1, _mm512_load_ps(&o_tile[qi + 1][0]));
|
|
1091
|
+
acc2 = _mm512_add_ps(acc2, _mm512_load_ps(&o_tile[qi + 2][0]));
|
|
1092
|
+
acc3 = _mm512_add_ps(acc3, _mm512_load_ps(&o_tile[qi + 3][0]));
|
|
1093
|
+
|
|
1094
|
+
_mm512_store_ps(&o_acc[(qi + 0) * head_dim_padded + head_start], acc0);
|
|
1095
|
+
_mm512_store_ps(&o_acc[(qi + 1) * head_dim_padded + head_start], acc1);
|
|
1096
|
+
_mm512_store_ps(&o_acc[(qi + 2) * head_dim_padded + head_start], acc2);
|
|
1097
|
+
_mm512_store_ps(&o_acc[(qi + 3) * head_dim_padded + head_start], acc3);
|
|
1098
|
+
}
|
|
1099
|
+
}
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
// Finalize: normalize O by row sums
|
|
1103
|
+
float row_sums[16];
|
|
1104
|
+
_mm512_store_ps(row_sums, softmax_state.row_sum);
|
|
1105
|
+
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
1106
|
+
nk_f32_t inv_sum = 1.0f / row_sums[qi];
|
|
1107
|
+
for (nk_size_t d = 0; d < head_dim; d++) {
|
|
1108
|
+
o_head[(qb + qi) * head_dim + d] = o_acc[qi * head_dim_padded + d] * inv_sum;
|
|
1109
|
+
}
|
|
1110
|
+
}
|
|
1111
|
+
}
|
|
1112
|
+
}
|
|
1113
|
+
}
|
|
1114
|
+
|
|
1115
|
+
NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, void const *kv_packed, nk_f32_t *output,
|
|
1116
|
+
nk_size_t num_heads, nk_size_t num_kv_heads,
|
|
1117
|
+
nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim,
|
|
1118
|
+
nk_f32_t scale) {
|
|
1119
|
+
|
|
1120
|
+
nk_attention_kv_packed_header_t const *header = (nk_attention_kv_packed_header_t const *)kv_packed;
|
|
1121
|
+
nk_size_t head_dim_padded = header->head_dim_padded;
|
|
1122
|
+
nk_size_t gqa_ratio = num_heads / num_kv_heads;
|
|
1123
|
+
|
|
1124
|
+
nk_size_t const Br = 16;
|
|
1125
|
+
nk_size_t const Bc = 32;
|
|
1126
|
+
|
|
1127
|
+
// Configure AMX tiles once
|
|
1128
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
1129
|
+
|
|
1130
|
+
// Tile dimensions
|
|
1131
|
+
nk_size_t tiles_per_depth = head_dim_padded / 32; // 4 for d=128
|
|
1132
|
+
nk_size_t v_tiles_per_head = nk_size_divide_round_up_(head_dim_padded, 16); // 8 for d=128
|
|
1133
|
+
|
|
1134
|
+
// K packing layout (16 seq positions per tile)
|
|
1135
|
+
nk_size_t k_tiles_per_seq = nk_size_divide_round_up_(kv_len, 16);
|
|
1136
|
+
nk_size_t tile_size = 512;
|
|
1137
|
+
|
|
1138
|
+
// V packing layout (32 seq positions per tile)
|
|
1139
|
+
nk_size_t v_tiles_per_seq = nk_size_divide_round_up_(kv_len, 32);
|
|
1140
|
+
|
|
1141
|
+
nk_bf16_t const *k_packed = (nk_bf16_t const *)((char const *)kv_packed + header->k_offset);
|
|
1142
|
+
nk_bf16_t const *v_packed = (nk_bf16_t const *)((char const *)kv_packed + header->v_offset);
|
|
1143
|
+
|
|
1144
|
+
// Pre-allocated buffers (all L1-resident)
|
|
1145
|
+
NK_ALIGN64 nk_bf16_t q_tiles[4][16][32]; // Q tiles for all depth chunks (max 4 for d=128)
|
|
1146
|
+
NK_ALIGN64 nk_f32_t scores[16][32]; // Score buffer (direct tile store target)
|
|
1147
|
+
NK_ALIGN64 nk_f32_t weights[16][32]; // Softmax output
|
|
1148
|
+
NK_ALIGN64 nk_bf16_t p_tile[16][32]; // P weights in BF16
|
|
1149
|
+
NK_ALIGN64 nk_f32_t o_tile[16][16]; // Output tile buffer
|
|
1150
|
+
NK_ALIGN64 nk_f32_t o_acc[16][256]; // Output accumulator (max d=256)
|
|
1151
|
+
|
|
1152
|
+
__m512 neg_inf = _mm512_set1_ps(NK_F32_MIN);
|
|
1153
|
+
|
|
1154
|
+
for (nk_size_t h = 0; h < num_heads; h++) {
|
|
1155
|
+
nk_size_t kv_h = h / gqa_ratio;
|
|
1156
|
+
nk_bf16_t const *q_head = q + h * query_len * head_dim;
|
|
1157
|
+
nk_f32_t *o_head = output + h * query_len * head_dim;
|
|
1158
|
+
|
|
1159
|
+
nk_bf16_t const *k_head = k_packed + kv_h * k_tiles_per_seq * tiles_per_depth * tile_size;
|
|
1160
|
+
nk_bf16_t const *v_head = v_packed + kv_h * v_tiles_per_seq * v_tiles_per_head * tile_size;
|
|
1161
|
+
|
|
1162
|
+
for (nk_size_t qb = 0; qb < query_len; qb += Br) {
|
|
1163
|
+
nk_size_t valid_q = (qb + Br <= query_len) ? Br : (query_len - qb);
|
|
1164
|
+
|
|
1165
|
+
// Pre-pack Q tiles once for all KV blocks
|
|
1166
|
+
for (nk_size_t dt = 0; dt < tiles_per_depth; dt++) {
|
|
1167
|
+
nk_size_t depth_start = dt * 32;
|
|
1168
|
+
if (depth_start + 32 <= head_dim) {
|
|
1169
|
+
// Full tile - fast SIMD copy
|
|
1170
|
+
for (nk_size_t row = 0; row < valid_q; row++) {
|
|
1171
|
+
nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
|
|
1172
|
+
__m256i q0 = _mm256_loadu_si256((__m256i const *)q_row);
|
|
1173
|
+
__m256i q1 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
|
|
1174
|
+
_mm256_store_si256((__m256i *)&q_tiles[dt][row][0], q0);
|
|
1175
|
+
_mm256_store_si256((__m256i *)&q_tiles[dt][row][16], q1);
|
|
1176
|
+
}
|
|
1177
|
+
// Zero remaining rows
|
|
1178
|
+
for (nk_size_t row = valid_q; row < 16; row++) {
|
|
1179
|
+
_mm256_store_si256((__m256i *)&q_tiles[dt][row][0], _mm256_setzero_si256());
|
|
1180
|
+
_mm256_store_si256((__m256i *)&q_tiles[dt][row][16], _mm256_setzero_si256());
|
|
1181
|
+
}
|
|
1182
|
+
}
|
|
1183
|
+
else {
|
|
1184
|
+
// Partial tile with padding
|
|
1185
|
+
nk_size_t valid_depth = head_dim - depth_start;
|
|
1186
|
+
for (nk_size_t row = 0; row < 16; row++) {
|
|
1187
|
+
for (nk_size_t col = 0; col < 32; col++) {
|
|
1188
|
+
if (row < valid_q && col < valid_depth) {
|
|
1189
|
+
q_tiles[dt][row][col] = q_head[(qb + row) * head_dim + depth_start + col];
|
|
1190
|
+
}
|
|
1191
|
+
else { q_tiles[dt][row][col] = 0; }
|
|
1192
|
+
}
|
|
1193
|
+
}
|
|
1194
|
+
}
|
|
1195
|
+
}
|
|
1196
|
+
|
|
1197
|
+
// Initialize softmax state and output accumulator
|
|
1198
|
+
nk_attention_softmax_row_state_t softmax_state;
|
|
1199
|
+
nk_attention_softmax_init_(&softmax_state);
|
|
1200
|
+
|
|
1201
|
+
__m512 zero = _mm512_setzero_ps();
|
|
1202
|
+
for (nk_size_t i = 0; i < 16 * head_dim_padded; i += 64) {
|
|
1203
|
+
_mm512_store_ps(&o_acc[0][i], zero);
|
|
1204
|
+
_mm512_store_ps(&o_acc[0][i + 16], zero);
|
|
1205
|
+
_mm512_store_ps(&o_acc[0][i + 32], zero);
|
|
1206
|
+
_mm512_store_ps(&o_acc[0][i + 48], zero);
|
|
1207
|
+
}
|
|
1208
|
+
|
|
1209
|
+
// Process KV blocks
|
|
1210
|
+
for (nk_size_t kvb = 0; kvb < kv_len; kvb += Bc) {
|
|
1211
|
+
nk_size_t valid_kv = (kvb + Bc <= kv_len) ? Bc : (kv_len - kvb);
|
|
1212
|
+
nk_size_t k_tile_idx0 = kvb / 16;
|
|
1213
|
+
nk_size_t k_tile_idx1 = (kvb + 16) / 16;
|
|
1214
|
+
|
|
1215
|
+
// Phase 1: S = Q × Kᵀ using pre-packed Q tiles
|
|
1216
|
+
_tile_zero(0); // Score cols 0:16
|
|
1217
|
+
_tile_zero(3); // Score cols 16:32
|
|
1218
|
+
|
|
1219
|
+
for (nk_size_t dt = 0; dt < tiles_per_depth; dt++) {
|
|
1220
|
+
// Load pre-packed Q tile from L1 (not global!)
|
|
1221
|
+
_tile_loadd(1, q_tiles[dt], 64);
|
|
1222
|
+
|
|
1223
|
+
// Load K tiles from global (necessary)
|
|
1224
|
+
nk_bf16_t const *k_tile_ptr0 = k_head + (k_tile_idx0 * tiles_per_depth + dt) * tile_size;
|
|
1225
|
+
_tile_loadd(2, k_tile_ptr0, 64);
|
|
1226
|
+
_tile_dpbf16ps(0, 1, 2);
|
|
1227
|
+
|
|
1228
|
+
if (kvb + 16 < kv_len) {
|
|
1229
|
+
nk_bf16_t const *k_tile_ptr1 = k_head + (k_tile_idx1 * tiles_per_depth + dt) * tile_size;
|
|
1230
|
+
_tile_loadd(4, k_tile_ptr1, 64);
|
|
1231
|
+
_tile_dpbf16ps(3, 1, 4);
|
|
1232
|
+
}
|
|
1233
|
+
}
|
|
1234
|
+
|
|
1235
|
+
// Store first 16 columns directly to scores[0:16]
|
|
1236
|
+
_tile_stored(0, &scores[0][0], 128); // stride=128 bytes (32 floats)
|
|
1237
|
+
|
|
1238
|
+
// Store second 16 columns to scores[16:32]
|
|
1239
|
+
if (kvb + 16 < kv_len) { _tile_stored(3, &scores[0][16], 128); }
|
|
1240
|
+
else {
|
|
1241
|
+
// Mask out second half
|
|
1242
|
+
for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi][16], neg_inf); }
|
|
1243
|
+
}
|
|
1244
|
+
|
|
1245
|
+
// Apply masking for invalid positions (only on boundaries)
|
|
1246
|
+
if (valid_q < 16 || valid_kv < 32) {
|
|
1247
|
+
__mmask16 kv_mask0 = (valid_kv >= 16) ? 0xFFFF : ((1u << valid_kv) - 1);
|
|
1248
|
+
__mmask16 kv_mask1 = (valid_kv > 16) ? ((1u << (valid_kv - 16)) - 1) : 0;
|
|
1249
|
+
if (valid_kv >= 32) kv_mask1 = 0xFFFF;
|
|
1250
|
+
|
|
1251
|
+
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1252
|
+
if (qi >= valid_q) {
|
|
1253
|
+
_mm512_store_ps(&scores[qi][0], neg_inf);
|
|
1254
|
+
_mm512_store_ps(&scores[qi][16], neg_inf);
|
|
1255
|
+
}
|
|
1256
|
+
else {
|
|
1257
|
+
__m512 s0 = _mm512_load_ps(&scores[qi][0]);
|
|
1258
|
+
__m512 s1 = _mm512_load_ps(&scores[qi][16]);
|
|
1259
|
+
_mm512_store_ps(&scores[qi][0], _mm512_mask_blend_ps(kv_mask0, neg_inf, s0));
|
|
1260
|
+
_mm512_store_ps(&scores[qi][16], _mm512_mask_blend_ps(kv_mask1, neg_inf, s1));
|
|
1261
|
+
}
|
|
1262
|
+
}
|
|
1263
|
+
}
|
|
1264
|
+
|
|
1265
|
+
// Phase 2: online softmax (fast degree-4 exp)
|
|
1266
|
+
__m512 old_max = softmax_state.row_max;
|
|
1267
|
+
nk_attention_softmax_update_bc32_fast_(&softmax_state, &scores[0][0], scale, &weights[0][0]);
|
|
1268
|
+
nk_attention_rescale_output_(&o_acc[0][0], head_dim_padded, old_max, softmax_state.row_max);
|
|
1269
|
+
|
|
1270
|
+
// Phase 3: O += P × V with hoisted P tile load
|
|
1271
|
+
// Convert F32 weights to BF16 P tile (once per KV block)
|
|
1272
|
+
for (nk_size_t qi = 0; qi < 16; qi++) {
|
|
1273
|
+
__m512 p0 = _mm512_load_ps(&weights[qi][0]);
|
|
1274
|
+
__m512 p1 = _mm512_load_ps(&weights[qi][16]);
|
|
1275
|
+
__m256bh pb0 = _mm512_cvtneps_pbh(p0);
|
|
1276
|
+
__m256bh pb1 = _mm512_cvtneps_pbh(p1);
|
|
1277
|
+
*(__m256bh *)&p_tile[qi][0] = pb0;
|
|
1278
|
+
*(__m256bh *)&p_tile[qi][16] = pb1;
|
|
1279
|
+
}
|
|
1280
|
+
|
|
1281
|
+
// Load P tile once, reuse for all V tiles
|
|
1282
|
+
_tile_loadd(6, p_tile, 64);
|
|
1283
|
+
|
|
1284
|
+
nk_size_t v_seq_tile = kvb / 32;
|
|
1285
|
+
|
|
1286
|
+
for (nk_size_t ht = 0; ht < v_tiles_per_head; ht++) {
|
|
1287
|
+
nk_size_t head_start = ht * 16;
|
|
1288
|
+
|
|
1289
|
+
// Load V tile from global
|
|
1290
|
+
nk_bf16_t const *v_tile_ptr = v_head + (v_seq_tile * v_tiles_per_head + ht) * tile_size;
|
|
1291
|
+
|
|
1292
|
+
_tile_zero(5);
|
|
1293
|
+
// P already in TMM6 - no reload!
|
|
1294
|
+
_tile_loadd(7, v_tile_ptr, 64);
|
|
1295
|
+
_tile_dpbf16ps(5, 6, 7);
|
|
1296
|
+
|
|
1297
|
+
// Store and accumulate
|
|
1298
|
+
_tile_stored(5, o_tile, 64);
|
|
1299
|
+
|
|
1300
|
+
// Accumulate into output (unrolled)
|
|
1301
|
+
for (nk_size_t qi = 0; qi < 16; qi += 4) {
|
|
1302
|
+
__m512 acc0 = _mm512_load_ps(&o_acc[qi + 0][head_start]);
|
|
1303
|
+
__m512 acc1 = _mm512_load_ps(&o_acc[qi + 1][head_start]);
|
|
1304
|
+
__m512 acc2 = _mm512_load_ps(&o_acc[qi + 2][head_start]);
|
|
1305
|
+
__m512 acc3 = _mm512_load_ps(&o_acc[qi + 3][head_start]);
|
|
1306
|
+
|
|
1307
|
+
acc0 = _mm512_add_ps(acc0, _mm512_load_ps(&o_tile[qi + 0][0]));
|
|
1308
|
+
acc1 = _mm512_add_ps(acc1, _mm512_load_ps(&o_tile[qi + 1][0]));
|
|
1309
|
+
acc2 = _mm512_add_ps(acc2, _mm512_load_ps(&o_tile[qi + 2][0]));
|
|
1310
|
+
acc3 = _mm512_add_ps(acc3, _mm512_load_ps(&o_tile[qi + 3][0]));
|
|
1311
|
+
|
|
1312
|
+
_mm512_store_ps(&o_acc[qi + 0][head_start], acc0);
|
|
1313
|
+
_mm512_store_ps(&o_acc[qi + 1][head_start], acc1);
|
|
1314
|
+
_mm512_store_ps(&o_acc[qi + 2][head_start], acc2);
|
|
1315
|
+
_mm512_store_ps(&o_acc[qi + 3][head_start], acc3);
|
|
1316
|
+
}
|
|
1317
|
+
}
|
|
1318
|
+
}
|
|
1319
|
+
|
|
1320
|
+
// Finalize: normalize O by row sums
|
|
1321
|
+
float row_sums[16];
|
|
1322
|
+
_mm512_store_ps(row_sums, softmax_state.row_sum);
|
|
1323
|
+
for (nk_size_t qi = 0; qi < valid_q; qi++) {
|
|
1324
|
+
__m512 inv_sum = _mm512_set1_ps(1.0f / row_sums[qi]);
|
|
1325
|
+
for (nk_size_t d = 0; d < head_dim; d += 16) {
|
|
1326
|
+
__m512 o = _mm512_load_ps(&o_acc[qi][d]);
|
|
1327
|
+
o = _mm512_mul_ps(o, inv_sum);
|
|
1328
|
+
_mm512_storeu_ps(&o_head[(qb + qi) * head_dim + d], o);
|
|
1329
|
+
}
|
|
1330
|
+
}
|
|
1331
|
+
}
|
|
1332
|
+
}
|
|
1333
|
+
}
|
|
1334
|
+
|
|
1335
|
+
NK_PUBLIC void nk_attention_causal_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_packed, nk_f32_t *output,
|
|
1336
|
+
nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
|
|
1337
|
+
nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
|
|
1338
|
+
|
|
1339
|
+
// For causal attention in autoregressive decode:
|
|
1340
|
+
// Query position q_pos can only attend to KV positions 0..q_pos
|
|
1341
|
+
// If kv_len == query_len (prefill), we need proper masking
|
|
1342
|
+
// If query_len == 1 (decode), the single query can see all KV
|
|
1343
|
+
|
|
1344
|
+
// Simplified: just call full attention for now
|
|
1345
|
+
// TODO: Implement proper causal masking with block skipping
|
|
1346
|
+
nk_attention_bf16_sapphireamx(q, kv_packed, output, num_heads, num_kv_heads, query_len, kv_len, head_dim, scale);
|
|
1347
|
+
}
|
|
1348
|
+
|
|
1349
|
+
#if defined(__clang__)
|
|
1350
|
+
#pragma clang attribute pop
|
|
1351
|
+
#elif defined(__GNUC__)
|
|
1352
|
+
#pragma GCC pop_options
|
|
1353
|
+
#endif
|
|
1354
|
+
|
|
1355
|
+
#if defined(__cplusplus)
|
|
1356
|
+
} // extern "C"
|
|
1357
|
+
#endif
|
|
1358
|
+
|
|
1359
|
+
#endif // NK_TARGET_SAPPHIREAMX
|
|
1360
|
+
#endif // NK_TARGET_X86_
|
|
1361
|
+
#endif // NK_ATTENTION_SAPPHIREAMX_H
|