numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -79,6 +79,13 @@
79
79
  #include "numkong/spatial/serial.h" // `nk_f32_sqrt_serial`
80
80
  #include "numkong/reduce.h" // `nk_reduce_moments_*`
81
81
 
82
+ /* GCC's -Wstringop-overflow produces false positives on the padded accumulator arrays
83
+ * in nk_define_cross_symmetric_ macro expansions (accumulators[4][7] with runtime indexing). */
84
+ #if defined(__GNUC__) && !defined(__clang__)
85
+ #pragma GCC diagnostic push
86
+ #pragma GCC diagnostic ignored "-Wstringop-overflow"
87
+ #endif
88
+
82
89
  #if defined(__cplusplus)
83
90
  extern "C" {
84
91
  #endif
@@ -264,82 +271,59 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
264
271
  }
265
272
 
266
273
  /**
267
- * @brief Generates function to pack and optionally convert B matrix for efficient GEMM inner loops.
268
- *
269
- * Packing serves two performance-critical purposes:
270
- *
271
- * 1. Type conversion (input_type → intermediate_type): For mixed-precision GEMM, convert B values
272
- * once during packing rather than repeatedly in tight inner loops. Example: F16 → F32 conversion
273
- * happens once per value instead of once per (row of A × value of B) access. This amortizes
274
- * conversion cost across all rows of A.
274
+ * @brief Generates pack function using SIMD load/store helpers.
275
275
  *
276
- * 2. Cache optimization: Pad depth to break power-of-2 byte strides that cause cache associativity
277
- * conflicts. Example: depth = 8192, F32 stride = 32,768 bytes (power-of-2) maps to same cache sets,
278
- * causing conflict misses. Padding to 8200 → stride = 32,800 bytes (non-power-of-2) distributes
279
- * accesses across more cache sets.
276
+ * Packs the B matrix into padded row-major layout with optional type conversion,
277
+ * using vectorized load/store for the bulk copy and a small scalar tail for padding.
280
278
  *
281
- * Input layout: B[column_count, depth] stored row-major with b_stride_in_bytes between rows
282
- * Output layout: B_packed[column_count, depth_padded] - simple column-major, no grouping
283
- * Addressing: B_packed[j, k] = packed_data[j × depth_padded + k]
284
- *
285
- * Depth padding: Round up to `depth_simd_dimensions` multiple, then add `depth_simd_dimensions`
286
- * if stride is power-of-2. Zero-initializes entire buffer before copying to handle padding safely.
287
- *
288
- * @param api_name Operation name (hammings, dots)
289
- * @param input_type_name Original type's name of B matrix values (i4, f16, bf16, e4m3, e5m2, f32, etc.)
290
- * @param isa_suffix Platform Instruct Set Architecture suffix (serial, haswell, icelake, etc.)
291
- * @param input_type Original type of B matrix values (i4x2, f16, bf16, e4m3, e5m2, f32, etc.)
292
- * @param intermediate_type Internal storage type in packed buffer (often bf16 or f32 for mixed precision)
293
- * @param convert_value_fn Element conversion function: void fn(input_type const*, intermediate_type*)
294
- * @param norm_value_type Type of per-column norm values (f32, f64, u32) appended after packed data
295
- * @param compute_norm_fn Function: norm_value_type fn(input_value_type const*, nk_size_t count)
296
- * @param depth_simd_dimensions SIMD vector width in values for depth padding alignment
297
- * @param dimensions_per_value Number of logical dimensions in a single value of input_type.
279
+ * @param vec_type SIMD vector type (nk_b512_vec_t, nk_b256_vec_t, nk_b128_vec_t)
280
+ * @param load_fn Full load: void fn(void const*, vec_type*)
281
+ * @param partial_load_fn Masked/partial load: void fn(void const*, vec_type*, nk_size_t)
282
+ * @param store_fn Full store: void fn(vec_type const*, void*)
283
+ * @param partial_store_fn Masked/partial store: void fn(vec_type const*, void*, nk_size_t)
284
+ * @param simd_width Elements per SIMD load/store operation
298
285
  */
299
- #define nk_define_cross_pack_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, \
300
- convert_value_fn, norm_value_type, compute_norm_fn, depth_simd_dimensions, \
301
- dimensions_per_value) \
286
+ #define nk_define_cross_pack_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, vec_type, \
287
+ load_fn, partial_load_fn, store_fn, partial_store_fn, simd_width, norm_value_type, \
288
+ compute_norm_fn, depth_simd_dimensions, dimensions_per_value) \
302
289
  NK_PUBLIC void nk_##api_name##_pack_##input_type_name##_##isa_suffix( \
303
290
  nk_##input_value_type##_t const *b, nk_size_t column_count, nk_size_t depth, nk_size_t b_stride_in_bytes, \
304
291
  void *b_packed) { \
305
- /* Use identical padding calculation as pack_size */ \
306
292
  nk_size_t depth_dimensions_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions); \
307
293
  nk_size_t depth_values_padded = nk_size_divide_round_up_(depth_dimensions_padded, dimensions_per_value); \
308
- \
309
- /* Power-of-2 breaking (same as pack_size) */ \
310
294
  nk_size_t const stride_bytes = depth_values_padded * sizeof(nk_##packed_value_type##_t); \
311
- if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) { \
295
+ if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) \
312
296
  depth_values_padded += nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
313
- } \
314
- \
315
- /* Calculate input depth in values */ \
316
297
  nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
317
298
  \
318
- /* Store dimensions in header */ \
319
299
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed; \
320
300
  header->column_count = (nk_u32_t)column_count; \
321
- header->depth_dimensions = (nk_u32_t)depth; /* depth in dimensions (nibbles for i4/u4) */ \
322
- header->depth_padded_values = (nk_u32_t)depth_values_padded; /* padded depth in VALUES (bytes for i4/u4) */ \
301
+ header->depth_dimensions = (nk_u32_t)depth; \
302
+ header->depth_padded_values = (nk_u32_t)depth_values_padded; \
323
303
  \
324
304
  nk_##packed_value_type##_t *packed = (nk_##packed_value_type##_t *)((char *)b_packed + \
325
305
  sizeof(nk_cross_packed_buffer_header_t)); \
306
+ nk_size_t const full_chunks = depth_in_values / (simd_width); \
307
+ nk_size_t const remainder = depth_in_values % (simd_width); \
326
308
  \
