numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -77,24 +77,24 @@ static nk_u16_t const nk_e3m2_magnitude_lut_rvv_[32] = {0, 1, 2, 3, 4,
77
77
  14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80,
78
78
  96, 112, 128, 160, 192, 224, 256, 320, 384, 448};
79
79
 
80
- #pragma region Single Precision Floats
80
+ #pragma region F32 Floats
81
81
 
82
82
  NK_PUBLIC nk_size_t nk_dots_packed_size_f32_rvv(nk_size_t column_count, nk_size_t depth) {
83
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
84
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
83
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
84
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
85
85
  // Break power-of-2 strides for cache associativity
86
86
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
87
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
87
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
88
88
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
89
89
  column_count * sizeof(nk_f64_t); // per-column norms
90
90
  }
91
91
 
92
92
  NK_PUBLIC void nk_dots_pack_f32_rvv(nk_f32_t const *b, nk_size_t column_count, nk_size_t depth,
93
93
  nk_size_t b_stride_in_bytes, void *b_packed) {
94
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
95
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
94
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
95
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
96
96
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
97
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
97
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
98
98
 
99
99
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
100
100
  header->column_count = (nk_u32_t)column_count;
@@ -103,12 +103,24 @@ NK_PUBLIC void nk_dots_pack_f32_rvv(nk_f32_t const *b, nk_size_t column_count, n
103
103
 
104
104
  nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
105
105
  nk_size_t total = column_count * depth_padded;
106
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
106
+ {
107
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
108
+ nk_size_t total_bytes = total * sizeof(nk_f32_t);
109
+ for (nk_size_t i = 0; i < total_bytes;) {
110
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
111
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
112
+ i += vector_length;
113
+ }
114
+ }
107
115
 
108
116
  for (nk_size_t column = 0; column < column_count; ++column) {
109
117
  nk_f32_t const *src = (nk_f32_t const *)((char const *)b + column * b_stride_in_bytes);
110
118
  nk_f32_t *dst = packed + column * depth_padded;
111
- for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
119
+ for (nk_size_t k = 0; k < depth;) {
120
+ nk_size_t vector_length = __riscv_vsetvl_e32m8(depth - k);
121
+ __riscv_vse32_v_f32m8(dst + k, __riscv_vle32_v_f32m8(src + k, vector_length), vector_length);
122
+ k += vector_length;
123
+ }
112
124
  }
113
125
 
114
126
  // Append per-column norms after packed data
@@ -158,11 +170,11 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
158
170
 
159
171
  for (nk_size_t column = 0; column < column_count; ++column) {
160
172
  nk_f32_t const *b_column = packed_data + column * depth_padded;
161
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
162
- vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
163
- vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
164
- vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
165
- vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
173
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
174
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
175
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
176
+ vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
177
+ vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
166
178
 
167
179
  nk_size_t remaining = depth;
168
180
  nk_size_t k = 0;
@@ -186,13 +198,13 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
186
198
  // Horizontal reduce directly to f64
187
199
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
188
200
  c_row_0[column] = __riscv_vfmv_f_s_f64m1_f64(
189
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
201
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
190
202
  c_row_1[column] = __riscv_vfmv_f_s_f64m1_f64(
191
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
203
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
192
204
  c_row_2[column] = __riscv_vfmv_f_s_f64m1_f64(
193
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, vlmax));
205
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, max_vector_length));
194
206
  c_row_3[column] = __riscv_vfmv_f_s_f64m1_f64(
195
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, vlmax));
207
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, max_vector_length));
196
208
  }
197
209
  }
198
210
  // Remainder rows (mr < 4)
@@ -201,8 +213,8 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
201
213
  nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + row * c_stride_in_bytes);
202
214
  for (nk_size_t column = 0; column < column_count; ++column) {
203
215
  nk_f32_t const *b_column = packed_data + column * depth_padded;
204
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
205
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
216
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
217
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
206
218
  nk_size_t remaining = depth;
207
219
  nk_size_t k = 0;
208
220
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -214,7 +226,7 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
214
226
  }
215
227
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
216
228
  c_row[column] = __riscv_vfmv_f_s_f64m1_f64(
217
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
229
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
218
230
  }
219
231
  }
220
232
  }
@@ -225,9 +237,10 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
225
237
  * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
226
238
  * vectors naturally, so no separate edge kernel is needed.
227
239
  */
228
- NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t m, nk_size_t n,
229
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
230
- nk_dots_packed_f32_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
240
+ NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
241
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
242
+ nk_size_t c_stride_in_bytes) {
243
+ nk_dots_packed_f32_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
231
244
  }
232
245
 
233
246
  /**
@@ -236,19 +249,19 @@ NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, n
236
249
  * Uses f64 widened accumulation via `vfwmacc_vv_f64m4` for precision.
237
250
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
238
251
  */
239
- NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
240
- nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
252
+ NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
253
+ nk_size_t stride_in_bytes, nk_f64_t *result, nk_size_t result_stride_in_bytes,
241
254
  nk_size_t row_start, nk_size_t row_count) {
242
- nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
243
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
244
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
255
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
256
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
257
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
245
258
 
246
259
  for (nk_size_t i = row_start; i < row_end; ++i) {
247
260
  nk_f32_t const *a_i = vectors + i * stride_elements;
248
- for (nk_size_t j = i; j < n_vectors; ++j) {
261
+ for (nk_size_t j = i; j < vectors_count; ++j) {
249
262
  nk_f32_t const *a_j = vectors + j * stride_elements;
250
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
251
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
263
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
264
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
252
265
  nk_size_t remaining = depth;
253
266
  nk_size_t k = 0;
254
267
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -260,31 +273,31 @@ NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t n_ve
260
273
  }
261
274
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
262
275
  nk_f64_t dot = __riscv_vfmv_f_s_f64m1_f64(
263
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
276
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
264
277
  result[i * result_stride_elements + j] = dot;
265
278
  }
266
279
  }
267
280
  }
268
281
 
269
- #pragma endregion // Single Precision Floats
282
+ #pragma endregion F32 Floats
270
283
 
271
- #pragma region Double Precision Floats
284
+ #pragma region F64 Floats
272
285
 
273
286
  NK_PUBLIC nk_size_t nk_dots_packed_size_f64_rvv(nk_size_t column_count, nk_size_t depth) {
274
- nk_size_t vector_length = __riscv_vsetvlmax_e64m4();
275
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
287
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
288
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
276
289
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f64_t);
277
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
290
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
278
291
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f64_t) +
279
292
  column_count * sizeof(nk_f64_t); // per-column norms
280
293
  }
281
294
 
282
295
  NK_PUBLIC void nk_dots_pack_f64_rvv(nk_f64_t const *b, nk_size_t column_count, nk_size_t depth,
283
296
  nk_size_t b_stride_in_bytes, void *b_packed) {
284
- nk_size_t vector_length = __riscv_vsetvlmax_e64m4();
285
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
297
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
298
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
286
299
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f64_t);
287
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
300
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
288
301
 
289
302
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
290
303
  header->column_count = (nk_u32_t)column_count;
