numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -214,8 +214,8 @@ NK_INTERNAL void nk_compiler_barrier_sapphireamx_(void) { __asm__ volatile("" ::
214
214
 
215
215
  /* Initialize BF16 output state to zero */
216
216
  NK_INTERNAL void nk_dots_bf16_init_sapphireamx_(nk_dots_bf16_state_sapphireamx_t *state) {
217
- __m512 zero = _mm512_setzero_ps();
218
- for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_ps(state->data[row_idx], zero); }
217
+ __m512 zero_f32x16 = _mm512_setzero_ps();
218
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_ps(state->data[row_idx], zero_f32x16); }
219
219
  }
220
220
 
221
221
  /* Load A tile from row-major source with masking for edge tiles */
@@ -225,14 +225,14 @@ NK_INTERNAL void nk_dots_bf16_load_a_sapphireamx_( //
225
225
  nk_size_t valid_rows, nk_size_t valid_cols) {
226
226
 
227
227
  __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
228
- __m512i zero = _mm512_setzero_si512();
228
+ __m512i zero_i16x32 = _mm512_setzero_si512();
229
229
 
230
230
  for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
231
231
  if (row_idx < valid_rows) {
232
- __m512i row = _mm512_maskz_loadu_epi16(column_mask, src + row_idx * src_stride_elements);
233
- _mm512_store_si512((__m512i *)a_tile->data[row_idx], row);
232
+ __m512i row_i16x32 = _mm512_maskz_loadu_epi16(column_mask, src + row_idx * src_stride_elements);
233
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], row_i16x32);
234
234
  }
235
- else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
235
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
236
236
  }
237
237
  nk_compiler_barrier_sapphireamx_();
238
238
  }
@@ -246,8 +246,8 @@ NK_INTERNAL void nk_dots_bf16_store_sapphireamx_( //
246
246
  __mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
247
247
 
248
248
  for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
249
- __m512 row = _mm512_load_ps(state->data[row_idx]);
250
- _mm512_mask_storeu_ps(dst + row_idx * dst_stride_elements, column_mask, row);
249
+ __m512 row_f32x16 = _mm512_load_ps(state->data[row_idx]);
250
+ _mm512_mask_storeu_ps(dst + row_idx * dst_stride_elements, column_mask, row_f32x16);
251
251
  }
252
252
  }
253
253
 
@@ -281,8 +281,10 @@ NK_INTERNAL void nk_dots_bf16_update_sapphireamx_( //
281
281
 
282
282
  /* Initialize INT8 output state to zero */
283
283
  NK_INTERNAL void nk_dots_i8_init_sapphireamx_(nk_dots_i8_state_sapphireamx_t *state) {
284
- __m512i zero = _mm512_setzero_si512();
285
- for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_si512((__m512i *)state->data[row_idx], zero); }
284
+ __m512i zero_i32x16 = _mm512_setzero_si512();
285
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
286
+ _mm512_store_si512((__m512i *)state->data[row_idx], zero_i32x16);
287
+ }
286
288
  }
287
289
 
288
290
  /* Load A tile from row-major source with masking for edge tiles */
@@ -292,14 +294,14 @@ NK_INTERNAL void nk_dots_i8_load_a_sapphireamx_( //
292
294
  nk_size_t valid_rows, nk_size_t valid_cols) {
293
295
 
294
296
  __mmask64 column_mask = (valid_cols >= 64) ? 0xFFFFFFFFFFFFFFFFULL : ((__mmask64)1 << valid_cols) - 1;
295
- __m512i zero = _mm512_setzero_si512();
297
+ __m512i zero_i8x64 = _mm512_setzero_si512();
296
298
 
297
299
  for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
298
300
  if (row_idx < valid_rows) {
299
- __m512i row = _mm512_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
300
- _mm512_store_si512((__m512i *)a_tile->data[row_idx], row);
301
+ __m512i row_i8x64 = _mm512_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
302
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], row_i8x64);
301
303
  }
302
- else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
304
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i8x64); }
303
305
  }
304
306
  nk_compiler_barrier_sapphireamx_();
305
307
  }
@@ -313,8 +315,8 @@ NK_INTERNAL void nk_dots_i8_store_sapphireamx_( //
313
315
  __mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
314
316
 
315
317
  for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
316
- __m512i row = _mm512_load_si512((__m512i const *)state->data[row_idx]);
317
- _mm512_mask_storeu_epi32(dst + row_idx * dst_stride_elements, column_mask, row);
318
+ __m512i row_i32x16 = _mm512_load_si512((__m512i const *)state->data[row_idx]);
319
+ _mm512_mask_storeu_epi32(dst + row_idx * dst_stride_elements, column_mask, row_i32x16);
318
320
  }
319
321
  }
320
322
 
@@ -353,24 +355,23 @@ NK_INTERNAL void nk_dots_bf16_output2x2_sapphireamx_( //
353
355
  nk_size_t valid_rows, nk_size_t valid_cols) {
354
356
 
355
357
  // Rows 0-15
356
- nk_size_t const rows_upper = (valid_rows > 16) ? 16 : valid_rows;
358
+ nk_size_t const rows_high = (valid_rows > 16) ? 16 : valid_rows;
357
359
  nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
358
360
  nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
359
361
 
360
- if (rows_upper > 0 && cols_left > 0)
361
- nk_dots_bf16_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_upper, cols_left);
362
- if (rows_upper > 0 && cols_right > 0)
363
- nk_dots_bf16_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_upper, cols_right);
362
+ if (rows_high > 0 && cols_left > 0)
363
+ nk_dots_bf16_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_high, cols_left);
364
+ if (rows_high > 0 && cols_right > 0)
365
+ nk_dots_bf16_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_high, cols_right);
364
366
 
365
367
  // Rows 16-31
366
368
  if (valid_rows > 16) {
367
- nk_size_t const rows_lower = valid_rows - 16;
368
- nk_f32_t *dst_lower = dst + 16 * dst_stride_elements;
369
+ nk_size_t const rows_low = valid_rows - 16;
370
+ nk_f32_t *dst_low = dst + 16 * dst_stride_elements;
369
371
  if (cols_left > 0)
370
- nk_dots_bf16_store_sapphireamx_(&state->c[1][0], dst_lower, dst_stride_elements, rows_lower, cols_left);
372
+ nk_dots_bf16_store_sapphireamx_(&state->c[1][0], dst_low, dst_stride_elements, rows_low, cols_left);
371
373
  if (cols_right > 0)
372
- nk_dots_bf16_store_sapphireamx_(&state->c[1][1], dst_lower + 16, dst_stride_elements, rows_lower,
373
- cols_right);
374
+ nk_dots_bf16_store_sapphireamx_(&state->c[1][1], dst_low + 16, dst_stride_elements, rows_low, cols_right);
374
375
  }
375
376
  }
376
377
 
@@ -380,22 +381,22 @@ NK_INTERNAL void nk_dots_i8_output2x2_sapphireamx_( //
380
381
  nk_i32_t *dst, nk_size_t dst_stride_elements, //
381
382
  nk_size_t valid_rows, nk_size_t valid_cols) {
382
383
 
383
- nk_size_t const rows_upper = (valid_rows > 16) ? 16 : valid_rows;
384
+ nk_size_t const rows_high = (valid_rows > 16) ? 16 : valid_rows;
384
385
  nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
385
386
  nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
386
387
 
387
- if (rows_upper > 0 && cols_left > 0)
388
- nk_dots_i8_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_upper, cols_left);
389
- if (rows_upper > 0 && cols_right > 0)
390
- nk_dots_i8_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_upper, cols_right);
388
+ if (rows_high > 0 && cols_left > 0)
389
+ nk_dots_i8_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_high, cols_left);
390
+ if (rows_high > 0 && cols_right > 0)
391
+ nk_dots_i8_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_high, cols_right);
391
392
 
392
393
  if (valid_rows > 16) {
393
- nk_size_t const rows_lower = valid_rows - 16;
394
- nk_i32_t *dst_lower = dst + 16 * dst_stride_elements;
394
+ nk_size_t const rows_low = valid_rows - 16;
395
+ nk_i32_t *dst_low = dst + 16 * dst_stride_elements;
395
396
  if (cols_left > 0)
396
- nk_dots_i8_store_sapphireamx_(&state->c[1][0], dst_lower, dst_stride_elements, rows_lower, cols_left);
397
+ nk_dots_i8_store_sapphireamx_(&state->c[1][0], dst_low, dst_stride_elements, rows_low, cols_left);
397
398
  if (cols_right > 0)
398
- nk_dots_i8_store_sapphireamx_(&state->c[1][1], dst_lower + 16, dst_stride_elements, rows_lower, cols_right);
399
+ nk_dots_i8_store_sapphireamx_(&state->c[1][1], dst_low + 16, dst_stride_elements, rows_low, cols_right);
399
400
  }
400
401
  }
401
402
 
@@ -441,114 +442,114 @@ NK_INTERNAL void nk_dots_pack_u8_transposed_sapphireamx_( //
441
442
 
442
443
  // Load all 16 rows - each row is 64 UINT8 = 64 bytes = 1 ZMM
443
444
  // Treat as 16 × 32-bit elements per row (each 32-bit = quad of UINT8)
444
- __m512i row00 = _mm512_load_si512(&a_tile->data[0][0]);
445
- __m512i row01 = _mm512_load_si512(&a_tile->data[1][0]);
446
- __m512i row02 = _mm512_load_si512(&a_tile->data[2][0]);
447
- __m512i row03 = _mm512_load_si512(&a_tile->data[3][0]);
448
- __m512i row04 = _mm512_load_si512(&a_tile->data[4][0]);
449
- __m512i row05 = _mm512_load_si512(&a_tile->data[5][0]);
450
- __m512i row06 = _mm512_load_si512(&a_tile->data[6][0]);
451
- __m512i row07 = _mm512_load_si512(&a_tile->data[7][0]);
452
- __m512i row08 = _mm512_load_si512(&a_tile->data[8][0]);
453
- __m512i row09 = _mm512_load_si512(&a_tile->data[9][0]);
454
- __m512i row10 = _mm512_load_si512(&a_tile->data[10][0]);
455
- __m512i row11 = _mm512_load_si512(&a_tile->data[11][0]);
456
- __m512i row12 = _mm512_load_si512(&a_tile->data[12][0]);
457
- __m512i row13 = _mm512_load_si512(&a_tile->data[13][0]);
458
- __m512i row14 = _mm512_load_si512(&a_tile->data[14][0]);
459
- __m512i row15 = _mm512_load_si512(&a_tile->data[15][0]);
445
+ __m512i row00_i32x16 = _mm512_load_si512(&a_tile->data[0][0]);
446
+ __m512i row01_i32x16 = _mm512_load_si512(&a_tile->data[1][0]);
447
+ __m512i row02_i32x16 = _mm512_load_si512(&a_tile->data[2][0]);
448
+ __m512i row03_i32x16 = _mm512_load_si512(&a_tile->data[3][0]);
449
+ __m512i row04_i32x16 = _mm512_load_si512(&a_tile->data[4][0]);
450
+ __m512i row05_i32x16 = _mm512_load_si512(&a_tile->data[5][0]);
451
+ __m512i row06_i32x16 = _mm512_load_si512(&a_tile->data[6][0]);
452
+ __m512i row07_i32x16 = _mm512_load_si512(&a_tile->data[7][0]);
453
+ __m512i row08_i32x16 = _mm512_load_si512(&a_tile->data[8][0]);
454
+ __m512i row09_i32x16 = _mm512_load_si512(&a_tile->data[9][0]);
455
+ __m512i row10_i32x16 = _mm512_load_si512(&a_tile->data[10][0]);
456
+ __m512i row11_i32x16 = _mm512_load_si512(&a_tile->data[11][0]);
457
+ __m512i row12_i32x16 = _mm512_load_si512(&a_tile->data[12][0]);
458
+ __m512i row13_i32x16 = _mm512_load_si512(&a_tile->data[13][0]);
459
+ __m512i row14_i32x16 = _mm512_load_si512(&a_tile->data[14][0]);
460
+ __m512i row15_i32x16 = _mm512_load_si512(&a_tile->data[15][0]);
460
461
 
461
462
  // 16×16 transpose of 32-bit elements using hierarchical unpacks
462
463
  // Stage 1: Unpack adjacent row pairs at 32-bit granularity
463
- __m512i t01_lo = _mm512_unpacklo_epi32(row00, row01);
464
- __m512i t01_hi = _mm512_unpackhi_epi32(row00, row01);
465
- __m512i t23_lo = _mm512_unpacklo_epi32(row02, row03);
466
- __m512i t23_hi = _mm512_unpackhi_epi32(row02, row03);
467
- __m512i t45_lo = _mm512_unpacklo_epi32(row04, row05);
468
- __m512i t45_hi = _mm512_unpackhi_epi32(row04, row05);
469
- __m512i t67_lo = _mm512_unpacklo_epi32(row06, row07);
470
- __m512i t67_hi = _mm512_unpackhi_epi32(row06, row07);
471
- __m512i t89_lo = _mm512_unpacklo_epi32(row08, row09);
472
- __m512i t89_hi = _mm512_unpackhi_epi32(row08, row09);
473
- __m512i tab_lo = _mm512_unpacklo_epi32(row10, row11);
474
- __m512i tab_hi = _mm512_unpackhi_epi32(row10, row11);
475
- __m512i tcd_lo = _mm512_unpacklo_epi32(row12, row13);
476
- __m512i tcd_hi = _mm512_unpackhi_epi32(row12, row13);
477
- __m512i tef_lo = _mm512_unpacklo_epi32(row14, row15);
478
- __m512i tef_hi = _mm512_unpackhi_epi32(row14, row15);
464
+ __m512i t01_low_i32x16 = _mm512_unpacklo_epi32(row00_i32x16, row01_i32x16);
465
+ __m512i t01_high_i32x16 = _mm512_unpackhi_epi32(row00_i32x16, row01_i32x16);
466
+ __m512i t23_low_i32x16 = _mm512_unpacklo_epi32(row02_i32x16, row03_i32x16);
467
+ __m512i t23_high_i32x16 = _mm512_unpackhi_epi32(row02_i32x16, row03_i32x16);
468
+ __m512i t45_low_i32x16 = _mm512_unpacklo_epi32(row04_i32x16, row05_i32x16);
469
+ __m512i t45_high_i32x16 = _mm512_unpackhi_epi32(row04_i32x16, row05_i32x16);
470
+ __m512i t67_low_i32x16 = _mm512_unpacklo_epi32(row06_i32x16, row07_i32x16);
471
+ __m512i t67_high_i32x16 = _mm512_unpackhi_epi32(row06_i32x16, row07_i32x16);
472
+ __m512i t89_low_i32x16 = _mm512_unpacklo_epi32(row08_i32x16, row09_i32x16);
473
+ __m512i t89_high_i32x16 = _mm512_unpackhi_epi32(row08_i32x16, row09_i32x16);
474
+ __m512i tab_low_i32x16 = _mm512_unpacklo_epi32(row10_i32x16, row11_i32x16);
475
+ __m512i tab_high_i32x16 = _mm512_unpackhi_epi32(row10_i32x16, row11_i32x16);
476
+ __m512i tcd_low_i32x16 = _mm512_unpacklo_epi32(row12_i32x16, row13_i32x16);
477
+ __m512i tcd_high_i32x16 = _mm512_unpackhi_epi32(row12_i32x16, row13_i32x16);
478
+ __m512i tef_low_i32x16 = _mm512_unpacklo_epi32(row14_i32x16, row15_i32x16);
479
+ __m512i tef_high_i32x16 = _mm512_unpackhi_epi32(row14_i32x16, row15_i32x16);
479
480
 
480
481
  // Stage 2: Unpack at 64-bit granularity
481
- __m512i u0123_ll = _mm512_unpacklo_epi64(t01_lo, t23_lo);
482
- __m512i u0123_lh = _mm512_unpackhi_epi64(t01_lo, t23_lo);
483
- __m512i u0123_hl = _mm512_unpacklo_epi64(t01_hi, t23_hi);
484
- __m512i u0123_hh = _mm512_unpackhi_epi64(t01_hi, t23_hi);
485
- __m512i u4567_ll = _mm512_unpacklo_epi64(t45_lo, t67_lo);
486
- __m512i u4567_lh = _mm512_unpackhi_epi64(t45_lo, t67_lo);
487
- __m512i u4567_hl = _mm512_unpacklo_epi64(t45_hi, t67_hi);
488
- __m512i u4567_hh = _mm512_unpackhi_epi64(t45_hi, t67_hi);
489
- __m512i u89ab_ll = _mm512_unpacklo_epi64(t89_lo, tab_lo);
490
- __m512i u89ab_lh = _mm512_unpackhi_epi64(t89_lo, tab_lo);
491
- __m512i u89ab_hl = _mm512_unpacklo_epi64(t89_hi, tab_hi);
492
- __m512i u89ab_hh = _mm512_unpackhi_epi64(t89_hi, tab_hi);
493
- __m512i ucdef_ll = _mm512_unpacklo_epi64(tcd_lo, tef_lo);
494
- __m512i ucdef_lh = _mm512_unpackhi_epi64(tcd_lo, tef_lo);
495
- __m512i ucdef_hl = _mm512_unpacklo_epi64(tcd_hi, tef_hi);
496
- __m512i ucdef_hh = _mm512_unpackhi_epi64(tcd_hi, tef_hi);
482
+ __m512i u0123_ll_i32x16 = _mm512_unpacklo_epi64(t01_low_i32x16, t23_low_i32x16);
483
+ __m512i u0123_lh_i32x16 = _mm512_unpackhi_epi64(t01_low_i32x16, t23_low_i32x16);
484
+ __m512i u0123_hl_i32x16 = _mm512_unpacklo_epi64(t01_high_i32x16, t23_high_i32x16);
485
+ __m512i u0123_hh_i32x16 = _mm512_unpackhi_epi64(t01_high_i32x16, t23_high_i32x16);
486
+ __m512i u4567_ll_i32x16 = _mm512_unpacklo_epi64(t45_low_i32x16, t67_low_i32x16);
487
+ __m512i u4567_lh_i32x16 = _mm512_unpackhi_epi64(t45_low_i32x16, t67_low_i32x16);
488
+ __m512i u4567_hl_i32x16 = _mm512_unpacklo_epi64(t45_high_i32x16, t67_high_i32x16);
489
+ __m512i u4567_hh_i32x16 = _mm512_unpackhi_epi64(t45_high_i32x16, t67_high_i32x16);
490
+ __m512i u89ab_ll_i32x16 = _mm512_unpacklo_epi64(t89_low_i32x16, tab_low_i32x16);
491
+ __m512i u89ab_lh_i32x16 = _mm512_unpackhi_epi64(t89_low_i32x16, tab_low_i32x16);
492
+ __m512i u89ab_hl_i32x16 = _mm512_unpacklo_epi64(t89_high_i32x16, tab_high_i32x16);
493
+ __m512i u89ab_hh_i32x16 = _mm512_unpackhi_epi64(t89_high_i32x16, tab_high_i32x16);
494
+ __m512i ucdef_ll_i32x16 = _mm512_unpacklo_epi64(tcd_low_i32x16, tef_low_i32x16);
495
+ __m512i ucdef_lh_i32x16 = _mm512_unpackhi_epi64(tcd_low_i32x16, tef_low_i32x16);
496
+ __m512i ucdef_hl_i32x16 = _mm512_unpacklo_epi64(tcd_high_i32x16, tef_high_i32x16);
497
+ __m512i ucdef_hh_i32x16 = _mm512_unpackhi_epi64(tcd_high_i32x16, tef_high_i32x16);
497
498
 
498
499
  // Stage 3: Shuffle 128-bit lanes
499
- __m512i v0_a = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0x88);
500
- __m512i v0_b = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0xDD);
501
- __m512i v1_a = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0x88);
502
- __m512i v1_b = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0xDD);
503
- __m512i v2_a = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0x88);
504
- __m512i v2_b = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0xDD);
505
- __m512i v3_a = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0x88);
506
- __m512i v3_b = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0xDD);
507
- __m512i v4_a = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0x88);
508
- __m512i v4_b = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0xDD);
509
- __m512i v5_a = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0x88);
510
- __m512i v5_b = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0xDD);
511
- __m512i v6_a = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0x88);
512
- __m512i v6_b = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0xDD);
513
- __m512i v7_a = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0x88);
514
- __m512i v7_b = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0xDD);
500
+ __m512i v0_a_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0x88);
501
+ __m512i v0_b_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0xDD);
502
+ __m512i v1_a_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0x88);
503
+ __m512i v1_b_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0xDD);
504
+ __m512i v2_a_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0x88);
505
+ __m512i v2_b_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0xDD);
506
+ __m512i v3_a_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0x88);
507
+ __m512i v3_b_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0xDD);
508
+ __m512i v4_a_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0x88);
509
+ __m512i v4_b_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0xDD);
510
+ __m512i v5_a_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0x88);
511
+ __m512i v5_b_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0xDD);
512
+ __m512i v6_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0x88);
513
+ __m512i v6_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0xDD);
514
+ __m512i v7_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0x88);
515
+ __m512i v7_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0xDD);
515
516
 
516
517
  // Stage 4: Final 256-bit shuffle to complete transpose
517
- __m512i out00 = _mm512_shuffle_i32x4(v0_a, v4_a, 0x88);
518
- __m512i out01 = _mm512_shuffle_i32x4(v1_a, v5_a, 0x88);
519
- __m512i out02 = _mm512_shuffle_i32x4(v2_a, v6_a, 0x88);
520
- __m512i out03 = _mm512_shuffle_i32x4(v3_a, v7_a, 0x88);
521
- __m512i out04 = _mm512_shuffle_i32x4(v0_a, v4_a, 0xDD);
522
- __m512i out05 = _mm512_shuffle_i32x4(v1_a, v5_a, 0xDD);
523
- __m512i out06 = _mm512_shuffle_i32x4(v2_a, v6_a, 0xDD);
524
- __m512i out07 = _mm512_shuffle_i32x4(v3_a, v7_a, 0xDD);
525
- __m512i out08 = _mm512_shuffle_i32x4(v0_b, v4_b, 0x88);
526
- __m512i out09 = _mm512_shuffle_i32x4(v1_b, v5_b, 0x88);
527
- __m512i out10 = _mm512_shuffle_i32x4(v2_b, v6_b, 0x88);
528
- __m512i out11 = _mm512_shuffle_i32x4(v3_b, v7_b, 0x88);
529
- __m512i out12 = _mm512_shuffle_i32x4(v0_b, v4_b, 0xDD);
530
- __m512i out13 = _mm512_shuffle_i32x4(v1_b, v5_b, 0xDD);
531
- __m512i out14 = _mm512_shuffle_i32x4(v2_b, v6_b, 0xDD);
532
- __m512i out15 = _mm512_shuffle_i32x4(v3_b, v7_b, 0xDD);
518
+ __m512i out00_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0x88);
519
+ __m512i out01_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0x88);
520
+ __m512i out02_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0x88);
521
+ __m512i out03_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0x88);
522
+ __m512i out04_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0xDD);
523
+ __m512i out05_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0xDD);
524
+ __m512i out06_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0xDD);
525
+ __m512i out07_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0xDD);
526
+ __m512i out08_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0x88);
527
+ __m512i out09_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0x88);
528
+ __m512i out10_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0x88);
529
+ __m512i out11_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0x88);
530
+ __m512i out12_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0xDD);
531
+ __m512i out13_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0xDD);
532
+ __m512i out14_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0xDD);
533
+ __m512i out15_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0xDD);
533
534
 
534
535
  // Store transposed results - each output row is one depth_group
535
536
  // Output layout: B.data[depth_group][column][quad] = 16 columns × 4 UINT8 = 64 bytes