327
- /* Zero entire buffer for depth padding */ \
328
- nk_size_t const total_values = column_count * depth_values_padded; \
329
- for (nk_size_t i = 0; i < total_values; ++i) packed[i] = 0; \
330
- \
331
- /* Copy/convert B[column_count, depth] to packed[column_count, depth_padded] - simple column-major */ \
332
309
  for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
333
- nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
334
310
  nk_##input_value_type##_t const *source_row = \
335
311
  (nk_##input_value_type##_t const *)((char const *)b + column_index * b_stride_in_bytes); \
336
- for (nk_size_t depth_index = 0; depth_index < depth_in_values; ++depth_index) { \
337
- convert_value_fn(&source_row[depth_index], &destination_row[depth_index]); \
312
+ nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
313
+ for (nk_size_t chunk = 0; chunk < full_chunks; ++chunk) { \
314
+ vec_type vec; \
315
+ load_fn(source_row + chunk * (simd_width), &vec); \
316
+ store_fn(&vec, destination_row + chunk * (simd_width)); \
338
317
  } \
339
- /* Padding values already zeroed above */ \
318
+ if (remainder > 0) { \
319
+ vec_type vec; \
320
+ partial_load_fn(source_row + full_chunks * (simd_width), &vec, remainder); \
321
+ partial_store_fn(&vec, destination_row + full_chunks * (simd_width), remainder); \
322
+ } \
323
+ for (nk_size_t pad = depth_in_values; pad < depth_values_padded; ++pad) destination_row[pad] = 0; \
340
324
  } \
341
325
  \
342
- /* Append per-column norms after packed data */ \
326
+ nk_size_t const total_values = column_count * depth_values_padded; \
343
327
  nk_##norm_value_type##_t *norms = (nk_##norm_value_type##_t *)(packed + total_values); \
344
328
  for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
345
329
  nk_##input_value_type##_t const *source_row = \
@@ -372,42 +356,51 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
372
356
  }
373
357
 
374
358
  /**
375
- * @brief Generates function to pack B matrix with BOTH norms and column sums for compensated GEMM.
376
- *
377
- * Like nk_define_cross_pack_ but uses compute_moments_fn(data, count, &sum, &norm) to compute
378
- * both sum and norm in a single pass, storing both after the packed data.
379
- * Layout: [ Header ] [ Packed data ] [ Norms ] [ Column sums ]
359
+ * @brief Like nk_define_cross_pack_ but stores both per-column norms AND column sums.
360
+ * Layout: [ Header 64B ] [ Packed data ] [ Norms (norm_type) ] [ Column sums (sum_type) ]
380
361
  */
381
362
  #define nk_define_cross_compensated_pack_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, \
382
- convert_value_fn, sum_value_type, norm_value_type, compute_moments_fn, \
383
- depth_simd_dimensions, dimensions_per_value) \
363
+ vec_type, load_fn, partial_load_fn, store_fn, partial_store_fn, simd_width, \
364
+ sum_value_type, norm_value_type, compute_moments_fn, depth_simd_dimensions, \
365
+ dimensions_per_value) \
384
366
  NK_PUBLIC void nk_##api_name##_pack_##input_type_name##_##isa_suffix( \
385
367
  nk_##input_value_type##_t const *b, nk_size_t column_count, nk_size_t depth, nk_size_t b_stride_in_bytes, \
386
368
  void *b_packed) { \
387
369
  nk_size_t depth_dimensions_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions); \
388
370
  nk_size_t depth_values_padded = nk_size_divide_round_up_(depth_dimensions_padded, dimensions_per_value); \
389
371
  nk_size_t const stride_bytes = depth_values_padded * sizeof(nk_##packed_value_type##_t); \
390
- if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) { \
372
+ if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) \
391
373
  depth_values_padded += nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
392
- } \
393
374
  nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
375
+ \
394
376
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed; \
395
377
  header->column_count = (nk_u32_t)column_count; \
396
378
  header->depth_dimensions = (nk_u32_t)depth; \
397
379
  header->depth_padded_values = (nk_u32_t)depth_values_padded; \
380
+ \
398
381
  nk_##packed_value_type##_t *packed = (nk_##packed_value_type##_t *)((char *)b_packed + \
399
382
  sizeof(nk_cross_packed_buffer_header_t)); \
400
- nk_size_t const total_values = column_count * depth_values_padded; \
401
- for (nk_size_t i = 0; i < total_values; ++i) packed[i] = 0; \
383
+ nk_size_t const full_chunks = depth_in_values / (simd_width); \
384
+ nk_size_t const remainder = depth_in_values % (simd_width); \
385
+ \
402
386
  for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
403
- nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
404
387
  nk_##input_value_type##_t const *source_row = \
