numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -120,43 +120,43 @@ typedef struct {
120
120
  * @param x Input vector (16 floats)
121
121
  * @return exp(x) for each element
122
122
  */
123
- NK_INTERNAL __m512 nk_exp_ps_avx512_(__m512 x) {
123
+ NK_INTERNAL __m512 nk_exp_ps_avx512_(__m512 x_f32x16) {
124
124
  // Constants for Cody-Waite range reduction
125
- const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
126
- const __m512 ln2_hi = _mm512_set1_ps(0.693145751953125f);
127
- const __m512 ln2_lo = _mm512_set1_ps(1.42860682030941723212e-6f);
125
+ const __m512 log2e_f32x16 = _mm512_set1_ps(1.4426950408889634f);
126
+ const __m512 ln2_high_f32x16 = _mm512_set1_ps(0.693145751953125f);
127
+ const __m512 ln2_low_f32x16 = _mm512_set1_ps(1.42860682030941723212e-6f);
128
128
 
129
129
  // Clamp to avoid overflow/underflow
130
- const __m512 max_x = _mm512_set1_ps(88.3762626647949f);
131
- const __m512 min_x = _mm512_set1_ps(-87.3365447504021f);
132
- x = _mm512_max_ps(_mm512_min_ps(x, max_x), min_x);
130
+ const __m512 max_x_f32x16 = _mm512_set1_ps(88.3762626647949f);
131
+ const __m512 min_x_f32x16 = _mm512_set1_ps(-87.3365447504021f);
132
+ x_f32x16 = _mm512_max_ps(_mm512_min_ps(x_f32x16, max_x_f32x16), min_x_f32x16);
133
133
 
134
- // n = round(x / ln(2))
135
- __m512 n = _mm512_roundscale_ps(_mm512_mul_ps(x, log2e), _MM_FROUND_TO_NEAREST_INT);
134
+ // n_f32x16 = round(x / ln(2))
135
+ __m512 n_f32x16 = _mm512_roundscale_ps(_mm512_mul_ps(x_f32x16, log2e_f32x16), _MM_FROUND_TO_NEAREST_INT);
136
136
 
137
- // r = x - n × ln(2) using Cody-Waite for precision
138
- __m512 r = _mm512_fnmadd_ps(n, ln2_hi, x);
139
- r = _mm512_fnmadd_ps(n, ln2_lo, r);
137
+ // r_f32x16 = x - n_f32x16 × ln(2) using Cody-Waite for precision
138
+ __m512 r_f32x16 = _mm512_fnmadd_ps(n_f32x16, ln2_high_f32x16, x_f32x16);
139
+ r_f32x16 = _mm512_fnmadd_ps(n_f32x16, ln2_low_f32x16, r_f32x16);
140
140
 
141
- // Polynomial approximation for exp(r): Remez minimax degree 6
141
+ // Polynomial approximation for exp(r_f32x16): Remez minimax degree 6
142
142
  // Coefficients optimized for [-ln(2)/2, ln(2)/2]
143
- __m512 p = _mm512_set1_ps(1.9875691500e-4f);
144
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.3981999507e-3f));
145
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(8.3334519073e-3f));
146
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(4.1665858030e-2f));
147
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.6666665459e-1f));
148
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(5.0000001201e-1f));
149
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
150
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
151
-
152
- // Reconstruct: exp(x) = 2ⁿ × exp(r)
143
+ __m512 p_f32x16 = _mm512_set1_ps(1.9875691500e-4f);
144
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.3981999507e-3f));
145
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(8.3334519073e-3f));
146
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(4.1665858030e-2f));
147
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.6666665459e-1f));
148
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(5.0000001201e-1f));
149
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.0f));
150
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.0f));
151
+
152
+ // Reconstruct: exp(x) = 2ⁿ × exp(r_f32x16)
153
153
  // 2ⁿ via IEEE 754 exponent manipulation
154
- __m512i ni = _mm512_cvtps_epi32(n);
155
- ni = _mm512_add_epi32(ni, _mm512_set1_epi32(127));
156
- ni = _mm512_slli_epi32(ni, 23);
157
- __m512 pow2n = _mm512_castsi512_ps(ni);
154
+ __m512i ni_i32x16 = _mm512_cvtps_epi32(n_f32x16);
155
+ ni_i32x16 = _mm512_add_epi32(ni_i32x16, _mm512_set1_epi32(127));
156
+ ni_i32x16 = _mm512_slli_epi32(ni_i32x16, 23);
157
+ __m512 pow2n_f32x16 = _mm512_castsi512_ps(ni_i32x16);
158
158
 
159
- return _mm512_mul_ps(p, pow2n);
159
+ return _mm512_mul_ps(p_f32x16, pow2n_f32x16);
160
160
  }
161
161
 
162
162
  /**
@@ -172,41 +172,41 @@ NK_INTERNAL __m512 nk_exp_ps_avx512_(__m512 x) {
172
172
  * @param x Input vector (16 floats)
173
173
  * @return exp(x) approximation
174
174
  */
175
- NK_INTERNAL __m512 nk_exp_ps_fast_avx512_(__m512 x) {
175
+ NK_INTERNAL __m512 nk_exp_ps_fast_avx512_(__m512 x_f32x16) {
176
176
  // Constants for Cody-Waite range reduction
177
- const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
178
- const __m512 ln2_hi = _mm512_set1_ps(0.693145751953125f);
179
- const __m512 ln2_lo = _mm512_set1_ps(1.42860682030941723212e-6f);
177
+ const __m512 log2e_f32x16 = _mm512_set1_ps(1.4426950408889634f);
178
+ const __m512 ln2_high_f32x16 = _mm512_set1_ps(0.693145751953125f);
179
+ const __m512 ln2_low_f32x16 = _mm512_set1_ps(1.42860682030941723212e-6f);
180
180
 
181
181
  // Clamp to avoid overflow/underflow (same as accurate version)
182
- const __m512 max_x = _mm512_set1_ps(88.3762626647949f);
183
- const __m512 min_x = _mm512_set1_ps(-87.3365447504021f);
184
- x = _mm512_max_ps(_mm512_min_ps(x, max_x), min_x);
182
+ const __m512 max_x_f32x16 = _mm512_set1_ps(88.3762626647949f);
183
+ const __m512 min_x_f32x16 = _mm512_set1_ps(-87.3365447504021f);
184
+ x_f32x16 = _mm512_max_ps(_mm512_min_ps(x_f32x16, max_x_f32x16), min_x_f32x16);
185
185
 
186
- // n = round(x / ln(2))
187
- __m512 n = _mm512_roundscale_ps(_mm512_mul_ps(x, log2e), _MM_FROUND_TO_NEAREST_INT);
186
+ // n_f32x16 = round(x / ln(2))
187
+ __m512 n_f32x16 = _mm512_roundscale_ps(_mm512_mul_ps(x_f32x16, log2e_f32x16), _MM_FROUND_TO_NEAREST_INT);
188
188
 
189
- // r = x - n × ln(2) using Cody-Waite for precision
190
- __m512 r = _mm512_fnmadd_ps(n, ln2_hi, x);
191
- r = _mm512_fnmadd_ps(n, ln2_lo, r);
189
+ // r_f32x16 = x - n_f32x16 × ln(2) using Cody-Waite for precision
190
+ __m512 r_f32x16 = _mm512_fnmadd_ps(n_f32x16, ln2_high_f32x16, x_f32x16);
191
+ r_f32x16 = _mm512_fnmadd_ps(n_f32x16, ln2_low_f32x16, r_f32x16);
192
192
 
193
- // Polynomial approximation for exp(r): degree 4
193
+ // Polynomial approximation for exp(r_f32x16): degree 4
194
194
  // Optimized coefficients for [-ln(2)/2, ln(2)/2]
195
- // exp(r) ≈ 1 + r + r²/2 + r³/6 + r⁴/24
196
- // Using Horner form: ((c₄ × r + c₃) × r + c₂) × r + c₁) × r + c₀
197
- __m512 p = _mm512_set1_ps(4.1666666667e-2f); // 1/24
198
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.6666666667e-1f)); // 1/6
199
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(5.0000000000e-1f)); // 1/2
200
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); // 1
201
- p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); // 1
202
-
203
- // Reconstruct: exp(x) = 2ⁿ × exp(r)
204
- __m512i ni = _mm512_cvtps_epi32(n);
205
- ni = _mm512_add_epi32(ni, _mm512_set1_epi32(127));
206
- ni = _mm512_slli_epi32(ni, 23);
207
- __m512 pow2n = _mm512_castsi512_ps(ni);
208
-
209
- return _mm512_mul_ps(p, pow2n);
195
+ // exp(r_f32x16) ≈ 1 + r_f32x16 + r²/2 + r³/6 + r⁴/24
196
+ // Using Horner form: ((c₄ × r_f32x16 + c₃) × r_f32x16 + c₂) × r_f32x16 + c₁) × r_f32x16 + c₀
197
+ __m512 p_f32x16 = _mm512_set1_ps(4.1666666667e-2f); // 1/24
198
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.6666666667e-1f)); // 1/6
199
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(5.0000000000e-1f)); // 1/2
200
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.0f)); // 1
201
+ p_f32x16 = _mm512_fmadd_ps(p_f32x16, r_f32x16, _mm512_set1_ps(1.0f)); // 1
202
+
203
+ // Reconstruct: exp(x) = 2ⁿ × exp(r_f32x16)
204
+ __m512i ni_i32x16 = _mm512_cvtps_epi32(n_f32x16);
205
+ ni_i32x16 = _mm512_add_epi32(ni_i32x16, _mm512_set1_epi32(127));
206
+ ni_i32x16 = _mm512_slli_epi32(ni_i32x16, 23);
207
+ __m512 pow2n_f32x16 = _mm512_castsi512_ps(ni_i32x16);
208
+
209
+ return _mm512_mul_ps(p_f32x16, pow2n_f32x16);
210
210
  }