@@ -293,12 +306,24 @@ NK_PUBLIC void nk_dots_pack_f64_rvv(nk_f64_t const *b, nk_size_t column_count, n
293
306
 
294
307
  nk_f64_t *packed = (nk_f64_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
295
308
  nk_size_t total = column_count * depth_padded;
296
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
309
+ {
310
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
311
+ nk_size_t total_bytes = total * sizeof(nk_f64_t);
312
+ for (nk_size_t i = 0; i < total_bytes;) {
313
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
314
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
315
+ i += vector_length;
316
+ }
317
+ }
297
318
 
298
319
  for (nk_size_t column = 0; column < column_count; ++column) {
299
320
  nk_f64_t const *src = (nk_f64_t const *)((char const *)b + column * b_stride_in_bytes);
300
321
  nk_f64_t *dst = packed + column * depth_padded;
301
- for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
322
+ for (nk_size_t k = 0; k < depth;) {
323
+ nk_size_t vector_length = __riscv_vsetvl_e64m8(depth - k);
324
+ __riscv_vse64_v_f64m8(dst + k, __riscv_vle64_v_f64m8(src + k, vector_length), vector_length);
325
+ k += vector_length;
326
+ }
302
327
  }
303
328
 
304
329
  // Append per-column norms after packed data
@@ -341,11 +366,11 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
341
366
 
342
367
  for (nk_size_t column = 0; column < column_count; ++column) {
343
368
  nk_f64_t const *b_column = packed_data + column * depth_padded;
344
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
345
- vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
346
- vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
347
- vfloat64m4_t compensation_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
348
- vfloat64m4_t compensation_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
369
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
370
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
371
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
372
+ vfloat64m4_t compensation_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
373
+ vfloat64m4_t compensation_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
349
374
 
350
375
  nk_size_t remaining = depth;
351
376
  nk_size_t k = 0;
@@ -384,9 +409,9 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
384
409
  // Horizontal reduce
385
410
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
386
411
  c_row_0[column] = __riscv_vfmv_f_s_f64m1_f64(
387
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
412
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
388
413
  c_row_1[column] = __riscv_vfmv_f_s_f64m1_f64(
389
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
414
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
390
415
  }
391
416
  }
392
417
  // Remainder rows
@@ -395,9 +420,9 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
395
420
  nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + row * c_stride_in_bytes);
396
421
  for (nk_size_t column = 0; column < column_count; ++column) {
397
422
  nk_f64_t const *b_column = packed_data + column * depth_padded;
398
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
399
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
400
- vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
423
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
424
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
425
+ vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
401
426
 
402
427
  nk_size_t remaining = depth;
403
428
  nk_size_t k = 0;
@@ -419,7 +444,7 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
419
444
 
420
445
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
421
446
  c_row[column] = __riscv_vfmv_f_s_f64m1_f64(
422
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
447
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
423
448
  }
424
449
  }
425
450
  }
@@ -427,9 +452,10 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
427
452
  /**
428
453
  * @brief Public f64 packed GEMM wrapper matching the declared signature in dots.h.
429
454
  */
430
- NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t m, nk_size_t n,
431
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
432
- nk_dots_packed_f64_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
455
+ NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
456
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
457
+ nk_size_t c_stride_in_bytes) {
458
+ nk_dots_packed_f64_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
433
459
  }
434
460
 
435
461
  /**
@@ -438,20 +464,20 @@ NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, n
438
464
  * Uses Kahan compensation over full depth for precision.
439
465
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
440
466
  */
441
- NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
442
- nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
467
+ NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
468
+ nk_size_t stride_in_bytes, nk_f64_t *result, nk_size_t result_stride_in_bytes,
443
469
  nk_size_t row_start, nk_size_t row_count) {
444
- nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
445
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
446
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
470
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
471
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
472
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
447
473
 
448
474
  for (nk_size_t i = row_start; i < row_end; ++i) {
449
475
  nk_f64_t const *a_i = vectors + i * stride_elements;
450
- for (nk_size_t j = i; j < n_vectors; ++j) {
476
+ for (nk_size_t j = i; j < vectors_count; ++j) {
451
477
  nk_f64_t const *a_j = vectors + j * stride_elements;
452
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
453
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
454
- vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
478
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
479
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
480
+ vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
455
481
 
456
482
  nk_size_t remaining = depth;
457
483
  nk_size_t k = 0;
@@ -473,15 +499,15 @@ NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t n_ve
473
499
 
474
500
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
475
501
  nk_f64_t dot = __riscv_vfmv_f_s_f64m1_f64(
476
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
502
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
477
503
  result[i * result_stride_elements + j] = dot;
478
504
  }
479
505
  }
480
506
  }
481
507
 
482
- #pragma endregion // Double Precision Floats
508
+ #pragma endregion F64 Floats
483
509
 
484
- #pragma region Micro Precision E2M3
510
+ #pragma region E2M3 Floats
485
511
 
486
512
  /**
487
513
  * @brief Scalar conversion helper: e2m3 byte → signed i8 (value × 16).
@@ -496,10 +522,10 @@ NK_INTERNAL nk_i8_t nk_e2m3_to_i8_rvv_(nk_u8_t raw) {
496
522
  }
497
523
 
498
524
  NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_rvv(nk_size_t column_count, nk_size_t depth) {
499
- nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
500
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
525
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
526
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
501
527
  nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
502
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
528
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
503
529
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i8_t) +
504
530
  column_count * sizeof(nk_f32_t); // per-column norms
505
531
  }
@@ -512,10 +538,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_rvv(nk_size_t column_count, nk_size
512
538
  */
513
539
  NK_PUBLIC void nk_dots_pack_e2m3_rvv(nk_e2m3_t const *b, nk_size_t column_count, nk_size_t depth,
514
540
  nk_size_t b_stride_in_bytes, void *b_packed) {
515
- nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
516
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
541
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
542
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
517
543
  nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
518
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
544
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
519
545
 
520
546
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
521
547
  header->column_count = (nk_u32_t)column_count;
@@ -524,7 +550,15 @@ NK_PUBLIC void nk_dots_pack_e2m3_rvv(nk_e2m3_t const *b, nk_size_t column_count,
524
550
 
525
551
  nk_i8_t *packed = (nk_i8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
526
552
  nk_size_t total = column_count * depth_padded;
527
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
553
+ {
554
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
555
+ nk_size_t total_bytes = total * sizeof(nk_i8_t);
556
+ for (nk_size_t i = 0; i < total_bytes;) {
557
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
558
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
559
+ i += vector_length;
560
+ }
561
+ }
528
562
 
529
563
  for (nk_size_t column = 0; column < column_count; ++column) {
530
564
  nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
@@ -584,11 +618,11 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
584
618
 
585
619
  for (nk_size_t column = 0; column < column_count; ++column) {
586
620
  nk_i8_t const *b_column = packed_data + column * depth_padded;
587
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
588
- vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
589
- vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
590
- vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
591
- vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
621
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
622
+ vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
623
+ vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
624
+ vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
625
+ vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
592
626
 
593
627
  nk_size_t remaining = depth;
594
628
  nk_size_t k = 0;
@@ -654,16 +688,16 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
654
688
  // Horizontal reduce and convert to f32 with scaling
655
689
  vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
656
690
  c_row_0[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
657
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, vlmax)) *
691
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, max_vector_length)) *
658
692
  lut_scale_reciprocal;
659
693
  c_row_1[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
660
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, vlmax)) *
694
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, max_vector_length)) *
661
695
  lut_scale_reciprocal;
662
696
  c_row_2[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
663
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, vlmax)) *
697
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, max_vector_length)) *
664
698
  lut_scale_reciprocal;
665
699
  c_row_3[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
666
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, vlmax)) *
700
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, max_vector_length)) *
667
701
  lut_scale_reciprocal;
668
702
  }
669
703
  }
@@ -673,8 +707,8 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
673
707
  nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
674
708
  for (nk_size_t column = 0; column < column_count; ++column) {
675
709
  nk_i8_t const *b_column = packed_data + column * depth_padded;
676
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
677
- vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
710
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
711
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
678
712
  nk_size_t remaining = depth;
679
713
  nk_size_t k = 0;
680
714
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -693,7 +727,7 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
693
727
  }
694
728
  vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
695
729
  c_row[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
696
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
730
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length)) *
697
731
  lut_scale_reciprocal;
698
732
  }
699
733
  }
@@ -702,9 +736,10 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
702
736
  /**
703
737
  * @brief Public e2m3 packed GEMM wrapper matching the declared signature in dots.h.
704
738
  */
705
- NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
706
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
707
- nk_dots_packed_e2m3_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
739
+ NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
740
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
741
+ nk_size_t c_stride_in_bytes) {
742
+ nk_dots_packed_e2m3_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
708
743
  }
709
744
 
710
745
  /**
@@ -713,20 +748,20 @@ NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed,
713
748
  * Uses integer i8 LUT arithmetic with i32 accumulation, scaled by 1/256.
714
749
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
715
750
  */