536
- _mm512_store_si512(&b_tile->data[0][0][0], out00);
537
- _mm512_store_si512(&b_tile->data[1][0][0], out01);
538
- _mm512_store_si512(&b_tile->data[2][0][0], out02);
539
- _mm512_store_si512(&b_tile->data[3][0][0], out03);
540
- _mm512_store_si512(&b_tile->data[4][0][0], out08);
541
- _mm512_store_si512(&b_tile->data[5][0][0], out09);
542
- _mm512_store_si512(&b_tile->data[6][0][0], out10);
543
- _mm512_store_si512(&b_tile->data[7][0][0], out11);
544
- _mm512_store_si512(&b_tile->data[8][0][0], out04);
545
- _mm512_store_si512(&b_tile->data[9][0][0], out05);
546
- _mm512_store_si512(&b_tile->data[10][0][0], out06);
547
- _mm512_store_si512(&b_tile->data[11][0][0], out07);
548
- _mm512_store_si512(&b_tile->data[12][0][0], out12);
549
- _mm512_store_si512(&b_tile->data[13][0][0], out13);
550
- _mm512_store_si512(&b_tile->data[14][0][0], out14);
551
- _mm512_store_si512(&b_tile->data[15][0][0], out15);
537
+ _mm512_store_si512(&b_tile->data[0][0][0], out00_i32x16);
538
+ _mm512_store_si512(&b_tile->data[1][0][0], out01_i32x16);
539
+ _mm512_store_si512(&b_tile->data[2][0][0], out02_i32x16);
540
+ _mm512_store_si512(&b_tile->data[3][0][0], out03_i32x16);
541
+ _mm512_store_si512(&b_tile->data[4][0][0], out08_i32x16);
542
+ _mm512_store_si512(&b_tile->data[5][0][0], out09_i32x16);
543
+ _mm512_store_si512(&b_tile->data[6][0][0], out10_i32x16);
544
+ _mm512_store_si512(&b_tile->data[7][0][0], out11_i32x16);
545
+ _mm512_store_si512(&b_tile->data[8][0][0], out04_i32x16);
546
+ _mm512_store_si512(&b_tile->data[9][0][0], out05_i32x16);
547
+ _mm512_store_si512(&b_tile->data[10][0][0], out06_i32x16);
548
+ _mm512_store_si512(&b_tile->data[11][0][0], out07_i32x16);
549
+ _mm512_store_si512(&b_tile->data[12][0][0], out12_i32x16);
550
+ _mm512_store_si512(&b_tile->data[13][0][0], out13_i32x16);
551
+ _mm512_store_si512(&b_tile->data[14][0][0], out14_i32x16);
552
+ _mm512_store_si512(&b_tile->data[15][0][0], out15_i32x16);
552
553
 
553
554
  nk_compiler_barrier_sapphireamx_();
554
555
  }
@@ -588,17 +589,17 @@ NK_INTERNAL void nk_dots_e4m3_load_a_sapphireamx_( //
588
589
  nk_size_t valid_rows, nk_size_t valid_cols) {
589
590
 
590
591
  __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
591
- __m512i zero = _mm512_setzero_si512();
592
+ __m512i zero_i16x32 = _mm512_setzero_si512();
592
593
 
593
594
  for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
594
595
  if (row_idx < valid_rows) {
595
596
  // Load 32 E4M3 bytes with masking
596
- __m256i e4m3_row = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
597
+ __m256i e4m3_row_u8x32 = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
597
598
  // Convert to 32 BF16 values
598
- __m512i bf16_row = nk_e4m3x32_to_bf16x32_icelake_(e4m3_row);
599
- _mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row);
599
+ __m512i bf16_row_i16x32 = nk_e4m3x32_to_bf16x32_icelake_(e4m3_row_u8x32);
600
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row_i16x32);
600
601
  }
601
- else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
602
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
602
603
  }
603
604
  nk_compiler_barrier_sapphireamx_();
604
605
  }
@@ -610,15 +611,15 @@ NK_INTERNAL void nk_dots_e5m2_load_a_sapphireamx_( //
610
611
  nk_size_t valid_rows, nk_size_t valid_cols) {
611
612
 
612
613
  __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
613
- __m512i zero = _mm512_setzero_si512();
614
+ __m512i zero_i16x32 = _mm512_setzero_si512();
614
615
 
615
616
  for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
616
617
  if (row_idx < valid_rows) {
617
- __m256i e5m2_row = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
618
- __m512i bf16_row = nk_e5m2x32_to_bf16x32_icelake_(e5m2_row);
619
- _mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row);
618
+ __m256i e5m2_row_u8x32 = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
619
+ __m512i bf16_row_i16x32 = nk_e5m2x32_to_bf16x32_icelake_(e5m2_row_u8x32);
620
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row_i16x32);
620
621
  }
621
- else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
622
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
622
623
  }
623
624
  nk_compiler_barrier_sapphireamx_();
624
625
  }
@@ -630,115 +631,115 @@ NK_INTERNAL void nk_dots_pack_bf16_transposed_sapphireamx_( //
630
631
 
631
632
  // Load all 16 rows - each row is 32 BF16 = 64 bytes = 1 ZMM
632
633
  // Treat as 16 × 32-bit elements per row (each 32-bit = pair of BF16)
633
- __m512i row00 = _mm512_load_si512(&a_tile->data[0][0]);
634
- __m512i row01 = _mm512_load_si512(&a_tile->data[1][0]);
635
- __m512i row02 = _mm512_load_si512(&a_tile->data[2][0]);
636
- __m512i row03 = _mm512_load_si512(&a_tile->data[3][0]);
637
- __m512i row04 = _mm512_load_si512(&a_tile->data[4][0]);
638
- __m512i row05 = _mm512_load_si512(&a_tile->data[5][0]);
639
- __m512i row06 = _mm512_load_si512(&a_tile->data[6][0]);
640
- __m512i row07 = _mm512_load_si512(&a_tile->data[7][0]);
641
- __m512i row08 = _mm512_load_si512(&a_tile->data[8][0]);
642
- __m512i row09 = _mm512_load_si512(&a_tile->data[9][0]);
643
- __m512i row10 = _mm512_load_si512(&a_tile->data[10][0]);
644
- __m512i row11 = _mm512_load_si512(&a_tile->data[11][0]);
645
- __m512i row12 = _mm512_load_si512(&a_tile->data[12][0]);
646
- __m512i row13 = _mm512_load_si512(&a_tile->data[13][0]);
647
- __m512i row14 = _mm512_load_si512(&a_tile->data[14][0]);
648
- __m512i row15 = _mm512_load_si512(&a_tile->data[15][0]);
634
+ __m512i row00_i32x16 = _mm512_load_si512(&a_tile->data[0][0]);
635
+ __m512i row01_i32x16 = _mm512_load_si512(&a_tile->data[1][0]);
636
+ __m512i row02_i32x16 = _mm512_load_si512(&a_tile->data[2][0]);
637
+ __m512i row03_i32x16 = _mm512_load_si512(&a_tile->data[3][0]);
638
+ __m512i row04_i32x16 = _mm512_load_si512(&a_tile->data[4][0]);
639
+ __m512i row05_i32x16 = _mm512_load_si512(&a_tile->data[5][0]);
640
+ __m512i row06_i32x16 = _mm512_load_si512(&a_tile->data[6][0]);
641
+ __m512i row07_i32x16 = _mm512_load_si512(&a_tile->data[7][0]);
642
+ __m512i row08_i32x16 = _mm512_load_si512(&a_tile->data[8][0]);
643
+ __m512i row09_i32x16 = _mm512_load_si512(&a_tile->data[9][0]);
644
+ __m512i row10_i32x16 = _mm512_load_si512(&a_tile->data[10][0]);
645
+ __m512i row11_i32x16 = _mm512_load_si512(&a_tile->data[11][0]);
646
+ __m512i row12_i32x16 = _mm512_load_si512(&a_tile->data[12][0]);
647
+ __m512i row13_i32x16 = _mm512_load_si512(&a_tile->data[13][0]);
648
+ __m512i row14_i32x16 = _mm512_load_si512(&a_tile->data[14][0]);
649
+ __m512i row15_i32x16 = _mm512_load_si512(&a_tile->data[15][0]);
649
650
 
650
651
  // 16×16 transpose of 32-bit elements using hierarchical unpacks
651
652
  // Stage 1: Unpack adjacent row pairs at 32-bit granularity
652
- __m512i t01_lo = _mm512_unpacklo_epi32(row00, row01);
653
- __m512i t01_hi = _mm512_unpackhi_epi32(row00, row01);
654
- __m512i t23_lo = _mm512_unpacklo_epi32(row02, row03);
655
- __m512i t23_hi = _mm512_unpackhi_epi32(row02, row03);
656
- __m512i t45_lo = _mm512_unpacklo_epi32(row04, row05);
657
- __m512i t45_hi = _mm512_unpackhi_epi32(row04, row05);
658
- __m512i t67_lo = _mm512_unpacklo_epi32(row06, row07);
659
- __m512i t67_hi = _mm512_unpackhi_epi32(row06, row07);
660
- __m512i t89_lo = _mm512_unpacklo_epi32(row08, row09);
661
- __m512i t89_hi = _mm512_unpackhi_epi32(row08, row09);
662
- __m512i tab_lo = _mm512_unpacklo_epi32(row10, row11);
663
- __m512i tab_hi = _mm512_unpackhi_epi32(row10, row11);
664
- __m512i tcd_lo = _mm512_unpacklo_epi32(row12, row13);
665
- __m512i tcd_hi = _mm512_unpackhi_epi32(row12, row13);
666
- __m512i tef_lo = _mm512_unpacklo_epi32(row14, row15);
667
- __m512i tef_hi = _mm512_unpackhi_epi32(row14, row15);
653
+ __m512i t01_low_i32x16 = _mm512_unpacklo_epi32(row00_i32x16, row01_i32x16);
654
+ __m512i t01_high_i32x16 = _mm512_unpackhi_epi32(row00_i32x16, row01_i32x16);
655
+ __m512i t23_low_i32x16 = _mm512_unpacklo_epi32(row02_i32x16, row03_i32x16);
656
+ __m512i t23_high_i32x16 = _mm512_unpackhi_epi32(row02_i32x16, row03_i32x16);
657
+ __m512i t45_low_i32x16 = _mm512_unpacklo_epi32(row04_i32x16, row05_i32x16);
658
+ __m512i t45_high_i32x16 = _mm512_unpackhi_epi32(row04_i32x16, row05_i32x16);
659
+ __m512i t67_low_i32x16 = _mm512_unpacklo_epi32(row06_i32x16, row07_i32x16);
660
+ __m512i t67_high_i32x16 = _mm512_unpackhi_epi32(row06_i32x16, row07_i32x16);
661
+ __m512i t89_low_i32x16 = _mm512_unpacklo_epi32(row08_i32x16, row09_i32x16);
662
+ __m512i t89_high_i32x16 = _mm512_unpackhi_epi32(row08_i32x16, row09_i32x16);
663
+ __m512i tab_low_i32x16 = _mm512_unpacklo_epi32(row10_i32x16, row11_i32x16);
664
+ __m512i tab_high_i32x16 = _mm512_unpackhi_epi32(row10_i32x16, row11_i32x16);
665
+ __m512i tcd_low_i32x16 = _mm512_unpacklo_epi32(row12_i32x16, row13_i32x16);
666
+ __m512i tcd_high_i32x16 = _mm512_unpackhi_epi32(row12_i32x16, row13_i32x16);
667
+ __m512i tef_low_i32x16 = _mm512_unpacklo_epi32(row14_i32x16, row15_i32x16);
668
+ __m512i tef_high_i32x16 = _mm512_unpackhi_epi32(row14_i32x16, row15_i32x16);
668
669
 
669
670
  // Stage 2: Unpack at 64-bit granularity
670
- __m512i u0123_ll = _mm512_unpacklo_epi64(t01_lo, t23_lo);
671
- __m512i u0123_lh = _mm512_unpackhi_epi64(t01_lo, t23_lo);
672
- __m512i u0123_hl = _mm512_unpacklo_epi64(t01_hi, t23_hi);
673
- __m512i u0123_hh = _mm512_unpackhi_epi64(t01_hi, t23_hi);
674
- __m512i u4567_ll = _mm512_unpacklo_epi64(t45_lo, t67_lo);
675
- __m512i u4567_lh = _mm512_unpackhi_epi64(t45_lo, t67_lo);
676
- __m512i u4567_hl = _mm512_unpacklo_epi64(t45_hi, t67_hi);
677
- __m512i u4567_hh = _mm512_unpackhi_epi64(t45_hi, t67_hi);
678
- __m512i u89ab_ll = _mm512_unpacklo_epi64(t89_lo, tab_lo);
679
- __m512i u89ab_lh = _mm512_unpackhi_epi64(t89_lo, tab_lo);
680
- __m512i u89ab_hl = _mm512_unpacklo_epi64(t89_hi, tab_hi);
681
- __m512i u89ab_hh = _mm512_unpackhi_epi64(t89_hi, tab_hi);
682
- __m512i ucdef_ll = _mm512_unpacklo_epi64(tcd_lo, tef_lo);
683
- __m512i ucdef_lh = _mm512_unpackhi_epi64(tcd_lo, tef_lo);
684
- __m512i ucdef_hl = _mm512_unpacklo_epi64(tcd_hi, tef_hi);
685
- __m512i ucdef_hh = _mm512_unpackhi_epi64(tcd_hi, tef_hi);
671
+ __m512i u0123_ll_i32x16 = _mm512_unpacklo_epi64(t01_low_i32x16, t23_low_i32x16);
672
+ __m512i u0123_lh_i32x16 = _mm512_unpackhi_epi64(t01_low_i32x16, t23_low_i32x16);
673
+ __m512i u0123_hl_i32x16 = _mm512_unpacklo_epi64(t01_high_i32x16, t23_high_i32x16);
674
+ __m512i u0123_hh_i32x16 = _mm512_unpackhi_epi64(t01_high_i32x16, t23_high_i32x16);
675
+ __m512i u4567_ll_i32x16 = _mm512_unpacklo_epi64(t45_low_i32x16, t67_low_i32x16);
676
+ __m512i u4567_lh_i32x16 = _mm512_unpackhi_epi64(t45_low_i32x16, t67_low_i32x16);
677
+ __m512i u4567_hl_i32x16 = _mm512_unpacklo_epi64(t45_high_i32x16, t67_high_i32x16);
678
+ __m512i u4567_hh_i32x16 = _mm512_unpackhi_epi64(t45_high_i32x16, t67_high_i32x16);
679
+ __m512i u89ab_ll_i32x16 = _mm512_unpacklo_epi64(t89_low_i32x16, tab_low_i32x16);
680
+ __m512i u89ab_lh_i32x16 = _mm512_unpackhi_epi64(t89_low_i32x16, tab_low_i32x16);
681
+ __m512i u89ab_hl_i32x16 = _mm512_unpacklo_epi64(t89_high_i32x16, tab_high_i32x16);
682
+ __m512i u89ab_hh_i32x16 = _mm512_unpackhi_epi64(t89_high_i32x16, tab_high_i32x16);
683
+ __m512i ucdef_ll_i32x16 = _mm512_unpacklo_epi64(tcd_low_i32x16, tef_low_i32x16);
684
+ __m512i ucdef_lh_i32x16 = _mm512_unpackhi_epi64(tcd_low_i32x16, tef_low_i32x16);
685
+ __m512i ucdef_hl_i32x16 = _mm512_unpacklo_epi64(tcd_high_i32x16, tef_high_i32x16);
686
+ __m512i ucdef_hh_i32x16 = _mm512_unpackhi_epi64(tcd_high_i32x16, tef_high_i32x16);
686
687
 
687
688
  // Stage 3: Shuffle 128-bit lanes using permute2x128 equivalent for 512-bit
688
689
  // Use shuffle_i32x4 to move 128-bit chunks
689
- __m512i v0_a = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0x88); // lanes 0,2 from each
690
- __m512i v0_b = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0xDD); // lanes 1,3 from each
691
- __m512i v1_a = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0x88);
692
- __m512i v1_b = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0xDD);
693
- __m512i v2_a = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0x88);
694
- __m512i v2_b = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0xDD);
695
- __m512i v3_a = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0x88);
696
- __m512i v3_b = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0xDD);
697
- __m512i v4_a = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0x88);
698
- __m512i v4_b = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0xDD);
699
- __m512i v5_a = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0x88);
700
- __m512i v5_b = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0xDD);
701
- __m512i v6_a = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0x88);
702
- __m512i v6_b = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0xDD);
703
- __m512i v7_a = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0x88);
704
- __m512i v7_b = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0xDD);
690
+ __m512i v0_a_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0x88); // lanes 0,2 from each
691
+ __m512i v0_b_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0xDD); // lanes 1,3 from each
692
+ __m512i v1_a_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0x88);
693
+ __m512i v1_b_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0xDD);
694
+ __m512i v2_a_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0x88);
695
+ __m512i v2_b_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0xDD);
696
+ __m512i v3_a_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0x88);
697
+ __m512i v3_b_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0xDD);
698
+ __m512i v4_a_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0x88);
699
+ __m512i v4_b_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0xDD);
700
+ __m512i v5_a_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0x88);
701
+ __m512i v5_b_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0xDD);
702
+ __m512i v6_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0x88);
703
+ __m512i v6_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0xDD);
704
+ __m512i v7_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0x88);
705
+ __m512i v7_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0xDD);
705
706
 
706
707
  // Stage 4: Final 256-bit shuffle to complete transpose
707
- __m512i out00 = _mm512_shuffle_i32x4(v0_a, v4_a, 0x88);
708
- __m512i out01 = _mm512_shuffle_i32x4(v1_a, v5_a, 0x88);
709
- __m512i out02 = _mm512_shuffle_i32x4(v2_a, v6_a, 0x88);
710
- __m512i out03 = _mm512_shuffle_i32x4(v3_a, v7_a, 0x88);
711
- __m512i out04 = _mm512_shuffle_i32x4(v0_a, v4_a, 0xDD);
712
- __m512i out05 = _mm512_shuffle_i32x4(v1_a, v5_a, 0xDD);
713
- __m512i out06 = _mm512_shuffle_i32x4(v2_a, v6_a, 0xDD);
714
- __m512i out07 = _mm512_shuffle_i32x4(v3_a, v7_a, 0xDD);
715
- __m512i out08 = _mm512_shuffle_i32x4(v0_b, v4_b, 0x88);
716
- __m512i out09 = _mm512_shuffle_i32x4(v1_b, v5_b, 0x88);
717
- __m512i out10 = _mm512_shuffle_i32x4(v2_b, v6_b, 0x88);
718
- __m512i out11 = _mm512_shuffle_i32x4(v3_b, v7_b, 0x88);
719
- __m512i out12 = _mm512_shuffle_i32x4(v0_b, v4_b, 0xDD);
720
- __m512i out13 = _mm512_shuffle_i32x4(v1_b, v5_b, 0xDD);
721
- __m512i out14 = _mm512_shuffle_i32x4(v2_b, v6_b, 0xDD);
722
- __m512i out15 = _mm512_shuffle_i32x4(v3_b, v7_b, 0xDD);
708
+ __m512i out00_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0x88);
709
+ __m512i out01_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0x88);
710
+ __m512i out02_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0x88);
711
+ __m512i out03_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0x88);
712
+ __m512i out04_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0xDD);
713
+ __m512i out05_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0xDD);
714
+ __m512i out06_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0xDD);
715
+ __m512i out07_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0xDD);
716
+ __m512i out08_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0x88);
717
+ __m512i out09_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0x88);
718
+ __m512i out10_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0x88);
719
+ __m512i out11_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0x88);
720
+ __m512i out12_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0xDD);
721
+ __m512i out13_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0xDD);
722
+ __m512i out14_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0xDD);
723
+ __m512i out15_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0xDD);
723
724
 
724
725
  // Store transposed results - each output row is one depth_group
725
726
  // Output layout: B.data[depth_group][column][pair] = 16 columns × 2 BF16 = 64 bytes
726
- _mm512_store_si512(&b_tile->data[0][0][0], out00);
727
- _mm512_store_si512(&b_tile->data[1][0][0], out01);
728
- _mm512_store_si512(&b_tile->data[2][0][0], out02);
729
- _mm512_store_si512(&b_tile->data[3][0][0], out03);
730
- _mm512_store_si512(&b_tile->data[4][0][0], out08);
731
- _mm512_store_si512(&b_tile->data[5][0][0], out09);
732
- _mm512_store_si512(&b_tile->data[6][0][0], out10);
733
- _mm512_store_si512(&b_tile->data[7][0][0], out11);
734
- _mm512_store_si512(&b_tile->data[8][0][0], out04);
735
- _mm512_store_si512(&b_tile->data[9][0][0], out05);
736
- _mm512_store_si512(&b_tile->data[10][0][0], out06);
737
- _mm512_store_si512(&b_tile->data[11][0][0], out07);
738
- _mm512_store_si512(&b_tile->data[12][0][0], out12);
739
- _mm512_store_si512(&b_tile->data[13][0][0], out13);
740
- _mm512_store_si512(&b_tile->data[14][0][0], out14);
741
- _mm512_store_si512(&b_tile->data[15][0][0], out15);
727
+ _mm512_store_si512(&b_tile->data[0][0][0], out00_i32x16);
728
+ _mm512_store_si512(&b_tile->data[1][0][0], out01_i32x16);
729
+ _mm512_store_si512(&b_tile->data[2][0][0], out02_i32x16);
730
+ _mm512_store_si512(&b_tile->data[3][0][0], out03_i32x16);
731
+ _mm512_store_si512(&b_tile->data[4][0][0], out08_i32x16);
732
+ _mm512_store_si512(&b_tile->data[5][0][0], out09_i32x16);
733
+ _mm512_store_si512(&b_tile->data[6][0][0], out10_i32x16);
734
+ _mm512_store_si512(&b_tile->data[7][0][0], out11_i32x16);
735
+ _mm512_store_si512(&b_tile->data[8][0][0], out04_i32x16);
736
+ _mm512_store_si512(&b_tile->data[9][0][0], out05_i32x16);
737
+ _mm512_store_si512(&b_tile->data[10][0][0], out06_i32x16);
738
+ _mm512_store_si512(&b_tile->data[11][0][0], out07_i32x16);
739
+ _mm512_store_si512(&b_tile->data[12][0][0], out12_i32x16);
740
+ _mm512_store_si512(&b_tile->data[13][0][0], out13_i32x16);
741
+ _mm512_store_si512(&b_tile->data[14][0][0], out14_i32x16);
742
+ _mm512_store_si512(&b_tile->data[15][0][0], out15_i32x16);
742
743
 
743
744
  nk_compiler_barrier_sapphireamx_();
744
745
  }