211
211
 
212
212
  /**
@@ -228,8 +228,8 @@ NK_INTERNAL __m512 nk_exp_ps_fast_avx512_(__m512 x) {
228
228
  * Tracks per-row running maximum and sum for 16 rows.
229
229
  */
230
230
  typedef struct {
231
- __m512 row_max; ///< Running max per row (16 values)
232
- __m512 row_sum; ///< Running sum of exp(x - max) per row
231
+ __m512 row_max_f32x16; ///< Running max per row (16 values)
232
+ __m512 row_sum_f32x16; ///< Running sum of exp(x - max) per row
233
233
  } nk_attention_softmax_row_state_t;
234
234
 
235
235
  /**
@@ -246,80 +246,80 @@ NK_INTERNAL void nk_attention_softmax_update_bc32_(nk_attention_softmax_row_stat
246
246
  nk_f32_t scale,
247
247
  nk_f32_t *weights_out) { // [16, 32] output weights
248
248
 
249
- __m512 scale_v = _mm512_set1_ps(scale);
249
+ __m512 scale_v_f32x16 = _mm512_set1_ps(scale);
250
250
 
251
251
  // Load and scale all scores, compute per-row max
252
252
  // Store in temporary arrays to avoid register pressure
253
- __m512 s_scaled[16][2];
253
+ __m512 s_scaled_f32x16[16][2];
254
254
  NK_ALIGN64 float row_maxes[16];
255
255
 
256
256
  // Process 4 rows at a time for ILP
257
257
  for (int i = 0; i < 16; i += 4) {
258
258
  // Row i
259
- s_scaled[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v);
260
- s_scaled[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v);
261
- __m512 m0 = _mm512_max_ps(s_scaled[i][0], s_scaled[i][1]);
259
+ s_scaled_f32x16[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v_f32x16);
260
+ s_scaled_f32x16[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v_f32x16);
261
+ __m512 m0_f32x16 = _mm512_max_ps(s_scaled_f32x16[i][0], s_scaled_f32x16[i][1]);
262
262
 
263
263
  // Row i+1
264
- s_scaled[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v);
265
- s_scaled[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v);
266
- __m512 m1 = _mm512_max_ps(s_scaled[i + 1][0], s_scaled[i + 1][1]);
264
+ s_scaled_f32x16[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v_f32x16);
265
+ s_scaled_f32x16[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v_f32x16);
266
+ __m512 m1_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 1][0], s_scaled_f32x16[i + 1][1]);
267
267
 
268
268
  // Row i+2
269
- s_scaled[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v);
270
- s_scaled[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v);
271
- __m512 m2 = _mm512_max_ps(s_scaled[i + 2][0], s_scaled[i + 2][1]);
269
+ s_scaled_f32x16[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v_f32x16);
270
+ s_scaled_f32x16[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v_f32x16);
271
+ __m512 m2_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 2][0], s_scaled_f32x16[i + 2][1]);
272
272
 
273
273
  // Row i+3
274
- s_scaled[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v);
275
- s_scaled[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v);
276
- __m512 m3 = _mm512_max_ps(s_scaled[i + 3][0], s_scaled[i + 3][1]);
274
+ s_scaled_f32x16[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v_f32x16);
275
+ s_scaled_f32x16[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v_f32x16);
276
+ __m512 m3_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 3][0], s_scaled_f32x16[i + 3][1]);
277
277
 
278
278
  // Reduce to scalar max
279
- row_maxes[i] = _mm512_reduce_max_ps(m0);
280
- row_maxes[i + 1] = _mm512_reduce_max_ps(m1);
281
- row_maxes[i + 2] = _mm512_reduce_max_ps(m2);
282
- row_maxes[i + 3] = _mm512_reduce_max_ps(m3);
279
+ row_maxes[i] = _mm512_reduce_max_ps(m0_f32x16);
280
+ row_maxes[i + 1] = _mm512_reduce_max_ps(m1_f32x16);
281
+ row_maxes[i + 2] = _mm512_reduce_max_ps(m2_f32x16);
282
+ row_maxes[i + 3] = _mm512_reduce_max_ps(m3_f32x16);
283
283
  }
284
284
 
285
- __m512 row_max_new = _mm512_load_ps(row_maxes);
286
- __m512 old_max = state->row_max;
287
- __m512 new_max = _mm512_max_ps(old_max, row_max_new);
285
+ __m512 row_max_new_f32x16 = _mm512_load_ps(row_maxes);
286
+ __m512 old_max_f32x16 = state->row_max_f32x16;
287
+ __m512 new_max_f32x16 = _mm512_max_ps(old_max_f32x16, row_max_new_f32x16);
288
288
 
289
289
  // Rescale old sum
290
- __m512 correction = nk_exp_ps_avx512_(_mm512_sub_ps(old_max, new_max));
291
- __m512 new_sum = _mm512_mul_ps(state->row_sum, correction);
290
+ __m512 correction_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(old_max_f32x16, new_max_f32x16));
291
+ __m512 new_sum_f32x16 = _mm512_mul_ps(state->row_sum_f32x16, correction_f32x16);
292
292
 
293
- // Compute P = exp(S - new_max) and accumulate sums
294
- NK_ALIGN64 float new_max_arr[16];
295
- NK_ALIGN64 float row_sums[16];
296
- _mm512_store_ps(new_max_arr, new_max);
293
+ // Compute P = exp(S - new_max_f32x16) and accumulate sums
294
+ NK_ALIGN64 nk_f32_t new_max_arr[16];
295
+ NK_ALIGN64 nk_f32_t row_sums[16];
296
+ _mm512_store_ps(new_max_arr, new_max_f32x16);
297
297
 
298
298
  // Process rows
299
299
  for (int i = 0; i < 16; i += 2) {
300
- __m512 max_i = _mm512_set1_ps(new_max_arr[i]);
301
- __m512 max_i1 = _mm512_set1_ps(new_max_arr[i + 1]);
300
+ __m512 max_i_f32x16 = _mm512_set1_ps(new_max_arr[i]);
301
+ __m512 max_i1_f32x16 = _mm512_set1_ps(new_max_arr[i + 1]);
302
302
 
303
303
  // Row i
304
- __m512 p0_i = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i][0], max_i));
305
- __m512 p1_i = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i][1], max_i));
306
- _mm512_store_ps(weights_out + i * 32 + 0, p0_i);
307
- _mm512_store_ps(weights_out + i * 32 + 16, p1_i);
308
- row_sums[i] = _mm512_reduce_add_ps(p0_i) + _mm512_reduce_add_ps(p1_i);
304
+ __m512 p0_i_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled_f32x16[i][0], max_i_f32x16));
305
+ __m512 p1_i_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled_f32x16[i][1], max_i_f32x16));
306
+ _mm512_store_ps(weights_out + i * 32 + 0, p0_i_f32x16);
307
+ _mm512_store_ps(weights_out + i * 32 + 16, p1_i_f32x16);
308
+ row_sums[i] = _mm512_reduce_add_ps(p0_i_f32x16) + _mm512_reduce_add_ps(p1_i_f32x16);
309
309
 
310
310
  // Row i+1
311
- __m512 p0_i1 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i + 1][0], max_i1));
312
- __m512 p1_i1 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled[i + 1][1], max_i1));
313
- _mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1);
314
- _mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1);
315
- row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1) + _mm512_reduce_add_ps(p1_i1);
311
+ __m512 p0_i1_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled_f32x16[i + 1][0], max_i1_f32x16));
312
+ __m512 p1_i1_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_scaled_f32x16[i + 1][1], max_i1_f32x16));
313
+ _mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1_f32x16);
314
+ _mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1_f32x16);
315
+ row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1_f32x16) + _mm512_reduce_add_ps(p1_i1_f32x16);
316
316
  }
317
317
 
318
318
  // Add row sums to running sum vectorially
319
- new_sum = _mm512_add_ps(new_sum, _mm512_load_ps(row_sums));
319
+ new_sum_f32x16 = _mm512_add_ps(new_sum_f32x16, _mm512_load_ps(row_sums));
320
320
 
321
- state->row_max = new_max;
322
- state->row_sum = new_sum;
321
+ state->row_max_f32x16 = new_max_f32x16;
322
+ state->row_sum_f32x16 = new_sum_f32x16;
323
323
  }
324
324
 
325
325
  /**
@@ -335,81 +335,81 @@ NK_INTERNAL void nk_attention_softmax_update_bc32_fast_(nk_attention_softmax_row
335
335
  nk_f32_t scale,
336
336
  nk_f32_t *weights_out) { // [16, 32] output weights
337
337
 
338
- __m512 scale_v = _mm512_set1_ps(scale);
338
+ __m512 scale_v_f32x16 = _mm512_set1_ps(scale);
339
339
 
340
340
  // Load and scale all scores, compute per-row max
341
- __m512 s_scaled[16][2];
341
+ __m512 s_scaled_f32x16[16][2];
342
342
  NK_ALIGN64 float row_maxes[16];
343
343
 
344
344
  // Process 4 rows at a time for ILP
345
345
  for (int i = 0; i < 16; i += 4) {
346
- s_scaled[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v);
347
- s_scaled[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v);
348
- __m512 m0 = _mm512_max_ps(s_scaled[i][0], s_scaled[i][1]);
349
-
350
- s_scaled[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v);
351
- s_scaled[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v);
352
- __m512 m1 = _mm512_max_ps(s_scaled[i + 1][0], s_scaled[i + 1][1]);
353
-
354
- s_scaled[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v);
355
- s_scaled[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v);
356
- __m512 m2 = _mm512_max_ps(s_scaled[i + 2][0], s_scaled[i + 2][1]);
357
-
358
- s_scaled[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v);
359
- s_scaled[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v);
360
- __m512 m3 = _mm512_max_ps(s_scaled[i + 3][0], s_scaled[i + 3][1]);
361
-
362
- row_maxes[i] = _mm512_reduce_max_ps(m0);
363
- row_maxes[i + 1] = _mm512_reduce_max_ps(m1);
364
- row_maxes[i + 2] = _mm512_reduce_max_ps(m2);
365
- row_maxes[i + 3] = _mm512_reduce_max_ps(m3);
346
+ s_scaled_f32x16[i][0] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 0), scale_v_f32x16);
347
+ s_scaled_f32x16[i][1] = _mm512_mul_ps(_mm512_load_ps(scores + i * 32 + 16), scale_v_f32x16);
348
+ __m512 m0_f32x16 = _mm512_max_ps(s_scaled_f32x16[i][0], s_scaled_f32x16[i][1]);
349
+
350
+ s_scaled_f32x16[i + 1][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 0), scale_v_f32x16);
351
+ s_scaled_f32x16[i + 1][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 1) * 32 + 16), scale_v_f32x16);
352
+ __m512 m1_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 1][0], s_scaled_f32x16[i + 1][1]);
353
+
354
+ s_scaled_f32x16[i + 2][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 0), scale_v_f32x16);
355
+ s_scaled_f32x16[i + 2][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 2) * 32 + 16), scale_v_f32x16);
356
+ __m512 m2_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 2][0], s_scaled_f32x16[i + 2][1]);
357
+
358
+ s_scaled_f32x16[i + 3][0] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 0), scale_v_f32x16);
359
+ s_scaled_f32x16[i + 3][1] = _mm512_mul_ps(_mm512_load_ps(scores + (i + 3) * 32 + 16), scale_v_f32x16);
360
+ __m512 m3_f32x16 = _mm512_max_ps(s_scaled_f32x16[i + 3][0], s_scaled_f32x16[i + 3][1]);
361
+
362
+ row_maxes[i] = _mm512_reduce_max_ps(m0_f32x16);
363
+ row_maxes[i + 1] = _mm512_reduce_max_ps(m1_f32x16);
364
+ row_maxes[i + 2] = _mm512_reduce_max_ps(m2_f32x16);
365
+ row_maxes[i + 3] = _mm512_reduce_max_ps(m3_f32x16);
366
366
  }
367
367
 
368
- __m512 row_max_new = _mm512_load_ps(row_maxes);
369
- __m512 old_max = state->row_max;
370
- __m512 new_max = _mm512_max_ps(old_max, row_max_new);
368
+ __m512 row_max_new_f32x16 = _mm512_load_ps(row_maxes);
369
+ __m512 old_max_f32x16 = state->row_max_f32x16;
370
+ __m512 new_max_f32x16 = _mm512_max_ps(old_max_f32x16, row_max_new_f32x16);
371
371
 
372
372
  // Rescale old sum using fast exp
373
- __m512 correction = nk_exp_ps_fast_avx512_(_mm512_sub_ps(old_max, new_max));
374
- __m512 new_sum = _mm512_mul_ps(state->row_sum, correction);
373
+ __m512 correction_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(old_max_f32x16, new_max_f32x16));
374
+ __m512 new_sum_f32x16 = _mm512_mul_ps(state->row_sum_f32x16, correction_f32x16);
375
375
 
376
- // Compute P = exp(S - new_max) using fast exp
377
- NK_ALIGN64 float new_max_arr[16];
378
- NK_ALIGN64 float row_sums[16];
379
- _mm512_store_ps(new_max_arr, new_max);
376
+ // Compute P = exp(S - new_max_f32x16) using fast exp
377
+ NK_ALIGN64 nk_f32_t new_max_arr[16];
378
+ NK_ALIGN64 nk_f32_t row_sums[16];
379
+ _mm512_store_ps(new_max_arr, new_max_f32x16);
380
380
 
381
381
  // Process rows with fast exp
382
382
  for (int i = 0; i < 16; i += 2) {
383
- __m512 max_i = _mm512_set1_ps(new_max_arr[i]);
384
- __m512 max_i1 = _mm512_set1_ps(new_max_arr[i + 1]);
383
+ __m512 max_i_f32x16 = _mm512_set1_ps(new_max_arr[i]);
384
+ __m512 max_i1_f32x16 = _mm512_set1_ps(new_max_arr[i + 1]);
385
385
 
386
386
  // Row i
387
- __m512 p0_i = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i][0], max_i));
388
- __m512 p1_i = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i][1], max_i));
389
- _mm512_store_ps(weights_out + i * 32 + 0, p0_i);
390
- _mm512_store_ps(weights_out + i * 32 + 16, p1_i);
391
- row_sums[i] = _mm512_reduce_add_ps(p0_i) + _mm512_reduce_add_ps(p1_i);
387
+ __m512 p0_i_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled_f32x16[i][0], max_i_f32x16));
388
+ __m512 p1_i_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled_f32x16[i][1], max_i_f32x16));
389
+ _mm512_store_ps(weights_out + i * 32 + 0, p0_i_f32x16);
390
+ _mm512_store_ps(weights_out + i * 32 + 16, p1_i_f32x16);
391
+ row_sums[i] = _mm512_reduce_add_ps(p0_i_f32x16) + _mm512_reduce_add_ps(p1_i_f32x16);
392
392
 
393
393
  // Row i+1
394
- __m512 p0_i1 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i + 1][0], max_i1));
395
- __m512 p1_i1 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled[i + 1][1], max_i1));
396
- _mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1);
397
- _mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1);
398
- row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1) + _mm512_reduce_add_ps(p1_i1);
394
+ __m512 p0_i1_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled_f32x16[i + 1][0], max_i1_f32x16));
395
+ __m512 p1_i1_f32x16 = nk_exp_ps_fast_avx512_(_mm512_sub_ps(s_scaled_f32x16[i + 1][1], max_i1_f32x16));
396
+ _mm512_store_ps(weights_out + (i + 1) * 32 + 0, p0_i1_f32x16);
397
+ _mm512_store_ps(weights_out + (i + 1) * 32 + 16, p1_i1_f32x16);
398
+ row_sums[i + 1] = _mm512_reduce_add_ps(p0_i1_f32x16) + _mm512_reduce_add_ps(p1_i1_f32x16);
399
399
  }
400
400
 
401
- new_sum = _mm512_add_ps(new_sum, _mm512_load_ps(row_sums));
401
+ new_sum_f32x16 = _mm512_add_ps(new_sum_f32x16, _mm512_load_ps(row_sums));
402
402
 
403
- state->row_max = new_max;
404
- state->row_sum = new_sum;
403
+ state->row_max_f32x16 = new_max_f32x16;
404
+ state->row_sum_f32x16 = new_sum_f32x16;
405
405
  }
406
406
 
407
407
  /**
408
408
  * @brief Initialize online softmax state.
409
409
  */