716
- NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
717
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
751
+ NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
752
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
718
753
  nk_size_t row_start, nk_size_t row_count) {
719
754
  nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
720
755
 
721
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
722
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
756
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
757
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
723
758
 
724
759
  for (nk_size_t i = row_start; i < row_end; ++i) {
725
- nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
726
- for (nk_size_t j = i; j < n_vectors; ++j) {
727
- nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
728
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
729
- vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
760
+ nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride_in_bytes;
761
+ for (nk_size_t j = i; j < vectors_count; ++j) {
762
+ nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride_in_bytes;
763
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
764
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
730
765
  nk_size_t remaining = depth;
731
766
  nk_size_t k = 0;
732
767
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -755,16 +790,16 @@ NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t n_
755
790
  }
756
791
  vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
757
792
  nk_f32_t dot = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
758
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
793
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length)) *
759
794
  lut_scale_reciprocal;
760
795
  result[i * result_stride_elements + j] = dot;
761
796
  }
762
797
  }
763
798
  }
764
799
 
765
- #pragma endregion // Micro Precision E2M3
800
+ #pragma endregion E2M3 Floats
766
801
 
767
- #pragma region Micro Precision E3M2
802
+ #pragma region E3M2 Floats
768
803
 
769
804
  /**
770
805
  * @brief Scalar conversion helper: e3m2 byte → signed i16 (value × 16).
@@ -779,10 +814,10 @@ NK_INTERNAL nk_i16_t nk_e3m2_to_i16_rvv_(nk_u8_t raw) {
779
814
  }
780
815
 
781
816
  NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_rvv(nk_size_t column_count, nk_size_t depth) {
782
- nk_size_t vector_length = __riscv_vsetvlmax_e16m2();
783
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
817
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e16m2();
818
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
784
819
  nk_size_t stride_bytes = depth_padded * sizeof(nk_i16_t);
785
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
820
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
786
821
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i16_t) +
787
822
  column_count * sizeof(nk_f32_t); // per-column norms
788
823
  }
@@ -795,10 +830,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_rvv(nk_size_t column_count, nk_size
795
830
  */
796
831
  NK_PUBLIC void nk_dots_pack_e3m2_rvv(nk_e3m2_t const *b, nk_size_t column_count, nk_size_t depth,
797
832
  nk_size_t b_stride_in_bytes, void *b_packed) {
798
- nk_size_t vector_length = __riscv_vsetvlmax_e16m2();
799
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
833
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e16m2();
834
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
800
835
  nk_size_t stride_bytes = depth_padded * sizeof(nk_i16_t);
801
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
836
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
802
837
 
803
838
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
804
839
  header->column_count = (nk_u32_t)column_count;
@@ -807,7 +842,15 @@ NK_PUBLIC void nk_dots_pack_e3m2_rvv(nk_e3m2_t const *b, nk_size_t column_count,
807
842
 
808
843
  nk_i16_t *packed = (nk_i16_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
809
844
  nk_size_t total = column_count * depth_padded;
810
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
845
+ {
846
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
847
+ nk_size_t total_bytes = total * sizeof(nk_i16_t);
848
+ for (nk_size_t i = 0; i < total_bytes;) {
849
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
850
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
851
+ i += vector_length;
852
+ }
853
+ }
811
854
 
812
855
  for (nk_size_t column = 0; column < column_count; ++column) {
813
856
  nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
@@ -862,9 +905,9 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
862
905
 
863
906
  for (nk_size_t column = 0; column < column_count; ++column) {
864
907
  nk_i16_t const *b_column = packed_data + column * depth_padded;
865
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
866
- vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
867
- vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
908
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
909
+ vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
910
+ vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
868
911
 
869
912
  nk_size_t remaining = depth;
870
913
  nk_size_t k = 0;
@@ -916,10 +959,10 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
916
959
  // Horizontal reduce and convert to f32 with scaling
917
960
  vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
918
961
  c_row_0[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
919
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, vlmax)) *
962
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, max_vector_length)) *
920
963
  lut_scale_reciprocal;
921
964
  c_row_1[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
922
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, vlmax)) *
965
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, max_vector_length)) *
923
966
  lut_scale_reciprocal;
924
967
  }
925
968
  }
@@ -929,8 +972,8 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
929
972
  nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
930
973
  for (nk_size_t column = 0; column < column_count; ++column) {
931
974
  nk_i16_t const *b_column = packed_data + column * depth_padded;
932
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
933
- vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
975
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
976
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
934
977
  nk_size_t remaining = depth;
935
978
  nk_size_t k = 0;
936
979
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -951,7 +994,7 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
951
994
  }
952
995
  vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
953
996
  c_row[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
954
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
997
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length)) *
955
998
  lut_scale_reciprocal;
956
999
  }
957
1000
  }
@@ -960,9 +1003,10 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
960
1003
  /**
961
1004
  * @brief Public e3m2 packed GEMM wrapper matching the declared signature in dots.h.
962
1005
  */
963
- NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
964
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
965
- nk_dots_packed_e3m2_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
1006
+ NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
1007
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
1008
+ nk_size_t c_stride_in_bytes) {
1009
+ nk_dots_packed_e3m2_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
966
1010
  }
967
1011
 
968
1012
  /**
@@ -971,20 +1015,20 @@ NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed,
971
1015
  * Uses integer i16 LUT arithmetic with i32 widening MAC, scaled by 1/256.
972
1016
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
973
1017
  */
974
- NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
975
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1018
+ NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
1019
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
976
1020
  nk_size_t row_start, nk_size_t row_count) {
977
1021
  nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
978
1022
 
979
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
980
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
1023
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1024
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
981
1025
 
982
1026
  for (nk_size_t i = row_start; i < row_end; ++i) {
983
- nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
984
- for (nk_size_t j = i; j < n_vectors; ++j) {
985
- nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
986
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
987
- vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1027
+ nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride_in_bytes;
1028
+ for (nk_size_t j = i; j < vectors_count; ++j) {
1029
+ nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride_in_bytes;
1030
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
1031
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
988
1032
  nk_size_t remaining = depth;
989
1033
  nk_size_t k = 0;
990
1034
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -1023,16 +1067,16 @@ NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_
1023
1067
  }
1024
1068
  vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
1025
1069
  nk_f32_t dot = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
1026
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
1070
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length)) *
1027
1071
  lut_scale_reciprocal;
1028
1072
  result[i * result_stride_elements + j] = dot;
1029
1073
  }
1030
1074
  }
1031
1075
  }
1032
1076
 
1033
- #pragma endregion // Micro Precision E3M2
1077
+ #pragma endregion E3M2 Floats
1034
1078
 
1035
- #pragma region Brain Float 16
1079
+ #pragma region BF16 Floats
1036
1080
 
1037
1081
  /**
1038
1082
  * @brief Compute the packed buffer size for bf16 GEMM (B stored as f32).
@@ -1041,11 +1085,11 @@ NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_
1041
1085
  * Layout: column-panel with depth-contiguous f32 values, cache-line padding.
1042
1086
  */
1043
1087
  NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_rvv(nk_size_t column_count, nk_size_t depth) {
1044
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1045
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1088
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1089
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1046
1090
  // Break power-of-2 strides for cache associativity
1047
1091
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1048
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1092
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1049
1093
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
1050
1094
  column_count * sizeof(nk_f32_t); // per-column norms
1051
1095
  }