@@ -750,119 +751,119 @@ NK_INTERNAL void nk_dots_pack_i8_transposed_sapphireamx_( //
750
751
 
751
752
  // Load all 16 rows - each row is 64 INT8 = 64 bytes = 1 ZMM
752
753
  // Treat as 16 × 32-bit elements per row (each 32-bit = quad of INT8)
753
- __m512i row00 = _mm512_load_si512(&a_tile->data[0][0]);
754
- __m512i row01 = _mm512_load_si512(&a_tile->data[1][0]);
755
- __m512i row02 = _mm512_load_si512(&a_tile->data[2][0]);
756
- __m512i row03 = _mm512_load_si512(&a_tile->data[3][0]);
757
- __m512i row04 = _mm512_load_si512(&a_tile->data[4][0]);
758
- __m512i row05 = _mm512_load_si512(&a_tile->data[5][0]);
759
- __m512i row06 = _mm512_load_si512(&a_tile->data[6][0]);
760
- __m512i row07 = _mm512_load_si512(&a_tile->data[7][0]);
761
- __m512i row08 = _mm512_load_si512(&a_tile->data[8][0]);
762
- __m512i row09 = _mm512_load_si512(&a_tile->data[9][0]);
763
- __m512i row10 = _mm512_load_si512(&a_tile->data[10][0]);
764
- __m512i row11 = _mm512_load_si512(&a_tile->data[11][0]);
765
- __m512i row12 = _mm512_load_si512(&a_tile->data[12][0]);
766
- __m512i row13 = _mm512_load_si512(&a_tile->data[13][0]);
767
- __m512i row14 = _mm512_load_si512(&a_tile->data[14][0]);
768
- __m512i row15 = _mm512_load_si512(&a_tile->data[15][0]);
754
+ __m512i row00_i32x16 = _mm512_load_si512(&a_tile->data[0][0]);
755
+ __m512i row01_i32x16 = _mm512_load_si512(&a_tile->data[1][0]);
756
+ __m512i row02_i32x16 = _mm512_load_si512(&a_tile->data[2][0]);
757
+ __m512i row03_i32x16 = _mm512_load_si512(&a_tile->data[3][0]);
758
+ __m512i row04_i32x16 = _mm512_load_si512(&a_tile->data[4][0]);
759
+ __m512i row05_i32x16 = _mm512_load_si512(&a_tile->data[5][0]);
760
+ __m512i row06_i32x16 = _mm512_load_si512(&a_tile->data[6][0]);
761
+ __m512i row07_i32x16 = _mm512_load_si512(&a_tile->data[7][0]);
762
+ __m512i row08_i32x16 = _mm512_load_si512(&a_tile->data[8][0]);
763
+ __m512i row09_i32x16 = _mm512_load_si512(&a_tile->data[9][0]);
764
+ __m512i row10_i32x16 = _mm512_load_si512(&a_tile->data[10][0]);
765
+ __m512i row11_i32x16 = _mm512_load_si512(&a_tile->data[11][0]);
766
+ __m512i row12_i32x16 = _mm512_load_si512(&a_tile->data[12][0]);
767
+ __m512i row13_i32x16 = _mm512_load_si512(&a_tile->data[13][0]);
768
+ __m512i row14_i32x16 = _mm512_load_si512(&a_tile->data[14][0]);
769
+ __m512i row15_i32x16 = _mm512_load_si512(&a_tile->data[15][0]);
769
770
 
770
771
  // 16×16 transpose of 32-bit elements using hierarchical unpacks
771
772
  // Stage 1: Unpack adjacent row pairs at 32-bit granularity
772
- __m512i t01_lo = _mm512_unpacklo_epi32(row00, row01);
773
- __m512i t01_hi = _mm512_unpackhi_epi32(row00, row01);
774
- __m512i t23_lo = _mm512_unpacklo_epi32(row02, row03);
775
- __m512i t23_hi = _mm512_unpackhi_epi32(row02, row03);
776
- __m512i t45_lo = _mm512_unpacklo_epi32(row04, row05);
777
- __m512i t45_hi = _mm512_unpackhi_epi32(row04, row05);
778
- __m512i t67_lo = _mm512_unpacklo_epi32(row06, row07);
779
- __m512i t67_hi = _mm512_unpackhi_epi32(row06, row07);
780
- __m512i t89_lo = _mm512_unpacklo_epi32(row08, row09);
781
- __m512i t89_hi = _mm512_unpackhi_epi32(row08, row09);
782
- __m512i tab_lo = _mm512_unpacklo_epi32(row10, row11);
783
- __m512i tab_hi = _mm512_unpackhi_epi32(row10, row11);
784
- __m512i tcd_lo = _mm512_unpacklo_epi32(row12, row13);
785
- __m512i tcd_hi = _mm512_unpackhi_epi32(row12, row13);
786
- __m512i tef_lo = _mm512_unpacklo_epi32(row14, row15);
787
- __m512i tef_hi = _mm512_unpackhi_epi32(row14, row15);
773
+ __m512i t01_low_i32x16 = _mm512_unpacklo_epi32(row00_i32x16, row01_i32x16);
774
+ __m512i t01_high_i32x16 = _mm512_unpackhi_epi32(row00_i32x16, row01_i32x16);
775
+ __m512i t23_low_i32x16 = _mm512_unpacklo_epi32(row02_i32x16, row03_i32x16);
776
+ __m512i t23_high_i32x16 = _mm512_unpackhi_epi32(row02_i32x16, row03_i32x16);
777
+ __m512i t45_low_i32x16 = _mm512_unpacklo_epi32(row04_i32x16, row05_i32x16);
778
+ __m512i t45_high_i32x16 = _mm512_unpackhi_epi32(row04_i32x16, row05_i32x16);
779
+ __m512i t67_low_i32x16 = _mm512_unpacklo_epi32(row06_i32x16, row07_i32x16);
780
+ __m512i t67_high_i32x16 = _mm512_unpackhi_epi32(row06_i32x16, row07_i32x16);
781
+ __m512i t89_low_i32x16 = _mm512_unpacklo_epi32(row08_i32x16, row09_i32x16);
782
+ __m512i t89_high_i32x16 = _mm512_unpackhi_epi32(row08_i32x16, row09_i32x16);
783
+ __m512i tab_low_i32x16 = _mm512_unpacklo_epi32(row10_i32x16, row11_i32x16);
784
+ __m512i tab_high_i32x16 = _mm512_unpackhi_epi32(row10_i32x16, row11_i32x16);
785
+ __m512i tcd_low_i32x16 = _mm512_unpacklo_epi32(row12_i32x16, row13_i32x16);
786
+ __m512i tcd_high_i32x16 = _mm512_unpackhi_epi32(row12_i32x16, row13_i32x16);
787
+ __m512i tef_low_i32x16 = _mm512_unpacklo_epi32(row14_i32x16, row15_i32x16);
788
+ __m512i tef_high_i32x16 = _mm512_unpackhi_epi32(row14_i32x16, row15_i32x16);
788
789
 
789
790
  // Stage 2: Unpack at 64-bit granularity
790
- __m512i u0123_ll = _mm512_unpacklo_epi64(t01_lo, t23_lo);
791
- __m512i u0123_lh = _mm512_unpackhi_epi64(t01_lo, t23_lo);
792
- __m512i u0123_hl = _mm512_unpacklo_epi64(t01_hi, t23_hi);
793
- __m512i u0123_hh = _mm512_unpackhi_epi64(t01_hi, t23_hi);
794
- __m512i u4567_ll = _mm512_unpacklo_epi64(t45_lo, t67_lo);
795
- __m512i u4567_lh = _mm512_unpackhi_epi64(t45_lo, t67_lo);
796
- __m512i u4567_hl = _mm512_unpacklo_epi64(t45_hi, t67_hi);
797
- __m512i u4567_hh = _mm512_unpackhi_epi64(t45_hi, t67_hi);
798
- __m512i u89ab_ll = _mm512_unpacklo_epi64(t89_lo, tab_lo);
799
- __m512i u89ab_lh = _mm512_unpackhi_epi64(t89_lo, tab_lo);
800
- __m512i u89ab_hl = _mm512_unpacklo_epi64(t89_hi, tab_hi);
801
- __m512i u89ab_hh = _mm512_unpackhi_epi64(t89_hi, tab_hi);
802
- __m512i ucdef_ll = _mm512_unpacklo_epi64(tcd_lo, tef_lo);
803
- __m512i ucdef_lh = _mm512_unpackhi_epi64(tcd_lo, tef_lo);
804
- __m512i ucdef_hl = _mm512_unpacklo_epi64(tcd_hi, tef_hi);
805
- __m512i ucdef_hh = _mm512_unpackhi_epi64(tcd_hi, tef_hi);
791
+ __m512i u0123_ll_i32x16 = _mm512_unpacklo_epi64(t01_low_i32x16, t23_low_i32x16);
792
+ __m512i u0123_lh_i32x16 = _mm512_unpackhi_epi64(t01_low_i32x16, t23_low_i32x16);
793
+ __m512i u0123_hl_i32x16 = _mm512_unpacklo_epi64(t01_high_i32x16, t23_high_i32x16);
794
+ __m512i u0123_hh_i32x16 = _mm512_unpackhi_epi64(t01_high_i32x16, t23_high_i32x16);
795
+ __m512i u4567_ll_i32x16 = _mm512_unpacklo_epi64(t45_low_i32x16, t67_low_i32x16);
796
+ __m512i u4567_lh_i32x16 = _mm512_unpackhi_epi64(t45_low_i32x16, t67_low_i32x16);
797
+ __m512i u4567_hl_i32x16 = _mm512_unpacklo_epi64(t45_high_i32x16, t67_high_i32x16);
798
+ __m512i u4567_hh_i32x16 = _mm512_unpackhi_epi64(t45_high_i32x16, t67_high_i32x16);
799
+ __m512i u89ab_ll_i32x16 = _mm512_unpacklo_epi64(t89_low_i32x16, tab_low_i32x16);
800
+ __m512i u89ab_lh_i32x16 = _mm512_unpackhi_epi64(t89_low_i32x16, tab_low_i32x16);
801
+ __m512i u89ab_hl_i32x16 = _mm512_unpacklo_epi64(t89_high_i32x16, tab_high_i32x16);
802
+ __m512i u89ab_hh_i32x16 = _mm512_unpackhi_epi64(t89_high_i32x16, tab_high_i32x16);
803
+ __m512i ucdef_ll_i32x16 = _mm512_unpacklo_epi64(tcd_low_i32x16, tef_low_i32x16);
804
+ __m512i ucdef_lh_i32x16 = _mm512_unpackhi_epi64(tcd_low_i32x16, tef_low_i32x16);
805
+ __m512i ucdef_hl_i32x16 = _mm512_unpacklo_epi64(tcd_high_i32x16, tef_high_i32x16);
806
+ __m512i ucdef_hh_i32x16 = _mm512_unpackhi_epi64(tcd_high_i32x16, tef_high_i32x16);
806
807
 
807
808
  // Stage 3: Shuffle 128-bit lanes
808
- __m512i v0_a = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0x88);
809
- __m512i v0_b = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0xDD);
810
- __m512i v1_a = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0x88);
811
- __m512i v1_b = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0xDD);
812
- __m512i v2_a = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0x88);
813
- __m512i v2_b = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0xDD);
814
- __m512i v3_a = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0x88);
815
- __m512i v3_b = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0xDD);
816
- __m512i v4_a = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0x88);
817
- __m512i v4_b = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0xDD);
818
- __m512i v5_a = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0x88);
819
- __m512i v5_b = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0xDD);
820
- __m512i v6_a = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0x88);
821
- __m512i v6_b = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0xDD);
822
- __m512i v7_a = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0x88);
823
- __m512i v7_b = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0xDD);
809
+ __m512i v0_a_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0x88);
810
+ __m512i v0_b_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0xDD);
811
+ __m512i v1_a_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0x88);
812
+ __m512i v1_b_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0xDD);
813
+ __m512i v2_a_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0x88);
814
+ __m512i v2_b_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0xDD);
815
+ __m512i v3_a_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0x88);
816
+ __m512i v3_b_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0xDD);
817
+ __m512i v4_a_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0x88);
818
+ __m512i v4_b_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0xDD);
819
+ __m512i v5_a_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0x88);
820
+ __m512i v5_b_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0xDD);
821
+ __m512i v6_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0x88);
822
+ __m512i v6_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0xDD);
823
+ __m512i v7_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0x88);
824
+ __m512i v7_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0xDD);
824
825
 
825
826
  // Stage 4: Final 256-bit shuffle to complete transpose
826
- __m512i out00 = _mm512_shuffle_i32x4(v0_a, v4_a, 0x88);
827
- __m512i out01 = _mm512_shuffle_i32x4(v1_a, v5_a, 0x88);
828
- __m512i out02 = _mm512_shuffle_i32x4(v2_a, v6_a, 0x88);
829
- __m512i out03 = _mm512_shuffle_i32x4(v3_a, v7_a, 0x88);
830
- __m512i out04 = _mm512_shuffle_i32x4(v0_a, v4_a, 0xDD);
831
- __m512i out05 = _mm512_shuffle_i32x4(v1_a, v5_a, 0xDD);
832
- __m512i out06 = _mm512_shuffle_i32x4(v2_a, v6_a, 0xDD);
833
- __m512i out07 = _mm512_shuffle_i32x4(v3_a, v7_a, 0xDD);
834
- __m512i out08 = _mm512_shuffle_i32x4(v0_b, v4_b, 0x88);
835
- __m512i out09 = _mm512_shuffle_i32x4(v1_b, v5_b, 0x88);
836
- __m512i out10 = _mm512_shuffle_i32x4(v2_b, v6_b, 0x88);
837
- __m512i out11 = _mm512_shuffle_i32x4(v3_b, v7_b, 0x88);
838
- __m512i out12 = _mm512_shuffle_i32x4(v0_b, v4_b, 0xDD);
839
- __m512i out13 = _mm512_shuffle_i32x4(v1_b, v5_b, 0xDD);
840
- __m512i out14 = _mm512_shuffle_i32x4(v2_b, v6_b, 0xDD);
841
- __m512i out15 = _mm512_shuffle_i32x4(v3_b, v7_b, 0xDD);
827
+ __m512i out00_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0x88);
828
+ __m512i out01_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0x88);
829
+ __m512i out02_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0x88);
830
+ __m512i out03_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0x88);
831
+ __m512i out04_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0xDD);
832
+ __m512i out05_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0xDD);
833
+ __m512i out06_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0xDD);
834
+ __m512i out07_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0xDD);
835
+ __m512i out08_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0x88);
836
+ __m512i out09_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0x88);
837
+ __m512i out10_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0x88);
838
+ __m512i out11_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0x88);
839
+ __m512i out12_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0xDD);
840
+ __m512i out13_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0xDD);
841
+ __m512i out14_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0xDD);
842
+ __m512i out15_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0xDD);
842
843
 
843
844
  // Store transposed results - each output row is one depth_group
844
845
  // Output layout: B.data[depth_group][column][quad] = 16 columns × 4 INT8 = 64 bytes
845
- _mm512_store_si512(&b_tile->data[0][0][0], out00);
846
- _mm512_store_si512(&b_tile->data[1][0][0], out01);
847
- _mm512_store_si512(&b_tile->data[2][0][0], out02);
848
- _mm512_store_si512(&b_tile->data[3][0][0], out03);
849
- _mm512_store_si512(&b_tile->data[4][0][0], out08);
850
- _mm512_store_si512(&b_tile->data[5][0][0], out09);
851
- _mm512_store_si512(&b_tile->data[6][0][0], out10);
852
- _mm512_store_si512(&b_tile->data[7][0][0], out11);
853
- _mm512_store_si512(&b_tile->data[8][0][0], out04);
854
- _mm512_store_si512(&b_tile->data[9][0][0], out05);
855
- _mm512_store_si512(&b_tile->data[10][0][0], out06);
856
- _mm512_store_si512(&b_tile->data[11][0][0], out07);
857
- _mm512_store_si512(&b_tile->data[12][0][0], out12);
858
- _mm512_store_si512(&b_tile->data[13][0][0], out13);
859
- _mm512_store_si512(&b_tile->data[14][0][0], out14);
860
- _mm512_store_si512(&b_tile->data[15][0][0], out15);
846
+ _mm512_store_si512(&b_tile->data[0][0][0], out00_i32x16);
847
+ _mm512_store_si512(&b_tile->data[1][0][0], out01_i32x16);
848
+ _mm512_store_si512(&b_tile->data[2][0][0], out02_i32x16);
849
+ _mm512_store_si512(&b_tile->data[3][0][0], out03_i32x16);
850
+ _mm512_store_si512(&b_tile->data[4][0][0], out08_i32x16);
851
+ _mm512_store_si512(&b_tile->data[5][0][0], out09_i32x16);
852
+ _mm512_store_si512(&b_tile->data[6][0][0], out10_i32x16);
853
+ _mm512_store_si512(&b_tile->data[7][0][0], out11_i32x16);
854
+ _mm512_store_si512(&b_tile->data[8][0][0], out04_i32x16);
855
+ _mm512_store_si512(&b_tile->data[9][0][0], out05_i32x16);
856
+ _mm512_store_si512(&b_tile->data[10][0][0], out06_i32x16);
857
+ _mm512_store_si512(&b_tile->data[11][0][0], out07_i32x16);
858
+ _mm512_store_si512(&b_tile->data[12][0][0], out12_i32x16);
859
+ _mm512_store_si512(&b_tile->data[13][0][0], out13_i32x16);
860
+ _mm512_store_si512(&b_tile->data[14][0][0], out14_i32x16);
861
+ _mm512_store_si512(&b_tile->data[15][0][0], out15_i32x16);
861
862
 
862
863
  nk_compiler_barrier_sapphireamx_();
863
864
  }
864
865
 
865
- #pragma region Half Precision Floats
866
+ #pragma region F16 Floats
866
867
 
867
868
  NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sapphireamx(nk_size_t column_count, nk_size_t depth) {
868
869
  nk_size_t const tmm_rows = 16;
@@ -890,14 +891,14 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sapphireamx(nk_size_t column_count,
890
891
 
891
892
  NK_PUBLIC void nk_dots_pack_bf16_sapphireamx( //
892
893
  nk_bf16_t const *b, nk_size_t column_count, nk_size_t depth, //
893
- nk_size_t b_stride, void *b_packed) {
894
+ nk_size_t b_stride_in_bytes, void *b_packed) {
894
895
 
895
896
  // AMX BF16 tile dimensions: 16 rows × 32 columns (512 BF16 elements = 1KB)
896
897
  nk_size_t const tmm_rows = 16;
897
898
  nk_size_t const tmm_cols = 32;
898
899
  nk_size_t const tile_elements = 512;
899
900
  nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
900
- nk_size_t const b_stride_elements = b_stride / sizeof(nk_bf16_t);
901
+ nk_size_t const b_stride_elements = b_stride_in_bytes / sizeof(nk_bf16_t);
901
902
 
902
903
  // Compute layout dimensions
903
904
  nk_size_t const column_tiles_count = column_count / tmm_rows;
@@ -920,36 +921,40 @@ NK_PUBLIC void nk_dots_pack_bf16_sapphireamx( //
920
921
  nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
921
922
  nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
922
923
 
923
- // Zero-initialize all tiles (handles depth remainder padding)
924
- for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
925
-
926
- // Pack tiles using LINEAR ordering: tile_index = column_tile × depth_tiles_count + depth_tile
927
- // This provides sequential memory access when streaming along depth dimension,
928
- // which is critical for cache efficiency in the compute kernel.
924
+ // Pack tiles using vectorized transposer: gather 16 strided rows into an aligned
925
+ // temporary, transpose via SIMD, then copy the result to the packed buffer.
929
926
  for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
930
927
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
931
928
 
932
- // Linear tile index: all depth-tiles for one column-tile are contiguous
933
929
  nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
934
930
  nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
935
931
 
936
- // Source coordinates in original B matrix
937
932
  nk_size_t const src_row_start = column_tile_idx * tmm_rows;
938
933
  nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
939
934
  nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
940
935
  : (depth - src_column_start);
941
936
 
942
- // Pack with pair-interleaving as required by TDPBF16PS instruction.
943
- // AMX expects: [col0_row0, col1_row0, col0_row1, col1_row1, col2_row0, col3_row0, ...]
944
- // Formula: packed_idx = (column / 2) × 32 + row × 2 + (column % 2)
945
- for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
946
- for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
947
- nk_size_t const src_idx = (src_row_start + row_idx) * b_stride_elements + src_column_start +
948
- column_idx;
949
- nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
950
- tile_output[dst_idx] = b[src_idx];
937
+ // Gather 16 strided source rows into a contiguous aligned tile
938
+ nk_dots_bf16_a16x32_sapphireamx_t source_tile;
939
+ if (columns_to_pack == tmm_cols) {
940
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
941
+ nk_bf16_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
942
+ _mm512_store_si512(&source_tile.data[row_idx][0], _mm512_loadu_si512(source_row));
943
+ }
944
+ }
945
+ else {
946
+ __mmask32 depth_mask = (__mmask32)((columns_to_pack < 32) ? ((1U << columns_to_pack) - 1) : ~0U);
947
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
948
+ nk_bf16_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
949
+ _mm512_store_si512(&source_tile.data[row_idx][0], _mm512_maskz_loadu_epi16(depth_mask, source_row));
951
950
  }
952
951
  }
952
+
953
+ // Transpose into aligned local, then copy to (potentially unaligned) packed buffer
954
+ nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
955
+ nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
956
+ for (nk_size_t i = 0; i < tile_bytes; i += 64)
957
+ _mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
953
958
  }
954
959
  }
955
960
 
@@ -1004,7 +1009,7 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1004
1009
  if (depth_tiles_count == 0) return;
1005
1010
 
1006
1011
  // Tile buffers for A (only used for edge tiles)
1007
- nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
1012
+ nk_dots_bf16_a16x32_sapphireamx_t a_tile_top, a_tile_bottom;
1008
1013
  nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
1009
1014
 
1010
1015
  // Precompute: number of full depth-tiles (no masking needed)
@@ -1033,8 +1038,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1033
1038
 
1034
1039
  // Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
1035
1040
  if (is_full_row_block && full_depth_tiles_count > 0) {
1036
- nk_bf16_t const *a_upper_base = a + row_block_start * a_stride_elements;
1037
- nk_bf16_t const *a_lower_base = a + (row_block_start + 16) * a_stride_elements;
1041
+ nk_bf16_t const *a_top_base = a + row_block_start * a_stride_elements;
1042
+ nk_bf16_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_elements;
1038
1043
 
1039
1044
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
1040
1045
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
@@ -1042,8 +1047,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1042
1047
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
1043
1048
 
1044
1049
  // Prologue: load first depth tile
1045
- _tile_loadd(0, a_upper_base, a_stride_bytes);
1046
- _tile_loadd(1, a_lower_base, a_stride_bytes);
1050
+ _tile_loadd(0, a_top_base, a_stride_bytes);
1051
+ _tile_loadd(1, a_bottom_base, a_stride_bytes);
1047
1052
  _tile_loadd(2, b_tile_left->data, 64);
1048
1053
  _tile_loadd(3, b_tile_right->data, 64);
1049
1054
 
@@ -1056,8 +1061,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1056
1061
  _tile_dpbf16ps(6, 1, 2);
1057
1062
  _tile_dpbf16ps(7, 1, 3);
1058
1063
 
1059
- _tile_loadd(0, a_upper_base + next_depth_offset, a_stride_bytes);
1060
- _tile_loadd(1, a_lower_base + next_depth_offset, a_stride_bytes);
1064
+ _tile_loadd(0, a_top_base + next_depth_offset, a_stride_bytes);
1065
+ _tile_loadd(1, a_bottom_base + next_depth_offset, a_stride_bytes);
1061
1066
  b_tile_left = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
1062
1067
  depth_tile_idx + 1) *
1063
1068
  tile_size);
@@ -1078,10 +1083,10 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1078
1083
  if (depth_remainder > 0) {
1079
1084
  nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
1080
1085
 
1081
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a_upper_base + depth_offset, a_stride_elements, 16,
1082
- depth_remainder);
1083
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_elements, 16,
1086
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_top, a_top_base + depth_offset, a_stride_elements, 16,
1084
1087
  depth_remainder);
1088
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base + depth_offset, a_stride_elements,
1089
+ 16, depth_remainder);
1085
1090
 
1086
1091
  b_tile_left = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
1087
1092
  full_depth_tiles_count) *
@@ -1090,8 +1095,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1090
1095
  full_depth_tiles_count) *
1091
1096
  tile_size);
1092
1097
 
1093
- _tile_loadd(0, a_tile_upper.data, 64);
1094
- _tile_loadd(1, a_tile_lower.data, 64);
1098
+ _tile_loadd(0, a_tile_top.data, 64);
1099
+ _tile_loadd(1, a_tile_bottom.data, 64);
1095
1100
  _tile_loadd(2, b_tile_left->data, 64);
1096
1101
  _tile_loadd(3, b_tile_right->data, 64);
1097
1102
 
@@ -1103,19 +1108,19 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1103
1108
  }
1104
1109
  // Full row-block but only partial depth tile (depth < tile_depth)
1105
1110
  else if (is_full_row_block) {
1106
- nk_bf16_t const *a_upper_base = a + row_block_start * a_stride_elements;
1107
- nk_bf16_t const *a_lower_base = a + (row_block_start + 16) * a_stride_elements;
1111
+ nk_bf16_t const *a_top_base = a + row_block_start * a_stride_elements;
1112
+ nk_bf16_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_elements;
1108
1113
 
1109
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a_upper_base, a_stride_elements, 16, depth_remainder);
1110
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower, a_lower_base, a_stride_elements, 16, depth_remainder);
1114
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_top, a_top_base, a_stride_elements, 16, depth_remainder);
1115
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base, a_stride_elements, 16, depth_remainder);
1111
1116
 
1112
1117
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
1113
1118
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
1114
1119
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
1115
1120
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
1116
1121
 
1117
- _tile_loadd(0, a_tile_upper.data, 64);
1118
- _tile_loadd(1, a_tile_lower.data, 64);
1122
+ _tile_loadd(0, a_tile_top.data, 64);
1123
+ _tile_loadd(1, a_tile_bottom.data, 64);
1119
1124
  _tile_loadd(2, b_tile_left->data, 64);
1120
1125
  _tile_loadd(3, b_tile_right->data, 64);
1121
1126
 
@@ -1126,21 +1131,21 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1126
1131
  }
1127
1132
  // Slow path: edge row-block → buffered load with masking
1128
1133
  else {
1129
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1130
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1134
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1135
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1131
1136
 
1132
1137
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1133
1138
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1134
1139
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
1135
1140
  : depth_remainder;
1136
1141
 
1137
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper,
1142
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_top,
1138
1143
  a + row_block_start * a_stride_elements + depth_offset,
1139
- a_stride_elements, rows_in_upper_tile, valid_depth);
1140
- if (rows_in_lower_tile > 0) {
1141
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower,
1144
+ a_stride_elements, rows_in_high_tile, valid_depth);
1145
+ if (rows_in_low_tile > 0) {
1146
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom,
1142
1147
  a + (row_block_start + 16) * a_stride_elements + depth_offset,
1143
- a_stride_elements, rows_in_lower_tile, valid_depth);
1148
+ a_stride_elements, rows_in_low_tile, valid_depth);
1144
1149
  }