410
410
  NK_INTERNAL void nk_attention_softmax_init_(nk_attention_softmax_row_state_t *state) {
411
- state->row_max = _mm512_set1_ps(NK_F32_MIN);
412
- state->row_sum = _mm512_setzero_ps();
411
+ state->row_max_f32x16 = _mm512_set1_ps(NK_F32_MIN);
412
+ state->row_sum_f32x16 = _mm512_setzero_ps();
413
413
  }
414
414
 
415
415
  /**
@@ -430,43 +430,43 @@ NK_INTERNAL void nk_attention_softmax_init_(nk_attention_softmax_row_state_t *st
430
430
  NK_INTERNAL void nk_attention_softmax_update_(nk_attention_softmax_row_state_t *state, nk_f32_t const *scores,
431
431
  nk_f32_t scale, nk_f32_t *weights_out) {
432
432
 
433
- __m512 scale_v = _mm512_set1_ps(scale);
433
+ __m512 scale_v_f32x16 = _mm512_set1_ps(scale);
434
434
 
435
435
  // Load scores into 16 ZMM registers (one per row)
436
- __m512 s[16];
437
- for (int i = 0; i < 16; i++) { s[i] = _mm512_mul_ps(_mm512_load_ps(scores + i * 16), scale_v); }
436
+ __m512 s_f32x16[16];
437
+ for (int i = 0; i < 16; i++) { s_f32x16[i] = _mm512_mul_ps(_mm512_load_ps(scores + i * 16), scale_v_f32x16); }
438
438
 
439
439
  // Per-row max (each row has 16 elements, we need max across those 16)
440
440
  // _mm512_reduce_max_ps returns a float scalar
441
441
  NK_ALIGN64 float row_maxes[16];
442
- for (int i = 0; i < 16; i++) { row_maxes[i] = _mm512_reduce_max_ps(s[i]); }
443
- __m512 row_max_new = _mm512_load_ps(row_maxes);
442
+ for (int i = 0; i < 16; i++) { row_maxes[i] = _mm512_reduce_max_ps(s_f32x16[i]); }
443
+ __m512 row_max_new_f32x16 = _mm512_load_ps(row_maxes);
444
444
 
445
445
  // Update running max
446
- __m512 old_max = state->row_max;
447
- __m512 new_max = _mm512_max_ps(old_max, row_max_new);
446
+ __m512 old_max_f32x16 = state->row_max_f32x16;
447
+ __m512 new_max_f32x16 = _mm512_max_ps(old_max_f32x16, row_max_new_f32x16);
448
448
 
449
449
  // Rescale old sum: l = l × exp(oldₘₐₓ - newₘₐₓ)
450
- __m512 correction = nk_exp_ps_avx512_(_mm512_sub_ps(old_max, new_max));
451
- __m512 old_sum_rescaled = _mm512_mul_ps(state->row_sum, correction);
450
+ __m512 correction_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(old_max_f32x16, new_max_f32x16));
451
+ __m512 old_sum_rescaled_f32x16 = _mm512_mul_ps(state->row_sum_f32x16, correction_f32x16);
452
452
 
453
453
  // Compute P = exp(S - newₘₐₓ) for each row, accumulate sum
454
- __m512 new_sum = old_sum_rescaled;
455
- float new_max_arr[16];
456
- _mm512_store_ps(new_max_arr, new_max);
454
+ __m512 new_sum_f32x16 = old_sum_rescaled_f32x16;
455
+ nk_f32_t new_max_arr[16];
456
+ _mm512_store_ps(new_max_arr, new_max_f32x16);
457
457
 
458
458
  for (int i = 0; i < 16; i++) {
459
- __m512 max_broadcast = _mm512_set1_ps(new_max_arr[i]);
460
- __m512 p = nk_exp_ps_avx512_(_mm512_sub_ps(s[i], max_broadcast));
461
- _mm512_store_ps(weights_out + i * 16, p);
459
+ __m512 max_broadcast_f32x16 = _mm512_set1_ps(new_max_arr[i]);
460
+ __m512 p_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(s_f32x16[i], max_broadcast_f32x16));
461
+ _mm512_store_ps(weights_out + i * 16, p_f32x16);
462
462
 
463
463
  // Add row sum to running sum (at position i)
464
- float row_sum = _mm512_reduce_add_ps(p);
465
- new_sum = _mm512_mask_add_ps(new_sum, 1u << i, new_sum, _mm512_set1_ps(row_sum));
464
+ nk_f32_t row_sum = _mm512_reduce_add_ps(p_f32x16);
465
+ new_sum_f32x16 = _mm512_mask_add_ps(new_sum_f32x16, 1u << i, new_sum_f32x16, _mm512_set1_ps(row_sum));
466
466
  }
467
467
 
468
- state->row_max = new_max;
469
- state->row_sum = new_sum;
468
+ state->row_max_f32x16 = new_max_f32x16;
469
+ state->row_sum_f32x16 = new_sum_f32x16;
470
470
  }
471
471
 
472
472
  /**
@@ -480,18 +480,19 @@ NK_INTERNAL void nk_attention_softmax_update_(nk_attention_softmax_row_state_t *
480
480
  * @param old_max Previous running max per row (16 values)
481
481
  * @param new_max New running max per row (16 values)
482
482
  */