@@ -1058,10 +1102,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_rvv(nk_size_t column_count, nk_size
1058
1102
  */
1059
1103
  NK_PUBLIC void nk_dots_pack_bf16_rvv(nk_bf16_t const *b, nk_size_t column_count, nk_size_t depth,
1060
1104
  nk_size_t b_stride_in_bytes, void *b_packed) {
1061
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1062
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1105
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1106
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1063
1107
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1064
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1108
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1065
1109
 
1066
1110
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1067
1111
  header->column_count = (nk_u32_t)column_count;
@@ -1070,7 +1114,15 @@ NK_PUBLIC void nk_dots_pack_bf16_rvv(nk_bf16_t const *b, nk_size_t column_count,
1070
1114
 
1071
1115
  nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1072
1116
  nk_size_t total = column_count * depth_padded;
1073
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
1117
+ {
1118
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
1119
+ nk_size_t total_bytes = total * sizeof(nk_f32_t);
1120
+ for (nk_size_t i = 0; i < total_bytes;) {
1121
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
1122
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
1123
+ i += vector_length;
1124
+ }
1125
+ }
1074
1126
 
1075
1127
  for (nk_size_t column = 0; column < column_count; ++column) {
1076
1128
  nk_u16_t const *src = (nk_u16_t const *)((char const *)b + column * b_stride_in_bytes);
@@ -1133,11 +1185,11 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
1133
1185
 
1134
1186
  for (nk_size_t column = 0; column < column_count; ++column) {
1135
1187
  nk_f32_t const *b_column = packed_data + column * depth_padded;
1136
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1137
- vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1138
- vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1139
- vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1140
- vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1188
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1189
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1190
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1191
+ vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1192
+ vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1141
1193
 
1142
1194
  nk_size_t remaining = depth;
1143
1195
  nk_size_t k = 0;
@@ -1166,13 +1218,13 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
1166
1218
  // Horizontal reduce and narrow to f32
1167
1219
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1168
1220
  c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1169
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
1221
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
1170
1222
  c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1171
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
1223
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
1172
1224
  c_row_2[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1173
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, vlmax));
1225
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, max_vector_length));
1174
1226
  c_row_3[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1175
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, vlmax));
1227
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, max_vector_length));
1176
1228
  }
1177
1229
  }
1178
1230
  // Remainder rows (mr < 4)
@@ -1181,8 +1233,8 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
1181
1233
  nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
1182
1234
  for (nk_size_t column = 0; column < column_count; ++column) {
1183
1235
  nk_f32_t const *b_column = packed_data + column * depth_padded;
1184
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1185
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1236
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1237
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1186
1238
  nk_size_t remaining = depth;
1187
1239
  nk_size_t k = 0;
1188
1240
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -1195,7 +1247,7 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
1195
1247
  }
1196
1248
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1197
1249
  c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1198
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
1250
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
1199
1251
  }
1200
1252
  }
1201
1253
  }
@@ -1206,9 +1258,10 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
1206
1258
  * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
1207
1259
  * vectors naturally, so no separate edge kernel is needed.
1208
1260
  */
1209
- NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
1210
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
1211
- nk_dots_packed_bf16_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
1261
+ NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
1262
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
1263
+ nk_size_t c_stride_in_bytes) {
1264
+ nk_dots_packed_bf16_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1212
1265
  }
1213
1266
 
1214
1267
  /**
@@ -1219,18 +1272,18 @@ NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed,
1219
1272
  * Stride is in bytes.
1220
1273
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
1221
1274
  */
1222
- NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1223
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1275
+ NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
1276
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
1224
1277
  nk_size_t row_start, nk_size_t row_count) {
1225
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1226
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
1278
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1279
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
1227
1280
 
1228
1281
  for (nk_size_t i = row_start; i < row_end; ++i) {
1229
- nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride);
1230
- for (nk_size_t j = i; j < n_vectors; ++j) {
1231
- nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride);
1232
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1233
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1282
+ nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride_in_bytes);
1283
+ for (nk_size_t j = i; j < vectors_count; ++j) {
1284
+ nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride_in_bytes);
1285
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1286
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1234
1287
  nk_size_t remaining = depth;
1235
1288
  nk_size_t k = 0;
1236
1289
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -1244,15 +1297,15 @@ NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_
1244
1297
  }
1245
1298
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1246
1299
  nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1247
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
1300
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
1248
1301
  result[i * result_stride_elements + j] = dot;
1249
1302
  }
1250
1303
  }
1251
1304
  }
1252
1305
 
1253
- #pragma endregion // Brain Float 16
1306
+ #pragma endregion BF16 Floats
1254
1307
 
1255
- #pragma region Half Precision Floats
1308
+ #pragma region F16 Floats
1256
1309
 
1257
1310
  /**
1258
1311
  * @brief Compute the packed buffer size for f16 GEMM (B stored as f32).
@@ -1261,11 +1314,11 @@ NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_
1261
1314
  * Layout: column-panel with depth-contiguous f32 values, cache-line padding.
1262
1315
  */
1263
1316
  NK_PUBLIC nk_size_t nk_dots_packed_size_f16_rvv(nk_size_t column_count, nk_size_t depth) {
1264
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1265
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1317
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1318
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1266
1319
  // Break power-of-2 strides for cache associativity
1267
1320
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1268
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1321
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1269
1322
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
1270
1323
  column_count * sizeof(nk_f32_t); // per-column norms
1271
1324
  }
@@ -1278,10 +1331,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_rvv(nk_size_t column_count, nk_size_
1278
1331
  */
1279
1332
  NK_PUBLIC void nk_dots_pack_f16_rvv(nk_f16_t const *b, nk_size_t column_count, nk_size_t depth,
1280
1333
  nk_size_t b_stride_in_bytes, void *b_packed) {
1281
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1282
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1334
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1335
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1283
1336
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1284
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1337
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1285
1338
 
1286
1339
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1287
1340
  header->column_count = (nk_u32_t)column_count;
@@ -1290,7 +1343,15 @@ NK_PUBLIC void nk_dots_pack_f16_rvv(nk_f16_t const *b, nk_size_t column_count, n
1290
1343
 
1291
1344
  nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1292
1345
  nk_size_t total = column_count * depth_padded;
1293
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
1346
+ {
1347
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
1348
+ nk_size_t total_bytes = total * sizeof(nk_f32_t);
1349
+ for (nk_size_t i = 0; i < total_bytes;) {
1350
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
1351
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
1352
+ i += vector_length;
1353
+ }
1354
+ }
1294
1355
 
1295
1356
  for (nk_size_t column = 0; column < column_count; ++column) {
1296
1357
  nk_f16_t const *src = (nk_f16_t const *)((char const *)b + column * b_stride_in_bytes);
@@ -1346,11 +1407,11 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
1346
1407
 
1347
1408
  for (nk_size_t column = 0; column < column_count; ++column) {
1348
1409
  nk_f32_t const *b_column = packed_data + column * depth_padded;
1349
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1350
- vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1351
- vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1352
- vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1353
- vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1410
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1411
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1412
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1413
+ vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1414
+ vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1354
1415
 
1355
1416
  nk_size_t remaining = depth;
1356
1417
  nk_size_t k = 0;
@@ -1379,13 +1440,13 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
1379
1440
  // Horizontal reduce and narrow to f32
1380
1441
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1381
1442
  c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1382
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
1443
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
1383
1444
  c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1384
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
1445
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
1385
1446
  c_row_2[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1386
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, vlmax));
1447
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, max_vector_length));
1387
1448
  c_row_3[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1388
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, vlmax));
1449
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, max_vector_length));
1389
1450
  }
1390
1451
  }
1391
1452
  // Remainder rows (mr < 4)
@@ -1394,8 +1455,8 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
1394
1455
  nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
1395
1456
  for (nk_size_t column = 0; column < column_count; ++column) {
1396
1457
  nk_f32_t const *b_column = packed_data + column * depth_padded;
1397
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1398
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1458
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1459
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1399
1460
  nk_size_t remaining = depth;
1400
1461
  nk_size_t k = 0;
1401
1462
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -1408,7 +1469,7 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
1408
1469
  }
1409
1470
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1410
1471
  c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1411
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
1472
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
1412
1473
  }
1413
1474
  }
1414
1475
  }
@@ -1419,9 +1480,10 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
1419
1480
  * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
1420
1481
  * vectors naturally, so no separate edge kernel is needed.
1421
1482
  */
1422
- NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
1423
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
1424
- nk_dots_packed_f16_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
1483
+ NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
1484
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
1485
+ nk_size_t c_stride_in_bytes) {
1486
+ nk_dots_packed_f16_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1425
1487
  }
