numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,3973 @@
1
+ /**
2
+ * @brief SIMD-accelerated Batched Dot Products for Sapphire Rapids.
3
+ * @file include/numkong/dots/sapphireamx.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/dots.h
8
+ *
9
+ * This file contains tiled matrix-multiplication kernels optimized for Intel AMX instructions,
10
+ * leveraging the new TMM registers on Intel Sapphire Rapids CPUs. Those are much larger than ZMM:
11
+ *
12
+ * - BF16 tiles: 16 rows × 32 elements = 512 BF16 values = 1KB per tile
13
+ * - INT8 tiles: 16 rows × 64 elements = 1024 INT8 values = 1KB per tile
14
+ *
15
+ * We typically use 4 registers for the 2 × 2 tile output for the matrix C accumulators, leaving
16
+ * 4 other registers for parts of A and B matrices:
17
+ *
18
+ * - TMM0, TMM1: A matrix tiles (row blocks i and i+16)
19
+ * - TMM2, TMM3: B matrix tiles (column blocks j and j+16)
20
+ * - TMM4-7: C accumulator tiles (2 × 2 output grid)
21
+ *
22
+ * In most synthetic benchmarks there seems to be no mahor difference between aggregating into 1 or 4
23
+ * output tiles, implying the CPU's ability to internally pipeline the accumulation; so using 2 × 2 for
24
+ * ouputs is more of memory-bandwidth saving measure.
25
+ *
26
+ * Lacking High Bandwidth Mememory, the performance in GEMM-like BLAS workloads is dominated by memory
27
+ * bandwidth. Latency hiding is also extremely hard, heavily affecting performance numbers. For reference,
28
+ * Intel MKL SGEMM for FP32 inputs yeilds arounf 250 GigaOPS per core on Intel Sapphire Rapids, leveraging
29
+ * AVX-512. At the same time, for AMX:
30
+ *
31
+ * - BF16 peak: ≈ 3 TeraOPS per core in theory, ≈ 500 GigaOPS per core in practice
32
+ * - INT8 peak: ≈ 6 TeraOPS per core in theory, ≈ 1000 GigaOPS per core in practice
33
+ *
34
+ * Several optimizations are used across file:
35
+ *
36
+ * - Pre-pack B matrix once for repeated inference (avoids runtime reordering)
37
+ * - Morton Z-curve tile ordering improves L2 cache hit rate by 5-25%
38
+ * - Use streaming stores for large C matrices to avoid cache pollution
39
+ *
40
+ * @section amx_instructions Intel AMX Instructions (Sapphire Rapids+)
41
+ *
42
+ * Tile configuration and data movement:
43
+ *
44
+ * Intrinsic Instruction Notes
45
+ * _tile_loadconfig LDTILECFG (mem64) Configure tile palette
46
+ * _tile_loadd TILELOADD (TMM, mem, stride) Load tile from memory
47
+ * _tile_stored TILESTORED (mem, TMM, stride) Store tile to memory
48
+ * _tile_zero TILEZERO (TMM) Zero a tile register
49
+ *
50
+ * BF16 matrix multiply (AMX-BF16):
51
+ *
52
+ * Intrinsic Instruction Operation
53
+ * _tile_dpbf16ps TDPBF16PS (TMM, TMM, TMM) C += A × B (bf16 → f32)
54
+ *
55
+ * INT8 matrix multiply (AMX-INT8):
56
+ *
57
+ * Intrinsic Instruction Operation
58
+ * _tile_dpbssd TDPBSSD (TMM, TMM, TMM) C += A × B (i8 × i8 → i32)
59
+ * _tile_dpbsud TDPBSUD (TMM, TMM, TMM) C += A × B (i8 × u8 → i32)
60
+ * _tile_dpbusd TDPBUSD (TMM, TMM, TMM) C += A × B (u8 × i8 → i32)
61
+ * _tile_dpbuud TDPBUUD (TMM, TMM, TMM) C += A × B (u8 × u8 → u32)
62
+ *
63
+ * AMX performance characteristics:
64
+ * - TDPBF16PS: 16 × 16 × 32 = 8192 BF16 MACs per instruction
65
+ * - TDPBSSD: 16 × 16 × 64 = 16384 INT8 MACs per instruction
66
+ * - Tile load latency is ~20-30 cycles; software pipelining essential
67
+ * - PDEP/PEXT used for Morton Z-curve encoding (BMI2): 2-3cy @ p1
68
+ */
69
+ #ifndef NK_DOTS_SAPPHIREAMX_H
70
+ #define NK_DOTS_SAPPHIREAMX_H
71
+
72
+ #if NK_TARGET_X86_
73
+ #if NK_TARGET_SAPPHIREAMX
74
+
75
+ #include "numkong/cast/icelake.h" // For FP8 ↔ BF16 conversions
76
+ #include "numkong/dots/serial.h" // For nk_dots_reduce_sumsq_bf16_
77
+
78
+ #if defined(__cplusplus)
79
+ extern "C" {
80
+ #endif
81
+
82
+ #if defined(__clang__)
83
+ #pragma clang attribute push( \
84
+ __attribute__((target( \
85
+ "avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512vbmi,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8"))), \
86
+ apply_to = function)
87
+ #elif defined(__GNUC__)
88
+ #pragma GCC push_options
89
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512vbmi", "f16c", "fma", \
90
+ "bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8")
91
+ #endif
92
+
93
+ /* AMX-specific packed buffer header (64-byte aligned).
94
+ * Different from nk_dots_amx_packed_header_t as AMX uses tile-based layout.
95
+ */
96
+ typedef struct {
97
+ nk_u32_t full_column_tiles; // Number of full column tiles (16 rows each)
98
+ nk_u32_t full_depth_tiles; // Number of depth tiles (32 columns for BF16, 64 for I8)
99
+ nk_u32_t column_remainder_count; // Remaining rows after full tiles (0-15)
100
+ nk_u32_t column_edge_offset; // Byte offset to edge data region
101
+ nk_u32_t norms_byte_offset; // Byte offset to per-column norms (for angular/euclidean)
102
+ nk_u32_t reserved[11]; // Padding to 64 bytes
103
+ } nk_dots_amx_packed_header_t;
104
+
105
+ /* Composable tile structures for AMX operations.
106
+ * These enable reusable primitives and cross-correlation (A × Aᵀ) use cases.
107
+ */
108
+
109
+ /* BF16 A tile: 16 rows × 32 depth-elements, row-major layout.
110
+ * Loaded from source matrix, used as left operand in AMX multiply.
111
+ */
112
+ typedef struct {
113
+ NK_ALIGN64 nk_bf16_t data[16][32]; // 16 rows × 32 columns = 1KB
114
+ } nk_dots_bf16_a16x32_sapphireamx_t;
115
+
116
+ /* BF16 B tile: 32 depth × 16 columns, pair-interleaved for TDPBF16PS.
117
+ * Access pattern: data[depth/2][column][depth%2] for logical B[depth, column].
118
+ * Pre-packed from column-major or transposed source.
119
+ */
120
+ typedef struct {
121
+ NK_ALIGN64 nk_bf16_t data[16][16][2]; // 16 depth-groups × 16 columns × 2 = 1KB
122
+ } nk_dots_bf16_b32x16_sapphireamx_t;
123
+
124
+ /* BF16 output state: 16 × 16 F32 accumulator tile.
125
+ * Holds partial sums during depth-dimension accumulation.
126
+ */
127
+ typedef struct {
128
+ NK_ALIGN64 nk_f32_t data[16][16]; // 16 × 16 = 1KB
129
+ } nk_dots_bf16_state_sapphireamx_t;
130
+
131
+ /* INT8 A tile: 16 rows × 64 depth-elements, row-major layout.
132
+ */
133
+ typedef struct {
134
+ NK_ALIGN64 nk_i8_t data[16][64]; // 16 rows × 64 columns = 1KB
135
+ } nk_dots_i8_a16x64_sapphireamx_t;
136
+
137
+ /* INT8 B tile: 64 depth × 16 columns, quad-interleaved for TDPBSSD.
138
+ * Access pattern: data[depth/4][column][depth%4] for logical B[depth, column].
139
+ */
140
+ typedef struct {
141
+ NK_ALIGN64 nk_i8_t data[16][16][4]; // 16 depth-groups × 16 columns × 4 = 1KB
142
+ } nk_dots_i8_b64x16_sapphireamx_t;
143
+
144
+ /* INT8 output state: 16 × 16 I32 accumulator tile.
145
+ */
146
+ typedef struct {
147
+ NK_ALIGN64 nk_i32_t data[16][16]; // 16 × 16 = 1KB
148
+ } nk_dots_i8_state_sapphireamx_t;
149
+
150
+ /* BF16 2 × 2 output state: 32 × 32 F32 output (4 accumulator tiles).
151
+ * Used for GEMM's 2 × 2 output blocking pattern.
152
+ */
153
+ typedef struct {
154
+ nk_dots_bf16_state_sapphireamx_t c[2][2]; // 4KB total
155
+ } nk_dots_bf16_state2x2_sapphireamx_t;
156
+
157
+ /* INT8 2 × 2 output state: 32 × 32 I32 output (4 accumulator tiles).
158
+ */
159
+ typedef struct {
160
+ nk_dots_i8_state_sapphireamx_t c[2][2]; // 4KB total
161
+ } nk_dots_i8_state2x2_sapphireamx_t;
162
+
163
+ /* UINT8 A tile: 16 rows × 64 depth-elements, row-major layout.
164
+ * Same layout as I8, different interpretation of signed vs unsigned.
165
+ */
166
+ typedef struct {
167
+ NK_ALIGN64 nk_u8_t data[16][64]; // 16 rows × 64 columns = 1KB
168
+ } nk_dots_u8_a16x64_sapphireamx_t;
169
+
170
+ /* UINT8 B tile: 64 depth × 16 columns, quad-interleaved for TDPBUUD.
171
+ */
172
+ typedef struct {
173
+ NK_ALIGN64 nk_u8_t data[16][16][4]; // 16 depth-groups × 16 columns × 4 = 1KB
174
+ } nk_dots_u8_b64x16_sapphireamx_t;
175
+
176
+ /* UINT8 output state: 16 × 16 U32 accumulator tile.
177
+ */
178
+ typedef struct {
179
+ NK_ALIGN64 nk_u32_t data[16][16]; // 16 × 16 = 1KB
180
+ } nk_dots_u8_state_sapphireamx_t;
181
+
182
+ /* UINT8 2 × 2 output state: 32 × 32 U32 output (4 accumulator tiles).
183
+ */
184
+ typedef struct {
185
+ nk_dots_u8_state_sapphireamx_t c[2][2]; // 4KB total
186
+ } nk_dots_u8_state2x2_sapphireamx_t;
187
+
188
+ /* Morton Z-curve encoding for cache-friendly tile traversal */
189
+ NK_INTERNAL nk_u64_t nk_morton_encode_sapphireamx_(nk_u32_t tile_row, nk_u32_t tile_col) {
190
+ return _pdep_u64(tile_row, 0x5555555555555555ULL) | _pdep_u64(tile_col, 0xAAAAAAAAAAAAAAAAULL);
191
+ }
192
+
193
+ /* Configure AMX tile registers */
194
+ NK_INTERNAL void nk_amx_tile_configure_sapphireamx_(void) {
195
+ NK_ALIGN64 nk_u8_t tile_config[64] = {0};
196
+ tile_config[0] = 1; // palette 1 (standard tile configuration)
197
+
198
+ nk_u16_t *bytes_per_row = (nk_u16_t *)&tile_config[16];
199
+ nk_u8_t *rows_per_tile = &tile_config[48];
200
+
201
+ for (int tile_idx = 0; tile_idx < 8; tile_idx++) {
202
+ rows_per_tile[tile_idx] = 16; // 16 rows per tile
203
+ bytes_per_row[tile_idx] = 64; // 64 bytes per row (1KB total)
204
+ }
205
+ _tile_loadconfig(tile_config);
206
+ }
207
+
208
+ /** @brief Compiler memory barrier to ensure stores complete before AMX tile loads */
209
+ #if defined(_MSC_VER)
210
+ NK_INTERNAL void nk_compiler_barrier_sapphireamx_(void) { _ReadWriteBarrier(); }
211
+ #else
212
+ NK_INTERNAL void nk_compiler_barrier_sapphireamx_(void) { __asm__ volatile("" ::: "memory"); }
213
+ #endif
214
+
215
+ /* Initialize BF16 output state to zero */
216
+ NK_INTERNAL void nk_dots_bf16_init_sapphireamx_(nk_dots_bf16_state_sapphireamx_t *state) {
217
+ __m512 zero = _mm512_setzero_ps();
218
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_ps(state->data[row_idx], zero); }
219
+ }
220
+
221
+ /* Load A tile from row-major source with masking for edge tiles */
222
+ NK_INTERNAL void nk_dots_bf16_load_a_sapphireamx_( //
223
+ nk_dots_bf16_a16x32_sapphireamx_t *a_tile, //
224
+ nk_bf16_t const *src, nk_size_t src_stride_elements, //
225
+ nk_size_t valid_rows, nk_size_t valid_cols) {
226
+
227
+ __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
228
+ __m512i zero = _mm512_setzero_si512();
229
+
230
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
231
+ if (row_idx < valid_rows) {
232
+ __m512i row = _mm512_maskz_loadu_epi16(column_mask, src + row_idx * src_stride_elements);
233
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], row);
234
+ }
235
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
236
+ }
237
+ nk_compiler_barrier_sapphireamx_();
238
+ }
239
+
240
+ /* Store state to output matrix with masking for edge tiles */
241
+ NK_INTERNAL void nk_dots_bf16_store_sapphireamx_( //
242
+ nk_dots_bf16_state_sapphireamx_t const *state, //
243
+ nk_f32_t *dst, nk_size_t dst_stride_elements, //
244
+ nk_size_t valid_rows, nk_size_t valid_cols) {
245
+
246
+ __mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
247
+
248
+ for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
249
+ __m512 row = _mm512_load_ps(state->data[row_idx]);
250
+ _mm512_mask_storeu_ps(dst + row_idx * dst_stride_elements, column_mask, row);
251
+ }
252
+ }
253
+
254
+ /* Accumulate 3 A x B tile pairs into state using AMX TDPBF16PS */
255
+ NK_INTERNAL void nk_dots_bf16_update_sapphireamx_( //
256
+ nk_dots_bf16_state_sapphireamx_t *state, //
257
+ nk_dots_bf16_a16x32_sapphireamx_t const *a_tile_0, //
258
+ nk_dots_bf16_a16x32_sapphireamx_t const *a_tile_1, //
259
+ nk_dots_bf16_a16x32_sapphireamx_t const *a_tile_2, //
260
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_0, //
261
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_1, //
262
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_2) {
263
+
264
+ // Load all tiles into registers
265
+ _tile_loadd(0, state->data, 64); // C accumulator
266
+ _tile_loadd(1, a_tile_0->data, 64); // A0
267
+ _tile_loadd(2, a_tile_1->data, 64); // A1
268
+ _tile_loadd(3, a_tile_2->data, 64); // A2
269
+ _tile_loadd(4, b_tile_0->data, 64); // B0
270
+ _tile_loadd(5, b_tile_1->data, 64); // B1
271
+ _tile_loadd(6, b_tile_2->data, 64); // B2
272
+
273
+ // Accumulate: C += A0 × B0 + A1 × B1 + A2 × B2
274
+ _tile_dpbf16ps(0, 1, 4); // C += A0 × B0
275
+ _tile_dpbf16ps(0, 2, 5); // C += A1 × B1
276
+ _tile_dpbf16ps(0, 3, 6); // C += A2 × B2
277
+
278
+ // Store result
279
+ _tile_stored(0, state->data, 64);
280
+ }
281
+
282
+ /* Initialize INT8 output state to zero */
283
+ NK_INTERNAL void nk_dots_i8_init_sapphireamx_(nk_dots_i8_state_sapphireamx_t *state) {
284
+ __m512i zero = _mm512_setzero_si512();
285
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_si512((__m512i *)state->data[row_idx], zero); }
286
+ }
287
+
288
+ /* Load A tile from row-major source with masking for edge tiles */
289
+ NK_INTERNAL void nk_dots_i8_load_a_sapphireamx_( //
290
+ nk_dots_i8_a16x64_sapphireamx_t *a_tile, //
291
+ nk_i8_t const *src, nk_size_t src_stride, //
292
+ nk_size_t valid_rows, nk_size_t valid_cols) {
293
+
294
+ __mmask64 column_mask = (valid_cols >= 64) ? 0xFFFFFFFFFFFFFFFFULL : ((__mmask64)1 << valid_cols) - 1;
295
+ __m512i zero = _mm512_setzero_si512();
296
+
297
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
298
+ if (row_idx < valid_rows) {
299
+ __m512i row = _mm512_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
300
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], row);
301
+ }
302
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
303
+ }
304
+ nk_compiler_barrier_sapphireamx_();
305
+ }
306
+
307
+ /* Store state to output matrix with masking for edge tiles */
308
+ NK_INTERNAL void nk_dots_i8_store_sapphireamx_( //
309
+ nk_dots_i8_state_sapphireamx_t const *state, //
310
+ nk_i32_t *dst, nk_size_t dst_stride_elements, //
311
+ nk_size_t valid_rows, nk_size_t valid_cols) {
312
+
313
+ __mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
314
+
315
+ for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
316
+ __m512i row = _mm512_load_si512((__m512i const *)state->data[row_idx]);
317
+ _mm512_mask_storeu_epi32(dst + row_idx * dst_stride_elements, column_mask, row);
318
+ }
319
+ }
320
+
321
+ /* Accumulate 3 A x B tile pairs into state using AMX TDPBSSD */
322
+ NK_INTERNAL void nk_dots_i8_update_sapphireamx_( //
323
+ nk_dots_i8_state_sapphireamx_t *state, //
324
+ nk_dots_i8_a16x64_sapphireamx_t const *a_tile_0, //
325
+ nk_dots_i8_a16x64_sapphireamx_t const *a_tile_1, //
326
+ nk_dots_i8_a16x64_sapphireamx_t const *a_tile_2, //
327
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_0, //
328
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_1, //
329
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_2) {
330
+
331
+ // Load all tiles into registers
332
+ _tile_loadd(0, state->data, 64); // C accumulator
333
+ _tile_loadd(1, a_tile_0->data, 64); // A0
334
+ _tile_loadd(2, a_tile_1->data, 64); // A1
335
+ _tile_loadd(3, a_tile_2->data, 64); // A2
336
+ _tile_loadd(4, b_tile_0->data, 64); // B0
337
+ _tile_loadd(5, b_tile_1->data, 64); // B1
338
+ _tile_loadd(6, b_tile_2->data, 64); // B2
339
+
340
+ // Accumulate: C += A0 × B0 + A1 × B1 + A2 × B2
341
+ _tile_dpbssd(0, 1, 4); // C += A0 × B0
342
+ _tile_dpbssd(0, 2, 5); // C += A1 × B1
343
+ _tile_dpbssd(0, 3, 6); // C += A2 × B2
344
+
345
+ // Store result
346
+ _tile_stored(0, state->data, 64);
347
+ }
348
+
349
+ /* Store BF16 2x2 state to output matrix with masking for edge tiles */
350
+ NK_INTERNAL void nk_dots_bf16_output2x2_sapphireamx_( //
351
+ nk_dots_bf16_state2x2_sapphireamx_t const *state, //
352
+ nk_f32_t *dst, nk_size_t dst_stride_elements, //
353
+ nk_size_t valid_rows, nk_size_t valid_cols) {
354
+
355
+ // Rows 0-15
356
+ nk_size_t const rows_upper = (valid_rows > 16) ? 16 : valid_rows;
357
+ nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
358
+ nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
359
+
360
+ if (rows_upper > 0 && cols_left > 0)
361
+ nk_dots_bf16_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_upper, cols_left);
362
+ if (rows_upper > 0 && cols_right > 0)
363
+ nk_dots_bf16_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_upper, cols_right);
364
+
365
+ // Rows 16-31
366
+ if (valid_rows > 16) {
367
+ nk_size_t const rows_lower = valid_rows - 16;
368
+ nk_f32_t *dst_lower = dst + 16 * dst_stride_elements;
369
+ if (cols_left > 0)
370
+ nk_dots_bf16_store_sapphireamx_(&state->c[1][0], dst_lower, dst_stride_elements, rows_lower, cols_left);
371
+ if (cols_right > 0)
372
+ nk_dots_bf16_store_sapphireamx_(&state->c[1][1], dst_lower + 16, dst_stride_elements, rows_lower,
373
+ cols_right);
374
+ }
375
+ }
376
+
377
+ /* Store INT8 2x2 state to output matrix with masking for edge tiles */
378
+ NK_INTERNAL void nk_dots_i8_output2x2_sapphireamx_( //
379
+ nk_dots_i8_state2x2_sapphireamx_t const *state, //
380
+ nk_i32_t *dst, nk_size_t dst_stride_elements, //
381
+ nk_size_t valid_rows, nk_size_t valid_cols) {
382
+
383
+ nk_size_t const rows_upper = (valid_rows > 16) ? 16 : valid_rows;
384
+ nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
385
+ nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
386
+
387
+ if (rows_upper > 0 && cols_left > 0)
388
+ nk_dots_i8_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_upper, cols_left);
389
+ if (rows_upper > 0 && cols_right > 0)
390
+ nk_dots_i8_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_upper, cols_right);
391
+
392
+ if (valid_rows > 16) {
393
+ nk_size_t const rows_lower = valid_rows - 16;
394
+ nk_i32_t *dst_lower = dst + 16 * dst_stride_elements;
395
+ if (cols_left > 0)
396
+ nk_dots_i8_store_sapphireamx_(&state->c[1][0], dst_lower, dst_stride_elements, rows_lower, cols_left);
397
+ if (cols_right > 0)
398
+ nk_dots_i8_store_sapphireamx_(&state->c[1][1], dst_lower + 16, dst_stride_elements, rows_lower, cols_right);
399
+ }
400
+ }
401
+
402
+ /* Initialize UINT8 output state to zero */
403
+ NK_INTERNAL void nk_dots_u8_init_sapphireamx_(nk_dots_u8_state_sapphireamx_t *state) {
404
+ nk_dots_i8_init_sapphireamx_((nk_dots_i8_state_sapphireamx_t *)state);
405
+ }
406
+
407
+ /* Load U8 A tile from row-major source with masking for edge tiles */
408
+ NK_INTERNAL void nk_dots_u8_load_a_sapphireamx_( //
409
+ nk_dots_u8_a16x64_sapphireamx_t *a_tile, //
410
+ nk_u8_t const *src, nk_size_t src_stride, //
411
+ nk_size_t valid_rows, nk_size_t valid_cols) {
412
+ nk_dots_i8_load_a_sapphireamx_( //
413
+ (nk_dots_i8_a16x64_sapphireamx_t *)a_tile, //
414
+ (nk_i8_t const *)src, src_stride, valid_rows, valid_cols);
415
+ }
416
+
417
+ /* Store U8 state to output matrix with masking for edge tiles */
418
+ NK_INTERNAL void nk_dots_u8_store_sapphireamx_( //
419
+ nk_dots_u8_state_sapphireamx_t const *state, //
420
+ nk_u32_t *dst, nk_size_t dst_stride_elements, //
421
+ nk_size_t valid_rows, nk_size_t valid_cols) {
422
+ nk_dots_i8_store_sapphireamx_( //
423
+ (nk_dots_i8_state_sapphireamx_t const *)state, //
424
+ (nk_i32_t *)dst, dst_stride_elements, valid_rows, valid_cols);
425
+ }
426
+
427
+ /* Store UINT8 2x2 state to output matrix with masking for edge tiles */
428
+ NK_INTERNAL void nk_dots_u8_output2x2_sapphireamx_( //
429
+ nk_dots_u8_state2x2_sapphireamx_t const *state, //
430
+ nk_u32_t *dst, nk_size_t dst_stride_elements, //
431
+ nk_size_t valid_rows, nk_size_t valid_cols) {
432
+ nk_dots_i8_output2x2_sapphireamx_( //
433
+ (nk_dots_i8_state2x2_sapphireamx_t const *)state, //
434
+ (nk_i32_t *)dst, dst_stride_elements, valid_rows, valid_cols);
435
+ }
436
+
437
+ /* Pack U8 A transposed into B format */
438
+ NK_INTERNAL void nk_dots_pack_u8_transposed_sapphireamx_( //
439
+ nk_dots_u8_a16x64_sapphireamx_t const *a_tile, //
440
+ nk_dots_u8_b64x16_sapphireamx_t *b_tile) {
441
+
442
+ // Load all 16 rows - each row is 64 UINT8 = 64 bytes = 1 ZMM
443
+ // Treat as 16 × 32-bit elements per row (each 32-bit = quad of UINT8)
444
+ __m512i row00 = _mm512_load_si512(&a_tile->data[0][0]);
445
+ __m512i row01 = _mm512_load_si512(&a_tile->data[1][0]);
446
+ __m512i row02 = _mm512_load_si512(&a_tile->data[2][0]);
447
+ __m512i row03 = _mm512_load_si512(&a_tile->data[3][0]);
448
+ __m512i row04 = _mm512_load_si512(&a_tile->data[4][0]);
449
+ __m512i row05 = _mm512_load_si512(&a_tile->data[5][0]);
450
+ __m512i row06 = _mm512_load_si512(&a_tile->data[6][0]);
451
+ __m512i row07 = _mm512_load_si512(&a_tile->data[7][0]);
452
+ __m512i row08 = _mm512_load_si512(&a_tile->data[8][0]);
453
+ __m512i row09 = _mm512_load_si512(&a_tile->data[9][0]);
454
+ __m512i row10 = _mm512_load_si512(&a_tile->data[10][0]);
455
+ __m512i row11 = _mm512_load_si512(&a_tile->data[11][0]);
456
+ __m512i row12 = _mm512_load_si512(&a_tile->data[12][0]);
457
+ __m512i row13 = _mm512_load_si512(&a_tile->data[13][0]);
458
+ __m512i row14 = _mm512_load_si512(&a_tile->data[14][0]);
459
+ __m512i row15 = _mm512_load_si512(&a_tile->data[15][0]);
460
+
461
+ // 16×16 transpose of 32-bit elements using hierarchical unpacks
462
+ // Stage 1: Unpack adjacent row pairs at 32-bit granularity
463
+ __m512i t01_lo = _mm512_unpacklo_epi32(row00, row01);
464
+ __m512i t01_hi = _mm512_unpackhi_epi32(row00, row01);
465
+ __m512i t23_lo = _mm512_unpacklo_epi32(row02, row03);
466
+ __m512i t23_hi = _mm512_unpackhi_epi32(row02, row03);
467
+ __m512i t45_lo = _mm512_unpacklo_epi32(row04, row05);
468
+ __m512i t45_hi = _mm512_unpackhi_epi32(row04, row05);
469
+ __m512i t67_lo = _mm512_unpacklo_epi32(row06, row07);
470
+ __m512i t67_hi = _mm512_unpackhi_epi32(row06, row07);
471
+ __m512i t89_lo = _mm512_unpacklo_epi32(row08, row09);
472
+ __m512i t89_hi = _mm512_unpackhi_epi32(row08, row09);
473
+ __m512i tab_lo = _mm512_unpacklo_epi32(row10, row11);
474
+ __m512i tab_hi = _mm512_unpackhi_epi32(row10, row11);
475
+ __m512i tcd_lo = _mm512_unpacklo_epi32(row12, row13);
476
+ __m512i tcd_hi = _mm512_unpackhi_epi32(row12, row13);
477
+ __m512i tef_lo = _mm512_unpacklo_epi32(row14, row15);
478
+ __m512i tef_hi = _mm512_unpackhi_epi32(row14, row15);
479
+
480
+ // Stage 2: Unpack at 64-bit granularity
481
+ __m512i u0123_ll = _mm512_unpacklo_epi64(t01_lo, t23_lo);
482
+ __m512i u0123_lh = _mm512_unpackhi_epi64(t01_lo, t23_lo);
483
+ __m512i u0123_hl = _mm512_unpacklo_epi64(t01_hi, t23_hi);
484
+ __m512i u0123_hh = _mm512_unpackhi_epi64(t01_hi, t23_hi);
485
+ __m512i u4567_ll = _mm512_unpacklo_epi64(t45_lo, t67_lo);
486
+ __m512i u4567_lh = _mm512_unpackhi_epi64(t45_lo, t67_lo);
487
+ __m512i u4567_hl = _mm512_unpacklo_epi64(t45_hi, t67_hi);
488
+ __m512i u4567_hh = _mm512_unpackhi_epi64(t45_hi, t67_hi);
489
+ __m512i u89ab_ll = _mm512_unpacklo_epi64(t89_lo, tab_lo);
490
+ __m512i u89ab_lh = _mm512_unpackhi_epi64(t89_lo, tab_lo);
491
+ __m512i u89ab_hl = _mm512_unpacklo_epi64(t89_hi, tab_hi);
492
+ __m512i u89ab_hh = _mm512_unpackhi_epi64(t89_hi, tab_hi);
493
+ __m512i ucdef_ll = _mm512_unpacklo_epi64(tcd_lo, tef_lo);
494
+ __m512i ucdef_lh = _mm512_unpackhi_epi64(tcd_lo, tef_lo);
495
+ __m512i ucdef_hl = _mm512_unpacklo_epi64(tcd_hi, tef_hi);
496
+ __m512i ucdef_hh = _mm512_unpackhi_epi64(tcd_hi, tef_hi);
497
+
498
+ // Stage 3: Shuffle 128-bit lanes
499
+ __m512i v0_a = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0x88);
500
+ __m512i v0_b = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0xDD);
501
+ __m512i v1_a = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0x88);
502
+ __m512i v1_b = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0xDD);
503
+ __m512i v2_a = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0x88);
504
+ __m512i v2_b = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0xDD);
505
+ __m512i v3_a = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0x88);
506
+ __m512i v3_b = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0xDD);
507
+ __m512i v4_a = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0x88);
508
+ __m512i v4_b = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0xDD);
509
+ __m512i v5_a = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0x88);
510
+ __m512i v5_b = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0xDD);
511
+ __m512i v6_a = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0x88);
512
+ __m512i v6_b = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0xDD);
513
+ __m512i v7_a = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0x88);
514
+ __m512i v7_b = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0xDD);
515
+
516
+ // Stage 4: Final 256-bit shuffle to complete transpose
517
+ __m512i out00 = _mm512_shuffle_i32x4(v0_a, v4_a, 0x88);
518
+ __m512i out01 = _mm512_shuffle_i32x4(v1_a, v5_a, 0x88);
519
+ __m512i out02 = _mm512_shuffle_i32x4(v2_a, v6_a, 0x88);
520
+ __m512i out03 = _mm512_shuffle_i32x4(v3_a, v7_a, 0x88);
521
+ __m512i out04 = _mm512_shuffle_i32x4(v0_a, v4_a, 0xDD);
522
+ __m512i out05 = _mm512_shuffle_i32x4(v1_a, v5_a, 0xDD);
523
+ __m512i out06 = _mm512_shuffle_i32x4(v2_a, v6_a, 0xDD);
524
+ __m512i out07 = _mm512_shuffle_i32x4(v3_a, v7_a, 0xDD);
525
+ __m512i out08 = _mm512_shuffle_i32x4(v0_b, v4_b, 0x88);
526
+ __m512i out09 = _mm512_shuffle_i32x4(v1_b, v5_b, 0x88);
527
+ __m512i out10 = _mm512_shuffle_i32x4(v2_b, v6_b, 0x88);
528
+ __m512i out11 = _mm512_shuffle_i32x4(v3_b, v7_b, 0x88);
529
+ __m512i out12 = _mm512_shuffle_i32x4(v0_b, v4_b, 0xDD);
530
+ __m512i out13 = _mm512_shuffle_i32x4(v1_b, v5_b, 0xDD);
531
+ __m512i out14 = _mm512_shuffle_i32x4(v2_b, v6_b, 0xDD);
532
+ __m512i out15 = _mm512_shuffle_i32x4(v3_b, v7_b, 0xDD);
533
+
534
+ // Store transposed results - each output row is one depth_group
535
+ // Output layout: B.data[depth_group][column][quad] = 16 columns × 4 UINT8 = 64 bytes
536
+ _mm512_store_si512(&b_tile->data[0][0][0], out00);
537
+ _mm512_store_si512(&b_tile->data[1][0][0], out01);
538
+ _mm512_store_si512(&b_tile->data[2][0][0], out02);
539
+ _mm512_store_si512(&b_tile->data[3][0][0], out03);
540
+ _mm512_store_si512(&b_tile->data[4][0][0], out08);
541
+ _mm512_store_si512(&b_tile->data[5][0][0], out09);
542
+ _mm512_store_si512(&b_tile->data[6][0][0], out10);
543
+ _mm512_store_si512(&b_tile->data[7][0][0], out11);
544
+ _mm512_store_si512(&b_tile->data[8][0][0], out04);
545
+ _mm512_store_si512(&b_tile->data[9][0][0], out05);
546
+ _mm512_store_si512(&b_tile->data[10][0][0], out06);
547
+ _mm512_store_si512(&b_tile->data[11][0][0], out07);
548
+ _mm512_store_si512(&b_tile->data[12][0][0], out12);
549
+ _mm512_store_si512(&b_tile->data[13][0][0], out13);
550
+ _mm512_store_si512(&b_tile->data[14][0][0], out14);
551
+ _mm512_store_si512(&b_tile->data[15][0][0], out15);
552
+
553
+ nk_compiler_barrier_sapphireamx_();
554
+ }
555
+
556
+ /* Accumulate 3 A x B tile pairs into state using AMX TDPBUUD */
557
+ NK_INTERNAL void nk_dots_u8_update_sapphireamx_( //
558
+ nk_dots_u8_state_sapphireamx_t *state, //
559
+ nk_dots_u8_a16x64_sapphireamx_t const *a_tile_0, //
560
+ nk_dots_u8_a16x64_sapphireamx_t const *a_tile_1, //
561
+ nk_dots_u8_a16x64_sapphireamx_t const *a_tile_2, //
562
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile_0, //
563
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile_1, //
564
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile_2) {
565
+
566
+ // Load all tiles into registers
567
+ _tile_loadd(0, state->data, 64); // C accumulator
568
+ _tile_loadd(1, a_tile_0->data, 64); // A0
569
+ _tile_loadd(2, a_tile_1->data, 64); // A1
570
+ _tile_loadd(3, a_tile_2->data, 64); // A2
571
+ _tile_loadd(4, b_tile_0->data, 64); // B0
572
+ _tile_loadd(5, b_tile_1->data, 64); // B1
573
+ _tile_loadd(6, b_tile_2->data, 64); // B2
574
+
575
+ // Accumulate: C += A0 × B0 + A1 × B1 + A2 × B2
576
+ _tile_dpbuud(0, 1, 4); // C += A0 × B0
577
+ _tile_dpbuud(0, 2, 5); // C += A1 × B1
578
+ _tile_dpbuud(0, 3, 6); // C += A2 × B2
579
+
580
+ // Store result
581
+ _tile_stored(0, state->data, 64);
582
+ }
583
+
584
+ /* Load E4M3 A tile with FP8 to BF16 conversion */
585
+ NK_INTERNAL void nk_dots_e4m3_load_a_sapphireamx_( //
586
+ nk_dots_bf16_a16x32_sapphireamx_t *a_tile, //
587
+ nk_e4m3_t const *src, nk_size_t src_stride, //
588
+ nk_size_t valid_rows, nk_size_t valid_cols) {
589
+
590
+ __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
591
+ __m512i zero = _mm512_setzero_si512();
592
+
593
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
594
+ if (row_idx < valid_rows) {
595
+ // Load 32 E4M3 bytes with masking
596
+ __m256i e4m3_row = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
597
+ // Convert to 32 BF16 values
598
+ __m512i bf16_row = nk_e4m3x32_to_bf16x32_icelake_(e4m3_row);
599
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row);
600
+ }
601
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
602
+ }
603
+ nk_compiler_barrier_sapphireamx_();
604
+ }
605
+
606
+ /* Load E5M2 A tile with FP8 to BF16 conversion */
607
+ NK_INTERNAL void nk_dots_e5m2_load_a_sapphireamx_( //
608
+ nk_dots_bf16_a16x32_sapphireamx_t *a_tile, //
609
+ nk_e5m2_t const *src, nk_size_t src_stride, //
610
+ nk_size_t valid_rows, nk_size_t valid_cols) {
611
+
612
+ __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
613
+ __m512i zero = _mm512_setzero_si512();
614
+
615
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
616
+ if (row_idx < valid_rows) {
617
+ __m256i e5m2_row = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
618
+ __m512i bf16_row = nk_e5m2x32_to_bf16x32_icelake_(e5m2_row);
619
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row);
620
+ }
621
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
622
+ }
623
+ nk_compiler_barrier_sapphireamx_();
624
+ }
625
+
626
+ /* Pack A transposed into B format for BF16 */
627
+ NK_INTERNAL void nk_dots_pack_bf16_transposed_sapphireamx_( //
628
+ nk_dots_bf16_a16x32_sapphireamx_t const *a_tile, //
629
+ nk_dots_bf16_b32x16_sapphireamx_t *b_tile) {
630
+
631
+ // Load all 16 rows - each row is 32 BF16 = 64 bytes = 1 ZMM
632
+ // Treat as 16 × 32-bit elements per row (each 32-bit = pair of BF16)
633
+ __m512i row00 = _mm512_load_si512(&a_tile->data[0][0]);
634
+ __m512i row01 = _mm512_load_si512(&a_tile->data[1][0]);
635
+ __m512i row02 = _mm512_load_si512(&a_tile->data[2][0]);
636
+ __m512i row03 = _mm512_load_si512(&a_tile->data[3][0]);
637
+ __m512i row04 = _mm512_load_si512(&a_tile->data[4][0]);
638
+ __m512i row05 = _mm512_load_si512(&a_tile->data[5][0]);
639
+ __m512i row06 = _mm512_load_si512(&a_tile->data[6][0]);
640
+ __m512i row07 = _mm512_load_si512(&a_tile->data[7][0]);
641
+ __m512i row08 = _mm512_load_si512(&a_tile->data[8][0]);
642
+ __m512i row09 = _mm512_load_si512(&a_tile->data[9][0]);
643
+ __m512i row10 = _mm512_load_si512(&a_tile->data[10][0]);
644
+ __m512i row11 = _mm512_load_si512(&a_tile->data[11][0]);
645
+ __m512i row12 = _mm512_load_si512(&a_tile->data[12][0]);
646
+ __m512i row13 = _mm512_load_si512(&a_tile->data[13][0]);
647
+ __m512i row14 = _mm512_load_si512(&a_tile->data[14][0]);
648
+ __m512i row15 = _mm512_load_si512(&a_tile->data[15][0]);
649
+
650
+ // 16×16 transpose of 32-bit elements using hierarchical unpacks
651
+ // Stage 1: Unpack adjacent row pairs at 32-bit granularity
652
+ __m512i t01_lo = _mm512_unpacklo_epi32(row00, row01);
653
+ __m512i t01_hi = _mm512_unpackhi_epi32(row00, row01);
654
+ __m512i t23_lo = _mm512_unpacklo_epi32(row02, row03);
655
+ __m512i t23_hi = _mm512_unpackhi_epi32(row02, row03);
656
+ __m512i t45_lo = _mm512_unpacklo_epi32(row04, row05);
657
+ __m512i t45_hi = _mm512_unpackhi_epi32(row04, row05);
658
+ __m512i t67_lo = _mm512_unpacklo_epi32(row06, row07);
659
+ __m512i t67_hi = _mm512_unpackhi_epi32(row06, row07);
660
+ __m512i t89_lo = _mm512_unpacklo_epi32(row08, row09);
661
+ __m512i t89_hi = _mm512_unpackhi_epi32(row08, row09);
662
+ __m512i tab_lo = _mm512_unpacklo_epi32(row10, row11);
663
+ __m512i tab_hi = _mm512_unpackhi_epi32(row10, row11);
664
+ __m512i tcd_lo = _mm512_unpacklo_epi32(row12, row13);
665
+ __m512i tcd_hi = _mm512_unpackhi_epi32(row12, row13);
666
+ __m512i tef_lo = _mm512_unpacklo_epi32(row14, row15);
667
+ __m512i tef_hi = _mm512_unpackhi_epi32(row14, row15);
668
+
669
+ // Stage 2: Unpack at 64-bit granularity
670
+ __m512i u0123_ll = _mm512_unpacklo_epi64(t01_lo, t23_lo);
671
+ __m512i u0123_lh = _mm512_unpackhi_epi64(t01_lo, t23_lo);
672
+ __m512i u0123_hl = _mm512_unpacklo_epi64(t01_hi, t23_hi);
673
+ __m512i u0123_hh = _mm512_unpackhi_epi64(t01_hi, t23_hi);
674
+ __m512i u4567_ll = _mm512_unpacklo_epi64(t45_lo, t67_lo);
675
+ __m512i u4567_lh = _mm512_unpackhi_epi64(t45_lo, t67_lo);
676
+ __m512i u4567_hl = _mm512_unpacklo_epi64(t45_hi, t67_hi);
677
+ __m512i u4567_hh = _mm512_unpackhi_epi64(t45_hi, t67_hi);
678
+ __m512i u89ab_ll = _mm512_unpacklo_epi64(t89_lo, tab_lo);
679
+ __m512i u89ab_lh = _mm512_unpackhi_epi64(t89_lo, tab_lo);
680
+ __m512i u89ab_hl = _mm512_unpacklo_epi64(t89_hi, tab_hi);
681
+ __m512i u89ab_hh = _mm512_unpackhi_epi64(t89_hi, tab_hi);
682
+ __m512i ucdef_ll = _mm512_unpacklo_epi64(tcd_lo, tef_lo);
683
+ __m512i ucdef_lh = _mm512_unpackhi_epi64(tcd_lo, tef_lo);
684
+ __m512i ucdef_hl = _mm512_unpacklo_epi64(tcd_hi, tef_hi);
685
+ __m512i ucdef_hh = _mm512_unpackhi_epi64(tcd_hi, tef_hi);
686
+
687
+ // Stage 3: Shuffle 128-bit lanes using permute2x128 equivalent for 512-bit
688
+ // Use shuffle_i32x4 to move 128-bit chunks
689
+ __m512i v0_a = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0x88); // lanes 0,2 from each
690
+ __m512i v0_b = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0xDD); // lanes 1,3 from each
691
+ __m512i v1_a = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0x88);
692
+ __m512i v1_b = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0xDD);
693
+ __m512i v2_a = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0x88);
694
+ __m512i v2_b = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0xDD);
695
+ __m512i v3_a = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0x88);
696
+ __m512i v3_b = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0xDD);
697
+ __m512i v4_a = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0x88);
698
+ __m512i v4_b = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0xDD);
699
+ __m512i v5_a = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0x88);
700
+ __m512i v5_b = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0xDD);
701
+ __m512i v6_a = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0x88);
702
+ __m512i v6_b = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0xDD);
703
+ __m512i v7_a = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0x88);
704
+ __m512i v7_b = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0xDD);
705
+
706
+ // Stage 4: Final 256-bit shuffle to complete transpose
707
+ __m512i out00 = _mm512_shuffle_i32x4(v0_a, v4_a, 0x88);
708
+ __m512i out01 = _mm512_shuffle_i32x4(v1_a, v5_a, 0x88);
709
+ __m512i out02 = _mm512_shuffle_i32x4(v2_a, v6_a, 0x88);
710
+ __m512i out03 = _mm512_shuffle_i32x4(v3_a, v7_a, 0x88);
711
+ __m512i out04 = _mm512_shuffle_i32x4(v0_a, v4_a, 0xDD);
712
+ __m512i out05 = _mm512_shuffle_i32x4(v1_a, v5_a, 0xDD);
713
+ __m512i out06 = _mm512_shuffle_i32x4(v2_a, v6_a, 0xDD);
714
+ __m512i out07 = _mm512_shuffle_i32x4(v3_a, v7_a, 0xDD);
715
+ __m512i out08 = _mm512_shuffle_i32x4(v0_b, v4_b, 0x88);
716
+ __m512i out09 = _mm512_shuffle_i32x4(v1_b, v5_b, 0x88);
717
+ __m512i out10 = _mm512_shuffle_i32x4(v2_b, v6_b, 0x88);
718
+ __m512i out11 = _mm512_shuffle_i32x4(v3_b, v7_b, 0x88);
719
+ __m512i out12 = _mm512_shuffle_i32x4(v0_b, v4_b, 0xDD);
720
+ __m512i out13 = _mm512_shuffle_i32x4(v1_b, v5_b, 0xDD);
721
+ __m512i out14 = _mm512_shuffle_i32x4(v2_b, v6_b, 0xDD);
722
+ __m512i out15 = _mm512_shuffle_i32x4(v3_b, v7_b, 0xDD);
723
+
724
+ // Store transposed results - each output row is one depth_group
725
+ // Output layout: B.data[depth_group][column][pair] = 16 columns × 2 BF16 = 64 bytes
726
+ _mm512_store_si512(&b_tile->data[0][0][0], out00);
727
+ _mm512_store_si512(&b_tile->data[1][0][0], out01);
728
+ _mm512_store_si512(&b_tile->data[2][0][0], out02);
729
+ _mm512_store_si512(&b_tile->data[3][0][0], out03);
730
+ _mm512_store_si512(&b_tile->data[4][0][0], out08);
731
+ _mm512_store_si512(&b_tile->data[5][0][0], out09);
732
+ _mm512_store_si512(&b_tile->data[6][0][0], out10);
733
+ _mm512_store_si512(&b_tile->data[7][0][0], out11);
734
+ _mm512_store_si512(&b_tile->data[8][0][0], out04);
735
+ _mm512_store_si512(&b_tile->data[9][0][0], out05);
736
+ _mm512_store_si512(&b_tile->data[10][0][0], out06);
737
+ _mm512_store_si512(&b_tile->data[11][0][0], out07);
738
+ _mm512_store_si512(&b_tile->data[12][0][0], out12);
739
+ _mm512_store_si512(&b_tile->data[13][0][0], out13);
740
+ _mm512_store_si512(&b_tile->data[14][0][0], out14);
741
+ _mm512_store_si512(&b_tile->data[15][0][0], out15);
742
+
743
+ nk_compiler_barrier_sapphireamx_();
744
+ }
745
+
746
+ /* Pack A transposed into B format for INT8 */
747
+ NK_INTERNAL void nk_dots_pack_i8_transposed_sapphireamx_( //
748
+ nk_dots_i8_a16x64_sapphireamx_t const *a_tile, //
749
+ nk_dots_i8_b64x16_sapphireamx_t *b_tile) {
750
+
751
+ // Load all 16 rows - each row is 64 INT8 = 64 bytes = 1 ZMM
752
+ // Treat as 16 × 32-bit elements per row (each 32-bit = quad of INT8)
753
+ __m512i row00 = _mm512_load_si512(&a_tile->data[0][0]);
754
+ __m512i row01 = _mm512_load_si512(&a_tile->data[1][0]);
755
+ __m512i row02 = _mm512_load_si512(&a_tile->data[2][0]);
756
+ __m512i row03 = _mm512_load_si512(&a_tile->data[3][0]);
757
+ __m512i row04 = _mm512_load_si512(&a_tile->data[4][0]);
758
+ __m512i row05 = _mm512_load_si512(&a_tile->data[5][0]);
759
+ __m512i row06 = _mm512_load_si512(&a_tile->data[6][0]);
760
+ __m512i row07 = _mm512_load_si512(&a_tile->data[7][0]);
761
+ __m512i row08 = _mm512_load_si512(&a_tile->data[8][0]);
762
+ __m512i row09 = _mm512_load_si512(&a_tile->data[9][0]);
763
+ __m512i row10 = _mm512_load_si512(&a_tile->data[10][0]);
764
+ __m512i row11 = _mm512_load_si512(&a_tile->data[11][0]);
765
+ __m512i row12 = _mm512_load_si512(&a_tile->data[12][0]);
766
+ __m512i row13 = _mm512_load_si512(&a_tile->data[13][0]);
767
+ __m512i row14 = _mm512_load_si512(&a_tile->data[14][0]);
768
+ __m512i row15 = _mm512_load_si512(&a_tile->data[15][0]);
769
+
770
+ // 16×16 transpose of 32-bit elements using hierarchical unpacks
771
+ // Stage 1: Unpack adjacent row pairs at 32-bit granularity
772
+ __m512i t01_lo = _mm512_unpacklo_epi32(row00, row01);
773
+ __m512i t01_hi = _mm512_unpackhi_epi32(row00, row01);
774
+ __m512i t23_lo = _mm512_unpacklo_epi32(row02, row03);
775
+ __m512i t23_hi = _mm512_unpackhi_epi32(row02, row03);
776
+ __m512i t45_lo = _mm512_unpacklo_epi32(row04, row05);
777
+ __m512i t45_hi = _mm512_unpackhi_epi32(row04, row05);
778
+ __m512i t67_lo = _mm512_unpacklo_epi32(row06, row07);
779
+ __m512i t67_hi = _mm512_unpackhi_epi32(row06, row07);
780
+ __m512i t89_lo = _mm512_unpacklo_epi32(row08, row09);
781
+ __m512i t89_hi = _mm512_unpackhi_epi32(row08, row09);
782
+ __m512i tab_lo = _mm512_unpacklo_epi32(row10, row11);
783
+ __m512i tab_hi = _mm512_unpackhi_epi32(row10, row11);
784
+ __m512i tcd_lo = _mm512_unpacklo_epi32(row12, row13);
785
+ __m512i tcd_hi = _mm512_unpackhi_epi32(row12, row13);
786
+ __m512i tef_lo = _mm512_unpacklo_epi32(row14, row15);
787
+ __m512i tef_hi = _mm512_unpackhi_epi32(row14, row15);
788
+
789
+ // Stage 2: Unpack at 64-bit granularity
790
+ __m512i u0123_ll = _mm512_unpacklo_epi64(t01_lo, t23_lo);
791
+ __m512i u0123_lh = _mm512_unpackhi_epi64(t01_lo, t23_lo);
792
+ __m512i u0123_hl = _mm512_unpacklo_epi64(t01_hi, t23_hi);
793
+ __m512i u0123_hh = _mm512_unpackhi_epi64(t01_hi, t23_hi);
794
+ __m512i u4567_ll = _mm512_unpacklo_epi64(t45_lo, t67_lo);
795
+ __m512i u4567_lh = _mm512_unpackhi_epi64(t45_lo, t67_lo);
796
+ __m512i u4567_hl = _mm512_unpacklo_epi64(t45_hi, t67_hi);
797
+ __m512i u4567_hh = _mm512_unpackhi_epi64(t45_hi, t67_hi);
798
+ __m512i u89ab_ll = _mm512_unpacklo_epi64(t89_lo, tab_lo);
799
+ __m512i u89ab_lh = _mm512_unpackhi_epi64(t89_lo, tab_lo);
800
+ __m512i u89ab_hl = _mm512_unpacklo_epi64(t89_hi, tab_hi);
801
+ __m512i u89ab_hh = _mm512_unpackhi_epi64(t89_hi, tab_hi);
802
+ __m512i ucdef_ll = _mm512_unpacklo_epi64(tcd_lo, tef_lo);
803
+ __m512i ucdef_lh = _mm512_unpackhi_epi64(tcd_lo, tef_lo);
804
+ __m512i ucdef_hl = _mm512_unpacklo_epi64(tcd_hi, tef_hi);
805
+ __m512i ucdef_hh = _mm512_unpackhi_epi64(tcd_hi, tef_hi);
806
+
807
+ // Stage 3: Shuffle 128-bit lanes
808
+ __m512i v0_a = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0x88);
809
+ __m512i v0_b = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0xDD);
810
+ __m512i v1_a = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0x88);
811
+ __m512i v1_b = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0xDD);
812
+ __m512i v2_a = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0x88);
813
+ __m512i v2_b = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0xDD);
814
+ __m512i v3_a = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0x88);
815
+ __m512i v3_b = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0xDD);
816
+ __m512i v4_a = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0x88);
817
+ __m512i v4_b = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0xDD);
818
+ __m512i v5_a = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0x88);
819
+ __m512i v5_b = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0xDD);
820
+ __m512i v6_a = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0x88);
821
+ __m512i v6_b = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0xDD);
822
+ __m512i v7_a = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0x88);
823
+ __m512i v7_b = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0xDD);
824
+
825
+ // Stage 4: Final 256-bit shuffle to complete transpose
826
+ __m512i out00 = _mm512_shuffle_i32x4(v0_a, v4_a, 0x88);
827
+ __m512i out01 = _mm512_shuffle_i32x4(v1_a, v5_a, 0x88);
828
+ __m512i out02 = _mm512_shuffle_i32x4(v2_a, v6_a, 0x88);
829
+ __m512i out03 = _mm512_shuffle_i32x4(v3_a, v7_a, 0x88);
830
+ __m512i out04 = _mm512_shuffle_i32x4(v0_a, v4_a, 0xDD);
831
+ __m512i out05 = _mm512_shuffle_i32x4(v1_a, v5_a, 0xDD);
832
+ __m512i out06 = _mm512_shuffle_i32x4(v2_a, v6_a, 0xDD);
833
+ __m512i out07 = _mm512_shuffle_i32x4(v3_a, v7_a, 0xDD);
834
+ __m512i out08 = _mm512_shuffle_i32x4(v0_b, v4_b, 0x88);
835
+ __m512i out09 = _mm512_shuffle_i32x4(v1_b, v5_b, 0x88);
836
+ __m512i out10 = _mm512_shuffle_i32x4(v2_b, v6_b, 0x88);
837
+ __m512i out11 = _mm512_shuffle_i32x4(v3_b, v7_b, 0x88);
838
+ __m512i out12 = _mm512_shuffle_i32x4(v0_b, v4_b, 0xDD);
839
+ __m512i out13 = _mm512_shuffle_i32x4(v1_b, v5_b, 0xDD);
840
+ __m512i out14 = _mm512_shuffle_i32x4(v2_b, v6_b, 0xDD);
841
+ __m512i out15 = _mm512_shuffle_i32x4(v3_b, v7_b, 0xDD);
842
+
843
+ // Store transposed results - each output row is one depth_group
844
+ // Output layout: B.data[depth_group][column][quad] = 16 columns × 4 INT8 = 64 bytes
845
+ _mm512_store_si512(&b_tile->data[0][0][0], out00);
846
+ _mm512_store_si512(&b_tile->data[1][0][0], out01);
847
+ _mm512_store_si512(&b_tile->data[2][0][0], out02);
848
+ _mm512_store_si512(&b_tile->data[3][0][0], out03);
849
+ _mm512_store_si512(&b_tile->data[4][0][0], out08);
850
+ _mm512_store_si512(&b_tile->data[5][0][0], out09);
851
+ _mm512_store_si512(&b_tile->data[6][0][0], out10);
852
+ _mm512_store_si512(&b_tile->data[7][0][0], out11);
853
+ _mm512_store_si512(&b_tile->data[8][0][0], out04);
854
+ _mm512_store_si512(&b_tile->data[9][0][0], out05);
855
+ _mm512_store_si512(&b_tile->data[10][0][0], out06);
856
+ _mm512_store_si512(&b_tile->data[11][0][0], out07);
857
+ _mm512_store_si512(&b_tile->data[12][0][0], out12);
858
+ _mm512_store_si512(&b_tile->data[13][0][0], out13);
859
+ _mm512_store_si512(&b_tile->data[14][0][0], out14);
860
+ _mm512_store_si512(&b_tile->data[15][0][0], out15);
861
+
862
+ nk_compiler_barrier_sapphireamx_();
863
+ }
864
+
865
+ #pragma region Half Precision Floats
866
+
867
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sapphireamx(nk_size_t column_count, nk_size_t depth) {
868
+ nk_size_t const tmm_rows = 16;
869
+ nk_size_t const tmm_cols = 32;
870
+ nk_size_t const tile_bytes = 512 * sizeof(nk_bf16_t); // 16 × 32 × 2 = 1KB
871
+
872
+ nk_size_t const full_column_tiles = column_count / tmm_rows;
873
+ nk_size_t const tiles_along_depth = nk_size_divide_round_up_(depth, tmm_cols);
874
+ nk_size_t const column_remainder_count = column_count - full_column_tiles * tmm_rows;
875
+
876
+ // Header (64 bytes aligned)
877
+ nk_size_t size = sizeof(nk_dots_amx_packed_header_t);
878
+
879
+ // All tiles for full column rows (Morton-ordered, pair-interleaved, depth remainder zero-padded)
880
+ size += full_column_tiles * tiles_along_depth * tile_bytes;
881
+
882
+ // Column edge: remaining rows for ALL depth columns, stored row-major
883
+ if (column_remainder_count > 0) size += column_remainder_count * depth * sizeof(nk_bf16_t);
884
+
885
+ // Per-column norms for angular/euclidean distance (4 bytes each: f32 or u32)
886
+ size += column_count * sizeof(nk_f32_t);
887
+
888
+ return size;
889
+ }
890
+
891
+ NK_PUBLIC void nk_dots_pack_bf16_sapphireamx( //
892
+ nk_bf16_t const *b, nk_size_t column_count, nk_size_t depth, //
893
+ nk_size_t b_stride, void *b_packed) {
894
+
895
+ // AMX BF16 tile dimensions: 16 rows × 32 columns (512 BF16 elements = 1KB)
896
+ nk_size_t const tmm_rows = 16;
897
+ nk_size_t const tmm_cols = 32;
898
+ nk_size_t const tile_elements = 512;
899
+ nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
900
+ nk_size_t const b_stride_elements = b_stride / sizeof(nk_bf16_t);
901
+
902
+ // Compute layout dimensions
903
+ nk_size_t const column_tiles_count = column_count / tmm_rows;
904
+ nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
905
+ nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
906
+ nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
907
+
908
+ // Write header with layout metadata
909
+ nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
910
+ header->full_column_tiles = (nk_u32_t)column_tiles_count;
911
+ header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
912
+ header->column_remainder_count = (nk_u32_t)column_remainder_count;
913
+
914
+ // Compute memory region offsets
915
+ nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
916
+ nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
917
+ header->column_edge_offset = (nk_u32_t)column_edge_offset;
918
+
919
+ // Pointers to packed data regions
920
+ nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
921
+ nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
922
+
923
+ // Zero-initialize all tiles (handles depth remainder padding)
924
+ for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
925
+
926
+ // Pack tiles using LINEAR ordering: tile_index = column_tile × depth_tiles_count + depth_tile
927
+ // This provides sequential memory access when streaming along depth dimension,
928
+ // which is critical for cache efficiency in the compute kernel.
929
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
930
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
931
+
932
+ // Linear tile index: all depth-tiles for one column-tile are contiguous
933
+ nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
934
+ nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
935
+
936
+ // Source coordinates in original B matrix
937
+ nk_size_t const src_row_start = column_tile_idx * tmm_rows;
938
+ nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
939
+ nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
940
+ : (depth - src_column_start);
941
+
942
+ // Pack with pair-interleaving as required by TDPBF16PS instruction.
943
+ // AMX expects: [col0_row0, col1_row0, col0_row1, col1_row1, col2_row0, col3_row0, ...]
944
+ // Formula: packed_idx = (column / 2) × 32 + row × 2 + (column % 2)
945
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
946
+ for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
947
+ nk_size_t const src_idx = (src_row_start + row_idx) * b_stride_elements + src_column_start +
948
+ column_idx;
949
+ nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
950
+ tile_output[dst_idx] = b[src_idx];
951
+ }
952
+ }
953
+ }
954
+ }
955
+
956
+ // Pack column-remainder rows in simple row-major format (for AVX-512 fallback)
957
+ if (column_remainder_count > 0) {
958
+ nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
959
+ for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
960
+ for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
961
+ column_edge_ptr[row_idx * depth + column_idx] =
962
+ b[(remainder_start_row + row_idx) * b_stride_elements + column_idx];
963
+ }
964
+ }
965
+ }
966
+
967
+ // Compute and store per-column norms for angular/euclidean distance
968
+ nk_size_t norms_offset = column_edge_offset +
969
+ (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_bf16_t) : 0);
970
+ header->norms_byte_offset = (nk_u32_t)norms_offset;
971
+ nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
972
+ for (nk_size_t col = 0; col < column_count; col++)
973
+ norms[col] = nk_dots_reduce_sumsq_bf16_(b + col * b_stride_elements, depth);
974
+ }
975
+
976
+ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
977
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
978
+ nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
979
+ nk_unused_(cols_count);
980
+
981
+ // Parse packed B header
982
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
983
+ nk_size_t const column_tiles_count = header->full_column_tiles;
984
+ nk_size_t const depth_tiles_count = header->full_depth_tiles;
985
+ nk_size_t const column_remainder_count = header->column_remainder_count;
986
+
987
+ // Packed B data regions
988
+ nk_bf16_t const *b_tiles_base = (nk_bf16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
989
+ nk_bf16_t const *col_edge_ptr = (nk_bf16_t const *)((char const *)b_packed + header->column_edge_offset);
990
+
991
+ // Stride conversions
992
+ nk_size_t const a_stride_elements = a_stride_bytes / sizeof(nk_bf16_t);
993
+ nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
994
+
995
+ // Tile dimensions
996
+ nk_size_t const tile_depth = 32; // depth elements per BF16 tile
997
+ nk_size_t const tile_size = 512; // elements per packed tile
998
+ nk_size_t const full_cols = column_tiles_count * 16;
999
+
1000
+ // Block counts (32 × 32 output blocks = 2 × 2 tiles)
1001
+ nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
1002
+ nk_size_t const col_blocks_count = column_tiles_count / 2;
1003
+
1004
+ if (depth_tiles_count == 0) return;
1005
+
1006
+ // Tile buffers for A (only used for edge tiles)
1007
+ nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
1008
+ nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
1009
+
1010
+ // Precompute: number of full depth-tiles (no masking needed)
1011
+ nk_size_t const full_depth_tiles_count = depth / tile_depth;
1012
+ nk_size_t const depth_remainder = depth % tile_depth;
1013
+
1014
+ nk_amx_tile_configure_sapphireamx_();
1015
+
1016
+ // Loop order: row_blocks outer, col_blocks inner - maximizes A tile L2 cache reuse
1017
+ // A tiles stay in L2 while we sweep through all col_blocks for a given row_block
1018
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
1019
+ nk_size_t const row_block_start = row_block_idx * 32;
1020
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
1021
+ nk_size_t const is_full_row_block = (valid_rows_count == 32);
1022
+
1023
+ for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
1024
+ nk_size_t const col_block_start = column_block_idx * 32;
1025
+ nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
1026
+ nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
1027
+
1028
+ // Zero accumulators (TMM4-7 stay resident across entire depth loop)
1029
+ _tile_zero(4);
1030
+ _tile_zero(5);
1031
+ _tile_zero(6);
1032
+ _tile_zero(7);
1033
+
1034
+ // Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
1035
+ if (is_full_row_block && full_depth_tiles_count > 0) {
1036
+ nk_bf16_t const *a_upper_base = a + row_block_start * a_stride_elements;
1037
+ nk_bf16_t const *a_lower_base = a + (row_block_start + 16) * a_stride_elements;
1038
+
1039
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
1040
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
1041
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
1042
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
1043
+
1044
+ // Prologue: load first depth tile
1045
+ _tile_loadd(0, a_upper_base, a_stride_bytes);
1046
+ _tile_loadd(1, a_lower_base, a_stride_bytes);
1047
+ _tile_loadd(2, b_tile_left->data, 64);
1048
+ _tile_loadd(3, b_tile_right->data, 64);
1049
+
1050
+ // Main loop: 2-deep software pipelining
1051
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < full_depth_tiles_count - 1; depth_tile_idx++) {
1052
+ nk_size_t const next_depth_offset = (depth_tile_idx + 1) * tile_depth;
1053
+
1054
+ _tile_dpbf16ps(4, 0, 2);
1055
+ _tile_dpbf16ps(5, 0, 3);
1056
+ _tile_dpbf16ps(6, 1, 2);
1057
+ _tile_dpbf16ps(7, 1, 3);
1058
+
1059
+ _tile_loadd(0, a_upper_base + next_depth_offset, a_stride_bytes);
1060
+ _tile_loadd(1, a_lower_base + next_depth_offset, a_stride_bytes);
1061
+ b_tile_left = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
1062
+ depth_tile_idx + 1) *
1063
+ tile_size);
1064
+ b_tile_right = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
1065
+ depth_tile_idx + 1) *
1066
+ tile_size);
1067
+ _tile_loadd(2, b_tile_left->data, 64);
1068
+ _tile_loadd(3, b_tile_right->data, 64);
1069
+ }
1070
+
1071
+ // Epilogue: final depth tile
1072
+ _tile_dpbf16ps(4, 0, 2);
1073
+ _tile_dpbf16ps(5, 0, 3);
1074
+ _tile_dpbf16ps(6, 1, 2);
1075
+ _tile_dpbf16ps(7, 1, 3);
1076
+
1077
+ // Handle partial depth-tile (if any)
1078
+ if (depth_remainder > 0) {
1079
+ nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
1080
+
1081
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a_upper_base + depth_offset, a_stride_elements, 16,
1082
+ depth_remainder);
1083
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_elements, 16,
1084
+ depth_remainder);
1085
+
1086
+ b_tile_left = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
1087
+ full_depth_tiles_count) *
1088
+ tile_size);
1089
+ b_tile_right = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
1090
+ full_depth_tiles_count) *
1091
+ tile_size);
1092
+
1093
+ _tile_loadd(0, a_tile_upper.data, 64);
1094
+ _tile_loadd(1, a_tile_lower.data, 64);
1095
+ _tile_loadd(2, b_tile_left->data, 64);
1096
+ _tile_loadd(3, b_tile_right->data, 64);
1097
+
1098
+ _tile_dpbf16ps(4, 0, 2);
1099
+ _tile_dpbf16ps(5, 0, 3);
1100
+ _tile_dpbf16ps(6, 1, 2);
1101
+ _tile_dpbf16ps(7, 1, 3);
1102
+ }
1103
+ }
1104
+ // Full row-block but only partial depth tile (depth < tile_depth)
1105
+ else if (is_full_row_block) {
1106
+ nk_bf16_t const *a_upper_base = a + row_block_start * a_stride_elements;
1107
+ nk_bf16_t const *a_lower_base = a + (row_block_start + 16) * a_stride_elements;
1108
+
1109
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a_upper_base, a_stride_elements, 16, depth_remainder);
1110
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower, a_lower_base, a_stride_elements, 16, depth_remainder);
1111
+
1112
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
1113
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
1114
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
1115
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
1116
+
1117
+ _tile_loadd(0, a_tile_upper.data, 64);
1118
+ _tile_loadd(1, a_tile_lower.data, 64);
1119
+ _tile_loadd(2, b_tile_left->data, 64);
1120
+ _tile_loadd(3, b_tile_right->data, 64);
1121
+
1122
+ _tile_dpbf16ps(4, 0, 2);
1123
+ _tile_dpbf16ps(5, 0, 3);
1124
+ _tile_dpbf16ps(6, 1, 2);
1125
+ _tile_dpbf16ps(7, 1, 3);
1126
+ }
1127
+ // Slow path: edge row-block → buffered load with masking
1128
+ else {
1129
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1130
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1131
+
1132
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1133
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1134
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
1135
+ : depth_remainder;
1136
+
1137
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper,
1138
+ a + row_block_start * a_stride_elements + depth_offset,
1139
+ a_stride_elements, rows_in_upper_tile, valid_depth);
1140
+ if (rows_in_lower_tile > 0) {
1141
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower,
1142
+ a + (row_block_start + 16) * a_stride_elements + depth_offset,
1143
+ a_stride_elements, rows_in_lower_tile, valid_depth);
1144
+ }
1145
+
1146
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
1147
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
1148
+ (b_column_left_base + depth_tile_idx) * tile_size);
1149
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
1150
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
1151
+ (b_column_right_base + depth_tile_idx) * tile_size);
1152
+
1153
+ _tile_loadd(0, a_tile_upper.data, 64);
1154
+ _tile_loadd(1, a_tile_lower.data, 64);
1155
+ _tile_loadd(2, b_tile_left->data, 64);
1156
+ _tile_loadd(3, b_tile_right->data, 64);
1157
+
1158
+ _tile_dpbf16ps(4, 0, 2);
1159
+ _tile_dpbf16ps(5, 0, 3);
1160
+ _tile_dpbf16ps(6, 1, 2);
1161
+ _tile_dpbf16ps(7, 1, 3);
1162
+ }
1163
+ }
1164
+
1165
+ // Store accumulators to output (once per output block)
1166
+ if (is_full_row_block) {
1167
+ nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
1168
+ _tile_stored(4, c_block, c_stride_bytes);
1169
+ _tile_stored(5, c_block + 16, c_stride_bytes);
1170
+ _tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
1171
+ _tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
1172
+ }
1173
+ else {
1174
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
1175
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
1176
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
1177
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
1178
+ nk_dots_bf16_output2x2_sapphireamx_(&c_accum_buffer,
1179
+ c + row_block_start * c_stride_elements + col_block_start,
1180
+ c_stride_elements, valid_rows_count, 32);
1181
+ }
1182
+ }
1183
+ }
1184
+
1185
+ // Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
1186
+ if (column_tiles_count % 2 == 1) {
1187
+ nk_size_t const column_tile_idx = column_tiles_count - 1;
1188
+ nk_size_t const col_start = column_tile_idx * 16;
1189
+ nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
1190
+
1191
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
1192
+ nk_size_t const row_block_start = row_block_idx * 32;
1193
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
1194
+ : (rows_count - row_block_start);
1195
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1196
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1197
+
1198
+ nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
1199
+
1200
+ _tile_zero(4);
1201
+ _tile_zero(6);
1202
+
1203
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1204
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1205
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
1206
+
1207
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_elements + depth_offset,
1208
+ a_stride_elements, rows_in_upper_tile, valid_depth);
1209
+ if (rows_in_lower_tile > 0) {
1210
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower,
1211
+ a + (row_block_start + 16) * a_stride_elements + depth_offset,
1212
+ a_stride_elements, rows_in_lower_tile, valid_depth);
1213
+ }
1214
+
1215
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
1216
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
1217
+ (b_column_base + depth_tile_idx) * tile_size);
1218
+
1219
+ _tile_loadd(0, a_tile_upper.data, 64);
1220
+ _tile_loadd(1, a_tile_lower.data, 64);
1221
+ _tile_loadd(2, b_tile->data, 64);
1222
+
1223
+ _tile_dpbf16ps(4, 0, 2);
1224
+ _tile_dpbf16ps(6, 1, 2);
1225
+ }
1226
+
1227
+ _tile_stored(4, c_upper_state.data, 64);
1228
+ _tile_stored(6, c_lower_state.data, 64);
1229
+
1230
+ nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
1231
+ c_stride_elements, rows_in_upper_tile, 16);
1232
+ if (rows_in_lower_tile > 0) {
1233
+ nk_dots_bf16_store_sapphireamx_(&c_lower_state,
1234
+ c + (row_block_start + 16) * c_stride_elements + col_start,
1235
+ c_stride_elements, rows_in_lower_tile, 16);
1236
+ }
1237
+ }
1238
+ }
1239
+
1240
+ // Handle column-edge (remaining columns < 16) using AMX with partial tiles
1241
+ if (column_remainder_count > 0) {
1242
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
1243
+ nk_size_t const row_block_start = row_block_idx * 32;
1244
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
1245
+ : (rows_count - row_block_start);
1246
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1247
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1248
+
1249
+ nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
1250
+ nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
1251
+ nk_dots_bf16_b32x16_sapphireamx_t b_tile;
1252
+
1253
+ _tile_zero(4);
1254
+ _tile_zero(6);
1255
+
1256
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1257
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1258
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
1259
+
1260
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_elements + depth_offset,
1261
+ a_stride_elements, rows_in_upper_tile, valid_depth);
1262
+ if (rows_in_lower_tile > 0) {
1263
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower,
1264
+ a + (row_block_start + 16) * a_stride_elements + depth_offset,
1265
+ a_stride_elements, rows_in_lower_tile, valid_depth);
1266
+ }
1267
+
1268
+ nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
1269
+ valid_depth);
1270
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
1271
+
1272
+ _tile_loadd(0, a_tile_upper.data, 64);
1273
+ _tile_loadd(1, a_tile_lower.data, 64);
1274
+ _tile_loadd(2, b_tile.data, 64);
1275
+
1276
+ _tile_dpbf16ps(4, 0, 2);
1277
+ _tile_dpbf16ps(6, 1, 2);
1278
+ }
1279
+
1280
+ _tile_stored(4, c_upper_state.data, 64);
1281
+ _tile_stored(6, c_lower_state.data, 64);
1282
+
1283
+ nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
1284
+ c_stride_elements, rows_in_upper_tile, column_remainder_count);
1285
+ if (rows_in_lower_tile > 0) {
1286
+ nk_dots_bf16_store_sapphireamx_(&c_lower_state,
1287
+ c + (row_block_start + 16) * c_stride_elements + full_cols,
1288
+ c_stride_elements, rows_in_lower_tile, column_remainder_count);
1289
+ }
1290
+ }
1291
+ }
1292
+
1293
+ _tile_release();
1294
+ }
1295
+
1296
+ NK_PUBLIC void nk_dots_compact_bf16_sapphireamx( //
1297
+ void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride) {
1298
+
1299
+ nk_size_t const c_stride_f32 = c_stride / sizeof(nk_f32_t);
1300
+ nk_f32_t const *c_f32 = (nk_f32_t const *)c;
1301
+ nk_bf16_t *c_bf16 = (nk_bf16_t *)c;
1302
+
1303
+ for (nk_size_t row_idx = 0; row_idx < row_count; row_idx++) {
1304
+ nk_f32_t const *src_row = c_f32 + row_idx * c_stride_f32;
1305
+ nk_bf16_t *dst_row = c_bf16 + row_idx * column_count;
1306
+ nk_size_t column_idx = 0;
1307
+
1308
+ // Process 16 floats at a time using AVX512-BF16
1309
+ for (; column_idx + 16 <= column_count; column_idx += 16) {
1310
+ __m512 f32_vec = _mm512_loadu_ps(src_row + column_idx);
1311
+ __m256bh bf16_vec = _mm512_cvtneps_pbh(f32_vec);
1312
+ _mm256_storeu_si256((__m256i *)(dst_row + column_idx), nk_m256i_from_m256bh_(bf16_vec));
1313
+ }
1314
+
1315
+ // Handle remaining elements with masked operations
1316
+ if (column_idx < column_count) {
1317
+ __mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
1318
+ __m512 f32_vec = _mm512_maskz_loadu_ps(tail_mask, src_row + column_idx);
1319
+ __m256bh bf16_vec = _mm512_cvtneps_pbh(f32_vec);
1320
+ _mm256_mask_storeu_epi16(dst_row + column_idx, tail_mask, nk_m256i_from_m256bh_(bf16_vec));
1321
+ }
1322
+ }
1323
+ }
1324
+
1325
+ NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx( //
1326
+ nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
1327
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
1328
+ nk_size_t row_start, nk_size_t row_count) {
1329
+
1330
+ nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
1331
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1332
+
1333
+ // Handle row slicing: compute rows [row_start, row_end)
1334
+ nk_size_t const row_end = (row_count == 0)
1335
+ ? n_vectors
1336
+ : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
1337
+
1338
+ // Round depth up to multiple of 96 (3 tiles × 32 elements)
1339
+ nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
1340
+ nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
1341
+
1342
+ nk_dots_bf16_a16x32_sapphireamx_t a_tiles[3];
1343
+ nk_dots_bf16_a16x32_sapphireamx_t b_src_tiles[3];
1344
+ nk_dots_bf16_b32x16_sapphireamx_t b_tiles[3];
1345
+ nk_dots_bf16_state_sapphireamx_t state;
1346
+
1347
+ nk_amx_tile_configure_sapphireamx_();
1348
+
1349
+ for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
1350
+ nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
1351
+
1352
+ for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
1353
+ nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
1354
+
1355
+ nk_dots_bf16_init_sapphireamx_(&state);
1356
+
1357
+ for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
1358
+ nk_size_t const depth_base = depth_group_idx * 96;
1359
+
1360
+ for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
1361
+ nk_size_t const depth_start = depth_base + tile_idx * 32;
1362
+ nk_size_t const valid_depth = (depth_start + 32 <= depth)
1363
+ ? 32
1364
+ : (depth > depth_start ? depth - depth_start : 0);
1365
+
1366
+ nk_dots_bf16_load_a_sapphireamx_( //
1367
+ &a_tiles[tile_idx], //
1368
+ vectors + row_tile * stride_elements + depth_start, //
1369
+ stride_elements, valid_rows, valid_depth);
1370
+
1371
+ if (row_tile == col_tile) {
1372
+ nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
1373
+ }
1374
+ else {
1375
+ nk_dots_bf16_load_a_sapphireamx_( //
1376
+ &b_src_tiles[tile_idx], //
1377
+ vectors + col_tile * stride_elements + depth_start, //
1378
+ stride_elements, valid_cols, valid_depth);
1379
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
1380
+ }
1381
+ }
1382
+
1383
+ nk_dots_bf16_update_sapphireamx_( //
1384
+ &state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
1385
+ }
1386
+
1387
+ nk_dots_bf16_store_sapphireamx_( //
1388
+ &state, result + row_tile * result_stride_elements + col_tile, //
1389
+ result_stride_elements, valid_rows, valid_cols);
1390
+ }
1391
+ }
1392
+ }
1393
+
1394
+ #pragma endregion // Half Precision Floats
1395
+
1396
+ #pragma region Signed Integers
1397
+
1398
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sapphireamx(nk_size_t column_count, nk_size_t depth) {
1399
+ nk_size_t const tmm_rows = 16;
1400
+ nk_size_t const tmm_cols = 64;
1401
+ nk_size_t const tile_bytes = 1024 * sizeof(nk_i8_t); // 16 × 64×1 = 1KB
1402
+
1403
+ nk_size_t const full_column_tiles = column_count / tmm_rows;
1404
+ nk_size_t const tiles_along_depth = nk_size_divide_round_up_(depth, tmm_cols);
1405
+ nk_size_t const column_remainder_count = column_count - full_column_tiles * tmm_rows;
1406
+
1407
+ // Header (64 bytes aligned)
1408
+ nk_size_t size = sizeof(nk_dots_amx_packed_header_t);
1409
+
1410
+ // All tiles for full column rows (Morton-ordered, quad-interleaved, depth remainder zero-padded)
1411
+ size += full_column_tiles * tiles_along_depth * tile_bytes;
1412
+
1413
+ // Column edge: remaining rows for ALL depth columns, stored row-major
1414
+ if (column_remainder_count > 0) size += column_remainder_count * depth * sizeof(nk_i8_t);
1415
+
1416
+ // Per-column norms for angular/euclidean distance (4 bytes each: f32 or u32)
1417
+ size += column_count * sizeof(nk_u32_t);
1418
+
1419
+ return size;
1420
+ }
1421
+
1422
+ NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
1423
+ nk_i8_t const *b, nk_size_t column_count, nk_size_t depth, //
1424
+ nk_size_t b_stride, void *b_packed) {
1425
+
1426
+ // AMX I8 tile dimensions: 16 rows × 64 columns (1024 I8 elements = 1KB)
1427
+ nk_size_t const tmm_rows = 16;
1428
+ nk_size_t const tmm_cols = 64;
1429
+ nk_size_t const tile_elements = 1024;
1430
+ nk_size_t const tile_bytes = tile_elements * sizeof(nk_i8_t);
1431
+
1432
+ // Compute layout dimensions
1433
+ nk_size_t const column_tiles_count = column_count / tmm_rows;
1434
+ nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
1435
+ nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
1436
+ nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
1437
+
1438
+ // Write header with layout metadata
1439
+ nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
1440
+ header->full_column_tiles = (nk_u32_t)column_tiles_count;
1441
+ header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
1442
+ header->column_remainder_count = (nk_u32_t)column_remainder_count;
1443
+
1444
+ // Compute memory region offsets
1445
+ nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
1446
+ nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
1447
+ header->column_edge_offset = (nk_u32_t)column_edge_offset;
1448
+
1449
+ // Pointers to packed data regions
1450
+ nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + tiles_offset);
1451
+ nk_i8_t *column_edge_ptr = (nk_i8_t *)((char *)b_packed + column_edge_offset);
1452
+
1453
+ // Zero-initialize all tiles (handles depth remainder padding)
1454
+ for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
1455
+
1456
+ // Pack tiles using LINEAR ordering: tile_index = column_tile × depth_tiles_count + depth_tile
1457
+ // This provides sequential memory access when streaming along depth dimension.
1458
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
1459
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1460
+
1461
+ // Linear tile index: all depth-tiles for one column-tile are contiguous
1462
+ nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
1463
+ nk_i8_t *tile_output = tiles_ptr + tile_index * tile_elements;
1464
+
1465
+ // Source coordinates in original B matrix
1466
+ nk_size_t const src_row_start = column_tile_idx * tmm_rows;
1467
+ nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
1468
+ nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
1469
+ : (depth - src_column_start);
1470
+
1471
+ // Pack with quad-interleaving as required by TDPBSSD instruction.
1472
+ // AMX expects: [col0_row0, col1_row0, col2_row0, col3_row0, col0_row1, ...]
1473
+ // Formula: packed_idx = (column / 4) × 64 + row × 4 + (column % 4)
1474
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
1475
+ for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
1476
+ nk_size_t const src_idx = (src_row_start + row_idx) * b_stride + src_column_start + column_idx;
1477
+ nk_size_t const dst_idx = (column_idx / 4) * 64 + row_idx * 4 + (column_idx % 4);
1478
+ tile_output[dst_idx] = b[src_idx];
1479
+ }
1480
+ }
1481
+ }
1482
+ }
1483
+
1484
+ // Pack column-remainder rows in simple row-major format (for AVX-512 fallback)
1485
+ if (column_remainder_count > 0) {
1486
+ nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
1487
+ for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
1488
+ for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
1489
+ column_edge_ptr[row_idx * depth + column_idx] =
1490
+ b[(remainder_start_row + row_idx) * b_stride + column_idx];
1491
+ }
1492
+ }
1493
+ }
1494
+
1495
+ // Compute and store per-column norms for angular/euclidean distance
1496
+ nk_size_t norms_offset = column_edge_offset +
1497
+ (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_i8_t) : 0);
1498
+ header->norms_byte_offset = (nk_u32_t)norms_offset;
1499
+ nk_u32_t *norms = (nk_u32_t *)((char *)b_packed + norms_offset);
1500
+ for (nk_size_t col = 0; col < column_count; col++) norms[col] = nk_dots_reduce_sumsq_i8_(b + col * b_stride, depth);
1501
+ }
1502
+
1503
+ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
1504
+ nk_i8_t const *a, void const *b_packed, nk_i32_t *c, //
1505
+ nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
1506
+ nk_unused_(cols_count);
1507
+
1508
+ // Parse packed B header
1509
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
1510
+ nk_size_t const column_tiles_count = header->full_column_tiles;
1511
+ nk_size_t const depth_tiles_count = header->full_depth_tiles;
1512
+ nk_size_t const column_remainder_count = header->column_remainder_count;
1513
+
1514
+ // Packed B data regions
1515
+ nk_i8_t const *b_tiles_base = (nk_i8_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
1516
+ nk_i8_t const *col_edge_ptr = (nk_i8_t const *)((char const *)b_packed + header->column_edge_offset);
1517
+
1518
+ // Stride conversions
1519
+ nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_i32_t);
1520
+
1521
+ // Tile dimensions
1522
+ nk_size_t const tile_depth = 64; // depth elements per INT8 tile
1523
+ nk_size_t const tile_size = 1024; // bytes per packed tile
1524
+ nk_size_t const full_cols = column_tiles_count * 16;
1525
+
1526
+ // Block counts (32 × 32 output blocks = 2 × 2 tiles)
1527
+ nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
1528
+ nk_size_t const col_blocks_count = column_tiles_count / 2;
1529
+
1530
+ if (depth_tiles_count == 0) return;
1531
+
1532
+ // Tile buffers for A (only used for edge tiles)
1533
+ nk_dots_i8_a16x64_sapphireamx_t a_tile_upper, a_tile_lower;
1534
+ nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
1535
+
1536
+ // Precompute: number of full depth-tiles (no masking needed)
1537
+ nk_size_t const full_depth_tiles_count = depth / tile_depth;
1538
+ nk_size_t const depth_remainder = depth % tile_depth;
1539
+
1540
+ nk_amx_tile_configure_sapphireamx_();
1541
+
1542
+ // Process all 32 × 32 row × column blocks (including partial edge blocks)
1543
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
1544
+ nk_size_t const row_block_start = row_block_idx * 32;
1545
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
1546
+ nk_size_t const is_full_row_block = (valid_rows_count == 32);
1547
+
1548
+ // Process full column-blocks (pairs of 16-column tiles = 32 columns)
1549
+ for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
1550
+ nk_size_t const col_block_start = column_block_idx * 32;
1551
+
1552
+ // B tile base indices (linear layout: col_tile × depth_tiles_count + depth_tile)
1553
+ nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
1554
+ nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
1555
+
1556
+ // Zero accumulators (TMM4-7 stay resident across entire depth loop)
1557
+ _tile_zero(4); // C[upper, left]
1558
+ _tile_zero(5); // C[upper, right]
1559
+ _tile_zero(6); // C[lower, left]
1560
+ _tile_zero(7); // C[lower, right]
1561
+
1562
+ // Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
1563
+ if (is_full_row_block && full_depth_tiles_count > 0) {
1564
+ // A row pointers for direct load
1565
+ nk_i8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
1566
+ nk_i8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
1567
+
1568
+ // B tile pointers
1569
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
1570
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
1571
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
1572
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
1573
+
1574
+ // Prologue: load first depth tile into TMM0-3
1575
+ _tile_loadd(0, a_upper_base, a_stride_bytes);
1576
+ _tile_loadd(1, a_lower_base, a_stride_bytes);
1577
+ _tile_loadd(2, b_tile_left->data, 64);
1578
+ _tile_loadd(3, b_tile_right->data, 64);
1579
+
1580
+ // Main loop: 2-deep software pipelining (compute current while loading next)
1581
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < full_depth_tiles_count - 1; depth_tile_idx++) {
1582
+ nk_size_t const next_depth_offset = (depth_tile_idx + 1) * tile_depth;
1583
+
1584
+ _tile_dpbssd(4, 0, 2);
1585
+ _tile_dpbssd(5, 0, 3);
1586
+ _tile_dpbssd(6, 1, 2);
1587
+ _tile_dpbssd(7, 1, 3);
1588
+
1589
+ _tile_loadd(0, a_upper_base + next_depth_offset, a_stride_bytes);
1590
+ _tile_loadd(1, a_lower_base + next_depth_offset, a_stride_bytes);
1591
+ b_tile_left = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
1592
+ (b_column_left_base + depth_tile_idx + 1) *
1593
+ tile_size);
1594
+ b_tile_right = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
1595
+ depth_tile_idx + 1) *
1596
+ tile_size);
1597
+ _tile_loadd(2, b_tile_left->data, 64);
1598
+ _tile_loadd(3, b_tile_right->data, 64);
1599
+ }
1600
+
1601
+ // Epilogue: final depth tile (no next to load)
1602
+ _tile_dpbssd(4, 0, 2);
1603
+ _tile_dpbssd(5, 0, 3);
1604
+ _tile_dpbssd(6, 1, 2);
1605
+ _tile_dpbssd(7, 1, 3);
1606
+
1607
+ // Handle partial depth-tile (if any) with buffered load
1608
+ if (depth_remainder > 0) {
1609
+ nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
1610
+
1611
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a_upper_base + depth_offset, a_stride_bytes, 16,
1612
+ depth_remainder);
1613
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_bytes, 16,
1614
+ depth_remainder);
1615
+
1616
+ b_tile_left = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
1617
+ full_depth_tiles_count) *
1618
+ tile_size);
1619
+ b_tile_right = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
1620
+ full_depth_tiles_count) *
1621
+ tile_size);
1622
+
1623
+ _tile_loadd(0, a_tile_upper.data, 64);
1624
+ _tile_loadd(1, a_tile_lower.data, 64);
1625
+ _tile_loadd(2, b_tile_left->data, 64);
1626
+ _tile_loadd(3, b_tile_right->data, 64);
1627
+
1628
+ _tile_dpbssd(4, 0, 2);
1629
+ _tile_dpbssd(5, 0, 3);
1630
+ _tile_dpbssd(6, 1, 2);
1631
+ _tile_dpbssd(7, 1, 3);
1632
+ }
1633
+ }
1634
+ // Full row-block but only partial depth tile (depth < tile_depth)
1635
+ else if (is_full_row_block) {
1636
+ nk_i8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
1637
+ nk_i8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
1638
+
1639
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a_upper_base, a_stride_bytes, 16, depth_remainder);
1640
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_lower, a_lower_base, a_stride_bytes, 16, depth_remainder);
1641
+
1642
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
1643
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
1644
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
1645
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
1646
+
1647
+ _tile_loadd(0, a_tile_upper.data, 64);
1648
+ _tile_loadd(1, a_tile_lower.data, 64);
1649
+ _tile_loadd(2, b_tile_left->data, 64);
1650
+ _tile_loadd(3, b_tile_right->data, 64);
1651
+
1652
+ _tile_dpbssd(4, 0, 2);
1653
+ _tile_dpbssd(5, 0, 3);
1654
+ _tile_dpbssd(6, 1, 2);
1655
+ _tile_dpbssd(7, 1, 3);
1656
+ }
1657
+ // Slow path: edge row-block → always use buffered load with masking
1658
+ else {
1659
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1660
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1661
+
1662
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1663
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1664
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
1665
+ : depth_remainder;
1666
+
1667
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
1668
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
1669
+ if (rows_in_lower_tile > 0) {
1670
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_lower,
1671
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
1672
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
1673
+ }
1674
+
1675
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
1676
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
1677
+ (b_column_left_base + depth_tile_idx) * tile_size);
1678
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
1679
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
1680
+ (b_column_right_base + depth_tile_idx) * tile_size);
1681
+
1682
+ _tile_loadd(0, a_tile_upper.data, 64);
1683
+ _tile_loadd(1, a_tile_lower.data, 64);
1684
+ _tile_loadd(2, b_tile_left->data, 64);
1685
+ _tile_loadd(3, b_tile_right->data, 64);
1686
+
1687
+ _tile_dpbssd(4, 0, 2);
1688
+ _tile_dpbssd(5, 0, 3);
1689
+ _tile_dpbssd(6, 1, 2);
1690
+ _tile_dpbssd(7, 1, 3);
1691
+ }
1692
+ }
1693
+
1694
+ // Store accumulators to output (once per output block, not per depth tile)
1695
+ if (is_full_row_block) {
1696
+ nk_i32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
1697
+ _tile_stored(4, c_block, c_stride_bytes);
1698
+ _tile_stored(5, c_block + 16, c_stride_bytes);
1699
+ _tile_stored(6, (nk_i32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
1700
+ _tile_stored(7, (nk_i32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
1701
+ }
1702
+ else {
1703
+ // Slow path: edge row-block needs masked output
1704
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
1705
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
1706
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
1707
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
1708
+ nk_dots_i8_output2x2_sapphireamx_(&c_accum_buffer,
1709
+ c + row_block_start * c_stride_elements + col_block_start,
1710
+ c_stride_elements, valid_rows_count, 32);
1711
+ }
1712
+ }
1713
+
1714
+ // Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
1715
+ if (column_tiles_count % 2 == 1) {
1716
+ nk_size_t const column_tile_idx = column_tiles_count - 1;
1717
+ nk_size_t const col_start = column_tile_idx * 16;
1718
+ nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
1719
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1720
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1721
+
1722
+ // Use 1 × 2 blocking for single column-tile (2 row-tiles × 1 column-tile)
1723
+ nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
1724
+
1725
+ _tile_zero(4);
1726
+ _tile_zero(6);
1727
+
1728
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1729
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1730
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
1731
+
1732
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
1733
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
1734
+ if (rows_in_lower_tile > 0) {
1735
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_lower,
1736
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
1737
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
1738
+ }
1739
+
1740
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile =
1741
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
1742
+ (b_column_base + depth_tile_idx) * tile_size);
1743
+
1744
+ _tile_loadd(0, a_tile_upper.data, 64);
1745
+ _tile_loadd(1, a_tile_lower.data, 64);
1746
+ _tile_loadd(2, b_tile->data, 64);
1747
+
1748
+ _tile_dpbssd(4, 0, 2);
1749
+ _tile_dpbssd(6, 1, 2);
1750
+ }
1751
+
1752
+ _tile_stored(4, c_upper_state.data, 64);
1753
+ _tile_stored(6, c_lower_state.data, 64);
1754
+
1755
+ nk_dots_i8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
1756
+ c_stride_elements, rows_in_upper_tile, 16);
1757
+ if (rows_in_lower_tile > 0) {
1758
+ nk_dots_i8_store_sapphireamx_(&c_lower_state,
1759
+ c + (row_block_start + 16) * c_stride_elements + col_start,
1760
+ c_stride_elements, rows_in_lower_tile, 16);
1761
+ }
1762
+ }
1763
+
1764
+ // Handle column-edge (remaining columns < 16) using AMX with partial tiles
1765
+ if (column_remainder_count > 0) {
1766
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1767
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1768
+
1769
+ nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
1770
+ nk_dots_i8_a16x64_sapphireamx_t b_as_a;
1771
+ nk_dots_i8_b64x16_sapphireamx_t b_tile;
1772
+
1773
+ _tile_zero(4);
1774
+ _tile_zero(6);
1775
+
1776
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1777
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1778
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
1779
+
1780
+ // Load A tiles
1781
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
1782
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
1783
+ if (rows_in_lower_tile > 0) {
1784
+ nk_dots_i8_load_a_sapphireamx_(&a_tile_lower,
1785
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
1786
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
1787
+ }
1788
+
1789
+ // Load B edge data (row-major: b_edge[row × depth + column]) and pack into B tile
1790
+ // Each "row" in edge data corresponds to one output column
1791
+ nk_dots_i8_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
1792
+ valid_depth);
1793
+ nk_dots_pack_i8_transposed_sapphireamx_(&b_as_a, &b_tile);
1794
+
1795
+ _tile_loadd(0, a_tile_upper.data, 64);
1796
+ _tile_loadd(1, a_tile_lower.data, 64);
1797
+ _tile_loadd(2, b_tile.data, 64);
1798
+
1799
+ _tile_dpbssd(4, 0, 2);
1800
+ _tile_dpbssd(6, 1, 2);
1801
+ }
1802
+
1803
+ _tile_stored(4, c_upper_state.data, 64);
1804
+ _tile_stored(6, c_lower_state.data, 64);
1805
+
1806
+ nk_dots_i8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
1807
+ c_stride_elements, rows_in_upper_tile, column_remainder_count);
1808
+ if (rows_in_lower_tile > 0) {
1809
+ nk_dots_i8_store_sapphireamx_(&c_lower_state,
1810
+ c + (row_block_start + 16) * c_stride_elements + full_cols,
1811
+ c_stride_elements, rows_in_lower_tile, column_remainder_count);
1812
+ }
1813
+ }
1814
+ }
1815
+
1816
+ _tile_release();
1817
+ }
1818
+
1819
+ NK_PUBLIC void nk_dots_compact_i8_sapphireamx( //
1820
+ void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride, nk_i32_t const *a_squared_norms,
1821
+ nk_i32_t const *b_squared_norms) {
1822
+
1823
+ nk_size_t const c_stride_i32 = c_stride / sizeof(nk_i32_t);
1824
+ nk_i32_t const *c_i32 = (nk_i32_t const *)c;
1825
+ nk_i8_t *c_i8 = (nk_i8_t *)c;
1826
+
1827
+ // Use space after I8 output for precomputed b_rsqrt (I8 output is 4x smaller than I32 input)
1828
+ nk_f32_t *b_rsqrt = (nk_f32_t *)(c_i8 + row_count * column_count);
1829
+
1830
+ // Precompute rsqrt of all b_norms using AVX512 (16 at a time)
1831
+ __m512 half_vec = _mm512_set1_ps(0.5f);
1832
+ __m512 three_halves_vec = _mm512_set1_ps(1.5f);
1833
+ nk_size_t column_idx = 0;
1834
+
1835
+ for (; column_idx + 16 <= column_count; column_idx += 16) {
1836
+ __m512i b_norms_i32 = _mm512_loadu_si512(b_squared_norms + column_idx);
1837
+ __m512 b_norms_f32 = _mm512_cvtepi32_ps(b_norms_i32);
1838
+ __m512 rsqrt_vec = _mm512_rsqrt14_ps(b_norms_f32);
1839
+ // Newton-Raphson refinement
1840
+ rsqrt_vec = _mm512_mul_ps(
1841
+ rsqrt_vec,
1842
+ _mm512_sub_ps(three_halves_vec,
1843
+ _mm512_mul_ps(half_vec, _mm512_mul_ps(b_norms_f32, _mm512_mul_ps(rsqrt_vec, rsqrt_vec)))));
1844
+ // Zero out rsqrt where norm was zero
1845
+ __mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32, _mm512_setzero_si512());
1846
+ rsqrt_vec = _mm512_maskz_mov_ps(nonzero_mask, rsqrt_vec);
1847
+ _mm512_storeu_ps(b_rsqrt + column_idx, rsqrt_vec);
1848
+ }
1849
+
1850
+ // Handle remaining b_norms with masked operations
1851
+ if (column_idx < column_count) {
1852
+ __mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
1853
+ __m512i b_norms_i32 = _mm512_maskz_loadu_epi32(tail_mask, b_squared_norms + column_idx);
1854
+ __m512 b_norms_f32 = _mm512_cvtepi32_ps(b_norms_i32);
1855
+ __m512 rsqrt_vec = _mm512_rsqrt14_ps(b_norms_f32);
1856
+ rsqrt_vec = _mm512_mul_ps(
1857
+ rsqrt_vec,
1858
+ _mm512_sub_ps(three_halves_vec,
1859
+ _mm512_mul_ps(half_vec, _mm512_mul_ps(b_norms_f32, _mm512_mul_ps(rsqrt_vec, rsqrt_vec)))));
1860
+ __mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32, _mm512_setzero_si512());
1861
+ rsqrt_vec = _mm512_maskz_mov_ps(nonzero_mask & tail_mask, rsqrt_vec);
1862
+ _mm512_mask_storeu_ps(b_rsqrt + column_idx, tail_mask, rsqrt_vec);
1863
+ }
1864
+
1865
+ __m512 scale_vec = _mm512_set1_ps(127.0f);
1866
+
1867
+ for (nk_size_t row_idx = 0; row_idx < row_count; row_idx++) {
1868
+ nk_i32_t const *src_row = c_i32 + row_idx * c_stride_i32;
1869
+ nk_i8_t *dst_row = c_i8 + row_idx * column_count;
1870
+
1871
+ // Compute rsqrt of a_norm for this row, broadcast to vector
1872
+ nk_f32_t a_norm_f32 = (nk_f32_t)a_squared_norms[row_idx];
1873
+ nk_f32_t a_rsqrt_val = 0.0f;
1874
+ if (a_norm_f32 > 0.0f) {
1875
+ __m128 a_vec = _mm_set_ss(a_norm_f32);
1876
+ __m128 rsqrt_s = _mm_rsqrt_ss(a_vec);
1877
+ rsqrt_s = _mm_mul_ss(
1878
+ rsqrt_s, _mm_sub_ss(_mm_set_ss(1.5f),
1879
+ _mm_mul_ss(_mm_set_ss(0.5f), _mm_mul_ss(a_vec, _mm_mul_ss(rsqrt_s, rsqrt_s)))));
1880
+ a_rsqrt_val = _mm_cvtss_f32(rsqrt_s);
1881
+ }
1882
+ __m512 a_rsqrt_vec = _mm512_set1_ps(a_rsqrt_val);
1883
+ __m512 row_scale = _mm512_mul_ps(a_rsqrt_vec, scale_vec);
1884
+
1885
+ column_idx = 0;
1886
+
1887
+ // Process 16 elements at a time
1888
+ for (; column_idx + 16 <= column_count; column_idx += 16) {
1889
+ __m512i c_vals = _mm512_loadu_si512(src_row + column_idx);
1890
+ __m512 c_f32 = _mm512_cvtepi32_ps(c_vals);
1891
+ __m512 b_rsqrt_vec = _mm512_loadu_ps(b_rsqrt + column_idx);
1892
+ __m512 normalized = _mm512_mul_ps(_mm512_mul_ps(c_f32, row_scale), b_rsqrt_vec);
1893
+ __m512i result_i32 = _mm512_cvtps_epi32(normalized);
1894
+ // Saturating pack I32 → I8 (16 values → 16 bytes in low 128 bits)
1895
+ __m128i result_i8 = _mm512_cvtsepi32_epi8(result_i32);
1896
+ _mm_storeu_si128((__m128i *)(dst_row + column_idx), result_i8);
1897
+ }
1898
+
1899
+ // Handle remaining elements with masked operations
1900
+ if (column_idx < column_count) {
1901
+ __mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
1902
+ __m512i c_vals = _mm512_maskz_loadu_epi32(tail_mask, src_row + column_idx);
1903
+ __m512 c_f32 = _mm512_cvtepi32_ps(c_vals);
1904
+ __m512 b_rsqrt_vec = _mm512_maskz_loadu_ps(tail_mask, b_rsqrt + column_idx);
1905
+ __m512 normalized = _mm512_mul_ps(_mm512_mul_ps(c_f32, row_scale), b_rsqrt_vec);
1906
+ __m512i result_i32 = _mm512_cvtps_epi32(normalized);
1907
+ __m128i result_i8 = _mm512_cvtsepi32_epi8(result_i32);
1908
+ _mm_mask_storeu_epi8(dst_row + column_idx, tail_mask, result_i8);
1909
+ }
1910
+ }
1911
+ }
1912
+
1913
+ NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
1914
+ nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
1915
+ nk_size_t stride, nk_i32_t *result, nk_size_t result_stride, //
1916
+ nk_size_t row_start, nk_size_t row_count) {
1917
+
1918
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_i32_t);
1919
+
1920
+ // Handle row slicing: compute rows [row_start, row_end)
1921
+ nk_size_t const row_end = (row_count == 0)
1922
+ ? n_vectors
1923
+ : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
1924
+
1925
+ // Round depth up to multiple of 192 (3 tiles × 64 elements)
1926
+ nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
1927
+ nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
1928
+
1929
+ nk_dots_i8_a16x64_sapphireamx_t a_tiles[3];
1930
+ nk_dots_i8_a16x64_sapphireamx_t b_src_tiles[3];
1931
+ nk_dots_i8_b64x16_sapphireamx_t b_tiles[3];
1932
+ nk_dots_i8_state_sapphireamx_t state;
1933
+
1934
+ nk_amx_tile_configure_sapphireamx_();
1935
+
1936
+ for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
1937
+ nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
1938
+
1939
+ for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
1940
+ nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
1941
+
1942
+ nk_dots_i8_init_sapphireamx_(&state);
1943
+
1944
+ for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
1945
+ nk_size_t const depth_base = depth_group_idx * 192;
1946
+
1947
+ for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
1948
+ nk_size_t const depth_start = depth_base + tile_idx * 64;
1949
+ nk_size_t const valid_depth = (depth_start + 64 <= depth)
1950
+ ? 64
1951
+ : (depth > depth_start ? depth - depth_start : 0);
1952
+
1953
+ nk_dots_i8_load_a_sapphireamx_( //
1954
+ &a_tiles[tile_idx], //
1955
+ vectors + row_tile * stride + depth_start, //
1956
+ stride, valid_rows, valid_depth);
1957
+
1958
+ if (row_tile == col_tile) {
1959
+ nk_dots_pack_i8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
1960
+ }
1961
+ else {
1962
+ nk_dots_i8_load_a_sapphireamx_( //
1963
+ &b_src_tiles[tile_idx], //
1964
+ vectors + col_tile * stride + depth_start, //
1965
+ stride, valid_cols, valid_depth);
1966
+ nk_dots_pack_i8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
1967
+ }
1968
+ }
1969
+
1970
+ nk_dots_i8_update_sapphireamx_( //
1971
+ &state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
1972
+ }
1973
+
1974
+ nk_dots_i8_store_sapphireamx_( //
1975
+ &state, result + row_tile * result_stride_elements + col_tile, //
1976
+ result_stride_elements, valid_rows, valid_cols);
1977
+ }
1978
+ }
1979
+ }
1980
+
1981
+ #pragma endregion // Signed Integers
1982
+
1983
+ #pragma region Unsigned Integers
1984
+
1985
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sapphireamx(nk_size_t column_count, nk_size_t depth) {
1986
+ // Same layout as I8 - just different type interpretation
1987
+ return nk_dots_packed_size_i8_sapphireamx(column_count, depth);
1988
+ }
1989
+
1990
+ NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
1991
+ nk_u8_t const *b, nk_size_t column_count, nk_size_t depth, //
1992
+ nk_size_t b_stride, void *b_packed) {
1993
+
1994
+ nk_size_t const tmm_rows = 16;
1995
+ nk_size_t const tmm_cols = 64;
1996
+ nk_size_t const tile_elements = 1024;
1997
+ nk_size_t const tile_bytes = tile_elements * sizeof(nk_u8_t);
1998
+
1999
+ nk_size_t const column_tiles_count = column_count / tmm_rows;
2000
+ nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
2001
+ nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
2002
+ nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
2003
+
2004
+ nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
2005
+ header->full_column_tiles = (nk_u32_t)column_tiles_count;
2006
+ header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
2007
+ header->column_remainder_count = (nk_u32_t)column_remainder_count;
2008
+
2009
+ nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
2010
+ nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
2011
+ header->column_edge_offset = (nk_u32_t)column_edge_offset;
2012
+
2013
+ nk_u8_t *tiles_ptr = (nk_u8_t *)((char *)b_packed + tiles_offset);
2014
+ nk_u8_t *column_edge_ptr = (nk_u8_t *)((char *)b_packed + column_edge_offset);
2015
+
2016
+ for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
2017
+
2018
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
2019
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2020
+
2021
+ nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
2022
+ nk_u8_t *tile_output = tiles_ptr + tile_index * tile_elements;
2023
+
2024
+ nk_size_t const src_row_start = column_tile_idx * tmm_rows;
2025
+ nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
2026
+ nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
2027
+ : (depth - src_column_start);
2028
+
2029
+ // Pack with quad-interleaving as required by TDPBUUD instruction.
2030
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
2031
+ for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
2032
+ nk_size_t const src_idx = (src_row_start + row_idx) * b_stride + src_column_start + column_idx;
2033
+ nk_size_t const dst_idx = (column_idx / 4) * 64 + row_idx * 4 + (column_idx % 4);
2034
+ tile_output[dst_idx] = b[src_idx];
2035
+ }
2036
+ }
2037
+ }
2038
+ }
2039
+
2040
+ if (column_remainder_count > 0) {
2041
+ nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
2042
+ for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
2043
+ for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
2044
+ column_edge_ptr[row_idx * depth + column_idx] =
2045
+ b[(remainder_start_row + row_idx) * b_stride + column_idx];
2046
+ }
2047
+ }
2048
+ }
2049
+
2050
+ // Compute and store per-column norms for angular/euclidean distance
2051
+ nk_size_t norms_offset = column_edge_offset +
2052
+ (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_u8_t) : 0);
2053
+ header->norms_byte_offset = (nk_u32_t)norms_offset;
2054
+ nk_u32_t *norms = (nk_u32_t *)((char *)b_packed + norms_offset);
2055
+ for (nk_size_t col = 0; col < column_count; col++) norms[col] = nk_dots_reduce_sumsq_u8_(b + col * b_stride, depth);
2056
+ }
2057
+
2058
+ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
2059
+ nk_u8_t const *a, void const *b_packed, nk_u32_t *c, //
2060
+ nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
2061
+ nk_unused_(cols_count);
2062
+
2063
+ // Parse packed B header
2064
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
2065
+ nk_size_t const column_tiles_count = header->full_column_tiles;
2066
+ nk_size_t const depth_tiles_count = header->full_depth_tiles;
2067
+ nk_size_t const column_remainder_count = header->column_remainder_count;
2068
+
2069
+ // Packed B data regions
2070
+ nk_u8_t const *b_tiles_base = (nk_u8_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
2071
+ nk_u8_t const *col_edge_ptr = (nk_u8_t const *)((char const *)b_packed + header->column_edge_offset);
2072
+
2073
+ // Stride conversions
2074
+ nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_u32_t);
2075
+
2076
+ // Tile dimensions
2077
+ nk_size_t const tile_depth = 64; // depth elements per U8 tile
2078
+ nk_size_t const tile_size = 1024; // bytes per packed tile
2079
+ nk_size_t const full_cols = column_tiles_count * 16;
2080
+
2081
+ // Block counts (32 × 32 output blocks = 2 × 2 tiles)
2082
+ nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
2083
+ nk_size_t const col_blocks_count = column_tiles_count / 2;
2084
+
2085
+ if (depth_tiles_count == 0) return;
2086
+
2087
+ // Tile buffers for A (only used for edge tiles)
2088
+ nk_dots_u8_a16x64_sapphireamx_t a_tile_upper, a_tile_lower;
2089
+ nk_dots_u8_state2x2_sapphireamx_t c_accum_buffer;
2090
+
2091
+ // Precompute: number of full depth-tiles
2092
+ nk_size_t const full_depth_tiles_count = depth / tile_depth;
2093
+ nk_size_t const depth_remainder = depth % tile_depth;
2094
+
2095
+ nk_amx_tile_configure_sapphireamx_();
2096
+
2097
+ // Process all 32 × 32 row × column blocks
2098
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
2099
+ nk_size_t const row_block_start = row_block_idx * 32;
2100
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
2101
+ nk_size_t const is_full_row_block = (valid_rows_count == 32);
2102
+
2103
+ // Process full column-blocks (pairs of 16-column tiles = 32 columns)
2104
+ for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
2105
+ nk_size_t const col_block_start = column_block_idx * 32;
2106
+
2107
+ // B tile base indices
2108
+ nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
2109
+ nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
2110
+
2111
+ // Zero accumulators (TMM4-7 stay resident across entire depth loop)
2112
+ _tile_zero(4);
2113
+ _tile_zero(5);
2114
+ _tile_zero(6);
2115
+ _tile_zero(7);
2116
+
2117
+ // Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
2118
+ if (is_full_row_block && full_depth_tiles_count > 0) {
2119
+ nk_u8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
2120
+ nk_u8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
2121
+
2122
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
2123
+ (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
2124
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile_right =
2125
+ (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
2126
+
2127
+ // Prologue: load first depth tile into TMM0-3
2128
+ _tile_loadd(0, a_upper_base, a_stride_bytes);
2129
+ _tile_loadd(1, a_lower_base, a_stride_bytes);
2130
+ _tile_loadd(2, b_tile_left->data, 64);
2131
+ _tile_loadd(3, b_tile_right->data, 64);
2132
+
2133
+ // Main loop: 2-deep software pipelining (compute current while loading next)
2134
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < full_depth_tiles_count - 1; depth_tile_idx++) {
2135
+ nk_size_t const next_depth_offset = (depth_tile_idx + 1) * tile_depth;
2136
+
2137
+ _tile_dpbuud(4, 0, 2);
2138
+ _tile_dpbuud(5, 0, 3);
2139
+ _tile_dpbuud(6, 1, 2);
2140
+ _tile_dpbuud(7, 1, 3);
2141
+
2142
+ _tile_loadd(0, a_upper_base + next_depth_offset, a_stride_bytes);
2143
+ _tile_loadd(1, a_lower_base + next_depth_offset, a_stride_bytes);
2144
+ b_tile_left = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
2145
+ (b_column_left_base + depth_tile_idx + 1) *
2146
+ tile_size);
2147
+ b_tile_right = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
2148
+ depth_tile_idx + 1) *
2149
+ tile_size);
2150
+ _tile_loadd(2, b_tile_left->data, 64);
2151
+ _tile_loadd(3, b_tile_right->data, 64);
2152
+ }
2153
+
2154
+ // Epilogue: final depth tile (no next to load)
2155
+ _tile_dpbuud(4, 0, 2);
2156
+ _tile_dpbuud(5, 0, 3);
2157
+ _tile_dpbuud(6, 1, 2);
2158
+ _tile_dpbuud(7, 1, 3);
2159
+
2160
+ // Handle partial depth-tile (if any) with buffered load
2161
+ if (depth_remainder > 0) {
2162
+ nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
2163
+
2164
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a_upper_base + depth_offset, a_stride_bytes, 16,
2165
+ depth_remainder);
2166
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_bytes, 16,
2167
+ depth_remainder);
2168
+
2169
+ b_tile_left = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
2170
+ full_depth_tiles_count) *
2171
+ tile_size);
2172
+ b_tile_right = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
2173
+ full_depth_tiles_count) *
2174
+ tile_size);
2175
+
2176
+ _tile_loadd(0, a_tile_upper.data, 64);
2177
+ _tile_loadd(1, a_tile_lower.data, 64);
2178
+ _tile_loadd(2, b_tile_left->data, 64);
2179
+ _tile_loadd(3, b_tile_right->data, 64);
2180
+
2181
+ _tile_dpbuud(4, 0, 2);
2182
+ _tile_dpbuud(5, 0, 3);
2183
+ _tile_dpbuud(6, 1, 2);
2184
+ _tile_dpbuud(7, 1, 3);
2185
+ }
2186
+ }
2187
+ // Full row-block but only partial depth tile (depth < tile_depth)
2188
+ else if (is_full_row_block) {
2189
+ nk_u8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
2190
+ nk_u8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
2191
+
2192
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a_upper_base, a_stride_bytes, 16, depth_remainder);
2193
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_lower, a_lower_base, a_stride_bytes, 16, depth_remainder);
2194
+
2195
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
2196
+ (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
2197
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile_right =
2198
+ (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
2199
+
2200
+ _tile_loadd(0, a_tile_upper.data, 64);
2201
+ _tile_loadd(1, a_tile_lower.data, 64);
2202
+ _tile_loadd(2, b_tile_left->data, 64);
2203
+ _tile_loadd(3, b_tile_right->data, 64);
2204
+
2205
+ _tile_dpbuud(4, 0, 2);
2206
+ _tile_dpbuud(5, 0, 3);
2207
+ _tile_dpbuud(6, 1, 2);
2208
+ _tile_dpbuud(7, 1, 3);
2209
+ }
2210
+ // Slow path: edge row-block → always use buffered load
2211
+ else {
2212
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2213
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2214
+
2215
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2216
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2217
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
2218
+ : depth_remainder;
2219
+
2220
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2221
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
2222
+ if (rows_in_lower_tile > 0) {
2223
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_lower,
2224
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2225
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
2226
+ }
2227
+
2228
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
2229
+ (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
2230
+ (b_column_left_base + depth_tile_idx) * tile_size);
2231
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile_right =
2232
+ (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
2233
+ (b_column_right_base + depth_tile_idx) * tile_size);
2234
+
2235
+ _tile_loadd(0, a_tile_upper.data, 64);
2236
+ _tile_loadd(1, a_tile_lower.data, 64);
2237
+ _tile_loadd(2, b_tile_left->data, 64);
2238
+ _tile_loadd(3, b_tile_right->data, 64);
2239
+
2240
+ _tile_dpbuud(4, 0, 2);
2241
+ _tile_dpbuud(5, 0, 3);
2242
+ _tile_dpbuud(6, 1, 2);
2243
+ _tile_dpbuud(7, 1, 3);
2244
+ }
2245
+ }
2246
+
2247
+ // Store accumulators to output (once per output block, not per depth tile)
2248
+ if (is_full_row_block) {
2249
+ nk_u32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
2250
+ _tile_stored(4, c_block, c_stride_bytes);
2251
+ _tile_stored(5, c_block + 16, c_stride_bytes);
2252
+ _tile_stored(6, (nk_u32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
2253
+ _tile_stored(7, (nk_u32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
2254
+ }
2255
+ else {
2256
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
2257
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
2258
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
2259
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
2260
+ nk_dots_u8_output2x2_sapphireamx_(&c_accum_buffer,
2261
+ c + row_block_start * c_stride_elements + col_block_start,
2262
+ c_stride_elements, valid_rows_count, 32);
2263
+ }
2264
+ }
2265
+
2266
+ // Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
2267
+ if (column_tiles_count % 2 == 1) {
2268
+ nk_size_t const column_tile_idx = column_tiles_count - 1;
2269
+ nk_size_t const col_start = column_tile_idx * 16;
2270
+ nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
2271
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2272
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2273
+
2274
+ nk_dots_u8_state_sapphireamx_t c_upper_state, c_lower_state;
2275
+
2276
+ _tile_zero(4);
2277
+ _tile_zero(6);
2278
+
2279
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2280
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2281
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2282
+
2283
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2284
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
2285
+ if (rows_in_lower_tile > 0) {
2286
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_lower,
2287
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2288
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
2289
+ }
2290
+
2291
+ nk_dots_u8_b64x16_sapphireamx_t const *b_tile =
2292
+ (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
2293
+ (b_column_base + depth_tile_idx) * tile_size);
2294
+
2295
+ _tile_loadd(0, a_tile_upper.data, 64);
2296
+ _tile_loadd(1, a_tile_lower.data, 64);
2297
+ _tile_loadd(2, b_tile->data, 64);
2298
+
2299
+ _tile_dpbuud(4, 0, 2);
2300
+ _tile_dpbuud(6, 1, 2);
2301
+ }
2302
+
2303
+ _tile_stored(4, c_upper_state.data, 64);
2304
+ _tile_stored(6, c_lower_state.data, 64);
2305
+
2306
+ nk_dots_u8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
2307
+ c_stride_elements, rows_in_upper_tile, 16);
2308
+ if (rows_in_lower_tile > 0) {
2309
+ nk_dots_u8_store_sapphireamx_(&c_lower_state,
2310
+ c + (row_block_start + 16) * c_stride_elements + col_start,
2311
+ c_stride_elements, rows_in_lower_tile, 16);
2312
+ }
2313
+ }
2314
+
2315
+ // Handle column-edge (remaining columns < 16) using AMX with partial tiles
2316
+ if (column_remainder_count > 0) {
2317
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2318
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2319
+
2320
+ nk_dots_u8_state_sapphireamx_t c_upper_state, c_lower_state;
2321
+ nk_dots_u8_a16x64_sapphireamx_t b_as_a;
2322
+ nk_dots_u8_b64x16_sapphireamx_t b_tile;
2323
+
2324
+ _tile_zero(4);
2325
+ _tile_zero(6);
2326
+
2327
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2328
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2329
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2330
+
2331
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2332
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
2333
+ if (rows_in_lower_tile > 0) {
2334
+ nk_dots_u8_load_a_sapphireamx_(&a_tile_lower,
2335
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2336
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
2337
+ }
2338
+
2339
+ nk_dots_u8_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
2340
+ valid_depth);
2341
+ nk_dots_pack_u8_transposed_sapphireamx_(&b_as_a, &b_tile);
2342
+
2343
+ _tile_loadd(0, a_tile_upper.data, 64);
2344
+ _tile_loadd(1, a_tile_lower.data, 64);
2345
+ _tile_loadd(2, b_tile.data, 64);
2346
+
2347
+ _tile_dpbuud(4, 0, 2);
2348
+ _tile_dpbuud(6, 1, 2);
2349
+ }
2350
+
2351
+ _tile_stored(4, c_upper_state.data, 64);
2352
+ _tile_stored(6, c_lower_state.data, 64);
2353
+
2354
+ nk_dots_u8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
2355
+ c_stride_elements, rows_in_upper_tile, column_remainder_count);
2356
+ if (rows_in_lower_tile > 0) {
2357
+ nk_dots_u8_store_sapphireamx_(&c_lower_state,
2358
+ c + (row_block_start + 16) * c_stride_elements + full_cols,
2359
+ c_stride_elements, rows_in_lower_tile, column_remainder_count);
2360
+ }
2361
+ }
2362
+ }
2363
+
2364
+ _tile_release();
2365
+ }
2366
+
2367
+ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
2368
+ nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
2369
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride, //
2370
+ nk_size_t row_start, nk_size_t row_count) {
2371
+
2372
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_u32_t);
2373
+
2374
+ // Handle row slicing: compute rows [row_start, row_end)
2375
+ nk_size_t const row_end = (row_count == 0)
2376
+ ? n_vectors
2377
+ : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
2378
+
2379
+ // Round depth up to multiple of 192 (3 tiles × 64 elements)
2380
+ nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
2381
+ nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
2382
+
2383
+ nk_dots_u8_a16x64_sapphireamx_t a_tiles[3];
2384
+ nk_dots_u8_a16x64_sapphireamx_t b_src_tiles[3];
2385
+ nk_dots_u8_b64x16_sapphireamx_t b_tiles[3];
2386
+ nk_dots_u8_state_sapphireamx_t state;
2387
+
2388
+ nk_amx_tile_configure_sapphireamx_();
2389
+
2390
+ for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
2391
+ nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
2392
+
2393
+ for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
2394
+ nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
2395
+
2396
+ nk_dots_u8_init_sapphireamx_(&state);
2397
+
2398
+ for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
2399
+ nk_size_t const depth_base = depth_group_idx * 192;
2400
+
2401
+ for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
2402
+ nk_size_t const depth_start = depth_base + tile_idx * 64;
2403
+ nk_size_t const valid_depth = (depth_start + 64 <= depth)
2404
+ ? 64
2405
+ : (depth > depth_start ? depth - depth_start : 0);
2406
+
2407
+ nk_dots_u8_load_a_sapphireamx_( //
2408
+ &a_tiles[tile_idx], //
2409
+ vectors + row_tile * stride + depth_start, //
2410
+ stride, valid_rows, valid_depth);
2411
+
2412
+ if (row_tile == col_tile) {
2413
+ nk_dots_pack_u8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
2414
+ }
2415
+ else {
2416
+ nk_dots_u8_load_a_sapphireamx_( //
2417
+ &b_src_tiles[tile_idx], //
2418
+ vectors + col_tile * stride + depth_start, //
2419
+ stride, valid_cols, valid_depth);
2420
+ nk_dots_pack_u8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
2421
+ }
2422
+ }
2423
+
2424
+ nk_dots_u8_update_sapphireamx_( //
2425
+ &state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
2426
+ }
2427
+
2428
+ nk_dots_u8_store_sapphireamx_( //
2429
+ &state, result + row_tile * result_stride_elements + col_tile, //
2430
+ result_stride_elements, valid_rows, valid_cols);
2431
+ }
2432
+ }
2433
+ }
2434
+
2435
+ #pragma endregion // Unsigned Integers
2436
+
2437
+ #pragma region Quarter Precision E4M3
2438
+
2439
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sapphireamx(nk_size_t column_count, nk_size_t depth) {
2440
+ // FP8 uses BF16 tile layout after conversion (same element count: 32 per row)
2441
+ return nk_dots_packed_size_bf16_sapphireamx(column_count, depth);
2442
+ }
2443
+
2444
+ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
2445
+ nk_e4m3_t const *b, nk_size_t column_count, nk_size_t depth, //
2446
+ nk_size_t b_stride, void *b_packed) {
2447
+
2448
+ nk_size_t const tmm_rows = 16;
2449
+ nk_size_t const tmm_cols = 32; // Same depth granularity as BF16
2450
+ nk_size_t const tile_elements = 512;
2451
+ nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
2452
+
2453
+ nk_size_t const column_tiles_count = column_count / tmm_rows;
2454
+ nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
2455
+ nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
2456
+ nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
2457
+
2458
+ nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
2459
+ header->full_column_tiles = (nk_u32_t)column_tiles_count;
2460
+ header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
2461
+ header->column_remainder_count = (nk_u32_t)column_remainder_count;
2462
+
2463
+ nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
2464
+ nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
2465
+ header->column_edge_offset = (nk_u32_t)column_edge_offset;
2466
+
2467
+ nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
2468
+ nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
2469
+
2470
+ for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
2471
+
2472
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
2473
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2474
+ nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
2475
+ nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
2476
+
2477
+ nk_size_t const src_row_start = column_tile_idx * tmm_rows;
2478
+ nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
2479
+ nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
2480
+ : (depth - src_column_start);
2481
+
2482
+ // Convert E4M3 to BF16 and pack with pair-interleaving
2483
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
2484
+ nk_size_t src_row = src_row_start + row_idx;
2485
+ // Load 32 E4M3 bytes and convert to BF16
2486
+ __mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
2487
+ __m256i e4m3_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
2488
+ __m512i bf16_row = nk_e4m3x32_to_bf16x32_icelake_(e4m3_row);
2489
+ // Store with pair-interleaving
2490
+ nk_bf16_t bf16_buf[32];
2491
+ _mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
2492
+ for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
2493
+ nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
2494
+ tile_output[dst_idx] = bf16_buf[column_idx];
2495
+ }
2496
+ }
2497
+ }
2498
+ }
2499
+
2500
+ // Pack column-remainder rows (convert E4M3 to BF16)
2501
+ if (column_remainder_count > 0) {
2502
+ nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
2503
+ for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
2504
+ for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
2505
+ nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
2506
+ __mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
2507
+ __m256i e4m3_chunk = _mm256_maskz_loadu_epi8(
2508
+ column_mask, b + (remainder_start_row + row_idx) * b_stride + column_idx);
2509
+ __m512i bf16_chunk = nk_e4m3x32_to_bf16x32_icelake_(e4m3_chunk);
2510
+ _mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask, bf16_chunk);
2511
+ }
2512
+ }
2513
+ }
2514
+
2515
+ // Compute and store per-column norms for angular/euclidean distance
2516
+ nk_size_t norms_offset = column_edge_offset +
2517
+ (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_bf16_t) : 0);
2518
+ header->norms_byte_offset = (nk_u32_t)norms_offset;
2519
+ nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
2520
+ for (nk_size_t col = 0; col < column_count; col++)
2521
+ norms[col] = nk_dots_reduce_sumsq_e4m3_(b + col * b_stride, depth);
2522
+ }
2523
+
2524
+ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
2525
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
2526
+ nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
2527
+ nk_unused_(cols_count);
2528
+
2529
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
2530
+ nk_size_t const column_tiles_count = header->full_column_tiles;
2531
+ nk_size_t const depth_tiles_count = header->full_depth_tiles;
2532
+ nk_size_t const column_remainder_count = header->column_remainder_count;
2533
+
2534
+ // B tiles are already in BF16 format
2535
+ nk_bf16_t const *b_tiles_base = (nk_bf16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
2536
+ nk_bf16_t const *col_edge_ptr = (nk_bf16_t const *)((char const *)b_packed + header->column_edge_offset);
2537
+
2538
+ nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
2539
+ nk_size_t const tile_depth = 32;
2540
+ nk_size_t const tile_size = 512;
2541
+ nk_size_t const full_cols = column_tiles_count * 16;
2542
+
2543
+ nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
2544
+ nk_size_t const col_blocks_count = column_tiles_count / 2;
2545
+
2546
+ if (depth_tiles_count == 0) return;
2547
+
2548
+ nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
2549
+ nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
2550
+
2551
+ nk_size_t const full_depth_tiles_count = depth / tile_depth;
2552
+ nk_size_t const depth_remainder = depth % tile_depth;
2553
+
2554
+ nk_amx_tile_configure_sapphireamx_();
2555
+
2556
+ // Loop order: row_blocks outer, col_blocks inner
2557
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
2558
+ nk_size_t const row_block_start = row_block_idx * 32;
2559
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
2560
+ nk_size_t const is_full_row_block = (valid_rows_count == 32);
2561
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2562
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2563
+
2564
+ for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
2565
+ nk_size_t const col_block_start = column_block_idx * 32;
2566
+ nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
2567
+ nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
2568
+
2569
+ // Zero accumulators (TMM4-7 stay resident across entire depth loop)
2570
+ _tile_zero(4);
2571
+ _tile_zero(5);
2572
+ _tile_zero(6);
2573
+ _tile_zero(7);
2574
+
2575
+ // FP8 always uses buffered load for E4M3 → BF16 conversion
2576
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2577
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2578
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2579
+
2580
+ // Load A with FP8 → BF16 conversion
2581
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2582
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
2583
+ if (rows_in_lower_tile > 0) {
2584
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_lower,
2585
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2586
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
2587
+ }
2588
+
2589
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
2590
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2591
+ (b_column_left_base + depth_tile_idx) * tile_size);
2592
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
2593
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2594
+ (b_column_right_base + depth_tile_idx) * tile_size);
2595
+
2596
+ _tile_loadd(0, a_tile_upper.data, 64);
2597
+ _tile_loadd(1, a_tile_lower.data, 64);
2598
+ _tile_loadd(2, b_tile_left->data, 64);
2599
+ _tile_loadd(3, b_tile_right->data, 64);
2600
+
2601
+ _tile_dpbf16ps(4, 0, 2);
2602
+ _tile_dpbf16ps(5, 0, 3);
2603
+ _tile_dpbf16ps(6, 1, 2);
2604
+ _tile_dpbf16ps(7, 1, 3);
2605
+ }
2606
+
2607
+ // Store accumulators to output (once per output block)
2608
+ if (is_full_row_block) {
2609
+ nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
2610
+ _tile_stored(4, c_block, c_stride_bytes);
2611
+ _tile_stored(5, c_block + 16, c_stride_bytes);
2612
+ _tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
2613
+ _tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
2614
+ }
2615
+ else {
2616
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
2617
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
2618
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
2619
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
2620
+ nk_dots_bf16_output2x2_sapphireamx_(&c_accum_buffer,
2621
+ c + row_block_start * c_stride_elements + col_block_start,
2622
+ c_stride_elements, valid_rows_count, 32);
2623
+ }
2624
+ }
2625
+
2626
+ // Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
2627
+ if (column_tiles_count % 2 == 1) {
2628
+ nk_size_t const column_tile_idx = column_tiles_count - 1;
2629
+ nk_size_t const col_start = column_tile_idx * 16;
2630
+ nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
2631
+
2632
+ nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
2633
+ _tile_zero(4);
2634
+ _tile_zero(6);
2635
+
2636
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2637
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2638
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2639
+
2640
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2641
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
2642
+ if (rows_in_lower_tile > 0) {
2643
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_lower,
2644
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2645
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
2646
+ }
2647
+
2648
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
2649
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2650
+ (b_column_base + depth_tile_idx) * tile_size);
2651
+
2652
+ _tile_loadd(0, a_tile_upper.data, 64);
2653
+ _tile_loadd(1, a_tile_lower.data, 64);
2654
+ _tile_loadd(2, b_tile->data, 64);
2655
+
2656
+ _tile_dpbf16ps(4, 0, 2);
2657
+ _tile_dpbf16ps(6, 1, 2);
2658
+ }
2659
+
2660
+ _tile_stored(4, c_upper_state.data, 64);
2661
+ _tile_stored(6, c_lower_state.data, 64);
2662
+
2663
+ nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
2664
+ c_stride_elements, rows_in_upper_tile, 16);
2665
+ if (rows_in_lower_tile > 0) {
2666
+ nk_dots_bf16_store_sapphireamx_(&c_lower_state,
2667
+ c + (row_block_start + 16) * c_stride_elements + col_start,
2668
+ c_stride_elements, rows_in_lower_tile, 16);
2669
+ }
2670
+ }
2671
+
2672
+ // Handle column-edge (remaining columns < 16) using AMX with partial tiles
2673
+ if (column_remainder_count > 0) {
2674
+ nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
2675
+ nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
2676
+ nk_dots_bf16_b32x16_sapphireamx_t b_tile;
2677
+
2678
+ _tile_zero(4);
2679
+ _tile_zero(6);
2680
+
2681
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2682
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2683
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2684
+
2685
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2686
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
2687
+ if (rows_in_lower_tile > 0) {
2688
+ nk_dots_e4m3_load_a_sapphireamx_(&a_tile_lower,
2689
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2690
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
2691
+ }
2692
+
2693
+ // B edge data is already in BF16 format
2694
+ nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
2695
+ valid_depth);
2696
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
2697
+
2698
+ _tile_loadd(0, a_tile_upper.data, 64);
2699
+ _tile_loadd(1, a_tile_lower.data, 64);
2700
+ _tile_loadd(2, b_tile.data, 64);
2701
+
2702
+ _tile_dpbf16ps(4, 0, 2);
2703
+ _tile_dpbf16ps(6, 1, 2);
2704
+ }
2705
+
2706
+ _tile_stored(4, c_upper_state.data, 64);
2707
+ _tile_stored(6, c_lower_state.data, 64);
2708
+
2709
+ nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
2710
+ c_stride_elements, rows_in_upper_tile, column_remainder_count);
2711
+ if (rows_in_lower_tile > 0) {
2712
+ nk_dots_bf16_store_sapphireamx_(&c_lower_state,
2713
+ c + (row_block_start + 16) * c_stride_elements + full_cols,
2714
+ c_stride_elements, rows_in_lower_tile, column_remainder_count);
2715
+ }
2716
+ }
2717
+ }
2718
+
2719
+ _tile_release();
2720
+ }
2721
+
2722
+ #pragma endregion // Quarter Precision E4M3
2723
+
2724
+ #pragma region Quarter Precision E5M2
2725
+
2726
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sapphireamx(nk_size_t column_count, nk_size_t depth) {
2727
+ return nk_dots_packed_size_bf16_sapphireamx(column_count, depth);
2728
+ }
2729
+
2730
+ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
2731
+ nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth, //
2732
+ nk_size_t b_stride, void *b_packed) {
2733
+
2734
+ nk_size_t const tmm_rows = 16;
2735
+ nk_size_t const tmm_cols = 32;
2736
+ nk_size_t const tile_elements = 512;
2737
+ nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
2738
+
2739
+ nk_size_t const column_tiles_count = column_count / tmm_rows;
2740
+ nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
2741
+ nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
2742
+ nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
2743
+
2744
+ nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
2745
+ header->full_column_tiles = (nk_u32_t)column_tiles_count;
2746
+ header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
2747
+ header->column_remainder_count = (nk_u32_t)column_remainder_count;
2748
+
2749
+ nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
2750
+ nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
2751
+ header->column_edge_offset = (nk_u32_t)column_edge_offset;
2752
+
2753
+ nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
2754
+ nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
2755
+
2756
+ for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
2757
+
2758
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
2759
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2760
+ nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
2761
+ nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
2762
+
2763
+ nk_size_t const src_row_start = column_tile_idx * tmm_rows;
2764
+ nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
2765
+ nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
2766
+ : (depth - src_column_start);
2767
+
2768
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
2769
+ nk_size_t src_row = src_row_start + row_idx;
2770
+ __mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
2771
+ __m256i e5m2_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
2772
+ __m512i bf16_row = nk_e5m2x32_to_bf16x32_icelake_(e5m2_row);
2773
+ nk_bf16_t bf16_buf[32];
2774
+ _mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
2775
+ for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
2776
+ nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
2777
+ tile_output[dst_idx] = bf16_buf[column_idx];
2778
+ }
2779
+ }
2780
+ }
2781
+ }
2782
+
2783
+ if (column_remainder_count > 0) {
2784
+ nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
2785
+ for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
2786
+ for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
2787
+ nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
2788
+ __mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
2789
+ __m256i e5m2_chunk = _mm256_maskz_loadu_epi8(
2790
+ column_mask, b + (remainder_start_row + row_idx) * b_stride + column_idx);
2791
+ __m512i bf16_chunk = nk_e5m2x32_to_bf16x32_icelake_(e5m2_chunk);
2792
+ _mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask, bf16_chunk);
2793
+ }
2794
+ }
2795
+ }
2796
+
2797
+ // Compute and store per-column norms for angular/euclidean distance
2798
+ nk_size_t norms_offset = column_edge_offset +
2799
+ (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_bf16_t) : 0);
2800
+ header->norms_byte_offset = (nk_u32_t)norms_offset;
2801
+ nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
2802
+ for (nk_size_t col = 0; col < column_count; col++)
2803
+ norms[col] = nk_dots_reduce_sumsq_e5m2_(b + col * b_stride, depth);
2804
+ }
2805
+
2806
+ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
2807
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
2808
+ nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
2809
+ nk_unused_(cols_count);
2810
+
2811
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
2812
+ nk_size_t const column_tiles_count = header->full_column_tiles;
2813
+ nk_size_t const depth_tiles_count = header->full_depth_tiles;
2814
+ nk_size_t const column_remainder_count = header->column_remainder_count;
2815
+
2816
+ nk_bf16_t const *b_tiles_base = (nk_bf16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
2817
+ nk_bf16_t const *col_edge_ptr = (nk_bf16_t const *)((char const *)b_packed + header->column_edge_offset);
2818
+
2819
+ nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
2820
+ nk_size_t const tile_depth = 32;
2821
+ nk_size_t const tile_size = 512;
2822
+ nk_size_t const full_cols = column_tiles_count * 16;
2823
+
2824
+ nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
2825
+ nk_size_t const col_blocks_count = column_tiles_count / 2;
2826
+
2827
+ if (depth_tiles_count == 0) return;
2828
+
2829
+ nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
2830
+ nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
2831
+
2832
+ nk_size_t const full_depth_tiles_count = depth / tile_depth;
2833
+ nk_size_t const depth_remainder = depth % tile_depth;
2834
+
2835
+ nk_amx_tile_configure_sapphireamx_();
2836
+
2837
+ // Loop order: row_blocks outer, col_blocks inner
2838
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
2839
+ nk_size_t const row_block_start = row_block_idx * 32;
2840
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
2841
+ nk_size_t const is_full_row_block = (valid_rows_count == 32);
2842
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
2843
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
2844
+
2845
+ for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
2846
+ nk_size_t const col_block_start = column_block_idx * 32;
2847
+ nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
2848
+ nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
2849
+
2850
+ // Zero accumulators (TMM4-7 stay resident across entire depth loop)
2851
+ _tile_zero(4);
2852
+ _tile_zero(5);
2853
+ _tile_zero(6);
2854
+ _tile_zero(7);
2855
+
2856
+ // FP8 always uses buffered load for E5M2 → BF16 conversion
2857
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2858
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2859
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2860
+
2861
+ // Load A with FP8 → BF16 conversion
2862
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2863
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
2864
+ if (rows_in_lower_tile > 0) {
2865
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_lower,
2866
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2867
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
2868
+ }
2869
+
2870
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
2871
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2872
+ (b_column_left_base + depth_tile_idx) * tile_size);
2873
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
2874
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2875
+ (b_column_right_base + depth_tile_idx) * tile_size);
2876
+
2877
+ _tile_loadd(0, a_tile_upper.data, 64);
2878
+ _tile_loadd(1, a_tile_lower.data, 64);
2879
+ _tile_loadd(2, b_tile_left->data, 64);
2880
+ _tile_loadd(3, b_tile_right->data, 64);
2881
+
2882
+ _tile_dpbf16ps(4, 0, 2);
2883
+ _tile_dpbf16ps(5, 0, 3);
2884
+ _tile_dpbf16ps(6, 1, 2);
2885
+ _tile_dpbf16ps(7, 1, 3);
2886
+ }
2887
+
2888
+ // Store accumulators to output (once per output block)
2889
+ if (is_full_row_block) {
2890
+ nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
2891
+ _tile_stored(4, c_block, c_stride_bytes);
2892
+ _tile_stored(5, c_block + 16, c_stride_bytes);
2893
+ _tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
2894
+ _tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
2895
+ }
2896
+ else {
2897
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
2898
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
2899
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
2900
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
2901
+ nk_dots_bf16_output2x2_sapphireamx_(&c_accum_buffer,
2902
+ c + row_block_start * c_stride_elements + col_block_start,
2903
+ c_stride_elements, valid_rows_count, 32);
2904
+ }
2905
+ }
2906
+
2907
+ // Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
2908
+ if (column_tiles_count % 2 == 1) {
2909
+ nk_size_t const column_tile_idx = column_tiles_count - 1;
2910
+ nk_size_t const col_start = column_tile_idx * 16;
2911
+ nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
2912
+
2913
+ nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
2914
+ _tile_zero(4);
2915
+ _tile_zero(6);
2916
+
2917
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2918
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2919
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2920
+
2921
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2922
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
2923
+ if (rows_in_lower_tile > 0) {
2924
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_lower,
2925
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2926
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
2927
+ }
2928
+
2929
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
2930
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
2931
+ (b_column_base + depth_tile_idx) * tile_size);
2932
+
2933
+ _tile_loadd(0, a_tile_upper.data, 64);
2934
+ _tile_loadd(1, a_tile_lower.data, 64);
2935
+ _tile_loadd(2, b_tile->data, 64);
2936
+
2937
+ _tile_dpbf16ps(4, 0, 2);
2938
+ _tile_dpbf16ps(6, 1, 2);
2939
+ }
2940
+
2941
+ _tile_stored(4, c_upper_state.data, 64);
2942
+ _tile_stored(6, c_lower_state.data, 64);
2943
+
2944
+ nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
2945
+ c_stride_elements, rows_in_upper_tile, 16);
2946
+ if (rows_in_lower_tile > 0) {
2947
+ nk_dots_bf16_store_sapphireamx_(&c_lower_state,
2948
+ c + (row_block_start + 16) * c_stride_elements + col_start,
2949
+ c_stride_elements, rows_in_lower_tile, 16);
2950
+ }
2951
+ }
2952
+
2953
+ // Handle column-edge (remaining columns < 16) using AMX with partial tiles
2954
+ if (column_remainder_count > 0) {
2955
+ nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
2956
+ nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
2957
+ nk_dots_bf16_b32x16_sapphireamx_t b_tile;
2958
+
2959
+ _tile_zero(4);
2960
+ _tile_zero(6);
2961
+
2962
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
2963
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
2964
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
2965
+
2966
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
2967
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
2968
+ if (rows_in_lower_tile > 0) {
2969
+ nk_dots_e5m2_load_a_sapphireamx_(&a_tile_lower,
2970
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
2971
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
2972
+ }
2973
+
2974
+ nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
2975
+ valid_depth);
2976
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
2977
+
2978
+ _tile_loadd(0, a_tile_upper.data, 64);
2979
+ _tile_loadd(1, a_tile_lower.data, 64);
2980
+ _tile_loadd(2, b_tile.data, 64);
2981
+
2982
+ _tile_dpbf16ps(4, 0, 2);
2983
+ _tile_dpbf16ps(6, 1, 2);
2984
+ }
2985
+
2986
+ _tile_stored(4, c_upper_state.data, 64);
2987
+ _tile_stored(6, c_lower_state.data, 64);
2988
+
2989
+ nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
2990
+ c_stride_elements, rows_in_upper_tile, column_remainder_count);
2991
+ if (rows_in_lower_tile > 0) {
2992
+ nk_dots_bf16_store_sapphireamx_(&c_lower_state,
2993
+ c + (row_block_start + 16) * c_stride_elements + full_cols,
2994
+ c_stride_elements, rows_in_lower_tile, column_remainder_count);
2995
+ }
2996
+ }
2997
+ }
2998
+
2999
+ _tile_release();
3000
+ }
3001
+
3002
+ NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
3003
+ nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
3004
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
3005
+ nk_size_t row_start, nk_size_t row_count) {
3006
+
3007
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
3008
+
3009
+ // Handle row slicing: compute rows [row_start, row_end)
3010
+ nk_size_t const row_end = (row_count == 0)
3011
+ ? n_vectors
3012
+ : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
3013
+
3014
+ // Round depth up to multiple of 96 (3 tiles × 32 elements)
3015
+ nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
3016
+ nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
3017
+
3018
+ nk_dots_bf16_a16x32_sapphireamx_t a_tiles[3];
3019
+ nk_dots_bf16_a16x32_sapphireamx_t b_src_tiles[3];
3020
+ nk_dots_bf16_b32x16_sapphireamx_t b_tiles[3];
3021
+ nk_dots_bf16_state_sapphireamx_t state;
3022
+
3023
+ nk_amx_tile_configure_sapphireamx_();
3024
+
3025
+ for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
3026
+ nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
3027
+
3028
+ for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
3029
+ nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
3030
+
3031
+ nk_dots_bf16_init_sapphireamx_(&state);
3032
+
3033
+ for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
3034
+ nk_size_t const depth_base = depth_group_idx * 96;
3035
+
3036
+ for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
3037
+ nk_size_t const depth_start = depth_base + tile_idx * 32;
3038
+ nk_size_t const valid_depth = (depth_start + 32 <= depth)
3039
+ ? 32
3040
+ : (depth > depth_start ? depth - depth_start : 0);
3041
+
3042
+ nk_dots_e5m2_load_a_sapphireamx_( //
3043
+ &a_tiles[tile_idx], //
3044
+ vectors + row_tile * stride + depth_start, //
3045
+ stride, valid_rows, valid_depth);
3046
+
3047
+ if (row_tile == col_tile) {
3048
+ nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
3049
+ }
3050
+ else {
3051
+ nk_dots_e5m2_load_a_sapphireamx_( //
3052
+ &b_src_tiles[tile_idx], //
3053
+ vectors + col_tile * stride + depth_start, //
3054
+ stride, valid_cols, valid_depth);
3055
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
3056
+ }
3057
+ }
3058
+
3059
+ nk_dots_bf16_update_sapphireamx_( //
3060
+ &state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
3061
+ }
3062
+
3063
+ nk_dots_bf16_store_sapphireamx_( //
3064
+ &state, result + row_tile * result_stride_elements + col_tile, //
3065
+ result_stride_elements, valid_rows, valid_cols);
3066
+ }
3067
+ }
3068
+ }
3069
+
3070
+ NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
3071
+ nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
3072
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
3073
+ nk_size_t row_start, nk_size_t row_count) {
3074
+
3075
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
3076
+
3077
+ // Handle row slicing: compute rows [row_start, row_end)
3078
+ nk_size_t const row_end = (row_count == 0)
3079
+ ? n_vectors
3080
+ : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
3081
+
3082
+ // Round depth up to multiple of 96 (3 tiles × 32 elements)
3083
+ nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
3084
+ nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
3085
+
3086
+ nk_dots_bf16_a16x32_sapphireamx_t a_tiles[3];
3087
+ nk_dots_bf16_a16x32_sapphireamx_t b_src_tiles[3];
3088
+ nk_dots_bf16_b32x16_sapphireamx_t b_tiles[3];
3089
+ nk_dots_bf16_state_sapphireamx_t state;
3090
+
3091
+ nk_amx_tile_configure_sapphireamx_();
3092
+
3093
+ for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
3094
+ nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
3095
+
3096
+ for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
3097
+ nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
3098
+
3099
+ nk_dots_bf16_init_sapphireamx_(&state);
3100
+
3101
+ for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
3102
+ nk_size_t const depth_base = depth_group_idx * 96;
3103
+
3104
+ for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
3105
+ nk_size_t const depth_start = depth_base + tile_idx * 32;
3106
+ nk_size_t const valid_depth = (depth_start + 32 <= depth)
3107
+ ? 32
3108
+ : (depth > depth_start ? depth - depth_start : 0);
3109
+
3110
+ nk_dots_e4m3_load_a_sapphireamx_( //
3111
+ &a_tiles[tile_idx], //
3112
+ vectors + row_tile * stride + depth_start, //
3113
+ stride, valid_rows, valid_depth);
3114
+
3115
+ if (row_tile == col_tile) {
3116
+ nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
3117
+ }
3118
+ else {
3119
+ nk_dots_e4m3_load_a_sapphireamx_( //
3120
+ &b_src_tiles[tile_idx], //
3121
+ vectors + col_tile * stride + depth_start, //
3122
+ stride, valid_cols, valid_depth);
3123
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
3124
+ }
3125
+ }
3126
+
3127
+ nk_dots_bf16_update_sapphireamx_( //
3128
+ &state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
3129
+ }
3130
+
3131
+ nk_dots_bf16_store_sapphireamx_( //
3132
+ &state, result + row_tile * result_stride_elements + col_tile, //
3133
+ result_stride_elements, valid_rows, valid_cols);
3134
+ }
3135
+ }
3136
+ }
3137
+
3138
+ #pragma endregion // Quarter Precision E5M2
3139
+
3140
+ #pragma region Micro Precision E2M3
3141
+
3142
+ /* Load E2M3 A tile with E2M3 to signed I8 conversion via VPERMB LUT.
3143
+ * Each E2M3 byte encodes: bit 5 = sign, bits 4:0 = magnitude (5-bit index).
3144
+ * The LUT maps 5-bit magnitude to value * 16, then sign is applied via conditional negation.
3145
+ * Result is stored in INT8 tile for use with _tile_dpbssd.
3146
+ */
3147
+ NK_INTERNAL void nk_dots_e2m3_load_a_sapphireamx_( //
3148
+ nk_dots_i8_a16x64_sapphireamx_t *a_tile, //
3149
+ nk_e2m3_t const *src, nk_size_t src_stride, //
3150
+ nk_size_t valid_rows, nk_size_t valid_cols) {
3151
+
3152
+ // Build 64-byte LUT for VPERMB: 32 entries replicated to fill both halves.
3153
+ // magnitude → value×16:
3154
+ // e=0 (step 2): {0,2,4,6,8,10,12,14},
3155
+ // e=1 (step 2): {16,18,20,22,24,26,28,30},
3156
+ // e=2 (step 4): {32,36,40,44,48,52,56,60},
3157
+ // e=3 (step 8): {64,72,80,88,96,104,112,120}
3158
+ NK_ALIGN64 static nk_u8_t const lut_bytes[64] = {
3159
+ 0, 2, 4, 6, 8, 10, 12, 14, //
3160
+ 16, 18, 20, 22, 24, 26, 28, 30, //
3161
+ 32, 36, 40, 44, 48, 52, 56, 60, //
3162
+ 64, 72, 80, 88, 96, 104, 112, 120, //
3163
+ 0, 2, 4, 6, 8, 10, 12, 14, //
3164
+ 16, 18, 20, 22, 24, 26, 28, 30, //
3165
+ 32, 36, 40, 44, 48, 52, 56, 60, //
3166
+ 64, 72, 80, 88, 96, 104, 112, 120, //
3167
+ };
3168
+ __m512i magnitude_lut_u8x64 = _mm512_load_si512((__m512i const *)lut_bytes);
3169
+ __m512i sign_mask_u8x64 = _mm512_set1_epi8(0x20);
3170
+ __m512i magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
3171
+ __m512i zero_i8x64 = _mm512_setzero_si512();
3172
+
3173
+ __mmask64 column_mask = (valid_cols >= 64) ? 0xFFFFFFFFFFFFFFFFULL : ((__mmask64)1 << valid_cols) - 1;
3174
+
3175
+ for (nk_size_t row = 0; row < 16; row++) {
3176
+ if (row < valid_rows) {
3177
+ __m512i raw_u8x64 = _mm512_maskz_loadu_epi8(column_mask, src + row * src_stride);
3178
+ __m512i magnitude_u8x64 = _mm512_and_si512(raw_u8x64, magnitude_mask_u8x64);
3179
+ __m512i unsigned_value_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, magnitude_lut_u8x64);
3180
+ __mmask64 negate_mask = _mm512_test_epi8_mask(raw_u8x64, sign_mask_u8x64);
3181
+ __m512i signed_value_i8x64 = _mm512_mask_sub_epi8(unsigned_value_u8x64, negate_mask, zero_i8x64,
3182
+ unsigned_value_u8x64);
3183
+ _mm512_store_si512(a_tile->data[row], signed_value_i8x64);
3184
+ }
3185
+ else { _mm512_store_si512(a_tile->data[row], zero_i8x64); }
3186
+ }
3187
+ nk_compiler_barrier_sapphireamx_();
3188
+ }
3189
+
3190
+ /* Store E2M3 accumulator: read I32 state, convert to F32, multiply by 1/256, store as F32. */
3191
+ NK_INTERNAL void nk_dots_e2m3_store_sapphireamx_( //
3192
+ nk_dots_i8_state_sapphireamx_t const *state, //
3193
+ nk_f32_t *dst, nk_size_t dst_stride_elements, //
3194
+ nk_size_t valid_rows, nk_size_t valid_cols) {
3195
+
3196
+ __mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
3197
+ __m512 scale = _mm512_set1_ps(1.0f / 256.0f);
3198
+
3199
+ for (nk_size_t row = 0; row < valid_rows; row++) {
3200
+ __m512i i32_row = _mm512_load_si512(state->data[row]);
3201
+ __m512 f32_row = _mm512_mul_ps(_mm512_cvtepi32_ps(i32_row), scale);
3202
+ _mm512_mask_storeu_ps(dst + row * dst_stride_elements, column_mask, f32_row);
3203
+ }
3204
+ }
3205
+
3206
+ /* Store E2M3 2x2 accumulator state to F32 output matrix with masking for edge tiles. */
3207
+ NK_INTERNAL void nk_dots_e2m3_output2x2_sapphireamx_( //
3208
+ nk_dots_i8_state2x2_sapphireamx_t const *state, //
3209
+ nk_f32_t *dst, nk_size_t dst_stride_elements, //
3210
+ nk_size_t valid_rows, nk_size_t valid_cols) {
3211
+
3212
+ nk_size_t const rows_upper = (valid_rows > 16) ? 16 : valid_rows;
3213
+ nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
3214
+ nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
3215
+
3216
+ if (rows_upper > 0 && cols_left > 0)
3217
+ nk_dots_e2m3_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_upper, cols_left);
3218
+ if (rows_upper > 0 && cols_right > 0)
3219
+ nk_dots_e2m3_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_upper, cols_right);
3220
+
3221
+ if (valid_rows > 16) {
3222
+ nk_size_t const rows_lower = valid_rows - 16;
3223
+ nk_f32_t *dst_lower = dst + 16 * dst_stride_elements;
3224
+ if (cols_left > 0)
3225
+ nk_dots_e2m3_store_sapphireamx_(&state->c[1][0], dst_lower, dst_stride_elements, rows_lower, cols_left);
3226
+ if (cols_right > 0)
3227
+ nk_dots_e2m3_store_sapphireamx_(&state->c[1][1], dst_lower + 16, dst_stride_elements, rows_lower,
3228
+ cols_right);
3229
+ }
3230
+ }
3231
+
3232
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sapphireamx(nk_size_t column_count, nk_size_t depth) {
3233
+ // E2M3 uses INT8 tile layout after conversion (same element count: 64 per row)
3234
+ return nk_dots_packed_size_i8_sapphireamx(column_count, depth);
3235
+ }
3236
+
3237
+ NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
3238
+ nk_e2m3_t const *b, nk_size_t column_count, nk_size_t depth, //
3239
+ nk_size_t b_stride, void *b_packed) {
3240
+
3241
+ // AMX I8 tile dimensions: 16 rows x 64 columns (1024 I8 elements = 1KB)
3242
+ nk_size_t const tmm_rows = 16;
3243
+ nk_size_t const tmm_cols = 64;
3244
+ nk_size_t const tile_elements = 1024;
3245
+ nk_size_t const tile_bytes = tile_elements * sizeof(nk_i8_t);
3246
+
3247
+ nk_size_t const column_tiles_count = column_count / tmm_rows;
3248
+ nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
3249
+ nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
3250
+ nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
3251
+
3252
+ nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
3253
+ header->full_column_tiles = (nk_u32_t)column_tiles_count;
3254
+ header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
3255
+ header->column_remainder_count = (nk_u32_t)column_remainder_count;
3256
+
3257
+ nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
3258
+ nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
3259
+ header->column_edge_offset = (nk_u32_t)column_edge_offset;
3260
+
3261
+ nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + tiles_offset);
3262
+ nk_i8_t *column_edge_ptr = (nk_i8_t *)((char *)b_packed + column_edge_offset);
3263
+
3264
+ // Zero-initialize all tiles (handles depth remainder padding)
3265
+ for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
3266
+
3267
+ // E2M3 magnitude-to-value LUT (value * 16)
3268
+ static nk_u8_t const lut_magnitude[32] = {
3269
+ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, //
3270
+ 32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, //
3271
+ };
3272
+
3273
+ // Pack tiles with E2M3 -> I8 conversion and quad-interleaving
3274
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
3275
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3276
+ nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
3277
+ nk_i8_t *tile_output = tiles_ptr + tile_index * tile_elements;
3278
+
3279
+ nk_size_t const src_row_start = column_tile_idx * tmm_rows;
3280
+ nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
3281
+ nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
3282
+ : (depth - src_column_start);
3283
+
3284
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
3285
+ for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
3286
+ nk_size_t const src_idx = (src_row_start + row_idx) * b_stride + src_column_start + column_idx;
3287
+ nk_size_t const dst_idx = (column_idx / 4) * 64 + row_idx * 4 + (column_idx % 4);
3288
+ nk_u8_t raw = b[src_idx];
3289
+ nk_u8_t magnitude = raw & 0x1F;
3290
+ nk_i8_t val = (nk_i8_t)lut_magnitude[magnitude];
3291
+ if (raw & 0x20) val = -val;
3292
+ tile_output[dst_idx] = val;
3293
+ }
3294
+ }
3295
+ }
3296
+ }
3297
+
3298
+ // Pack column-remainder rows (convert E2M3 to I8)
3299
+ if (column_remainder_count > 0) {
3300
+ nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
3301
+ for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
3302
+ for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
3303
+ nk_u8_t raw = b[(remainder_start_row + row_idx) * b_stride + column_idx];
3304
+ nk_u8_t magnitude = raw & 0x1F;
3305
+ nk_i8_t val = (nk_i8_t)lut_magnitude[magnitude];
3306
+ if (raw & 0x20) val = -val;
3307
+ column_edge_ptr[row_idx * depth + column_idx] = val;
3308
+ }
3309
+ }
3310
+ }
3311
+
3312
+ // Compute and store per-column norms for angular/euclidean distance
3313
+ nk_size_t norms_offset = column_edge_offset +
3314
+ (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_i8_t) : 0);
3315
+ header->norms_byte_offset = (nk_u32_t)norms_offset;
3316
+ nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
3317
+ for (nk_size_t col = 0; col < column_count; col++)
3318
+ norms[col] = nk_dots_reduce_sumsq_e2m3_(b + col * b_stride, depth);
3319
+ }
3320
+
3321
+ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
3322
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
3323
+ nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
3324
+ nk_unused_(cols_count);
3325
+
3326
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
3327
+ nk_size_t const column_tiles_count = header->full_column_tiles;
3328
+ nk_size_t const depth_tiles_count = header->full_depth_tiles;
3329
+ nk_size_t const column_remainder_count = header->column_remainder_count;
3330
+
3331
+ // B tiles are already in I8 format
3332
+ nk_i8_t const *b_tiles_base = (nk_i8_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
3333
+ nk_i8_t const *col_edge_ptr = (nk_i8_t const *)((char const *)b_packed + header->column_edge_offset);
3334
+
3335
+ nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
3336
+ nk_size_t const tile_depth = 64;
3337
+ nk_size_t const tile_size = 1024;
3338
+ nk_size_t const full_cols = column_tiles_count * 16;
3339
+
3340
+ nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
3341
+ nk_size_t const col_blocks_count = column_tiles_count / 2;
3342
+
3343
+ if (depth_tiles_count == 0) return;
3344
+
3345
+ nk_dots_i8_a16x64_sapphireamx_t a_tile_upper, a_tile_lower;
3346
+ nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
3347
+
3348
+ nk_size_t const full_depth_tiles_count = depth / tile_depth;
3349
+ nk_size_t const depth_remainder = depth % tile_depth;
3350
+
3351
+ nk_amx_tile_configure_sapphireamx_();
3352
+
3353
+ // Loop order: row_blocks outer, col_blocks inner
3354
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
3355
+ nk_size_t const row_block_start = row_block_idx * 32;
3356
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
3357
+ nk_size_t const is_full_row_block = (valid_rows_count == 32);
3358
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
3359
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
3360
+
3361
+ for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
3362
+ nk_size_t const col_block_start = column_block_idx * 32;
3363
+ nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
3364
+ nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
3365
+
3366
+ // Zero accumulators (TMM4-7 stay resident across entire depth loop)
3367
+ _tile_zero(4);
3368
+ _tile_zero(5);
3369
+ _tile_zero(6);
3370
+ _tile_zero(7);
3371
+
3372
+ // E2M3 always uses buffered load for E2M3 -> I8 conversion
3373
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3374
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3375
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3376
+
3377
+ // Load A with E2M3 -> I8 conversion
3378
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3379
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
3380
+ if (rows_in_lower_tile > 0) {
3381
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_lower,
3382
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3383
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
3384
+ }
3385
+
3386
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
3387
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
3388
+ (b_column_left_base + depth_tile_idx) * tile_size);
3389
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
3390
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
3391
+ (b_column_right_base + depth_tile_idx) * tile_size);
3392
+
3393
+ _tile_loadd(0, a_tile_upper.data, 64);
3394
+ _tile_loadd(1, a_tile_lower.data, 64);
3395
+ _tile_loadd(2, b_tile_left->data, 64);
3396
+ _tile_loadd(3, b_tile_right->data, 64);
3397
+
3398
+ _tile_dpbssd(4, 0, 2);
3399
+ _tile_dpbssd(5, 0, 3);
3400
+ _tile_dpbssd(6, 1, 2);
3401
+ _tile_dpbssd(7, 1, 3);
3402
+ }
3403
+
3404
+ // Store accumulators to output (once per output block)
3405
+ // Can't directly store I32 tiles to F32 output, must buffer + convert
3406
+ if (is_full_row_block) {
3407
+ nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
3408
+ nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
3409
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
3410
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
3411
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
3412
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
3413
+ nk_dots_e2m3_output2x2_sapphireamx_(&c_accum_buffer, c_block, c_stride_elements, valid_rows_count, 32);
3414
+ }
3415
+ else {
3416
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
3417
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
3418
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
3419
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
3420
+ nk_dots_e2m3_output2x2_sapphireamx_(&c_accum_buffer,
3421
+ c + row_block_start * c_stride_elements + col_block_start,
3422
+ c_stride_elements, valid_rows_count, 32);
3423
+ }
3424
+ }
3425
+
3426
+ // Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
3427
+ if (column_tiles_count % 2 == 1) {
3428
+ nk_size_t const column_tile_idx = column_tiles_count - 1;
3429
+ nk_size_t const col_start = column_tile_idx * 16;
3430
+ nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
3431
+
3432
+ nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
3433
+ _tile_zero(4);
3434
+ _tile_zero(6);
3435
+
3436
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3437
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3438
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3439
+
3440
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3441
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
3442
+ if (rows_in_lower_tile > 0) {
3443
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_lower,
3444
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3445
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
3446
+ }
3447
+
3448
+ nk_dots_i8_b64x16_sapphireamx_t const *b_tile =
3449
+ (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
3450
+ (b_column_base + depth_tile_idx) * tile_size);
3451
+
3452
+ _tile_loadd(0, a_tile_upper.data, 64);
3453
+ _tile_loadd(1, a_tile_lower.data, 64);
3454
+ _tile_loadd(2, b_tile->data, 64);
3455
+
3456
+ _tile_dpbssd(4, 0, 2);
3457
+ _tile_dpbssd(6, 1, 2);
3458
+ }
3459
+
3460
+ _tile_stored(4, c_upper_state.data, 64);
3461
+ _tile_stored(6, c_lower_state.data, 64);
3462
+
3463
+ nk_dots_e2m3_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
3464
+ c_stride_elements, rows_in_upper_tile, 16);
3465
+ if (rows_in_lower_tile > 0) {
3466
+ nk_dots_e2m3_store_sapphireamx_(&c_lower_state,
3467
+ c + (row_block_start + 16) * c_stride_elements + col_start,
3468
+ c_stride_elements, rows_in_lower_tile, 16);
3469
+ }
3470
+ }
3471
+
3472
+ // Handle column-edge (remaining columns < 16) using AMX with partial tiles
3473
+ if (column_remainder_count > 0) {
3474
+ nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
3475
+ nk_dots_i8_a16x64_sapphireamx_t b_as_a;
3476
+ nk_dots_i8_b64x16_sapphireamx_t b_tile;
3477
+
3478
+ _tile_zero(4);
3479
+ _tile_zero(6);
3480
+
3481
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3482
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3483
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3484
+
3485
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3486
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
3487
+ if (rows_in_lower_tile > 0) {
3488
+ nk_dots_e2m3_load_a_sapphireamx_(&a_tile_lower,
3489
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3490
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
3491
+ }
3492
+
3493
+ // B edge data is already in I8 format
3494
+ nk_dots_i8_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
3495
+ valid_depth);
3496
+ nk_dots_pack_i8_transposed_sapphireamx_(&b_as_a, &b_tile);
3497
+
3498
+ _tile_loadd(0, a_tile_upper.data, 64);
3499
+ _tile_loadd(1, a_tile_lower.data, 64);
3500
+ _tile_loadd(2, b_tile.data, 64);
3501
+
3502
+ _tile_dpbssd(4, 0, 2);
3503
+ _tile_dpbssd(6, 1, 2);
3504
+ }
3505
+
3506
+ _tile_stored(4, c_upper_state.data, 64);
3507
+ _tile_stored(6, c_lower_state.data, 64);
3508
+
3509
+ nk_dots_e2m3_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
3510
+ c_stride_elements, rows_in_upper_tile, column_remainder_count);
3511
+ if (rows_in_lower_tile > 0) {
3512
+ nk_dots_e2m3_store_sapphireamx_(&c_lower_state,
3513
+ c + (row_block_start + 16) * c_stride_elements + full_cols,
3514
+ c_stride_elements, rows_in_lower_tile, column_remainder_count);
3515
+ }
3516
+ }
3517
+ }
3518
+
3519
+ _tile_release();
3520
+ }
3521
+
3522
+ NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
3523
+ nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
3524
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
3525
+ nk_size_t row_start, nk_size_t row_count) {
3526
+
3527
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
3528
+
3529
+ // Handle row slicing: compute rows [row_start, row_end)
3530
+ nk_size_t const row_end = (row_count == 0)
3531
+ ? n_vectors
3532
+ : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
3533
+
3534
+ // Round depth up to multiple of 192 (3 tiles x 64 elements)
3535
+ nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
3536
+ nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
3537
+
3538
+ nk_dots_i8_a16x64_sapphireamx_t a_tiles[3];
3539
+ nk_dots_i8_a16x64_sapphireamx_t b_src_tiles[3];
3540
+ nk_dots_i8_b64x16_sapphireamx_t b_tiles[3];
3541
+ nk_dots_i8_state_sapphireamx_t state;
3542
+
3543
+ nk_amx_tile_configure_sapphireamx_();
3544
+
3545
+ for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
3546
+ nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
3547
+
3548
+ for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
3549
+ nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
3550
+
3551
+ nk_dots_i8_init_sapphireamx_(&state);
3552
+
3553
+ for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
3554
+ nk_size_t const depth_base = depth_group_idx * 192;
3555
+
3556
+ for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
3557
+ nk_size_t const depth_start = depth_base + tile_idx * 64;
3558
+ nk_size_t const valid_depth = (depth_start + 64 <= depth)
3559
+ ? 64
3560
+ : (depth > depth_start ? depth - depth_start : 0);
3561
+
3562
+ nk_dots_e2m3_load_a_sapphireamx_( //
3563
+ &a_tiles[tile_idx], //
3564
+ vectors + row_tile * stride + depth_start, //
3565
+ stride, valid_rows, valid_depth);
3566
+
3567
+ if (row_tile == col_tile) {
3568
+ nk_dots_pack_i8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
3569
+ }
3570
+ else {
3571
+ nk_dots_e2m3_load_a_sapphireamx_( //
3572
+ &b_src_tiles[tile_idx], //
3573
+ vectors + col_tile * stride + depth_start, //
3574
+ stride, valid_cols, valid_depth);
3575
+ nk_dots_pack_i8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
3576
+ }
3577
+ }
3578
+
3579
+ nk_dots_i8_update_sapphireamx_( //
3580
+ &state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
3581
+ }
3582
+
3583
+ nk_dots_e2m3_store_sapphireamx_( //
3584
+ &state, result + row_tile * result_stride_elements + col_tile, //
3585
+ result_stride_elements, valid_rows, valid_cols);
3586
+ }
3587
+ }
3588
+ }
3589
+
3590
+ #pragma endregion // Micro Precision E2M3
3591
+
3592
+ #pragma region Micro Precision E3M2
3593
+
3594
+ /* Load E3M2 A tile with FP8 to BF16 conversion */
3595
+ NK_INTERNAL void nk_dots_e3m2_load_a_sapphireamx_( //
3596
+ nk_dots_bf16_a16x32_sapphireamx_t *a_tile, //
3597
+ nk_e3m2_t const *src, nk_size_t src_stride, //
3598
+ nk_size_t valid_rows, nk_size_t valid_cols) {
3599
+
3600
+ __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
3601
+ __m512i zero = _mm512_setzero_si512();
3602
+
3603
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
3604
+ if (row_idx < valid_rows) {
3605
+ __m256i e3m2_row = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
3606
+ __m512i bf16_row = nk_e3m2x32_to_bf16x32_icelake_(e3m2_row);
3607
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row);
3608
+ }
3609
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
3610
+ }
3611
+ nk_compiler_barrier_sapphireamx_();
3612
+ }
3613
+
3614
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_sapphireamx(nk_size_t column_count, nk_size_t depth) {
3615
+ return nk_dots_packed_size_bf16_sapphireamx(column_count, depth);
3616
+ }
3617
+
3618
+ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
3619
+ nk_e3m2_t const *b, nk_size_t column_count, nk_size_t depth, //
3620
+ nk_size_t b_stride, void *b_packed) {
3621
+
3622
+ nk_size_t const tmm_rows = 16;
3623
+ nk_size_t const tmm_cols = 32;
3624
+ nk_size_t const tile_elements = 512;
3625
+ nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
3626
+
3627
+ nk_size_t const column_tiles_count = column_count / tmm_rows;
3628
+ nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
3629
+ nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
3630
+ nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
3631
+
3632
+ nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
3633
+ header->full_column_tiles = (nk_u32_t)column_tiles_count;
3634
+ header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
3635
+ header->column_remainder_count = (nk_u32_t)column_remainder_count;
3636
+
3637
+ nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
3638
+ nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
3639
+ header->column_edge_offset = (nk_u32_t)column_edge_offset;
3640
+
3641
+ nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
3642
+ nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
3643
+
3644
+ for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
3645
+
3646
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
3647
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3648
+ nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
3649
+ nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
3650
+
3651
+ nk_size_t const src_row_start = column_tile_idx * tmm_rows;
3652
+ nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
3653
+ nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
3654
+ : (depth - src_column_start);
3655
+
3656
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
3657
+ nk_size_t src_row = src_row_start + row_idx;
3658
+ __mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
3659
+ __m256i e3m2_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
3660
+ __m512i bf16_row = nk_e3m2x32_to_bf16x32_icelake_(e3m2_row);
3661
+ nk_bf16_t bf16_buf[32];
3662
+ _mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
3663
+ for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
3664
+ nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
3665
+ tile_output[dst_idx] = bf16_buf[column_idx];
3666
+ }
3667
+ }
3668
+ }
3669
+ }
3670
+
3671
+ if (column_remainder_count > 0) {
3672
+ nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
3673
+ for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
3674
+ for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
3675
+ nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
3676
+ __mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
3677
+ __m256i e3m2_chunk = _mm256_maskz_loadu_epi8(
3678
+ column_mask, b + (remainder_start_row + row_idx) * b_stride + column_idx);
3679
+ __m512i bf16_chunk = nk_e3m2x32_to_bf16x32_icelake_(e3m2_chunk);
3680
+ _mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask, bf16_chunk);
3681
+ }
3682
+ }
3683
+ }
3684
+
3685
+ // Compute and store per-column norms for angular/euclidean distance
3686
+ nk_size_t norms_offset = column_edge_offset +
3687
+ (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_bf16_t) : 0);
3688
+ header->norms_byte_offset = (nk_u32_t)norms_offset;
3689
+ nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
3690
+ for (nk_size_t col = 0; col < column_count; col++)
3691
+ norms[col] = nk_dots_reduce_sumsq_e3m2_(b + col * b_stride, depth);
3692
+ }
3693
+
3694
+ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
3695
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
3696
+ nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
3697
+ nk_unused_(cols_count);
3698
+
3699
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
3700
+ nk_size_t const column_tiles_count = header->full_column_tiles;
3701
+ nk_size_t const depth_tiles_count = header->full_depth_tiles;
3702
+ nk_size_t const column_remainder_count = header->column_remainder_count;
3703
+
3704
+ nk_bf16_t const *b_tiles_base = (nk_bf16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
3705
+ nk_bf16_t const *col_edge_ptr = (nk_bf16_t const *)((char const *)b_packed + header->column_edge_offset);
3706
+
3707
+ nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
3708
+ nk_size_t const tile_depth = 32;
3709
+ nk_size_t const tile_size = 512;
3710
+ nk_size_t const full_cols = column_tiles_count * 16;
3711
+
3712
+ nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
3713
+ nk_size_t const col_blocks_count = column_tiles_count / 2;
3714
+
3715
+ if (depth_tiles_count == 0) return;
3716
+
3717
+ nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
3718
+ nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
3719
+
3720
+ nk_size_t const full_depth_tiles_count = depth / tile_depth;
3721
+ nk_size_t const depth_remainder = depth % tile_depth;
3722
+
3723
+ nk_amx_tile_configure_sapphireamx_();
3724
+
3725
+ // Loop order: row_blocks outer, col_blocks inner
3726
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
3727
+ nk_size_t const row_block_start = row_block_idx * 32;
3728
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
3729
+ nk_size_t const is_full_row_block = (valid_rows_count == 32);
3730
+ nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
3731
+ nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
3732
+
3733
+ for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
3734
+ nk_size_t const col_block_start = column_block_idx * 32;
3735
+ nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
3736
+ nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
3737
+
3738
+ // Zero accumulators (TMM4-7 stay resident across entire depth loop)
3739
+ _tile_zero(4);
3740
+ _tile_zero(5);
3741
+ _tile_zero(6);
3742
+ _tile_zero(7);
3743
+
3744
+ // FP8 always uses buffered load for E3M2 -> BF16 conversion
3745
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3746
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3747
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3748
+
3749
+ // Load A with FP8 -> BF16 conversion
3750
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3751
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
3752
+ if (rows_in_lower_tile > 0) {
3753
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_lower,
3754
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3755
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
3756
+ }
3757
+
3758
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
3759
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
3760
+ (b_column_left_base + depth_tile_idx) * tile_size);
3761
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
3762
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
3763
+ (b_column_right_base + depth_tile_idx) * tile_size);
3764
+
3765
+ _tile_loadd(0, a_tile_upper.data, 64);
3766
+ _tile_loadd(1, a_tile_lower.data, 64);
3767
+ _tile_loadd(2, b_tile_left->data, 64);
3768
+ _tile_loadd(3, b_tile_right->data, 64);
3769
+
3770
+ _tile_dpbf16ps(4, 0, 2);
3771
+ _tile_dpbf16ps(5, 0, 3);
3772
+ _tile_dpbf16ps(6, 1, 2);
3773
+ _tile_dpbf16ps(7, 1, 3);
3774
+ }
3775
+
3776
+ // Store accumulators to output (once per output block)
3777
+ if (is_full_row_block) {
3778
+ nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
3779
+ _tile_stored(4, c_block, c_stride_bytes);
3780
+ _tile_stored(5, c_block + 16, c_stride_bytes);
3781
+ _tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
3782
+ _tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
3783
+ }
3784
+ else {
3785
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
3786
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
3787
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
3788
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
3789
+ nk_dots_bf16_output2x2_sapphireamx_(&c_accum_buffer,
3790
+ c + row_block_start * c_stride_elements + col_block_start,
3791
+ c_stride_elements, valid_rows_count, 32);
3792
+ }
3793
+ }
3794
+
3795
+ // Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
3796
+ if (column_tiles_count % 2 == 1) {
3797
+ nk_size_t const column_tile_idx = column_tiles_count - 1;
3798
+ nk_size_t const col_start = column_tile_idx * 16;
3799
+ nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
3800
+
3801
+ nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
3802
+ _tile_zero(4);
3803
+ _tile_zero(6);
3804
+
3805
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3806
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3807
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3808
+
3809
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3810
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
3811
+ if (rows_in_lower_tile > 0) {
3812
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_lower,
3813
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3814
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
3815
+ }
3816
+
3817
+ nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
3818
+ (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
3819
+ (b_column_base + depth_tile_idx) * tile_size);
3820
+
3821
+ _tile_loadd(0, a_tile_upper.data, 64);
3822
+ _tile_loadd(1, a_tile_lower.data, 64);
3823
+ _tile_loadd(2, b_tile->data, 64);
3824
+
3825
+ _tile_dpbf16ps(4, 0, 2);
3826
+ _tile_dpbf16ps(6, 1, 2);
3827
+ }
3828
+
3829
+ _tile_stored(4, c_upper_state.data, 64);
3830
+ _tile_stored(6, c_lower_state.data, 64);
3831
+
3832
+ nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
3833
+ c_stride_elements, rows_in_upper_tile, 16);
3834
+ if (rows_in_lower_tile > 0) {
3835
+ nk_dots_bf16_store_sapphireamx_(&c_lower_state,
3836
+ c + (row_block_start + 16) * c_stride_elements + col_start,
3837
+ c_stride_elements, rows_in_lower_tile, 16);
3838
+ }
3839
+ }
3840
+
3841
+ // Handle column-edge (remaining columns < 16) using AMX with partial tiles
3842
+ if (column_remainder_count > 0) {
3843
+ nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
3844
+ nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
3845
+ nk_dots_bf16_b32x16_sapphireamx_t b_tile;
3846
+
3847
+ _tile_zero(4);
3848
+ _tile_zero(6);
3849
+
3850
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
3851
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
3852
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
3853
+
3854
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
3855
+ a_stride_bytes, rows_in_upper_tile, valid_depth);
3856
+ if (rows_in_lower_tile > 0) {
3857
+ nk_dots_e3m2_load_a_sapphireamx_(&a_tile_lower,
3858
+ a + (row_block_start + 16) * a_stride_bytes + depth_offset,
3859
+ a_stride_bytes, rows_in_lower_tile, valid_depth);
3860
+ }
3861
+
3862
+ nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
3863
+ valid_depth);
3864
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
3865
+
3866
+ _tile_loadd(0, a_tile_upper.data, 64);
3867
+ _tile_loadd(1, a_tile_lower.data, 64);
3868
+ _tile_loadd(2, b_tile.data, 64);
3869
+
3870
+ _tile_dpbf16ps(4, 0, 2);
3871
+ _tile_dpbf16ps(6, 1, 2);
3872
+ }
3873
+
3874
+ _tile_stored(4, c_upper_state.data, 64);
3875
+ _tile_stored(6, c_lower_state.data, 64);
3876
+
3877
+ nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
3878
+ c_stride_elements, rows_in_upper_tile, column_remainder_count);
3879
+ if (rows_in_lower_tile > 0) {
3880
+ nk_dots_bf16_store_sapphireamx_(&c_lower_state,
3881
+ c + (row_block_start + 16) * c_stride_elements + full_cols,
3882
+ c_stride_elements, rows_in_lower_tile, column_remainder_count);
3883
+ }
3884
+ }
3885
+ }
3886
+
3887
+ _tile_release();
3888
+ }
3889
+
3890
+ NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx( //
3891
+ nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
3892
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
3893
+ nk_size_t row_start, nk_size_t row_count) {
3894
+
3895
+ nk_size_t const stride_elements = stride; // sizeof(nk_e3m2_t) == 1, so bytes == elements
3896
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
3897
+
3898
+ // Handle row slicing: compute rows [row_start, row_end)
3899
+ nk_size_t const row_end = (row_count == 0)
3900
+ ? n_vectors
3901
+ : (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
3902
+
3903
+ // Round depth up to multiple of 96 (3 tiles x 32 bf16 elements)
3904
+ nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
3905
+ nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
3906
+
3907
+ nk_dots_bf16_a16x32_sapphireamx_t a_tiles[3];
3908
+ nk_dots_bf16_a16x32_sapphireamx_t b_src_tiles[3];
3909
+ nk_dots_bf16_b32x16_sapphireamx_t b_tiles[3];
3910
+ nk_dots_bf16_state_sapphireamx_t state;
3911
+
3912
+ nk_amx_tile_configure_sapphireamx_();
3913
+
3914
+ for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
3915
+ nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
3916
+
3917
+ for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
3918
+ nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
3919
+
3920
+ nk_dots_bf16_init_sapphireamx_(&state);
3921
+
3922
+ for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
3923
+ nk_size_t const depth_base = depth_group_idx * 96;
3924
+
3925
+ for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
3926
+ nk_size_t const depth_start = depth_base + tile_idx * 32;
3927
+ nk_size_t const valid_depth = (depth_start + 32 <= depth)
3928
+ ? 32
3929
+ : (depth > depth_start ? depth - depth_start : 0);
3930
+
3931
+ nk_dots_e3m2_load_a_sapphireamx_( //
3932
+ &a_tiles[tile_idx], //
3933
+ vectors + row_tile * stride_elements + depth_start, //
3934
+ stride_elements, valid_rows, valid_depth);
3935
+
3936
+ if (row_tile == col_tile) {
3937
+ nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
3938
+ }
3939
+ else {
3940
+ nk_dots_e3m2_load_a_sapphireamx_( //
3941
+ &b_src_tiles[tile_idx], //
3942
+ vectors + col_tile * stride_elements + depth_start, //
3943
+ stride_elements, valid_cols, valid_depth);
3944
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
3945
+ }
3946
+ }
3947
+
3948
+ nk_dots_bf16_update_sapphireamx_( //
3949
+ &state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
3950
+ }
3951
+
3952
+ nk_dots_bf16_store_sapphireamx_( //
3953
+ &state, result + row_tile * result_stride_elements + col_tile, //
3954
+ result_stride_elements, valid_rows, valid_cols);
3955
+ }
3956
+ }
3957
+ }
3958
+
3959
+ #pragma endregion // Micro Precision E3M2
3960
+
3961
+ #if defined(__clang__)
3962
+ #pragma clang attribute pop
3963
+ #elif defined(__GNUC__)
3964
+ #pragma GCC pop_options
3965
+ #endif
3966
+
3967
+ #if defined(__cplusplus)
3968
+ } // extern "C"
3969
+ #endif
3970
+
3971
+ #endif // NK_TARGET_SAPPHIREAMX
3972
+ #endif // NK_TARGET_X86_
3973
+ #endif // NK_DOTS_SAPPHIREAMX_H