483
- NK_INTERNAL void nk_attention_rescale_output_(nk_f32_t *output, nk_size_t head_dim, __m512 old_max, __m512 new_max) {
483
+ NK_INTERNAL void nk_attention_rescale_output_(nk_f32_t *output, nk_size_t head_dim, __m512 old_max_f32x16,
484
+ __m512 new_max_f32x16) {
484
485
 
485
- __m512 correction = nk_exp_ps_avx512_(_mm512_sub_ps(old_max, new_max));
486
- float corr_arr[16];
487
- _mm512_store_ps(corr_arr, correction);
486
+ __m512 correction_f32x16 = nk_exp_ps_avx512_(_mm512_sub_ps(old_max_f32x16, new_max_f32x16));
487
+ nk_f32_t corr_arr[16];
488
+ _mm512_store_ps(corr_arr, correction_f32x16);
488
489
 
489
490
  for (nk_size_t row = 0; row < 16; row++) {
490
- __m512 corr_v = _mm512_set1_ps(corr_arr[row]);
491
+ __m512 corr_v_f32x16 = _mm512_set1_ps(corr_arr[row]);
491
492
  for (nk_size_t col = 0; col < head_dim; col += 16) {
492
- __m512 o = _mm512_load_ps(output + row * head_dim + col);
493
- o = _mm512_mul_ps(o, corr_v);
494
- _mm512_store_ps(output + row * head_dim + col, o);
493
+ __m512 o_f32x16 = _mm512_load_ps(output + row * head_dim + col);
494
+ o_f32x16 = _mm512_mul_ps(o_f32x16, corr_v_f32x16);
495
+ _mm512_store_ps(output + row * head_dim + col, o_f32x16);
495
496
  }
496
497
  }
497
498
  }
@@ -790,22 +791,22 @@ NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_
790
791
  // Phase 1: Compute S = Q × Kᵀ using AVX-512 FMA
791
792
  for (nk_size_t qi = 0; qi < valid_q; qi++) {
792
793
  for (nk_size_t ki = 0; ki < valid_kv; ki++) {
793
- __m512 sum_v = _mm512_setzero_ps();
794
+ __m512 sum_v_f32x16 = _mm512_setzero_ps();
794
795
  nk_size_t d = 0;
795
796
  // Vectorized loop over head_dim
796
797
  for (; d + 16 <= head_dim; d += 16) {
797
- __m512 q_v = _mm512_loadu_ps(&q_block[qi * head_dim + d]);
798
+ __m512 q_v_f32x16 = _mm512_loadu_ps(&q_block[qi * head_dim + d]);
798
799
  // Kᵀ is stored as [head_dim, kv], gather is slow, use scalar for now
799
- __m512 k_v = _mm512_set_ps(
800
+ __m512 k_v_f32x16 = _mm512_set_ps(
800
801
  k_block[(d + 15) * 16 + ki], k_block[(d + 14) * 16 + ki], k_block[(d + 13) * 16 + ki],
801
802
  k_block[(d + 12) * 16 + ki], k_block[(d + 11) * 16 + ki], k_block[(d + 10) * 16 + ki],
802
803
  k_block[(d + 9) * 16 + ki], k_block[(d + 8) * 16 + ki], k_block[(d + 7) * 16 + ki],
803
804
  k_block[(d + 6) * 16 + ki], k_block[(d + 5) * 16 + ki], k_block[(d + 4) * 16 + ki],
804
805
  k_block[(d + 3) * 16 + ki], k_block[(d + 2) * 16 + ki], k_block[(d + 1) * 16 + ki],
805
806
  k_block[(d + 0) * 16 + ki]);
806
- sum_v = _mm512_fmadd_ps(q_v, k_v, sum_v);
807
+ sum_v_f32x16 = _mm512_fmadd_ps(q_v_f32x16, k_v_f32x16, sum_v_f32x16);
807
808
  }
808
- nk_f32_t sum = _mm512_reduce_add_ps(sum_v);
809
+ nk_f32_t sum = _mm512_reduce_add_ps(sum_v_f32x16);
809
810
  // Scalar tail
810
811
  for (; d < head_dim; d++) { sum += q_block[qi * head_dim + d] * k_block[d * 16 + ki]; }
811
812
  scores[qi * 16 + ki] = sum;
@@ -819,11 +820,11 @@ NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_
819
820
  }
820
821
 
821
822
  // Phase 2: Online softmax update
822
- __m512 old_max = softmax_state.row_max;
823
+ __m512 old_max_f32x16 = softmax_state.row_max_f32x16;
823
824
  nk_attention_softmax_update_(&softmax_state, scores, scale, weights);
824
825
 
825
826
  // Rescale output accumulator if max changed
826
- nk_attention_rescale_output_(o_acc, head_dim_padded, old_max, softmax_state.row_max);
827
+ nk_attention_rescale_output_(o_acc, head_dim_padded, old_max_f32x16, softmax_state.row_max_f32x16);
827
828
 
828
829
  // Extract V block: V[valid_kv, head_dim] using bulk extraction
829
830
  nk_attention_extract_v_block_(v_packed, v_block, kv_h, kvb, valid_kv, head_dim, kv_len);
@@ -833,13 +834,13 @@ NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_
833
834
  nk_size_t d = 0;
834
835
  // Vectorized loop over head_dim
835
836
  for (; d + 16 <= head_dim; d += 16) {
836
- __m512 acc_v = _mm512_loadu_ps(&o_acc[qi * head_dim_padded + d]);
837
+ __m512 acc_v_f32x16 = _mm512_loadu_ps(&o_acc[qi * head_dim_padded + d]);
837
838
  for (nk_size_t ki = 0; ki < valid_kv; ki++) {
838
- __m512 p_v = _mm512_set1_ps(weights[qi * 16 + ki]);
839
- __m512 v_v = _mm512_loadu_ps(&v_block[ki * head_dim + d]);
840
- acc_v = _mm512_fmadd_ps(p_v, v_v, acc_v);
839
+ __m512 p_v_f32x16 = _mm512_set1_ps(weights[qi * 16 + ki]);
840
+ __m512 v_v_f32x16 = _mm512_loadu_ps(&v_block[ki * head_dim + d]);
841
+ acc_v_f32x16 = _mm512_fmadd_ps(p_v_f32x16, v_v_f32x16, acc_v_f32x16);
841
842
  }
842
- _mm512_storeu_ps(&o_acc[qi * head_dim_padded + d], acc_v);
843
+ _mm512_storeu_ps(&o_acc[qi * head_dim_padded + d], acc_v_f32x16);
843
844
  }
844
845
  // Scalar tail
845
846
  for (; d < head_dim; d++) {
@@ -853,8 +854,8 @@ NK_PUBLIC void nk_attention_bf16_sapphireamx(nk_bf16_t const *q, void const *kv_
853
854
  }
854
855
 
855
856
  // Finalize: normalize O by row sums
856
- float row_sums[16];
857
- _mm512_store_ps(row_sums, softmax_state.row_sum);
857
+ nk_f32_t row_sums[16];
858
+ _mm512_store_ps(row_sums, softmax_state.row_sum_f32x16);
858
859
 
859
860
  for (nk_size_t qi = 0; qi < valid_q; qi++) {
860
861
  nk_f32_t inv_sum = 1.0f / row_sums[qi];
@@ -918,12 +919,12 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
918
919
  nk_attention_softmax_init_(&softmax_state);
919
920
 
920
921
  // Zero output accumulator using SIMD
921
- __m512 zero = _mm512_setzero_ps();
922
+ __m512 zero_f32x16 = _mm512_setzero_ps();
922
923
  for (nk_size_t i = 0; i < 16 * head_dim_padded; i += 64) {
923
- _mm512_store_ps(&o_acc[i], zero);
924
- _mm512_store_ps(&o_acc[i + 16], zero);
925
- _mm512_store_ps(&o_acc[i + 32], zero);
926
- _mm512_store_ps(&o_acc[i + 48], zero);
924
+ _mm512_store_ps(&o_acc[i], zero_f32x16);
925
+ _mm512_store_ps(&o_acc[i + 16], zero_f32x16);
926
+ _mm512_store_ps(&o_acc[i + 32], zero_f32x16);
927
+ _mm512_store_ps(&o_acc[i + 48], zero_f32x16);
927
928
  }
928
929
 
929
930
  // Process KV blocks in chunks of 32
@@ -949,10 +950,10 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
949
950
  for (nk_size_t row = 0; row < valid_q; row++) {
950
951
  nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
951
952
  // Load 32 BF16 values (64 bytes) using two 256-bit loads
952
- __m256i q0 = _mm256_loadu_si256((__m256i const *)q_row);
953
- __m256i q1 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
954
- _mm256_store_si256((__m256i *)&q_tile[row][0], q0);
955
- _mm256_store_si256((__m256i *)&q_tile[row][16], q1);
953
+ __m256i q0_bf16x16 = _mm256_loadu_si256((__m256i const *)q_row);
954
+ __m256i q1_bf16x16 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
955
+ _mm256_store_si256((__m256i *)&q_tile[row][0], q0_bf16x16);
956
+ _mm256_store_si256((__m256i *)&q_tile[row][16], q1_bf16x16);
956
957
  }
957
958
  }
958
959
  else {
@@ -990,23 +991,23 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
990
991
  // Use SIMD for fast extraction
991
992
  _tile_stored(0, s_tile, 64);
992
993
 
993
- __m512 neg_inf = _mm512_set1_ps(NK_F32_MIN);
994
+ __m512 neg_inf_f32x16 = _mm512_set1_ps(NK_F32_MIN);
994
995
 
995
996
  if (valid_q == 16 && valid_kv >= 16) {
996
997
  // Fast path: full first half, just copy
997
998
  for (nk_size_t qi = 0; qi < 16; qi++) {
998
- __m512 s0 = _mm512_load_ps(&s_tile[qi][0]);
999
- _mm512_store_ps(&scores[qi * 32], s0);
999
+ __m512 s0_f32x16 = _mm512_load_ps(&s_tile[qi][0]);
1000
+ _mm512_store_ps(&scores[qi * 32], s0_f32x16);
1000
1001
  }
1001
1002
  }
1002
1003
  else {
1003
1004
  // Partial - need masking
1004
1005
  __mmask16 kv_mask = (1u << valid_kv) - 1;
1005
1006
  for (nk_size_t qi = 0; qi < 16; qi++) {
1006
- __m512 s0 = _mm512_load_ps(&s_tile[qi][0]);
1007
- if (qi < valid_q) { s0 = _mm512_mask_blend_ps(kv_mask, neg_inf, s0); }
1008
- else { s0 = neg_inf; }
1009
- _mm512_store_ps(&scores[qi * 32], s0);
1007
+ __m512 s0_f32x16 = _mm512_load_ps(&s_tile[qi][0]);
1008
+ if (qi < valid_q) { s0_f32x16 = _mm512_mask_blend_ps(kv_mask, neg_inf_f32x16, s0_f32x16); }
1009
+ else { s0_f32x16 = neg_inf_f32x16; }
1010
+ _mm512_store_ps(&scores[qi * 32], s0_f32x16);
1010
1011
  }
1011
1012
  }
1012
1013
 
@@ -1018,36 +1019,36 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
1018
1019
  if (valid_q == 16 && valid_kv2 >= 16) {
1019
1020
  // Fast path
1020
1021
  for (nk_size_t qi = 0; qi < 16; qi++) {
1021
- __m512 s1 = _mm512_load_ps(&s_tile[qi][0]);
1022
- _mm512_store_ps(&scores[qi * 32 + 16], s1);
1022
+ __m512 s1_f32x16 = _mm512_load_ps(&s_tile[qi][0]);
1023
+ _mm512_store_ps(&scores[qi * 32 + 16], s1_f32x16);
1023
1024
  }
1024
1025
  }
1025
1026
  else {
1026
1027
  __mmask16 kv_mask2 = (valid_kv2 >= 16) ? 0xFFFF : ((1u << valid_kv2) - 1);
1027
1028
  for (nk_size_t qi = 0; qi < 16; qi++) {
1028
- __m512 s1 = _mm512_load_ps(&s_tile[qi][0]);
1029
- if (qi < valid_q) { s1 = _mm512_mask_blend_ps(kv_mask2, neg_inf, s1); }
1030
- else { s1 = neg_inf; }
1031
- _mm512_store_ps(&scores[qi * 32 + 16], s1);
1029
+ __m512 s1_f32x16 = _mm512_load_ps(&s_tile[qi][0]);
1030
+ if (qi < valid_q) { s1_f32x16 = _mm512_mask_blend_ps(kv_mask2, neg_inf_f32x16, s1_f32x16); }
1031
+ else { s1_f32x16 = neg_inf_f32x16; }
1032
+ _mm512_store_ps(&scores[qi * 32 + 16], s1_f32x16);
1032
1033
  }
1033
1034
  }
1034
1035
  }
1035
1036
  else {
1036
1037
  // Mask out second half entirely
1037
- for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi * 32 + 16], neg_inf); }
1038
+ for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi * 32 + 16], neg_inf_f32x16); }
1038
1039
  }