1426
1488
 
1427
1489
  /**
@@ -1432,18 +1494,18 @@ NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, n
1432
1494
  * Stride is in bytes.
1433
1495
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
1434
1496
  */
1435
- NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1436
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1497
+ NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
1498
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
1437
1499
  nk_size_t row_start, nk_size_t row_count) {
1438
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1439
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
1500
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1501
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
1440
1502
 
1441
1503
  for (nk_size_t i = row_start; i < row_end; ++i) {
1442
- nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride);
1443
- for (nk_size_t j = i; j < n_vectors; ++j) {
1444
- nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride);
1445
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1446
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1504
+ nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride_in_bytes);
1505
+ for (nk_size_t j = i; j < vectors_count; ++j) {
1506
+ nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride_in_bytes);
1507
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
1508
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
1447
1509
  nk_size_t remaining = depth;
1448
1510
  nk_size_t k = 0;
1449
1511
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -1457,15 +1519,15 @@ NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_ve
1457
1519
  }
1458
1520
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1459
1521
  nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1460
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
1522
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
1461
1523
  result[i * result_stride_elements + j] = dot;
1462
1524
  }
1463
1525
  }
1464
1526
  }
1465
1527
 
1466
- #pragma endregion // Half Precision Floats
1528
+ #pragma endregion F16 Floats
1467
1529
 
1468
- #pragma region Signed 8-bit Integers
1530
+ #pragma region I8 Integers
1469
1531
 
1470
1532
  /**
1471
1533
  * @brief Compute the packed buffer size for i8 GEMM (B stored as i8).
@@ -1474,11 +1536,11 @@ NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_ve
1474
1536
  * Layout: column-panel with depth-contiguous i8 values, cache-line padding.
1475
1537
  */
1476
1538
  NK_PUBLIC nk_size_t nk_dots_packed_size_i8_rvv(nk_size_t column_count, nk_size_t depth) {
1477
- nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
1478
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1539
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
1540
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1479
1541
  // Break power-of-2 strides for cache associativity
1480
1542
  nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
1481
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1543
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1482
1544
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i8_t) +
1483
1545
  column_count * sizeof(nk_u32_t); // per-column norms
1484
1546
  }
@@ -1491,10 +1553,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_rvv(nk_size_t column_count, nk_size_t
1491
1553
  */
1492
1554
  NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t column_count, nk_size_t depth,
1493
1555
  nk_size_t b_stride_in_bytes, void *b_packed) {
1494
- nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
1495
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1556
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
1557
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1496
1558
  nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
1497
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1559
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1498
1560
 
1499
1561
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1500
1562
  header->column_count = (nk_u32_t)column_count;
@@ -1503,12 +1565,25 @@ NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t column_count, nk_
1503
1565
 
1504
1566
  nk_i8_t *packed = (nk_i8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1505
1567
  nk_size_t total = column_count * depth_padded;
1506
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
1568
+ {
1569
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
1570
+ nk_size_t total_bytes = total * sizeof(nk_i8_t);
1571
+ for (nk_size_t i = 0; i < total_bytes;) {
1572
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
1573
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
1574
+ i += vector_length;
1575
+ }
1576
+ }
1507
1577
 
1508
1578
  for (nk_size_t column = 0; column < column_count; ++column) {
1509
1579
  nk_i8_t const *src = (nk_i8_t const *)((char const *)b + column * b_stride_in_bytes);
1510
1580
  nk_i8_t *dst = packed + column * depth_padded;
1511
- for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
1581
+ for (nk_size_t k = 0; k < depth;) {
1582
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(depth - k);
1583
+ __riscv_vse8_v_u8m8((nk_u8_t *)(dst + k), __riscv_vle8_v_u8m8((nk_u8_t const *)(src + k), vector_length),
1584
+ vector_length);
1585
+ k += vector_length;
1586
+ }
1512
1587
  }
1513
1588
 
1514
1589
  // Append per-column norms after packed data
@@ -1524,7 +1599,7 @@ NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t column_count, nk_
1524
1599
  *
1525
1600
  * Vectorizes over the depth dimension (k). For each (row, column) pair:
1526
1601
  * - Load i8 values from A and pre-packed i8 values from B
1527
- * - Widening multiply: i8 x i8 -> i16 via `vwmul`
1602
+ * - Widening multiply: i8 × i8 i16 via `vwmul`
1528
1603
  * - Widen-accumulate: i32 += i16 via `vwadd_wv`
1529
1604
  * - Horizontal reduce via `vredsum`
1530
1605
  *
@@ -1560,11 +1635,11 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
1560
1635
 
1561
1636
  for (nk_size_t column = 0; column < column_count; ++column) {
1562
1637
  nk_i8_t const *b_column = packed_data + column * depth_padded;
1563
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1564
- vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1565
- vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1566
- vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1567
- vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1638
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
1639
+ vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
1640
+ vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
1641
+ vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
1642
+ vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
1568
1643
 
1569
1644
  nk_size_t remaining = depth;
1570
1645
  nk_size_t k = 0;
@@ -1592,13 +1667,13 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
1592
1667
  // Horizontal reduce
1593
1668
  vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
1594
1669
  c_row_0[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1595
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, vlmax));
1670
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, max_vector_length));
1596
1671
  c_row_1[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1597
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, vlmax));
1672
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, max_vector_length));
1598
1673
  c_row_2[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1599
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, vlmax));
1674
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, max_vector_length));
1600
1675
  c_row_3[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1601
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, vlmax));
1676
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, max_vector_length));
1602
1677
  }
1603
1678
  }
1604
1679
  // Remainder rows (mr < 4)
@@ -1607,8 +1682,8 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
1607
1682
  nk_i32_t *c_row = (nk_i32_t *)((char *)c_matrix + row * c_stride_in_bytes);
1608
1683
  for (nk_size_t column = 0; column < column_count; ++column) {
1609
1684
  nk_i8_t const *b_column = packed_data + column * depth_padded;
1610
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1611
- vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1685
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
1686
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
1612
1687
  nk_size_t remaining = depth;
1613
1688
  nk_size_t k = 0;
1614
1689
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -1621,7 +1696,7 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
1621
1696
  }
1622
1697
  vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
1623
1698
  c_row[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1624
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax));
1699
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length));
1625
1700
  }
1626
1701
  }
1627
1702
  }
@@ -1632,31 +1707,32 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
1632
1707
  * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
1633
1708
  * vectors naturally, so no separate edge kernel is needed.
1634
1709
  */
1635
- NK_PUBLIC void nk_dots_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t m, nk_size_t n,
1636
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
1637
- nk_dots_packed_i8_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
1710
+ NK_PUBLIC void nk_dots_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t rows,
1711
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
1712
+ nk_size_t c_stride_in_bytes) {
1713
+ nk_dots_packed_i8_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1638
1714
  }
1639
1715
 
1640
1716
  /**
1641
1717
  * @brief Symmetric i8 GEMM: C = A * A^T, upper triangle + mirror.
1642
1718
  *
1643
1719
  * Uses integer i8 arithmetic with i32 accumulation.
1644
- * Both inputs are i8, widened via i8 x i8 -> i16 -> i32 accumulation.
1720
+ * Both inputs are i8, widened via i8 × i8 i16 i32 accumulation.
1645
1721
  * Stride is in bytes.
1646
1722
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
1647
1723
  */
1648
- NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
1649
- nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
1650
- nk_size_t row_count) {
1651
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_i32_t);
1652
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
1724
+ NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
1725
+ nk_size_t stride_in_bytes, nk_i32_t *result, nk_size_t result_stride_in_bytes,
1726
+ nk_size_t row_start, nk_size_t row_count) {
1727
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_i32_t);
1728
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
1653
1729
 