1145
1150
 
1146
1151
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
@@ -1150,8 +1155,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1150
1155
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
1151
1156
  (b_column_right_base + depth_tile_idx) * tile_size);
1152
1157
 
1153
- _tile_loadd(0, a_tile_upper.data, 64);
1154
- _tile_loadd(1, a_tile_lower.data, 64);
1158
+ _tile_loadd(0, a_tile_top.data, 64);
1159
+ _tile_loadd(1, a_tile_bottom.data, 64);
1155
1160
  _tile_loadd(2, b_tile_left->data, 64);
1156
1161
  _tile_loadd(3, b_tile_right->data, 64);
1157
1162
 
@@ -1192,10 +1197,10 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1192
1197
  nk_size_t const row_block_start = row_block_idx * 32;
1193
1198
  nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
1194
1199
  : (rows_count - row_block_start);
1195
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1196
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1200
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1201
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1197
1202
 
1198
- nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
1203
+ nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
1199
1204
 
1200
1205
  _tile_zero(4);
1201
1206
  _tile_zero(6);
@@ -1204,35 +1209,35 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1204
1209
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1205
1210
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
1206
1211
 
1207
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_elements + depth_offset,
1208
- a_stride_elements, rows_in_upper_tile, valid_depth);
1209
- if (rows_in_lower_tile > 0) {
1210
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower,
1212
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
1213
+ a_stride_elements, rows_in_high_tile, valid_depth);
1214
+ if (rows_in_low_tile > 0) {
1215
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom,
1211
1216
  a + (row_block_start + 16) * a_stride_elements + depth_offset,
1212
- a_stride_elements, rows_in_lower_tile, valid_depth);
1217
+ a_stride_elements, rows_in_low_tile, valid_depth);
1213
1218
  }
1214
1219
 
1215
1220
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
1216
1221
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
1217
1222
  (b_column_base + depth_tile_idx) * tile_size);
1218
1223
 
1219
- _tile_loadd(0, a_tile_upper.data, 64);
1220
- _tile_loadd(1, a_tile_lower.data, 64);
1224
+ _tile_loadd(0, a_tile_top.data, 64);
1225
+ _tile_loadd(1, a_tile_bottom.data, 64);
1221
1226
  _tile_loadd(2, b_tile->data, 64);
1222
1227
 
1223
1228
  _tile_dpbf16ps(4, 0, 2);
1224
1229
  _tile_dpbf16ps(6, 1, 2);
1225
1230
  }
1226
1231
 
1227
- _tile_stored(4, c_upper_state.data, 64);
1228
- _tile_stored(6, c_lower_state.data, 64);
1232
+ _tile_stored(4, c_high_state.data, 64);
1233
+ _tile_stored(6, c_low_state.data, 64);
1229
1234
 
1230
- nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
1231
- c_stride_elements, rows_in_upper_tile, 16);
1232
- if (rows_in_lower_tile > 0) {
1233
- nk_dots_bf16_store_sapphireamx_(&c_lower_state,
1235
+ nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
1236
+ c_stride_elements, rows_in_high_tile, 16);
1237
+ if (rows_in_low_tile > 0) {
1238
+ nk_dots_bf16_store_sapphireamx_(&c_low_state,
1234
1239
  c + (row_block_start + 16) * c_stride_elements + col_start,
1235
- c_stride_elements, rows_in_lower_tile, 16);
1240
+ c_stride_elements, rows_in_low_tile, 16);
1236
1241
  }
1237
1242
  }
1238
1243
  }
@@ -1243,10 +1248,10 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1243
1248
  nk_size_t const row_block_start = row_block_idx * 32;
1244
1249
  nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
1245
1250
  : (rows_count - row_block_start);
1246
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1247
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1251
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1252
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1248
1253
 
1249
- nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
1254
+ nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
1250
1255
  nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
1251
1256
  nk_dots_bf16_b32x16_sapphireamx_t b_tile;
1252
1257
 
@@ -1257,35 +1262,35 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1257
1262
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1258
1263
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
1259
1264
 
1260
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_elements + depth_offset,
1261
- a_stride_elements, rows_in_upper_tile, valid_depth);
1262
- if (rows_in_lower_tile > 0) {
1263
- nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower,
1265
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
1266
+ a_stride_elements, rows_in_high_tile, valid_depth);
1267
+ if (rows_in_low_tile > 0) {
1268
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom,
1264
1269
  a + (row_block_start + 16) * a_stride_elements + depth_offset,
1265
- a_stride_elements, rows_in_lower_tile, valid_depth);
1270
+ a_stride_elements, rows_in_low_tile, valid_depth);
1266
1271
  }
1267
1272
 
1268
1273
  nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
1269
1274
  valid_depth);
1270
1275
  nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
1271
1276
 
1272
- _tile_loadd(0, a_tile_upper.data, 64);
1273
- _tile_loadd(1, a_tile_lower.data, 64);
1277
+ _tile_loadd(0, a_tile_top.data, 64);
1278
+ _tile_loadd(1, a_tile_bottom.data, 64);
1274
1279
  _tile_loadd(2, b_tile.data, 64);
1275
1280
 
1276
1281
  _tile_dpbf16ps(4, 0, 2);
1277
1282
  _tile_dpbf16ps(6, 1, 2);
1278
1283
  }
1279
1284
 
1280
- _tile_stored(4, c_upper_state.data, 64);
1281
- _tile_stored(6, c_lower_state.data, 64);
1285
+ _tile_stored(4, c_high_state.data, 64);
1286
+ _tile_stored(6, c_low_state.data, 64);
1282
1287
 
1283
- nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
1284
- c_stride_elements, rows_in_upper_tile, column_remainder_count);
1285
- if (rows_in_lower_tile > 0) {
1286
- nk_dots_bf16_store_sapphireamx_(&c_lower_state,
1288
+ nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
1289
+ c_stride_elements, rows_in_high_tile, column_remainder_count);
1290
+ if (rows_in_low_tile > 0) {
1291
+ nk_dots_bf16_store_sapphireamx_(&c_low_state,
1287
1292
  c + (row_block_start + 16) * c_stride_elements + full_cols,
1288
- c_stride_elements, rows_in_lower_tile, column_remainder_count);
1293
+ c_stride_elements, rows_in_low_tile, column_remainder_count);
1289
1294
  }
1290
1295
  }
1291
1296
  }
@@ -1294,9 +1299,9 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
1294
1299
  }
1295
1300
 
1296
1301
  NK_PUBLIC void nk_dots_compact_bf16_sapphireamx( //
1297
- void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride) {
1302
+ void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride_in_bytes) {
1298
1303
 
1299
- nk_size_t const c_stride_f32 = c_stride / sizeof(nk_f32_t);
1304
+ nk_size_t const c_stride_f32 = c_stride_in_bytes / sizeof(nk_f32_t);
1300
1305
  nk_f32_t const *c_f32 = (nk_f32_t const *)c;
1301
1306
  nk_bf16_t *c_bf16 = (nk_bf16_t *)c;
1302
1307
 
@@ -1322,18 +1327,18 @@ NK_PUBLIC void nk_dots_compact_bf16_sapphireamx( //
1322
1327
  }
1323
1328
  }
1324
1329
 
1325
- NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx( //
1326
- nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
1327
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
1330
+ NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx( //
1331
+ nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
1332
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
1328
1333
  nk_size_t row_start, nk_size_t row_count) {
1329
1334
 
1330
- nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
1331
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1335
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
1336
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1332
1337
 
1333
1338
  // Handle row slicing: compute rows [row_start, row_end)
1334
1339
  nk_size_t const row_end = (row_count == 0)
1335
- ? n_vectors
1336
- : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
1340
+ ? vectors_count
1341
+ : (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
1337
1342
 
1338
1343
  // Round depth up to multiple of 96 (3 tiles × 32 elements)
1339
1344
  nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
@@ -1349,8 +1354,8 @@ NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx( //
1349
1354
  for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
1350
1355
  nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
1351
1356
 
1352
- for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
1353
- nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
1357
+ for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
1358
+ nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
1354
1359
 
1355
1360
  nk_dots_bf16_init_sapphireamx_(&state);
1356
1361
 
@@ -1391,7 +1396,7 @@ NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx( //
1391
1396
  }
1392
1397
  }
1393
1398
 
1394
- #pragma endregion // Half Precision Floats
1399
+ #pragma endregion F16 Floats
1395
1400
 
1396
1401
  #pragma region Signed Integers
1397
1402
 
@@ -1421,7 +1426,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sapphireamx(nk_size_t column_count, n
1421
1426
 
1422
1427
  NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
1423
1428
  nk_i8_t const *b, nk_size_t column_count, nk_size_t depth, //
1424
- nk_size_t b_stride, void *b_packed) {
1429
+ nk_size_t b_stride_in_bytes, void *b_packed) {
1425
1430
 
1426
1431
  // AMX I8 tile dimensions: 16 rows × 64 columns (1024 I8 elements = 1KB)
1427
1432
  nk_size_t const tmm_rows = 16;
@@ -1450,34 +1455,45 @@ NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
1450
1455
  nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + tiles_offset);
1451
1456
  nk_i8_t *column_edge_ptr = (nk_i8_t *)((char *)b_packed + column_edge_offset);
1452
1457
 
1453
- // Zero-initialize all tiles (handles depth remainder padding)
1454
- for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
1455
-
1456
- // Pack tiles using LINEAR ordering: tile_index = column_tile × depth_tiles_count + depth_tile
1457
- // This provides sequential memory access when streaming along depth dimension.
1458
+ // Pack tiles using vectorized transposer: gather 16 strided rows into an aligned
1459
+ // temporary, transpose via SIMD, then copy the result to the packed buffer.
1460
+ // Stack-local aligned tiles are needed because the packed buffer may not be 64-byte aligned.
1458
1461
  for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
1459
1462
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1460
1463
 
1461
- // Linear tile index: all depth-tiles for one column-tile are contiguous
1462
1464
  nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
1463
1465
  nk_i8_t *tile_output = tiles_ptr + tile_index * tile_elements;
1464
1466
 
1465
- // Source coordinates in original B matrix
1466
1467
  nk_size_t const src_row_start = column_tile_idx * tmm_rows;
1467
1468
  nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
1468
1469
  nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
1469
1470
  : (depth - src_column_start);
1470
1471
 
1471
- // Pack with quad-interleaving as required by TDPBSSD instruction.
1472
- // AMX expects: [col0_row0, col1_row0, col2_row0, col3_row0, col0_row1, ...]
1473
- // Formula: packed_idx = (column / 4) × 64 + row × 4 + (column % 4)
1474
- for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
1475
- for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
1476
- nk_size_t const src_idx = (src_row_start + row_idx) * b_stride + src_column_start + column_idx;
1477
- nk_size_t const dst_idx = (column_idx / 4) * 64 + row_idx * 4 + (column_idx % 4);
1478
- tile_output[dst_idx] = b[src_idx];
1472
+ // Gather 16 strided source rows into a contiguous aligned tile
1473
+ nk_dots_i8_a16x64_sapphireamx_t source_tile;
1474
+ if (columns_to_pack == tmm_cols) {
1475
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
1476
+ nk_i8_t const *source_row = (nk_i8_t const *)((char const *)b +
1477
+ (src_row_start + row_idx) * b_stride_in_bytes) +
1478
+ src_column_start;
1479
+ _mm512_store_si512(&source_tile.data[row_idx][0], _mm512_loadu_si512(source_row));
1480
+ }
1481
+ }
1482
+ else {
1483
+ __mmask64 depth_mask = (__mmask64)((columns_to_pack < 64) ? ((1ULL << columns_to_pack) - 1) : ~0ULL);
1484
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
1485
+ nk_i8_t const *source_row = (nk_i8_t const *)((char const *)b +
1486
+ (src_row_start + row_idx) * b_stride_in_bytes) +
1487
+ src_column_start;
1488
+ _mm512_store_si512(&source_tile.data[row_idx][0], _mm512_maskz_loadu_epi8(depth_mask, source_row));
1479
1489
  }
1480
1490
  }
1491
+
1492
+ // Transpose into aligned local, then copy to (potentially unaligned) packed buffer
1493
+ nk_dots_i8_b64x16_sapphireamx_t transposed_tile;
1494
+ nk_dots_pack_i8_transposed_sapphireamx_(&source_tile, &transposed_tile);
1495
+ for (nk_size_t i = 0; i < tile_elements; i += 64)
1496
+ _mm512_storeu_si512(tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
1481
1497
  }
1482
1498
  }
1483
1499
 
@@ -1487,7 +1503,7 @@ NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
1487
1503
  for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
1488
1504
  for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
1489
1505
  column_edge_ptr[row_idx * depth + column_idx] =
1490
- b[(remainder_start_row + row_idx) * b_stride + column_idx];
1506
+ b[(remainder_start_row + row_idx) * b_stride_in_bytes + column_idx];
1491
1507
  }
1492
1508
  }
1493
1509
  }
@@ -1497,7 +1513,8 @@ NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
1497
1513
  (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_i8_t) : 0);
1498
1514
  header->norms_byte_offset = (nk_u32_t)norms_offset;
1499
1515
  nk_u32_t *norms = (nk_u32_t *)((char *)b_packed + norms_offset);
1500
- for (nk_size_t col = 0; col < column_count; col++) norms[col] = nk_dots_reduce_sumsq_i8_(b + col * b_stride, depth);
1516
+ for (nk_size_t col = 0; col < column_count; col++)
1517
+ norms[col] = nk_dots_reduce_sumsq_i8_(b + col * b_stride_in_bytes, depth);
1501
1518
  }
1502
1519
 
1503
1520
  NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
@@ -1530,7 +1547,7 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1530
1547
  if (depth_tiles_count == 0) return;
1531
1548
 
1532
1549
  // Tile buffers for A (only used for edge tiles)
1533
- nk_dots_i8_a16x64_sapphireamx_t a_tile_upper, a_tile_lower;
1550
+ nk_dots_i8_a16x64_sapphireamx_t a_tile_top, a_tile_bottom;
1534
1551
  nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
1535
1552
 
1536
1553
  // Precompute: number of full depth-tiles (no masking needed)
@@ -1562,8 +1579,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1562
1579
  // Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
1563
1580
  if (is_full_row_block && full_depth_tiles_count > 0) {
1564
1581
  // A row pointers for direct load
1565
- nk_i8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
1566
- nk_i8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
1582
+ nk_i8_t const *a_top_base = a + row_block_start * a_stride_bytes;
1583
+ nk_i8_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_bytes;
1567
1584
 
1568
1585
  // B tile pointers
1569
1586
  nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
@@ -1572,8 +1589,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1572
1589
  (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
1573
1590
 
1574
1591
  // Prologue: load first depth tile into TMM0-3
1575
- _tile_loadd(0, a_upper_base, a_stride_bytes);
1576
- _tile_loadd(1, a_lower_base, a_stride_bytes);
1592
+ _tile_loadd(0, a_top_base, a_stride_bytes);
1593
+ _tile_loadd(1, a_bottom_base, a_stride_bytes);
1577
1594
  _tile_loadd(2, b_tile_left->data, 64);
1578
1595
  _tile_loadd(3, b_tile_right->data, 64);
1579
1596
 
@@ -1586,8 +1603,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1586
1603
  _tile_dpbssd(6, 1, 2);
1587
1604
  _tile_dpbssd(7, 1, 3);
1588
1605
 
1589
- _tile_loadd(0, a_upper_base + next_depth_offset, a_stride_bytes);
1590
- _tile_loadd(1, a_lower_base + next_depth_offset, a_stride_bytes);
1606
+ _tile_loadd(0, a_top_base + next_depth_offset, a_stride_bytes);
1607
+ _tile_loadd(1, a_bottom_base + next_depth_offset, a_stride_bytes);
1591
1608
  b_tile_left = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
1592
1609
  (b_column_left_base + depth_tile_idx + 1) *
1593
1610
  tile_size);
@@ -1608,9 +1625,9 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1608
1625
  if (depth_remainder > 0) {
1609
1626
  nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
1610
1627
 
1611
- nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a_upper_base + depth_offset, a_stride_bytes, 16,
1628
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a_top_base + depth_offset, a_stride_bytes, 16,
1612
1629
  depth_remainder);
1613
- nk_dots_i8_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_bytes, 16,
1630
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base + depth_offset, a_stride_bytes, 16,
1614
1631
  depth_remainder);
1615
1632
 
1616
1633
  b_tile_left = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
@@ -1620,8 +1637,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1620
1637
  full_depth_tiles_count) *
1621
1638
  tile_size);
1622
1639
 
1623
- _tile_loadd(0, a_tile_upper.data, 64);
1624
- _tile_loadd(1, a_tile_lower.data, 64);
1640
+ _tile_loadd(0, a_tile_top.data, 64);
1641
+ _tile_loadd(1, a_tile_bottom.data, 64);
1625
1642
  _tile_loadd(2, b_tile_left->data, 64);
1626
1643
  _tile_loadd(3, b_tile_right->data, 64);
1627
1644
 
@@ -1633,19 +1650,19 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1633
1650
  }
1634
1651
  // Full row-block but only partial depth tile (depth < tile_depth)
1635
1652
  else if (is_full_row_block) {
1636
- nk_i8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
1637
- nk_i8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
1653
+ nk_i8_t const *a_top_base = a + row_block_start * a_stride_bytes;
1654
+ nk_i8_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_bytes;
1638
1655
 
1639
- nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a_upper_base, a_stride_bytes, 16, depth_remainder);
1640
- nk_dots_i8_load_a_sapphireamx_(&a_tile_lower, a_lower_base, a_stride_bytes, 16, depth_remainder);
1656
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a_top_base, a_stride_bytes, 16, depth_remainder);
1657
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base, a_stride_bytes, 16, depth_remainder);
1641
1658
 
1642
1659
  nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
1643
1660
  (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
1644
1661
  nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
1645
1662
  (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
1646
1663
 
1647
- _tile_loadd(0, a_tile_upper.data, 64);
1648
- _tile_loadd(1, a_tile_lower.data, 64);
1664
+ _tile_loadd(0, a_tile_top.data, 64);
1665
+ _tile_loadd(1, a_tile_bottom.data, 64);
1649
1666
  _tile_loadd(2, b_tile_left->data, 64);
1650
1667
  _tile_loadd(3, b_tile_right->data, 64);
1651
1668
 
@@ -1656,20 +1673,20 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1656
1673
  }
1657
1674
  // Slow path: edge row-block → always use buffered load with masking
1658
1675
  else {
1659
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1660
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1676
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1677
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1661
1678
 
1662
1679
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1663
1680
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1664
1681
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
1665
1682
  : depth_remainder;
1666
1683
 
1667
- nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
1668
- a_stride_bytes, rows_in_upper_tile, valid_depth);
1669
- if (rows_in_lower_tile > 0) {
1670
- nk_dots_i8_load_a_sapphireamx_(&a_tile_lower,
1684
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
1685
+ a_stride_bytes, rows_in_high_tile, valid_depth);
1686
+ if (rows_in_low_tile > 0) {
1687
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom,
1671
1688
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
1672
- a_stride_bytes, rows_in_lower_tile, valid_depth);
1689
+ a_stride_bytes, rows_in_low_tile, valid_depth);
1673
1690
  }
1674
1691
 
1675
1692
  nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
@@ -1679,8 +1696,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1679
1696
  (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
1680
1697
  (b_column_right_base + depth_tile_idx) * tile_size);
1681
1698
 
1682
- _tile_loadd(0, a_tile_upper.data, 64);
1683
- _tile_loadd(1, a_tile_lower.data, 64);
1699
+ _tile_loadd(0, a_tile_top.data, 64);
1700
+ _tile_loadd(1, a_tile_bottom.data, 64);
1684
1701
  _tile_loadd(2, b_tile_left->data, 64);
1685
1702
  _tile_loadd(3, b_tile_right->data, 64);
1686
1703
 
@@ -1716,11 +1733,11 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1716
1733
  nk_size_t const column_tile_idx = column_tiles_count - 1;
1717
1734
  nk_size_t const col_start = column_tile_idx * 16;
1718
1735
  nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
1719
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1720
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1736
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1737
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1721
1738
 
1722
1739
  // Use 1 × 2 blocking for single column-tile (2 row-tiles × 1 column-tile)
1723
- nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
1740
+ nk_dots_i8_state_sapphireamx_t c_high_state, c_low_state;
1724
1741
 
1725
1742
  _tile_zero(4);
1726
1743
  _tile_zero(6);
@@ -1729,44 +1746,43 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1729
1746
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1730
1747
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
1731
1748
 
1732
- nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
1733
- a_stride_bytes, rows_in_upper_tile, valid_depth);
1734
- if (rows_in_lower_tile > 0) {
1735
- nk_dots_i8_load_a_sapphireamx_(&a_tile_lower,
1749
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
1750
+ a_stride_bytes, rows_in_high_tile, valid_depth);
1751
+ if (rows_in_low_tile > 0) {
1752
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom,
1736
1753
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
1737
- a_stride_bytes, rows_in_lower_tile, valid_depth);
1754
+ a_stride_bytes, rows_in_low_tile, valid_depth);
1738
1755
  }
1739
1756
 
1740
1757
  nk_dots_i8_b64x16_sapphireamx_t const *b_tile =
1741
1758
  (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
1742
1759
  (b_column_base + depth_tile_idx) * tile_size);
1743
1760
 
1744
- _tile_loadd(0, a_tile_upper.data, 64);
1745
- _tile_loadd(1, a_tile_lower.data, 64);
1761
+ _tile_loadd(0, a_tile_top.data, 64);
1762
+ _tile_loadd(1, a_tile_bottom.data, 64);
1746
1763
  _tile_loadd(2, b_tile->data, 64);
1747
1764
 
1748
1765
  _tile_dpbssd(4, 0, 2);
1749
1766
  _tile_dpbssd(6, 1, 2);
1750
1767
  }
1751
1768
 
1752
- _tile_stored(4, c_upper_state.data, 64);
1753
- _tile_stored(6, c_lower_state.data, 64);
1769
+ _tile_stored(4, c_high_state.data, 64);
1770
+ _tile_stored(6, c_low_state.data, 64);
1754
1771
 
1755
- nk_dots_i8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
1756
- c_stride_elements, rows_in_upper_tile, 16);
1757
- if (rows_in_lower_tile > 0) {
1758
- nk_dots_i8_store_sapphireamx_(&c_lower_state,
1759
- c + (row_block_start + 16) * c_stride_elements + col_start,
1760
- c_stride_elements, rows_in_lower_tile, 16);
1772
+ nk_dots_i8_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
1773
+ c_stride_elements, rows_in_high_tile, 16);
1774
+ if (rows_in_low_tile > 0) {
1775
+ nk_dots_i8_store_sapphireamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + col_start,
1776
+ c_stride_elements, rows_in_low_tile, 16);
1761
1777
  }
1762
1778
  }
1763
1779
 
1764
1780
  // Handle column-edge (remaining columns < 16) using AMX with partial tiles
1765
1781
  if (column_remainder_count > 0) {
1766
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1767
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1782
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1783
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1768
1784
 
1769
- nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
1785
+ nk_dots_i8_state_sapphireamx_t c_high_state, c_low_state;
1770
1786
  nk_dots_i8_a16x64_sapphireamx_t b_as_a;
1771
1787
  nk_dots_i8_b64x16_sapphireamx_t b_tile;
1772
1788
 
@@ -1778,12 +1794,12 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1778
1794
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
1779
1795
 
1780
1796
  // Load A tiles
1781
- nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
1782
- a_stride_bytes, rows_in_upper_tile, valid_depth);
1783
- if (rows_in_lower_tile > 0) {
1784
- nk_dots_i8_load_a_sapphireamx_(&a_tile_lower,
1797
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
1798
+ a_stride_bytes, rows_in_high_tile, valid_depth);
1799
+ if (rows_in_low_tile > 0) {
1800
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom,
1785
1801
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
1786
- a_stride_bytes, rows_in_lower_tile, valid_depth);
1802
+ a_stride_bytes, rows_in_low_tile, valid_depth);
1787
1803
  }
1788
1804
 
1789
1805
  // Load B edge data (row-major: b_edge[row × depth + column]) and pack into B tile
@@ -1792,23 +1808,22 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1792
1808
  valid_depth);
1793
1809
  nk_dots_pack_i8_transposed_sapphireamx_(&b_as_a, &b_tile);
1794
1810
 