1039
1040
 
1040
1041
  // Phase 2: online softmax (fast degree-4 exp)
1041
- __m512 old_max = softmax_state.row_max;
1042
+ __m512 old_max_f32x16 = softmax_state.row_max_f32x16;
1042
1043
  nk_attention_softmax_update_bc32_fast_(&softmax_state, scores, scale, weights);
1043
- nk_attention_rescale_output_(o_acc, head_dim_padded, old_max, softmax_state.row_max);
1044
+ nk_attention_rescale_output_(o_acc, head_dim_padded, old_max_f32x16, softmax_state.row_max_f32x16);
1044
1045
 
1045
1046
  // Phase 3: O += P × V using AMX
1046
1047
  // Convert P[16, 32] from F32 to BF16 and pack as A-tile
1047
1048
  for (nk_size_t qi = 0; qi < 16; qi++) {
1048
1049
  for (nk_size_t ki = 0; ki < 32; ki += 16) {
1049
- __m512 p_f32 = _mm512_loadu_ps(&weights[qi * 32 + ki]);
1050
- __m256bh p_bf16 = _mm512_cvtneps_pbh(p_f32);
1050
+ __m512 p_f32_f32x16 = _mm512_loadu_ps(&weights[qi * 32 + ki]);
1051
+ __m256bh p_bf16 = _mm512_cvtneps_pbh(p_f32_f32x16);
1051
1052
  // Store BF16 vector - cast through union or memory
1052
1053
  *(__m256bh *)&p_tile[qi][ki] = p_bf16;
1053
1054
  }
@@ -1079,29 +1080,29 @@ NK_PUBLIC void nk_attention_bf16_amx_bc32_sapphireamx(nk_bf16_t const *q, void c
1079
1080
  _tile_stored(5, o_tile, 64);
1080
1081
 
1081
1082
  // Add to output accumulator - unrolled for all 16 rows
1082
- // Even if valid_q < 16, we accumulate all (padded rows have zero weights)
1083
+ // Even if valid_q < 16, we accumulate all (padded rows have zero_f32x16 weights)
1083
1084
  for (nk_size_t qi = 0; qi < 16; qi += 4) {
1084
- __m512 acc0 = _mm512_load_ps(&o_acc[(qi + 0) * head_dim_padded + head_start]);
1085
- __m512 acc1 = _mm512_load_ps(&o_acc[(qi + 1) * head_dim_padded + head_start]);
1086
- __m512 acc2 = _mm512_load_ps(&o_acc[(qi + 2) * head_dim_padded + head_start]);
1087
- __m512 acc3 = _mm512_load_ps(&o_acc[(qi + 3) * head_dim_padded + head_start]);
1088
-
1089
- acc0 = _mm512_add_ps(acc0, _mm512_load_ps(&o_tile[qi + 0][0]));
1090
- acc1 = _mm512_add_ps(acc1, _mm512_load_ps(&o_tile[qi + 1][0]));
1091
- acc2 = _mm512_add_ps(acc2, _mm512_load_ps(&o_tile[qi + 2][0]));
1092
- acc3 = _mm512_add_ps(acc3, _mm512_load_ps(&o_tile[qi + 3][0]));
1093
-
1094
- _mm512_store_ps(&o_acc[(qi + 0) * head_dim_padded + head_start], acc0);
1095
- _mm512_store_ps(&o_acc[(qi + 1) * head_dim_padded + head_start], acc1);
1096
- _mm512_store_ps(&o_acc[(qi + 2) * head_dim_padded + head_start], acc2);
1097
- _mm512_store_ps(&o_acc[(qi + 3) * head_dim_padded + head_start], acc3);
1085
+ __m512 acc0_f32x16 = _mm512_load_ps(&o_acc[(qi + 0) * head_dim_padded + head_start]);
1086
+ __m512 acc1_f32x16 = _mm512_load_ps(&o_acc[(qi + 1) * head_dim_padded + head_start]);
1087
+ __m512 acc2_f32x16 = _mm512_load_ps(&o_acc[(qi + 2) * head_dim_padded + head_start]);
1088
+ __m512 acc3_f32x16 = _mm512_load_ps(&o_acc[(qi + 3) * head_dim_padded + head_start]);
1089
+
1090
+ acc0_f32x16 = _mm512_add_ps(acc0_f32x16, _mm512_load_ps(&o_tile[qi + 0][0]));
1091
+ acc1_f32x16 = _mm512_add_ps(acc1_f32x16, _mm512_load_ps(&o_tile[qi + 1][0]));
1092
+ acc2_f32x16 = _mm512_add_ps(acc2_f32x16, _mm512_load_ps(&o_tile[qi + 2][0]));
1093
+ acc3_f32x16 = _mm512_add_ps(acc3_f32x16, _mm512_load_ps(&o_tile[qi + 3][0]));
1094
+
1095
+ _mm512_store_ps(&o_acc[(qi + 0) * head_dim_padded + head_start], acc0_f32x16);
1096
+ _mm512_store_ps(&o_acc[(qi + 1) * head_dim_padded + head_start], acc1_f32x16);
1097
+ _mm512_store_ps(&o_acc[(qi + 2) * head_dim_padded + head_start], acc2_f32x16);
1098
+ _mm512_store_ps(&o_acc[(qi + 3) * head_dim_padded + head_start], acc3_f32x16);
1098
1099
  }
1099
1100
  }
1100
1101
  }