1654
1730
  for (nk_size_t i = row_start; i < row_end; ++i) {
1655
- nk_i8_t const *a_i = (nk_i8_t const *)((char const *)vectors + i * stride);
1656
- for (nk_size_t j = i; j < n_vectors; ++j) {
1657
- nk_i8_t const *a_j = (nk_i8_t const *)((char const *)vectors + j * stride);
1658
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1659
- vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1731
+ nk_i8_t const *a_i = (nk_i8_t const *)((char const *)vectors + i * stride_in_bytes);
1732
+ for (nk_size_t j = i; j < vectors_count; ++j) {
1733
+ nk_i8_t const *a_j = (nk_i8_t const *)((char const *)vectors + j * stride_in_bytes);
1734
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
1735
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
1660
1736
  nk_size_t remaining = depth;
1661
1737
  nk_size_t k = 0;
1662
1738
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -1669,15 +1745,15 @@ NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vect
1669
1745
  }
1670
1746
  vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
1671
1747
  nk_i32_t dot = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1672
- __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax));
1748
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length));
1673
1749
  result[i * result_stride_elements + j] = dot;
1674
1750
  }
1675
1751
  }
1676
1752
  }
1677
1753
 
1678
- #pragma endregion // Signed 8-bit Integers
1754
+ #pragma endregion I8 Integers
1679
1755
 
1680
- #pragma region Unsigned 8-bit Integers
1756
+ #pragma region U8 Integers
1681
1757
 
1682
1758
  /**
1683
1759
  * @brief Compute the packed buffer size for u8 GEMM (B stored as u8).
@@ -1686,11 +1762,11 @@ NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vect
1686
1762
  * Layout: column-panel with depth-contiguous u8 values, cache-line padding.
1687
1763
  */
1688
1764
  NK_PUBLIC nk_size_t nk_dots_packed_size_u8_rvv(nk_size_t column_count, nk_size_t depth) {
1689
- nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
1690
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1765
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
1766
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1691
1767
  // Break power-of-2 strides for cache associativity
1692
1768
  nk_size_t stride_bytes = depth_padded * sizeof(nk_u8_t);
1693
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1769
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1694
1770
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_u8_t) +
1695
1771
  column_count * sizeof(nk_u32_t); // per-column norms
1696
1772
  }
@@ -1703,10 +1779,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_rvv(nk_size_t column_count, nk_size_t
1703
1779
  */
1704
1780
  NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t column_count, nk_size_t depth,
1705
1781
  nk_size_t b_stride_in_bytes, void *b_packed) {
1706
- nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
1707
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1782
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
1783
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1708
1784
  nk_size_t stride_bytes = depth_padded * sizeof(nk_u8_t);
1709
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1785
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1710
1786
 
1711
1787
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1712
1788
  header->column_count = (nk_u32_t)column_count;
@@ -1715,12 +1791,24 @@ NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t column_count, nk_
1715
1791
 
1716
1792
  nk_u8_t *packed = (nk_u8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1717
1793
  nk_size_t total = column_count * depth_padded;
1718
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
1794
+ {
1795
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
1796
+ nk_size_t total_bytes = total * sizeof(nk_u8_t);
1797
+ for (nk_size_t i = 0; i < total_bytes;) {
1798
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
1799
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
1800
+ i += vector_length;
1801
+ }
1802
+ }
1719
1803
 
1720
1804
  for (nk_size_t column = 0; column < column_count; ++column) {
1721
1805
  nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
1722
1806
  nk_u8_t *dst = packed + column * depth_padded;
1723
- for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
1807
+ for (nk_size_t k = 0; k < depth;) {
1808
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(depth - k);
1809
+ __riscv_vse8_v_u8m8(dst + k, __riscv_vle8_v_u8m8(src + k, vector_length), vector_length);
1810
+ k += vector_length;
1811
+ }
1724
1812
  }
1725
1813
 
1726
1814
  // Append per-column norms after packed data
@@ -1736,7 +1824,7 @@ NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t column_count, nk_
1736
1824
  *
1737
1825
  * Vectorizes over the depth dimension (k). For each (row, column) pair:
1738
1826
  * - Load u8 values from A and pre-packed u8 values from B
1739
- * - Widening multiply: u8 x u8 -> u16 via `vwmulu`
1827
+ * - Widening multiply: u8 × u8 u16 via `vwmulu`
1740
1828
  * - Widen-accumulate: u32 += u16 via `vwaddu_wv`
1741
1829
  * - Horizontal reduce via `vredsum`
1742
1830
  *
@@ -1772,11 +1860,11 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
1772
1860
 
1773
1861
  for (nk_size_t column = 0; column < column_count; ++column) {
1774
1862
  nk_u8_t const *b_column = packed_data + column * depth_padded;
1775
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1776
- vuint32m4_t accumulator_0_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1777
- vuint32m4_t accumulator_1_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1778
- vuint32m4_t accumulator_2_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1779
- vuint32m4_t accumulator_3_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1863
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
1864
+ vuint32m4_t accumulator_0_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
1865
+ vuint32m4_t accumulator_1_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
1866
+ vuint32m4_t accumulator_2_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
1867
+ vuint32m4_t accumulator_3_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
1780
1868
 
1781
1869
  nk_size_t remaining = depth;
1782
1870
  nk_size_t k = 0;
@@ -1804,13 +1892,13 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
1804
1892
  // Horizontal reduce
1805
1893
  vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
1806
1894
  c_row_0[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1807
- __riscv_vredsum_vs_u32m4_u32m1(accumulator_0_u32m4, zero_u32m1, vlmax));
1895
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_0_u32m4, zero_u32m1, max_vector_length));
1808
1896
  c_row_1[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1809
- __riscv_vredsum_vs_u32m4_u32m1(accumulator_1_u32m4, zero_u32m1, vlmax));
1897
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_1_u32m4, zero_u32m1, max_vector_length));
1810
1898
  c_row_2[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1811
- __riscv_vredsum_vs_u32m4_u32m1(accumulator_2_u32m4, zero_u32m1, vlmax));
1899
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_2_u32m4, zero_u32m1, max_vector_length));
1812
1900
  c_row_3[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1813
- __riscv_vredsum_vs_u32m4_u32m1(accumulator_3_u32m4, zero_u32m1, vlmax));
1901
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_3_u32m4, zero_u32m1, max_vector_length));
1814
1902
  }
1815
1903
  }
1816
1904
  // Remainder rows (mr < 4)
@@ -1819,8 +1907,8 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
1819
1907
  nk_u32_t *c_row = (nk_u32_t *)((char *)c_matrix + row * c_stride_in_bytes);
1820
1908
  for (nk_size_t column = 0; column < column_count; ++column) {
1821
1909
  nk_u8_t const *b_column = packed_data + column * depth_padded;
1822
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1823
- vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1910
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
1911
+ vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
1824
1912
  nk_size_t remaining = depth;
1825
1913
  nk_size_t k = 0;
1826
1914
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -1833,7 +1921,7 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
1833
1921
  }
1834
1922
  vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
1835
1923
  c_row[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1836
- __riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, vlmax));
1924
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, max_vector_length));
1837
1925
  }
1838
1926
  }
1839
1927
  }
@@ -1844,31 +1932,32 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
1844
1932
  * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
1845
1933
  * vectors naturally, so no separate edge kernel is needed.
1846
1934
  */
1847
- NK_PUBLIC void nk_dots_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t m, nk_size_t n,
1848
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
1849
- nk_dots_packed_u8_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
1935
+ NK_PUBLIC void nk_dots_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t rows,
1936
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
1937
+ nk_size_t c_stride_in_bytes) {
1938
+ nk_dots_packed_u8_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1850
1939
  }
1851
1940
 
1852
1941
  /**
1853
1942
  * @brief Symmetric u8 GEMM: C = A * A^T, upper triangle + mirror.
1854
1943
  *
1855
1944
  * Uses unsigned integer u8 arithmetic with u32 accumulation.
1856
- * Both inputs are u8, widened via u8 x u8 -> u16 -> u32 accumulation.
1945
+ * Both inputs are u8, widened via u8 × u8 u16 u32 accumulation.
1857
1946
  * Stride is in bytes.
1858
1947
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
1859
1948
  */