1795
- _tile_loadd(0, a_tile_upper.data, 64);
1796
- _tile_loadd(1, a_tile_lower.data, 64);
1811
+ _tile_loadd(0, a_tile_top.data, 64);
1812
+ _tile_loadd(1, a_tile_bottom.data, 64);
1797
1813
  _tile_loadd(2, b_tile.data, 64);
1798
1814
 
1799
1815
  _tile_dpbssd(4, 0, 2);
1800
1816
  _tile_dpbssd(6, 1, 2);
1801
1817
  }
1802
1818
 
1803
- _tile_stored(4, c_upper_state.data, 64);
1804
- _tile_stored(6, c_lower_state.data, 64);
1819
+ _tile_stored(4, c_high_state.data, 64);
1820
+ _tile_stored(6, c_low_state.data, 64);
1805
1821
 
1806
- nk_dots_i8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
1807
- c_stride_elements, rows_in_upper_tile, column_remainder_count);
1808
- if (rows_in_lower_tile > 0) {
1809
- nk_dots_i8_store_sapphireamx_(&c_lower_state,
1810
- c + (row_block_start + 16) * c_stride_elements + full_cols,
1811
- c_stride_elements, rows_in_lower_tile, column_remainder_count);
1822
+ nk_dots_i8_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
1823
+ c_stride_elements, rows_in_high_tile, column_remainder_count);
1824
+ if (rows_in_low_tile > 0) {
1825
+ nk_dots_i8_store_sapphireamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + full_cols,
1826
+ c_stride_elements, rows_in_low_tile, column_remainder_count);
1812
1827
  }
1813
1828
  }
1814
1829
  }
@@ -1817,10 +1832,10 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1817
1832
  }
1818
1833
 
1819
1834
  NK_PUBLIC void nk_dots_compact_i8_sapphireamx( //
1820
- void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride, nk_i32_t const *a_squared_norms,
1835
+ void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride_in_bytes, nk_i32_t const *a_squared_norms,
1821
1836
  nk_i32_t const *b_squared_norms) {
1822
1837
 
1823
- nk_size_t const c_stride_i32 = c_stride / sizeof(nk_i32_t);
1838
+ nk_size_t const c_stride_i32 = c_stride_in_bytes / sizeof(nk_i32_t);
1824
1839
  nk_i32_t const *c_i32 = (nk_i32_t const *)c;
1825
1840
  nk_i8_t *c_i8 = (nk_i8_t *)c;
1826
1841
 
@@ -1828,41 +1843,45 @@ NK_PUBLIC void nk_dots_compact_i8_sapphireamx( //
1828
1843
  nk_f32_t *b_rsqrt = (nk_f32_t *)(c_i8 + row_count * column_count);
1829
1844
 
1830
1845
  // Precompute rsqrt of all b_norms using AVX512 (16 at a time)
1831
- __m512 half_vec = _mm512_set1_ps(0.5f);
1832
- __m512 three_halves_vec = _mm512_set1_ps(1.5f);
1846
+ __m512 half_vec_f32x16 = _mm512_set1_ps(0.5f);
1847
+ __m512 three_halves_vec_f32x16 = _mm512_set1_ps(1.5f);
1833
1848
  nk_size_t column_idx = 0;
1834
1849
 
1835
1850
  for (; column_idx + 16 <= column_count; column_idx += 16) {
1836
- __m512i b_norms_i32 = _mm512_loadu_si512(b_squared_norms + column_idx);
1837
- __m512 b_norms_f32 = _mm512_cvtepi32_ps(b_norms_i32);
1838
- __m512 rsqrt_vec = _mm512_rsqrt14_ps(b_norms_f32);
1851
+ __m512i b_norms_i32x16 = _mm512_loadu_si512(b_squared_norms + column_idx);
1852
+ __m512 b_norms_f32x16 = _mm512_cvtepi32_ps(b_norms_i32x16);
1853
+ __m512 rsqrt_vec_f32x16 = _mm512_rsqrt14_ps(b_norms_f32x16);
1839
1854
  // Newton-Raphson refinement
1840
- rsqrt_vec = _mm512_mul_ps(
1841
- rsqrt_vec,
1842
- _mm512_sub_ps(three_halves_vec,
1843
- _mm512_mul_ps(half_vec, _mm512_mul_ps(b_norms_f32, _mm512_mul_ps(rsqrt_vec, rsqrt_vec)))));
1855
+ rsqrt_vec_f32x16 = _mm512_mul_ps(
1856
+ rsqrt_vec_f32x16,
1857
+ _mm512_sub_ps(
1858
+ three_halves_vec_f32x16,
1859
+ _mm512_mul_ps(half_vec_f32x16,
1860
+ _mm512_mul_ps(b_norms_f32x16, _mm512_mul_ps(rsqrt_vec_f32x16, rsqrt_vec_f32x16)))));
1844
1861
  // Zero out rsqrt where norm was zero
1845
- __mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32, _mm512_setzero_si512());
1846
- rsqrt_vec = _mm512_maskz_mov_ps(nonzero_mask, rsqrt_vec);
1847
- _mm512_storeu_ps(b_rsqrt + column_idx, rsqrt_vec);
1862
+ __mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32x16, _mm512_setzero_si512());
1863
+ rsqrt_vec_f32x16 = _mm512_maskz_mov_ps(nonzero_mask, rsqrt_vec_f32x16);
1864
+ _mm512_storeu_ps(b_rsqrt + column_idx, rsqrt_vec_f32x16);
1848
1865
  }
1849
1866
 
1850
1867
  // Handle remaining b_norms with masked operations
1851
1868
  if (column_idx < column_count) {
1852
1869
  __mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
1853
- __m512i b_norms_i32 = _mm512_maskz_loadu_epi32(tail_mask, b_squared_norms + column_idx);
1854
- __m512 b_norms_f32 = _mm512_cvtepi32_ps(b_norms_i32);
1855
- __m512 rsqrt_vec = _mm512_rsqrt14_ps(b_norms_f32);
1856
- rsqrt_vec = _mm512_mul_ps(
1857
- rsqrt_vec,
1858
- _mm512_sub_ps(three_halves_vec,
1859
- _mm512_mul_ps(half_vec, _mm512_mul_ps(b_norms_f32, _mm512_mul_ps(rsqrt_vec, rsqrt_vec)))));
1860
- __mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32, _mm512_setzero_si512());
1861
- rsqrt_vec = _mm512_maskz_mov_ps(nonzero_mask & tail_mask, rsqrt_vec);
1862
- _mm512_mask_storeu_ps(b_rsqrt + column_idx, tail_mask, rsqrt_vec);
1870
+ __m512i b_norms_i32x16 = _mm512_maskz_loadu_epi32(tail_mask, b_squared_norms + column_idx);
1871
+ __m512 b_norms_f32x16 = _mm512_cvtepi32_ps(b_norms_i32x16);
1872
+ __m512 rsqrt_vec_f32x16 = _mm512_rsqrt14_ps(b_norms_f32x16);
1873
+ rsqrt_vec_f32x16 = _mm512_mul_ps(
1874
+ rsqrt_vec_f32x16,
1875
+ _mm512_sub_ps(
1876
+ three_halves_vec_f32x16,
1877
+ _mm512_mul_ps(half_vec_f32x16,
1878
+ _mm512_mul_ps(b_norms_f32x16, _mm512_mul_ps(rsqrt_vec_f32x16, rsqrt_vec_f32x16)))));
1879
+ __mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32x16, _mm512_setzero_si512());
1880
+ rsqrt_vec_f32x16 = _mm512_maskz_mov_ps(nonzero_mask & tail_mask, rsqrt_vec_f32x16);
1881
+ _mm512_mask_storeu_ps(b_rsqrt + column_idx, tail_mask, rsqrt_vec_f32x16);
1863
1882
  }
1864
1883
 
1865
- __m512 scale_vec = _mm512_set1_ps(127.0f);
1884
+ __m512 scale_vec_f32x16 = _mm512_set1_ps(127.0f);
1866
1885
 
1867
1886
  for (nk_size_t row_idx = 0; row_idx < row_count; row_idx++) {
1868
1887
  nk_i32_t const *src_row = c_i32 + row_idx * c_stride_i32;
@@ -1872,55 +1891,57 @@ NK_PUBLIC void nk_dots_compact_i8_sapphireamx( //
1872
1891
  nk_f32_t a_norm_f32 = (nk_f32_t)a_squared_norms[row_idx];
1873
1892
  nk_f32_t a_rsqrt_val = 0.0f;
1874
1893
  if (a_norm_f32 > 0.0f) {
1875
- __m128 a_vec = _mm_set_ss(a_norm_f32);
1876
- __m128 rsqrt_s = _mm_rsqrt_ss(a_vec);
1877
- rsqrt_s = _mm_mul_ss(
1878
- rsqrt_s, _mm_sub_ss(_mm_set_ss(1.5f),
1879
- _mm_mul_ss(_mm_set_ss(0.5f), _mm_mul_ss(a_vec, _mm_mul_ss(rsqrt_s, rsqrt_s)))));
1880
- a_rsqrt_val = _mm_cvtss_f32(rsqrt_s);
1894
+ __m128 a_vec_f32x4 = _mm_set_ss(a_norm_f32);
1895
+ __m128 rsqrt_s_f32x4 = _mm_rsqrt_ss(a_vec_f32x4);
1896
+ rsqrt_s_f32x4 = _mm_mul_ss(
1897
+ rsqrt_s_f32x4,
1898
+ _mm_sub_ss(
1899
+ _mm_set_ss(1.5f),
1900
+ _mm_mul_ss(_mm_set_ss(0.5f), _mm_mul_ss(a_vec_f32x4, _mm_mul_ss(rsqrt_s_f32x4, rsqrt_s_f32x4)))));
1901
+ a_rsqrt_val = _mm_cvtss_f32(rsqrt_s_f32x4);
1881
1902
  }
1882
- __m512 a_rsqrt_vec = _mm512_set1_ps(a_rsqrt_val);
1883
- __m512 row_scale = _mm512_mul_ps(a_rsqrt_vec, scale_vec);
1903
+ __m512 a_rsqrt_vec_f32x16 = _mm512_set1_ps(a_rsqrt_val);
1904
+ __m512 row_scale_f32x16 = _mm512_mul_ps(a_rsqrt_vec_f32x16, scale_vec_f32x16);
1884
1905
 
1885
1906
  column_idx = 0;
1886
1907
 
1887
1908
  // Process 16 elements at a time
1888
1909
  for (; column_idx + 16 <= column_count; column_idx += 16) {
1889
- __m512i c_vals = _mm512_loadu_si512(src_row + column_idx);
1890
- __m512 c_f32 = _mm512_cvtepi32_ps(c_vals);
1891
- __m512 b_rsqrt_vec = _mm512_loadu_ps(b_rsqrt + column_idx);
1892
- __m512 normalized = _mm512_mul_ps(_mm512_mul_ps(c_f32, row_scale), b_rsqrt_vec);
1893
- __m512i result_i32 = _mm512_cvtps_epi32(normalized);
1910
+ __m512i c_vals_i32x16 = _mm512_loadu_si512(src_row + column_idx);
1911
+ __m512 c_f32_f32x16 = _mm512_cvtepi32_ps(c_vals_i32x16);
1912
+ __m512 b_rsqrt_vec_f32x16 = _mm512_loadu_ps(b_rsqrt + column_idx);
1913
+ __m512 normalized_f32x16 = _mm512_mul_ps(_mm512_mul_ps(c_f32_f32x16, row_scale_f32x16), b_rsqrt_vec_f32x16);
1914
+ __m512i result_i32x16 = _mm512_cvtps_epi32(normalized_f32x16);
1894
1915
  // Saturating pack I32 → I8 (16 values → 16 bytes in low 128 bits)
1895
- __m128i result_i8 = _mm512_cvtsepi32_epi8(result_i32);
1896
- _mm_storeu_si128((__m128i *)(dst_row + column_idx), result_i8);
1916
+ __m128i result_i8x16 = _mm512_cvtsepi32_epi8(result_i32x16);
1917
+ _mm_storeu_si128((__m128i *)(dst_row + column_idx), result_i8x16);
1897
1918
  }
1898
1919
 
1899
1920
  // Handle remaining elements with masked operations
1900
1921
  if (column_idx < column_count) {
1901
1922
  __mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
1902
- __m512i c_vals = _mm512_maskz_loadu_epi32(tail_mask, src_row + column_idx);
1903
- __m512 c_f32 = _mm512_cvtepi32_ps(c_vals);
1904
- __m512 b_rsqrt_vec = _mm512_maskz_loadu_ps(tail_mask, b_rsqrt + column_idx);
1905
- __m512 normalized = _mm512_mul_ps(_mm512_mul_ps(c_f32, row_scale), b_rsqrt_vec);
1906
- __m512i result_i32 = _mm512_cvtps_epi32(normalized);
1907
- __m128i result_i8 = _mm512_cvtsepi32_epi8(result_i32);
1908
- _mm_mask_storeu_epi8(dst_row + column_idx, tail_mask, result_i8);
1923
+ __m512i c_vals_i32x16 = _mm512_maskz_loadu_epi32(tail_mask, src_row + column_idx);
1924
+ __m512 c_f32_f32x16 = _mm512_cvtepi32_ps(c_vals_i32x16);
1925
+ __m512 b_rsqrt_vec_f32x16 = _mm512_maskz_loadu_ps(tail_mask, b_rsqrt + column_idx);
1926
+ __m512 normalized_f32x16 = _mm512_mul_ps(_mm512_mul_ps(c_f32_f32x16, row_scale_f32x16), b_rsqrt_vec_f32x16);
1927
+ __m512i result_i32x16 = _mm512_cvtps_epi32(normalized_f32x16);
1928
+ __m128i result_i8x16 = _mm512_cvtsepi32_epi8(result_i32x16);
1929
+ _mm_mask_storeu_epi8(dst_row + column_idx, tail_mask, result_i8x16);
1909
1930
  }
1910
1931
  }
1911
1932
  }
1912
1933
 
1913
- NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
1914
- nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
1915
- nk_size_t stride, nk_i32_t *result, nk_size_t result_stride, //
1934
+ NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
1935
+ nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
1936
+ nk_size_t stride_in_bytes, nk_i32_t *result, nk_size_t result_stride_in_bytes, //
1916
1937
  nk_size_t row_start, nk_size_t row_count) {
1917
1938
 
1918
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_i32_t);
1939
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_i32_t);
1919
1940
 
1920
1941
  // Handle row slicing: compute rows [row_start, row_end)
1921
1942
  nk_size_t const row_end = (row_count == 0)
1922
- ? n_vectors
1923
- : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
1943
+ ? vectors_count
1944
+ : (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
1924
1945
 
1925
1946
  // Round depth up to multiple of 192 (3 tiles × 64 elements)
1926
1947
  nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
@@ -1936,8 +1957,8 @@ NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
1936
1957
  for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
1937
1958
  nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
1938
1959
 
1939
- for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
1940
- nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
1960
+ for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
1961
+ nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
1941
1962
 
1942
1963
  nk_dots_i8_init_sapphireamx_(&state);
1943
1964
 
@@ -1950,19 +1971,19 @@ NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
1950
1971
  ? 64
1951
1972
  : (depth > depth_start ? depth - depth_start : 0);
1952
1973
 
1953
- nk_dots_i8_load_a_sapphireamx_( //
1954
- &a_tiles[tile_idx], //
1955
- vectors + row_tile * stride + depth_start, //
1956
- stride, valid_rows, valid_depth);
1974
+ nk_dots_i8_load_a_sapphireamx_( //
1975
+ &a_tiles[tile_idx], //
1976
+ vectors + row_tile * stride_in_bytes + depth_start, //
1977
+ stride_in_bytes, valid_rows, valid_depth);
1957
1978
 
1958
1979
  if (row_tile == col_tile) {
1959
1980
  nk_dots_pack_i8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
1960
1981
  }
1961
1982
  else {
1962
- nk_dots_i8_load_a_sapphireamx_( //
1963
- &b_src_tiles[tile_idx], //
1964
- vectors + col_tile * stride + depth_start, //
1965
- stride, valid_cols, valid_depth);
1983
+ nk_dots_i8_load_a_sapphireamx_( //
1984
+ &b_src_tiles[tile_idx], //
1985
+ vectors + col_tile * stride_in_bytes + depth_start, //
1986
+ stride_in_bytes, valid_cols, valid_depth);
1966
1987
  nk_dots_pack_i8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
1967
1988
  }
1968
1989
  }
@@ -1978,7 +1999,7 @@ NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
1978
1999
  }
1979
2000
  }
1980
2001
 
1981
- #pragma endregion // Signed Integers
2002
+ #pragma endregion Signed Integers
1982
2003
 
1983
2004
  #pragma region Unsigned Integers
1984
2005
 
@@ -1989,7 +2010,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sapphireamx(nk_size_t column_count, n
1989
2010
 
1990
2011
  NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
1991
2012
  nk_u8_t const *b, nk_size_t column_count, nk_size_t depth, //
1992
- nk_size_t b_stride, void *b_packed) {
2013
+ nk_size_t b_stride_in_bytes, void *b_packed) {
1993
2014
 
1994
2015
  nk_size_t const tmm_rows = 16;
1995
2016
  nk_size_t const tmm_cols = 64;
@@ -2013,8 +2034,9 @@ NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
2013
2034
  nk_u8_t *tiles_ptr = (nk_u8_t *)((char *)b_packed + tiles_offset);
2014
2035
  nk_u8_t *column_edge_ptr = (nk_u8_t *)((char *)b_packed + column_edge_offset);
2015
2036
 
2016
- for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
2017
-
2037
+ // Pack tiles using vectorized transposer: gather 16 strided rows into an aligned
2038
+ // temporary, transpose via SIMD, then copy the result to the packed buffer.
2039
+ // Stack-local aligned tiles are needed because the packed buffer may not be 64-byte aligned.
2018
2040
  for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
2019
2041
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2020
2042
 
@@ -2026,14 +2048,31 @@ NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
2026
2048
  nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
2027
2049
  : (depth - src_column_start);
2028
2050
 
2029
- // Pack with quad-interleaving as required by TDPBUUD instruction.
2030
- for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
2031
- for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
2032
- nk_size_t const src_idx = (src_row_start + row_idx) * b_stride + src_column_start + column_idx;
2033
- nk_size_t const dst_idx = (column_idx / 4) * 64 + row_idx * 4 + (column_idx % 4);
2034
- tile_output[dst_idx] = b[src_idx];
2051
+ // Gather 16 strided source rows into a contiguous aligned tile
2052
+ nk_dots_u8_a16x64_sapphireamx_t source_tile;
2053
+ if (columns_to_pack == tmm_cols) {
2054
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
2055
+ nk_u8_t const *source_row = (nk_u8_t const *)((char const *)b +
2056
+ (src_row_start + row_idx) * b_stride_in_bytes) +
2057
+ src_column_start;
2058
+ _mm512_store_si512(&source_tile.data[row_idx][0], _mm512_loadu_si512(source_row));
2059
+ }
2060
+ }
2061
+ else {
2062
+ __mmask64 depth_mask = (__mmask64)((columns_to_pack < 64) ? ((1ULL << columns_to_pack) - 1) : ~0ULL);
2063
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
2064
+ nk_u8_t const *source_row = (nk_u8_t const *)((char const *)b +
2065
+ (src_row_start + row_idx) * b_stride_in_bytes) +
2066
+ src_column_start;
2067
+ _mm512_store_si512(&source_tile.data[row_idx][0], _mm512_maskz_loadu_epi8(depth_mask, source_row));
2035
2068
  }
2036
2069
  }
2070
+
2071
+ // Transpose into aligned local, then copy to (potentially unaligned) packed buffer
2072
+ nk_dots_u8_b64x16_sapphireamx_t transposed_tile;
2073
+ nk_dots_pack_u8_transposed_sapphireamx_(&source_tile, &transposed_tile);
2074
+ for (nk_size_t i = 0; i < tile_elements; i += 64)
2075
+ _mm512_storeu_si512(tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
2037
2076
  }
2038
2077
  }
2039
2078
 
@@ -2042,7 +2081,7 @@ NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
2042
2081
  for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
2043
2082
  for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
2044
2083
  column_edge_ptr[row_idx * depth + column_idx] =
2045
- b[(remainder_start_row + row_idx) * b_stride + column_idx];
2084
+ b[(remainder_start_row + row_idx) * b_stride_in_bytes + column_idx];
2046
2085
  }
2047
2086
  }
2048
2087
  }
@@ -2052,7 +2091,8 @@ NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
2052
2091
  (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_u8_t) : 0);
2053
2092
  header->norms_byte_offset = (nk_u32_t)norms_offset;
2054
2093
  nk_u32_t *norms = (nk_u32_t *)((char *)b_packed + norms_offset);
2055
- for (nk_size_t col = 0; col < column_count; col++) norms[col] = nk_dots_reduce_sumsq_u8_(b + col * b_stride, depth);
2094
+ for (nk_size_t col = 0; col < column_count; col++)
2095
+ norms[col] = nk_dots_reduce_sumsq_u8_(b + col * b_stride_in_bytes, depth);
2056
2096
  }
2057
2097
 
2058
2098
  NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
@@ -2085,7 +2125,7 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2085
2125
  if (depth_tiles_count == 0) return;
2086
2126
 
2087
2127
  // Tile buffers for A (only used for edge tiles)
2088
- nk_dots_u8_a16x64_sapphireamx_t a_tile_upper, a_tile_lower;
2128
+ nk_dots_u8_a16x64_sapphireamx_t a_tile_top, a_tile_bottom;
2089
2129
  nk_dots_u8_state2x2_sapphireamx_t c_accum_buffer;
2090
2130
 
2091
2131
  // Precompute: number of full depth-tiles
@@ -2116,8 +2156,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2116
2156
 
2117
2157
  // Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
2118
2158
  if (is_full_row_block && full_depth_tiles_count > 0) {
2119
- nk_u8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
2120
- nk_u8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
2159
+ nk_u8_t const *a_top_base = a + row_block_start * a_stride_bytes;
2160
+ nk_u8_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_bytes;
2121
2161
 
2122
2162
  nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
2123
2163
  (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
@@ -2125,8 +2165,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2125
2165
  (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
2126
2166
 
2127
2167
  // Prologue: load first depth tile into TMM0-3
2128
- _tile_loadd(0, a_upper_base, a_stride_bytes);
2129
- _tile_loadd(1, a_lower_base, a_stride_bytes);
2168
+ _tile_loadd(0, a_top_base, a_stride_bytes);
2169
+ _tile_loadd(1, a_bottom_base, a_stride_bytes);
2130
2170
  _tile_loadd(2, b_tile_left->data, 64);
2131
2171
  _tile_loadd(3, b_tile_right->data, 64);
2132
2172
 
@@ -2139,8 +2179,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2139
2179
  _tile_dpbuud(6, 1, 2);
2140
2180
  _tile_dpbuud(7, 1, 3);
2141
2181
 
2142
- _tile_loadd(0, a_upper_base + next_depth_offset, a_stride_bytes);
2143
- _tile_loadd(1, a_lower_base + next_depth_offset, a_stride_bytes);
2182
+ _tile_loadd(0, a_top_base + next_depth_offset, a_stride_bytes);
2183
+ _tile_loadd(1, a_bottom_base + next_depth_offset, a_stride_bytes);
2144
2184
  b_tile_left = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
2145
2185
  (b_column_left_base + depth_tile_idx + 1) *
2146
2186
  tile_size);
@@ -2161,9 +2201,9 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2161
2201
  if (depth_remainder > 0) {
2162
2202
  nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
2163
2203
 
2164
- nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a_upper_base + depth_offset, a_stride_bytes, 16,
2204
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a_top_base + depth_offset, a_stride_bytes, 16,
2165
2205
  depth_remainder);
2166
- nk_dots_u8_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_bytes, 16,
2206
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base + depth_offset, a_stride_bytes, 16,
2167
2207
  depth_remainder);
2168
2208
 
2169
2209
  b_tile_left = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
@@ -2173,8 +2213,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2173
2213
  full_depth_tiles_count) *
2174
2214
  tile_size);
2175
2215
 
2176
- _tile_loadd(0, a_tile_upper.data, 64);
2177
- _tile_loadd(1, a_tile_lower.data, 64);
2216
+ _tile_loadd(0, a_tile_top.data, 64);
2217
+ _tile_loadd(1, a_tile_bottom.data, 64);
2178
2218
  _tile_loadd(2, b_tile_left->data, 64);
2179
2219
  _tile_loadd(3, b_tile_right->data, 64);
2180
2220
 
@@ -2186,19 +2226,19 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2186
2226
  }
2187
2227
  // Full row-block but only partial depth tile (depth < tile_depth)