1101
1102
 
1102
1103
  // Finalize: normalize O by row sums
1103
- float row_sums[16];
1104
- _mm512_store_ps(row_sums, softmax_state.row_sum);
1104
+ nk_f32_t row_sums[16];
1105
+ _mm512_store_ps(row_sums, softmax_state.row_sum_f32x16);
1105
1106
  for (nk_size_t qi = 0; qi < valid_q; qi++) {
1106
1107
  nk_f32_t inv_sum = 1.0f / row_sums[qi];
1107
1108
  for (nk_size_t d = 0; d < head_dim; d++) {
@@ -1149,7 +1150,7 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
1149
1150
  NK_ALIGN64 nk_f32_t o_tile[16][16]; // Output tile buffer
1150
1151
  NK_ALIGN64 nk_f32_t o_acc[16][256]; // Output accumulator (max d=256)
1151
1152
 
1152
- __m512 neg_inf = _mm512_set1_ps(NK_F32_MIN);
1153
+ __m512 neg_inf_f32x16 = _mm512_set1_ps(NK_F32_MIN);
1153
1154
 
1154
1155
  for (nk_size_t h = 0; h < num_heads; h++) {
1155
1156
  nk_size_t kv_h = h / gqa_ratio;
@@ -1169,10 +1170,10 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
1169
1170
  // Full tile - fast SIMD copy
1170
1171
  for (nk_size_t row = 0; row < valid_q; row++) {
1171
1172
  nk_bf16_t const *q_row = q_head + (qb + row) * head_dim + depth_start;
1172
- __m256i q0 = _mm256_loadu_si256((__m256i const *)q_row);
1173
- __m256i q1 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
1174
- _mm256_store_si256((__m256i *)&q_tiles[dt][row][0], q0);
1175
- _mm256_store_si256((__m256i *)&q_tiles[dt][row][16], q1);
1173
+ __m256i q0_bf16x16 = _mm256_loadu_si256((__m256i const *)q_row);
1174
+ __m256i q1_bf16x16 = _mm256_loadu_si256((__m256i const *)(q_row + 16));
1175
+ _mm256_store_si256((__m256i *)&q_tiles[dt][row][0], q0_bf16x16);
1176
+ _mm256_store_si256((__m256i *)&q_tiles[dt][row][16], q1_bf16x16);
1176
1177
  }
1177
1178
  // Zero remaining rows
1178
1179
  for (nk_size_t row = valid_q; row < 16; row++) {
@@ -1198,12 +1199,12 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
1198
1199
  nk_attention_softmax_row_state_t softmax_state;
1199
1200
  nk_attention_softmax_init_(&softmax_state);
1200
1201
 
1201
- __m512 zero = _mm512_setzero_ps();
1202
+ __m512 zero_f32x16 = _mm512_setzero_ps();
1202
1203
  for (nk_size_t i = 0; i < 16 * head_dim_padded; i += 64) {
1203
- _mm512_store_ps(&o_acc[0][i], zero);
1204
- _mm512_store_ps(&o_acc[0][i + 16], zero);
1205
- _mm512_store_ps(&o_acc[0][i + 32], zero);
1206
- _mm512_store_ps(&o_acc[0][i + 48], zero);
1204
+ _mm512_store_ps(&o_acc[0][i], zero_f32x16);
1205
+ _mm512_store_ps(&o_acc[0][i + 16], zero_f32x16);
1206
+ _mm512_store_ps(&o_acc[0][i + 32], zero_f32x16);
1207
+ _mm512_store_ps(&o_acc[0][i + 48], zero_f32x16);
1207
1208
  }
1208
1209
 
1209
1210
  // Process KV blocks
@@ -1239,7 +1240,7 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
1239
1240
  if (kvb + 16 < kv_len) { _tile_stored(3, &scores[0][16], 128); }
1240
1241
  else {
1241
1242
  // Mask out second half
1242
- for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi][16], neg_inf); }
1243
+ for (nk_size_t qi = 0; qi < 16; qi++) { _mm512_store_ps(&scores[qi][16], neg_inf_f32x16); }
1243
1244
  }
1244
1245
 
1245
1246
  // Apply masking for invalid positions (only on boundaries)
@@ -1250,30 +1251,31 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
1250
1251
 
1251
1252
  for (nk_size_t qi = 0; qi < 16; qi++) {
1252
1253
  if (qi >= valid_q) {
1253
- _mm512_store_ps(&scores[qi][0], neg_inf);
1254
- _mm512_store_ps(&scores[qi][16], neg_inf);
1254
+ _mm512_store_ps(&scores[qi][0], neg_inf_f32x16);
1255
+ _mm512_store_ps(&scores[qi][16], neg_inf_f32x16);
1255
1256
  }
1256
1257
  else {
1257
- __m512 s0 = _mm512_load_ps(&scores[qi][0]);
1258
- __m512 s1 = _mm512_load_ps(&scores[qi][16]);
1259
- _mm512_store_ps(&scores[qi][0], _mm512_mask_blend_ps(kv_mask0, neg_inf, s0));
1260
- _mm512_store_ps(&scores[qi][16], _mm512_mask_blend_ps(kv_mask1, neg_inf, s1));
1258
+ __m512 s0_f32x16 = _mm512_load_ps(&scores[qi][0]);
1259
+ __m512 s1_f32x16 = _mm512_load_ps(&scores[qi][16]);
1260
+ _mm512_store_ps(&scores[qi][0], _mm512_mask_blend_ps(kv_mask0, neg_inf_f32x16, s0_f32x16));
1261
+ _mm512_store_ps(&scores[qi][16], _mm512_mask_blend_ps(kv_mask1, neg_inf_f32x16, s1_f32x16));
1261
1262
  }
1262
1263
  }
1263
1264
  }
