numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1361 @@
1
+ /**
2
+ * @brief FlashAttention-style kernels for Intel Sapphire Rapids AMX.
3
+ * @file include/numkong/attention/sapphireamx.h
4
+ * @author Ash Vardanian
5
+ * @date January 5, 2026
6
+ *
7
+ * @sa include/numkong/attention.h
8
+ *
9
+ * This file implements FlashAttention-2 style scaled dot-product attention (SDPA) optimized
10
+ * for Intel AMX instructions on Sapphire Rapids CPUs. The kernel computes:
11
+ *
12
+ * O = softmax(Q × Kᵀ / √d) × V
13
+ *
14
+ * Key features:
15
+ * - Online softmax: Mathematically exact, processes KV blocks incrementally
16
+ * - Pre-packed KV cache: Amortizes packing cost for repeated inference
17
+ * - GQA/MQA support: Different num_heads and num_kv_heads for grouped-query attention
18
+ * - Causal masking: Optional masking for autoregressive generation
19
+ *
20
+ * Target models (2025):
21
+ * - Kimi K2: head_dim=112, 64 heads, MHA, 128K context
22
+ * - LLaMA 3.1 405B: head_dim=128, 128 heads, 16 KV heads (GQA 8:1), 128K context
23
+ * - Qwen 2.5 72B: head_dim=128, 64 heads, 8 KV heads (GQA 8:1), 32K context
24
+ *
25
+ * Performance comparison with H100 FlashAttention-2:
26
+ * - H100 SXM5: ~335 TFLOPS (35% of 989 TFLOPS peak), 80GB HBM3
27
+ * - 100-core SPR: ~40 TFLOPS with FlashAttention (13% of 300 TFLOPS peak)
28
+ * - CPU advantage: 512GB-2TB DDR5 vs 80GB HBM → supports 10-25⨯ longer contexts
29
+ *
30
+ * Expected performance per core:
31
+ * - Decode (query_len=1, kv_len=4K): 350-450 GOPS (softmax bound)
32
+ * - Prefill (query_len=64, kv_len=4K): 450-550 GOPS (better AMX utilization)
33
+ * - Long context (kv_len=64K+): 250-350 GOPS (memory bandwidth bound)
34
+ *
35
+ * Block sizes:
36
+ * - Bᵣ = 16 (query block rows, matches AMX tile height)
37
+ * - Bᶜ = 16 (KV block columns, fits 16×16 scores in 16 ZMM registers)
38
+ *
39
+ * Algorithm (FlashAttention-2 style):
40
+ * For each query block:
41
+ * Initialize O = 0, rowsum = 0, rowmax = -∞
42
+ * For each KV block:
43
+ * S = Q × Kᵀ using AMX TDPBF16PS
44
+ * Apply online softmax: rescale old values, accumulate new
45
+ * O = rescale(O) + P × V using AMX
46
+ * Finalize: normalize O by row sums
47
+ *
48
+ * @section sapphireamx_attention_instructions Relevant Instructions
49
+ *
50
+ * Intrinsic Instruction Sapphire
51
+ * _tile_dpbf16ps TDPBF16PS (TMM, TMM, TMM) ~16cy (16x16x32 BF16)
52
+ * _tile_dpbssd TDPBSSD (TMM, TMM, TMM) ~16cy (16x16x64 INT8)
53
+ * _tile_loadd TILELOADD (TMM, MEM) ~10cy @ p23
54
+ * _tile_stored TILESTORED (MEM, TMM) ~10cy @ p4
55
+ * _tile_zero TILEZERO (TMM) ~1cy
56
+ * _mm512_fmadd_ps VFMADD (ZMM, ZMM, ZMM) 4cy @ p05
57
+ * _mm512_mul_ps VMULPS (ZMM, ZMM, ZMM) 4cy @ p05
58
+ * _mm512_max_ps VMAXPS (ZMM, ZMM, ZMM) 4cy @ p05
59
+ * _mm512_reduce_max_ps (pseudo: VHADDPS chain) ~8cy
60
+ * _mm512_reduce_add_ps (pseudo: VHADDPS chain) ~8cy
61
+ */
62
+ #ifndef NK_ATTENTION_SAPPHIREAMX_H
63
+ #define NK_ATTENTION_SAPPHIREAMX_H
64
+
65
+ #if NK_TARGET_X86_
66
+ #if NK_TARGET_SAPPHIREAMX
67
+
68
+ #include "numkong/types.h"
69
+ #include "numkong/dots/sapphireamx.h"
70
+
71
+ #if defined(__cplusplus)
72
+ extern "C" {
73
+ #endif
74
+
75
+ #if defined(__clang__)
76
+ #pragma clang attribute push( \
77
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512bf16,f16c,fma,bmi,bmi2"))), \
78
+ apply_to = function)
79
+ #elif defined(__GNUC__)
80
+ #pragma GCC push_options
81
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512bf16", "f16c", "fma", \
82
+ "bmi", "bmi2")
83
+ #endif
84
+
85
+ /**
86
+ * @brief Packed KV cache header for attention (64-byte aligned).
87
+ *
88
+ * Layout in memory:
89
+ * [header: 64 bytes][K tiles: variable][V tiles: variable]
90
+ *
91
+ * K and V are packed in AMX tile format for efficient loading.
92
+ */
93
+ typedef struct {
94
+ nk_u32_t num_kv_heads; ///< Number of K/V heads (for GQA, may differ from Q heads)
95
+ nk_u32_t head_dim; ///< Original head dimension (64, 112, 128)
96
+ nk_u32_t head_dim_padded; ///< Padded to multiple of 32 for AMX tiles
97
+ nk_u32_t seq_len; ///< Current sequence length
98
+ nk_u32_t max_seq_len; ///< Maximum sequence length (for pre-allocation)
99
+ nk_u32_t k_offset; ///< Byte offset to K tiles from header start
100
+ nk_u32_t v_offset; ///< Byte offset to V tiles from header start
101
+ nk_u32_t reserved[9]; ///< Pad to 64 bytes
102
+ } nk_attention_kv_packed_header_t;
103
+
104
+ /**
105
+ * @brief Fast exp approximation for AVX-512.
106
+ *
107
+ * Uses Cody-Waite range reduction + Remez minimax polynomial.
108
+ * Accuracy: max error < 1 ULP for x ∈ [-87.3, 88.7] (float range).
109
+ * Performance: ~15-20 cycles for 16 floats.
110
+ */
111
+
112
+ /**
113
+ * @brief Fast vectorized exp(x) approximation using AVX-512.
114
+ *
115
+ * Algorithm:
116
+ * 1. Range reduction: x = n × ln(2) + r, where |r| < ln(2)/2
117
+ * 2. Polynomial approximation: exp(r) ≈ 1 + r + r²/2 + ... (degree 6)
118
+ * 3. Reconstruction: exp(x) = 2ⁿ × exp(r)
119
+ *
120
+ * @param x Input vector (16 floats)
121
+ * @return exp(x) for each element
122
+ */
123
+ NK_INTERNAL __m512 nk_exp_ps_avx512_(__m512 x) {
124
+ // Constants for Cody-Waite range reduction
125
+ const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
126
+ const __m512 ln2_hi = _mm512_set1_ps(0.693145751953125f);
127
+ const __m512 ln2_lo = _mm512_set1_ps(1.42860682030941723212e-6f);
128
+
129
+ // Clamp to avoid overflow/underflow
130
+ const __m512 max_x = _mm512_set1_ps(88.3762626647949f);
131
+ const __m512 min_x = _mm512_set1_ps(-87.3365447504021f);
132
+ x = _mm512_max_ps(_mm512_min_ps(x, max_x), min_x);
133
+
134
+ // n = round(x / ln(2))
135
+ __m512 n = _mm512_roundscale_ps(_mm512_mul_ps(x, log2e), _MM_FROUND_TO_NEAREST_INT);
136
+
137
+ // r = x - n × ln(2) using Cody-Waite for precision
138
+ __m512 r = _mm512_fnmadd_ps(n, ln2_hi, x);
139
+ r = _mm512_fnmadd_ps(n, ln2_lo, r);
140
+
141
+ // Polynomial approximation for exp(r): Remez minimax degree 6
142
+ // Coefficients optimized for [-ln(2)/2, ln(2)/2]
143
+ __m512 p = _mm512_set1_ps(1.9875691500e-4f);
144
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.3981999507e-3f));
145
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(8.3334519073e-3f));
146
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(4.1665858030e-2f));
147
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.6666665459e-1f));
148
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(5.0000001201e-1f));
149
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
150
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
151
+
152
+ // Reconstruct: exp(x) = 2ⁿ × exp(r)
153
+ // 2ⁿ via IEEE 754 exponent manipulation
154
+ __m512i ni = _mm512_cvtps_epi32(n);
155
+ ni = _mm512_add_epi32(ni, _mm512_set1_epi32(127));
156
+ ni = _mm512_slli_epi32(ni, 23);
157
+ __m512 pow2n = _mm512_castsi512_ps(ni);
158
+
159
+ return _mm512_mul_ps(p, pow2n);
160
+ }
161
+
162
+ /**
163
+ * @brief Faster exp(x) approximation using degree-4 polynomial.
164
+ *
165
+ * Trades accuracy for speed: ~0.1% relative error (vs <0.001% for degree-6).
166
+ * This is acceptable for softmax where:
167
+ * - Probabilities sum to 1 (normalization absorbs errors)
168
+ * - Relative ranking matters more than absolute values
169
+ *
170
+ * Performance: ~12-15 cycles for 16 floats (vs ~18-22 for degree-6)
171
+ *
172
+ * @param x Input vector (16 floats)
173
+ * @return exp(x) approximation
174
+ */
175
+ NK_INTERNAL __m512 nk_exp_ps_fast_avx512_(__m512 x) {
176
+ // Constants for Cody-Waite range reduction
177
+ const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
178
+ const __m512 ln2_hi = _mm512_set1_ps(0.693145751953125f);
179
+ const __m512 ln2_lo = _mm512_set1_ps(1.42860682030941723212e-6f);
180
+
181
+ // Clamp to avoid overflow/underflow (same as accurate version)
182
+ const __m512 max_x = _mm512_set1_ps(88.3762626647949f);
183
+ const __m512 min_x = _mm512_set1_ps(-87.3365447504021f);
184
+ x = _mm512_max_ps(_mm512_min_ps(x, max_x), min_x);
185
+
186
+ // n = round(x / ln(2))
187
+ __m512 n = _mm512_roundscale_ps(_mm512_mul_ps(x, log2e), _MM_FROUND_TO_NEAREST_INT);
188
+
189
+ // r = x - n × ln(2) using Cody-Waite for precision
190
+ __m512 r = _mm512_fnmadd_ps(n, ln2_hi, x);
191
+ r = _mm512_fnmadd_ps(n, ln2_lo, r);
192
+
193
+ // Polynomial approximation for exp(r): degree 4
194
+ // Optimized coefficients for [-ln(2)/2, ln(2)/2]
195
+ // exp(r) ≈ 1 + r + r²/2 + r³/6 + r⁴/24
196
+ // Using Horner form: ((c₄ × r + c₃) × r + c₂) × r + c₁) × r + c₀
197
+ __m512 p = _mm512_set1_ps(4.1666666667e-2f); // 1/24
198
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.6666666667e-1f)); // 1/6
199
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(5.0000000000e-1f)); // 1/2
200
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); // 1
201
+ p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); // 1
202
+
203
+ // Reconstruct: exp(x) = 2ⁿ × exp(r)
204
+ __m512i ni = _mm512_cvtps_epi32(n);
205
+ ni = _mm512_add_epi32(ni, _mm512_set1_epi32(127));
206
+ ni = _mm512_slli_epi32(ni, 23);
207
+ __m512 pow2n = _mm512_castsi512_ps(ni);
208
+
209
+ return _mm512_mul_ps(p, pow2n);
210
+ }
211
+
212
+ /**
213
+ * @brief Online softmax primitives.
214
+ *
215
+ * These implement the online softmax algorithm from FlashAttention.
216
+ * Key insight: softmax can be computed incrementally by tracking:
217
+ * - m: running maximum (for numerical stability)
218
+ * - l: running sum of exp(x - m)
219
+ *
220
+ * When a new block arrives with larger values:
221
+ * - Rescale old sum: l = l × exp(m_old - m_new)
222
+ * - Add new contributions: l += Σ exp(x_new - m_new)
223
+ */
224
+
225
+ /**
226
+ * @brief State for online softmax computation.
227
+ *
228
+ * Tracks per-row running maximum and sum for 16 rows.
229
+ */
230
+ typedef struct {
231
+ __m512 row_max; ///< Running max per row (16 values)
232
+ __m512 row_sum; ///< Running sum of exp(x - max) per row
233
+ } nk_attention_softmax_row_state_t;
234
+
235
+ /**
236
+ * @brief Update softmax state with Bᶜ=32 score block (optimized).
237
+ *
238
+ * Computes online softmax for 16×32 score block using AVX-512.
239
+ * Optimizations:
240
+ * - Process 4 rows at a time for better ILP
241
+ * - Keep scaled scores in registers to avoid reloading
242
+ * - Vectorized row sum accumulation
243
+ */
244
+ NK_INTERNAL void nk_attention_softmax_update_bc32_(nk_attention_softmax_row_state_t *state,
245
+ nk_f32_t const *scores, // [16, 32] score block
246
+ nk_f32_t scale,
247
+ nk_f32_t *weights_out) { // [16, 32] output weights
248
+
249
+ __m512 scale_v = _mm512_set1_ps(scale);
250
+
251
+ // Load and scale all scores, compute per-row max
252
+ // Store in temporary arrays to avoid register pressure
253
+ __m512 s_scaled[16][2];
254
+ NK_ALIGN64 float row_maxes[16];
255
+
256
+ // Process 4 rows at a time for ILP
257
+ for (int i = 0; i < 16; i += 4) {
258
+ // Row i
259
+ s_scaled[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v);
260
+ s_scaled[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v);
261
+ __m512 m0 = _mm512_max_ps(s_scaled[i][0], s_scaled[i][1]);
262
+
263
+ // Row i+1
264
+ s_scaled[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v);
265
+ s_scaled[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v);
266
+ __m512 m1 = _mm512_max_ps(s_scaled[i + 1][0], s_scaled[i + 1][1]);
267
+
268
+ // Row i+2
269
+ s_scaled[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v);
270
+ s_scaled[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v);
271
+ __m512 m2 = _mm512_max_ps(s_scaled[i + 2][0], s_scaled[i + 2][1]);
272
+
273
+ // Row i+3
274
+ s_scaled[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v);
275
+ s_scaled[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v);
276
+ __m512 m3 = _mm512_max_ps(s_scaled[i + 3][0], s_scaled[i + 3][1]);
277
+
278
+ // Reduce to scalar max
279
+ row_maxes[i] = _mm512_reduce_max_ps(m0);
280
+ row_maxes[i + 1] = _mm512_reduce_max_ps(m1);
281
+ row_maxes[i + 2] = _mm512_reduce_max_ps(m2);
282
+ row_maxes[i + 3] = _mm512_reduce_max_ps(m3);
283
+ }
284
+
285
+ __m512 row_max_new = _mm512_load_ps(row_maxes);
286
+ __m512 old_max = state->row_max;
287
+ __m512 new_max = _mm512_max_ps(old_max, row_max_new);
288
+
289
+ // Rescale old sum
290
+ __m512 correction = nk_exp_ps_avx512_(_mm512_sub_ps(old_max, new_max));
291
+ __m512 new_sum = _mm512_mul_ps(state->row_sum, correction);
292
+
293
+ // Compute P = exp(S - new_max) and accumulate sums
294
+ NK_ALIGN64 float new_max_arr[16];
295
+ NK_ALIGN64 float row_sums[16];
296
+ _mm512_store_ps(new_max_arr, new_max);
297
+
298
+ // Process rows
299
+ for (int i = 0; i < 16; i += 2) {
300
+ __m512 max_i = _mm512_set1_ps(new_max_arr[i]);
301
+ __m512 max_i1 = _mm512_set1_ps(new_max_arr[i + 1]);
302
+
303
+ // Row i
304
+ __m512 p0_i = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i][0], max_i));
305
+ __m512 p1_i = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i][1], max_i));
306
+ _mm512_store_ps(weights_out + i * 32 + 0, p0_i);
307
+ _mm512_store_ps(weights_out + i * 32 + 16, p1_i);
308
+ row_sums[i] = _mm512_reduce_add_ps(p0_i) + _mm512_reduce_add_ps(p1_i);
309
+
310
+ // Row i+1
311
+ __m512 p0_i1 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i + 1][0], max_i1));
312
+ __m512 p1_i1 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i + 1][1], max_i1));
313
+ _mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1);
314
+ _mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1);
315
+ row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1) + _mm512_reduce_add_ps(p1_i1);
316
+ }
317
+
318
+ // Add row sums to running sum vectorially
319
+ new_sum = _mm512_add_ps(new_sum, _mm512_load_ps(row_sums));
320
+
321
+ state->row_max = new_max;
322
+ state->row_sum = new_sum;
323
+ }
324
+
325
+ /**
326
+ * @brief Fast softmax update using degree-4 exp polynomial.
327
+ *
328
+ * Same algorithm as nk_attention_softmax_update_bc32_ but uses faster exp.
329
+ * Trades ~0.1% accuracy for ~20% performance improvement.
330
+ *
331
+ * Use this for inference where throughput matters more than last-bit accuracy.
332
+ */
333
+ NK_INTERNAL void nk_attention_softmax_update_bc32_fast_(nk_attention_softmax_row_state_t *state,
334
+ nk_f32_t const *scores, // [16, 32] score block
335
+ nk_f32_t scale,
336
+ nk_f32_t *weights_out) { // [16, 32] output weights
337
+
338
+ __m512 scale_v = _mm512_set1_ps(scale);
339
+
340
+ // Load and scale all scores, compute per-row max
341
+ __m512 s_scaled[16][2];
342
+ NK_ALIGN64 float row_maxes[16];
343
+
344
+ // Process 4 rows at a time for ILP
345
+ for (int i = 0; i < 16; i += 4) {
346
+ s_scaled[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v);
347
+ s_scaled[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v);
348
+ __m512 m0 = _mm512_max_ps(s_scaled[i][0], s_scaled[i][1]);
349
+
350
+ s_scaled[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v);
351
+ s_scaled[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v);
352
+ __m512 m1 = _mm512_max_ps(s_scaled[i + 1][0], s_scaled[i + 1][1]);
353
+
354
+ s_scaled[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v);
355
+ s_scaled[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v);
356
+ __m512 m2 = _mm512_max_ps(s_scaled[i + 2][0], s_scaled[i + 2][1]);
357
+
358
+ s_scaled[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v);
359
+ s_scaled[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v);
360
+ __m512 m3 = _mm512_max_ps(s_scaled[i + 3][0], s_scaled[i + 3][1]);
361
+
362
+ row_maxes[i] = _mm512_reduce_max_ps(m0);
363
+ row_maxes[i + 1] = _mm512_reduce_max_ps(m1);
364
+ row_maxes[i + 2] = _mm512_reduce_max_ps(m2);
365
+ row_maxes[i + 3] = _mm512_reduce_max_ps(m3);
366
+ }
367
+
368
+ __m512 row_max_new = _mm512_load_ps(row_maxes);
369
+ __m512 old_max = state->row_max;
370
+ __m512 new_max = _mm512_max_ps(old_max, row_max_new);
371
+
372
+ // Rescale old sum using fast exp
373
+ __m512 correction = nk_exp_ps_fast_avx512_(_mm512_sub_ps(old_max, new_max));
374
+ __m512 new_sum = _mm512_mul_ps(state->row_sum, correction);
375
+
376
+ // Compute P = exp(S - new_max) using fast exp
377
+ NK_ALIGN64 float new_max_arr[16];
378
+ NK_ALIGN64 float row_sums[16];
379
+ _mm512_store_ps(new_max_arr, new_max);
380
+
381
+ // Process rows with fast exp
382
+ for (int i = 0; i < 16; i += 2) {
383
+ __m512 max_i = _mm512_set1_ps(new_max_arr[i]);
384
+ __m512 max_i1 = _mm512_set1_ps(new_max_arr[i + 1]);
385
+
386
+ // Row i
387
+ __m512 p0_i = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i][0], max_i));
388
+ __m512 p1_i = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i][1], max_i));
389
+ _mm512_store_ps(weights_out + i * 32 + 0, p0_i);
390
+ _mm512_store_ps(weights_out + i * 32 + 16, p1_i);
391
+ row_sums[i] = _mm512_reduce_add_ps(p0_i) + _mm512_reduce_add_ps(p1_i);
392
+
393
+ // Row i+1
394
+ __m512 p0_i1 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i + 1][0], max_i1));
395
+ __m512 p1_i1 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i + 1][1], max_i1));
396
+ _mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1);
397
+ _mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1);
398
+ row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1) + _mm512_reduce_add_ps(p1_i1);
399
+ }
400
+
401
+ new_sum = _mm512_add_ps(new_sum, _mm512_load_ps(row_sums));
402
+
403
+ state->row_max = new_max;
404
+ state->row_sum = new_sum;
405
+ }
406
+
407
+ /**
408
+ * @brief Initialize online softmax state.
409
+ */
410
+ NK_INTERNAL void nk_attention_softmax_init_(nk_attention_softmax_row_state_t *state) {
411
+ state->row_max = _mm512_set1_ps(NK_F32_MIN);
412
+ state->row_sum = _mm512_setzero_ps();
413
+ }
414
+
415
+ /**
416
+ * @brief Update softmax state with new score block and compute attention weights.
417
+ *
418
+ * For a 16×16 score block S[16][16]:
419
+ * 1. Compute row-wise max of S
420
+ * 2. Update running max: newₘₐₓ = max(oldₘₐₓ, blockₘₐₓ)
421
+ * 3. Rescale old sum: oldₛᵤₘ × = exp(oldₘₐₓ - newₘₐₓ)
422
+ * 4. Compute P = exp(S - newₘₐₓ), store for P × V
423
+ * 5. Update sum: newₛᵤₘ = oldₛᵤₘ + row_sum(P)
424
+ *
425
+ * @param state Running softmax state (updated in place)
426
+ * @param scores 16×16 score block in row-major order (256 floats)
427
+ * @param scale Scaling factor (1/√head_dim)
428
+ * @param weights_out Output: 16×16 attention weights P (pre-softmax normalized)
429
+ */
430
+ NK_INTERNAL void nk_attention_softmax_update_(nk_attention_softmax_row_state_t *state, nk_f32_t const *scores,
431
+ nk_f32_t scale, nk_f32_t *weights_out) {
432
+
433
+ __m512 scale_v = _mm512_set1_ps(scale);
434
+
435
+ // Load scores into 16 ZMM registers (one per row)
436
+ __m512 s[16];
437
+ for (int i = 0; i < 16; i++) { s[i] = _mm512_mul_ps(_mm512_load_ps(scores + i * 16), scale_v); }
438
+
439
+ // Per-row max (each row has 16 elements, we need max across those 16)
440
+ // _mm512_reduce_max_ps returns a float scalar
441
+ NK_ALIGN64 float row_maxes[16];
442
+ for (int i = 0; i < 16; i++) { row_maxes[i] = _mm512_reduce_max_ps(s[i]); }
443
+ __m512 row_max_new = _mm512_load_ps(row_maxes);
444
+
445
+ // Update running max
446
+ __m512 old_max = state->row_max;
447
+ __m512 new_max = _mm512_max_ps(old_max, row_max_new);
448
+
449
+ // Rescale old sum: l = l × exp(oldₘₐₓ - newₘₐₓ)
450
+ __m512 correction = nk_exp_ps_avx512_(_mm512_sub_ps(old_max, new_max));
451
+ __m512 old_sum_rescaled = _mm512_mul_ps(state->row_sum, correction);
452
+
453
+ // Compute P = exp(S - newₘₐₓ) for each row, accumulate sum
454
+ __m512 new_sum = old_sum_rescaled;
455
+ float new_max_arr[16];
456
+ _mm512_store_ps(new_max_arr, new_max);
457
+
458
+ for (int i = 0; i < 16; i++) {
459
+ __m512 max_broadcast = _mm512_set1_ps(new_max_arr[i]);
460
+ __m512 p = nk_exp_ps_avx512_(_mm512_sub_ps(s[i], max_broadcast));
461
+ _mm512_store_ps(weights_out + i * 16, p);
462
+
463
+ // Add row sum to running sum (at position i)
464
+ float row_sum = _mm512_reduce_add_ps(p);
465
+ new_sum = _mm512_mask_add_ps(new_sum, 1u << i, new_sum, _mm512_set1_ps(row_sum));
466
+ }
467
+
468
+ state->row_max = new_max;
469
+ state->row_sum = new_sum;
470
+ }
471
+
472
+ /**
473
+ * @brief Rescale output accumulator when max changes.
474
+ *
475
+ * When processing a new KV block with larger scores, previous O accumulator
476
+ * needs rescaling: O = O × exp(oldₘₐₓ - newₘₐₓ)
477
+ *
478
+ * @param output Output accumulator [16][head_dim] in F32
479
+ * @param head_dim Head dimension
480
+ * @param old_max Previous running max per row (16 values)
481
+ * @param new_max New running max per row (16 values)
482
+ */
483
+ NK_INTERNAL void nk_attention_rescale_output_(nk_f32_t *output, nk_size_t head_dim, __m512 old_max, __m512 new_max) {
484
+
485
+ __m512 correction = nk_exp_ps_avx512_(_mm512_sub_ps(old_max, new_max));
486
+ float corr_arr[16];
487
+ _mm512_store_ps(corr_arr, correction);
488
+
489
+ for (nk_size_t row = 0; row < 16; row++) {
490
+ __m512 corr_v = _mm512_set1_ps(corr_arr[row]);
491
+ for (nk_size_t col = 0; col < head_dim; col += 16) {
492
+ __m512 o = _mm512_load_ps(output + row * head_dim + col);
493
+ o = _mm512_mul_ps(o, corr_v);
494
+ _mm512_store_ps(output + row * head_dim + col, o);
495
+ }
496
+ }
497
+ }
498
+
499
+ NK_PUBLIC nk_size_t nk_attention_kv_packed_size_sapphireamx(nk_size_t num_kv_heads, nk_size_t head_dim,
500
+ nk_size_t max_seq_len) {
501
+
502
+ // Pad head_dim to multiple of 32 for AMX tiles
503
+ nk_size_t head_dim_padded = nk_size_round_up_to_multiple_(head_dim, 32);
504
+
505
+ // Each head: seq_len × head_dim_padded BF16 values
506
+ // Packed in AMX tile format: 16-row tiles with pair-interleaving
507
+ nk_size_t tiles_per_head_col = nk_size_divide_round_up_(max_seq_len, 16);
508
+ nk_size_t tiles_per_head_depth = head_dim_padded / 32;
509
+ nk_size_t bytes_per_head = tiles_per_head_col * tiles_per_head_depth * 1024; // 1KB per tile
510
+
511
+ // K and V each have num_kv_heads heads
512
+ nk_size_t k_size = num_kv_heads * bytes_per_head;
513
+ nk_size_t v_size = num_kv_heads * bytes_per_head;
514
+
515
+ // Header + K + V, all 64-byte aligned
516
+ return sizeof(nk_attention_kv_packed_header_t) + k_size + v_size;
517
+ }
518
+
519
+ NK_PUBLIC void nk_attention_pack_k_sapphireamx(nk_bf16_t const *k, void *kv_packed, nk_size_t num_kv_heads,
520
+ nk_size_t seq_len, nk_size_t head_dim) {
521
+
522
+ nk_attention_kv_packed_header_t *header = (nk_attention_kv_packed_header_t *)kv_packed;
523
+
524
+ // Initialize header
525
+ nk_size_t head_dim_padded = nk_size_round_up_to_multiple_(head_dim, 32);
526
+ header->num_kv_heads = (nk_u32_t)num_kv_heads;
527
+ header->head_dim = (nk_u32_t)head_dim;
528
+ header->head_dim_padded = (nk_u32_t)head_dim_padded;
529
+ header->seq_len = (nk_u32_t)seq_len;
530
+ header->k_offset = sizeof(nk_attention_kv_packed_header_t);
531
+
532
+ nk_bf16_t *k_packed = (nk_bf16_t *)((char *)kv_packed + header->k_offset);
533
+
534
+ // For Q × Kᵀ, K acts as B matrix but transposed
535
+ // K[h, s, d] → Kᵀ[h, d, s]
536
+ // Pack Kᵀ into AMX B tile format with pair-interleaving
537
+
538
+ nk_size_t tiles_per_seq = nk_size_divide_round_up_(seq_len, 16);
539
+ nk_size_t tiles_per_depth = head_dim_padded / 32;
540
+ nk_size_t tile_size = 512; // BF16 elements per tile
541
+
542
+ for (nk_size_t h = 0; h < num_kv_heads; h++) {
543
+ nk_bf16_t const *k_head = k + h * seq_len * head_dim;
544
+ nk_bf16_t *k_head_packed = k_packed + h * tiles_per_seq * tiles_per_depth * tile_size;
545
+
546
+ // Pack tiles: iterate over seq_len tiles (columns of Kᵀ) and depth tiles
547
+ for (nk_size_t seq_tile = 0; seq_tile < tiles_per_seq; seq_tile++) {
548
+ nk_size_t seq_start = seq_tile * 16;
549
+ nk_size_t valid_seq = (seq_start + 16 <= seq_len) ? 16 : (seq_len - seq_start);
550
+
551
+ for (nk_size_t depth_tile = 0; depth_tile < tiles_per_depth; depth_tile++) {
552
+ nk_size_t depth_start = depth_tile * 32;
553
+ nk_size_t valid_depth = (depth_start + 32 <= head_dim) ? 32 : (head_dim - depth_start);
554
+
555
+ // Tile index in packed format
556
+ nk_size_t tile_idx = seq_tile * tiles_per_depth + depth_tile;
557
+ nk_bf16_t *tile_ptr = k_head_packed + tile_idx * tile_size;
558
+
559
+ // Pack with pair-interleaving for TDPBF16PS
560
+ // B tile layout: data[depth/2][col][depth%2]
561
+ // For Kᵀ: depth is original head_dim, col is original seq position
562
+ for (nk_size_t d = 0; d < 32; d += 2) {
563
+ for (nk_size_t s = 0; s < 16; s++) {
564
+ nk_size_t dst_idx = (d / 2) * 32 + s * 2;
565
+
566
+ // K[h, seq_start + s, depth_start + d] and K[h, seq_start + s, depth_start + d + 1]
567
+ nk_bf16_t v0 = 0, v1 = 0;
568
+ if (s < valid_seq && d < valid_depth) {
569
+ v0 = k_head[(seq_start + s) * head_dim + depth_start + d];
570
+ }
571
+ if (s < valid_seq && d + 1 < valid_depth) {
572
+ v1 = k_head[(seq_start + s) * head_dim + depth_start + d + 1];
573
+ }
574
+
575
+ tile_ptr[dst_idx] = v0;
576
+ tile_ptr[dst_idx + 1] = v1;
577
+ }
578
+ }
579
+ }
580
+ }
581
+ }
582
+
583
+ // Calculate V offset
584
+ nk_size_t k_size = num_kv_heads * tiles_per_seq * tiles_per_depth * tile_size * sizeof(nk_bf16_t);
585
+ header->v_offset = header->k_offset + (nk_u32_t)k_size;
586
+ }
587
+
588
+ NK_PUBLIC void nk_attention_pack_v_sapphireamx(nk_bf16_t const *v, void *kv_packed, nk_size_t num_kv_heads,
589
+ nk_size_t seq_len, nk_size_t head_dim) {
590
+
591
+ nk_attention_kv_packed_header_t *header = (nk_attention_kv_packed_header_t *)kv_packed;
592
+ nk_size_t head_dim_padded = header->head_dim_padded;
593
+
594
+ nk_bf16_t *v_packed = (nk_bf16_t *)((char *)kv_packed + header->v_offset);
595
+
596
+ // For P @ V, P is [query_len, seq_len], V is [seq_len, head_dim]
597
+ // V acts as B matrix: pack with seq_len as "depth", head_dim as "columns"
598
+
599
+ nk_size_t tiles_per_seq = nk_size_divide_round_up_(seq_len, 32); // seq_len is depth for V
600
+ nk_size_t tiles_per_head = nk_size_divide_round_up_(head_dim_padded, 16); // head_dim is columns
601
+ nk_size_t tile_size = 512;
602
+
603
+ for (nk_size_t h = 0; h < num_kv_heads; h++) {
604
+ nk_bf16_t const *v_head = v + h * seq_len * head_dim;
605
+ nk_bf16_t *v_head_packed = v_packed + h * tiles_per_seq * tiles_per_head * tile_size;
606
+
607
+ for (nk_size_t seq_tile = 0; seq_tile < tiles_per_seq; seq_tile++) {
608
+ nk_size_t seq_start = seq_tile * 32;
609
+ nk_size_t valid_seq = (seq_start + 32 <= seq_len) ? 32 : (seq_len - seq_start);
610
+
611
+ for (nk_size_t head_tile = 0; head_tile < tiles_per_head; head_tile++) {
612
+ nk_size_t head_start = head_tile * 16;
613
+ nk_size_t valid_head = (head_start + 16 <= head_dim) ? 16 : (head_dim - head_start);
614
+
615
+ nk_size_t tile_idx = seq_tile * tiles_per_head + head_tile;
616
+ nk_bf16_t *tile_ptr = v_head_packed + tile_idx * tile_size;
617
+
618
+ // Pack with pair-interleaving
619
+ // B tile: data[depth/2][col][depth%2] where depth=seq, col=head_dim
620
+ for (nk_size_t s = 0; s < 32; s += 2) {
621
+ for (nk_size_t d = 0; d < 16; d++) {
622
+ nk_size_t dst_idx = (s / 2) * 32 + d * 2;
623
+
624
+ nk_bf16_t v0 = 0, v1 = 0;
625
+ if (s < valid_seq && d < valid_head) {
626
+ v0 = v_head[(seq_start + s) * head_dim + head_start + d];
627
+ }
628
+ if (s + 1 < valid_seq && d < valid_head) {
629
+ v1 = v_head[(seq_start + s + 1) * head_dim + head_start + d];
630
+ }
631
+
632
+ tile_ptr[dst_idx] = v0;
633
+ tile_ptr[dst_idx + 1] = v1;
634
+ }
635
+ }
636
+ }
637
+ }
638
+ }
639
+ }
640
+
641
+ /**
642
+ * @brief Extract K block from packed format: Kᵀ[head_dim, Bᶜ] for a given kv_block.
643
+ *
644
+ * K is packed as Kᵀ for Q × Kᵀ, with pair-interleaving.
645
+ * Output is in row-major F32 format: k_out[d × Bᶜ + kᵢ] = Kᵀ[d, kᵢ]
646
+ */
647
+ NK_INTERNAL void nk_attention_extract_k_block_(nk_bf16_t const *k_packed, nk_f32_t *k_out, nk_size_t kv_h,
648
+ nk_size_t kv_block_start, nk_size_t valid_kv, nk_size_t head_dim,
649
+ nk_size_t kv_len) {
650
+
651
+ nk_size_t const Bc = 16;
652
+ nk_size_t head_dim_padded = nk_size_round_up_to_multiple_(head_dim, 32);
653
+ nk_size_t tiles_per_seq = nk_size_divide_round_up_(kv_len, 16);
654
+ nk_size_t tiles_per_depth = head_dim_padded / 32;
655
+ nk_size_t tile_size = 512;
656
+
657
+ nk_size_t seq_tile = kv_block_start / 16;
658
+ nk_size_t base_s = kv_block_start % 16;
659
+
660
+ // Get pointer to this head's K data
661
+ nk_bf16_t const *k_head = k_packed + kv_h * tiles_per_seq * tiles_per_depth * tile_size;
662
+
663
+ // Extract each depth tile
664
+ for (nk_size_t depth_tile = 0; depth_tile < tiles_per_depth; depth_tile++) {
665
+ nk_size_t depth_start = depth_tile * 32;
666
+ nk_size_t tile_idx = seq_tile * tiles_per_depth + depth_tile;
667
+ nk_bf16_t const *tile_ptr = k_head + tile_idx * tile_size;
668
+
669
+ // Unpack tile: pair-interleaved layout data[d/2][s][d%2]
670
+ for (nk_size_t d_in_tile = 0; d_in_tile < 32 && depth_start + d_in_tile < head_dim; d_in_tile++) {
671
+ nk_size_t d = depth_start + d_in_tile;
672
+ for (nk_size_t ki = 0; ki < valid_kv; ki++) {
673
+ nk_size_t s_in_tile = base_s + ki;
674
+ if (s_in_tile >= 16) continue; // Shouldn't happen if kv_block aligned
675
+
676
+ nk_size_t elem_idx = (d_in_tile / 2) * 32 + s_in_tile * 2 + (d_in_tile % 2);
677
+ nk_bf16_t bf16_val = tile_ptr[elem_idx];
678
+ nk_f32_t f32_val;
679
+ nk_bf16_to_f32_serial(&bf16_val, &f32_val);
680
+ k_out[d * Bc + ki] = f32_val;
681
+ }
682
+ }
683
+ }
684
+ }
685
+
686
+ /**
687
+ * @brief Extract V block from packed format: V[Bᶜ, head_dim] for a given kv_block.
688
+ *
689
+ * V is packed for P × V, with pair-interleaving.
690
+ * Output is in row-major F32 format: v_out[kᵢ × head_dim + d] = V[kᵢ, d]
691
+ */
692
+ NK_INTERNAL void nk_attention_extract_v_block_(nk_bf16_t const *v_packed, nk_f32_t *v_out, nk_size_t kv_h,
693
+ nk_size_t kv_block_start, nk_size_t valid_kv, nk_size_t head_dim,
694
+ nk_size_t kv_len) {
695
+
696
+ nk_size_t head_dim_padded = nk_size_round_up_to_multiple_(head_dim, 32);
697
+ nk_size_t tiles_per_seq = nk_size_divide_round_up_(kv_len, 32);
698
+ nk_size_t tiles_per_head = nk_size_divide_round_up_(head_dim_padded, 16);
699
+ nk_size_t tile_size = 512;
700
+
701
+ // Get pointer to this head's V data
702
+ nk_bf16_t const *v_head = v_packed + kv_h * tiles_per_seq * tiles_per_head * tile_size;
703
+
704
+ // For each kv position in the block
705
+ for (nk_size_t ki = 0; ki < valid_kv; ki++) {
706
+ nk_size_t kv_pos = kv_block_start + ki;
707
+ nk_size_t seq_tile = kv_pos / 32;
708
+ nk_size_t s_in_tile = kv_pos % 32;
709
+
710
+ // Extract each head_dim tile
711
+ for (nk_size_t head_tile = 0; head_tile < tiles_per_head; head_tile++) {
712
+ nk_size_t head_start = head_tile * 16;
713
+ nk_size_t tile_idx = seq_tile * tiles_per_head + head_tile;
714
+ nk_bf16_t const *tile_ptr = v_head + tile_idx * tile_size;
715
+
716
+ // Unpack: pair-interleaved layout data[s/2][d][s%2]
717
+ for (nk_size_t d_in_tile = 0; d_in_tile < 16 && head_start + d_in_tile < head_dim; d_in_tile++) {
718
+ nk_size_t d = head_start + d_in_tile;
719
+ nk_size_t elem_idx = (s_in_tile / 2) * 32 + d_in_tile * 2 + (s_in_tile % 2);
720
+ nk_bf16_t bf16_val = tile_ptr[elem_idx];
721
+ nk_f32_t f32_val;
722
+ nk_bf16_to_f32_serial(&bf16_val, &f32_val);
723
+ v_out[ki * head_dim + d] = f32_val;
724
+ }
725
+ }
726
+ }
727
+ }
728
+
729
+ NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_packed, nk_f32_t *output,
730
+ nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
731
+ nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
732
+
733
+ nk_attention_kv_packed_header_t const *header = (nk_attention_kv_packed_header_t const *)kv_packed;
734
+ nk_size_t head_dim_padded = header->head_dim_padded;
735
+ nk_size_t gqa_ratio = num_heads / num_kv_heads;
736
+
737
+ // Tile sizes
738
+ nk_size_t const Br = 16; // Query block rows
739
+ nk_size_t const Bc = 16; // KV block columns
740
+
741
+ // Configure AMX tiles
742
+ nk_amx_tile_configure_sapphireamx_();
743
+
744
+ // Temporary buffers (aligned to 64 bytes)
745
+ NK_ALIGN64 nk_f32_t scores[16 * 16]; // S = Q × Kᵀ block
746
+ NK_ALIGN64 nk_f32_t weights[16 * 16]; // P = softmax(S)
747
+ NK_ALIGN64 nk_f32_t o_acc[16 * 256]; // Output accumulator (max head_dim=256)
748
+
749
+ // Packed data pointers
750
+ nk_bf16_t const *k_packed = (nk_bf16_t const *)((char const *)kv_packed + header->k_offset);
751
+ nk_bf16_t const *v_packed = (nk_bf16_t const *)((char const *)kv_packed + header->v_offset);
752
+
753
+ // Process each head
754
+ for (nk_size_t h = 0; h < num_heads; h++) {
755
+ nk_size_t kv_h = h / gqa_ratio;
756
+
757
+ nk_bf16_t const *q_head = q + h * query_len * head_dim;
758
+ nk_f32_t *o_head = output + h * query_len * head_dim;
759
+
760
+ // Process query blocks
761
+ for (nk_size_t qb = 0; qb < query_len; qb += Br) {
762
+ nk_size_t valid_q = (qb + Br <= query_len) ? Br : (query_len - qb);
763
+
764
+ // Initialize softmax state and output accumulator
765
+ nk_attention_softmax_row_state_t softmax_state;
766
+ nk_attention_softmax_init_(&softmax_state);
767
+
768
+ for (nk_size_t i = 0; i < valid_q * head_dim_padded; i++) { o_acc[i] = 0.0f; }
769
+
770
+ // Temporary buffers for extracted K and V blocks
771
+ NK_ALIGN64 nk_f32_t k_block[16 * 256]; // Kᵀ block [head_dim, 16]
772
+ NK_ALIGN64 nk_f32_t v_block[16 * 256]; // V block [16, head_dim]
773
+ NK_ALIGN64 nk_f32_t q_block[16 * 256]; // Q block [16, head_dim]
774
+
775
+ // Pre-convert Q block to F32
776
+ for (nk_size_t qi = 0; qi < valid_q; qi++) {
777
+ for (nk_size_t d = 0; d < head_dim; d++) {
778
+ nk_bf16_t q_val = q_head[(qb + qi) * head_dim + d];
779
+ nk_bf16_to_f32_serial(&q_val, &q_block[qi * head_dim + d]);
780
+ }
781
+ }
782
+
783
+ // Process KV blocks
784
+ for (nk_size_t kvb = 0; kvb < kv_len; kvb += Bc) {
785
+ nk_size_t valid_kv = (kvb + Bc <= kv_len) ? Bc : (kv_len - kvb);
786
+
787
+ // Extract K block: Kᵀ[head_dim, valid_kv] using bulk extraction
788
+ nk_attention_extract_k_block_(k_packed, k_block, kv_h, kvb, valid_kv, head_dim, kv_len);
789
+
790
+ // Phase 1: Compute S = Q × Kᵀ using AVX-512 FMA
791
+ for (nk_size_t qi = 0; qi < valid_q; qi++) {
792
+ for (nk_size_t ki = 0; ki < valid_kv; ki++) {
793
+ __m512 sum_v = _mm512_setzero_ps();
794
+ nk_size_t d = 0;
795
+ // Vectorized loop over head_dim
796
+ for (; d + 16 <= head_dim; d += 16) {
797
+ __m512 q_v = _mm512_loadu_ps(&q_block[qi * head_dim + d]);
798
+ // Kᵀ is stored as [head_dim, kv], gather is slow, use scalar for now
799
+ __m512 k_v = _mm512_set_ps(
800
+ k_block[(d + 15) * 16 + ki], k_block[(d + 14) * 16 + ki], k_block[(d + 13) * 16 + ki],
801
+ k_block[(d + 12) * 16 + ki], k_block[(d + 11) * 16 + ki], k_block[(d + 10) * 16 + ki],
802
+ k_block[(d + 9) * 16 + ki], k_block[(d + 8) * 16 + ki], k_block[(d + 7) * 16 + ki],
803
+ k_block[(d + 6) * 16 + ki], k_block[(d + 5) * 16 + ki], k_block[(d + 4) * 16 + ki],
804
+ k_block[(d + 3) * 16 + ki], k_block[(d + 2) * 16 + ki], k_block[(d + 1) * 16 + ki],
805
+ k_block[(d + 0) * 16 + ki]);
806
+ sum_v = _mm512_fmadd_ps(q_v, k_v, sum_v);
807
+ }
808
+ nk_f32_t sum = _mm512_reduce_add_ps(sum_v);
809
+ // Scalar tail
810
+ for (; d < head_dim; d++) { sum += q_block[qi * head_dim + d] * k_block[d * 16 + ki]; }
811
+ scores[qi * 16 + ki] = sum;
812
+ }
813
+ // Zero out invalid KV positions
814
+ for (nk_size_t ki = valid_kv; ki < 16; ki++) { scores[qi * 16 + ki] = NK_F32_MIN; }
815
+ }
816
+ // Zero out invalid query rows
817
+ for (nk_size_t qi = valid_q; qi < 16; qi++) {
818
+ for (nk_size_t ki = 0; ki < 16; ki++) { scores[qi * 16 + ki] = NK_F32_MIN; }
819
+ }
820
+
821
+ // Phase 2: Online softmax update
822
+ __m512 old_max = softmax_state.row_max;
823
+ nk_attention_softmax_update_(&softmax_state, scores, scale, weights);
824
+
825
+ // Rescale output accumulator if max changed
826
+ nk_attention_rescale_output_(o_acc, head_dim_padded, old_max, softmax_state.row_max);
827
+
828
+ // Extract V block: V[valid_kv, head_dim] using bulk extraction
829
+ nk_attention_extract_v_block_(v_packed, v_block, kv_h, kvb, valid_kv, head_dim, kv_len);
830
+
831
+ // Phase 3: Compute O += P × V using AVX-512 FMA
832
+ for (nk_size_t qi = 0; qi < valid_q; qi++) {
833
+ nk_size_t d = 0;
834
+ // Vectorized loop over head_dim
835
+ for (; d + 16 <= head_dim; d += 16) {
836
+ __m512 acc_v = _mm512_loadu_ps(&o_acc[qi * head_dim_padded + d]);
837
+ for (nk_size_t ki = 0; ki < valid_kv; ki++) {
838
+ __m512 p_v = _mm512_set1_ps(weights[qi * 16 + ki]);
839
+ __m512 v_v = _mm512_loadu_ps(&v_block[ki * head_dim + d]);
840
+ acc_v = _mm512_fmadd_ps(p_v, v_v, acc_v);
841
+ }
842
+ _mm512_storeu_ps(&o_acc[qi * head_dim_padded + d], acc_v);
843
+ }
844
+ // Scalar tail
845
+ for (; d < head_dim; d++) {
846
+ nk_f32_t sum = o_acc[qi * head_dim_padded + d];
847
+ for (nk_size_t ki = 0; ki < valid_kv; ki++) {
848
+ sum += weights[qi * 16 + ki] * v_block[ki * head_dim + d];
849
+ }
850
+ o_acc[qi * head_dim_padded + d] = sum;
851
+ }
852
+ }
853
+ }
854
+
855
+ // Finalize: normalize O by row sums
856
+ float row_sums[16];
857
+ _mm512_store_ps(row_sums, softmax_state.row_sum);
858
+
859
+ for (nk_size_t qi = 0; qi < valid_q; qi++) {
860
+ nk_f32_t inv_sum = 1.0f / row_sums[qi];
861
+ for (nk_size_t d = 0; d < head_dim; d++) {
862
+ o_head[(qb + qi) * head_dim + d] = o_acc[qi * head_dim_padded + d] * inv_sum;
863
+ }
864
+ }
865
+ }
866
+ }
867
+ }
868
+
869
+ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void const *kv_packed, nk_f32_t *output,
870
+ nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
871
+ nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
872
+
873
+ nk_attention_kv_packed_header_t const *header = (nk_attention_kv_packed_header_t const *)kv_packed;
874
+ nk_size_t head_dim_padded = header->head_dim_padded;
875
+ nk_size_t gqa_ratio = num_heads / num_kv_heads;
876
+
877
+ // Block sizes - Bc=32 matches V tile depth granularity
878
+ nk_size_t const Br = 16;
879
+ nk_size_t const Bc = 32;
880
+
881
+ // Configure AMX tiles
882
+ nk_amx_tile_configure_sapphireamx_();
883
+
884
+ // Buffers
885
+ NK_ALIGN64 nk_f32_t scores[16 * 32]; // S [16, 32]
886
+ NK_ALIGN64 nk_f32_t weights[16 * 32]; // P [16, 32]
887
+ NK_ALIGN64 nk_f32_t o_acc[16 * 256]; // Output accumulator
888
+ NK_ALIGN64 nk_bf16_t q_tile[16][32]; // Q as A-tile format
889
+ NK_ALIGN64 nk_f32_t s_tile[16][16]; // Score tile output (for each half)
890
+ NK_ALIGN64 nk_bf16_t p_tile[16][32]; // P weights as A-tile format
891
+ NK_ALIGN64 nk_f32_t o_tile[16][16]; // Output tile from AMX
892
+
893
+ // K packing layout (16 seq positions per tile)
894
+ nk_size_t k_tiles_per_seq = nk_size_divide_round_up_(kv_len, 16);
895
+ nk_size_t tiles_per_depth = head_dim_padded / 32;
896
+ nk_size_t tile_size = 512; // BF16 elements per tile
897
+
898
+ // V packing layout (32 seq positions per tile)
899
+ nk_size_t v_tiles_per_seq = nk_size_divide_round_up_(kv_len, 32);
900
+ nk_size_t v_tiles_per_head = nk_size_divide_round_up_(head_dim_padded, 16);
901
+
902
+ nk_bf16_t const *k_packed = (nk_bf16_t const *)((char const *)kv_packed + header->k_offset);
903
+ nk_bf16_t const *v_packed = (nk_bf16_t const *)((char const *)kv_packed + header->v_offset);
904
+
905
+ for (nk_size_t h = 0; h < num_heads; h++) {
906
+ nk_size_t kv_h = h / gqa_ratio;
907
+ nk_bf16_t const *q_head = q + h * query_len * head_dim;
908
+ nk_f32_t *o_head = output + h * query_len * head_dim;
909
+
910
+ // Pointer to this KV head's packed data
911
+ nk_bf16_t const *k_head = k_packed + kv_h * k_tiles_per_seq * tiles_per_depth * tile_size;
912
+ nk_bf16_t const *v_head = v_packed + kv_h * v_tiles_per_seq * v_tiles_per_head * tile_size;
913
+
914
+ for (nk_size_t qb = 0; qb < query_len; qb += Br) {
915
+ nk_size_t valid_q = (qb + Br <= query_len) ? Br : (query_len - qb);
916
+
917
+ nk_attention_softmax_row_state_t softmax_state;
918
+ nk_attention_softmax_init_(&softmax_state);
919
+
920
+ // Zero output accumulator using SIMD
921
+ __m512 zero = _mm512_setzero_ps();
922
+ for (nk_size_t i = 0; i < 16 * head_dim_padded; i += 64) {
923
+ _mm512_store_ps(&o_acc[i], zero);
924
+ _mm512_store_ps(&o_acc[i + 16], zero);
925
+ _mm512_store_ps(&o_acc[i + 32], zero);
926
+ _mm512_store_ps(&o_acc[i + 48], zero);
927
+ }
928
+
929
+ // Process KV blocks in chunks of 32
930
+ for (nk_size_t kvb = 0; kvb < kv_len; kvb += Bc) {
931
+ nk_size_t valid_kv = (kvb + Bc <= kv_len) ? Bc : (kv_len - kvb);
932
+
933
+ // Phase 1: S = Q × Kᵀ using AMX
934
+ // Need 2 K tiles per block (each K tile has 16 columns)
935
+ nk_size_t k_tile_idx0 = kvb / 16; // First K tile
936
+ nk_size_t k_tile_idx1 = (kvb + 16) / 16; // Second K tile
937
+
938
+ // Process first half: S[0:16, 0:16]
939
+ _tile_zero(0); // TMM0 = score accumulator for first 16 columns
940
+ _tile_zero(3); // TMM3 = score accumulator for second 16 columns
941
+
942
+ for (nk_size_t dt = 0; dt < tiles_per_depth; dt++) {
943
+ nk_size_t depth_start = dt * 32;
944
+
945
+ // Load Q[qb:qb+16, depth_start:depth_start+32] into A-tile format
946
+ // Use SIMD loads when possible (full 32 elements per row)
947
+ if (depth_start + 32 <= head_dim) {
948
+ // Full tile - use fast SIMD copy
949
+ for (nk_size_t row = 0; row < valid_q; row++) {
950
+ nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
951
+ // Load 32 BF16 values (64 bytes) using two 256-bit loads
952
+ __m256i q0 = _mm256_loadu_si256((__m256i const *)q_row);
953
+ __m256i q1 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
954
+ _mm256_store_si256((__m256i *)&q_tile[row][0], q0);
955
+ _mm256_store_si256((__m256i *)&q_tile[row][16], q1);
956
+ }
957
+ }
958
+ else {
959
+ // Partial tile - element-by-element with padding
960
+ nk_size_t valid_depth = head_dim - depth_start;
961
+ for (nk_size_t row = 0; row < valid_q; row++) {
962
+ nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
963
+ for (nk_size_t col = 0; col < 32; col++) {
964
+ q_tile[row][col] = (col < valid_depth) ? q_row[col] : 0;
965
+ }
966
+ }
967
+ }
968
+ // Zero pad remaining rows
969
+ for (nk_size_t row = valid_q; row < 16; row++) {
970
+ _mm256_store_si256((__m256i *)&q_tile[row][0], _mm256_setzero_si256());
971
+ _mm256_store_si256((__m256i *)&q_tile[row][16], _mm256_setzero_si256());
972
+ }
973
+
974
+ _tile_loadd(1, q_tile, 64); // A: 16×32 BF16
975
+
976
+ // First K tile (columns 0:16)
977
+ nk_bf16_t const *k_tile_ptr0 = k_head + (k_tile_idx0 * tiles_per_depth + dt) * tile_size;
978
+ _tile_loadd(2, k_tile_ptr0, 64); // B: 32×16 BF16
979
+ _tile_dpbf16ps(0, 1, 2); // TMM0 += Q × K0
980
+
981
+ // Second K tile (columns 16:32) if within bounds
982
+ if (kvb + 16 < kv_len) {
983
+ nk_bf16_t const *k_tile_ptr1 = k_head + (k_tile_idx1 * tiles_per_depth + dt) * tile_size;
984
+ _tile_loadd(4, k_tile_ptr1, 64); // B: 32×16 BF16
985
+ _tile_dpbf16ps(3, 1, 4); // TMM3 += Q × K1
986
+ }
987
+ }
988
+
989
+ // Store scores from TMM0 and TMM3
990
+ // Use SIMD for fast extraction
991
+ _tile_stored(0, s_tile, 64);
992
+
993
+ __m512 neg_inf = _mm512_set1_ps(NK_F32_MIN);
994
+
995
+ if (valid_q == 16 && valid_kv >= 16) {
996
+ // Fast path: full first half, just copy
997
+ for (nk_size_t qi = 0; qi < 16; qi++) {
998
+ __m512 s0 = _mm512_load_ps(&s_tile[qi][0]);
999
+ _mm512_store_ps(&scores[qi * 32], s0);
1000
+ }
1001
+ }
1002
+ else {
1003
+ // Partial - need masking
1004
+ __mmask16 kv_mask = (1u << valid_kv) - 1;
1005
+ for (nk_size_t qi = 0; qi < 16; qi++) {
1006
+ __m512 s0 = _mm512_load_ps(&s_tile[qi][0]);
1007
+ if (qi < valid_q) { s0 = _mm512_mask_blend_ps(kv_mask, neg_inf, s0); }
1008
+ else { s0 = neg_inf; }
1009
+ _mm512_store_ps(&scores[qi * 32], s0);
1010
+ }
1011
+ }
1012
+
1013
+ // Second half scores (columns 16:32)
1014
+ if (kvb + 16 < kv_len) {
1015
+ _tile_stored(3, s_tile, 64);
1016
+ nk_size_t valid_kv2 = (valid_kv > 16) ? (valid_kv - 16) : 0;
1017
+
1018
+ if (valid_q == 16 && valid_kv2 >= 16) {
1019
+ // Fast path
1020
+ for (nk_size_t qi = 0; qi < 16; qi++) {
1021
+ __m512 s1 = _mm512_load_ps(&s_tile[qi][0]);
1022
+ _mm512_store_ps(&scores[qi * 32 + 16], s1);
1023
+ }
1024
+ }
1025
+ else {
1026
+ __mmask16 kv_mask2 = (valid_kv2 >= 16) ? 0xFFFF : ((1u << valid_kv2) - 1);
1027
+ for (nk_size_t qi = 0; qi < 16; qi++) {
1028
+ __m512 s1 = _mm512_load_ps(&s_tile[qi][0]);
1029
+ if (qi < valid_q) { s1 = _mm512_mask_blend_ps(kv_mask2, neg_inf, s1); }
1030
+ else { s1 = neg_inf; }
1031
+ _mm512_store_ps(&scores[qi * 32 + 16], s1);
1032
+ }
1033
+ }
1034
+ }
1035
+ else {
1036
+ // Mask out second half entirely
1037
+ for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi * 32 + 16], neg_inf); }
1038
+ }
1039
+
1040
+ // Phase 2: online softmax (fast degree-4 exp)
1041
+ __m512 old_max = softmax_state.row_max;
1042
+ nk_attention_softmax_update_bc32_fast_(&softmax_state, scores, scale, weights);
1043
+ nk_attention_rescale_output_(o_acc, head_dim_padded, old_max, softmax_state.row_max);
1044
+
1045
+ // Phase 3: O += P × V using AMX
1046
+ // Convert P[16, 32] from F32 to BF16 and pack as A-tile
1047
+ for (nk_size_t qi = 0; qi < 16; qi++) {
1048
+ for (nk_size_t ki = 0; ki < 32; ki += 16) {
1049
+ __m512 p_f32 = _mm512_loadu_ps(&weights[qi * 32 + ki]);
1050
+ __m256bh p_bf16 = _mm512_cvtneps_pbh(p_f32);
1051
+ // Store BF16 vector - cast through union or memory
1052
+ *(__m256bh *)&p_tile[qi][ki] = p_bf16;
1053
+ }
1054
+ }
1055
+
1056
+ // V tile index for this block
1057
+ nk_size_t v_seq_tile = kvb / 32;
1058
+
1059
+ // For each head_dim chunk of 16
1060
+ for (nk_size_t ht = 0; ht < v_tiles_per_head; ht++) {
1061
+ nk_size_t head_start = ht * 16;
1062
+
1063
+ // V tile is already packed: V[32, 16] in B-tile format
1064
+ nk_bf16_t const *v_tile_ptr = v_head + (v_seq_tile * v_tiles_per_head + ht) * tile_size;
1065
+
1066
+ // Zero output tile
1067
+ _tile_zero(5);
1068
+
1069
+ // Load P into TMM6 (A-tile: 16×32)
1070
+ _tile_loadd(6, p_tile, 64);
1071
+
1072
+ // Load V into TMM7 (B-tile: 32×16)
1073
+ _tile_loadd(7, v_tile_ptr, 64);
1074
+
1075
+ // O_tile = P × V
1076
+ _tile_dpbf16ps(5, 6, 7);
1077
+
1078
+ // Store and accumulate
1079
+ _tile_stored(5, o_tile, 64);
1080
+
1081
+ // Add to output accumulator - unrolled for all 16 rows
1082
+ // Even if valid_q < 16, we accumulate all (padded rows have zero weights)
1083
+ for (nk_size_t qi = 0; qi < 16; qi += 4) {
1084
+ __m512 acc0 = _mm512_load_ps(&o_acc[(qi + 0) * head_dim_padded + head_start]);
1085
+ __m512 acc1 = _mm512_load_ps(&o_acc[(qi + 1) * head_dim_padded + head_start]);
1086
+ __m512 acc2 = _mm512_load_ps(&o_acc[(qi + 2) * head_dim_padded + head_start]);
1087
+ __m512 acc3 = _mm512_load_ps(&o_acc[(qi + 3) * head_dim_padded + head_start]);
1088
+
1089
+ acc0 = _mm512_add_ps(acc0, _mm512_load_ps(&o_tile[qi + 0][0]));
1090
+ acc1 = _mm512_add_ps(acc1, _mm512_load_ps(&o_tile[qi + 1][0]));
1091
+ acc2 = _mm512_add_ps(acc2, _mm512_load_ps(&o_tile[qi + 2][0]));
1092
+ acc3 = _mm512_add_ps(acc3, _mm512_load_ps(&o_tile[qi + 3][0]));
1093
+
1094
+ _mm512_store_ps(&o_acc[(qi + 0) * head_dim_padded + head_start], acc0);
1095
+ _mm512_store_ps(&o_acc[(qi + 1) * head_dim_padded + head_start], acc1);
1096
+ _mm512_store_ps(&o_acc[(qi + 2) * head_dim_padded + head_start], acc2);
1097
+ _mm512_store_ps(&o_acc[(qi + 3) * head_dim_padded + head_start], acc3);
1098
+ }
1099
+ }
1100
+ }
1101
+
1102
+ // Finalize: normalize O by row sums
1103
+ float row_sums[16];
1104
+ _mm512_store_ps(row_sums, softmax_state.row_sum);
1105
+ for (nk_size_t qi = 0; qi < valid_q; qi++) {
1106
+ nk_f32_t inv_sum = 1.0f / row_sums[qi];
1107
+ for (nk_size_t d = 0; d < head_dim; d++) {
1108
+ o_head[(qb + qi) * head_dim + d] = o_acc[qi * head_dim_padded + d] * inv_sum;
1109
+ }
1110
+ }
1111
+ }
1112
+ }
1113
+ }
1114
+
1115
+ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, void const *kv_packed, nk_f32_t *output,
1116
+ nk_size_t num_heads, nk_size_t num_kv_heads,
1117
+ nk_size_t query_len, nk_size_t kv_len, nk_size_t head_dim,
1118
+ nk_f32_t scale) {
1119
+
1120
+ nk_attention_kv_packed_header_t const *header = (nk_attention_kv_packed_header_t const *)kv_packed;
1121
+ nk_size_t head_dim_padded = header->head_dim_padded;
1122
+ nk_size_t gqa_ratio = num_heads / num_kv_heads;
1123
+
1124
+ nk_size_t const Br = 16;
1125
+ nk_size_t const Bc = 32;
1126
+
1127
+ // Configure AMX tiles once
1128
+ nk_amx_tile_configure_sapphireamx_();
1129
+
1130
+ // Tile dimensions
1131
+ nk_size_t tiles_per_depth = head_dim_padded / 32; // 4 for d=128
1132
+ nk_size_t v_tiles_per_head = nk_size_divide_round_up_(head_dim_padded, 16); // 8 for d=128
1133
+
1134
+ // K packing layout (16 seq positions per tile)
1135
+ nk_size_t k_tiles_per_seq = nk_size_divide_round_up_(kv_len, 16);
1136
+ nk_size_t tile_size = 512;
1137
+
1138
+ // V packing layout (32 seq positions per tile)
1139
+ nk_size_t v_tiles_per_seq = nk_size_divide_round_up_(kv_len, 32);
1140
+
1141
+ nk_bf16_t const *k_packed = (nk_bf16_t const *)((char const *)kv_packed + header->k_offset);
1142
+ nk_bf16_t const *v_packed = (nk_bf16_t const *)((char const *)kv_packed + header->v_offset);
1143
+
1144
+ // Pre-allocated buffers (all L1-resident)
1145
+ NK_ALIGN64 nk_bf16_t q_tiles[4][16][32]; // Q tiles for all depth chunks (max 4 for d=128)
1146
+ NK_ALIGN64 nk_f32_t scores[16][32]; // Score buffer (direct tile store target)
1147
+ NK_ALIGN64 nk_f32_t weights[16][32]; // Softmax output
1148
+ NK_ALIGN64 nk_bf16_t p_tile[16][32]; // P weights in BF16
1149
+ NK_ALIGN64 nk_f32_t o_tile[16][16]; // Output tile buffer
1150
+ NK_ALIGN64 nk_f32_t o_acc[16][256]; // Output accumulator (max d=256)
1151
+
1152
+ __m512 neg_inf = _mm512_set1_ps(NK_F32_MIN);
1153
+
1154
+ for (nk_size_t h = 0; h < num_heads; h++) {
1155
+ nk_size_t kv_h = h / gqa_ratio;
1156
+ nk_bf16_t const *q_head = q + h * query_len * head_dim;
1157
+ nk_f32_t *o_head = output + h * query_len * head_dim;
1158
+
1159
+ nk_bf16_t const *k_head = k_packed + kv_h * k_tiles_per_seq * tiles_per_depth * tile_size;
1160
+ nk_bf16_t const *v_head = v_packed + kv_h * v_tiles_per_seq * v_tiles_per_head * tile_size;
1161
+
1162
+ for (nk_size_t qb = 0; qb < query_len; qb += Br) {
1163
+ nk_size_t valid_q = (qb + Br <= query_len) ? Br : (query_len - qb);
1164
+
1165
+ // Pre-pack Q tiles once for all KV blocks
1166
+ for (nk_size_t dt = 0; dt < tiles_per_depth; dt++) {
1167
+ nk_size_t depth_start = dt * 32;
1168
+ if (depth_start + 32 <= head_dim) {
1169
+ // Full tile - fast SIMD copy
1170
+ for (nk_size_t row = 0; row < valid_q; row++) {
1171
+ nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
1172
+ __m256i q0 = _mm256_loadu_si256((__m256i const *)q_row);
1173
+ __m256i q1 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
1174
+ _mm256_store_si256((__m256i *)&q_tiles[dt][row][0], q0);
1175
+ _mm256_store_si256((__m256i *)&q_tiles[dt][row][16], q1);
1176
+ }
1177
+ // Zero remaining rows
1178
+ for (nk_size_t row = valid_q; row < 16; row++) {
1179
+ _mm256_store_si256((__m256i *)&q_tiles[dt][row][0], _mm256_setzero_si256());
1180
+ _mm256_store_si256((__m256i *)&q_tiles[dt][row][16], _mm256_setzero_si256());
1181
+ }
1182
+ }
1183
+ else {
1184
+ // Partial tile with padding
1185
+ nk_size_t valid_depth = head_dim - depth_start;
1186
+ for (nk_size_t row = 0; row < 16; row++) {
1187
+ for (nk_size_t col = 0; col < 32; col++) {
1188
+ if (row < valid_q && col < valid_depth) {
1189
+ q_tiles[dt][row][col] = q_head[(qb + row) * head_dim + depth_start + col];
1190
+ }
1191
+ else { q_tiles[dt][row][col] = 0; }
1192
+ }
1193
+ }
1194
+ }
1195
+ }
1196
+
1197
+ // Initialize softmax state and output accumulator
1198
+ nk_attention_softmax_row_state_t softmax_state;
1199
+ nk_attention_softmax_init_(&softmax_state);
1200
+
1201
+ __m512 zero = _mm512_setzero_ps();
1202
+ for (nk_size_t i = 0; i < 16 * head_dim_padded; i += 64) {
1203
+ _mm512_store_ps(&o_acc[0][i], zero);
1204
+ _mm512_store_ps(&o_acc[0][i + 16], zero);
1205
+ _mm512_store_ps(&o_acc[0][i + 32], zero);
1206
+ _mm512_store_ps(&o_acc[0][i + 48], zero);
1207
+ }
1208
+
1209
+ // Process KV blocks
1210
+ for (nk_size_t kvb = 0; kvb < kv_len; kvb += Bc) {
1211
+ nk_size_t valid_kv = (kvb + Bc <= kv_len) ? Bc : (kv_len - kvb);
1212
+ nk_size_t k_tile_idx0 = kvb / 16;
1213
+ nk_size_t k_tile_idx1 = (kvb + 16) / 16;
1214
+
1215
+ // Phase 1: S = Q × Kᵀ using pre-packed Q tiles
1216
+ _tile_zero(0); // Score cols 0:16
1217
+ _tile_zero(3); // Score cols 16:32
1218
+
1219
+ for (nk_size_t dt = 0; dt < tiles_per_depth; dt++) {
1220
+ // Load pre-packed Q tile from L1 (not global!)
1221
+ _tile_loadd(1, q_tiles[dt], 64);
1222
+
1223
+ // Load K tiles from global (necessary)
1224
+ nk_bf16_t const *k_tile_ptr0 = k_head + (k_tile_idx0 * tiles_per_depth + dt) * tile_size;
1225
+ _tile_loadd(2, k_tile_ptr0, 64);
1226
+ _tile_dpbf16ps(0, 1, 2);
1227
+
1228
+ if (kvb + 16 < kv_len) {
1229
+ nk_bf16_t const *k_tile_ptr1 = k_head + (k_tile_idx1 * tiles_per_depth + dt) * tile_size;
1230
+ _tile_loadd(4, k_tile_ptr1, 64);
1231
+ _tile_dpbf16ps(3, 1, 4);
1232
+ }
1233
+ }
1234
+
1235
+ // Store first 16 columns directly to scores[0:16]
1236
+ _tile_stored(0, &scores[0][0], 128); // stride=128 bytes (32 floats)
1237
+
1238
+ // Store second 16 columns to scores[16:32]
1239
+ if (kvb + 16 < kv_len) { _tile_stored(3, &scores[0][16], 128); }
1240
+ else {
1241
+ // Mask out second half
1242
+ for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi][16], neg_inf); }
1243
+ }
1244
+
1245
+ // Apply masking for invalid positions (only on boundaries)
1246
+ if (valid_q < 16 || valid_kv < 32) {
1247
+ __mmask16 kv_mask0 = (valid_kv >= 16) ? 0xFFFF : ((1u << valid_kv) - 1);
1248
+ __mmask16 kv_mask1 = (valid_kv > 16) ? ((1u << (valid_kv - 16)) - 1) : 0;
1249
+ if (valid_kv >= 32) kv_mask1 = 0xFFFF;
1250
+
1251
+ for (nk_size_t qi = 0; qi < 16; qi++) {
1252
+ if (qi >= valid_q) {
1253
+ _mm512_store_ps(&scores[qi][0], neg_inf);
1254
+ _mm512_store_ps(&scores[qi][16], neg_inf);
1255
+ }
1256
+ else {
1257
+ __m512 s0 = _mm512_load_ps(&scores[qi][0]);
1258
+ __m512 s1 = _mm512_load_ps(&scores[qi][16]);
1259
+ _mm512_store_ps(&scores[qi][0], _mm512_mask_blend_ps(kv_mask0, neg_inf, s0));
1260
+ _mm512_store_ps(&scores[qi][16], _mm512_mask_blend_ps(kv_mask1, neg_inf, s1));
1261
+ }
1262
+ }
1263
+ }
1264
+
1265
+ // Phase 2: online softmax (fast degree-4 exp)
1266
+ __m512 old_max = softmax_state.row_max;
1267
+ nk_attention_softmax_update_bc32_fast_(&softmax_state, &scores[0][0], scale, &weights[0][0]);
1268
+ nk_attention_rescale_output_(&o_acc[0][0], head_dim_padded, old_max, softmax_state.row_max);
1269
+
1270
+ // Phase 3: O += P × V with hoisted P tile load
1271
+ // Convert F32 weights to BF16 P tile (once per KV block)
1272
+ for (nk_size_t qi = 0; qi < 16; qi++) {
1273
+ __m512 p0 = _mm512_load_ps(&weights[qi][0]);
1274
+ __m512 p1 = _mm512_load_ps(&weights[qi][16]);
1275
+ __m256bh pb0 = _mm512_cvtneps_pbh(p0);
1276
+ __m256bh pb1 = _mm512_cvtneps_pbh(p1);
1277
+ *(__m256bh *)&p_tile[qi][0] = pb0;
1278
+ *(__m256bh *)&p_tile[qi][16] = pb1;
1279
+ }
1280
+
1281
+ // Load P tile once, reuse for all V tiles
1282
+ _tile_loadd(6, p_tile, 64);
1283
+
1284
+ nk_size_t v_seq_tile = kvb / 32;
1285
+
1286
+ for (nk_size_t ht = 0; ht < v_tiles_per_head; ht++) {
1287
+ nk_size_t head_start = ht * 16;
1288
+
1289
+ // Load V tile from global
1290
+ nk_bf16_t const *v_tile_ptr = v_head + (v_seq_tile * v_tiles_per_head + ht) * tile_size;
1291
+
1292
+ _tile_zero(5);
1293
+ // P already in TMM6 - no reload!
1294
+ _tile_loadd(7, v_tile_ptr, 64);
1295
+ _tile_dpbf16ps(5, 6, 7);
1296
+
1297
+ // Store and accumulate
1298
+ _tile_stored(5, o_tile, 64);
1299
+
1300
+ // Accumulate into output (unrolled)
1301
+ for (nk_size_t qi = 0; qi < 16; qi += 4) {
1302
+ __m512 acc0 = _mm512_load_ps(&o_acc[qi + 0][head_start]);
1303
+ __m512 acc1 = _mm512_load_ps(&o_acc[qi + 1][head_start]);
1304
+ __m512 acc2 = _mm512_load_ps(&o_acc[qi + 2][head_start]);
1305
+ __m512 acc3 = _mm512_load_ps(&o_acc[qi + 3][head_start]);
1306
+
1307
+ acc0 = _mm512_add_ps(acc0, _mm512_load_ps(&o_tile[qi + 0][0]));
1308
+ acc1 = _mm512_add_ps(acc1, _mm512_load_ps(&o_tile[qi + 1][0]));
1309
+ acc2 = _mm512_add_ps(acc2, _mm512_load_ps(&o_tile[qi + 2][0]));
1310
+ acc3 = _mm512_add_ps(acc3, _mm512_load_ps(&o_tile[qi + 3][0]));
1311
+
1312
+ _mm512_store_ps(&o_acc[qi + 0][head_start], acc0);
1313
+ _mm512_store_ps(&o_acc[qi + 1][head_start], acc1);
1314
+ _mm512_store_ps(&o_acc[qi + 2][head_start], acc2);
1315
+ _mm512_store_ps(&o_acc[qi + 3][head_start], acc3);
1316
+ }
1317
+ }
1318
+ }
1319
+
1320
+ // Finalize: normalize O by row sums
1321
+ float row_sums[16];
1322
+ _mm512_store_ps(row_sums, softmax_state.row_sum);
1323
+ for (nk_size_t qi = 0; qi < valid_q; qi++) {
1324
+ __m512 inv_sum = _mm512_set1_ps(1.0f / row_sums[qi]);
1325
+ for (nk_size_t d = 0; d < head_dim; d += 16) {
1326
+ __m512 o = _mm512_load_ps(&o_acc[qi][d]);
1327
+ o = _mm512_mul_ps(o, inv_sum);
1328
+ _mm512_storeu_ps(&o_head[(qb + qi) * head_dim + d], o);
1329
+ }
1330
+ }
1331
+ }
1332
+ }
1333
+ }
1334
+
1335
+ NK_PUBLIC void nk_attention_causal_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_packed, nk_f32_t *output,
1336
+ nk_size_t num_heads, nk_size_t num_kv_heads, nk_size_t query_len,
1337
+ nk_size_t kv_len, nk_size_t head_dim, nk_f32_t scale) {
1338
+
1339
+ // For causal attention in autoregressive decode:
1340
+ // Query position q_pos can only attend to KV positions 0..q_pos
1341
+ // If kv_len == query_len (prefill), we need proper masking
1342
+ // If query_len == 1 (decode), the single query can see all KV
1343
+
1344
+ // Simplified: just call full attention for now
1345
+ // TODO: Implement proper causal masking with block skipping
1346
+ nk_attention_bf16_sapphireamx(q, kv_packed, output, num_heads, num_kv_heads, query_len, kv_len, head_dim, scale);
1347
+ }
1348
+
1349
+ #if defined(__clang__)
1350
+ #pragma clang attribute pop
1351
+ #elif defined(__GNUC__)
1352
+ #pragma GCC pop_options
1353
+ #endif
1354
+
1355
+ #if defined(__cplusplus)
1356
+ } // extern "C"
1357
+ #endif
1358
+
1359
+ #endif // NK_TARGET_SAPPHIREAMX
1360
+ #endif // NK_TARGET_X86_
1361
+ #endif // NK_ATTENTION_SAPPHIREAMX_H