2188
2228
  else if (is_full_row_block) {
2189
- nk_u8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
2190
- nk_u8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
2229
+ nk_u8_t const *a_top_base = a + row_block_start * a_stride_bytes;
2230
+ nk_u8_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_bytes;
2191
2231
 
2192
- nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a_upper_base, a_stride_bytes, 16, depth_remainder);
2193
- nk_dots_u8_load_a_sapphireamx_(&a_tile_lower, a_lower_base, a_stride_bytes, 16, depth_remainder);
2232
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a_top_base, a_stride_bytes, 16, depth_remainder);
2233
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base, a_stride_bytes, 16, depth_remainder);
2194
2234
 
2195
2235
  nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
2196
2236
  (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
2197
2237
  nk_dots_u8_b64x16_sapphireamx_t const *b_tile_right =
2198
2238
  (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
2199
2239
 
2200
- _tile_loadd(0, a_tile_upper.data, 64);
2201
- _tile_loadd(1, a_tile_lower.data, 64);
2240
+ _tile_loadd(0, a_tile_top.data, 64);
2241
+ _tile_loadd(1, a_tile_bottom.data, 64);
2202
2242
  _tile_loadd(2, b_tile_left->data, 64);
2203
2243
  _tile_loadd(3, b_tile_right->data, 64);
2204
2244
 
@@ -2209,20 +2249,20 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2209
2249
  }
2210
2250
  // Slow path: edge row-block → always use buffered load
2211
2251
  else {
2212
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2213
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2252
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2253
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2214
2254
 
2215
2255
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2216
2256
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2217
2257
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
2218
2258
  : depth_remainder;
2219
2259
 
2220
- nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2221
- a_stride_bytes, rows_in_upper_tile, valid_depth);
2222
- if (rows_in_lower_tile > 0) {
2223
- nk_dots_u8_load_a_sapphireamx_(&a_tile_lower,
2260
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
2261
+ a_stride_bytes, rows_in_high_tile, valid_depth);
2262
+ if (rows_in_low_tile > 0) {
2263
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom,
2224
2264
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2225
- a_stride_bytes, rows_in_lower_tile, valid_depth);
2265
+ a_stride_bytes, rows_in_low_tile, valid_depth);
2226
2266
  }
2227
2267
 
2228
2268
  nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
@@ -2232,8 +2272,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2232
2272
  (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
2233
2273
  (b_column_right_base + depth_tile_idx) * tile_size);
2234
2274
 
2235
- _tile_loadd(0, a_tile_upper.data, 64);
2236
- _tile_loadd(1, a_tile_lower.data, 64);
2275
+ _tile_loadd(0, a_tile_top.data, 64);
2276
+ _tile_loadd(1, a_tile_bottom.data, 64);
2237
2277
  _tile_loadd(2, b_tile_left->data, 64);
2238
2278
  _tile_loadd(3, b_tile_right->data, 64);
2239
2279
 
@@ -2268,10 +2308,10 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2268
2308
  nk_size_t const column_tile_idx = column_tiles_count - 1;
2269
2309
  nk_size_t const col_start = column_tile_idx * 16;
2270
2310
  nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
2271
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2272
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2311
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2312
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2273
2313
 
2274
- nk_dots_u8_state_sapphireamx_t c_upper_state, c_lower_state;
2314
+ nk_dots_u8_state_sapphireamx_t c_high_state, c_low_state;
2275
2315
 
2276
2316
  _tile_zero(4);
2277
2317
  _tile_zero(6);
@@ -2280,44 +2320,43 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2280
2320
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2281
2321
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2282
2322
 
2283
- nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2284
- a_stride_bytes, rows_in_upper_tile, valid_depth);
2285
- if (rows_in_lower_tile > 0) {
2286
- nk_dots_u8_load_a_sapphireamx_(&a_tile_lower,
2323
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
2324
+ a_stride_bytes, rows_in_high_tile, valid_depth);
2325
+ if (rows_in_low_tile > 0) {
2326
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom,
2287
2327
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2288
- a_stride_bytes, rows_in_lower_tile, valid_depth);
2328
+ a_stride_bytes, rows_in_low_tile, valid_depth);
2289
2329
  }
2290
2330
 
2291
2331
  nk_dots_u8_b64x16_sapphireamx_t const *b_tile =
2292
2332
  (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
2293
2333
  (b_column_base + depth_tile_idx) * tile_size);
2294
2334
 
2295
- _tile_loadd(0, a_tile_upper.data, 64);
2296
- _tile_loadd(1, a_tile_lower.data, 64);
2335
+ _tile_loadd(0, a_tile_top.data, 64);
2336
+ _tile_loadd(1, a_tile_bottom.data, 64);
2297
2337
  _tile_loadd(2, b_tile->data, 64);
2298
2338
 
2299
2339
  _tile_dpbuud(4, 0, 2);
2300
2340
  _tile_dpbuud(6, 1, 2);
2301
2341
  }
2302
2342
 
2303
- _tile_stored(4, c_upper_state.data, 64);
2304
- _tile_stored(6, c_lower_state.data, 64);
2343
+ _tile_stored(4, c_high_state.data, 64);
2344
+ _tile_stored(6, c_low_state.data, 64);
2305
2345
 
2306
- nk_dots_u8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
2307
- c_stride_elements, rows_in_upper_tile, 16);
2308
- if (rows_in_lower_tile > 0) {
2309
- nk_dots_u8_store_sapphireamx_(&c_lower_state,
2310
- c + (row_block_start + 16) * c_stride_elements + col_start,
2311
- c_stride_elements, rows_in_lower_tile, 16);
2346
+ nk_dots_u8_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
2347
+ c_stride_elements, rows_in_high_tile, 16);
2348
+ if (rows_in_low_tile > 0) {
2349
+ nk_dots_u8_store_sapphireamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + col_start,
2350
+ c_stride_elements, rows_in_low_tile, 16);
2312
2351
  }
2313
2352
  }
2314
2353
 
2315
2354
  // Handle column-edge (remaining columns < 16) using AMX with partial tiles
2316
2355
  if (column_remainder_count > 0) {
2317
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2318
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2356
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2357
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2319
2358
 
2320
- nk_dots_u8_state_sapphireamx_t c_upper_state, c_lower_state;
2359
+ nk_dots_u8_state_sapphireamx_t c_high_state, c_low_state;
2321
2360
  nk_dots_u8_a16x64_sapphireamx_t b_as_a;
2322
2361
  nk_dots_u8_b64x16_sapphireamx_t b_tile;
2323
2362
 
@@ -2328,35 +2367,34 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2328
2367
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2329
2368
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2330
2369
 
2331
- nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2332
- a_stride_bytes, rows_in_upper_tile, valid_depth);
2333
- if (rows_in_lower_tile > 0) {
2334
- nk_dots_u8_load_a_sapphireamx_(&a_tile_lower,
2370
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
2371
+ a_stride_bytes, rows_in_high_tile, valid_depth);
2372
+ if (rows_in_low_tile > 0) {
2373
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom,
2335
2374
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2336
- a_stride_bytes, rows_in_lower_tile, valid_depth);
2375
+ a_stride_bytes, rows_in_low_tile, valid_depth);
2337
2376
  }
2338
2377
 
2339
2378
  nk_dots_u8_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
2340
2379
  valid_depth);
2341
2380
  nk_dots_pack_u8_transposed_sapphireamx_(&b_as_a, &b_tile);
2342
2381
 
2343
- _tile_loadd(0, a_tile_upper.data, 64);
2344
- _tile_loadd(1, a_tile_lower.data, 64);
2382
+ _tile_loadd(0, a_tile_top.data, 64);
2383
+ _tile_loadd(1, a_tile_bottom.data, 64);
2345
2384
  _tile_loadd(2, b_tile.data, 64);
2346
2385
 
2347
2386
  _tile_dpbuud(4, 0, 2);
2348
2387
  _tile_dpbuud(6, 1, 2);
2349
2388
  }
2350
2389
 
2351
- _tile_stored(4, c_upper_state.data, 64);
2352
- _tile_stored(6, c_lower_state.data, 64);
2390
+ _tile_stored(4, c_high_state.data, 64);
2391
+ _tile_stored(6, c_low_state.data, 64);
2353
2392
 
2354
- nk_dots_u8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
2355
- c_stride_elements, rows_in_upper_tile, column_remainder_count);
2356
- if (rows_in_lower_tile > 0) {
2357
- nk_dots_u8_store_sapphireamx_(&c_lower_state,
2358
- c + (row_block_start + 16) * c_stride_elements + full_cols,
2359
- c_stride_elements, rows_in_lower_tile, column_remainder_count);
2393
+ nk_dots_u8_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
2394
+ c_stride_elements, rows_in_high_tile, column_remainder_count);
2395
+ if (rows_in_low_tile > 0) {
2396
+ nk_dots_u8_store_sapphireamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + full_cols,
2397
+ c_stride_elements, rows_in_low_tile, column_remainder_count);
2360
2398
  }
2361
2399
  }
2362
2400
  }
@@ -2364,17 +2402,17 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2364
2402
  _tile_release();
2365
2403
  }
2366
2404
 
2367
- NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
2368
- nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
2369
- nk_size_t stride, nk_u32_t *result, nk_size_t result_stride, //
2405
+ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
2406
+ nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
2407
+ nk_size_t stride_in_bytes, nk_u32_t *result, nk_size_t result_stride_in_bytes, //
2370
2408
  nk_size_t row_start, nk_size_t row_count) {
2371
2409
 
2372
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_u32_t);
2410
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_u32_t);
2373
2411
 
2374
2412
  // Handle row slicing: compute rows [row_start, row_end)
2375
2413
  nk_size_t const row_end = (row_count == 0)
2376
- ? n_vectors
2377
- : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
2414
+ ? vectors_count
2415
+ : (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
2378
2416
 
2379
2417
  // Round depth up to multiple of 192 (3 tiles × 64 elements)
2380
2418
  nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
@@ -2390,8 +2428,8 @@ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
2390
2428
  for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
2391
2429
  nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
2392
2430
 
2393
- for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
2394
- nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
2431
+ for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
2432
+ nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
2395
2433
 
2396
2434
  nk_dots_u8_init_sapphireamx_(&state);
2397
2435
 
@@ -2404,19 +2442,19 @@ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
2404
2442
  ? 64
2405
2443
  : (depth > depth_start ? depth - depth_start : 0);
2406
2444
 
2407
- nk_dots_u8_load_a_sapphireamx_( //
2408
- &a_tiles[tile_idx], //
2409
- vectors + row_tile * stride + depth_start, //
2410
- stride, valid_rows, valid_depth);
2445
+ nk_dots_u8_load_a_sapphireamx_( //
2446
+ &a_tiles[tile_idx], //
2447
+ vectors + row_tile * stride_in_bytes + depth_start, //
2448
+ stride_in_bytes, valid_rows, valid_depth);
2411
2449
 
2412
2450
  if (row_tile == col_tile) {
2413
2451
  nk_dots_pack_u8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
2414
2452
  }
2415
2453
  else {
2416
- nk_dots_u8_load_a_sapphireamx_( //
2417
- &b_src_tiles[tile_idx], //
2418
- vectors + col_tile * stride + depth_start, //
2419
- stride, valid_cols, valid_depth);
2454
+ nk_dots_u8_load_a_sapphireamx_( //
2455
+ &b_src_tiles[tile_idx], //
2456
+ vectors + col_tile * stride_in_bytes + depth_start, //
2457
+ stride_in_bytes, valid_cols, valid_depth);
2420
2458
  nk_dots_pack_u8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
2421
2459
  }
2422
2460
  }
@@ -2432,9 +2470,9 @@ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
2432
2470
  }
2433
2471
  }
2434
2472
 
2435
- #pragma endregion // Unsigned Integers
2473
+ #pragma endregion Unsigned Integers
2436
2474
 
2437
- #pragma region Quarter Precision E4M3
2475
+ #pragma region E4M3 Floats
2438
2476
 
2439
2477
  NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sapphireamx(nk_size_t column_count, nk_size_t depth) {
2440
2478
  // FP8 uses BF16 tile layout after conversion (same element count: 32 per row)
@@ -2443,7 +2481,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sapphireamx(nk_size_t column_count,
2443
2481
 
2444
2482
  NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
2445
2483
  nk_e4m3_t const *b, nk_size_t column_count, nk_size_t depth, //
2446
- nk_size_t b_stride, void *b_packed) {
2484
+ nk_size_t b_stride_in_bytes, void *b_packed) {
2447
2485
 
2448
2486
  nk_size_t const tmm_rows = 16;
2449
2487
  nk_size_t const tmm_cols = 32; // Same depth granularity as BF16
@@ -2467,8 +2505,7 @@ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
2467
2505
  nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
2468
2506
  nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
2469
2507
 
2470
- for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
2471
-
2508
+ // Pack tiles using vectorized convert + SIMD transpose
2472
2509
  for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
2473
2510
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2474
2511
  nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
@@ -2479,21 +2516,19 @@ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
2479
2516
  nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
2480
2517
  : (depth - src_column_start);
2481
2518
 
2482
- // Convert E4M3 to BF16 and pack with pair-interleaving
2519
+ // Convert E4M3 BF16 and gather into aligned source tile
2520
+ __mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
2521
+ nk_dots_bf16_a16x32_sapphireamx_t source_tile;
2483
2522
  for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
2484
- nk_size_t src_row = src_row_start + row_idx;
2485
- // Load 32 E4M3 bytes and convert to BF16
2486
- __mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
2487
- __m256i e4m3_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
2488
- __m512i bf16_row = nk_e4m3x32_to_bf16x32_icelake_(e4m3_row);
2489
- // Store with pair-interleaving
2490
- nk_bf16_t bf16_buf[32];
2491
- _mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
2492
- for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
2493
- nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
2494
- tile_output[dst_idx] = bf16_buf[column_idx];
2495
- }
2523
+ __m256i e4m3_row_u8x32 = _mm256_maskz_loadu_epi8(
2524
+ column_mask, b + (src_row_start + row_idx) * b_stride_in_bytes + src_column_start);
2525
+ _mm512_store_si512(&source_tile.data[row_idx][0], nk_e4m3x32_to_bf16x32_icelake_(e4m3_row_u8x32));
2496
2526
  }
2527
+
2528
+ nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
2529
+ nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
2530
+ for (nk_size_t i = 0; i < tile_bytes; i += 64)
2531
+ _mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
2497
2532
  }
2498
2533
  }
2499
2534
 
@@ -2504,10 +2539,11 @@ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
2504
2539
  for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
2505
2540
  nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
2506
2541
  __mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
2507
- __m256i e4m3_chunk = _mm256_maskz_loadu_epi8(
2508
- column_mask, b + (remainder_start_row + row_idx) * b_stride + column_idx);
2509
- __m512i bf16_chunk = nk_e4m3x32_to_bf16x32_icelake_(e4m3_chunk);
2510
- _mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask, bf16_chunk);
2542
+ __m256i e4m3_chunk_u8x32 = _mm256_maskz_loadu_epi8(
2543
+ column_mask, b + (remainder_start_row + row_idx) * b_stride_in_bytes + column_idx);
2544
+ __m512i bf16_chunk_i16x32 = nk_e4m3x32_to_bf16x32_icelake_(e4m3_chunk_u8x32);
2545
+ _mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask,
2546
+ bf16_chunk_i16x32);
2511
2547
  }
2512
2548
  }
2513
2549
  }
@@ -2518,7 +2554,7 @@ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
2518
2554
  header->norms_byte_offset = (nk_u32_t)norms_offset;
2519
2555
  nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
2520
2556
  for (nk_size_t col = 0; col < column_count; col++)
2521
- norms[col] = nk_dots_reduce_sumsq_e4m3_(b + col * b_stride, depth);
2557
+ norms[col] = nk_dots_reduce_sumsq_e4m3_(b + col * b_stride_in_bytes, depth);
2522
2558
  }
2523
2559
 
2524
2560
  NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
@@ -2545,7 +2581,7 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2545
2581
 
2546
2582
  if (depth_tiles_count == 0) return;
2547
2583
 
2548
- nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
2584
+ nk_dots_bf16_a16x32_sapphireamx_t a_tile_top, a_tile_bottom;
2549
2585
  nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
2550
2586
 
2551
2587
  nk_size_t const full_depth_tiles_count = depth / tile_depth;
@@ -2558,8 +2594,8 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2558
2594
  nk_size_t const row_block_start = row_block_idx * 32;
2559
2595
  nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
2560
2596
  nk_size_t const is_full_row_block = (valid_rows_count == 32);
2561
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2562
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2597
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2598
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2563
2599
 
2564
2600
  for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
2565
2601
  nk_size_t const col_block_start = column_block_idx * 32;
@@ -2578,12 +2614,12 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2578
2614
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2579
2615
 
2580
2616
  // Load A with FP8 → BF16 conversion
2581
- nk_dots_e4m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2582
- a_stride_bytes, rows_in_upper_tile, valid_depth);
2583
- if (rows_in_lower_tile > 0) {
2584
- nk_dots_e4m3_load_a_sapphireamx_(&a_tile_lower,
2617
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
2618
+ a_stride_bytes, rows_in_high_tile, valid_depth);
2619
+ if (rows_in_low_tile > 0) {
2620
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_bottom,
2585
2621
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2586
- a_stride_bytes, rows_in_lower_tile, valid_depth);
2622
+ a_stride_bytes, rows_in_low_tile, valid_depth);
2587
2623
  }
2588
2624
 
2589
2625
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
@@ -2593,8 +2629,8 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2593
2629
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2594
2630
  (b_column_right_base + depth_tile_idx) * tile_size);
2595
2631
 
2596
- _tile_loadd(0, a_tile_upper.data, 64);
2597
- _tile_loadd(1, a_tile_lower.data, 64);
2632
+ _tile_loadd(0, a_tile_top.data, 64);
2633
+ _tile_loadd(1, a_tile_bottom.data, 64);
2598
2634
  _tile_loadd(2, b_tile_left->data, 64);
2599
2635
  _tile_loadd(3, b_tile_right->data, 64);
2600
2636
 
@@ -2629,7 +2665,7 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2629
2665
  nk_size_t const col_start = column_tile_idx * 16;
2630
2666
  nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
2631
2667
 
2632
- nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
2668
+ nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
2633
2669
  _tile_zero(4);
2634
2670
  _tile_zero(6);
2635
2671
 
@@ -2637,41 +2673,41 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2637
2673
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2638
2674
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2639
2675
 
2640
- nk_dots_e4m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2641
- a_stride_bytes, rows_in_upper_tile, valid_depth);
2642
- if (rows_in_lower_tile > 0) {
2643
- nk_dots_e4m3_load_a_sapphireamx_(&a_tile_lower,
2676
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
2677
+ a_stride_bytes, rows_in_high_tile, valid_depth);
2678
+ if (rows_in_low_tile > 0) {
2679
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_bottom,
2644
2680
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2645
- a_stride_bytes, rows_in_lower_tile, valid_depth);
2681
+ a_stride_bytes, rows_in_low_tile, valid_depth);
2646
2682
  }
2647
2683
 
2648
2684
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
2649
2685
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2650
2686
  (b_column_base + depth_tile_idx) * tile_size);
2651
2687
 
2652
- _tile_loadd(0, a_tile_upper.data, 64);
2653
- _tile_loadd(1, a_tile_lower.data, 64);
2688
+ _tile_loadd(0, a_tile_top.data, 64);
2689
+ _tile_loadd(1, a_tile_bottom.data, 64);
2654
2690
  _tile_loadd(2, b_tile->data, 64);
2655
2691
 
2656
2692
  _tile_dpbf16ps(4, 0, 2);
2657
2693
  _tile_dpbf16ps(6, 1, 2);
2658
2694
  }
2659
2695
 
2660
- _tile_stored(4, c_upper_state.data, 64);
2661
- _tile_stored(6, c_lower_state.data, 64);
2696
+ _tile_stored(4, c_high_state.data, 64);
2697
+ _tile_stored(6, c_low_state.data, 64);
2662
2698
 
2663
- nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
2664
- c_stride_elements, rows_in_upper_tile, 16);
2665
- if (rows_in_lower_tile > 0) {
2666
- nk_dots_bf16_store_sapphireamx_(&c_lower_state,
2699
+ nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
2700
+ c_stride_elements, rows_in_high_tile, 16);
2701
+ if (rows_in_low_tile > 0) {
2702
+ nk_dots_bf16_store_sapphireamx_(&c_low_state,
2667
2703
  c + (row_block_start + 16) * c_stride_elements + col_start,
2668
- c_stride_elements, rows_in_lower_tile, 16);
2704
+ c_stride_elements, rows_in_low_tile, 16);
2669
2705
  }
2670
2706
  }
2671
2707
 
2672
2708
  // Handle column-edge (remaining columns < 16) using AMX with partial tiles
2673
2709
  if (column_remainder_count > 0) {
2674
- nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
2710
+ nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
2675
2711
  nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
2676
2712
  nk_dots_bf16_b32x16_sapphireamx_t b_tile;
2677
2713
 
@@ -2682,12 +2718,12 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2682
2718
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2683
2719
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2684
2720
 
2685
- nk_dots_e4m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2686
- a_stride_bytes, rows_in_upper_tile, valid_depth);
2687
- if (rows_in_lower_tile > 0) {
2688
- nk_dots_e4m3_load_a_sapphireamx_(&a_tile_lower,
2721
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
2722
+ a_stride_bytes, rows_in_high_tile, valid_depth);
2723
+ if (rows_in_low_tile > 0) {
2724
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_bottom,
2689
2725
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2690
- a_stride_bytes, rows_in_lower_tile, valid_depth);
2726
+ a_stride_bytes, rows_in_low_tile, valid_depth);
2691
2727
  }
2692
2728
 
2693
2729
  // B edge data is already in BF16 format
@@ -2695,23 +2731,23 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2695
2731
  valid_depth);
2696
2732
  nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
2697
2733
 
2698
- _tile_loadd(0, a_tile_upper.data, 64);
2699
- _tile_loadd(1, a_tile_lower.data, 64);
2734
+ _tile_loadd(0, a_tile_top.data, 64);
2735
+ _tile_loadd(1, a_tile_bottom.data, 64);
2700
2736
  _tile_loadd(2, b_tile.data, 64);
2701
2737
 
2702
2738
  _tile_dpbf16ps(4, 0, 2);
2703
2739
  _tile_dpbf16ps(6, 1, 2);
2704
2740
  }
2705
2741
 
2706
- _tile_stored(4, c_upper_state.data, 64);
2707
- _tile_stored(6, c_lower_state.data, 64);
2742
+ _tile_stored(4, c_high_state.data, 64);
2743
+ _tile_stored(6, c_low_state.data, 64);
2708
2744
 
2709
- nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
2710
- c_stride_elements, rows_in_upper_tile, column_remainder_count);
2711
- if (rows_in_lower_tile > 0) {
2712
- nk_dots_bf16_store_sapphireamx_(&c_lower_state,
2745
+ nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
2746
+ c_stride_elements, rows_in_high_tile, column_remainder_count);
2747
+ if (rows_in_low_tile > 0) {
2748
+ nk_dots_bf16_store_sapphireamx_(&c_low_state,
2713
2749
  c + (row_block_start + 16) * c_stride_elements + full_cols,
2714
- c_stride_elements, rows_in_lower_tile, column_remainder_count);
2750
+ c_stride_elements, rows_in_low_tile, column_remainder_count);
2715
2751
  }
2716
2752
  }
2717
2753
  }
@@ -2719,9 +2755,9 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2719
2755
  _tile_release();
2720
2756
  }
2721
2757
 
2722
- #pragma endregion // Quarter Precision E4M3
2758
+ #pragma endregion E4M3 Floats
2723
2759
 
2724
- #pragma region Quarter Precision E5M2
2760
+ #pragma region E5M2 Floats
2725
2761
 