1264
1265
 
1265
1266
  // Phase 2: online softmax (fast degree-4 exp)
1266
- __m512 old_max = softmax_state.row_max;
1267
+ __m512 old_max_f32x16 = softmax_state.row_max_f32x16;
1267
1268
  nk_attention_softmax_update_bc32_fast_(&softmax_state, &scores[0][0], scale, &weights[0][0]);
1268
- nk_attention_rescale_output_(&o_acc[0][0], head_dim_padded, old_max, softmax_state.row_max);
1269
+ nk_attention_rescale_output_(&o_acc[0][0], head_dim_padded, old_max_f32x16,
1270
+ softmax_state.row_max_f32x16);
1269
1271
 
1270
1272
  // Phase 3: O += P × V with hoisted P tile load
1271
1273
  // Convert F32 weights to BF16 P tile (once per KV block)
1272
1274
  for (nk_size_t qi = 0; qi < 16; qi++) {
1273
- __m512 p0 = _mm512_load_ps(&weights[qi][0]);
1274
- __m512 p1 = _mm512_load_ps(&weights[qi][16]);
1275
- __m256bh pb0 = _mm512_cvtneps_pbh(p0);
1276
- __m256bh pb1 = _mm512_cvtneps_pbh(p1);
1275
+ __m512 p0_f32x16 = _mm512_load_ps(&weights[qi][0]);
1276
+ __m512 p1_f32x16 = _mm512_load_ps(&weights[qi][16]);
1277
+ __m256bh pb0 = _mm512_cvtneps_pbh(p0_f32x16);
1278
+ __m256bh pb1 = _mm512_cvtneps_pbh(p1_f32x16);
1277
1279
  *(__m256bh *)&p_tile[qi][0] = pb0;
1278
1280
  *(__m256bh *)&p_tile[qi][16] = pb1;
1279
1281
  }