1860
- NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
1861
- nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
1862
- nk_size_t row_count) {
1863
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_u32_t);
1864
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
1949
+ NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
1950
+ nk_size_t stride_in_bytes, nk_u32_t *result, nk_size_t result_stride_in_bytes,
1951
+ nk_size_t row_start, nk_size_t row_count) {
1952
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_u32_t);
1953
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
1865
1954
 
1866
1955
  for (nk_size_t i = row_start; i < row_end; ++i) {
1867
- nk_u8_t const *a_i = (nk_u8_t const *)((char const *)vectors + i * stride);
1868
- for (nk_size_t j = i; j < n_vectors; ++j) {
1869
- nk_u8_t const *a_j = (nk_u8_t const *)((char const *)vectors + j * stride);
1870
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1871
- vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1956
+ nk_u8_t const *a_i = (nk_u8_t const *)((char const *)vectors + i * stride_in_bytes);
1957
+ for (nk_size_t j = i; j < vectors_count; ++j) {
1958
+ nk_u8_t const *a_j = (nk_u8_t const *)((char const *)vectors + j * stride_in_bytes);
1959
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
1960
+ vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
1872
1961
  nk_size_t remaining = depth;
1873
1962
  nk_size_t k = 0;
1874
1963
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -1881,18 +1970,18 @@ NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t n_vect
1881
1970
  }
1882
1971
  vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
1883
1972
  nk_u32_t dot = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1884
- __riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, vlmax));
1973
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, max_vector_length));
1885
1974
  result[i * result_stride_elements + j] = dot;
1886
1975
  }
1887
1976
  }
1888
1977
  }
1889
1978
 
1890
- #pragma endregion // Unsigned 8-bit Integers
1979
+ #pragma endregion U8 Integers
1891
1980
 
1892
- #pragma region Quarter Precision E4M3
1981
+ #pragma region E4M3 Floats
1893
1982
 
1894
1983
  /**
1895
- * @brief E4M3 magnitude LUT: 7-bit magnitude -> f32 bit pattern (u32).
1984
+ * @brief E4M3 magnitude LUT: 7-bit magnitude f32 bit pattern (u32).
1896
1985
  * nk_e4m3_magnitude_lut_rvv_[i] = float_to_bits(e4m3_to_f32(i)) for i=0..127.
1897
1986
  * E4M3FN: 4 exponent bits (bias=7), 3 mantissa bits, no infinity,
1898
1987
  * NaN = magnitude 0x7F only.
@@ -1933,10 +2022,10 @@ static nk_u32_t const nk_e4m3_magnitude_lut_rvv_[128] = {
1933
2022
  };
1934
2023
 
1935
2024
  NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_rvv(nk_size_t column_count, nk_size_t depth) {
1936
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1937
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
2025
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2026
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1938
2027
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1939
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
2028
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1940
2029
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
1941
2030
  column_count * sizeof(nk_f32_t); // per-column norms
1942
2031
  }
@@ -1949,10 +2038,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_rvv(nk_size_t column_count, nk_size
1949
2038
  */
1950
2039
  NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t column_count, nk_size_t depth,
1951
2040
  nk_size_t b_stride_in_bytes, void *b_packed) {
1952
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1953
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
2041
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2042
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
1954
2043
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1955
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
2044
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
1956
2045
 
1957
2046
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1958
2047
  header->column_count = (nk_u32_t)column_count;
@@ -1961,7 +2050,15 @@ NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t column_count,
1961
2050
 
1962
2051
  nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1963
2052
  nk_size_t total = column_count * depth_padded;
1964
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
2053
+ {
2054
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
2055
+ nk_size_t total_bytes = total * sizeof(nk_f32_t);
2056
+ for (nk_size_t i = 0; i < total_bytes;) {
2057
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
2058
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
2059
+ i += vector_length;
2060
+ }
2061
+ }
1965
2062
 
1966
2063
  for (nk_size_t column = 0; column < column_count; ++column) {
1967
2064
  nk_e4m3_t const *src = (nk_e4m3_t const *)((char const *)b + column * b_stride_in_bytes);
@@ -1985,7 +2082,7 @@ NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t column_count,
1985
2082
  * - Load raw e4m3 bytes from A, convert on-the-fly via 128-entry f32 LUT gather:
1986
2083
  * extract 7-bit magnitude, zero-extend to u32, compute byte offsets (x4),
1987
2084
  * gather f32 bit patterns, inject sign bit from bit 7 (<<24), reinterpret as f32
1988
- * - Widening FMA: f32xf32 -> f64 via `vfwmacc_vv_f64m4`
2085
+ * - Widening FMA: f32xf32 f64 via `vfwmacc_vv_f64m4`
1989
2086
  *
1990
2087
  * Register tile: process 2 rows per iteration (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
1991
2088
  */
@@ -2014,9 +2111,9 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
2014
2111
 
2015
2112
  for (nk_size_t column = 0; column < column_count; ++column) {
2016
2113
  nk_f32_t const *b_column = packed_data + column * depth_padded;
2017
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2018
- vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2019
- vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2114
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2115
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
2116
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
2020
2117
 
2021
2118
  nk_size_t remaining = depth;
2022
2119
  nk_size_t k = 0;
@@ -2059,7 +2156,7 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
2059
2156
  vfloat32m2_t a_vector_1_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2060
2157
  __riscv_vor_vv_u32m2(bits1_u32m2, sign1_u32m2, vector_length));
2061
2158
 
2062
- // Widening FMA: f32xf32 -> f64
2159
+ // Widening FMA: f32xf32 f64
2063
2160
  accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
2064
2161
  vector_length);
2065
2162
  accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
@@ -2069,9 +2166,9 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
2069
2166
  // Horizontal reduce and narrow to f32
2070
2167
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2071
2168
  c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2072
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
2169
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
2073
2170
  c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2074
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
2171
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
2075
2172
  }
2076
2173
  }
2077
2174
  // Remainder rows
@@ -2080,8 +2177,8 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
2080
2177
  nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
2081
2178
  for (nk_size_t column = 0; column < column_count; ++column) {
2082
2179
  nk_f32_t const *b_column = packed_data + column * depth_padded;
2083
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2084
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2180
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2181
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
2085
2182
  nk_size_t remaining = depth;
2086
2183
  nk_size_t k = 0;
2087
2184
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -2103,7 +2200,7 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
2103
2200
  }
2104
2201
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2105
2202
  c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2106
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
2203
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
2107
2204
  }
2108
2205
  }
2109
2206
  }
@@ -2111,9 +2208,10 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
2111
2208
  /**
2112
2209
  * @brief Public e4m3 packed GEMM wrapper matching the declared signature in dots.h.
2113
2210
  */
2114
- NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
2115
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
2116
- nk_dots_packed_e4m3_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
2211
+ NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
2212
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
2213
+ nk_size_t c_stride_in_bytes) {
2214
+ nk_dots_packed_e4m3_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
2117
2215
  }
2118
2216
 
2119
2217
  /**
@@ -2123,18 +2221,18 @@ NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed,
2123
2221
  * Both operands are converted from e4m3 on-the-fly via magnitude LUT.
2124
2222
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
2125
2223
  */
2126
- NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2127
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2224
+ NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
2225
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
2128
2226
  nk_size_t row_start, nk_size_t row_count) {
2129
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
2130
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
2227
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
2228
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
2131
2229
 
2132
2230
  for (nk_size_t i = row_start; i < row_end; ++i) {
2133
- nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
2134
- for (nk_size_t j = i; j < n_vectors; ++j) {
2135
- nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
2136
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2137
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2231
+ nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride_in_bytes;
2232
+ for (nk_size_t j = i; j < vectors_count; ++j) {
2233
+ nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride_in_bytes;
2234
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2235
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
2138
2236
  nk_size_t remaining = depth;
2139
2237
  nk_size_t k = 0;
2140
2238
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -2166,24 +2264,24 @@ NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t n_
2166
2264
  vfloat32m2_t val_j_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2167
2265
  __riscv_vor_vv_u32m2(bits_j_u32m2, sign_j_u32m2, vector_length));
2168
2266
 
2169
- // Widening FMA: f32xf32 -> f64
2267
+ // Widening FMA: f32xf32 f64
2170
2268
  accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, val_i_f32m2, val_j_f32m2,
2171
2269
  vector_length);
2172
2270
  }
2173
2271
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2174
2272
  nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2175
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
2273
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
2176
2274
  result[i * result_stride_elements + j] = dot;
2177
2275
  }