2726
2762
  NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sapphireamx(nk_size_t column_count, nk_size_t depth) {
2727
2763
  return nk_dots_packed_size_bf16_sapphireamx(column_count, depth);
@@ -2729,7 +2765,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sapphireamx(nk_size_t column_count,
2729
2765
 
2730
2766
  NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
2731
2767
  nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth, //
2732
- nk_size_t b_stride, void *b_packed) {
2768
+ nk_size_t b_stride_in_bytes, void *b_packed) {
2733
2769
 
2734
2770
  nk_size_t const tmm_rows = 16;
2735
2771
  nk_size_t const tmm_cols = 32;
@@ -2753,8 +2789,7 @@ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
2753
2789
  nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
2754
2790
  nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
2755
2791
 
2756
- for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
2757
-
2792
+ // Pack tiles using vectorized convert + SIMD transpose
2758
2793
  for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
2759
2794
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2760
2795
  nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
@@ -2765,18 +2800,18 @@ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
2765
2800
  nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
2766
2801
  : (depth - src_column_start);
2767
2802
 
2803
+ __mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
2804
+ nk_dots_bf16_a16x32_sapphireamx_t source_tile;
2768
2805
  for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
2769
- nk_size_t src_row = src_row_start + row_idx;
2770
- __mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
2771
- __m256i e5m2_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
2772
- __m512i bf16_row = nk_e5m2x32_to_bf16x32_icelake_(e5m2_row);
2773
- nk_bf16_t bf16_buf[32];
2774
- _mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
2775
- for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
2776
- nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
2777
- tile_output[dst_idx] = bf16_buf[column_idx];
2778
- }
2806
+ __m256i e5m2_row_u8x32 = _mm256_maskz_loadu_epi8(
2807
+ column_mask, b + (src_row_start + row_idx) * b_stride_in_bytes + src_column_start);
2808
+ _mm512_store_si512(&source_tile.data[row_idx][0], nk_e5m2x32_to_bf16x32_icelake_(e5m2_row_u8x32));
2779
2809
  }
2810
+
2811
+ nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
2812
+ nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
2813
+ for (nk_size_t i = 0; i < tile_bytes; i += 64)
2814
+ _mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
2780
2815
  }
2781
2816
  }
2782
2817
 
@@ -2786,10 +2821,11 @@ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
2786
2821
  for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
2787
2822
  nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
2788
2823
  __mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
2789
- __m256i e5m2_chunk = _mm256_maskz_loadu_epi8(
2790
- column_mask, b + (remainder_start_row + row_idx) * b_stride + column_idx);
2791
- __m512i bf16_chunk = nk_e5m2x32_to_bf16x32_icelake_(e5m2_chunk);
2792
- _mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask, bf16_chunk);
2824
+ __m256i e5m2_chunk_u8x32 = _mm256_maskz_loadu_epi8(
2825
+ column_mask, b + (remainder_start_row + row_idx) * b_stride_in_bytes + column_idx);
2826
+ __m512i bf16_chunk_i16x32 = nk_e5m2x32_to_bf16x32_icelake_(e5m2_chunk_u8x32);
2827
+ _mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask,
2828
+ bf16_chunk_i16x32);
2793
2829
  }
2794
2830
  }
2795
2831
  }
@@ -2800,7 +2836,7 @@ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
2800
2836
  header->norms_byte_offset = (nk_u32_t)norms_offset;
2801
2837
  nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
2802
2838
  for (nk_size_t col = 0; col < column_count; col++)
2803
- norms[col] = nk_dots_reduce_sumsq_e5m2_(b + col * b_stride, depth);
2839
+ norms[col] = nk_dots_reduce_sumsq_e5m2_(b + col * b_stride_in_bytes, depth);
2804
2840
  }
2805
2841
 
2806
2842
  NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
@@ -2826,7 +2862,7 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
2826
2862
 
2827
2863
  if (depth_tiles_count == 0) return;
2828
2864
 
2829
- nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
2865
+ nk_dots_bf16_a16x32_sapphireamx_t a_tile_top, a_tile_bottom;
2830
2866
  nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
2831
2867
 
2832
2868
  nk_size_t const full_depth_tiles_count = depth / tile_depth;
@@ -2839,8 +2875,8 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
2839
2875
  nk_size_t const row_block_start = row_block_idx * 32;
2840
2876
  nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
2841
2877
  nk_size_t const is_full_row_block = (valid_rows_count == 32);
2842
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2843
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2878
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2879
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2844
2880
 
2845
2881
  for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
2846
2882
  nk_size_t const col_block_start = column_block_idx * 32;
@@ -2859,12 +2895,12 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
2859
2895
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2860
2896
 
2861
2897
  // Load A with FP8 → BF16 conversion
2862
- nk_dots_e5m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2863
- a_stride_bytes, rows_in_upper_tile, valid_depth);
2864
- if (rows_in_lower_tile > 0) {
2865
- nk_dots_e5m2_load_a_sapphireamx_(&a_tile_lower,
2898
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
2899
+ a_stride_bytes, rows_in_high_tile, valid_depth);
2900
+ if (rows_in_low_tile > 0) {
2901
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_bottom,
2866
2902
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2867
- a_stride_bytes, rows_in_lower_tile, valid_depth);
2903
+ a_stride_bytes, rows_in_low_tile, valid_depth);
2868
2904
  }
2869
2905
 
2870
2906
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
@@ -2874,8 +2910,8 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
2874
2910
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2875
2911
  (b_column_right_base + depth_tile_idx) * tile_size);
2876
2912
 
2877
- _tile_loadd(0, a_tile_upper.data, 64);
2878
- _tile_loadd(1, a_tile_lower.data, 64);
2913
+ _tile_loadd(0, a_tile_top.data, 64);
2914
+ _tile_loadd(1, a_tile_bottom.data, 64);
2879
2915
  _tile_loadd(2, b_tile_left->data, 64);
2880
2916
  _tile_loadd(3, b_tile_right->data, 64);
2881
2917
 
@@ -2910,7 +2946,7 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
2910
2946
  nk_size_t const col_start = column_tile_idx * 16;
2911
2947
  nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
2912
2948
 
2913
- nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
2949
+ nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
2914
2950
  _tile_zero(4);
2915
2951
  _tile_zero(6);
2916
2952
 
@@ -2918,41 +2954,41 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
2918
2954
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2919
2955
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2920
2956
 
2921
- nk_dots_e5m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2922
- a_stride_bytes, rows_in_upper_tile, valid_depth);
2923
- if (rows_in_lower_tile > 0) {
2924
- nk_dots_e5m2_load_a_sapphireamx_(&a_tile_lower,
2957
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
2958
+ a_stride_bytes, rows_in_high_tile, valid_depth);
2959
+ if (rows_in_low_tile > 0) {
2960
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_bottom,
2925
2961
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2926
- a_stride_bytes, rows_in_lower_tile, valid_depth);
2962
+ a_stride_bytes, rows_in_low_tile, valid_depth);
2927
2963
  }
2928
2964
 
2929
2965
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
2930
2966
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2931
2967
  (b_column_base + depth_tile_idx) * tile_size);
2932
2968
 
2933
- _tile_loadd(0, a_tile_upper.data, 64);
2934
- _tile_loadd(1, a_tile_lower.data, 64);
2969
+ _tile_loadd(0, a_tile_top.data, 64);
2970
+ _tile_loadd(1, a_tile_bottom.data, 64);
2935
2971
  _tile_loadd(2, b_tile->data, 64);
2936
2972
 
2937
2973
  _tile_dpbf16ps(4, 0, 2);
2938
2974
  _tile_dpbf16ps(6, 1, 2);
2939
2975
  }
2940
2976
 
2941
- _tile_stored(4, c_upper_state.data, 64);
2942
- _tile_stored(6, c_lower_state.data, 64);
2977
+ _tile_stored(4, c_high_state.data, 64);
2978
+ _tile_stored(6, c_low_state.data, 64);
2943
2979
 
2944
- nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
2945
- c_stride_elements, rows_in_upper_tile, 16);
2946
- if (rows_in_lower_tile > 0) {
2947
- nk_dots_bf16_store_sapphireamx_(&c_lower_state,
2980
+ nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
2981
+ c_stride_elements, rows_in_high_tile, 16);
2982
+ if (rows_in_low_tile > 0) {
2983
+ nk_dots_bf16_store_sapphireamx_(&c_low_state,
2948
2984
  c + (row_block_start + 16) * c_stride_elements + col_start,
2949
- c_stride_elements, rows_in_lower_tile, 16);
2985
+ c_stride_elements, rows_in_low_tile, 16);
2950
2986
  }
2951
2987
  }
2952
2988
 
2953
2989
  // Handle column-edge (remaining columns < 16) using AMX with partial tiles
2954
2990
  if (column_remainder_count > 0) {
2955
- nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
2991
+ nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
2956
2992
  nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
2957
2993
  nk_dots_bf16_b32x16_sapphireamx_t b_tile;
2958
2994
 
@@ -2963,35 +2999,35 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
2963
2999
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2964
3000
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2965
3001
 
2966
- nk_dots_e5m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2967
- a_stride_bytes, rows_in_upper_tile, valid_depth);
2968
- if (rows_in_lower_tile > 0) {
2969
- nk_dots_e5m2_load_a_sapphireamx_(&a_tile_lower,
3002
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
3003
+ a_stride_bytes, rows_in_high_tile, valid_depth);
3004
+ if (rows_in_low_tile > 0) {
3005
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_bottom,
2970
3006
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2971
- a_stride_bytes, rows_in_lower_tile, valid_depth);
3007
+ a_stride_bytes, rows_in_low_tile, valid_depth);
2972
3008
  }
2973
3009
 
2974
3010
  nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
2975
3011
  valid_depth);
2976
3012
  nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
2977
3013
 
2978
- _tile_loadd(0, a_tile_upper.data, 64);
2979
- _tile_loadd(1, a_tile_lower.data, 64);
3014
+ _tile_loadd(0, a_tile_top.data, 64);
3015
+ _tile_loadd(1, a_tile_bottom.data, 64);
2980
3016
  _tile_loadd(2, b_tile.data, 64);
2981
3017
 
2982
3018
  _tile_dpbf16ps(4, 0, 2);
2983
3019
  _tile_dpbf16ps(6, 1, 2);
2984
3020
  }
2985
3021
 
2986
- _tile_stored(4, c_upper_state.data, 64);
2987
- _tile_stored(6, c_lower_state.data, 64);
3022
+ _tile_stored(4, c_high_state.data, 64);
3023
+ _tile_stored(6, c_low_state.data, 64);
2988
3024
 
2989
- nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
2990
- c_stride_elements, rows_in_upper_tile, column_remainder_count);
2991
- if (rows_in_lower_tile > 0) {
2992
- nk_dots_bf16_store_sapphireamx_(&c_lower_state,
3025
+ nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
3026
+ c_stride_elements, rows_in_high_tile, column_remainder_count);
3027
+ if (rows_in_low_tile > 0) {
3028
+ nk_dots_bf16_store_sapphireamx_(&c_low_state,
2993
3029
  c + (row_block_start + 16) * c_stride_elements + full_cols,
2994
- c_stride_elements, rows_in_lower_tile, column_remainder_count);
3030
+ c_stride_elements, rows_in_low_tile, column_remainder_count);
2995
3031
  }
2996
3032
  }
2997
3033
  }
@@ -2999,17 +3035,17 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
2999
3035
  _tile_release();
3000
3036
  }
3001
3037
 
3002
- NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
3003
- nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
3004
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
3038
+ NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
3039
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
3040
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
3005
3041
  nk_size_t row_start, nk_size_t row_count) {
3006
3042
 
3007
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
3043
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
3008
3044
 
3009
3045
  // Handle row slicing: compute rows [row_start, row_end)
3010
3046
  nk_size_t const row_end = (row_count == 0)
3011
- ? n_vectors
3012
- : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
3047
+ ? vectors_count
3048
+ : (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
3013
3049
 
3014
3050
  // Round depth up to multiple of 96 (3 tiles × 32 elements)
3015
3051
  nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
@@ -3025,8 +3061,8 @@ NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
3025
3061
  for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
3026
3062
  nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
3027
3063
 
3028
- for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
3029
- nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
3064
+ for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
3065
+ nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
3030
3066
 
3031
3067
  nk_dots_bf16_init_sapphireamx_(&state);
3032
3068
 
@@ -3039,19 +3075,19 @@ NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
3039
3075
  ? 32
3040
3076
  : (depth > depth_start ? depth - depth_start : 0);
3041
3077
 
3042
- nk_dots_e5m2_load_a_sapphireamx_( //
3043
- &a_tiles[tile_idx], //
3044
- vectors + row_tile * stride + depth_start, //
3045
- stride, valid_rows, valid_depth);
3078
+ nk_dots_e5m2_load_a_sapphireamx_( //
3079
+ &a_tiles[tile_idx], //
3080
+ vectors + row_tile * stride_in_bytes + depth_start, //
3081
+ stride_in_bytes, valid_rows, valid_depth);
3046
3082
 
3047
3083
  if (row_tile == col_tile) {
3048
3084
  nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
3049
3085
  }
3050
3086
  else {
3051
- nk_dots_e5m2_load_a_sapphireamx_( //
3052
- &b_src_tiles[tile_idx], //
3053
- vectors + col_tile * stride + depth_start, //
3054
- stride, valid_cols, valid_depth);
3087
+ nk_dots_e5m2_load_a_sapphireamx_( //
3088
+ &b_src_tiles[tile_idx], //
3089
+ vectors + col_tile * stride_in_bytes + depth_start, //
3090
+ stride_in_bytes, valid_cols, valid_depth);
3055
3091
  nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
3056
3092
  }
3057
3093
  }
@@ -3067,17 +3103,17 @@ NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
3067
3103
  }
3068
3104
  }
3069
3105
 
3070
- NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
3071
- nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
3072
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
3106
+ NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
3107
+ nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
3108
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
3073
3109
  nk_size_t row_start, nk_size_t row_count) {
3074
3110
 
3075
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
3111
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
3076
3112
 
3077
3113
  // Handle row slicing: compute rows [row_start, row_end)
3078
3114
  nk_size_t const row_end = (row_count == 0)
3079
- ? n_vectors
3080
- : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
3115
+ ? vectors_count
3116
+ : (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
3081
3117
 
3082
3118
  // Round depth up to multiple of 96 (3 tiles × 32 elements)
3083
3119
  nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
@@ -3093,8 +3129,8 @@ NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
3093
3129
  for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
3094
3130
  nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
3095
3131
 
3096
- for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
3097
- nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
3132
+ for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
3133
+ nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
3098
3134
 
3099
3135
  nk_dots_bf16_init_sapphireamx_(&state);
3100
3136
 
@@ -3107,19 +3143,19 @@ NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
3107
3143
  ? 32
3108
3144
  : (depth > depth_start ? depth - depth_start : 0);
3109
3145
 
3110
- nk_dots_e4m3_load_a_sapphireamx_( //
3111
- &a_tiles[tile_idx], //
3112
- vectors + row_tile * stride + depth_start, //
3113
- stride, valid_rows, valid_depth);
3146
+ nk_dots_e4m3_load_a_sapphireamx_( //
3147
+ &a_tiles[tile_idx], //
3148
+ vectors + row_tile * stride_in_bytes + depth_start, //
3149
+ stride_in_bytes, valid_rows, valid_depth);
3114
3150
 
3115
3151
  if (row_tile == col_tile) {
3116
3152
  nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
3117
3153
  }
3118
3154
  else {
3119
- nk_dots_e4m3_load_a_sapphireamx_( //
3120
- &b_src_tiles[tile_idx], //
3121
- vectors + col_tile * stride + depth_start, //
3122
- stride, valid_cols, valid_depth);
3155
+ nk_dots_e4m3_load_a_sapphireamx_( //
3156
+ &b_src_tiles[tile_idx], //
3157
+ vectors + col_tile * stride_in_bytes + depth_start, //
3158
+ stride_in_bytes, valid_cols, valid_depth);
3123
3159
  nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
3124
3160
  }
3125
3161
  }
@@ -3135,9 +3171,9 @@ NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
3135
3171
  }
3136
3172
  }
3137
3173
 
3138
- #pragma endregion // Quarter Precision E5M2
3174
+ #pragma endregion E5M2 Floats
3139
3175
 
3140
- #pragma region Micro Precision E2M3
3176
+ #pragma region E2M3 Floats
3141
3177
 