405
388
  (nk_##input_value_type##_t const *)((char const *)b + column_index * b_stride_in_bytes); \
406
- for (nk_size_t depth_index = 0; depth_index < depth_in_values; ++depth_index) { \
407
- convert_value_fn(&source_row[depth_index], &destination_row[depth_index]); \
389
+ nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
390
+ for (nk_size_t chunk = 0; chunk < full_chunks; ++chunk) { \
391
+ vec_type vec; \
392
+ load_fn(source_row + chunk * (simd_width), &vec); \
393
+ store_fn(&vec, destination_row + chunk * (simd_width)); \
394
+ } \
395
+ if (remainder > 0) { \
396
+ vec_type vec; \
397
+ partial_load_fn(source_row + full_chunks * (simd_width), &vec, remainder); \
398
+ partial_store_fn(&vec, destination_row + full_chunks * (simd_width), remainder); \
408
399
  } \
400
+ for (nk_size_t pad = depth_in_values; pad < depth_values_padded; ++pad) destination_row[pad] = 0; \
409
401
  } \
410
- /* Norms first (same offset as non-compensated pack), then column sums */ \
402
+ \
403
+ nk_size_t const total_values = column_count * depth_values_padded; \
411
404
  nk_##norm_value_type##_t *norms = (nk_##norm_value_type##_t *)(packed + total_values); \
412
405
  nk_##sum_value_type##_t *col_sums = (nk_##sum_value_type##_t *)(norms + column_count); \
413
406
  for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
@@ -1246,9 +1239,9 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
1246
1239
  nk_##packed_value_type##_t const *bp5 = packed_data + (tc + 5) * depth_padded; \
1247
1240
  nk_##packed_value_type##_t const *bp6 = packed_data + (tc + 6) * depth_padded; \
1248
1241
  nk_##packed_value_type##_t const *bp7 = packed_data + (tc + 7) * depth_padded; \
1249
- result_vec_type b_sum_lo, b_sum_hi; \
1250
- load_sum_fn(b_sums + tc, &b_sum_lo); \
1251
- load_sum_fn(b_sums + tc + 4, &b_sum_hi); \
1242
+ result_vec_type b_sum_low, b_sum_high; \
1243
+ load_sum_fn(b_sums + tc, &b_sum_low); \
1244
+ load_sum_fn(b_sums + tc + 4, &b_sum_high); \
1252
1245
  for (nk_size_t ri = rb2; ri < re2; ++ri) { \
1253
1246
  state_type s0, s1, s2, s3, s4, s5, s6, s7; \
1254
1247
  init_accumulator_fn(&s0), init_accumulator_fn(&s1), init_accumulator_fn(&s2), \
@@ -1277,9 +1270,9 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
1277
1270
  result_vec_type rv; \
1278
1271
  nk_##result_value_type##_t *c_row = (nk_##result_value_type##_t *)((char *)c_matrix + \
1279
1272
  ri * c_stride_in_bytes); \
1280
- compensated_finalize_fn(&s0, &s1, &s2, &s3, depth, a_sum_val, b_sum_lo, &rv); \
1273
+ compensated_finalize_fn(&s0, &s1, &s2, &s3, depth, a_sum_val, b_sum_low, &rv); \
1281
1274
  store_fn(&rv, c_row + tc); \
1282
- compensated_finalize_fn(&s4, &s5, &s6, &s7, depth, a_sum_val, b_sum_hi, &rv); \
1275
+ compensated_finalize_fn(&s4, &s5, &s6, &s7, depth, a_sum_val, b_sum_high, &rv); \
1283
1276
  store_fn(&rv, c_row + tc + 4); \
1284
1277
  } \
1285
1278
  } \
@@ -1893,8 +1886,9 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
1893
1886
  } \
1894
1887
  } \
1895
1888
  NK_PUBLIC void nk_##api_name##_symmetric_##input_type_name##_##isa_suffix( \
