numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,2486 @@
1
+ /**
2
+ * @brief SIMD-accelerated Batched Dot Products for RISC-V.
3
+ * @file include/numkong/dots/rvv.h
4
+ * @author Ash Vardanian
5
+ * @date February 6, 2026
6
+ *
7
+ * @sa include/numkong/dots.h
8
+ *
9
+ * Custom RVV-native register-tiled GEMM implementation, analogous to how AMX
10
+ * (dots/sapphireamx.h) and SME (dots/sme.h) each have their own unique implementations
11
+ * independent of the cross-product macros.
12
+ *
13
+ * RVV's variable-length vectors and widening multiply-accumulate (`vfwmacc`) make it
14
+ * fundamentally different from fixed-width SIMD. Key design choices:
15
+ *
16
+ * - f32 GEMM: Uses `vfwmacc_vv_f64m4` for f64 accumulation (vector-vector widened FMA),
17
+ * Process 4 rows per tile (rows_per_tile=4). Narrowed to f32 on store.
18
+ * - f64 GEMM: Uses `vfmul`+Kahan with Kahan compensation,
19
+ * Process 2 rows per tile (rows_per_tile=2, tighter register budget at LMUL=4).
20
+ * - B packing: Column-panel layout with cache-line padding. Each depth step stores
21
+ * contiguous elements along depth — one `vle32`/`vle64` per vectorized chunk.
22
+ * - Edge handling: RVV's `vsetvl` returns actual VL for partial vectors — no separate
23
+ * edge kernel needed.
24
+ * - Vectorization axis: depth (k dimension). Each inner loop iteration loads a chunk of
25
+ * both A and B along depth, computing element-wise widened FMA.
26
+ *
27
+ * - e2m3 GEMM: Integer arithmetic via LUT (5-bit magnitude → i8 value×16).
28
+ * B is pre-packed as signed i8. A is converted on-the-fly via `vluxei8` gather.
29
+ * Uses `vwmul` (i8→i16) then `vwadd_wv` (i32+=i16) for K-vectorized accumulation.
30
+ * Final result scaled by 1/256. Process 4 rows per tile (rows_per_tile=4).
31
+ * - e3m2 GEMM: Integer arithmetic via LUT (5-bit magnitude → i16 value×16).
32
+ * B is pre-packed as signed i16. A is converted on-the-fly via `vluxei16` gather.
33
+ * Uses `vwmacc` (i16×i16→i32) for K-vectorized widening MAC.
34
+ * Final result scaled by 1/256. Process 2 rows per tile (rows_per_tile=2, wider accumulator elements).
35
+ * - e4m3 GEMM: f32 LUT gather (7-bit magnitude → f32 bit pattern, 128 entries).
36
+ * B is pre-packed as f32. A is converted on-the-fly via `vluxei32` gather with
37
+ * sign injection (bit 7 → bit 31). Uses `vfwmacc_vv_f64m4` for f64 accumulation.
38
+ * Process 2 rows per tile (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
39
+ * - e5m2 GEMM: Same f32 LUT gather approach as e4m3, different LUT contents.
40
+ * E5M2 has 5 exponent bits (wider range, lower precision than e4m3).
41
+ * Process 2 rows per tile (rows_per_tile=2).
42
+ */
43
+ #ifndef NK_DOTS_RVV_H
44
+ #define NK_DOTS_RVV_H
45
+
46
+ #if NK_TARGET_RISCV_
47
+ #if NK_TARGET_RVV
48
+
49
+ #include "numkong/types.h"
50
+ #include "numkong/dots/serial.h"
51
+ #include "numkong/cast/rvv.h" // `nk_bf16m1_to_f32m2_rvv_`
52
+
53
+ #if defined(__clang__)
54
+ #pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
55
+ #elif defined(__GNUC__)
56
+ #pragma GCC push_options
57
+ #pragma GCC target("arch=+v")
58
+ #endif
59
+
60
+ #if defined(__cplusplus)
61
+ extern "C" {
62
+ #endif
63
+
64
+ /**
65
+ * @brief E2M3 magnitude LUT: 5-bit magnitude → unsigned value×16 (u8).
66
+ * Shared across scalar helper, packed kernel, and symmetric kernel.
67
+ */
68
+ static nk_u8_t const nk_e2m3_magnitude_lut_rvv_[32] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20,
69
+ 22, 24, 26, 28, 30, 32, 36, 40, 44, 48, 52,
70
+ 56, 60, 64, 72, 80, 88, 96, 104, 112, 120};
71
+
72
+ /**
73
+ * @brief E3M2 magnitude LUT: 5-bit magnitude → unsigned value×16 (u16).
74
+ * Shared across scalar helper, packed kernel, and symmetric kernel.
75
+ */
76
+ static nk_u16_t const nk_e3m2_magnitude_lut_rvv_[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12,
77
+ 14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80,
78
+ 96, 112, 128, 160, 192, 224, 256, 320, 384, 448};
79
+
80
+ #pragma region Single Precision Floats
81
+
82
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_rvv(nk_size_t column_count, nk_size_t depth) {
83
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
84
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
85
+ // Break power-of-2 strides for cache associativity
86
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
87
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
88
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
89
+ column_count * sizeof(nk_f64_t); // per-column norms
90
+ }
91
+
92
+ NK_PUBLIC void nk_dots_pack_f32_rvv(nk_f32_t const *b, nk_size_t column_count, nk_size_t depth,
93
+ nk_size_t b_stride_in_bytes, void *b_packed) {
94
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
95
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
96
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
97
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
98
+
99
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
100
+ header->column_count = (nk_u32_t)column_count;
101
+ header->depth_dimensions = (nk_u32_t)depth;
102
+ header->depth_padded_values = (nk_u32_t)depth_padded;
103
+
104
+ nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
105
+ nk_size_t total = column_count * depth_padded;
106
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
107
+
108
+ for (nk_size_t column = 0; column < column_count; ++column) {
109
+ nk_f32_t const *src = (nk_f32_t const *)((char const *)b + column * b_stride_in_bytes);
110
+ nk_f32_t *dst = packed + column * depth_padded;
111
+ for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
112
+ }
113
+
114
+ // Append per-column norms after packed data
115
+ nk_f64_t *norms = (nk_f64_t *)(packed + total);
116
+ for (nk_size_t column = 0; column < column_count; ++column) {
117
+ nk_f32_t const *src = (nk_f32_t const *)((char const *)b + column * b_stride_in_bytes);
118
+ norms[column] = nk_dots_reduce_sumsq_f32_(src, depth);
119
+ }
120
+ }
121
+
122
+ /**
123
+ * @brief f32 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
124
+ *
125
+ * Vectorizes over the depth dimension (k). For each (row, column) pair:
126
+ * acc_f64 = sum_k f64(a[row][k]) * f64(b_packed[column][k])
127
+ * using `vfwmacc_vv_f64m4` which widens both operands from f32m2 to f64m4.
128
+ *
129
+ * Register tile: process 4 rows per iteration (rows_per_tile=4).
130
+ * Each row loads its own A vector; B vector is shared across rows per depth chunk.
131
+ */
132
+ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void const *b_packed_buffer,
133
+ nk_f64_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
134
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
135
+ nk_size_t c_stride_in_bytes) {
136
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
137
+ nk_size_t const depth_padded = header->depth_padded_values;
138
+ nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
139
+ sizeof(nk_cross_packed_buffer_header_t));
140
+
141
+ // Zero output matrix
142
+ for (nk_size_t i = 0; i < row_count; ++i) {
143
+ nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + i * c_stride_in_bytes);
144
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
145
+ }
146
+
147
+ // mr=4 register tile over rows
148
+ nk_size_t row = 0;
149
+ for (; row + 4 <= row_count; row += 4) {
150
+ nk_f32_t const *a_row_0 = (nk_f32_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
151
+ nk_f32_t const *a_row_1 = (nk_f32_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
152
+ nk_f32_t const *a_row_2 = (nk_f32_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
153
+ nk_f32_t const *a_row_3 = (nk_f32_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
154
+ nk_f64_t *c_row_0 = (nk_f64_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
155
+ nk_f64_t *c_row_1 = (nk_f64_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
156
+ nk_f64_t *c_row_2 = (nk_f64_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
157
+ nk_f64_t *c_row_3 = (nk_f64_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
158
+
159
+ for (nk_size_t column = 0; column < column_count; ++column) {
160
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
161
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
162
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
163
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
164
+ vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
165
+ vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
166
+
167
+ nk_size_t remaining = depth;
168
+ nk_size_t k = 0;
169
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
170
+ vector_length = __riscv_vsetvl_e32m2(remaining);
171
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
172
+ vfloat32m2_t a_vector_0_f32m2 = __riscv_vle32_v_f32m2(a_row_0 + k, vector_length);
173
+ vfloat32m2_t a_vector_1_f32m2 = __riscv_vle32_v_f32m2(a_row_1 + k, vector_length);
174
+ vfloat32m2_t a_vector_2_f32m2 = __riscv_vle32_v_f32m2(a_row_2 + k, vector_length);
175
+ vfloat32m2_t a_vector_3_f32m2 = __riscv_vle32_v_f32m2(a_row_3 + k, vector_length);
176
+ accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
177
+ vector_length);
178
+ accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
179
+ vector_length);
180
+ accumulator_2_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_2_f64m4, a_vector_2_f32m2, b_vector_f32m2,
181
+ vector_length);
182
+ accumulator_3_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_3_f64m4, a_vector_3_f32m2, b_vector_f32m2,
183
+ vector_length);
184
+ }
185
+
186
+ // Horizontal reduce directly to f64
187
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
188
+ c_row_0[column] = __riscv_vfmv_f_s_f64m1_f64(
189
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
190
+ c_row_1[column] = __riscv_vfmv_f_s_f64m1_f64(
191
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
192
+ c_row_2[column] = __riscv_vfmv_f_s_f64m1_f64(
193
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, vlmax));
194
+ c_row_3[column] = __riscv_vfmv_f_s_f64m1_f64(
195
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, vlmax));
196
+ }
197
+ }
198
+ // Remainder rows (mr < 4)
199
+ for (; row < row_count; ++row) {
200
+ nk_f32_t const *a_row = (nk_f32_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
201
+ nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + row * c_stride_in_bytes);
202
+ for (nk_size_t column = 0; column < column_count; ++column) {
203
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
204
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
205
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
206
+ nk_size_t remaining = depth;
207
+ nk_size_t k = 0;
208
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
209
+ vector_length = __riscv_vsetvl_e32m2(remaining);
210
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
211
+ vfloat32m2_t a_vector_f32m2 = __riscv_vle32_v_f32m2(a_row + k, vector_length);
212
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
213
+ vector_length);
214
+ }
215
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
216
+ c_row[column] = __riscv_vfmv_f_s_f64m1_f64(
217
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
218
+ }
219
+ }
220
+ }
221
+
222
+ /**
223
+ * @brief Public f32 packed GEMM wrapper matching the declared signature in dots.h.
224
+ *
225
+ * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
226
+ * vectors naturally, so no separate edge kernel is needed.
227
+ */
228
+ NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t m, nk_size_t n,
229
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
230
+ nk_dots_packed_f32_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
231
+ }
232
+
233
+ /**
234
+ * @brief Symmetric f32 GEMM: C = A * A^T, upper triangle + mirror.
235
+ *
236
+ * Uses f64 widened accumulation via `vfwmacc_vv_f64m4` for precision.
237
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
238
+ */
239
+ NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
240
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
241
+ nk_size_t row_start, nk_size_t row_count) {
242
+ nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
243
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
244
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
245
+
246
+ for (nk_size_t i = row_start; i < row_end; ++i) {
247
+ nk_f32_t const *a_i = vectors + i * stride_elements;
248
+ for (nk_size_t j = i; j < n_vectors; ++j) {
249
+ nk_f32_t const *a_j = vectors + j * stride_elements;
250
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
251
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
252
+ nk_size_t remaining = depth;
253
+ nk_size_t k = 0;
254
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
255
+ vector_length = __riscv_vsetvl_e32m2(remaining);
256
+ vfloat32m2_t a_vector_f32m2 = __riscv_vle32_v_f32m2(a_i + k, vector_length);
257
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(a_j + k, vector_length);
258
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
259
+ vector_length);
260
+ }
261
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
262
+ nk_f64_t dot = __riscv_vfmv_f_s_f64m1_f64(
263
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
264
+ result[i * result_stride_elements + j] = dot;
265
+ }
266
+ }
267
+ }
268
+
269
+ #pragma endregion // Single Precision Floats
270
+
271
+ #pragma region Double Precision Floats
272
+
273
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_rvv(nk_size_t column_count, nk_size_t depth) {
274
+ nk_size_t vector_length = __riscv_vsetvlmax_e64m4();
275
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
276
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f64_t);
277
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
278
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f64_t) +
279
+ column_count * sizeof(nk_f64_t); // per-column norms
280
+ }
281
+
282
+ NK_PUBLIC void nk_dots_pack_f64_rvv(nk_f64_t const *b, nk_size_t column_count, nk_size_t depth,
283
+ nk_size_t b_stride_in_bytes, void *b_packed) {
284
+ nk_size_t vector_length = __riscv_vsetvlmax_e64m4();
285
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
286
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f64_t);
287
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
288
+
289
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
290
+ header->column_count = (nk_u32_t)column_count;
291
+ header->depth_dimensions = (nk_u32_t)depth;
292
+ header->depth_padded_values = (nk_u32_t)depth_padded;
293
+
294
+ nk_f64_t *packed = (nk_f64_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
295
+ nk_size_t total = column_count * depth_padded;
296
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
297
+
298
+ for (nk_size_t column = 0; column < column_count; ++column) {
299
+ nk_f64_t const *src = (nk_f64_t const *)((char const *)b + column * b_stride_in_bytes);
300
+ nk_f64_t *dst = packed + column * depth_padded;
301
+ for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
302
+ }
303
+
304
+ // Append per-column norms after packed data
305
+ nk_f64_t *norms = (nk_f64_t *)(packed + total);
306
+ for (nk_size_t column = 0; column < column_count; ++column) {
307
+ nk_f64_t const *src = (nk_f64_t const *)((char const *)b + column * b_stride_in_bytes);
308
+ norms[column] = nk_dots_reduce_sumsq_f64_(src, depth);
309
+ }
310
+ }
311
+
312
+ /**
313
+ * @brief f64 packed GEMM kernel: C += A * B_packed^T with Kahan compensation.
314
+ *
315
+ * Vectorizes over depth dimension k using `vfmul`+Kahan (vector-vector multiply).
316
+ * Uses Kahan summation over full depth to maintain precision.
317
+ * Register tile: process 2 rows per iteration (rows_per_tile=2, budget: 32 regs at LMUL=4).
318
+ */
319
+ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void const *b_packed_buffer,
320
+ nk_f64_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
321
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
322
+ nk_size_t c_stride_in_bytes) {
323
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
324
+ nk_size_t const depth_padded = header->depth_padded_values;
325
+ nk_f64_t const *packed_data = (nk_f64_t const *)((char const *)b_packed_buffer +
326
+ sizeof(nk_cross_packed_buffer_header_t));
327
+
328
+ // Zero output matrix
329
+ for (nk_size_t i = 0; i < row_count; ++i) {
330
+ nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + i * c_stride_in_bytes);
331
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
332
+ }
333
+
334
+ // Process 2 rows per tile (rows_per_tile=2, tighter register budget for f64 at LMUL=4)
335
+ nk_size_t row = 0;
336
+ for (; row + 2 <= row_count; row += 2) {
337
+ nk_f64_t const *a_row_0 = (nk_f64_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
338
+ nk_f64_t const *a_row_1 = (nk_f64_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
339
+ nk_f64_t *c_row_0 = (nk_f64_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
340
+ nk_f64_t *c_row_1 = (nk_f64_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
341
+
342
+ for (nk_size_t column = 0; column < column_count; ++column) {
343
+ nk_f64_t const *b_column = packed_data + column * depth_padded;
344
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
345
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
346
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
347
+ vfloat64m4_t compensation_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
348
+ vfloat64m4_t compensation_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
349
+
350
+ nk_size_t remaining = depth;
351
+ nk_size_t k = 0;
352
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
353
+ vector_length = __riscv_vsetvl_e64m4(remaining);
354
+ vfloat64m4_t b_vector_f64m4 = __riscv_vle64_v_f64m4(b_column + k, vector_length);
355
+ vfloat64m4_t a_vector_0_f64m4 = __riscv_vle64_v_f64m4(a_row_0 + k, vector_length);
356
+ vfloat64m4_t a_vector_1_f64m4 = __riscv_vle64_v_f64m4(a_row_1 + k, vector_length);
357
+
358
+ // Kahan step for row 0: product = a*b; corrected = product - comp; running = acc + corrected; comp =
359
+ // (running - acc) - corrected; acc = running
360
+ vfloat64m4_t product_0_f64m4 = __riscv_vfmul_vv_f64m4(a_vector_0_f64m4, b_vector_f64m4, vector_length);
361
+ vfloat64m4_t corrected_term_0_f64m4 = __riscv_vfsub_vv_f64m4(product_0_f64m4, compensation_0_f64m4,
362
+ vector_length);
363
+ vfloat64m4_t running_sum_0_f64m4 = __riscv_vfadd_vv_f64m4_tu(accumulator_0_f64m4, accumulator_0_f64m4,
364
+ corrected_term_0_f64m4, vector_length);
365
+ compensation_0_f64m4 = __riscv_vfsub_vv_f64m4_tu(
366
+ compensation_0_f64m4,
367
+ __riscv_vfsub_vv_f64m4(running_sum_0_f64m4, accumulator_0_f64m4, vector_length),
368
+ corrected_term_0_f64m4, vector_length);
369
+ accumulator_0_f64m4 = running_sum_0_f64m4;
370
+
371
+ // Kahan step for row 1
372
+ vfloat64m4_t product_1_f64m4 = __riscv_vfmul_vv_f64m4(a_vector_1_f64m4, b_vector_f64m4, vector_length);
373
+ vfloat64m4_t corrected_term_1_f64m4 = __riscv_vfsub_vv_f64m4(product_1_f64m4, compensation_1_f64m4,
374
+ vector_length);
375
+ vfloat64m4_t running_sum_1_f64m4 = __riscv_vfadd_vv_f64m4_tu(accumulator_1_f64m4, accumulator_1_f64m4,
376
+ corrected_term_1_f64m4, vector_length);
377
+ compensation_1_f64m4 = __riscv_vfsub_vv_f64m4_tu(
378
+ compensation_1_f64m4,
379
+ __riscv_vfsub_vv_f64m4(running_sum_1_f64m4, accumulator_1_f64m4, vector_length),
380
+ corrected_term_1_f64m4, vector_length);
381
+ accumulator_1_f64m4 = running_sum_1_f64m4;
382
+ }
383
+
384
+ // Horizontal reduce
385
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
386
+ c_row_0[column] = __riscv_vfmv_f_s_f64m1_f64(
387
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
388
+ c_row_1[column] = __riscv_vfmv_f_s_f64m1_f64(
389
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
390
+ }
391
+ }
392
+ // Remainder rows
393
+ for (; row < row_count; ++row) {
394
+ nk_f64_t const *a_row = (nk_f64_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
395
+ nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + row * c_stride_in_bytes);
396
+ for (nk_size_t column = 0; column < column_count; ++column) {
397
+ nk_f64_t const *b_column = packed_data + column * depth_padded;
398
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
399
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
400
+ vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
401
+
402
+ nk_size_t remaining = depth;
403
+ nk_size_t k = 0;
404
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
405
+ vector_length = __riscv_vsetvl_e64m4(remaining);
406
+ vfloat64m4_t b_vector_f64m4 = __riscv_vle64_v_f64m4(b_column + k, vector_length);
407
+ vfloat64m4_t a_vector_f64m4 = __riscv_vle64_v_f64m4(a_row + k, vector_length);
408
+
409
+ vfloat64m4_t product_f64m4 = __riscv_vfmul_vv_f64m4(a_vector_f64m4, b_vector_f64m4, vector_length);
410
+ vfloat64m4_t corrected_term_f64m4 = __riscv_vfsub_vv_f64m4(product_f64m4, compensation_f64m4,
411
+ vector_length);
412
+ vfloat64m4_t running_sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(accumulator_f64m4, accumulator_f64m4,
413
+ corrected_term_f64m4, vector_length);
414
+ compensation_f64m4 = __riscv_vfsub_vv_f64m4_tu(
415
+ compensation_f64m4, __riscv_vfsub_vv_f64m4(running_sum_f64m4, accumulator_f64m4, vector_length),
416
+ corrected_term_f64m4, vector_length);
417
+ accumulator_f64m4 = running_sum_f64m4;
418
+ }
419
+
420
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
421
+ c_row[column] = __riscv_vfmv_f_s_f64m1_f64(
422
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
423
+ }
424
+ }
425
+ }
426
+
427
+ /**
428
+ * @brief Public f64 packed GEMM wrapper matching the declared signature in dots.h.
429
+ */
430
+ NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t m, nk_size_t n,
431
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
432
+ nk_dots_packed_f64_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
433
+ }
434
+
435
+ /**
436
+ * @brief Symmetric f64 GEMM: C = A * A^T, upper triangle + mirror.
437
+ *
438
+ * Uses Kahan compensation over full depth for precision.
439
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
440
+ */
441
+ NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
442
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
443
+ nk_size_t row_start, nk_size_t row_count) {
444
+ nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
445
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
446
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
447
+
448
+ for (nk_size_t i = row_start; i < row_end; ++i) {
449
+ nk_f64_t const *a_i = vectors + i * stride_elements;
450
+ for (nk_size_t j = i; j < n_vectors; ++j) {
451
+ nk_f64_t const *a_j = vectors + j * stride_elements;
452
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
453
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
454
+ vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
455
+
456
+ nk_size_t remaining = depth;
457
+ nk_size_t k = 0;
458
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
459
+ vector_length = __riscv_vsetvl_e64m4(remaining);
460
+ vfloat64m4_t a_vector_f64m4 = __riscv_vle64_v_f64m4(a_i + k, vector_length);
461
+ vfloat64m4_t b_vector_f64m4 = __riscv_vle64_v_f64m4(a_j + k, vector_length);
462
+
463
+ vfloat64m4_t product_f64m4 = __riscv_vfmul_vv_f64m4(a_vector_f64m4, b_vector_f64m4, vector_length);
464
+ vfloat64m4_t corrected_term_f64m4 = __riscv_vfsub_vv_f64m4(product_f64m4, compensation_f64m4,
465
+ vector_length);
466
+ vfloat64m4_t running_sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(accumulator_f64m4, accumulator_f64m4,
467
+ corrected_term_f64m4, vector_length);
468
+ compensation_f64m4 = __riscv_vfsub_vv_f64m4_tu(
469
+ compensation_f64m4, __riscv_vfsub_vv_f64m4(running_sum_f64m4, accumulator_f64m4, vector_length),
470
+ corrected_term_f64m4, vector_length);
471
+ accumulator_f64m4 = running_sum_f64m4;
472
+ }
473
+
474
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
475
+ nk_f64_t dot = __riscv_vfmv_f_s_f64m1_f64(
476
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
477
+ result[i * result_stride_elements + j] = dot;
478
+ }
479
+ }
480
+ }
481
+
482
+ #pragma endregion // Double Precision Floats
483
+
484
+ #pragma region Micro Precision E2M3
485
+
486
+ /**
487
+ * @brief Scalar conversion helper: e2m3 byte → signed i8 (value × 16).
488
+ *
489
+ * Extracts 5-bit magnitude, looks up in LUT, applies sign from bit 5.
490
+ * Every e2m3 value × 16 is an exact integer in [-120, +120], fitting in i8.
491
+ */
492
+ NK_INTERNAL nk_i8_t nk_e2m3_to_i8_rvv_(nk_u8_t raw) {
493
+ nk_u8_t magnitude = raw & 0x1Fu;
494
+ nk_i8_t val = (nk_i8_t)nk_e2m3_magnitude_lut_rvv_[magnitude];
495
+ return (raw & 0x20u) ? (nk_i8_t)(-val) : val;
496
+ }
497
+
498
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_rvv(nk_size_t column_count, nk_size_t depth) {
499
+ nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
500
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
501
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
502
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
503
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i8_t) +
504
+ column_count * sizeof(nk_f32_t); // per-column norms
505
+ }
506
+
507
+ /**
508
+ * @brief Pack B matrix from e2m3 to signed i8 (value × 16) for integer dot product.
509
+ *
510
+ * Each e2m3 byte is converted to a signed i8 via scalar LUT lookup.
511
+ * Padding values are zeroed. Column-panel layout with depth-contiguous storage.
512
+ */
513
+ NK_PUBLIC void nk_dots_pack_e2m3_rvv(nk_e2m3_t const *b, nk_size_t column_count, nk_size_t depth,
514
+ nk_size_t b_stride_in_bytes, void *b_packed) {
515
+ nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
516
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
517
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
518
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
519
+
520
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
521
+ header->column_count = (nk_u32_t)column_count;
522
+ header->depth_dimensions = (nk_u32_t)depth;
523
+ header->depth_padded_values = (nk_u32_t)depth_padded;
524
+
525
+ nk_i8_t *packed = (nk_i8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
526
+ nk_size_t total = column_count * depth_padded;
527
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
528
+
529
+ for (nk_size_t column = 0; column < column_count; ++column) {
530
+ nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
531
+ nk_i8_t *dst = packed + column * depth_padded;
532
+ for (nk_size_t k = 0; k < depth; ++k) dst[k] = nk_e2m3_to_i8_rvv_(src[k]);
533
+ }
534
+
535
+ // Append per-column norms after packed data
536
+ nk_f32_t *norms = (nk_f32_t *)(packed + total);
537
+ for (nk_size_t column = 0; column < column_count; ++column) {
538
+ nk_e2m3_t const *src = (nk_e2m3_t const *)((char const *)b + column * b_stride_in_bytes);
539
+ norms[column] = nk_dots_reduce_sumsq_e2m3_(src, depth);
540
+ }
541
+ }
542
+
543
+ /**
544
+ * @brief e2m3 packed GEMM kernel: C += A * B_packed^T with integer i8 LUT arithmetic.
545
+ *
546
+ * Vectorizes over the depth dimension (k). For each (row, column) pair:
547
+ * - Load raw e2m3 bytes from A, extract magnitude via `vluxei8` gather LUT
548
+ * - Apply sign from bit 5 via masked negate to produce signed i8 A values
549
+ * - Load pre-packed signed i8 values from B
550
+ * - Widening multiply i8×i8 → i16, then widen-accumulate i32 += i16
551
+ * - Final result = i32_sum / 256.0f
552
+ *
553
+ * Register tile: process 4 rows per iteration (rows_per_tile=4).
554
+ * The LUT gather on A magnitudes uses `vluxei8_v_u8m1` (byte-indexed byte gather).
555
+ */
556
+ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, void const *b_packed_buffer,
557
+ nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
558
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
559
+ nk_size_t c_stride_in_bytes) {
560
+ nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
561
+
562
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
563
+ nk_size_t const depth_padded = header->depth_padded_values;
564
+ nk_i8_t const *packed_data = (nk_i8_t const *)((char const *)b_packed_buffer +
565
+ sizeof(nk_cross_packed_buffer_header_t));
566
+
567
+ // Zero output matrix
568
+ for (nk_size_t i = 0; i < row_count; ++i) {
569
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
570
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
571
+ }
572
+
573
+ // mr=4 register tile over rows
574
+ nk_size_t row = 0;
575
+ for (; row + 4 <= row_count; row += 4) {
576
+ nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
577
+ nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
578
+ nk_u8_t const *a_row_2 = (nk_u8_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
579
+ nk_u8_t const *a_row_3 = (nk_u8_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
580
+ nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
581
+ nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
582
+ nk_f32_t *c_row_2 = (nk_f32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
583
+ nk_f32_t *c_row_3 = (nk_f32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
584
+
585
+ for (nk_size_t column = 0; column < column_count; ++column) {
586
+ nk_i8_t const *b_column = packed_data + column * depth_padded;
587
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
588
+ vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
589
+ vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
590
+ vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
591
+ vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
592
+
593
+ nk_size_t remaining = depth;
594
+ nk_size_t k = 0;
595
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
596
+ vector_length = __riscv_vsetvl_e8m1(remaining);
597
+
598
+ // Load pre-packed i8 B values
599
+ vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(b_column + k, vector_length);
600
+
601
+ // Load raw e2m3 bytes from each A row and convert via LUT
602
+ vuint8m1_t raw0_u8m1 = __riscv_vle8_v_u8m1(a_row_0 + k, vector_length);
603
+ vuint8m1_t raw1_u8m1 = __riscv_vle8_v_u8m1(a_row_1 + k, vector_length);
604
+ vuint8m1_t raw2_u8m1 = __riscv_vle8_v_u8m1(a_row_2 + k, vector_length);
605
+ vuint8m1_t raw3_u8m1 = __riscv_vle8_v_u8m1(a_row_3 + k, vector_length);
606
+
607
+ // Extract magnitudes and gather from LUT
608
+ vuint8m1_t mag0_u8m1 = __riscv_vand_vx_u8m1(raw0_u8m1, 0x1F, vector_length);
609
+ vuint8m1_t mag1_u8m1 = __riscv_vand_vx_u8m1(raw1_u8m1, 0x1F, vector_length);
610
+ vuint8m1_t mag2_u8m1 = __riscv_vand_vx_u8m1(raw2_u8m1, 0x1F, vector_length);
611
+ vuint8m1_t mag3_u8m1 = __riscv_vand_vx_u8m1(raw3_u8m1, 0x1F, vector_length);
612
+ vuint8m1_t uval0_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag0_u8m1, vector_length);
613
+ vuint8m1_t uval1_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag1_u8m1, vector_length);
614
+ vuint8m1_t uval2_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag2_u8m1, vector_length);
615
+ vuint8m1_t uval3_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag3_u8m1, vector_length);
616
+
617
+ // Apply sign to A: negate where bit 5 is set.
618
+ // B is already signed from packing, so A sign completes the product sign.
619
+ vint8m1_t a_vector_0_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval0_u8m1);
620
+ vbool8_t negated_0_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw0_u8m1, 0x20, vector_length),
621
+ 0, vector_length);
622
+ a_vector_0_i8m1 = __riscv_vneg_v_i8m1_mu(negated_0_b8, a_vector_0_i8m1, a_vector_0_i8m1, vector_length);
623
+
624
+ vint8m1_t a_vector_1_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval1_u8m1);
625
+ vbool8_t negated_1_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw1_u8m1, 0x20, vector_length),
626
+ 0, vector_length);
627
+ a_vector_1_i8m1 = __riscv_vneg_v_i8m1_mu(negated_1_b8, a_vector_1_i8m1, a_vector_1_i8m1, vector_length);
628
+
629
+ vint8m1_t a_vector_2_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval2_u8m1);
630
+ vbool8_t negated_2_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw2_u8m1, 0x20, vector_length),
631
+ 0, vector_length);
632
+ a_vector_2_i8m1 = __riscv_vneg_v_i8m1_mu(negated_2_b8, a_vector_2_i8m1, a_vector_2_i8m1, vector_length);
633
+
634
+ vint8m1_t a_vector_3_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval3_u8m1);
635
+ vbool8_t negated_3_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw3_u8m1, 0x20, vector_length),
636
+ 0, vector_length);
637
+ a_vector_3_i8m1 = __riscv_vneg_v_i8m1_mu(negated_3_b8, a_vector_3_i8m1, a_vector_3_i8m1, vector_length);
638
+
639
+ // Widening multiply: i8×i8 → i16, then accumulate: i32 += i16
640
+ vint16m2_t product_0_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_0_i8m1, b_vector_i8m1, vector_length);
641
+ vint16m2_t product_1_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_1_i8m1, b_vector_i8m1, vector_length);
642
+ vint16m2_t product_2_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_2_i8m1, b_vector_i8m1, vector_length);
643
+ vint16m2_t product_3_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_3_i8m1, b_vector_i8m1, vector_length);
644
+ accumulator_0_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_0_i32m4, accumulator_0_i32m4,
645
+ product_0_i16m2, vector_length);
646
+ accumulator_1_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_1_i32m4, accumulator_1_i32m4,
647
+ product_1_i16m2, vector_length);
648
+ accumulator_2_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_2_i32m4, accumulator_2_i32m4,
649
+ product_2_i16m2, vector_length);
650
+ accumulator_3_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_3_i32m4, accumulator_3_i32m4,
651
+ product_3_i16m2, vector_length);
652
+ }
653
+
654
+ // Horizontal reduce and convert to f32 with scaling
655
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
656
+ c_row_0[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
657
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, vlmax)) *
658
+ lut_scale_reciprocal;
659
+ c_row_1[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
660
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, vlmax)) *
661
+ lut_scale_reciprocal;
662
+ c_row_2[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
663
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, vlmax)) *
664
+ lut_scale_reciprocal;
665
+ c_row_3[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
666
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, vlmax)) *
667
+ lut_scale_reciprocal;
668
+ }
669
+ }
670
+ // Remainder rows (mr < 4)
671
+ for (; row < row_count; ++row) {
672
+ nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
673
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
674
+ for (nk_size_t column = 0; column < column_count; ++column) {
675
+ nk_i8_t const *b_column = packed_data + column * depth_padded;
676
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
677
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
678
+ nk_size_t remaining = depth;
679
+ nk_size_t k = 0;
680
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
681
+ vector_length = __riscv_vsetvl_e8m1(remaining);
682
+ vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(b_column + k, vector_length);
683
+ vuint8m1_t raw_a_u8m1 = __riscv_vle8_v_u8m1(a_row + k, vector_length);
684
+ vuint8m1_t mag_a_u8m1 = __riscv_vand_vx_u8m1(raw_a_u8m1, 0x1F, vector_length);
685
+ vuint8m1_t uval_a_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag_a_u8m1, vector_length);
686
+ vint8m1_t a_vector_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval_a_u8m1);
687
+ vbool8_t negated_a_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw_a_u8m1, 0x20, vector_length),
688
+ 0, vector_length);
689
+ a_vector_i8m1 = __riscv_vneg_v_i8m1_mu(negated_a_b8, a_vector_i8m1, a_vector_i8m1, vector_length);
690
+ vint16m2_t product_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_i8m1, b_vector_i8m1, vector_length);
691
+ accumulator_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_i32m4, accumulator_i32m4, product_i16m2,
692
+ vector_length);
693
+ }
694
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
695
+ c_row[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
696
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
697
+ lut_scale_reciprocal;
698
+ }
699
+ }
700
+ }
701
+
702
+ /**
703
+ * @brief Public e2m3 packed GEMM wrapper matching the declared signature in dots.h.
704
+ */
705
+ NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
706
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
707
+ nk_dots_packed_e2m3_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
708
+ }
709
+
710
+ /**
711
+ * @brief Symmetric e2m3 GEMM: C = A * A^T, upper triangle + mirror.
712
+ *
713
+ * Uses integer i8 LUT arithmetic with i32 accumulation, scaled by 1/256.
714
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
715
+ */
716
+ NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
717
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
718
+ nk_size_t row_start, nk_size_t row_count) {
719
+ nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
720
+
721
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
722
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
723
+
724
+ for (nk_size_t i = row_start; i < row_end; ++i) {
725
+ nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
726
+ for (nk_size_t j = i; j < n_vectors; ++j) {
727
+ nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
728
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
729
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
730
+ nk_size_t remaining = depth;
731
+ nk_size_t k = 0;
732
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
733
+ vector_length = __riscv_vsetvl_e8m1(remaining);
734
+ vuint8m1_t raw_i_u8m1 = __riscv_vle8_v_u8m1(a_i + k, vector_length);
735
+ vuint8m1_t raw_j_u8m1 = __riscv_vle8_v_u8m1(a_j + k, vector_length);
736
+
737
+ // Extract magnitudes and gather from LUT
738
+ vuint8m1_t mag_i_u8m1 = __riscv_vand_vx_u8m1(raw_i_u8m1, 0x1F, vector_length);
739
+ vuint8m1_t mag_j_u8m1 = __riscv_vand_vx_u8m1(raw_j_u8m1, 0x1F, vector_length);
740
+ vuint8m1_t uval_i_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag_i_u8m1, vector_length);
741
+ vuint8m1_t uval_j_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag_j_u8m1, vector_length);
742
+
743
+ // Combined sign: XOR sign bits → conditional negate on B side
744
+ vuint8m1_t sign_xor_u8m1 = __riscv_vand_vx_u8m1(
745
+ __riscv_vxor_vv_u8m1(raw_i_u8m1, raw_j_u8m1, vector_length), 0x20, vector_length);
746
+ vbool8_t negate_b8 = __riscv_vmsne_vx_u8m1_b8(sign_xor_u8m1, 0, vector_length);
747
+ vint8m1_t val_i_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval_i_u8m1);
748
+ vint8m1_t val_j_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval_j_u8m1);
749
+ val_j_i8m1 = __riscv_vneg_v_i8m1_mu(negate_b8, val_j_i8m1, val_j_i8m1, vector_length);
750
+
751
+ // Widening multiply: i8×i8 → i16, then accumulate: i32 += i16
752
+ vint16m2_t product_i16m2 = __riscv_vwmul_vv_i16m2(val_i_i8m1, val_j_i8m1, vector_length);
753
+ accumulator_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_i32m4, accumulator_i32m4, product_i16m2,
754
+ vector_length);
755
+ }
756
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
757
+ nk_f32_t dot = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
758
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
759
+ lut_scale_reciprocal;
760
+ result[i * result_stride_elements + j] = dot;
761
+ }
762
+ }
763
+ }
764
+
765
+ #pragma endregion // Micro Precision E2M3
766
+
767
+ #pragma region Micro Precision E3M2
768
+
769
+ /**
770
+ * @brief Scalar conversion helper: e3m2 byte → signed i16 (value × 16).
771
+ *
772
+ * Extracts 5-bit magnitude, looks up in LUT, applies sign from bit 5.
773
+ * Every e3m2 value × 16 is an exact integer in [-448, +448], requiring i16.
774
+ */
775
+ NK_INTERNAL nk_i16_t nk_e3m2_to_i16_rvv_(nk_u8_t raw) {
776
+ nk_u8_t magnitude = raw & 0x1Fu;
777
+ nk_i16_t val = (nk_i16_t)nk_e3m2_magnitude_lut_rvv_[magnitude];
778
+ return (raw & 0x20u) ? (nk_i16_t)(-val) : val;
779
+ }
780
+
781
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_rvv(nk_size_t column_count, nk_size_t depth) {
782
+ nk_size_t vector_length = __riscv_vsetvlmax_e16m2();
783
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
784
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_i16_t);
785
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
786
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i16_t) +
787
+ column_count * sizeof(nk_f32_t); // per-column norms
788
+ }
789
+
790
+ /**
791
+ * @brief Pack B matrix from e3m2 to signed i16 (value × 16) for integer dot product.
792
+ *
793
+ * Each e3m2 byte is converted to a signed i16 via scalar LUT lookup.
794
+ * Padding values are zeroed. Column-panel layout with depth-contiguous storage.
795
+ */
796
+ NK_PUBLIC void nk_dots_pack_e3m2_rvv(nk_e3m2_t const *b, nk_size_t column_count, nk_size_t depth,
797
+ nk_size_t b_stride_in_bytes, void *b_packed) {
798
+ nk_size_t vector_length = __riscv_vsetvlmax_e16m2();
799
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
800
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_i16_t);
801
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
802
+
803
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
804
+ header->column_count = (nk_u32_t)column_count;
805
+ header->depth_dimensions = (nk_u32_t)depth;
806
+ header->depth_padded_values = (nk_u32_t)depth_padded;
807
+
808
+ nk_i16_t *packed = (nk_i16_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
809
+ nk_size_t total = column_count * depth_padded;
810
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
811
+
812
+ for (nk_size_t column = 0; column < column_count; ++column) {
813
+ nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
814
+ nk_i16_t *dst = packed + column * depth_padded;
815
+ for (nk_size_t k = 0; k < depth; ++k) dst[k] = nk_e3m2_to_i16_rvv_(src[k]);
816
+ }
817
+
818
+ // Append per-column norms after packed data
819
+ nk_f32_t *norms = (nk_f32_t *)(packed + total);
820
+ for (nk_size_t column = 0; column < column_count; ++column) {
821
+ nk_e3m2_t const *src = (nk_e3m2_t const *)((char const *)b + column * b_stride_in_bytes);
822
+ norms[column] = nk_dots_reduce_sumsq_e3m2_(src, depth);
823
+ }
824
+ }
825
+
826
+ /**
827
+ * @brief e3m2 packed GEMM kernel: C += A * B_packed^T with integer i16 LUT arithmetic.
828
+ *
829
+ * Vectorizes over the depth dimension (k). For each (row, column) pair:
830
+ * - Load raw e3m2 bytes from A, convert to signed i16 via `vluxei16` gather LUT
831
+ * - Load pre-packed i16 values from B
832
+ * - Widening multiply-accumulate: i16×i16 → i32 via `vwmacc`
833
+ * - Final result = i32_sum / 256.0f
834
+ *
835
+ * Register tile: process 2 rows per iteration (rows_per_tile=2, wider i16/i32 elements reduce VL).
836
+ * The LUT gather on A magnitudes uses `vluxei16_v_u16m2` (16-bit indexed 16-bit gather).
837
+ */
838
+ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, void const *b_packed_buffer,
839
+ nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
840
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
841
+ nk_size_t c_stride_in_bytes) {
842
+ nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
843
+
844
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
845
+ nk_size_t const depth_padded = header->depth_padded_values;
846
+ nk_i16_t const *packed_data = (nk_i16_t const *)((char const *)b_packed_buffer +
847
+ sizeof(nk_cross_packed_buffer_header_t));
848
+
849
+ // Zero output matrix
850
+ for (nk_size_t i = 0; i < row_count; ++i) {
851
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
852
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
853
+ }
854
+
855
+ // mr=2 register tile (i16 at LMUL=2 and i32 at LMUL=4 leaves fewer spare registers)
856
+ nk_size_t row = 0;
857
+ for (; row + 2 <= row_count; row += 2) {
858
+ nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
859
+ nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
860
+ nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
861
+ nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
862
+
863
+ for (nk_size_t column = 0; column < column_count; ++column) {
864
+ nk_i16_t const *b_column = packed_data + column * depth_padded;
865
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
866
+ vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
867
+ vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
868
+
869
+ nk_size_t remaining = depth;
870
+ nk_size_t k = 0;
871
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
872
+ vector_length = __riscv_vsetvl_e16m2(remaining);
873
+
874
+ // Load pre-packed i16 B values
875
+ vint16m2_t b_vector_i16m2 = __riscv_vle16_v_i16m2(b_column + k, vector_length);
876
+
877
+ // Load raw e3m2 bytes from each A row
878
+ vuint8m1_t raw0_u8m1 = __riscv_vle8_v_u8m1(a_row_0 + k, vector_length);
879
+ vuint8m1_t raw1_u8m1 = __riscv_vle8_v_u8m1(a_row_1 + k, vector_length);
880
+
881
+ // Extract magnitudes, zero-extend to u16, compute byte offsets for i16 LUT gather
882
+ vuint8m1_t mag0_u8m1 = __riscv_vand_vx_u8m1(raw0_u8m1, 0x1F, vector_length);
883
+ vuint8m1_t mag1_u8m1 = __riscv_vand_vx_u8m1(raw1_u8m1, 0x1F, vector_length);
884
+ vuint16m2_t idx0_u16m2 = __riscv_vzext_vf2_u16m2(mag0_u8m1, vector_length);
885
+ vuint16m2_t idx1_u16m2 = __riscv_vzext_vf2_u16m2(mag1_u8m1, vector_length);
886
+ vuint16m2_t off0_u16m2 = __riscv_vsll_vx_u16m2(idx0_u16m2, 1,
887
+ vector_length); // byte offsets = index × 2
888
+ vuint16m2_t off1_u16m2 = __riscv_vsll_vx_u16m2(idx1_u16m2, 1, vector_length);
889
+
890
+ // Gather unsigned magnitudes from i16 LUT
891
+ vuint16m2_t uval0_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off0_u16m2,
892
+ vector_length);
893
+ vuint16m2_t uval1_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off1_u16m2,
894
+ vector_length);
895
+
896
+ // Apply sign: negate where bit 5 is set
897
+ vuint8m1_t sign0_u8m1 = __riscv_vand_vx_u8m1(raw0_u8m1, 0x20, vector_length);
898
+ vuint8m1_t sign1_u8m1 = __riscv_vand_vx_u8m1(raw1_u8m1, 0x20, vector_length);
899
+ vbool8_t negated_0_b8 = __riscv_vmsne_vx_u8m1_b8(sign0_u8m1, 0, vector_length);
900
+ vbool8_t negated_1_b8 = __riscv_vmsne_vx_u8m1_b8(sign1_u8m1, 0, vector_length);
901
+
902
+ vint16m2_t a_vector_0_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval0_u16m2);
903
+ a_vector_0_i16m2 = __riscv_vneg_v_i16m2_mu(negated_0_b8, a_vector_0_i16m2, a_vector_0_i16m2,
904
+ vector_length);
905
+ vint16m2_t a_vector_1_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval1_u16m2);
906
+ a_vector_1_i16m2 = __riscv_vneg_v_i16m2_mu(negated_1_b8, a_vector_1_i16m2, a_vector_1_i16m2,
907
+ vector_length);
908
+
909
+ // Widening multiply-accumulate: i16×i16 → i32
910
+ accumulator_0_i32m4 = __riscv_vwmacc_vv_i32m4_tu(accumulator_0_i32m4, a_vector_0_i16m2, b_vector_i16m2,
911
+ vector_length);
912
+ accumulator_1_i32m4 = __riscv_vwmacc_vv_i32m4_tu(accumulator_1_i32m4, a_vector_1_i16m2, b_vector_i16m2,
913
+ vector_length);
914
+ }
915
+
916
+ // Horizontal reduce and convert to f32 with scaling
917
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
918
+ c_row_0[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
919
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, vlmax)) *
920
+ lut_scale_reciprocal;
921
+ c_row_1[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
922
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, vlmax)) *
923
+ lut_scale_reciprocal;
924
+ }
925
+ }
926
+ // Remainder rows
927
+ for (; row < row_count; ++row) {
928
+ nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
929
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
930
+ for (nk_size_t column = 0; column < column_count; ++column) {
931
+ nk_i16_t const *b_column = packed_data + column * depth_padded;
932
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
933
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
934
+ nk_size_t remaining = depth;
935
+ nk_size_t k = 0;
936
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
937
+ vector_length = __riscv_vsetvl_e16m2(remaining);
938
+ vint16m2_t b_vector_i16m2 = __riscv_vle16_v_i16m2(b_column + k, vector_length);
939
+ vuint8m1_t raw_a_u8m1 = __riscv_vle8_v_u8m1(a_row + k, vector_length);
940
+ vuint8m1_t mag_a_u8m1 = __riscv_vand_vx_u8m1(raw_a_u8m1, 0x1F, vector_length);
941
+ vuint16m2_t idx_a_u16m2 = __riscv_vzext_vf2_u16m2(mag_a_u8m1, vector_length);
942
+ vuint16m2_t off_a_u16m2 = __riscv_vsll_vx_u16m2(idx_a_u16m2, 1, vector_length);
943
+ vuint16m2_t uval_a_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off_a_u16m2,
944
+ vector_length);
945
+ vint16m2_t a_vector_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval_a_u16m2);
946
+ vbool8_t negated_a_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw_a_u8m1, 0x20, vector_length),
947
+ 0, vector_length);
948
+ a_vector_i16m2 = __riscv_vneg_v_i16m2_mu(negated_a_b8, a_vector_i16m2, a_vector_i16m2, vector_length);
949
+ accumulator_i32m4 = __riscv_vwmacc_vv_i32m4_tu(accumulator_i32m4, a_vector_i16m2, b_vector_i16m2,
950
+ vector_length);
951
+ }
952
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
953
+ c_row[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
954
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
955
+ lut_scale_reciprocal;
956
+ }
957
+ }
958
+ }
959
+
960
+ /**
961
+ * @brief Public e3m2 packed GEMM wrapper matching the declared signature in dots.h.
962
+ */
963
+ NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
964
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
965
+ nk_dots_packed_e3m2_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
966
+ }
967
+
968
+ /**
969
+ * @brief Symmetric e3m2 GEMM: C = A * A^T, upper triangle + mirror.
970
+ *
971
+ * Uses integer i16 LUT arithmetic with i32 widening MAC, scaled by 1/256.
972
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
973
+ */
974
+ NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
975
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
976
+ nk_size_t row_start, nk_size_t row_count) {
977
+ nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
978
+
979
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
980
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
981
+
982
+ for (nk_size_t i = row_start; i < row_end; ++i) {
983
+ nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
984
+ for (nk_size_t j = i; j < n_vectors; ++j) {
985
+ nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
986
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
987
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
988
+ nk_size_t remaining = depth;
989
+ nk_size_t k = 0;
990
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
991
+ vector_length = __riscv_vsetvl_e16m2(remaining);
992
+ vuint8m1_t raw_i_u8m1 = __riscv_vle8_v_u8m1(a_i + k, vector_length);
993
+ vuint8m1_t raw_j_u8m1 = __riscv_vle8_v_u8m1(a_j + k, vector_length);
994
+
995
+ // Extract magnitudes, zero-extend to u16, compute byte offsets
996
+ vuint8m1_t mag_i_u8m1 = __riscv_vand_vx_u8m1(raw_i_u8m1, 0x1F, vector_length);
997
+ vuint8m1_t mag_j_u8m1 = __riscv_vand_vx_u8m1(raw_j_u8m1, 0x1F, vector_length);
998
+ vuint16m2_t idx_i_u16m2 = __riscv_vzext_vf2_u16m2(mag_i_u8m1, vector_length);
999
+ vuint16m2_t idx_j_u16m2 = __riscv_vzext_vf2_u16m2(mag_j_u8m1, vector_length);
1000
+ vuint16m2_t off_i_u16m2 = __riscv_vsll_vx_u16m2(idx_i_u16m2, 1, vector_length);
1001
+ vuint16m2_t off_j_u16m2 = __riscv_vsll_vx_u16m2(idx_j_u16m2, 1, vector_length);
1002
+
1003
+ // Gather unsigned magnitudes
1004
+ vuint16m2_t uval_i_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off_i_u16m2,
1005
+ vector_length);
1006
+ vuint16m2_t uval_j_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off_j_u16m2,
1007
+ vector_length);
1008
+
1009
+ // Apply individual signs
1010
+ vuint8m1_t sign_i_u8m1 = __riscv_vand_vx_u8m1(raw_i_u8m1, 0x20, vector_length);
1011
+ vuint8m1_t sign_j_u8m1 = __riscv_vand_vx_u8m1(raw_j_u8m1, 0x20, vector_length);
1012
+ vbool8_t negated_i_b8 = __riscv_vmsne_vx_u8m1_b8(sign_i_u8m1, 0, vector_length);
1013
+ vbool8_t negated_j_b8 = __riscv_vmsne_vx_u8m1_b8(sign_j_u8m1, 0, vector_length);
1014
+
1015
+ vint16m2_t val_i_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval_i_u16m2);
1016
+ val_i_i16m2 = __riscv_vneg_v_i16m2_mu(negated_i_b8, val_i_i16m2, val_i_i16m2, vector_length);
1017
+ vint16m2_t val_j_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval_j_u16m2);
1018
+ val_j_i16m2 = __riscv_vneg_v_i16m2_mu(negated_j_b8, val_j_i16m2, val_j_i16m2, vector_length);
1019
+
1020
+ // Widening multiply-accumulate: i16×i16 → i32
1021
+ accumulator_i32m4 = __riscv_vwmacc_vv_i32m4_tu(accumulator_i32m4, val_i_i16m2, val_j_i16m2,
1022
+ vector_length);
1023
+ }
1024
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
1025
+ nk_f32_t dot = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
1026
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
1027
+ lut_scale_reciprocal;
1028
+ result[i * result_stride_elements + j] = dot;
1029
+ }
1030
+ }
1031
+ }
1032
+
1033
+ #pragma endregion // Micro Precision E3M2
1034
+
1035
+ #pragma region Brain Float 16
1036
+
1037
+ /**
1038
+ * @brief Compute the packed buffer size for bf16 GEMM (B stored as f32).
1039
+ *
1040
+ * VL is determined by `__riscv_vsetvlmax_e32m2()` since B is stored as f32.
1041
+ * Layout: column-panel with depth-contiguous f32 values, cache-line padding.
1042
+ */
1043
+ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_rvv(nk_size_t column_count, nk_size_t depth) {
1044
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1045
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1046
+ // Break power-of-2 strides for cache associativity
1047
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1048
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1049
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
1050
+ column_count * sizeof(nk_f32_t); // per-column norms
1051
+ }
1052
+
1053
+ /**
1054
+ * @brief Pack B matrix from bf16 to f32 for widened dot product.
1055
+ *
1056
+ * Each bf16 value is converted to f32 via bit shift (bf16 is the upper 16 bits of f32).
1057
+ * Padding values are zeroed. Column-panel layout with depth-contiguous storage.
1058
+ */
1059
+ NK_PUBLIC void nk_dots_pack_bf16_rvv(nk_bf16_t const *b, nk_size_t column_count, nk_size_t depth,
1060
+ nk_size_t b_stride_in_bytes, void *b_packed) {
1061
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1062
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1063
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1064
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1065
+
1066
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1067
+ header->column_count = (nk_u32_t)column_count;
1068
+ header->depth_dimensions = (nk_u32_t)depth;
1069
+ header->depth_padded_values = (nk_u32_t)depth_padded;
1070
+
1071
+ nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1072
+ nk_size_t total = column_count * depth_padded;
1073
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
1074
+
1075
+ for (nk_size_t column = 0; column < column_count; ++column) {
1076
+ nk_u16_t const *src = (nk_u16_t const *)((char const *)b + column * b_stride_in_bytes);
1077
+ nk_f32_t *dst = packed + column * depth_padded;
1078
+ for (nk_size_t k = 0; k < depth; ++k) {
1079
+ union {
1080
+ nk_u32_t u;
1081
+ nk_f32_t f;
1082
+ } conv;
1083
+ conv.u = (nk_u32_t)src[k] << 16;
1084
+ dst[k] = conv.f;
1085
+ }
1086
+ }
1087
+
1088
+ // Append per-column norms after packed data
1089
+ nk_f32_t *norms = (nk_f32_t *)(packed + total);
1090
+ for (nk_size_t column = 0; column < column_count; ++column) {
1091
+ nk_bf16_t const *src = (nk_bf16_t const *)((char const *)b + column * b_stride_in_bytes);
1092
+ norms[column] = nk_dots_reduce_sumsq_bf16_(src, depth);
1093
+ }
1094
+ }
1095
+
1096
+ /**
1097
+ * @brief bf16 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
1098
+ *
1099
+ * Vectorizes over the depth dimension (k). For each (row, column) pair:
1100
+ * - Load A as u16m1 and convert to f32m2 via `nk_bf16m1_to_f32m2_rvv_`
1101
+ * - Load B as f32m2 directly (pre-packed)
1102
+ * - Accumulate via `vfwmacc_vv_f64m4` which widens both f32 operands to f64
1103
+ * - Horizontal reduce and narrow to f32 on store
1104
+ *
1105
+ * Register tile: process 4 rows per iteration (rows_per_tile=4).
1106
+ */
1107
+ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, void const *b_packed_buffer,
1108
+ nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
1109
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
1110
+ nk_size_t c_stride_in_bytes) {
1111
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
1112
+ nk_size_t const depth_padded = header->depth_padded_values;
1113
+ nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
1114
+ sizeof(nk_cross_packed_buffer_header_t));
1115
+
1116
+ // Zero output matrix
1117
+ for (nk_size_t i = 0; i < row_count; ++i) {
1118
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
1119
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
1120
+ }
1121
+
1122
+ // mr=4 register tile over rows
1123
+ nk_size_t row = 0;
1124
+ for (; row + 4 <= row_count; row += 4) {
1125
+ nk_u16_t const *a_row_0 = (nk_u16_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
1126
+ nk_u16_t const *a_row_1 = (nk_u16_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
1127
+ nk_u16_t const *a_row_2 = (nk_u16_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
1128
+ nk_u16_t const *a_row_3 = (nk_u16_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
1129
+ nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
1130
+ nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
1131
+ nk_f32_t *c_row_2 = (nk_f32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
1132
+ nk_f32_t *c_row_3 = (nk_f32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
1133
+
1134
+ for (nk_size_t column = 0; column < column_count; ++column) {
1135
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
1136
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1137
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1138
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1139
+ vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1140
+ vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1141
+
1142
+ nk_size_t remaining = depth;
1143
+ nk_size_t k = 0;
1144
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1145
+ vector_length = __riscv_vsetvl_e32m2(remaining);
1146
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
1147
+ // Load A as u16m1 and convert to f32m2
1148
+ vuint16m1_t a_raw_0_u16m1 = __riscv_vle16_v_u16m1(a_row_0 + k, vector_length);
1149
+ vfloat32m2_t a_vector_0_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_0_u16m1, vector_length);
1150
+ vuint16m1_t a_raw_1_u16m1 = __riscv_vle16_v_u16m1(a_row_1 + k, vector_length);
1151
+ vfloat32m2_t a_vector_1_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_1_u16m1, vector_length);
1152
+ vuint16m1_t a_raw_2_u16m1 = __riscv_vle16_v_u16m1(a_row_2 + k, vector_length);
1153
+ vfloat32m2_t a_vector_2_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_2_u16m1, vector_length);
1154
+ vuint16m1_t a_raw_3_u16m1 = __riscv_vle16_v_u16m1(a_row_3 + k, vector_length);
1155
+ vfloat32m2_t a_vector_3_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_3_u16m1, vector_length);
1156
+ accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
1157
+ vector_length);
1158
+ accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
1159
+ vector_length);
1160
+ accumulator_2_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_2_f64m4, a_vector_2_f32m2, b_vector_f32m2,
1161
+ vector_length);
1162
+ accumulator_3_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_3_f64m4, a_vector_3_f32m2, b_vector_f32m2,
1163
+ vector_length);
1164
+ }
1165
+
1166
+ // Horizontal reduce and narrow to f32
1167
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1168
+ c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1169
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
1170
+ c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1171
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
1172
+ c_row_2[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1173
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, vlmax));
1174
+ c_row_3[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1175
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, vlmax));
1176
+ }
1177
+ }
1178
+ // Remainder rows (mr < 4)
1179
+ for (; row < row_count; ++row) {
1180
+ nk_u16_t const *a_row = (nk_u16_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
1181
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
1182
+ for (nk_size_t column = 0; column < column_count; ++column) {
1183
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
1184
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1185
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1186
+ nk_size_t remaining = depth;
1187
+ nk_size_t k = 0;
1188
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1189
+ vector_length = __riscv_vsetvl_e32m2(remaining);
1190
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
1191
+ vuint16m1_t a_raw_u16m1 = __riscv_vle16_v_u16m1(a_row + k, vector_length);
1192
+ vfloat32m2_t a_vector_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_u16m1, vector_length);
1193
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
1194
+ vector_length);
1195
+ }
1196
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1197
+ c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1198
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
1199
+ }
1200
+ }
1201
+ }
1202
+
1203
+ /**
1204
+ * @brief Public bf16 packed GEMM wrapper matching the declared signature in dots.h.
1205
+ *
1206
+ * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
1207
+ * vectors naturally, so no separate edge kernel is needed.
1208
+ */
1209
+ NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
1210
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
1211
+ nk_dots_packed_bf16_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
1212
+ }
1213
+
1214
+ /**
1215
+ * @brief Symmetric bf16 GEMM: C = A * A^T, upper triangle + mirror.
1216
+ *
1217
+ * Uses f64 widened accumulation via `vfwmacc_vv_f64m4` for precision.
1218
+ * Both inputs are bf16, loaded as u16 and converted to f32 via `nk_bf16m1_to_f32m2_rvv_`.
1219
+ * Stride is in bytes.
1220
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
1221
+ */
1222
+ NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1223
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1224
+ nk_size_t row_start, nk_size_t row_count) {
1225
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1226
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
1227
+
1228
+ for (nk_size_t i = row_start; i < row_end; ++i) {
1229
+ nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride);
1230
+ for (nk_size_t j = i; j < n_vectors; ++j) {
1231
+ nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride);
1232
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1233
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1234
+ nk_size_t remaining = depth;
1235
+ nk_size_t k = 0;
1236
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1237
+ vector_length = __riscv_vsetvl_e32m2(remaining);
1238
+ vuint16m1_t a_raw_u16m1 = __riscv_vle16_v_u16m1(a_i + k, vector_length);
1239
+ vfloat32m2_t a_vector_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_u16m1, vector_length);
1240
+ vuint16m1_t b_raw_u16m1 = __riscv_vle16_v_u16m1(a_j + k, vector_length);
1241
+ vfloat32m2_t b_vector_f32m2 = nk_bf16m1_to_f32m2_rvv_(b_raw_u16m1, vector_length);
1242
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
1243
+ vector_length);
1244
+ }
1245
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1246
+ nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1247
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
1248
+ result[i * result_stride_elements + j] = dot;
1249
+ }
1250
+ }
1251
+ }
1252
+
1253
+ #pragma endregion // Brain Float 16
1254
+
1255
+ #pragma region Half Precision Floats
1256
+
1257
+ /**
1258
+ * @brief Compute the packed buffer size for f16 GEMM (B stored as f32).
1259
+ *
1260
+ * VL is determined by `__riscv_vsetvlmax_e32m2()` since B is stored as f32.
1261
+ * Layout: column-panel with depth-contiguous f32 values, cache-line padding.
1262
+ */
1263
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_rvv(nk_size_t column_count, nk_size_t depth) {
1264
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1265
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1266
+ // Break power-of-2 strides for cache associativity
1267
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1268
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1269
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
1270
+ column_count * sizeof(nk_f32_t); // per-column norms
1271
+ }
1272
+
1273
+ /**
1274
+ * @brief Pack B matrix from f16 to f32 for widened dot product.
1275
+ *
1276
+ * Each f16 value is converted to f32 via `nk_f16_to_f32_serial`.
1277
+ * Padding values are zeroed. Column-panel layout with depth-contiguous storage.
1278
+ */
1279
+ NK_PUBLIC void nk_dots_pack_f16_rvv(nk_f16_t const *b, nk_size_t column_count, nk_size_t depth,
1280
+ nk_size_t b_stride_in_bytes, void *b_packed) {
1281
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1282
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1283
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1284
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1285
+
1286
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1287
+ header->column_count = (nk_u32_t)column_count;
1288
+ header->depth_dimensions = (nk_u32_t)depth;
1289
+ header->depth_padded_values = (nk_u32_t)depth_padded;
1290
+
1291
+ nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1292
+ nk_size_t total = column_count * depth_padded;
1293
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
1294
+
1295
+ for (nk_size_t column = 0; column < column_count; ++column) {
1296
+ nk_f16_t const *src = (nk_f16_t const *)((char const *)b + column * b_stride_in_bytes);
1297
+ nk_f32_t *dst = packed + column * depth_padded;
1298
+ for (nk_size_t k = 0; k < depth; ++k) nk_f16_to_f32_serial(&src[k], &dst[k]);
1299
+ }
1300
+
1301
+ // Append per-column norms after packed data
1302
+ nk_f32_t *norms = (nk_f32_t *)(packed + total);
1303
+ for (nk_size_t column = 0; column < column_count; ++column) {
1304
+ nk_f16_t const *src = (nk_f16_t const *)((char const *)b + column * b_stride_in_bytes);
1305
+ norms[column] = nk_dots_reduce_sumsq_f16_(src, depth);
1306
+ }
1307
+ }
1308
+
1309
+ /**
1310
+ * @brief f16 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
1311
+ *
1312
+ * Vectorizes over the depth dimension (k). For each (row, column) pair:
1313
+ * - Load A as u16m1 and convert to f32m2 via `nk_f16m1_to_f32m2_rvv_`
1314
+ * - Load B as f32m2 directly (pre-packed)
1315
+ * - Accumulate via `vfwmacc_vv_f64m4` which widens both f32 operands to f64
1316
+ * - Horizontal reduce and narrow to f32 on store
1317
+ *
1318
+ * Register tile: process 4 rows per iteration (rows_per_tile=4).
1319
+ */
1320
+ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void const *b_packed_buffer,
1321
+ nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
1322
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
1323
+ nk_size_t c_stride_in_bytes) {
1324
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
1325
+ nk_size_t const depth_padded = header->depth_padded_values;
1326
+ nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
1327
+ sizeof(nk_cross_packed_buffer_header_t));
1328
+
1329
+ // Zero output matrix
1330
+ for (nk_size_t i = 0; i < row_count; ++i) {
1331
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
1332
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
1333
+ }
1334
+
1335
+ // mr=4 register tile over rows
1336
+ nk_size_t row = 0;
1337
+ for (; row + 4 <= row_count; row += 4) {
1338
+ nk_u16_t const *a_row_0 = (nk_u16_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
1339
+ nk_u16_t const *a_row_1 = (nk_u16_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
1340
+ nk_u16_t const *a_row_2 = (nk_u16_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
1341
+ nk_u16_t const *a_row_3 = (nk_u16_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
1342
+ nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
1343
+ nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
1344
+ nk_f32_t *c_row_2 = (nk_f32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
1345
+ nk_f32_t *c_row_3 = (nk_f32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
1346
+
1347
+ for (nk_size_t column = 0; column < column_count; ++column) {
1348
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
1349
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1350
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1351
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1352
+ vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1353
+ vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1354
+
1355
+ nk_size_t remaining = depth;
1356
+ nk_size_t k = 0;
1357
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1358
+ vector_length = __riscv_vsetvl_e32m2(remaining);
1359
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
1360
+ // Load A as u16m1 and convert to f32m2
1361
+ vuint16m1_t a_raw_0_u16m1 = __riscv_vle16_v_u16m1(a_row_0 + k, vector_length);
1362
+ vfloat32m2_t a_vector_0_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_0_u16m1, vector_length);
1363
+ vuint16m1_t a_raw_1_u16m1 = __riscv_vle16_v_u16m1(a_row_1 + k, vector_length);
1364
+ vfloat32m2_t a_vector_1_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_1_u16m1, vector_length);
1365
+ vuint16m1_t a_raw_2_u16m1 = __riscv_vle16_v_u16m1(a_row_2 + k, vector_length);
1366
+ vfloat32m2_t a_vector_2_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_2_u16m1, vector_length);
1367
+ vuint16m1_t a_raw_3_u16m1 = __riscv_vle16_v_u16m1(a_row_3 + k, vector_length);
1368
+ vfloat32m2_t a_vector_3_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_3_u16m1, vector_length);
1369
+ accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
1370
+ vector_length);
1371
+ accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
1372
+ vector_length);
1373
+ accumulator_2_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_2_f64m4, a_vector_2_f32m2, b_vector_f32m2,
1374
+ vector_length);
1375
+ accumulator_3_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_3_f64m4, a_vector_3_f32m2, b_vector_f32m2,
1376
+ vector_length);
1377
+ }
1378
+
1379
+ // Horizontal reduce and narrow to f32
1380
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1381
+ c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1382
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
1383
+ c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1384
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
1385
+ c_row_2[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1386
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, vlmax));
1387
+ c_row_3[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1388
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, vlmax));
1389
+ }
1390
+ }
1391
+ // Remainder rows (mr < 4)
1392
+ for (; row < row_count; ++row) {
1393
+ nk_u16_t const *a_row = (nk_u16_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
1394
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
1395
+ for (nk_size_t column = 0; column < column_count; ++column) {
1396
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
1397
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1398
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1399
+ nk_size_t remaining = depth;
1400
+ nk_size_t k = 0;
1401
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1402
+ vector_length = __riscv_vsetvl_e32m2(remaining);
1403
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
1404
+ vuint16m1_t a_raw_u16m1 = __riscv_vle16_v_u16m1(a_row + k, vector_length);
1405
+ vfloat32m2_t a_vector_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_u16m1, vector_length);
1406
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
1407
+ vector_length);
1408
+ }
1409
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1410
+ c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1411
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
1412
+ }
1413
+ }
1414
+ }
1415
+
1416
+ /**
1417
+ * @brief Public f16 packed GEMM wrapper matching the declared signature in dots.h.
1418
+ *
1419
+ * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
1420
+ * vectors naturally, so no separate edge kernel is needed.
1421
+ */
1422
+ NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
1423
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
1424
+ nk_dots_packed_f16_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
1425
+ }
1426
+
1427
+ /**
1428
+ * @brief Symmetric f16 GEMM: C = A * A^T, upper triangle + mirror.
1429
+ *
1430
+ * Uses f64 widened accumulation via `vfwmacc_vv_f64m4` for precision.
1431
+ * Both inputs are f16, loaded as u16 and converted to f32 via `nk_f16m1_to_f32m2_rvv_`.
1432
+ * Stride is in bytes.
1433
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
1434
+ */
1435
+ NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1436
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1437
+ nk_size_t row_start, nk_size_t row_count) {
1438
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1439
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
1440
+
1441
+ for (nk_size_t i = row_start; i < row_end; ++i) {
1442
+ nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride);
1443
+ for (nk_size_t j = i; j < n_vectors; ++j) {
1444
+ nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride);
1445
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
1446
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
1447
+ nk_size_t remaining = depth;
1448
+ nk_size_t k = 0;
1449
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1450
+ vector_length = __riscv_vsetvl_e32m2(remaining);
1451
+ vuint16m1_t a_raw_u16m1 = __riscv_vle16_v_u16m1(a_i + k, vector_length);
1452
+ vfloat32m2_t a_vector_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_u16m1, vector_length);
1453
+ vuint16m1_t b_raw_u16m1 = __riscv_vle16_v_u16m1(a_j + k, vector_length);
1454
+ vfloat32m2_t b_vector_f32m2 = nk_f16m1_to_f32m2_rvv_(b_raw_u16m1, vector_length);
1455
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
1456
+ vector_length);
1457
+ }
1458
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
1459
+ nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
1460
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
1461
+ result[i * result_stride_elements + j] = dot;
1462
+ }
1463
+ }
1464
+ }
1465
+
1466
+ #pragma endregion // Half Precision Floats
1467
+
1468
+ #pragma region Signed 8-bit Integers
1469
+
1470
+ /**
1471
+ * @brief Compute the packed buffer size for i8 GEMM (B stored as i8).
1472
+ *
1473
+ * VL is determined by `__riscv_vsetvlmax_e8m1()` since B is stored as i8.
1474
+ * Layout: column-panel with depth-contiguous i8 values, cache-line padding.
1475
+ */
1476
+ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_rvv(nk_size_t column_count, nk_size_t depth) {
1477
+ nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
1478
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1479
+ // Break power-of-2 strides for cache associativity
1480
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
1481
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1482
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i8_t) +
1483
+ column_count * sizeof(nk_u32_t); // per-column norms
1484
+ }
1485
+
1486
+ /**
1487
+ * @brief Pack B matrix from i8 to i8 (direct copy) for integer dot product.
1488
+ *
1489
+ * No conversion needed — values are copied directly.
1490
+ * Padding values are zeroed. Column-panel layout with depth-contiguous storage.
1491
+ */
1492
+ NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t column_count, nk_size_t depth,
1493
+ nk_size_t b_stride_in_bytes, void *b_packed) {
1494
+ nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
1495
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1496
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
1497
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1498
+
1499
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1500
+ header->column_count = (nk_u32_t)column_count;
1501
+ header->depth_dimensions = (nk_u32_t)depth;
1502
+ header->depth_padded_values = (nk_u32_t)depth_padded;
1503
+
1504
+ nk_i8_t *packed = (nk_i8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1505
+ nk_size_t total = column_count * depth_padded;
1506
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
1507
+
1508
+ for (nk_size_t column = 0; column < column_count; ++column) {
1509
+ nk_i8_t const *src = (nk_i8_t const *)((char const *)b + column * b_stride_in_bytes);
1510
+ nk_i8_t *dst = packed + column * depth_padded;
1511
+ for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
1512
+ }
1513
+
1514
+ // Append per-column norms after packed data
1515
+ nk_u32_t *norms = (nk_u32_t *)(packed + total);
1516
+ for (nk_size_t column = 0; column < column_count; ++column) {
1517
+ nk_i8_t const *src = (nk_i8_t const *)((char const *)b + column * b_stride_in_bytes);
1518
+ norms[column] = nk_dots_reduce_sumsq_i8_(src, depth);
1519
+ }
1520
+ }
1521
+
1522
+ /**
1523
+ * @brief i8 packed GEMM kernel: C += A * B_packed^T with i32 accumulation.
1524
+ *
1525
+ * Vectorizes over the depth dimension (k). For each (row, column) pair:
1526
+ * - Load i8 values from A and pre-packed i8 values from B
1527
+ * - Widening multiply: i8 x i8 -> i16 via `vwmul`
1528
+ * - Widen-accumulate: i32 += i16 via `vwadd_wv`
1529
+ * - Horizontal reduce via `vredsum`
1530
+ *
1531
+ * Register tile: process 4 rows per iteration (rows_per_tile=4).
1532
+ * Output is nk_i32_t (integer result, no scaling).
1533
+ */
1534
+ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void const *b_packed_buffer,
1535
+ nk_i32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
1536
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
1537
+ nk_size_t c_stride_in_bytes) {
1538
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
1539
+ nk_size_t const depth_padded = header->depth_padded_values;
1540
+ nk_i8_t const *packed_data = (nk_i8_t const *)((char const *)b_packed_buffer +
1541
+ sizeof(nk_cross_packed_buffer_header_t));
1542
+
1543
+ // Zero output matrix
1544
+ for (nk_size_t i = 0; i < row_count; ++i) {
1545
+ nk_i32_t *c_row = (nk_i32_t *)((char *)c_matrix + i * c_stride_in_bytes);
1546
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
1547
+ }
1548
+
1549
+ // mr=4 register tile over rows
1550
+ nk_size_t row = 0;
1551
+ for (; row + 4 <= row_count; row += 4) {
1552
+ nk_i8_t const *a_row_0 = (nk_i8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
1553
+ nk_i8_t const *a_row_1 = (nk_i8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
1554
+ nk_i8_t const *a_row_2 = (nk_i8_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
1555
+ nk_i8_t const *a_row_3 = (nk_i8_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
1556
+ nk_i32_t *c_row_0 = (nk_i32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
1557
+ nk_i32_t *c_row_1 = (nk_i32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
1558
+ nk_i32_t *c_row_2 = (nk_i32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
1559
+ nk_i32_t *c_row_3 = (nk_i32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
1560
+
1561
+ for (nk_size_t column = 0; column < column_count; ++column) {
1562
+ nk_i8_t const *b_column = packed_data + column * depth_padded;
1563
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1564
+ vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1565
+ vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1566
+ vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1567
+ vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1568
+
1569
+ nk_size_t remaining = depth;
1570
+ nk_size_t k = 0;
1571
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1572
+ vector_length = __riscv_vsetvl_e8m1(remaining);
1573
+ vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(b_column + k, vector_length);
1574
+ vint8m1_t a_vector_0_i8m1 = __riscv_vle8_v_i8m1(a_row_0 + k, vector_length);
1575
+ vint8m1_t a_vector_1_i8m1 = __riscv_vle8_v_i8m1(a_row_1 + k, vector_length);
1576
+ vint8m1_t a_vector_2_i8m1 = __riscv_vle8_v_i8m1(a_row_2 + k, vector_length);
1577
+ vint8m1_t a_vector_3_i8m1 = __riscv_vle8_v_i8m1(a_row_3 + k, vector_length);
1578
+ vint16m2_t product_0_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_0_i8m1, b_vector_i8m1, vector_length);
1579
+ vint16m2_t product_1_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_1_i8m1, b_vector_i8m1, vector_length);
1580
+ vint16m2_t product_2_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_2_i8m1, b_vector_i8m1, vector_length);
1581
+ vint16m2_t product_3_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_3_i8m1, b_vector_i8m1, vector_length);
1582
+ accumulator_0_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_0_i32m4, accumulator_0_i32m4,
1583
+ product_0_i16m2, vector_length);
1584
+ accumulator_1_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_1_i32m4, accumulator_1_i32m4,
1585
+ product_1_i16m2, vector_length);
1586
+ accumulator_2_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_2_i32m4, accumulator_2_i32m4,
1587
+ product_2_i16m2, vector_length);
1588
+ accumulator_3_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_3_i32m4, accumulator_3_i32m4,
1589
+ product_3_i16m2, vector_length);
1590
+ }
1591
+
1592
+ // Horizontal reduce
1593
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
1594
+ c_row_0[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1595
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, vlmax));
1596
+ c_row_1[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1597
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, vlmax));
1598
+ c_row_2[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1599
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, vlmax));
1600
+ c_row_3[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1601
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, vlmax));
1602
+ }
1603
+ }
1604
+ // Remainder rows (mr < 4)
1605
+ for (; row < row_count; ++row) {
1606
+ nk_i8_t const *a_row = (nk_i8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
1607
+ nk_i32_t *c_row = (nk_i32_t *)((char *)c_matrix + row * c_stride_in_bytes);
1608
+ for (nk_size_t column = 0; column < column_count; ++column) {
1609
+ nk_i8_t const *b_column = packed_data + column * depth_padded;
1610
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1611
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1612
+ nk_size_t remaining = depth;
1613
+ nk_size_t k = 0;
1614
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1615
+ vector_length = __riscv_vsetvl_e8m1(remaining);
1616
+ vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(b_column + k, vector_length);
1617
+ vint8m1_t a_vector_i8m1 = __riscv_vle8_v_i8m1(a_row + k, vector_length);
1618
+ vint16m2_t product_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_i8m1, b_vector_i8m1, vector_length);
1619
+ accumulator_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_i32m4, accumulator_i32m4, product_i16m2,
1620
+ vector_length);
1621
+ }
1622
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
1623
+ c_row[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1624
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax));
1625
+ }
1626
+ }
1627
+ }
1628
+
1629
+ /**
1630
+ * @brief Public i8 packed GEMM wrapper matching the declared signature in dots.h.
1631
+ *
1632
+ * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
1633
+ * vectors naturally, so no separate edge kernel is needed.
1634
+ */
1635
+ NK_PUBLIC void nk_dots_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t m, nk_size_t n,
1636
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
1637
+ nk_dots_packed_i8_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
1638
+ }
1639
+
1640
+ /**
1641
+ * @brief Symmetric i8 GEMM: C = A * A^T, upper triangle + mirror.
1642
+ *
1643
+ * Uses integer i8 arithmetic with i32 accumulation.
1644
+ * Both inputs are i8, widened via i8 x i8 -> i16 -> i32 accumulation.
1645
+ * Stride is in bytes.
1646
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
1647
+ */
1648
+ NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
1649
+ nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
1650
+ nk_size_t row_count) {
1651
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_i32_t);
1652
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
1653
+
1654
+ for (nk_size_t i = row_start; i < row_end; ++i) {
1655
+ nk_i8_t const *a_i = (nk_i8_t const *)((char const *)vectors + i * stride);
1656
+ for (nk_size_t j = i; j < n_vectors; ++j) {
1657
+ nk_i8_t const *a_j = (nk_i8_t const *)((char const *)vectors + j * stride);
1658
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1659
+ vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
1660
+ nk_size_t remaining = depth;
1661
+ nk_size_t k = 0;
1662
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1663
+ vector_length = __riscv_vsetvl_e8m1(remaining);
1664
+ vint8m1_t a_vector_i8m1 = __riscv_vle8_v_i8m1(a_i + k, vector_length);
1665
+ vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(a_j + k, vector_length);
1666
+ vint16m2_t product_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_i8m1, b_vector_i8m1, vector_length);
1667
+ accumulator_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_i32m4, accumulator_i32m4, product_i16m2,
1668
+ vector_length);
1669
+ }
1670
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
1671
+ nk_i32_t dot = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
1672
+ __riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax));
1673
+ result[i * result_stride_elements + j] = dot;
1674
+ }
1675
+ }
1676
+ }
1677
+
1678
+ #pragma endregion // Signed 8-bit Integers
1679
+
1680
+ #pragma region Unsigned 8-bit Integers
1681
+
1682
+ /**
1683
+ * @brief Compute the packed buffer size for u8 GEMM (B stored as u8).
1684
+ *
1685
+ * VL is determined by `__riscv_vsetvlmax_e8m1()` since B is stored as u8.
1686
+ * Layout: column-panel with depth-contiguous u8 values, cache-line padding.
1687
+ */
1688
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_rvv(nk_size_t column_count, nk_size_t depth) {
1689
+ nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
1690
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1691
+ // Break power-of-2 strides for cache associativity
1692
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_u8_t);
1693
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1694
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_u8_t) +
1695
+ column_count * sizeof(nk_u32_t); // per-column norms
1696
+ }
1697
+
1698
+ /**
1699
+ * @brief Pack B matrix from u8 to u8 (direct copy) for integer dot product.
1700
+ *
1701
+ * No conversion needed — values are copied directly.
1702
+ * Padding values are zeroed. Column-panel layout with depth-contiguous storage.
1703
+ */
1704
+ NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t column_count, nk_size_t depth,
1705
+ nk_size_t b_stride_in_bytes, void *b_packed) {
1706
+ nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
1707
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1708
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_u8_t);
1709
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1710
+
1711
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1712
+ header->column_count = (nk_u32_t)column_count;
1713
+ header->depth_dimensions = (nk_u32_t)depth;
1714
+ header->depth_padded_values = (nk_u32_t)depth_padded;
1715
+
1716
+ nk_u8_t *packed = (nk_u8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1717
+ nk_size_t total = column_count * depth_padded;
1718
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
1719
+
1720
+ for (nk_size_t column = 0; column < column_count; ++column) {
1721
+ nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
1722
+ nk_u8_t *dst = packed + column * depth_padded;
1723
+ for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
1724
+ }
1725
+
1726
+ // Append per-column norms after packed data
1727
+ nk_u32_t *norms = (nk_u32_t *)(packed + total);
1728
+ for (nk_size_t column = 0; column < column_count; ++column) {
1729
+ nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
1730
+ norms[column] = nk_dots_reduce_sumsq_u8_(src, depth);
1731
+ }
1732
+ }
1733
+
1734
+ /**
1735
+ * @brief u8 packed GEMM kernel: C += A * B_packed^T with u32 accumulation.
1736
+ *
1737
+ * Vectorizes over the depth dimension (k). For each (row, column) pair:
1738
+ * - Load u8 values from A and pre-packed u8 values from B
1739
+ * - Widening multiply: u8 x u8 -> u16 via `vwmulu`
1740
+ * - Widen-accumulate: u32 += u16 via `vwaddu_wv`
1741
+ * - Horizontal reduce via `vredsum`
1742
+ *
1743
+ * Register tile: process 4 rows per iteration (rows_per_tile=4).
1744
+ * Output is nk_u32_t (unsigned integer result, no scaling).
1745
+ */
1746
+ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void const *b_packed_buffer,
1747
+ nk_u32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
1748
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
1749
+ nk_size_t c_stride_in_bytes) {
1750
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
1751
+ nk_size_t const depth_padded = header->depth_padded_values;
1752
+ nk_u8_t const *packed_data = (nk_u8_t const *)((char const *)b_packed_buffer +
1753
+ sizeof(nk_cross_packed_buffer_header_t));
1754
+
1755
+ // Zero output matrix
1756
+ for (nk_size_t i = 0; i < row_count; ++i) {
1757
+ nk_u32_t *c_row = (nk_u32_t *)((char *)c_matrix + i * c_stride_in_bytes);
1758
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
1759
+ }
1760
+
1761
+ // mr=4 register tile over rows
1762
+ nk_size_t row = 0;
1763
+ for (; row + 4 <= row_count; row += 4) {
1764
+ nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
1765
+ nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
1766
+ nk_u8_t const *a_row_2 = (nk_u8_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
1767
+ nk_u8_t const *a_row_3 = (nk_u8_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
1768
+ nk_u32_t *c_row_0 = (nk_u32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
1769
+ nk_u32_t *c_row_1 = (nk_u32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
1770
+ nk_u32_t *c_row_2 = (nk_u32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
1771
+ nk_u32_t *c_row_3 = (nk_u32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
1772
+
1773
+ for (nk_size_t column = 0; column < column_count; ++column) {
1774
+ nk_u8_t const *b_column = packed_data + column * depth_padded;
1775
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1776
+ vuint32m4_t accumulator_0_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1777
+ vuint32m4_t accumulator_1_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1778
+ vuint32m4_t accumulator_2_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1779
+ vuint32m4_t accumulator_3_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1780
+
1781
+ nk_size_t remaining = depth;
1782
+ nk_size_t k = 0;
1783
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1784
+ vector_length = __riscv_vsetvl_e8m1(remaining);
1785
+ vuint8m1_t b_vector_u8m1 = __riscv_vle8_v_u8m1(b_column + k, vector_length);
1786
+ vuint8m1_t a_vector_0_u8m1 = __riscv_vle8_v_u8m1(a_row_0 + k, vector_length);
1787
+ vuint8m1_t a_vector_1_u8m1 = __riscv_vle8_v_u8m1(a_row_1 + k, vector_length);
1788
+ vuint8m1_t a_vector_2_u8m1 = __riscv_vle8_v_u8m1(a_row_2 + k, vector_length);
1789
+ vuint8m1_t a_vector_3_u8m1 = __riscv_vle8_v_u8m1(a_row_3 + k, vector_length);
1790
+ vuint16m2_t product_0_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_0_u8m1, b_vector_u8m1, vector_length);
1791
+ vuint16m2_t product_1_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_1_u8m1, b_vector_u8m1, vector_length);
1792
+ vuint16m2_t product_2_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_2_u8m1, b_vector_u8m1, vector_length);
1793
+ vuint16m2_t product_3_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_3_u8m1, b_vector_u8m1, vector_length);
1794
+ accumulator_0_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_0_u32m4, accumulator_0_u32m4,
1795
+ product_0_u16m2, vector_length);
1796
+ accumulator_1_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_1_u32m4, accumulator_1_u32m4,
1797
+ product_1_u16m2, vector_length);
1798
+ accumulator_2_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_2_u32m4, accumulator_2_u32m4,
1799
+ product_2_u16m2, vector_length);
1800
+ accumulator_3_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_3_u32m4, accumulator_3_u32m4,
1801
+ product_3_u16m2, vector_length);
1802
+ }
1803
+
1804
+ // Horizontal reduce
1805
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
1806
+ c_row_0[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1807
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_0_u32m4, zero_u32m1, vlmax));
1808
+ c_row_1[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1809
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_1_u32m4, zero_u32m1, vlmax));
1810
+ c_row_2[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1811
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_2_u32m4, zero_u32m1, vlmax));
1812
+ c_row_3[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1813
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_3_u32m4, zero_u32m1, vlmax));
1814
+ }
1815
+ }
1816
+ // Remainder rows (mr < 4)
1817
+ for (; row < row_count; ++row) {
1818
+ nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
1819
+ nk_u32_t *c_row = (nk_u32_t *)((char *)c_matrix + row * c_stride_in_bytes);
1820
+ for (nk_size_t column = 0; column < column_count; ++column) {
1821
+ nk_u8_t const *b_column = packed_data + column * depth_padded;
1822
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1823
+ vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1824
+ nk_size_t remaining = depth;
1825
+ nk_size_t k = 0;
1826
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1827
+ vector_length = __riscv_vsetvl_e8m1(remaining);
1828
+ vuint8m1_t b_vector_u8m1 = __riscv_vle8_v_u8m1(b_column + k, vector_length);
1829
+ vuint8m1_t a_vector_u8m1 = __riscv_vle8_v_u8m1(a_row + k, vector_length);
1830
+ vuint16m2_t product_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_u8m1, b_vector_u8m1, vector_length);
1831
+ accumulator_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_u32m4, accumulator_u32m4, product_u16m2,
1832
+ vector_length);
1833
+ }
1834
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
1835
+ c_row[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1836
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, vlmax));
1837
+ }
1838
+ }
1839
+ }
1840
+
1841
+ /**
1842
+ * @brief Public u8 packed GEMM wrapper matching the declared signature in dots.h.
1843
+ *
1844
+ * Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
1845
+ * vectors naturally, so no separate edge kernel is needed.
1846
+ */
1847
+ NK_PUBLIC void nk_dots_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t m, nk_size_t n,
1848
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
1849
+ nk_dots_packed_u8_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
1850
+ }
1851
+
1852
+ /**
1853
+ * @brief Symmetric u8 GEMM: C = A * A^T, upper triangle + mirror.
1854
+ *
1855
+ * Uses unsigned integer u8 arithmetic with u32 accumulation.
1856
+ * Both inputs are u8, widened via u8 x u8 -> u16 -> u32 accumulation.
1857
+ * Stride is in bytes.
1858
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
1859
+ */
1860
+ NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
1861
+ nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
1862
+ nk_size_t row_count) {
1863
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_u32_t);
1864
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
1865
+
1866
+ for (nk_size_t i = row_start; i < row_end; ++i) {
1867
+ nk_u8_t const *a_i = (nk_u8_t const *)((char const *)vectors + i * stride);
1868
+ for (nk_size_t j = i; j < n_vectors; ++j) {
1869
+ nk_u8_t const *a_j = (nk_u8_t const *)((char const *)vectors + j * stride);
1870
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
1871
+ vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
1872
+ nk_size_t remaining = depth;
1873
+ nk_size_t k = 0;
1874
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
1875
+ vector_length = __riscv_vsetvl_e8m1(remaining);
1876
+ vuint8m1_t a_vector_u8m1 = __riscv_vle8_v_u8m1(a_i + k, vector_length);
1877
+ vuint8m1_t b_vector_u8m1 = __riscv_vle8_v_u8m1(a_j + k, vector_length);
1878
+ vuint16m2_t product_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_u8m1, b_vector_u8m1, vector_length);
1879
+ accumulator_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_u32m4, accumulator_u32m4, product_u16m2,
1880
+ vector_length);
1881
+ }
1882
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
1883
+ nk_u32_t dot = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
1884
+ __riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, vlmax));
1885
+ result[i * result_stride_elements + j] = dot;
1886
+ }
1887
+ }
1888
+ }
1889
+
1890
+ #pragma endregion // Unsigned 8-bit Integers
1891
+
1892
+ #pragma region Quarter Precision E4M3
1893
+
1894
+ /**
1895
+ * @brief E4M3 magnitude LUT: 7-bit magnitude -> f32 bit pattern (u32).
1896
+ * nk_e4m3_magnitude_lut_rvv_[i] = float_to_bits(e4m3_to_f32(i)) for i=0..127.
1897
+ * E4M3FN: 4 exponent bits (bias=7), 3 mantissa bits, no infinity,
1898
+ * NaN = magnitude 0x7F only.
1899
+ */
1900
+ static nk_u32_t const nk_e4m3_magnitude_lut_rvv_[128] = {
1901
+ 0x00000000u, 0x3B000000u, 0x3B800000u, 0x3BC00000u,
1902
+ 0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u, /* [ 0.. 7] */
1903
+ 0x3C800000u, 0x3C900000u, 0x3CA00000u, 0x3CB00000u,
1904
+ 0x3CC00000u, 0x3CD00000u, 0x3CE00000u, 0x3CF00000u, /* [ 8.. 15] */
1905
+ 0x3D000000u, 0x3D100000u, 0x3D200000u, 0x3D300000u,
1906
+ 0x3D400000u, 0x3D500000u, 0x3D600000u, 0x3D700000u, /* [ 16.. 23] */
1907
+ 0x3D800000u, 0x3D900000u, 0x3DA00000u, 0x3DB00000u,
1908
+ 0x3DC00000u, 0x3DD00000u, 0x3DE00000u, 0x3DF00000u, /* [ 24.. 31] */
1909
+ 0x3E000000u, 0x3E100000u, 0x3E200000u, 0x3E300000u,
1910
+ 0x3E400000u, 0x3E500000u, 0x3E600000u, 0x3E700000u, /* [ 32.. 39] */
1911
+ 0x3E800000u, 0x3E900000u, 0x3EA00000u, 0x3EB00000u,
1912
+ 0x3EC00000u, 0x3ED00000u, 0x3EE00000u, 0x3EF00000u, /* [ 40.. 47] */
1913
+ 0x3F000000u, 0x3F100000u, 0x3F200000u, 0x3F300000u,
1914
+ 0x3F400000u, 0x3F500000u, 0x3F600000u, 0x3F700000u, /* [ 48.. 55] */
1915
+ 0x3F800000u, 0x3F900000u, 0x3FA00000u, 0x3FB00000u,
1916
+ 0x3FC00000u, 0x3FD00000u, 0x3FE00000u, 0x3FF00000u, /* [ 56.. 63] */
1917
+ 0x40000000u, 0x40100000u, 0x40200000u, 0x40300000u,
1918
+ 0x40400000u, 0x40500000u, 0x40600000u, 0x40700000u, /* [ 64.. 71] */
1919
+ 0x40800000u, 0x40900000u, 0x40A00000u, 0x40B00000u,
1920
+ 0x40C00000u, 0x40D00000u, 0x40E00000u, 0x40F00000u, /* [ 72.. 79] */
1921
+ 0x41000000u, 0x41100000u, 0x41200000u, 0x41300000u,
1922
+ 0x41400000u, 0x41500000u, 0x41600000u, 0x41700000u, /* [ 80.. 87] */
1923
+ 0x41800000u, 0x41900000u, 0x41A00000u, 0x41B00000u,
1924
+ 0x41C00000u, 0x41D00000u, 0x41E00000u, 0x41F00000u, /* [ 88.. 95] */
1925
+ 0x42000000u, 0x42100000u, 0x42200000u, 0x42300000u,
1926
+ 0x42400000u, 0x42500000u, 0x42600000u, 0x42700000u, /* [ 96..103] */
1927
+ 0x42800000u, 0x42900000u, 0x42A00000u, 0x42B00000u,
1928
+ 0x42C00000u, 0x42D00000u, 0x42E00000u, 0x42F00000u, /* [104..111] */
1929
+ 0x43000000u, 0x43100000u, 0x43200000u, 0x43300000u,
1930
+ 0x43400000u, 0x43500000u, 0x43600000u, 0x43700000u, /* [112..119] */
1931
+ 0x43800000u, 0x43900000u, 0x43A00000u, 0x43B00000u,
1932
+ 0x43C00000u, 0x43D00000u, 0x43E00000u, 0x7FC00000u /* [120..127] */
1933
+ };
1934
+
1935
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_rvv(nk_size_t column_count, nk_size_t depth) {
1936
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1937
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1938
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1939
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1940
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
1941
+ column_count * sizeof(nk_f32_t); // per-column norms
1942
+ }
1943
+
1944
+ /**
1945
+ * @brief Pack B matrix from e4m3 to f32 for floating-point dot product.
1946
+ *
1947
+ * Each e4m3 byte is converted to f32 via `nk_e4m3_to_f32_serial`.
1948
+ * Padding values are zeroed. Column-panel layout with depth-contiguous storage.
1949
+ */
1950
+ NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t column_count, nk_size_t depth,
1951
+ nk_size_t b_stride_in_bytes, void *b_packed) {
1952
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
1953
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
1954
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
1955
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
1956
+
1957
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
1958
+ header->column_count = (nk_u32_t)column_count;
1959
+ header->depth_dimensions = (nk_u32_t)depth;
1960
+ header->depth_padded_values = (nk_u32_t)depth_padded;
1961
+
1962
+ nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
1963
+ nk_size_t total = column_count * depth_padded;
1964
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
1965
+
1966
+ for (nk_size_t column = 0; column < column_count; ++column) {
1967
+ nk_e4m3_t const *src = (nk_e4m3_t const *)((char const *)b + column * b_stride_in_bytes);
1968
+ nk_f32_t *dst = packed + column * depth_padded;
1969
+ for (nk_size_t k = 0; k < depth; ++k) nk_e4m3_to_f32_serial(&src[k], &dst[k]);
1970
+ }
1971
+
1972
+ // Append per-column norms after packed data
1973
+ nk_f32_t *norms = (nk_f32_t *)(packed + total);
1974
+ for (nk_size_t column = 0; column < column_count; ++column) {
1975
+ nk_e4m3_t const *src = (nk_e4m3_t const *)((char const *)b + column * b_stride_in_bytes);
1976
+ norms[column] = nk_dots_reduce_sumsq_e4m3_(src, depth);
1977
+ }
1978
+ }
1979
+
1980
+ /**
1981
+ * @brief e4m3 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
1982
+ *
1983
+ * Vectorizes over the depth dimension (k). For each (row, column) pair:
1984
+ * - Load pre-packed f32 values from B
1985
+ * - Load raw e4m3 bytes from A, convert on-the-fly via 128-entry f32 LUT gather:
1986
+ * extract 7-bit magnitude, zero-extend to u32, compute byte offsets (x4),
1987
+ * gather f32 bit patterns, inject sign bit from bit 7 (<<24), reinterpret as f32
1988
+ * - Widening FMA: f32xf32 -> f64 via `vfwmacc_vv_f64m4`
1989
+ *
1990
+ * Register tile: process 2 rows per iteration (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
1991
+ */
1992
+ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, void const *b_packed_buffer,
1993
+ nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
1994
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
1995
+ nk_size_t c_stride_in_bytes) {
1996
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
1997
+ nk_size_t const depth_padded = header->depth_padded_values;
1998
+ nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
1999
+ sizeof(nk_cross_packed_buffer_header_t));
2000
+
2001
+ // Zero output matrix
2002
+ for (nk_size_t i = 0; i < row_count; ++i) {
2003
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
2004
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
2005
+ }
2006
+
2007
+ // mr=2 register tile over rows
2008
+ nk_size_t row = 0;
2009
+ for (; row + 2 <= row_count; row += 2) {
2010
+ nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
2011
+ nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
2012
+ nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
2013
+ nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
2014
+
2015
+ for (nk_size_t column = 0; column < column_count; ++column) {
2016
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
2017
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2018
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2019
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2020
+
2021
+ nk_size_t remaining = depth;
2022
+ nk_size_t k = 0;
2023
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
2024
+ vector_length = __riscv_vsetvl_e32m2(remaining);
2025
+
2026
+ // Load pre-packed f32 B values
2027
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
2028
+
2029
+ // Load raw e4m3 bytes from each A row
2030
+ vuint8mf2_t raw0_u8mf2 = __riscv_vle8_v_u8mf2(a_row_0 + k, vector_length);
2031
+ vuint8mf2_t raw1_u8mf2 = __riscv_vle8_v_u8mf2(a_row_1 + k, vector_length);
2032
+
2033
+ // Extract 7-bit magnitudes, zero-extend to u32, compute byte offsets for f32 LUT
2034
+ vuint8mf2_t mag0_u8mf2 = __riscv_vand_vx_u8mf2(raw0_u8mf2, 0x7F, vector_length);
2035
+ vuint8mf2_t mag1_u8mf2 = __riscv_vand_vx_u8mf2(raw1_u8mf2, 0x7F, vector_length);
2036
+ vuint32m2_t idx0_u32m2 = __riscv_vzext_vf4_u32m2(mag0_u8mf2, vector_length);
2037
+ vuint32m2_t idx1_u32m2 = __riscv_vzext_vf4_u32m2(mag1_u8mf2, vector_length);
2038
+ vuint32m2_t off0_u32m2 = __riscv_vsll_vx_u32m2(idx0_u32m2, 2,
2039
+ vector_length); // byte offsets = index * 4
2040
+ vuint32m2_t off1_u32m2 = __riscv_vsll_vx_u32m2(idx1_u32m2, 2, vector_length);
2041
+
2042
+ // Gather f32 bit patterns from magnitude LUT
2043
+ vuint32m2_t bits0_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off0_u32m2,
2044
+ vector_length);
2045
+ vuint32m2_t bits1_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off1_u32m2,
2046
+ vector_length);
2047
+
2048
+ // Extract sign bit 7, shift to f32 sign position (bit 31)
2049
+ vuint8mf2_t sign0_u8mf2 = __riscv_vand_vx_u8mf2(raw0_u8mf2, 0x80, vector_length);
2050
+ vuint8mf2_t sign1_u8mf2 = __riscv_vand_vx_u8mf2(raw1_u8mf2, 0x80, vector_length);
2051
+ vuint32m2_t sign0_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign0_u8mf2, vector_length), 24,
2052
+ vector_length);
2053
+ vuint32m2_t sign1_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign1_u8mf2, vector_length), 24,
2054
+ vector_length);
2055
+
2056
+ // Apply sign and reinterpret as f32
2057
+ vfloat32m2_t a_vector_0_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2058
+ __riscv_vor_vv_u32m2(bits0_u32m2, sign0_u32m2, vector_length));
2059
+ vfloat32m2_t a_vector_1_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2060
+ __riscv_vor_vv_u32m2(bits1_u32m2, sign1_u32m2, vector_length));
2061
+
2062
+ // Widening FMA: f32xf32 -> f64
2063
+ accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
2064
+ vector_length);
2065
+ accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
2066
+ vector_length);
2067
+ }
2068
+
2069
+ // Horizontal reduce and narrow to f32
2070
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2071
+ c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2072
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
2073
+ c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2074
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
2075
+ }
2076
+ }
2077
+ // Remainder rows
2078
+ for (; row < row_count; ++row) {
2079
+ nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
2080
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
2081
+ for (nk_size_t column = 0; column < column_count; ++column) {
2082
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
2083
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2084
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2085
+ nk_size_t remaining = depth;
2086
+ nk_size_t k = 0;
2087
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
2088
+ vector_length = __riscv_vsetvl_e32m2(remaining);
2089
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
2090
+ vuint8mf2_t raw_a_u8mf2 = __riscv_vle8_v_u8mf2(a_row + k, vector_length);
2091
+ vuint8mf2_t mag_a_u8mf2 = __riscv_vand_vx_u8mf2(raw_a_u8mf2, 0x7F, vector_length);
2092
+ vuint32m2_t idx_a_u32m2 = __riscv_vzext_vf4_u32m2(mag_a_u8mf2, vector_length);
2093
+ vuint32m2_t off_a_u32m2 = __riscv_vsll_vx_u32m2(idx_a_u32m2, 2, vector_length);
2094
+ vuint32m2_t bits_a_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off_a_u32m2,
2095
+ vector_length);
2096
+ vuint8mf2_t sign_a_u8mf2 = __riscv_vand_vx_u8mf2(raw_a_u8mf2, 0x80, vector_length);
2097
+ vuint32m2_t sign_a_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_a_u8mf2, vector_length),
2098
+ 24, vector_length);
2099
+ vfloat32m2_t a_vector_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2100
+ __riscv_vor_vv_u32m2(bits_a_u32m2, sign_a_u32m2, vector_length));
2101
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
2102
+ vector_length);
2103
+ }
2104
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2105
+ c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2106
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
2107
+ }
2108
+ }
2109
+ }
2110
+
2111
+ /**
2112
+ * @brief Public e4m3 packed GEMM wrapper matching the declared signature in dots.h.
2113
+ */
2114
+ NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
2115
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
2116
+ nk_dots_packed_e4m3_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
2117
+ }
2118
+
2119
+ /**
2120
+ * @brief Symmetric e4m3 GEMM: C = A * A^T, upper triangle + mirror.
2121
+ *
2122
+ * Uses f32 LUT gather with f64 widened accumulation for precision.
2123
+ * Both operands are converted from e4m3 on-the-fly via magnitude LUT.
2124
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
2125
+ */
2126
+ NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2127
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2128
+ nk_size_t row_start, nk_size_t row_count) {
2129
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
2130
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
2131
+
2132
+ for (nk_size_t i = row_start; i < row_end; ++i) {
2133
+ nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
2134
+ for (nk_size_t j = i; j < n_vectors; ++j) {
2135
+ nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
2136
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2137
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2138
+ nk_size_t remaining = depth;
2139
+ nk_size_t k = 0;
2140
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
2141
+ vector_length = __riscv_vsetvl_e32m2(remaining);
2142
+ vuint8mf2_t raw_i_u8mf2 = __riscv_vle8_v_u8mf2(a_i + k, vector_length);
2143
+ vuint8mf2_t raw_j_u8mf2 = __riscv_vle8_v_u8mf2(a_j + k, vector_length);
2144
+
2145
+ // Convert i-vector via LUT gather
2146
+ vuint8mf2_t mag_i_u8mf2 = __riscv_vand_vx_u8mf2(raw_i_u8mf2, 0x7F, vector_length);
2147
+ vuint32m2_t idx_i_u32m2 = __riscv_vzext_vf4_u32m2(mag_i_u8mf2, vector_length);
2148
+ vuint32m2_t off_i_u32m2 = __riscv_vsll_vx_u32m2(idx_i_u32m2, 2, vector_length);
2149
+ vuint32m2_t bits_i_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off_i_u32m2,
2150
+ vector_length);
2151
+ vuint8mf2_t sign_i_u8mf2 = __riscv_vand_vx_u8mf2(raw_i_u8mf2, 0x80, vector_length);
2152
+ vuint32m2_t sign_i_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_i_u8mf2, vector_length),
2153
+ 24, vector_length);
2154
+ vfloat32m2_t val_i_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2155
+ __riscv_vor_vv_u32m2(bits_i_u32m2, sign_i_u32m2, vector_length));
2156
+
2157
+ // Convert j-vector via LUT gather
2158
+ vuint8mf2_t mag_j_u8mf2 = __riscv_vand_vx_u8mf2(raw_j_u8mf2, 0x7F, vector_length);
2159
+ vuint32m2_t idx_j_u32m2 = __riscv_vzext_vf4_u32m2(mag_j_u8mf2, vector_length);
2160
+ vuint32m2_t off_j_u32m2 = __riscv_vsll_vx_u32m2(idx_j_u32m2, 2, vector_length);
2161
+ vuint32m2_t bits_j_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off_j_u32m2,
2162
+ vector_length);
2163
+ vuint8mf2_t sign_j_u8mf2 = __riscv_vand_vx_u8mf2(raw_j_u8mf2, 0x80, vector_length);
2164
+ vuint32m2_t sign_j_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_j_u8mf2, vector_length),
2165
+ 24, vector_length);
2166
+ vfloat32m2_t val_j_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2167
+ __riscv_vor_vv_u32m2(bits_j_u32m2, sign_j_u32m2, vector_length));
2168
+
2169
+ // Widening FMA: f32xf32 -> f64
2170
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, val_i_f32m2, val_j_f32m2,
2171
+ vector_length);
2172
+ }
2173
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2174
+ nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2175
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
2176
+ result[i * result_stride_elements + j] = dot;
2177
+ }
2178
+ }
2179
+ }
2180
+
2181
+ #pragma endregion // Quarter Precision E4M3
2182
+
2183
+ #pragma region Quarter Precision E5M2
2184
+
2185
+ /**
2186
+ * @brief E5M2 magnitude LUT: 7-bit magnitude -> f32 bit pattern (u32).
2187
+ * nk_e5m2_magnitude_lut_rvv_[i] = float_to_bits(e5m2_to_f32(i)) for i=0..127.
2188
+ * E5M2: 5 exponent bits (bias=15), 2 mantissa bits, has infinity (0x7C) and
2189
+ * NaN (magnitudes 0x7D..0x7F).
2190
+ */
2191
+ static nk_u32_t const nk_e5m2_magnitude_lut_rvv_[128] = {
2192
+ 0x00000000u, 0x37800000u, 0x38000000u, 0x38400000u,
2193
+ 0x38800000u, 0x38A00000u, 0x38C00000u, 0x38E00000u, /* [ 0.. 7] */
2194
+ 0x39000000u, 0x39200000u, 0x39400000u, 0x39600000u,
2195
+ 0x39800000u, 0x39A00000u, 0x39C00000u, 0x39E00000u, /* [ 8.. 15] */
2196
+ 0x3A000000u, 0x3A200000u, 0x3A400000u, 0x3A600000u,
2197
+ 0x3A800000u, 0x3AA00000u, 0x3AC00000u, 0x3AE00000u, /* [ 16.. 23] */
2198
+ 0x3B000000u, 0x3B200000u, 0x3B400000u, 0x3B600000u,
2199
+ 0x3B800000u, 0x3BA00000u, 0x3BC00000u, 0x3BE00000u, /* [ 24.. 31] */
2200
+ 0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u,
2201
+ 0x3C800000u, 0x3CA00000u, 0x3CC00000u, 0x3CE00000u, /* [ 32.. 39] */
2202
+ 0x3D000000u, 0x3D200000u, 0x3D400000u, 0x3D600000u,
2203
+ 0x3D800000u, 0x3DA00000u, 0x3DC00000u, 0x3DE00000u, /* [ 40.. 47] */
2204
+ 0x3E000000u, 0x3E200000u, 0x3E400000u, 0x3E600000u,
2205
+ 0x3E800000u, 0x3EA00000u, 0x3EC00000u, 0x3EE00000u, /* [ 48.. 55] */
2206
+ 0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u,
2207
+ 0x3F800000u, 0x3FA00000u, 0x3FC00000u, 0x3FE00000u, /* [ 56.. 63] */
2208
+ 0x40000000u, 0x40200000u, 0x40400000u, 0x40600000u,
2209
+ 0x40800000u, 0x40A00000u, 0x40C00000u, 0x40E00000u, /* [ 64.. 71] */
2210
+ 0x41000000u, 0x41200000u, 0x41400000u, 0x41600000u,
2211
+ 0x41800000u, 0x41A00000u, 0x41C00000u, 0x41E00000u, /* [ 72.. 79] */
2212
+ 0x42000000u, 0x42200000u, 0x42400000u, 0x42600000u,
2213
+ 0x42800000u, 0x42A00000u, 0x42C00000u, 0x42E00000u, /* [ 80.. 87] */
2214
+ 0x43000000u, 0x43200000u, 0x43400000u, 0x43600000u,
2215
+ 0x43800000u, 0x43A00000u, 0x43C00000u, 0x43E00000u, /* [ 88.. 95] */
2216
+ 0x44000000u, 0x44200000u, 0x44400000u, 0x44600000u,
2217
+ 0x44800000u, 0x44A00000u, 0x44C00000u, 0x44E00000u, /* [ 96..103] */
2218
+ 0x45000000u, 0x45200000u, 0x45400000u, 0x45600000u,
2219
+ 0x45800000u, 0x45A00000u, 0x45C00000u, 0x45E00000u, /* [104..111] */
2220
+ 0x46000000u, 0x46200000u, 0x46400000u, 0x46600000u,
2221
+ 0x46800000u, 0x46A00000u, 0x46C00000u, 0x46E00000u, /* [112..119] */
2222
+ 0x47000000u, 0x47200000u, 0x47400000u, 0x47600000u,
2223
+ 0x7F800000u, 0x7FC00000u, 0x7FC00000u, 0x7FC00000u /* [120..127] */
2224
+ };
2225
+
2226
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_rvv(nk_size_t column_count, nk_size_t depth) {
2227
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
2228
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
2229
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
2230
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
2231
+ return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
2232
+ column_count * sizeof(nk_f32_t); // per-column norms
2233
+ }
2234
+
2235
+ /**
2236
+ * @brief Pack B matrix from e5m2 to f32 for floating-point dot product.
2237
+ *
2238
+ * Each e5m2 byte is converted to f32 via `nk_e5m2_to_f32_serial`.
2239
+ * Padding values are zeroed. Column-panel layout with depth-contiguous storage.
2240
+ */
2241
+ NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth,
2242
+ nk_size_t b_stride_in_bytes, void *b_packed) {
2243
+ nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
2244
+ nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
2245
+ nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
2246
+ if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
2247
+
2248
+ nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
2249
+ header->column_count = (nk_u32_t)column_count;
2250
+ header->depth_dimensions = (nk_u32_t)depth;
2251
+ header->depth_padded_values = (nk_u32_t)depth_padded;
2252
+
2253
+ nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
2254
+ nk_size_t total = column_count * depth_padded;
2255
+ for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
2256
+
2257
+ for (nk_size_t column = 0; column < column_count; ++column) {
2258
+ nk_e5m2_t const *src = (nk_e5m2_t const *)((char const *)b + column * b_stride_in_bytes);
2259
+ nk_f32_t *dst = packed + column * depth_padded;
2260
+ for (nk_size_t k = 0; k < depth; ++k) nk_e5m2_to_f32_serial(&src[k], &dst[k]);
2261
+ }
2262
+
2263
+ // Append per-column norms after packed data
2264
+ nk_f32_t *norms = (nk_f32_t *)(packed + total);
2265
+ for (nk_size_t column = 0; column < column_count; ++column) {
2266
+ nk_e5m2_t const *src = (nk_e5m2_t const *)((char const *)b + column * b_stride_in_bytes);
2267
+ norms[column] = nk_dots_reduce_sumsq_e5m2_(src, depth);
2268
+ }
2269
+ }
2270
+
2271
+ /**
2272
+ * @brief e5m2 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
2273
+ *
2274
+ * Vectorizes over the depth dimension (k). For each (row, column) pair:
2275
+ * - Load pre-packed f32 values from B
2276
+ * - Load raw e5m2 bytes from A, convert on-the-fly via 128-entry f32 LUT gather:
2277
+ * extract 7-bit magnitude, zero-extend to u32, compute byte offsets (x4),
2278
+ * gather f32 bit patterns, inject sign bit from bit 7 (<<24), reinterpret as f32
2279
+ * - Widening FMA: f32xf32 -> f64 via `vfwmacc_vv_f64m4`
2280
+ *
2281
+ * Register tile: process 2 rows per iteration (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
2282
+ */
2283
+ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, void const *b_packed_buffer,
2284
+ nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
2285
+ nk_size_t depth, nk_size_t a_stride_in_bytes,
2286
+ nk_size_t c_stride_in_bytes) {
2287
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
2288
+ nk_size_t const depth_padded = header->depth_padded_values;
2289
+ nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
2290
+ sizeof(nk_cross_packed_buffer_header_t));
2291
+
2292
+ // Zero output matrix
2293
+ for (nk_size_t i = 0; i < row_count; ++i) {
2294
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
2295
+ for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
2296
+ }
2297
+
2298
+ // mr=2 register tile over rows
2299
+ nk_size_t row = 0;
2300
+ for (; row + 2 <= row_count; row += 2) {
2301
+ nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
2302
+ nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
2303
+ nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
2304
+ nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
2305
+
2306
+ for (nk_size_t column = 0; column < column_count; ++column) {
2307
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
2308
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2309
+ vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2310
+ vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2311
+
2312
+ nk_size_t remaining = depth;
2313
+ nk_size_t k = 0;
2314
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
2315
+ vector_length = __riscv_vsetvl_e32m2(remaining);
2316
+
2317
+ // Load pre-packed f32 B values
2318
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
2319
+
2320
+ // Load raw e5m2 bytes from each A row
2321
+ vuint8mf2_t raw0_u8mf2 = __riscv_vle8_v_u8mf2(a_row_0 + k, vector_length);
2322
+ vuint8mf2_t raw1_u8mf2 = __riscv_vle8_v_u8mf2(a_row_1 + k, vector_length);
2323
+
2324
+ // Extract 7-bit magnitudes, zero-extend to u32, compute byte offsets for f32 LUT
2325
+ vuint8mf2_t mag0_u8mf2 = __riscv_vand_vx_u8mf2(raw0_u8mf2, 0x7F, vector_length);
2326
+ vuint8mf2_t mag1_u8mf2 = __riscv_vand_vx_u8mf2(raw1_u8mf2, 0x7F, vector_length);
2327
+ vuint32m2_t idx0_u32m2 = __riscv_vzext_vf4_u32m2(mag0_u8mf2, vector_length);
2328
+ vuint32m2_t idx1_u32m2 = __riscv_vzext_vf4_u32m2(mag1_u8mf2, vector_length);
2329
+ vuint32m2_t off0_u32m2 = __riscv_vsll_vx_u32m2(idx0_u32m2, 2,
2330
+ vector_length); // byte offsets = index * 4
2331
+ vuint32m2_t off1_u32m2 = __riscv_vsll_vx_u32m2(idx1_u32m2, 2, vector_length);
2332
+
2333
+ // Gather f32 bit patterns from magnitude LUT
2334
+ vuint32m2_t bits0_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off0_u32m2,
2335
+ vector_length);
2336
+ vuint32m2_t bits1_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off1_u32m2,
2337
+ vector_length);
2338
+
2339
+ // Extract sign bit 7, shift to f32 sign position (bit 31)
2340
+ vuint8mf2_t sign0_u8mf2 = __riscv_vand_vx_u8mf2(raw0_u8mf2, 0x80, vector_length);
2341
+ vuint8mf2_t sign1_u8mf2 = __riscv_vand_vx_u8mf2(raw1_u8mf2, 0x80, vector_length);
2342
+ vuint32m2_t sign0_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign0_u8mf2, vector_length), 24,
2343
+ vector_length);
2344
+ vuint32m2_t sign1_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign1_u8mf2, vector_length), 24,
2345
+ vector_length);
2346
+
2347
+ // Apply sign and reinterpret as f32
2348
+ vfloat32m2_t a_vector_0_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2349
+ __riscv_vor_vv_u32m2(bits0_u32m2, sign0_u32m2, vector_length));
2350
+ vfloat32m2_t a_vector_1_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2351
+ __riscv_vor_vv_u32m2(bits1_u32m2, sign1_u32m2, vector_length));
2352
+
2353
+ // Widening FMA: f32xf32 -> f64
2354
+ accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
2355
+ vector_length);
2356
+ accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
2357
+ vector_length);
2358
+ }
2359
+
2360
+ // Horizontal reduce and narrow to f32
2361
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2362
+ c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2363
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
2364
+ c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2365
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
2366
+ }
2367
+ }
2368
+ // Remainder rows
2369
+ for (; row < row_count; ++row) {
2370
+ nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
2371
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
2372
+ for (nk_size_t column = 0; column < column_count; ++column) {
2373
+ nk_f32_t const *b_column = packed_data + column * depth_padded;
2374
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2375
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2376
+ nk_size_t remaining = depth;
2377
+ nk_size_t k = 0;
2378
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
2379
+ vector_length = __riscv_vsetvl_e32m2(remaining);
2380
+ vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
2381
+ vuint8mf2_t raw_a_u8mf2 = __riscv_vle8_v_u8mf2(a_row + k, vector_length);
2382
+ vuint8mf2_t mag_a_u8mf2 = __riscv_vand_vx_u8mf2(raw_a_u8mf2, 0x7F, vector_length);
2383
+ vuint32m2_t idx_a_u32m2 = __riscv_vzext_vf4_u32m2(mag_a_u8mf2, vector_length);
2384
+ vuint32m2_t off_a_u32m2 = __riscv_vsll_vx_u32m2(idx_a_u32m2, 2, vector_length);
2385
+ vuint32m2_t bits_a_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off_a_u32m2,
2386
+ vector_length);
2387
+ vuint8mf2_t sign_a_u8mf2 = __riscv_vand_vx_u8mf2(raw_a_u8mf2, 0x80, vector_length);
2388
+ vuint32m2_t sign_a_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_a_u8mf2, vector_length),
2389
+ 24, vector_length);
2390
+ vfloat32m2_t a_vector_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2391
+ __riscv_vor_vv_u32m2(bits_a_u32m2, sign_a_u32m2, vector_length));
2392
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
2393
+ vector_length);
2394
+ }
2395
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2396
+ c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2397
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
2398
+ }
2399
+ }
2400
+ }
2401
+
2402
+ /**
2403
+ * @brief Public e5m2 packed GEMM wrapper matching the declared signature in dots.h.
2404
+ */
2405
+ NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
2406
+ nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
2407
+ nk_dots_packed_e5m2_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
2408
+ }
2409
+
2410
+ /**
2411
+ * @brief Symmetric e5m2 GEMM: C = A * A^T, upper triangle + mirror.
2412
+ *
2413
+ * Uses f32 LUT gather with f64 widened accumulation for precision.
2414
+ * Both operands are converted from e5m2 on-the-fly via magnitude LUT.
2415
+ * Processes only the rows in [row_start, row_start + row_count) for parallelism.
2416
+ */
2417
+ NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2418
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2419
+ nk_size_t row_start, nk_size_t row_count) {
2420
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
2421
+ nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
2422
+
2423
+ for (nk_size_t i = row_start; i < row_end; ++i) {
2424
+ nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
2425
+ for (nk_size_t j = i; j < n_vectors; ++j) {
2426
+ nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
2427
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
2428
+ vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2429
+ nk_size_t remaining = depth;
2430
+ nk_size_t k = 0;
2431
+ for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
2432
+ vector_length = __riscv_vsetvl_e32m2(remaining);
2433
+ vuint8mf2_t raw_i_u8mf2 = __riscv_vle8_v_u8mf2(a_i + k, vector_length);
2434
+ vuint8mf2_t raw_j_u8mf2 = __riscv_vle8_v_u8mf2(a_j + k, vector_length);
2435
+
2436
+ // Convert i-vector via LUT gather
2437
+ vuint8mf2_t mag_i_u8mf2 = __riscv_vand_vx_u8mf2(raw_i_u8mf2, 0x7F, vector_length);
2438
+ vuint32m2_t idx_i_u32m2 = __riscv_vzext_vf4_u32m2(mag_i_u8mf2, vector_length);
2439
+ vuint32m2_t off_i_u32m2 = __riscv_vsll_vx_u32m2(idx_i_u32m2, 2, vector_length);
2440
+ vuint32m2_t bits_i_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off_i_u32m2,
2441
+ vector_length);
2442
+ vuint8mf2_t sign_i_u8mf2 = __riscv_vand_vx_u8mf2(raw_i_u8mf2, 0x80, vector_length);
2443
+ vuint32m2_t sign_i_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_i_u8mf2, vector_length),
2444
+ 24, vector_length);
2445
+ vfloat32m2_t val_i_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2446
+ __riscv_vor_vv_u32m2(bits_i_u32m2, sign_i_u32m2, vector_length));
2447
+
2448
+ // Convert j-vector via LUT gather
2449
+ vuint8mf2_t mag_j_u8mf2 = __riscv_vand_vx_u8mf2(raw_j_u8mf2, 0x7F, vector_length);
2450
+ vuint32m2_t idx_j_u32m2 = __riscv_vzext_vf4_u32m2(mag_j_u8mf2, vector_length);
2451
+ vuint32m2_t off_j_u32m2 = __riscv_vsll_vx_u32m2(idx_j_u32m2, 2, vector_length);
2452
+ vuint32m2_t bits_j_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off_j_u32m2,
2453
+ vector_length);
2454
+ vuint8mf2_t sign_j_u8mf2 = __riscv_vand_vx_u8mf2(raw_j_u8mf2, 0x80, vector_length);
2455
+ vuint32m2_t sign_j_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_j_u8mf2, vector_length),
2456
+ 24, vector_length);
2457
+ vfloat32m2_t val_j_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
2458
+ __riscv_vor_vv_u32m2(bits_j_u32m2, sign_j_u32m2, vector_length));
2459
+
2460
+ // Widening FMA: f32xf32 -> f64
2461
+ accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, val_i_f32m2, val_j_f32m2,
2462
+ vector_length);
2463
+ }
2464
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2465
+ nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
2466
+ __riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
2467
+ result[i * result_stride_elements + j] = dot;
2468
+ }
2469
+ }
2470
+ }
2471
+
2472
+ #pragma endregion // Quarter Precision E5M2
2473
+
2474
+ #if defined(__cplusplus)
2475
+ } // extern "C"
2476
+ #endif
2477
+
2478
+ #if defined(__clang__)
2479
+ #pragma clang attribute pop
2480
+ #elif defined(__GNUC__)
2481
+ #pragma GCC pop_options
2482
+ #endif
2483
+
2484
+ #endif // NK_TARGET_RVV
2485
+ #endif // NK_TARGET_RISCV_
2486
+ #endif // NK_DOTS_RVV_H