3142
3178
  /* Load E2M3 A tile with E2M3 to signed I8 conversion via VPERMB LUT.
3143
3179
  * Each E2M3 byte encodes: bit 5 = sign, bits 4:0 = magnitude (5-bit index).
@@ -3194,12 +3230,12 @@ NK_INTERNAL void nk_dots_e2m3_store_sapphireamx_( //
3194
3230
  nk_size_t valid_rows, nk_size_t valid_cols) {
3195
3231
 
3196
3232
  __mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
3197
- __m512 scale = _mm512_set1_ps(1.0f / 256.0f);
3233
+ __m512 scale_f32x16 = _mm512_set1_ps(1.0f / 256.0f);
3198
3234
 
3199
3235
  for (nk_size_t row = 0; row < valid_rows; row++) {
3200
- __m512i i32_row = _mm512_load_si512(state->data[row]);
3201
- __m512 f32_row = _mm512_mul_ps(_mm512_cvtepi32_ps(i32_row), scale);
3202
- _mm512_mask_storeu_ps(dst + row * dst_stride_elements, column_mask, f32_row);
3236
+ __m512i i32_row_i32x16 = _mm512_load_si512(state->data[row]);
3237
+ __m512 f32_row_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(i32_row_i32x16), scale_f32x16);
3238
+ _mm512_mask_storeu_ps(dst + row * dst_stride_elements, column_mask, f32_row_f32x16);
3203
3239
  }
3204
3240
  }
3205
3241
 
@@ -3209,23 +3245,22 @@ NK_INTERNAL void nk_dots_e2m3_output2x2_sapphireamx_( //
3209
3245
  nk_f32_t *dst, nk_size_t dst_stride_elements, //
3210
3246
  nk_size_t valid_rows, nk_size_t valid_cols) {
3211
3247
 
3212
- nk_size_t const rows_upper = (valid_rows > 16) ? 16 : valid_rows;
3248
+ nk_size_t const rows_high = (valid_rows > 16) ? 16 : valid_rows;
3213
3249
  nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
3214
3250
  nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
3215
3251
 
3216
- if (rows_upper > 0 && cols_left > 0)
3217
- nk_dots_e2m3_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_upper, cols_left);
3218
- if (rows_upper > 0 && cols_right > 0)
3219
- nk_dots_e2m3_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_upper, cols_right);
3252
+ if (rows_high > 0 && cols_left > 0)
3253
+ nk_dots_e2m3_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_high, cols_left);
3254
+ if (rows_high > 0 && cols_right > 0)
3255
+ nk_dots_e2m3_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_high, cols_right);
3220
3256
 
3221
3257
  if (valid_rows > 16) {
3222
- nk_size_t const rows_lower = valid_rows - 16;
3223
- nk_f32_t *dst_lower = dst + 16 * dst_stride_elements;
3258
+ nk_size_t const rows_low = valid_rows - 16;
3259
+ nk_f32_t *dst_low = dst + 16 * dst_stride_elements;
3224
3260
  if (cols_left > 0)
3225
- nk_dots_e2m3_store_sapphireamx_(&state->c[1][0], dst_lower, dst_stride_elements, rows_lower, cols_left);
3261
+ nk_dots_e2m3_store_sapphireamx_(&state->c[1][0], dst_low, dst_stride_elements, rows_low, cols_left);
3226
3262
  if (cols_right > 0)
3227
- nk_dots_e2m3_store_sapphireamx_(&state->c[1][1], dst_lower + 16, dst_stride_elements, rows_lower,
3228
- cols_right);
3263
+ nk_dots_e2m3_store_sapphireamx_(&state->c[1][1], dst_low + 16, dst_stride_elements, rows_low, cols_right);
3229
3264
  }
3230
3265
  }
3231
3266
 
@@ -3236,7 +3271,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sapphireamx(nk_size_t column_count,
3236
3271
 
3237
3272
  NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
3238
3273
  nk_e2m3_t const *b, nk_size_t column_count, nk_size_t depth, //
3239
- nk_size_t b_stride, void *b_packed) {
3274
+ nk_size_t b_stride_in_bytes, void *b_packed) {
3240
3275
 
3241
3276
  // AMX I8 tile dimensions: 16 rows x 64 columns (1024 I8 elements = 1KB)
3242
3277
  nk_size_t const tmm_rows = 16;
@@ -3261,16 +3296,7 @@ NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
3261
3296
  nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + tiles_offset);
3262
3297
  nk_i8_t *column_edge_ptr = (nk_i8_t *)((char *)b_packed + column_edge_offset);
3263
3298
 
3264
- // Zero-initialize all tiles (handles depth remainder padding)
3265
- for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
3266
-
3267
- // E2M3 magnitude-to-value LUT (value * 16)
3268
- static nk_u8_t const lut_magnitude[32] = {
3269
- 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, //
3270
- 32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, //
3271
- };
3272
-
3273
- // Pack tiles with E2M3 -> I8 conversion and quad-interleaving
3299
+ // Pack tiles using vectorized E2M3 → I8 conversion + SIMD transpose
3274
3300
  for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
3275
3301
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3276
3302
  nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
@@ -3281,26 +3307,44 @@ NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
3281
3307
  nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
3282
3308
  : (depth - src_column_start);
3283
3309
 
3284
- for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
3285
- for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
3286
- nk_size_t const src_idx = (src_row_start + row_idx) * b_stride + src_column_start + column_idx;
3287
- nk_size_t const dst_idx = (column_idx / 4) * 64 + row_idx * 4 + (column_idx % 4);
3288
- nk_u8_t raw = b[src_idx];
3289
- nk_u8_t magnitude = raw & 0x1F;
3290
- nk_i8_t val = (nk_i8_t)lut_magnitude[magnitude];
3291
- if (raw & 0x20) val = -val;
3292
- tile_output[dst_idx] = val;
3310
+ // Convert E2M3 I8 and gather into aligned source tile
3311
+ nk_dots_i8_a16x64_sapphireamx_t source_tile;
3312
+ if (columns_to_pack == tmm_cols) {
3313
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
3314
+ __m512i raw_row = _mm512_loadu_si512(
3315
+ (nk_e2m3_t const *)((char const *)b + (src_row_start + row_idx) * b_stride_in_bytes) +
3316
+ src_column_start);
3317
+ _mm512_store_si512(&source_tile.data[row_idx][0], nk_e2m3x64_to_i8x64_skylake_(raw_row));
3318
+ }
3319
+ }
3320
+ else {
3321
+ __mmask64 depth_mask = (__mmask64)((columns_to_pack < 64) ? ((1ULL << columns_to_pack) - 1) : ~0ULL);
3322
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
3323
+ __m512i raw_row = _mm512_maskz_loadu_epi8(
3324
+ depth_mask,
3325
+ (nk_e2m3_t const *)((char const *)b + (src_row_start + row_idx) * b_stride_in_bytes) +
3326
+ src_column_start);
3327
+ _mm512_store_si512(&source_tile.data[row_idx][0], nk_e2m3x64_to_i8x64_skylake_(raw_row));
3293
3328
  }
3294
3329
  }
3330
+
3331
+ nk_dots_i8_b64x16_sapphireamx_t transposed_tile;
3332
+ nk_dots_pack_i8_transposed_sapphireamx_(&source_tile, &transposed_tile);
3333
+ for (nk_size_t i = 0; i < tile_elements; i += 64)
3334
+ _mm512_storeu_si512(tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
3295
3335
  }
3296
3336
  }
3297
3337
 
3298
- // Pack column-remainder rows (convert E2M3 to I8)
3338
+ // Pack column-remainder rows (convert E2M3 to I8) using scalar LUT
3339
+ static nk_u8_t const lut_magnitude[32] = {
3340
+ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, //
3341
+ 32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, //
3342
+ };
3299
3343
  if (column_remainder_count > 0) {
3300
3344
  nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
3301
3345
  for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
3302
3346
  for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
3303
- nk_u8_t raw = b[(remainder_start_row + row_idx) * b_stride + column_idx];
3347
+ nk_u8_t raw = b[(remainder_start_row + row_idx) * b_stride_in_bytes + column_idx];
3304
3348
  nk_u8_t magnitude = raw & 0x1F;
3305
3349
  nk_i8_t val = (nk_i8_t)lut_magnitude[magnitude];
3306
3350
  if (raw & 0x20) val = -val;
@@ -3315,7 +3359,7 @@ NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
3315
3359
  header->norms_byte_offset = (nk_u32_t)norms_offset;
3316
3360
  nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
3317
3361
  for (nk_size_t col = 0; col < column_count; col++)
3318
- norms[col] = nk_dots_reduce_sumsq_e2m3_(b + col * b_stride, depth);
3362
+ norms[col] = nk_dots_reduce_sumsq_e2m3_(b + col * b_stride_in_bytes, depth);
3319
3363
  }
3320
3364
 
3321
3365
  NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
@@ -3342,7 +3386,7 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3342
3386
 
3343
3387
  if (depth_tiles_count == 0) return;
3344
3388
 
3345
- nk_dots_i8_a16x64_sapphireamx_t a_tile_upper, a_tile_lower;
3389
+ nk_dots_i8_a16x64_sapphireamx_t a_tile_top, a_tile_bottom;
3346
3390
  nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
3347
3391
 
3348
3392
  nk_size_t const full_depth_tiles_count = depth / tile_depth;
@@ -3355,8 +3399,8 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3355
3399
  nk_size_t const row_block_start = row_block_idx * 32;
3356
3400
  nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
3357
3401
  nk_size_t const is_full_row_block = (valid_rows_count == 32);
3358
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
3359
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
3402
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
3403
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
3360
3404
 
3361
3405
  for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
3362
3406
  nk_size_t const col_block_start = column_block_idx * 32;
@@ -3375,12 +3419,12 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3375
3419
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3376
3420
 
3377
3421
  // Load A with E2M3 -> I8 conversion
3378
- nk_dots_e2m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3379
- a_stride_bytes, rows_in_upper_tile, valid_depth);
3380
- if (rows_in_lower_tile > 0) {
3381
- nk_dots_e2m3_load_a_sapphireamx_(&a_tile_lower,
3422
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
3423
+ a_stride_bytes, rows_in_high_tile, valid_depth);
3424
+ if (rows_in_low_tile > 0) {
3425
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_bottom,
3382
3426
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3383
- a_stride_bytes, rows_in_lower_tile, valid_depth);
3427
+ a_stride_bytes, rows_in_low_tile, valid_depth);
3384
3428
  }
3385
3429
 
3386
3430
  nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
@@ -3390,8 +3434,8 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3390
3434
  (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
3391
3435
  (b_column_right_base + depth_tile_idx) * tile_size);
3392
3436
 
3393
- _tile_loadd(0, a_tile_upper.data, 64);
3394
- _tile_loadd(1, a_tile_lower.data, 64);
3437
+ _tile_loadd(0, a_tile_top.data, 64);
3438
+ _tile_loadd(1, a_tile_bottom.data, 64);
3395
3439
  _tile_loadd(2, b_tile_left->data, 64);
3396
3440
  _tile_loadd(3, b_tile_right->data, 64);
3397
3441
 
@@ -3429,7 +3473,7 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3429
3473
  nk_size_t const col_start = column_tile_idx * 16;
3430
3474
  nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
3431
3475
 
3432
- nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
3476
+ nk_dots_i8_state_sapphireamx_t c_high_state, c_low_state;
3433
3477
  _tile_zero(4);
3434
3478
  _tile_zero(6);
3435
3479
 
@@ -3437,41 +3481,41 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3437
3481
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3438
3482
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3439
3483
 
3440
- nk_dots_e2m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3441
- a_stride_bytes, rows_in_upper_tile, valid_depth);
3442
- if (rows_in_lower_tile > 0) {
3443
- nk_dots_e2m3_load_a_sapphireamx_(&a_tile_lower,
3484
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
3485
+ a_stride_bytes, rows_in_high_tile, valid_depth);
3486
+ if (rows_in_low_tile > 0) {
3487
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_bottom,
3444
3488
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3445
- a_stride_bytes, rows_in_lower_tile, valid_depth);
3489
+ a_stride_bytes, rows_in_low_tile, valid_depth);
3446
3490
  }
3447
3491
 
3448
3492
  nk_dots_i8_b64x16_sapphireamx_t const *b_tile =
3449
3493
  (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
3450
3494
  (b_column_base + depth_tile_idx) * tile_size);
3451
3495
 
3452
- _tile_loadd(0, a_tile_upper.data, 64);
3453
- _tile_loadd(1, a_tile_lower.data, 64);
3496
+ _tile_loadd(0, a_tile_top.data, 64);
3497
+ _tile_loadd(1, a_tile_bottom.data, 64);
3454
3498
  _tile_loadd(2, b_tile->data, 64);
3455
3499
 
3456
3500
  _tile_dpbssd(4, 0, 2);
3457
3501
  _tile_dpbssd(6, 1, 2);
3458
3502
  }
3459
3503
 
3460
- _tile_stored(4, c_upper_state.data, 64);
3461
- _tile_stored(6, c_lower_state.data, 64);
3504
+ _tile_stored(4, c_high_state.data, 64);
3505
+ _tile_stored(6, c_low_state.data, 64);
3462
3506
 
3463
- nk_dots_e2m3_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
3464
- c_stride_elements, rows_in_upper_tile, 16);
3465
- if (rows_in_lower_tile > 0) {
3466
- nk_dots_e2m3_store_sapphireamx_(&c_lower_state,
3507
+ nk_dots_e2m3_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
3508
+ c_stride_elements, rows_in_high_tile, 16);
3509
+ if (rows_in_low_tile > 0) {
3510
+ nk_dots_e2m3_store_sapphireamx_(&c_low_state,
3467
3511
  c + (row_block_start + 16) * c_stride_elements + col_start,
3468
- c_stride_elements, rows_in_lower_tile, 16);
3512
+ c_stride_elements, rows_in_low_tile, 16);
3469
3513
  }
3470
3514
  }
3471
3515
 
3472
3516
  // Handle column-edge (remaining columns < 16) using AMX with partial tiles
3473
3517
  if (column_remainder_count > 0) {
3474
- nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
3518
+ nk_dots_i8_state_sapphireamx_t c_high_state, c_low_state;
3475
3519
  nk_dots_i8_a16x64_sapphireamx_t b_as_a;
3476
3520
  nk_dots_i8_b64x16_sapphireamx_t b_tile;
3477
3521
 
@@ -3482,12 +3526,12 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3482
3526
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3483
3527
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3484
3528
 
3485
- nk_dots_e2m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3486
- a_stride_bytes, rows_in_upper_tile, valid_depth);
3487
- if (rows_in_lower_tile > 0) {
3488
- nk_dots_e2m3_load_a_sapphireamx_(&a_tile_lower,
3529
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
3530
+ a_stride_bytes, rows_in_high_tile, valid_depth);
3531
+ if (rows_in_low_tile > 0) {
3532
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_bottom,
3489
3533
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3490
- a_stride_bytes, rows_in_lower_tile, valid_depth);
3534
+ a_stride_bytes, rows_in_low_tile, valid_depth);
3491
3535
  }
3492
3536
 
3493
3537
  // B edge data is already in I8 format
@@ -3495,23 +3539,23 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3495
3539
  valid_depth);
3496
3540
  nk_dots_pack_i8_transposed_sapphireamx_(&b_as_a, &b_tile);
3497
3541
 
3498
- _tile_loadd(0, a_tile_upper.data, 64);
3499
- _tile_loadd(1, a_tile_lower.data, 64);
3542
+ _tile_loadd(0, a_tile_top.data, 64);
3543
+ _tile_loadd(1, a_tile_bottom.data, 64);
3500
3544
  _tile_loadd(2, b_tile.data, 64);
3501
3545
 
3502
3546
  _tile_dpbssd(4, 0, 2);
3503
3547
  _tile_dpbssd(6, 1, 2);
3504
3548
  }
3505
3549
 
3506
- _tile_stored(4, c_upper_state.data, 64);
3507
- _tile_stored(6, c_lower_state.data, 64);
3550
+ _tile_stored(4, c_high_state.data, 64);
3551
+ _tile_stored(6, c_low_state.data, 64);
3508
3552
 
3509
- nk_dots_e2m3_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
3510
- c_stride_elements, rows_in_upper_tile, column_remainder_count);
3511
- if (rows_in_lower_tile > 0) {
3512
- nk_dots_e2m3_store_sapphireamx_(&c_lower_state,
3553
+ nk_dots_e2m3_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
3554
+ c_stride_elements, rows_in_high_tile, column_remainder_count);
3555
+ if (rows_in_low_tile > 0) {
3556
+ nk_dots_e2m3_store_sapphireamx_(&c_low_state,
3513
3557
  c + (row_block_start + 16) * c_stride_elements + full_cols,
3514
- c_stride_elements, rows_in_lower_tile, column_remainder_count);
3558
+ c_stride_elements, rows_in_low_tile, column_remainder_count);
3515
3559
  }
3516
3560
  }
3517
3561
  }
@@ -3519,17 +3563,17 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3519
3563
  _tile_release();
3520
3564
  }
3521
3565
 
3522
- NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
3523
- nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
3524
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
3566
+ NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
3567
+ nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
3568
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
3525
3569
  nk_size_t row_start, nk_size_t row_count) {
3526
3570
 
3527
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
3571
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
3528
3572
 
3529
3573
  // Handle row slicing: compute rows [row_start, row_end)
3530
3574
  nk_size_t const row_end = (row_count == 0)
3531
- ? n_vectors
3532
- : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
3575
+ ? vectors_count
3576
+ : (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
3533
3577
 
3534
3578
  // Round depth up to multiple of 192 (3 tiles x 64 elements)
3535
3579
  nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
@@ -3545,8 +3589,8 @@ NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
3545
3589
  for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
3546
3590
  nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
3547
3591
 
3548
- for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
3549
- nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
3592
+ for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
3593
+ nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
3550
3594
 
3551
3595
  nk_dots_i8_init_sapphireamx_(&state);
3552
3596
 
@@ -3559,19 +3603,19 @@ NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
3559
3603
  ? 64
3560
3604
  : (depth > depth_start ? depth - depth_start : 0);
3561
3605
 
3562
- nk_dots_e2m3_load_a_sapphireamx_( //
3563
- &a_tiles[tile_idx], //
3564
- vectors + row_tile * stride + depth_start, //
3565
- stride, valid_rows, valid_depth);
3606
+ nk_dots_e2m3_load_a_sapphireamx_( //
3607
+ &a_tiles[tile_idx], //
3608
+ vectors + row_tile * stride_in_bytes + depth_start, //
3609
+ stride_in_bytes, valid_rows, valid_depth);
3566
3610
 
3567
3611
  if (row_tile == col_tile) {
3568
3612
  nk_dots_pack_i8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
3569
3613
  }
3570
3614
  else {
3571
- nk_dots_e2m3_load_a_sapphireamx_( //
3572
- &b_src_tiles[tile_idx], //
3573
- vectors + col_tile * stride + depth_start, //
3574
- stride, valid_cols, valid_depth);
3615
+ nk_dots_e2m3_load_a_sapphireamx_( //
3616
+ &b_src_tiles[tile_idx], //
3617
+ vectors + col_tile * stride_in_bytes + depth_start, //
3618
+ stride_in_bytes, valid_cols, valid_depth);
3575
3619
  nk_dots_pack_i8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
3576
3620
  }
3577
3621
  }
@@ -3587,9 +3631,9 @@ NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
3587
3631
  }
3588
3632
  }
3589
3633
 
3590
- #pragma endregion // Micro Precision E2M3
3634
+ #pragma endregion E2M3 Floats
3591
3635
 
3592
- #pragma region Micro Precision E3M2
3636
+ #pragma region E3M2 Floats
3593
3637
 
3594
3638
  /* Load E3M2 A tile with FP8 to BF16 conversion */
3595
3639
  NK_INTERNAL void nk_dots_e3m2_load_a_sapphireamx_( //
@@ -3598,15 +3642,15 @@ NK_INTERNAL void nk_dots_e3m2_load_a_sapphireamx_( //
3598
3642
  nk_size_t valid_rows, nk_size_t valid_cols) {
3599
3643
 
3600
3644
  __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
3601
- __m512i zero = _mm512_setzero_si512();
3645
+ __m512i zero_i16x32 = _mm512_setzero_si512();
3602
3646
 
3603
3647
  for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
3604
3648
  if (row_idx < valid_rows) {
3605
- __m256i e3m2_row = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
3606
- __m512i bf16_row = nk_e3m2x32_to_bf16x32_icelake_(e3m2_row);
3607
- _mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row);
3649
+ __m256i e3m2_row_u8x32 = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
3650
+ __m512i bf16_row_i16x32 = nk_e3m2x32_to_bf16x32_icelake_(e3m2_row_u8x32);
3651
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row_i16x32);
3608
3652
  }
3609
- else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
3653
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
3610
3654
  }
3611
3655
  nk_compiler_barrier_sapphireamx_();
3612
3656
  }
@@ -3617,7 +3661,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_sapphireamx(nk_size_t column_count,
3617
3661
 
3618
3662
  NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
3619
3663
  nk_e3m2_t const *b, nk_size_t column_count, nk_size_t depth, //
3620
- nk_size_t b_stride, void *b_packed) {
3664
+ nk_size_t b_stride_in_bytes, void *b_packed) {
3621
3665
 
3622
3666
  nk_size_t const tmm_rows = 16;
3623
3667
  nk_size_t const tmm_cols = 32;
@@ -3641,8 +3685,7 @@ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
3641
3685
  nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
3642
3686
  nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
3643
3687
 
3644
- for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
3645
-
3688
+ // Pack tiles using vectorized convert + SIMD transpose
3646
3689
  for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
3647
3690
  for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3648
3691
  nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
@@ -3653,18 +3696,18 @@ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
3653
3696
  nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
3654
3697
  : (depth - src_column_start);
3655
3698
 
3699
+ __mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
3700
+ nk_dots_bf16_a16x32_sapphireamx_t source_tile;
3656
3701
  for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
3657
- nk_size_t src_row = src_row_start + row_idx;
3658
- __mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
3659
- __m256i e3m2_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
3660
- __m512i bf16_row = nk_e3m2x32_to_bf16x32_icelake_(e3m2_row);
3661
- nk_bf16_t bf16_buf[32];
3662
- _mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
3663
- for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
3664
- nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
3665
- tile_output[dst_idx] = bf16_buf[column_idx];
3666
- }
3702
+ __m256i e3m2_row_u8x32 = _mm256_maskz_loadu_epi8(
3703
+ column_mask, b + (src_row_start + row_idx) * b_stride_in_bytes + src_column_start);
3704
+ _mm512_store_si512(&source_tile.data[row_idx][0], nk_e3m2x32_to_bf16x32_icelake_(e3m2_row_u8x32));
3667
3705
  }
3706
+
3707
+ nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
3708
+ nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
3709
+ for (nk_size_t i = 0; i < tile_bytes; i += 64)
3710
+ _mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
3668
3711
  }
3669
3712
  }
3670
3713
 
@@ -3674,10 +3717,11 @@ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
3674
3717
  for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
3675
3718
  nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
3676
3719
  __mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
3677
- __m256i e3m2_chunk = _mm256_maskz_loadu_epi8(
3678
- column_mask, b + (remainder_start_row + row_idx) * b_stride + column_idx);
3679
- __m512i bf16_chunk = nk_e3m2x32_to_bf16x32_icelake_(e3m2_chunk);
3680
- _mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask, bf16_chunk);
3720
+ __m256i e3m2_chunk_u8x32 = _mm256_maskz_loadu_epi8(
3721
+ column_mask, b + (remainder_start_row + row_idx) * b_stride_in_bytes + column_idx);
3722
+ __m512i bf16_chunk_i16x32 = nk_e3m2x32_to_bf16x32_icelake_(e3m2_chunk_u8x32);
3723
+ _mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask,
3724
+ bf16_chunk_i16x32);
3681
3725
  }
3682
3726
  }
3683
3727
  }
@@ -3688,7 +3732,7 @@ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
3688
3732
  header->norms_byte_offset = (nk_u32_t)norms_offset;
3689
3733
  nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
3690
3734
  for (nk_size_t col = 0; col < column_count; col++)
3691
- norms[col] = nk_dots_reduce_sumsq_e3m2_(b + col * b_stride, depth);
3735
+ norms[col] = nk_dots_reduce_sumsq_e3m2_(b + col * b_stride_in_bytes, depth);
3692
3736
  }
3693
3737
 
3694
3738
  NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
@@ -3714,7 +3758,7 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
3714
3758
 
3715
3759
  if (depth_tiles_count == 0) return;
3716
3760
 
3717
- nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
3761
+ nk_dots_bf16_a16x32_sapphireamx_t a_tile_top, a_tile_bottom;
3718
3762
  nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
3719
3763
 
3720
3764
  nk_size_t const full_depth_tiles_count = depth / tile_depth;
@@ -3727,8 +3771,8 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
3727
3771
  nk_size_t const row_block_start = row_block_idx * 32;
3728
3772
  nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
3729
3773
  nk_size_t const is_full_row_block = (valid_rows_count == 32);
3730
- nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
3731
- nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
3774
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
3775
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
3732
3776
 
3733
3777
  for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
3734
3778
  nk_size_t const col_block_start = column_block_idx * 32;
@@ -3747,12 +3791,12 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
3747
3791
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3748
3792
 
3749
3793
  // Load A with FP8 -> BF16 conversion
3750
- nk_dots_e3m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3751
- a_stride_bytes, rows_in_upper_tile, valid_depth);
3752
- if (rows_in_lower_tile > 0) {
3753
- nk_dots_e3m2_load_a_sapphireamx_(&a_tile_lower,
3794
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
3795
+ a_stride_bytes, rows_in_high_tile, valid_depth);
3796
+ if (rows_in_low_tile > 0) {
3797
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_bottom,
3754
3798
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3755
- a_stride_bytes, rows_in_lower_tile, valid_depth);
3799
+ a_stride_bytes, rows_in_low_tile, valid_depth);
3756
3800
  }
3757
3801
 
3758
3802
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
@@ -3762,8 +3806,8 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
3762
3806
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
3763
3807
  (b_column_right_base + depth_tile_idx) * tile_size);
3764
3808
 
3765
- _tile_loadd(0, a_tile_upper.data, 64);
3766
- _tile_loadd(1, a_tile_lower.data, 64);
3809
+ _tile_loadd(0, a_tile_top.data, 64);
3810
+ _tile_loadd(1, a_tile_bottom.data, 64);
3767
3811
  _tile_loadd(2, b_tile_left->data, 64);
3768
3812
  _tile_loadd(3, b_tile_right->data, 64);
3769
3813
 
@@ -3798,7 +3842,7 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
3798
3842
  nk_size_t const col_start = column_tile_idx * 16;
3799
3843
  nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
3800
3844
 
3801
- nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
3845
+ nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
3802
3846
  _tile_zero(4);
3803
3847
  _tile_zero(6);
3804
3848
 
@@ -3806,41 +3850,41 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
3806
3850
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3807
3851
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3808
3852
 
3809
- nk_dots_e3m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3810
- a_stride_bytes, rows_in_upper_tile, valid_depth);
3811
- if (rows_in_lower_tile > 0) {
3812
- nk_dots_e3m2_load_a_sapphireamx_(&a_tile_lower,
3853
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
3854
+ a_stride_bytes, rows_in_high_tile, valid_depth);
3855
+ if (rows_in_low_tile > 0) {
3856
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_bottom,
3813
3857
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3814
- a_stride_bytes, rows_in_lower_tile, valid_depth);
3858
+ a_stride_bytes, rows_in_low_tile, valid_depth);
3815
3859
  }
3816
3860
 
3817
3861
  nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
3818
3862
  (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
3819
3863
  (b_column_base + depth_tile_idx) * tile_size);
3820
3864
 
3821
- _tile_loadd(0, a_tile_upper.data, 64);
3822
- _tile_loadd(1, a_tile_lower.data, 64);
3865
+ _tile_loadd(0, a_tile_top.data, 64);
3866
+ _tile_loadd(1, a_tile_bottom.data, 64);
3823
3867
  _tile_loadd(2, b_tile->data, 64);
3824
3868
 
3825
3869
  _tile_dpbf16ps(4, 0, 2);
3826
3870
  _tile_dpbf16ps(6, 1, 2);
3827
3871
  }
3828
3872
 
3829
- _tile_stored(4, c_upper_state.data, 64);
3830
- _tile_stored(6, c_lower_state.data, 64);
3873
+ _tile_stored(4, c_high_state.data, 64);
3874
+ _tile_stored(6, c_low_state.data, 64);
3831
3875
 
3832
- nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
3833
- c_stride_elements, rows_in_upper_tile, 16);
3834
- if (rows_in_lower_tile > 0) {
3835
- nk_dots_bf16_store_sapphireamx_(&c_lower_state,
3876
+ nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
3877
+ c_stride_elements, rows_in_high_tile, 16);
3878
+ if (rows_in_low_tile > 0) {
3879
+ nk_dots_bf16_store_sapphireamx_(&c_low_state,
3836
3880
  c + (row_block_start + 16) * c_stride_elements + col_start,
3837
- c_stride_elements, rows_in_lower_tile, 16);
3881
+ c_stride_elements, rows_in_low_tile, 16);
3838
3882
  }
3839
3883
  }
3840
3884
 
3841
3885
  // Handle column-edge (remaining columns < 16) using AMX with partial tiles
3842
3886
  if (column_remainder_count > 0) {
3843
- nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
3887
+ nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
3844
3888
  nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
3845
3889
  nk_dots_bf16_b32x16_sapphireamx_t b_tile;
3846
3890
 
@@ -3851,35 +3895,35 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
3851
3895
  nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3852
3896
  nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3853
3897
 
3854
- nk_dots_e3m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3855
- a_stride_bytes, rows_in_upper_tile, valid_depth);
3856
- if (rows_in_lower_tile > 0) {
3857
- nk_dots_e3m2_load_a_sapphireamx_(&a_tile_lower,
3898
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
3899
+ a_stride_bytes, rows_in_high_tile, valid_depth);
3900
+ if (rows_in_low_tile > 0) {
3901
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_bottom,
3858
3902
  a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3859
- a_stride_bytes, rows_in_lower_tile, valid_depth);
3903
+ a_stride_bytes, rows_in_low_tile, valid_depth);
3860
3904
  }
3861
3905
 
3862
3906
  nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
3863
3907
  valid_depth);
3864
3908
  nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
3865
3909
 
3866
- _tile_loadd(0, a_tile_upper.data, 64);
3867
- _tile_loadd(1, a_tile_lower.data, 64);
3910
+ _tile_loadd(0, a_tile_top.data, 64);
3911
+ _tile_loadd(1, a_tile_bottom.data, 64);
3868
3912
  _tile_loadd(2, b_tile.data, 64);
3869
3913
 
3870
3914
  _tile_dpbf16ps(4, 0, 2);
3871
3915
  _tile_dpbf16ps(6, 1, 2);
3872
3916
  }
3873
3917
 
3874
- _tile_stored(4, c_upper_state.data, 64);
3875
- _tile_stored(6, c_lower_state.data, 64);
3918
+ _tile_stored(4, c_high_state.data, 64);
3919
+ _tile_stored(6, c_low_state.data, 64);
3876
3920
 
3877
- nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
3878
- c_stride_elements, rows_in_upper_tile, column_remainder_count);
3879
- if (rows_in_lower_tile > 0) {
3880
- nk_dots_bf16_store_sapphireamx_(&c_lower_state,
3921
+ nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
3922
+ c_stride_elements, rows_in_high_tile, column_remainder_count);
3923
+ if (rows_in_low_tile > 0) {
3924
+ nk_dots_bf16_store_sapphireamx_(&c_low_state,
3881
3925
  c + (row_block_start + 16) * c_stride_elements + full_cols,
3882
- c_stride_elements, rows_in_lower_tile, column_remainder_count);
3926
+ c_stride_elements, rows_in_low_tile, column_remainder_count);
3883
3927
  }
3884
3928
  }
3885
3929
  }
@@ -3887,18 +3931,18 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
3887
3931
  _tile_release();
3888
3932
  }
3889
3933
 
3890
- NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx( //
3891
- nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
3892
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
3934
+ NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx( //
3935
+ nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
3936
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
3893
3937
  nk_size_t row_start, nk_size_t row_count) {
3894
3938
 
3895
- nk_size_t const stride_elements = stride; // sizeof(nk_e3m2_t) == 1, so bytes == elements
3896
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
3939
+ nk_size_t const stride_elements = stride_in_bytes; // sizeof(nk_e3m2_t) == 1, so bytes == elements
3940
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
3897
3941
 
3898
3942
  // Handle row slicing: compute rows [row_start, row_end)
3899
3943
  nk_size_t const row_end = (row_count == 0)
3900
- ? n_vectors
3901
- : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
3944
+ ? vectors_count
3945
+ : (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
3902
3946
 
3903
3947
  // Round depth up to multiple of 96 (3 tiles x 32 bf16 elements)
3904
3948
  nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
@@ -3914,8 +3958,8 @@ NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx( //
3914
3958
  for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
3915
3959
  nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
3916
3960
 
3917
- for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
3918
- nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
3961
+ for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
3962
+ nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
3919
3963
 
3920
3964
  nk_dots_bf16_init_sapphireamx_(&state);
3921
3965
 
@@ -3956,7 +4000,7 @@ NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx( //
3956
4000
  }
3957
4001
  }
3958
4002
 
3959
- #pragma endregion // Micro Precision E3M2
4003
+ #pragma endregion E3M2 Floats
3960
4004
 
3961
4005
  #if defined(__clang__)
3962
4006
  #pragma clang attribute pop