1896
- nk_##input_value_type##_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, \
1897
- nk_##result_value_type##_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) { \
1889
+ nk_##input_value_type##_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, \
1890
+ nk_##result_value_type##_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, \
1891
+ nk_size_t row_count) { \
1898
1892
  nk_size_t const macro_tile_size = 32; \
1899
1893
  nk_size_t const row_block_size = 128; /* L2 cache blocking */ \
1900
1894
  nk_size_t const column_block_size = 2048; /* L3 cache blocking */ \
@@ -1904,13 +1898,13 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
1904
1898
  nk_size_t const remainder_depth = depth_in_values - aligned_depth; \
1905
1899
  nk_size_t const remainder_dimensions = depth - depth_dimensions_aligned; \
1906
1900
  nk_size_t const depth_step = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
1907
- nk_size_t const result_stride_values = result_stride / sizeof(nk_##result_value_type##_t); \
1908
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors; \
1901
+ nk_size_t const result_stride_values = result_stride_in_bytes / sizeof(nk_##result_value_type##_t); \
1902
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count; \
1909
1903
  \
1910
1904
  /* Process upper triangle with L3/L2/L1 blocking (column blocks → row blocks → 32×32 macro-tiles) */ \
1911
- for (nk_size_t j_block = 0; j_block < n_vectors; j_block += column_block_size) { \
1912
- nk_size_t j_block_end = (j_block + column_block_size < n_vectors) ? j_block + column_block_size \
1913
- : n_vectors; \
1905
+ for (nk_size_t j_block = 0; j_block < vectors_count; j_block += column_block_size) { \
1906
+ nk_size_t j_block_end = (j_block + column_block_size < vectors_count) ? j_block + column_block_size \
1907
+ : vectors_count; \
1914
1908
  \
1915
1909
  for (nk_size_t i_block = row_start; i_block < row_end; i_block += row_block_size) { \
1916
1910
  nk_size_t i_block_end = (i_block + row_block_size < row_end) ? i_block + row_block_size : row_end; \
@@ -1933,7 +1927,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
1933
1927
  nk_##input_value_type##_t const *vec_ptrs_j[32]; \
1934
1928
  for (nk_size_t k = 0; k < macro_i_size; k++) \
1935
1929
  vec_ptrs_i[k] = (nk_##input_value_type##_t const *)((char const *)vectors + \
1936
- (i_macro + k) * stride); \
1930
+ (i_macro + k) * stride_in_bytes); \
1937
1931
  for (nk_size_t k = macro_i_size; k < 32; k++) vec_ptrs_i[k] = vec_ptrs_i[0]; \
1938
1932
  \
1939
1933
  if (i_macro == j_macro && macro_i_size == macro_j_size) { \
@@ -1947,7 +1941,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
1947
1941
  /* Off-diagonal macro-tile */ \
1948
1942
  for (nk_size_t k = 0; k < macro_j_size; k++) \
1949
1943
  vec_ptrs_j[k] = (nk_##input_value_type##_t const *)((char const *)vectors + \
1950
- (j_macro + k) * stride); \
1944
+ (j_macro + k) * stride_in_bytes); \
1951
1945
  for (nk_size_t k = macro_j_size; k < 32; k++) vec_ptrs_j[k] = vec_ptrs_j[0]; \
1952
1946
  nk_##api_name##_symmetric_offdiagonal_##input_type_name##_##isa_suffix##_( \
1953
1947
  vec_ptrs_i, vec_ptrs_j, i_macro, j_macro, macro_i_size, macro_j_size, aligned_depth, \
@@ -2365,28 +2359,29 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
2365
2359
  } \
2366
2360
  } \
2367
2361
  NK_PUBLIC void nk_##api_name##_symmetric_##input_type_name##_##isa_suffix( \
2368
- nk_##input_value_type##_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, \
2369
- nk_##result_value_type##_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) { \
2362
+ nk_##input_value_type##_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, \
2363
+ nk_##result_value_type##_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, \
2364
+ nk_size_t row_count) { \
2370
2365
  nk_size_t const macro_tile_size = 32; \
2371
2366
  nk_size_t const finalizer_batch_size = 4; \
2372
2367
  nk_size_t const row_block_size = 128; /* L2 cache blocking */ \
2373
2368
  nk_size_t const column_block_size = 2048; /* L3 cache blocking */ \
2374
2369
  \
2375
2370
  /* Stride and depth calculations */ \
2376
- nk_size_t const vectors_stride_values = stride / sizeof(nk_##input_value_type##_t); \
2377
- nk_size_t const result_stride_values = result_stride / sizeof(nk_##result_value_type##_t); \
2371
+ nk_size_t const vectors_stride_values = stride_in_bytes / sizeof(nk_##input_value_type##_t); \
2372
+ nk_size_t const result_stride_values = result_stride_in_bytes / sizeof(nk_##result_value_type##_t); \
2378
2373
  nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
2379
2374
  nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
2380
2375
  nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
2381
2376
  nk_size_t const remainder_depth = depth_in_values - aligned_depth; \
2382
2377
  nk_size_t const remainder_dimensions = depth - depth_dimensions_aligned; \
2383
2378
  nk_size_t const depth_step_values = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
2384
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors; \
2379
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count; \
2385
2380
  \
2386
2381
  /* Process upper triangle with L3/L2/L1 blocking (column blocks → row blocks → 32×32 macro-tiles) */ \
2387
- for (nk_size_t j_block = 0; j_block < n_vectors; j_block += column_block_size) { \
2388
- nk_size_t j_block_end = (j_block + column_block_size < n_vectors) ? j_block + column_block_size \
2389
- : n_vectors; \
2382
+ for (nk_size_t j_block = 0; j_block < vectors_count; j_block += column_block_size) { \
2383
+ nk_size_t j_block_end = (j_block + column_block_size < vectors_count) ? j_block + column_block_size \
2384
+ : vectors_count; \
2390
2385
  \
2391
2386
  for (nk_size_t i_block = row_start; i_block < row_end; i_block += row_block_size) { \
2392
2387
  nk_size_t i_block_end = (i_block + row_block_size < row_end) ? i_block + row_block_size : row_end; \
@@ -2451,9 +2446,9 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
2451
2446
  /* F64 GEMM: depth_simd_dimensions=2 (2 f64s = 16 bytes) */
2452
2447
  nk_define_cross_pack_size_(dots, f64, serial, f64, f64, /*norm_value_type=*/f64, /*depth_simd_dimensions=*/2,
2453
2448
  /*dimensions_per_value=*/1)
2454
- nk_define_cross_pack_(dots, f64, serial, f64, f64, nk_assign_from_to_, /*norm_value_type=*/f64,
2455
- nk_dots_reduce_sumsq_f64_,
2456
- /*depth_simd_dimensions=*/2, /*dimensions_per_value=*/1)
2449
+ nk_define_cross_pack_(dots, f64, serial, f64, f64, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b64x2_serial_,
2450
+ nk_store_b128_serial_, nk_partial_store_b64x2_serial_, /*simd_width=*/2, /*norm_value_type=*/f64,
2451
+ nk_dots_reduce_sumsq_f64_, /*depth_simd_dimensions=*/2, /*dimensions_per_value=*/1)
2457
2452
  nk_define_cross_symmetric_(dots, f64, serial, f64, f64, nk_b128_vec_t, nk_dot_f64x2_state_serial_t, nk_b256_vec_t,
2458
2453
  nk_dot_f64x2_init_serial, nk_load_b128_serial_, nk_partial_load_b64x2_serial_,
2459
2454
  nk_dot_f64x2_update_serial, nk_dot_f64x2_finalize_serial, nk_store_b256_serial_,
@@ -2468,9 +2463,9 @@ nk_define_cross_packed_(dots, f64, serial, f64, f64, f64, nk_b128_vec_t, nk_dot_
2468
2463
  /* F32 GEMM: depth_simd_dimensions=4 (4 f32s = 16 bytes) */
2469
2464
  nk_define_cross_pack_size_(dots, f32, serial, f32, f32, /*norm_value_type=*/f64, /*depth_simd_dimensions=*/4,
2470
2465
  /*dimensions_per_value=*/1)
2471
- nk_define_cross_pack_(dots, f32, serial, f32, f32, nk_assign_from_to_, /*norm_value_type=*/f64,
2472
- nk_dots_reduce_sumsq_f32_,
2473
- /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
2466
+ nk_define_cross_pack_(dots, f32, serial, f32, f32, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
2467
+ nk_store_b128_serial_, nk_partial_store_b32x4_serial_, /*simd_width=*/4, /*norm_value_type=*/f64,
2468
+ nk_dots_reduce_sumsq_f32_, /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
2474
2469
  nk_define_cross_symmetric_(dots, f32, serial, f32, f64, nk_b128_vec_t, nk_dot_f32x4_state_serial_t, nk_b256_vec_t,
2475
2470
  nk_dot_f32x4_init_serial, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
2476
2471
  nk_dot_f32x4_update_serial, nk_dot_f32x4_finalize_serial, nk_store_b256_serial_,
@@ -2482,28 +2477,31 @@ nk_define_cross_packed_(dots, f32, serial, f32, f32, f64, nk_b128_vec_t, nk_dot_
2482
2477
  nk_dot_f32x4_finalize_serial, nk_store_b256_serial_, nk_partial_store_b64x4_serial_,
2483
2478
  /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
2484
2479
 
2485
- /* F16 GEMM: depth_simd_dimensions=8 (8 f16s = 16 bytes), F32 accumulator */
2486
- nk_define_cross_pack_size_(dots, f16, serial, f16, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
2480
+ /* F16 packed GEMM: pre-upcast B to f32 and process 4 logical dimensions per 128-bit step. */
2481
+ nk_define_cross_pack_size_(dots, f16, serial, f16, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/4,
2487
2482
  /*dimensions_per_value=*/1)
2488
- nk_define_cross_pack_(dots, f16, serial, f16, f16, nk_assign_from_to_, /*norm_value_type=*/f32,
2489
- nk_dots_reduce_sumsq_f16_,
2490
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
2483
+ nk_define_cross_pack_(dots, f16, serial, f16, f32, nk_b128_vec_t, nk_load_f16x4_to_f32x4_serial_,
2484
+ nk_partial_load_f16x4_to_f32x4_serial_, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
2485
+ /*simd_width=*/4, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_f16_,
2486
+ /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
2491
2487
  nk_define_cross_symmetric_(dots, f16, serial, f16, f32, nk_b128_vec_t, nk_dot_f16x8_state_serial_t, nk_b128_vec_t,
2492
2488
  nk_dot_f16x8_init_serial, nk_load_b128_serial_, nk_partial_load_b16x8_serial_,
2493
2489
  nk_dot_f16x8_update_serial, nk_dot_f16x8_finalize_serial, nk_store_b128_serial_,
2494
2490
  nk_partial_store_b32x4_serial_,
2495
2491
  /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
2496
- nk_define_cross_packed_(dots, f16, serial, f16, f16, f32, nk_b128_vec_t, nk_dot_f16x8_state_serial_t, nk_b128_vec_t,
2497
- nk_dot_f16x8_init_serial, nk_load_b128_serial_, nk_partial_load_b16x8_serial_,
2498
- nk_load_b128_serial_, nk_partial_load_b16x8_serial_, nk_dot_f16x8_update_serial,
2499
- nk_dot_f16x8_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
2500
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
2492
+ nk_define_cross_packed_(dots, f16, serial, f16, f32, f32, nk_b128_vec_t, nk_dot_through_f32x4_state_serial_t,
2493
+ nk_b128_vec_t, nk_dot_through_f32x4_init_serial, nk_load_f16x4_to_f32x4_serial_,
2494
+ nk_partial_load_f16x4_to_f32x4_serial_, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
2495
+ nk_dot_through_f32x4_update_serial, nk_dot_through_f32x4_finalize_serial, nk_store_b128_serial_,
2496
+ nk_partial_store_b32x4_serial_,
2497
+ /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
2501
2498
 
2502
2499
  /* BF16 GEMM: depth_simd_dimensions=8 (8 bf16s = 16 bytes), F32 accumulator */
2503
- nk_define_cross_pack_size_(dots, bf16, serial, bf16, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
2500
+ nk_define_cross_pack_size_(dots, bf16, serial, bf16, bf16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
2504
2501
  /*dimensions_per_value=*/1)
2505
- nk_define_cross_pack_(dots, bf16, serial, bf16, bf16, nk_assign_from_to_, /*norm_value_type=*/f32,
2506
- nk_dots_reduce_sumsq_bf16_,
2502
+ nk_define_cross_pack_(dots, bf16, serial, bf16, bf16, nk_b128_vec_t, nk_load_b128_serial_,
2503
+ nk_partial_load_b16x8_serial_, nk_store_b128_serial_, nk_partial_store_b16x8_serial_,
2504
+ /*simd_width=*/8, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_bf16_,
2507
2505
  /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
2508
2506
  nk_define_cross_symmetric_(dots, bf16, serial, bf16, f32, nk_b128_vec_t, nk_dot_bf16x8_state_serial_t, nk_b128_vec_t,
2509
2507
  nk_dot_bf16x8_init_serial, nk_load_b128_serial_, nk_partial_load_b16x8_serial_,
@@ -2519,8 +2517,10 @@ nk_define_cross_packed_(dots, bf16, serial, bf16, bf16, f32, nk_b128_vec_t, nk_d
2519
2517
  /* I8 GEMM: depth_simd_dimensions=16 (16 i8s = 16 bytes), I32 accumulator */
2520
2518
  nk_define_cross_pack_size_(dots, i8, serial, i8, i8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
2521
2519
  /*dimensions_per_value=*/1)
2522
- nk_define_cross_pack_(dots, i8, serial, i8, i8, nk_assign_from_to_, /*norm_value_type=*/u32, nk_dots_reduce_sumsq_i8_,
2523
- /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
2520
+ nk_define_cross_pack_(dots, i8, serial, i8, i8, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
2521
+ nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
2522
+ /*norm_value_type=*/u32, nk_dots_reduce_sumsq_i8_, /*depth_simd_dimensions=*/16,
2523
+ /*dimensions_per_value=*/1)
2524
2524
  nk_define_cross_symmetric_(dots, i8, serial, i8, i32, nk_b128_vec_t, nk_dot_i8x16_state_serial_t, nk_b128_vec_t,
2525
2525
  nk_dot_i8x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
2526
2526
  nk_dot_i8x16_update_serial, nk_dot_i8x16_finalize_serial, nk_store_b128_serial_,
@@ -2535,8 +2535,10 @@ nk_define_cross_packed_(dots, i8, serial, i8, i8, i32, nk_b128_vec_t, nk_dot_i8x
2535
2535
  /* U8 GEMM: depth_simd_dimensions=16 (16 u8s = 16 bytes), U32 accumulator */
2536
2536
  nk_define_cross_pack_size_(dots, u8, serial, u8, u8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
2537
2537
  /*dimensions_per_value=*/1)
2538
- nk_define_cross_pack_(dots, u8, serial, u8, u8, nk_assign_from_to_, /*norm_value_type=*/u32, nk_dots_reduce_sumsq_u8_,
2539
- /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
2538
+ nk_define_cross_pack_(dots, u8, serial, u8, u8, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
2539
+ nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
2540
+ /*norm_value_type=*/u32, nk_dots_reduce_sumsq_u8_, /*depth_simd_dimensions=*/16,
2541
+ /*dimensions_per_value=*/1)
2540
2542
  nk_define_cross_symmetric_(dots, u8, serial, u8, u32, nk_b128_vec_t, nk_dot_u8x16_state_serial_t, nk_b128_vec_t,
2541
2543
  nk_dot_u8x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
2542
2544
  nk_dot_u8x16_update_serial, nk_dot_u8x16_finalize_serial, nk_store_b128_serial_,
@@ -2551,8 +2553,9 @@ nk_define_cross_packed_(dots, u8, serial, u8, u8, u32, nk_b128_vec_t, nk_dot_u8x
2551
2553
  /* E4M3 GEMM: depth_simd_dimensions=16 (16 e4m3s = 16 bytes), F32 accumulator */
2552
2554
  nk_define_cross_pack_size_(dots, e4m3, serial, e4m3, e4m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
2553
2555
  /*dimensions_per_value=*/1)
2554
- nk_define_cross_pack_(dots, e4m3, serial, e4m3, e4m3, nk_assign_from_to_, /*norm_value_type=*/f32,
2555
- nk_dots_reduce_sumsq_e4m3_,
2556
+ nk_define_cross_pack_(dots, e4m3, serial, e4m3, e4m3, nk_b128_vec_t, nk_load_b128_serial_,
2557
+ nk_partial_load_b8x16_serial_, nk_store_b128_serial_, nk_partial_store_b8x16_serial_,
2558
+ /*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
2556
2559
  /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
2557
2560
  nk_define_cross_symmetric_(dots, e4m3, serial, e4m3, f32, nk_b128_vec_t, nk_dot_e4m3x16_state_serial_t, nk_b128_vec_t,
2558
2561
  nk_dot_e4m3x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
@@ -2568,8 +2571,9 @@ nk_define_cross_packed_(dots, e4m3, serial, e4m3, e4m3, f32, nk_b128_vec_t, nk_d
2568
2571
  /* E5M2 GEMM: depth_simd_dimensions=16 (16 e5m2s = 16 bytes), F32 accumulator */
2569
2572
  nk_define_cross_pack_size_(dots, e5m2, serial, e5m2, e5m2, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
2570
2573
  /*dimensions_per_value=*/1)
2571
- nk_define_cross_pack_(dots, e5m2, serial, e5m2, e5m2, nk_assign_from_to_, /*norm_value_type=*/f32,
2572
- nk_dots_reduce_sumsq_e5m2_,
2574
+ nk_define_cross_pack_(dots, e5m2, serial, e5m2, e5m2, nk_b128_vec_t, nk_load_b128_serial_,
2575
+ nk_partial_load_b8x16_serial_, nk_store_b128_serial_, nk_partial_store_b8x16_serial_,
2576
+ /*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
2573
2577
  /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
2574
2578
  nk_define_cross_symmetric_(dots, e5m2, serial, e5m2, f32, nk_b128_vec_t, nk_dot_e5m2x16_state_serial_t, nk_b128_vec_t,
2575
2579
  nk_dot_e5m2x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
@@ -2585,8 +2589,9 @@ nk_define_cross_packed_(dots, e5m2, serial, e5m2, e5m2, f32, nk_b128_vec_t, nk_d
2585
2589
  /* E2M3 GEMM: depth_simd_dimensions=16 (16 e2m3s = 16 bytes), F32 accumulator */
2586
2590
  nk_define_cross_pack_size_(dots, e2m3, serial, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
2587
2591
  /*dimensions_per_value=*/1)
2588
- nk_define_cross_pack_(dots, e2m3, serial, e2m3, e2m3, nk_assign_from_to_, /*norm_value_type=*/f32,
2589
- nk_dots_reduce_sumsq_e2m3_,
2592
+ nk_define_cross_pack_(dots, e2m3, serial, e2m3, e2m3, nk_b128_vec_t, nk_load_b128_serial_,
2593
+ nk_partial_load_b8x16_serial_, nk_store_b128_serial_, nk_partial_store_b8x16_serial_,
2594
+ /*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e2m3_,
2590
2595
  /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
2591
2596
  nk_define_cross_symmetric_(dots, e2m3, serial, e2m3, f32, nk_b128_vec_t, nk_dot_e2m3x16_state_serial_t, nk_b128_vec_t,
2592
2597
  nk_dot_e2m3x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
@@ -2602,8 +2607,9 @@ nk_define_cross_packed_(dots, e2m3, serial, e2m3, e2m3, f32, nk_b128_vec_t, nk_d
2602
2607
  /* E3M2 GEMM: depth_simd_dimensions=16 (16 e3m2s = 16 bytes), F32 accumulator */
2603
2608
  nk_define_cross_pack_size_(dots, e3m2, serial, e3m2, e3m2, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
2604
2609
  /*dimensions_per_value=*/1)
2605
- nk_define_cross_pack_(dots, e3m2, serial, e3m2, e3m2, nk_assign_from_to_, /*norm_value_type=*/f32,
2606
- nk_dots_reduce_sumsq_e3m2_,
2610
+ nk_define_cross_pack_(dots, e3m2, serial, e3m2, e3m2, nk_b128_vec_t, nk_load_b128_serial_,
2611
+ nk_partial_load_b8x16_serial_, nk_store_b128_serial_, nk_partial_store_b8x16_serial_,
2612
+ /*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e3m2_,
2607
2613
  /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
2608
2614
  nk_define_cross_symmetric_(dots, e3m2, serial, e3m2, f32, nk_b128_vec_t, nk_dot_e3m2x16_state_serial_t, nk_b128_vec_t,
2609
2615
  nk_dot_e3m2x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
@@ -2619,9 +2625,10 @@ nk_define_cross_packed_(dots, e3m2, serial, e3m2, e3m2, f32, nk_b128_vec_t, nk_d
2619
2625
  /* U4 GEMM: u4x2 for both A and B */
2620
2626
  nk_define_cross_pack_size_(dots, u4, serial, u4x2, u4x2, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
2621
2627
  /*dimensions_per_value=*/2)
2622
- nk_define_cross_pack_(dots, u4, serial, u4x2, u4x2, nk_assign_from_to_, /*norm_value_type=*/u32,
2623
- nk_dots_reduce_sumsq_u4_,
2624
- /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/2)
2628
+ nk_define_cross_pack_(dots, u4, serial, u4x2, u4x2, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
2629
+ nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
2630
+ /*norm_value_type=*/u32, nk_dots_reduce_sumsq_u4_, /*depth_simd_dimensions=*/16,
2631
+ /*dimensions_per_value=*/2)
2625
2632
  nk_define_cross_symmetric_(dots, u4, serial, u4x2, u32, nk_b64_vec_t, nk_dot_u4x16_state_serial_t, nk_b128_vec_t,
2626
2633
  nk_dot_u4x16_init_serial, nk_load_b64_serial_, nk_partial_load_b4x16_serial_,
2627
2634
  nk_dot_u4x16_update_serial, nk_dot_u4x16_finalize_serial, nk_store_b128_serial_,
@@ -2636,9 +2643,10 @@ nk_define_cross_packed_(dots, u4, serial, u4x2, u4x2, u32, nk_b64_vec_t, nk_dot_
2636
2643
  /* I4 GEMM: i4x2 for both A and B */
2637
2644
  nk_define_cross_pack_size_(dots, i4, serial, i4x2, i4x2, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
2638
2645
  /*dimensions_per_value=*/2)
2639
- nk_define_cross_pack_(dots, i4, serial, i4x2, i4x2, nk_assign_from_to_, /*norm_value_type=*/u32,
2640
- nk_dots_reduce_sumsq_i4_,
2641
- /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/2)
2646
+ nk_define_cross_pack_(dots, i4, serial, i4x2, i4x2, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
2647
+ nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
2648
+ /*norm_value_type=*/u32, nk_dots_reduce_sumsq_i4_, /*depth_simd_dimensions=*/16,
2649
+ /*dimensions_per_value=*/2)
2642
2650
  nk_define_cross_symmetric_(dots, i4, serial, i4x2, i32, nk_b64_vec_t, nk_dot_i4x16_state_serial_t, nk_b128_vec_t,
2643
2651
  nk_dot_i4x16_init_serial, nk_load_b64_serial_, nk_partial_load_b4x16_serial_,
2644
2652
  nk_dot_i4x16_update_serial, nk_dot_i4x16_finalize_serial, nk_store_b128_serial_,
@@ -2653,8 +2661,10 @@ nk_define_cross_packed_(dots, i4, serial, i4x2, i4x2, i32, nk_b64_vec_t, nk_dot_
2653
2661
  /* U1 GEMM: u1x8 for both A and B */
2654
2662
  nk_define_cross_pack_size_(dots, u1, serial, u1x8, u1x8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/128,
2655
2663
  /*dimensions_per_value=*/8)
2656
- nk_define_cross_pack_(dots, u1, serial, u1x8, u1x8, nk_assign_from_to_, /*norm_value_type=*/u32, nk_dots_reduce_sum_u1_,
2657
- /*depth_simd_dimensions=*/128, /*dimensions_per_value=*/8)
2664
+ nk_define_cross_pack_(dots, u1, serial, u1x8, u1x8, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
2665
+ nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
2666
+ /*norm_value_type=*/u32, nk_dots_reduce_sum_u1_, /*depth_simd_dimensions=*/128,
2667
+ /*dimensions_per_value=*/8)
2658
2668
  nk_define_cross_symmetric_(dots, u1, serial, u1x8, u32, nk_b128_vec_t, nk_dot_u1x128_state_serial_t, nk_b128_vec_t,
2659
2669
  nk_dot_u1x128_init_serial, nk_load_b128_serial_, nk_partial_load_b1x128_serial_,
2660
2670
  nk_dot_u1x128_update_serial, nk_dot_u1x128_finalize_serial, nk_store_b128_serial_,
@@ -2673,7 +2683,7 @@ nk_define_cross_packed_(dots, u1, serial, u1x8, u1x8, u32, nk_b128_vec_t, nk_dot
2673
2683
  #endif
2674
2684
 
2675
2685
  /* BF16 compact: truncate F32 → BF16 in-place.
2676
- * Reads F32 matrix with c_stride_in_bytes, writes BF16 tightly packed (stride = column_count × sizeof(bf16)).
2686
+ * Reads F32 matrix with c_stride_in_bytes, writes BF16 tightly packed (stride_in_bytes = column_count × sizeof(bf16)).
2677
2687
  */
2678
2688
  NK_PUBLIC void nk_dots_compact_bf16_serial(void *c, nk_size_t row_count, nk_size_t column_count,
2679
2689
  nk_size_t c_stride_in_bytes) {
@@ -2767,78 +2777,84 @@ NK_PUBLIC void nk_dots_compact_i8_serial(void *c, nk_size_t row_count, nk_size_t
2767
2777
  } \
2768
2778
  }
2769
2779
 
2770
- #define nk_define_cross_normalized_symmetric_(metric_name, input_type_name, isa_suffix, input_value_type, \
2771
- dot_result_type, norm_value_type, final_result_type, vec_type, \
2772
- dots_symmetric_fn, from_dot_fn, compute_norm_fn, load_fn, \
2773
- partial_load_fn, store_fn, partial_store_fn, dimensions_per_value) \
2774
- NK_PUBLIC void nk_##metric_name##s_symmetric_##input_type_name##_##isa_suffix( \
2775
- nk_##input_value_type##_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, \
2776
- nk_##final_result_type##_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) { \
2777
- \
2778
- dots_symmetric_fn(vectors, n_vectors, depth, stride, (nk_##dot_result_type##_t *)result, result_stride, \
2779
- row_start, row_count); \
2780
- \
2781
- /* Phase 1 — cache row norms in the result diagonal (O(row_count) calls) */ \
2782
- for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
2783
- nk_##input_value_type##_t const *row_vector = (nk_##input_value_type##_t const *)((char const *)vectors + \
2784
- row_index * stride); \
2785
- nk_##norm_value_type##_t *row_diag = (nk_##norm_value_type##_t *)((char *)result + \
2786
- row_index * result_stride); \
2787
- row_diag[row_index] = compute_norm_fn(row_vector, depth); \
2788
- } \
2789
- \
2790
- /* Phase 2 — column-first post-processing with 256-element norm cache */ \
2791
- nk_##norm_value_type##_t column_norms[256]; \
2792
- for (nk_size_t column_chunk_start = 0; column_chunk_start < n_vectors; column_chunk_start += 256) { \
2793
- nk_size_t column_chunk_end = column_chunk_start + 256 < n_vectors ? column_chunk_start + 256 : n_vectors; \
2794
- \
2795
- /* Pre-compute norms for this column chunk — each column visited exactly once */ \
2796
- for (nk_size_t col = column_chunk_start; col < column_chunk_end; ++col) { \
2797
- nk_##input_value_type##_t const *column_vector = \
2798
- (nk_##input_value_type##_t const *)((char const *)vectors + col * stride); \
2799
- column_norms[col - column_chunk_start] = compute_norm_fn(column_vector, depth); \
2800
- } \
2801
- \
2802
- /* Sweep assigned rows against this column chunk */ \
2803
- for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
2804
- nk_size_t j_start = row_index + 1 > column_chunk_start ? row_index + 1 : column_chunk_start; \
2805
- if (j_start >= column_chunk_end) continue; \
2806
- char *row_ptr = (char *)result + row_index * result_stride; \
2807
- nk_##norm_value_type##_t sumsq_i = ((nk_##norm_value_type##_t *)row_ptr)[row_index]; \
2808
- nk_##dot_result_type##_t *r_dots = (nk_##dot_result_type##_t *)row_ptr; \
2809
- nk_##final_result_type##_t *r_out = (nk_##final_result_type##_t *)row_ptr; \
2810
- \
2811
- /* 4-wide vectorized loop */ \
2812
- nk_size_t j = j_start; \
2813
- for (; j + 4 <= column_chunk_end; j += 4) { \
2814
- vec_type target_norms_vec; \
2815
- load_fn(&column_norms[j - column_chunk_start], &target_norms_vec); \
2816
- vec_type dots_vec, results_vec; \
2817
- load_fn(r_dots + j, &dots_vec); \
2818
- from_dot_fn(dots_vec, sumsq_i, target_norms_vec, &results_vec); \
2819
- store_fn(&results_vec, r_out + j); \
2820
- } \
2821
- /* Remainder */ \
2822
- if (j < column_chunk_end) { \
2823
- vec_type dots_vec = {0}, norms_vec = {0}, results_vec; \
2824
- partial_load_fn(r_dots + j, &dots_vec, column_chunk_end - j); \
2825
- partial_load_fn(&column_norms[j - column_chunk_start], &norms_vec, column_chunk_end - j); \
2826
- from_dot_fn(dots_vec, sumsq_i, norms_vec, &results_vec); \
2827
- partial_store_fn(&results_vec, r_out + j, column_chunk_end - j); \
2828
- } \
2829
- } \
2830
- } \
2831
- \
2832
- /* Phase 3 — zero diagonals */ \
2833
- for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
2834
- nk_##final_result_type##_t *r_out = (nk_##final_result_type##_t *)((char *)result + \
2835
- row_index * result_stride); \
2836
- r_out[row_index] = 0; \
2837
- } \
2780
+ #define nk_define_cross_normalized_symmetric_(metric_name, input_type_name, isa_suffix, input_value_type, \
2781
+ dot_result_type, norm_value_type, final_result_type, vec_type, \
2782
+ dots_symmetric_fn, from_dot_fn, compute_norm_fn, load_fn, \
2783
+ partial_load_fn, store_fn, partial_store_fn, dimensions_per_value) \
2784
+ NK_PUBLIC void nk_##metric_name##s_symmetric_##input_type_name##_##isa_suffix( \
2785
+ nk_##input_value_type##_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, \
2786
+ nk_##final_result_type##_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, \
2787
+ nk_size_t row_count) { \
2788
+ \
2789
+ dots_symmetric_fn(vectors, vectors_count, depth, stride_in_bytes, (nk_##dot_result_type##_t *)result, \
2790
+ result_stride_in_bytes, row_start, row_count); \
2791
+ \
2792
+ /* Phase 1 cache row norms in the result diagonal (O(row_count) calls) */ \
2793
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
2794
+ nk_##input_value_type##_t const *row_vector = \
2795
+ (nk_##input_value_type##_t const *)((char const *)vectors + row_index * stride_in_bytes); \
2796
+ nk_##norm_value_type##_t *row_diag = (nk_##norm_value_type##_t *)((char *)result + \
2797
+ row_index * result_stride_in_bytes); \
2798
+ row_diag[row_index] = compute_norm_fn(row_vector, depth); \
2799
+ } \
2800
+ \
2801
+ /* Phase 2 — column-first post-processing with 256-element norm cache */ \
2802
+ nk_##norm_value_type##_t column_norms[256]; \
2803
+ for (nk_size_t column_chunk_start = 0; column_chunk_start < vectors_count; column_chunk_start += 256) { \
2804
+ nk_size_t column_chunk_end = column_chunk_start + 256 < vectors_count ? column_chunk_start + 256 \
2805
+ : vectors_count; \
2806
+ \
2807
+ /* Pre-compute norms for this column chunk — each column visited exactly once */ \
2808
+ for (nk_size_t col = column_chunk_start; col < column_chunk_end; ++col) { \
2809
+ nk_##input_value_type##_t const *column_vector = \
2810
+ (nk_##input_value_type##_t const *)((char const *)vectors + col * stride_in_bytes); \
2811
+ column_norms[col - column_chunk_start] = compute_norm_fn(column_vector, depth); \
2812
+ } \
2813
+ \
2814
+ /* Sweep assigned rows against this column chunk */ \
2815
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
2816
+ nk_size_t j_start = row_index + 1 > column_chunk_start ? row_index + 1 : column_chunk_start; \
2817
+ if (j_start >= column_chunk_end) continue; \
2818
+ char *row_ptr = (char *)result + row_index * result_stride_in_bytes; \
2819
+ nk_##norm_value_type##_t sumsq_i = ((nk_##norm_value_type##_t *)row_ptr)[row_index]; \
2820
+ nk_##dot_result_type##_t *r_dots = (nk_##dot_result_type##_t *)row_ptr; \
2821
+ nk_##final_result_type##_t *r_out = (nk_##final_result_type##_t *)row_ptr; \
2822
+ \
2823
+ /* 4-wide vectorized loop */ \
2824
+ nk_size_t j = j_start; \
2825
+ for (; j + 4 <= column_chunk_end; j += 4) { \
2826
+ vec_type target_norms_vec; \
2827
+ load_fn(&column_norms[j - column_chunk_start], &target_norms_vec); \
2828
+ vec_type dots_vec, results_vec; \
2829
+ load_fn(r_dots + j, &dots_vec); \
2830
+ from_dot_fn(dots_vec, sumsq_i, target_norms_vec, &results_vec); \
2831
+ store_fn(&results_vec, r_out + j); \
2832
+ } \
2833
+ /* Remainder */ \
2834
+ if (j < column_chunk_end) { \
2835
+ vec_type dots_vec = {0}, norms_vec = {0}, results_vec; \
2836
+ partial_load_fn(r_dots + j, &dots_vec, column_chunk_end - j); \
2837
+ partial_load_fn(&column_norms[j - column_chunk_start], &norms_vec, column_chunk_end - j); \
2838
+ from_dot_fn(dots_vec, sumsq_i, norms_vec, &results_vec); \
2839
+ partial_store_fn(&results_vec, r_out + j, column_chunk_end - j); \
2840
+ } \
2841
+ } \
2842
+ } \
2843
+ \
2844
+ /* Phase 3 zero diagonals */ \
2845
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
2846
+ nk_##final_result_type##_t *r_out = (nk_##final_result_type##_t *)((char *)result + \
2847
+ row_index * result_stride_in_bytes); \
2848
+ r_out[row_index] = 0; \
2849
+ } \
2838
2850
  }
2839
2851
 
2840
2852
  #if defined(__cplusplus)
2841
2853
  } // extern "C"
2842
2854
  #endif
2843
2855
 
2856
+ #if defined(__GNUC__) && !defined(__clang__)
2857
+ #pragma GCC diagnostic pop
2858
+ #endif
2859
+
2844
2860
  #endif // NK_DOTS_SERIAL_H