@@ -1299,33 +1301,33 @@ NK_PUBLIC void nk_attention_bf16_amx_optimized_sapphireamx(nk_bf16_t const *q, v
1299
1301
 
1300
1302
  // Accumulate into output (unrolled)
1301
1303
  for (nk_size_t qi = 0; qi < 16; qi += 4) {
1302
- __m512 acc0 = _mm512_load_ps(&o_acc[qi + 0][head_start]);
1303
- __m512 acc1 = _mm512_load_ps(&o_acc[qi + 1][head_start]);
1304
- __m512 acc2 = _mm512_load_ps(&o_acc[qi + 2][head_start]);
1305
- __m512 acc3 = _mm512_load_ps(&o_acc[qi + 3][head_start]);
1306
-
1307
- acc0 = _mm512_add_ps(acc0, _mm512_load_ps(&o_tile[qi + 0][0]));
1308
- acc1 = _mm512_add_ps(acc1, _mm512_load_ps(&o_tile[qi + 1][0]));
1309
- acc2 = _mm512_add_ps(acc2, _mm512_load_ps(&o_tile[qi + 2][0]));
1310
- acc3 = _mm512_add_ps(acc3, _mm512_load_ps(&o_tile[qi + 3][0]));
1311
-
1312
- _mm512_store_ps(&o_acc[qi + 0][head_start], acc0);
1313
- _mm512_store_ps(&o_acc[qi + 1][head_start], acc1);
1314
- _mm512_store_ps(&o_acc[qi + 2][head_start], acc2);
1315
- _mm512_store_ps(&o_acc[qi + 3][head_start], acc3);
1304
+ __m512 acc0_f32x16 = _mm512_load_ps(&o_acc[qi + 0][head_start]);
1305
+ __m512 acc1_f32x16 = _mm512_load_ps(&o_acc[qi + 1][head_start]);
1306
+ __m512 acc2_f32x16 = _mm512_load_ps(&o_acc[qi + 2][head_start]);
1307
+ __m512 acc3_f32x16 = _mm512_load_ps(&o_acc[qi + 3][head_start]);
1308
+
1309
+ acc0_f32x16 = _mm512_add_ps(acc0_f32x16, _mm512_load_ps(&o_tile[qi + 0][0]));
1310
+ acc1_f32x16 = _mm512_add_ps(acc1_f32x16, _mm512_load_ps(&o_tile[qi + 1][0]));
1311
+ acc2_f32x16 = _mm512_add_ps(acc2_f32x16, _mm512_load_ps(&o_tile[qi + 2][0]));
1312
+ acc3_f32x16 = _mm512_add_ps(acc3_f32x16, _mm512_load_ps(&o_tile[qi + 3][0]));
1313
+
1314
+ _mm512_store_ps(&o_acc[qi + 0][head_start], acc0_f32x16);
1315
+ _mm512_store_ps(&o_acc[qi + 1][head_start], acc1_f32x16);
1316
+ _mm512_store_ps(&o_acc[qi + 2][head_start], acc2_f32x16);
1317
+ _mm512_store_ps(&o_acc[qi + 3][head_start], acc3_f32x16);
1316
1318
  }
1317
1319
  }
1318
1320
  }
1319
1321
 
1320
1322
  // Finalize: normalize O by row sums
1321
- float row_sums[16];
1322
- _mm512_store_ps(row_sums, softmax_state.row_sum);
1323
+ nk_f32_t row_sums[16];
1324
+ _mm512_store_ps(row_sums, softmax_state.row_sum_f32x16);
1323
1325
  for (nk_size_t qi = 0; qi < valid_q; qi++) {
1324
- __m512 inv_sum = _mm512_set1_ps(1.0f / row_sums[qi]);
1326
+ __m512 inv_sum_f32x16 = _mm512_set1_ps(1.0f / row_sums[qi]);
1325
1327
  for (nk_size_t d = 0; d < head_dim; d += 16) {
1326
- __m512 o = _mm512_load_ps(&o_acc[qi][d]);
1327
- o = _mm512_mul_ps(o, inv_sum);
1328
- _mm512_storeu_ps(&o_head[(qb + qi) * head_dim + d], o);
1328
+ __m512 o_f32x16 = _mm512_load_ps(&o_acc[qi][d]);
1329
+ o_f32x16 = _mm512_mul_ps(o_f32x16, inv_sum_f32x16);
1330
+ _mm512_storeu_ps(&o_head[(qb + qi) * head_dim + d], o_f32x16);
1329
1331
  }
1330
1332
  }
1331
1333
  }