2178
2276
  }
2179
2277
  }
2180
2278
 
2181
- #pragma endregion // Quarter Precision E4M3
2279
+ #pragma endregion E4M3 Floats
2182
2280
 
2183
- #pragma region Quarter Precision E5M2
2281
+ #pragma region E5M2 Floats
2184
2282
 
2185
2283
  /**
2186
- * @brief E5M2 magnitude LUT: 7-bit magnitude -> f32 bit pattern (u32).
2284
+ * @brief E5M2 magnitude LUT: 7-bit magnitude f32 bit pattern (u32).
2187
2285
  * nk_e5m2_magnitude_lut_rvv_[i] = float_to_bits(e5m2_to_f32(i)) for i=0..127.
2188
2286
  * E5M2: 5 exponent bits (bias=15), 2 mantissa bits, has infinity (0x7C) and
2189
2287
  * NaN (magnitudes 0x7D..0x7F).
@@ -2224,10 +2322,10 @@ static nk_u32_t const nk_e5m2_magnitude_lut_rvv_[128] = {
2224
2322
  };
2225
2323
 
2226
2324
  NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_rvv(nk_size_t column_count, nk_size_t depth) {
2227
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
2228
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
2325
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2326
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
2229
2327
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
2230
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
2328
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
2231
2329
  return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
2232
2330
  column_count * sizeof(nk_f32_t); // per-column norms
2233
2331
  }
@@ -2240,10 +2338,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_rvv(nk_size_t column_count, nk_size
2240
2338
  */
2241
2339
  NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth,
2242
2340
  nk_size_t b_stride_in_bytes, void *b_packed) {
2243
- nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
2244
- nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
2341
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2342
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
2245
2343
  nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
2246
- if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
2344
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
2247
2345
 
2248
2346
  nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
2249
2347
  header->column_count = (nk_u32_t)column_count;
@@ -2252,7 +2350,15 @@ NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t column_count,
2252
2350
 
2253
2351
  nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
2254
2352
  nk_size_t total = column_count * depth_padded;
2255
- for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
2353
+ {
2354
+ nk_u8_t *zero_ptr = (nk_u8_t *)packed;
2355
+ nk_size_t total_bytes = total * sizeof(nk_f32_t);
2356
+ for (nk_size_t i = 0; i < total_bytes;) {
2357
+ nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
2358
+ __riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
2359
+ i += vector_length;
2360
+ }
2361
+ }
2256
2362
 
2257
2363
  for (nk_size_t column = 0; column < column_count; ++column) {
2258
2364
  nk_e5m2_t const *src = (nk_e5m2_t const *)((char const *)b + column * b_stride_in_bytes);
@@ -2276,7 +2382,7 @@ NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t column_count,
2276
2382
  * - Load raw e5m2 bytes from A, convert on-the-fly via 128-entry f32 LUT gather:
2277
2383
  * extract 7-bit magnitude, zero-extend to u32, compute byte offsets (x4),
2278
2384
  * gather f32 bit patterns, inject sign bit from bit 7 (<<24), reinterpret as f32
2279
- * - Widening FMA: f32xf32 -> f64 via `vfwmacc_vv_f64m4`
2385
+ * - Widening FMA: f32xf32 f64 via `vfwmacc_vv_f64m4`
2280
2386
  *
2281
2387
  * Register tile: process 2 rows per iteration (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
2282
2388
  */
@@ -2305,9 +2411,9 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
2305
2411
 
2306
2412
  for (nk_size_t column = 0; column < column_count; ++column) {
2307
2413
  nk_f32_t const *b_column = packed_data + column * depth_padded;
2308
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2309
- vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2310
- vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2414
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2415
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
2416
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
2311
2417
 
2312
2418
  nk_size_t remaining = depth;
2313
2419
  nk_size_t k = 0;
@@ -2350,7 +2456,7 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
2350
2456
  vfloat32m2_t a_vector_1_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2351
2457
  __riscv_vor_vv_u32m2(bits1_u32m2, sign1_u32m2, vector_length));
2352
2458
 
2353
- // Widening FMA: f32xf32 -> f64
2459
+ // Widening FMA: f32xf32 f64
2354
2460
  accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
2355
2461
  vector_length);
2356
2462
  accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
@@ -2360,9 +2466,9 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
2360
2466
  // Horizontal reduce and narrow to f32
2361
2467
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2362
2468
  c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2363
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
2469
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
2364
2470
  c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2365
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
2471
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
2366
2472
  }
2367
2473
  }
2368
2474
  // Remainder rows
@@ -2371,8 +2477,8 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
2371
2477
  nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
2372
2478
  for (nk_size_t column = 0; column < column_count; ++column) {
2373
2479
  nk_f32_t const *b_column = packed_data + column * depth_padded;
2374
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2375
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2480
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2481
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
2376
2482
  nk_size_t remaining = depth;
2377
2483
  nk_size_t k = 0;
2378
2484
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -2394,7 +2500,7 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
2394
2500
  }
2395
2501
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2396
2502
  c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2397
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
2503
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
2398
2504
  }
2399
2505
  }
2400
2506
  }
@@ -2402,9 +2508,10 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
2402
2508
  /**
2403
2509
  * @brief Public e5m2 packed GEMM wrapper matching the declared signature in dots.h.
2404
2510
  */
2405
- NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
2406
- nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
2407
- nk_dots_packed_e5m2_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
2511
+ NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
2512
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
2513
+ nk_size_t c_stride_in_bytes) {
2514
+ nk_dots_packed_e5m2_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
2408
2515
  }
2409
2516
 
2410
2517
  /**
@@ -2414,18 +2521,18 @@ NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed,
2414
2521
  * Both operands are converted from e5m2 on-the-fly via magnitude LUT.
2415
2522
  * Processes only the rows in [row_start, row_start + row_count) for parallelism.
2416
2523
  */
2417
- NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2418
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2524
+ NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
2525
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
2419
2526
  nk_size_t row_start, nk_size_t row_count) {
2420
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
2421
- nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
2527
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
2528
+ nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
2422
2529
 
2423
2530
  for (nk_size_t i = row_start; i < row_end; ++i) {
2424
- nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
2425
- for (nk_size_t j = i; j < n_vectors; ++j) {
2426
- nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
2427
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2428
- vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2531
+ nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride_in_bytes;
2532
+ for (nk_size_t j = i; j < vectors_count; ++j) {
2533
+ nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride_in_bytes;
2534
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
2535
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
2429
2536
  nk_size_t remaining = depth;
2430
2537
  nk_size_t k = 0;
2431
2538
  for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
@@ -2457,19 +2564,19 @@ NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t n_
2457
2564
  vfloat32m2_t val_j_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2458
2565
  __riscv_vor_vv_u32m2(bits_j_u32m2, sign_j_u32m2, vector_length));
2459
2566
 
2460
- // Widening FMA: f32xf32 -> f64
2567
+ // Widening FMA: f32xf32 f64
2461
2568
  accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, val_i_f32m2, val_j_f32m2,
2462
2569
  vector_length);
2463
2570
  }
2464
2571
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2465
2572
  nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2466
- __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
2573
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
2467
2574
  result[i * result_stride_elements + j] = dot;
2468
2575
  }
2469
2576
  }
2470
2577
  }
2471
2578
 
2472
- #pragma endregion // Quarter Precision E5M2
2579
+ #pragma endregion E5M2 Floats
2473
2580
 
2474
2581
  #if defined(__cplusplus)
2475
2582
  } // extern "C"