numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,714 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for RISC-V.
3
+ * @file include/numkong/dot/rvv.h
4
+ * @author Ash Vardanian
5
+ * @date January 5, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * SpacemiT K1 and similar chips implement base RVV 1.0 without half-precision extensions.
10
+ * RVV uses vector length agnostic programming where:
11
+ * - `vsetvl_e*m*(n)` sets VL = min(n, VLMAX) and returns actual VL
12
+ * - Loads/stores with VL automatically handle partial vectors (tail elements)
13
+ * - No explicit masking needed for simple reductions
14
+ *
15
+ * This file contains base RVV 1.0 operations (i8, u8, f32, f64).
16
+ * For f16 (Zvfh) see rvvhalf.h, for bf16 (Zvfbfwma) see rvvbf16.h.
17
+ *
18
+ * Widening operations:
19
+ * - i8 ⨯ i8 → i16 via vwmul, then i16 reduction → i32 via vwredsum
20
+ * - f32 ⨯ f32 → f64 via vfwmul (for precision, like Skylake)
21
+ */
22
+ #ifndef NK_DOT_RVV_H
23
+ #define NK_DOT_RVV_H
24
+
25
+ #if NK_TARGET_RISCV_
26
+ #if NK_TARGET_RVV
27
+
28
+ #include "numkong/types.h"
29
+ #include "numkong/cast/rvv.h" // `nk_e4m3m1_to_f32m4_rvv_`
30
+ #include "numkong/set/rvv.h" // `nk_popcount_u8m4_rvv_`
31
+
32
+ #if defined(__clang__)
33
+ #pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
34
+ #elif defined(__GNUC__)
35
+ #pragma GCC push_options
36
+ #pragma GCC target("arch=+v")
37
+ #endif
38
+
39
+ #if defined(__cplusplus)
40
+ extern "C" {
41
+ #endif
42
+
43
+ /** @brief Compensated horizontal sum of RVV f64m1 lanes via TwoSum tree reduction.
44
+ *
45
+ * Uses vslidedown to extract the upper half at each tree level (same pattern as
46
+ * nk_reduce_vsaddu_u64m1_rvv_ in reduce/rvv.h). Tail lanes beyond vlmax are zero
47
+ * from the initial vfmv_v_f, so they are harmless in the reduction.
48
+ */
49
+ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64m1_rvv_(vfloat64m1_t sum_f64m1, vfloat64m1_t compensation_f64m1) {
50
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
51
+ // Stage 0: TwoSum merge of sum + compensation
52
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_f64m1, compensation_f64m1, vlmax);
53
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_f64m1, vlmax);
54
+ vfloat64m1_t accumulated_error_f64m1 = __riscv_vfadd_vv_f64m1(
55
+ __riscv_vfsub_vv_f64m1(sum_f64m1, __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vlmax),
56
+ vlmax),
57
+ __riscv_vfsub_vv_f64m1(compensation_f64m1, virtual_addend_f64m1, vlmax), vlmax);
58
+ // Tree reduction: TwoSum halving at each level
59
+ for (nk_size_t half = vlmax / 2; half > 0; half >>= 1) {
60
+ vfloat64m1_t upper_sum_f64m1 = __riscv_vslidedown_vx_f64m1(tentative_sum_f64m1, half, vlmax);
61
+ vfloat64m1_t upper_error_f64m1 = __riscv_vslidedown_vx_f64m1(accumulated_error_f64m1, half, vlmax);
62
+ vfloat64m1_t halved_tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(tentative_sum_f64m1, upper_sum_f64m1, vlmax);
63
+ vfloat64m1_t halved_virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(halved_tentative_sum_f64m1,
64
+ tentative_sum_f64m1, vlmax);
65
+ vfloat64m1_t rounding_error_f64m1 = __riscv_vfadd_vv_f64m1(
66
+ __riscv_vfsub_vv_f64m1(
67
+ tentative_sum_f64m1,
68
+ __riscv_vfsub_vv_f64m1(halved_tentative_sum_f64m1, halved_virtual_addend_f64m1, vlmax), vlmax),
69
+ __riscv_vfsub_vv_f64m1(upper_sum_f64m1, halved_virtual_addend_f64m1, vlmax), vlmax);
70
+ tentative_sum_f64m1 = halved_tentative_sum_f64m1;
71
+ accumulated_error_f64m1 = __riscv_vfadd_vv_f64m1(
72
+ __riscv_vfadd_vv_f64m1(accumulated_error_f64m1, upper_error_f64m1, vlmax), rounding_error_f64m1, vlmax);
73
+ }
74
+ return __riscv_vfmv_f_s_f64m1_f64(tentative_sum_f64m1) + __riscv_vfmv_f_s_f64m1_f64(accumulated_error_f64m1);
75
+ }
76
+
77
+ NK_PUBLIC void nk_dot_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
78
+ nk_i32_t *result) {
79
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
80
+ vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
81
+ for (nk_size_t vector_length; count_scalars > 0;
82
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
83
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
84
+ vint8m1_t a_i8m1 = __riscv_vle8_v_i8m1(a_scalars, vector_length);
85
+ vint8m1_t b_i8m1 = __riscv_vle8_v_i8m1(b_scalars, vector_length);
86
+ // Widening multiply: i8 ⨯ i8 → i16
87
+ vint16m2_t ab_i16m2 = __riscv_vwmul_vv_i16m2(a_i8m1, b_i8m1, vector_length);
88
+ // Per-lane widening accumulate: i32 += i16
89
+ sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, ab_i16m2, vector_length);
90
+ }
91
+ // Single horizontal reduction at the end
92
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
93
+ *result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, vlmax));
94
+ }
95
+
96
+ NK_PUBLIC void nk_dot_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
97
+ nk_u32_t *result) {
98
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
99
+ vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
100
+ for (nk_size_t vector_length; count_scalars > 0;
101
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
102
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
103
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1(a_scalars, vector_length);
104
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1(b_scalars, vector_length);
105
+ // Widening multiply: u8 ⨯ u8 → u16
106
+ vuint16m2_t ab_u16m2 = __riscv_vwmulu_vv_u16m2(a_u8m1, b_u8m1, vector_length);
107
+ // Per-lane widening accumulate: u32 += u16
108
+ sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, ab_u16m2, vector_length);
109
+ }
110
+ // Single horizontal reduction at the end
111
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
112
+ *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, vlmax));
113
+ }
114
+
115
+ NK_PUBLIC void nk_dot_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
116
+ nk_f64_t *result) {
117
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
118
+ vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
119
+ for (nk_size_t vector_length; count_scalars > 0;
120
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
121
+ vector_length = __riscv_vsetvl_e32m1(count_scalars);
122
+ vfloat32m1_t a_f32m1 = __riscv_vle32_v_f32m1(a_scalars, vector_length);
123
+ vfloat32m1_t b_f32m1 = __riscv_vle32_v_f32m1(b_scalars, vector_length);
124
+ // Widening FMA: f64 += f32 ⨯ f32, per-lane accumulation
125
+ sum_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_f64m2, a_f32m1, b_f32m1, vector_length);
126
+ }
127
+ // Single horizontal reduction at the end
128
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
129
+ *result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1, vlmax));
130
+ }
131
+
132
+ NK_PUBLIC void nk_dot_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
133
+ nk_f64_t *result) {
134
+ // Dot2 (Ogita-Rump-Oishi) compensated accumulation via TwoProd + TwoSum
135
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
136
+ vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
137
+ vfloat64m1_t compensation_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
138
+ for (nk_size_t vector_length; count_scalars > 0;
139
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
140
+ vector_length = __riscv_vsetvl_e64m1(count_scalars);
141
+ vfloat64m1_t a_f64m1 = __riscv_vle64_v_f64m1(a_scalars, vector_length);
142
+ vfloat64m1_t b_f64m1 = __riscv_vle64_v_f64m1(b_scalars, vector_length);
143
+ // TwoProd: product = a*b, product_error = fma(a,b,-product)
144
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_f64m1, b_f64m1, vector_length);
145
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_f64m1, b_f64m1, vector_length);
146
+ // TwoSum: tentative_sum = sum + product
147
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_f64m1, product_f64m1, vector_length);
148
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_f64m1, vector_length);
149
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
150
+ __riscv_vfsub_vv_f64m1(sum_f64m1,
151
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
152
+ vector_length),
153
+ __riscv_vfsub_vv_f64m1(product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
154
+ // Tail-undisturbed updates: preserve zero tails across partial iterations
155
+ sum_f64m1 = __riscv_vslideup_vx_f64m1_tu(sum_f64m1, tentative_sum_f64m1, 0, vector_length);
156
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, product_error_f64m1, vector_length);
157
+ compensation_f64m1 = __riscv_vfadd_vv_f64m1_tu(compensation_f64m1, compensation_f64m1, total_error_f64m1,
158
+ vector_length);
159
+ }
160
+ // Compensated horizontal reduction
161
+ *result = nk_dot_stable_sum_f64m1_rvv_(sum_f64m1, compensation_f64m1);
162
+ }
163
+
164
+ NK_PUBLIC void nk_dot_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
165
+ nk_f32_t *result) {
166
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
167
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
168
+ for (nk_size_t vector_length; count_scalars > 0;
169
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
170
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
171
+
172
+ // Load f16 as u16 bits and convert to f32 via helper
173
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a_scalars, vector_length);
174
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b_scalars, vector_length);
175
+ vfloat32m2_t a_f32m2 = nk_f16m1_to_f32m2_rvv_(a_u16m1, vector_length);
176
+ vfloat32m2_t b_f32m2 = nk_f16m1_to_f32m2_rvv_(b_u16m1, vector_length);
177
+
178
+ // Per-lane FMA accumulation
179
+ sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, a_f32m2, b_f32m2, vector_length);
180
+ }
181
+ // Single horizontal reduction at the end
182
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
183
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
184
+ }
185
+
186
+ NK_PUBLIC void nk_dot_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
187
+ nk_f32_t *result) {
188
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
189
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
190
+ for (nk_size_t vector_length; count_scalars > 0;
191
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
192
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
193
+
194
+ // Load bf16 as u16 and convert to f32 via helper
195
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a_scalars, vector_length);
196
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b_scalars, vector_length);
197
+ vfloat32m2_t a_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_u16m1, vector_length);
198
+ vfloat32m2_t b_f32m2 = nk_bf16m1_to_f32m2_rvv_(b_u16m1, vector_length);
199
+
200
+ // Per-lane FMA accumulation
201
+ sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, a_f32m2, b_f32m2, vector_length);
202
+ }
203
+ // Single horizontal reduction at the end
204
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
205
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
206
+ }
207
+
208
+ NK_PUBLIC void nk_dot_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
209
+ nk_f32_t *result) {
210
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
211
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
212
+ for (nk_size_t vector_length; count_scalars > 0;
213
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
214
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
215
+
216
+ // Load e4m3 as u8 and convert to f32 via helper
217
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
218
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
219
+ vfloat32m4_t a_f32m4 = nk_e4m3m1_to_f32m4_rvv_(a_u8m1, vector_length);
220
+ vfloat32m4_t b_f32m4 = nk_e4m3m1_to_f32m4_rvv_(b_u8m1, vector_length);
221
+
222
+ // Per-lane FMA accumulation
223
+ sum_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sum_f32m4, a_f32m4, b_f32m4, vector_length);
224
+ }
225
+ // Single horizontal reduction at the end
226
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
227
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
228
+ }
229
+
230
+ NK_PUBLIC void nk_dot_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
231
+ nk_f32_t *result) {
232
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
233
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
234
+ for (nk_size_t vector_length; count_scalars > 0;
235
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
236
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
237
+
238
+ // Load e5m2 as u8 and convert to f32 via helper
239
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
240
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
241
+ vfloat32m4_t a_f32m4 = nk_e5m2m1_to_f32m4_rvv_(a_u8m1, vector_length);
242
+ vfloat32m4_t b_f32m4 = nk_e5m2m1_to_f32m4_rvv_(b_u8m1, vector_length);
243
+
244
+ // Per-lane FMA accumulation
245
+ sum_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sum_f32m4, a_f32m4, b_f32m4, vector_length);
246
+ }
247
+ // Single horizontal reduction at the end
248
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
249
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
250
+ }
251
+
252
+ NK_PUBLIC void nk_dot_e2m3_rvv(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
253
+ nk_f32_t *result) {
254
+ // Integer dot product for e2m3 using byte gather LUT + widening multiply.
255
+ // Every e2m3 value × 16 is an exact integer in [-120, +120].
256
+ // Result = i32_dot / 256.0f (exact, no rounding error).
257
+ static nk_u8_t const lut_magnitude[32] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
258
+ 32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120};
259
+
260
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
261
+ vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
262
+ for (nk_size_t vector_length; count_scalars > 0;
263
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
264
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
265
+ vuint8m1_t a_e2m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
266
+ vuint8m1_t b_e2m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
267
+
268
+ // Magnitude extraction + byte gather LUT
269
+ vuint8m1_t a_magnitude_u8m1 = __riscv_vand_vx_u8m1(a_e2m3_u8m1, 0x1F, vector_length);
270
+ vuint8m1_t b_magnitude_u8m1 = __riscv_vand_vx_u8m1(b_e2m3_u8m1, 0x1F, vector_length);
271
+ vuint8m1_t a_unsigned_u8m1 = __riscv_vluxei8_v_u8m1(lut_magnitude, a_magnitude_u8m1, vector_length);
272
+ vuint8m1_t b_unsigned_u8m1 = __riscv_vluxei8_v_u8m1(lut_magnitude, b_magnitude_u8m1, vector_length);
273
+
274
+ // Combined sign + conditional negate
275
+ vuint8m1_t sign_combined_u8m1 = __riscv_vand_vx_u8m1(
276
+ __riscv_vxor_vv_u8m1(a_e2m3_u8m1, b_e2m3_u8m1, vector_length), 0x20, vector_length);
277
+ vbool8_t negate_mask_b8 = __riscv_vmsne_vx_u8m1_b8(sign_combined_u8m1, 0, vector_length);
278
+ vint8m1_t b_positive_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(b_unsigned_u8m1);
279
+ vint8m1_t b_negated_i8m1 = __riscv_vneg_v_i8m1(b_positive_i8m1, vector_length);
280
+ vint8m1_t b_signed_i8m1 = __riscv_vmerge_vvm_i8m1(b_positive_i8m1, b_negated_i8m1, negate_mask_b8,
281
+ vector_length);
282
+
283
+ // Widening multiply: i8×i8 → i16, then accumulate: i32 += i16
284
+ vint8m1_t a_signed_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(a_unsigned_u8m1);
285
+ vint16m2_t products_i16m2 = __riscv_vwmul_vv_i16m2(a_signed_i8m1, b_signed_i8m1, vector_length);
286
+ sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, products_i16m2, vector_length);
287
+ }
288
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
289
+ nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, vlmax));
290
+ *result = (nk_f32_t)sum / 256.0f;
291
+ }
292
+
293
+ NK_PUBLIC void nk_dot_e3m2_rvv(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
294
+ nk_f32_t *result) {
295
+ // Integer dot product for e3m2 using i16 gather LUT + widening multiply.
296
+ // Every e3m2 value × 16 is an exact integer, but magnitudes reach 448, requiring i16.
297
+ // Result = i32_dot / 256.0f (exact, no rounding error).
298
+ static nk_u16_t const lut_magnitude[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28,
299
+ 32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384, 448};
300
+
301
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
302
+ vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
303
+ for (nk_size_t vector_length; count_scalars > 0;
304
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
305
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
306
+ vuint8m1_t a_e3m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
307
+ vuint8m1_t b_e3m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
308
+
309
+ // Magnitude extraction: lower 5 bits as u16 byte offsets for gather
310
+ vuint8m1_t a_mag_u8m1 = __riscv_vand_vx_u8m1(a_e3m2_u8m1, 0x1F, vector_length);
311
+ vuint8m1_t b_mag_u8m1 = __riscv_vand_vx_u8m1(b_e3m2_u8m1, 0x1F, vector_length);
312
+ vuint16m2_t a_idx_u16m2 = __riscv_vzext_vf2_u16m2(a_mag_u8m1, vector_length);
313
+ vuint16m2_t b_idx_u16m2 = __riscv_vzext_vf2_u16m2(b_mag_u8m1, vector_length);
314
+
315
+ // Gather from i16 LUT: byte offsets = index × 2
316
+ vuint16m2_t a_byte_offsets_u16m2 = __riscv_vsll_vx_u16m2(a_idx_u16m2, 1, vector_length);
317
+ vuint16m2_t b_byte_offsets_u16m2 = __riscv_vsll_vx_u16m2(b_idx_u16m2, 1, vector_length);
318
+ vuint16m2_t a_unsigned_u16m2 = __riscv_vluxei16_v_u16m2(lut_magnitude, a_byte_offsets_u16m2, vector_length);
319
+ vuint16m2_t b_unsigned_u16m2 = __riscv_vluxei16_v_u16m2(lut_magnitude, b_byte_offsets_u16m2, vector_length);
320
+
321
+ // Extract sign bits and apply conditional negate
322
+ vuint8m1_t a_sign_u8m1 = __riscv_vand_vx_u8m1(a_e3m2_u8m1, 0x20, vector_length);
323
+ vuint8m1_t b_sign_u8m1 = __riscv_vand_vx_u8m1(b_e3m2_u8m1, 0x20, vector_length);
324
+ vbool8_t a_negate_b8 = __riscv_vmsne_vx_u8m1_b8(a_sign_u8m1, 0, vector_length);
325
+ vbool8_t b_negate_b8 = __riscv_vmsne_vx_u8m1_b8(b_sign_u8m1, 0, vector_length);
326
+
327
+ vint16m2_t a_signed_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(a_unsigned_u16m2);
328
+ a_signed_i16m2 = __riscv_vneg_v_i16m2_mu(a_negate_b8, a_signed_i16m2, a_signed_i16m2, vector_length);
329
+
330
+ vint16m2_t b_signed_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(b_unsigned_u16m2);
331
+ b_signed_i16m2 = __riscv_vneg_v_i16m2_mu(b_negate_b8, b_signed_i16m2, b_signed_i16m2, vector_length);
332
+
333
+ // Widening multiply-accumulate: i16×i16 → i32
334
+ sum_i32m4 = __riscv_vwmacc_vv_i32m4_tu(sum_i32m4, a_signed_i16m2, b_signed_i16m2, vector_length);
335
+ }
336
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
337
+ nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, vlmax));
338
+ *result = (nk_f32_t)sum / 256.0f;
339
+ }
340
+
341
+ NK_PUBLIC void nk_dot_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scalars, nk_size_t count_dimensions,
342
+ nk_i32_t *result) {
343
+ // count_dimensions = number of 4-bit values, not bytes
344
+ count_dimensions = nk_size_round_up_to_multiple_(count_dimensions, 2);
345
+ nk_size_t n_full_bytes = count_dimensions / 2;
346
+
347
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
348
+ vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
349
+ for (nk_size_t vector_length; n_full_bytes > 0;
350
+ n_full_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
351
+ vector_length = __riscv_vsetvl_e8m1(n_full_bytes);
352
+
353
+ vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
354
+ vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
355
+
356
+ vuint8m1_t a_high_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
357
+ vuint8m1_t b_high_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
358
+ vuint8m1_t a_low_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
359
+ vuint8m1_t b_low_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
360
+
361
+ // Sign extend 4-bit to 8-bit: (x ^ 8) - 8
362
+ vint8m1_t a_high_i8m1 = __riscv_vsub_vx_i8m1(
363
+ __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(a_high_u8m1), 8, vector_length), 8, vector_length);
364
+ vint8m1_t b_high_i8m1 = __riscv_vsub_vx_i8m1(
365
+ __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(b_high_u8m1), 8, vector_length), 8, vector_length);
366
+ vint8m1_t a_low_i8m1 = __riscv_vsub_vx_i8m1(
367
+ __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(a_low_u8m1), 8, vector_length), 8, vector_length);
368
+ vint8m1_t b_low_i8m1 = __riscv_vsub_vx_i8m1(
369
+ __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(b_low_u8m1), 8, vector_length), 8, vector_length);
370
+
371
+ // Widening multiply: i8 ⨯ i8 → i16
372
+ vint16m2_t ab_high_i16m2 = __riscv_vwmul_vv_i16m2(a_high_i8m1, b_high_i8m1, vector_length);
373
+ vint16m2_t ab_low_i16m2 = __riscv_vwmul_vv_i16m2(a_low_i8m1, b_low_i8m1, vector_length);
374
+
375
+ // Per-lane widening accumulate: i32 += i16
376
+ sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, ab_high_i16m2, vector_length);
377
+ sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, ab_low_i16m2, vector_length);
378
+ }
379
+ // Single horizontal reduction at the end
380
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
381
+ *result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, vlmax));
382
+ }
383
+
384
+ NK_PUBLIC void nk_dot_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scalars, nk_size_t count_dimensions,
385
+ nk_u32_t *result) {
386
+ // count_dimensions = number of 4-bit values, not bytes
387
+ count_dimensions = nk_size_round_up_to_multiple_(count_dimensions, 2);
388
+ nk_size_t n_full_bytes = count_dimensions / 2;
389
+
390
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
391
+ vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
392
+ for (nk_size_t vector_length; n_full_bytes > 0;
393
+ n_full_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
394
+ vector_length = __riscv_vsetvl_e8m1(n_full_bytes);
395
+
396
+ vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
397
+ vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
398
+
399
+ vuint8m1_t a_high_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
400
+ vuint8m1_t b_high_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
401
+ vuint8m1_t a_low_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
402
+ vuint8m1_t b_low_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
403
+
404
+ // Widening multiply: u8 ⨯ u8 → u16
405
+ vuint16m2_t ab_high_u16m2 = __riscv_vwmulu_vv_u16m2(a_high_u8m1, b_high_u8m1, vector_length);
406
+ vuint16m2_t ab_low_u16m2 = __riscv_vwmulu_vv_u16m2(a_low_u8m1, b_low_u8m1, vector_length);
407
+
408
+ // Per-lane widening accumulate: u32 += u16
409
+ sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, ab_high_u16m2, vector_length);
410
+ sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, ab_low_u16m2, vector_length);
411
+ }
412
+ // Single horizontal reduction at the end
413
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
414
+ *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, vlmax));
415
+ }
416
+
417
+ NK_PUBLIC void nk_dot_u1_rvv(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
418
+ nk_size_t count_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
419
+
420
+ vuint32m1_t sum_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
421
+
422
+ nk_size_t i = 0;
423
+ for (nk_size_t vector_length; i + 1 <= count_bytes; i += vector_length) {
424
+ vector_length = __riscv_vsetvl_e8m4(count_bytes - i);
425
+
426
+ // Load and AND to find shared bits (dot product of binary vectors)
427
+ vuint8m4_t a_u8m4 = __riscv_vle8_v_u8m4(a + i, vector_length);
428
+ vuint8m4_t b_u8m4 = __riscv_vle8_v_u8m4(b + i, vector_length);
429
+ vuint8m4_t and_u8m4 = __riscv_vand_vv_u8m4(a_u8m4, b_u8m4, vector_length);
430
+
431
+ // Popcount each byte using arithmetic SWAR
432
+ vuint8m4_t popcount_u8m4 = nk_popcount_u8m4_rvv_(and_u8m4, vector_length);
433
+
434
+ // Widen to u16 and accumulate via widening reduction sum
435
+ vuint16m8_t popcount_u16m8 = __riscv_vwaddu_vx_u16m8(popcount_u8m4, 0, vector_length);
436
+ sum_u32m1 = __riscv_vwredsumu_vs_u16m8_u32m1(popcount_u16m8, sum_u32m1, vector_length);
437
+ }
438
+
439
+ *result = __riscv_vmv_x_s_u32m1_u32(sum_u32m1);
440
+ }
441
+
442
+ NK_PUBLIC void nk_dot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
443
+ nk_f64c_t *results) {
444
+ nk_f32_t const *a_f32 = (nk_f32_t const *)a_pairs;
445
+ nk_f32_t const *b_f32 = (nk_f32_t const *)b_pairs;
446
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
447
+ vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
448
+ vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
449
+ for (nk_size_t vector_length; count_pairs > 0;
450
+ count_pairs -= vector_length, a_f32 += vector_length * 2, b_f32 += vector_length * 2) {
451
+ vector_length = __riscv_vsetvl_e32m1(count_pairs);
452
+ vfloat32m1x2_t a_f32m1x2 = __riscv_vlseg2e32_v_f32m1x2(a_f32, vector_length);
453
+ vfloat32m1x2_t b_f32m1x2 = __riscv_vlseg2e32_v_f32m1x2(b_f32, vector_length);
454
+ vfloat32m1_t a_real_f32m1 = __riscv_vget_v_f32m1x2_f32m1(a_f32m1x2, 0);
455
+ vfloat32m1_t a_imag_f32m1 = __riscv_vget_v_f32m1x2_f32m1(a_f32m1x2, 1);
456
+ vfloat32m1_t b_real_f32m1 = __riscv_vget_v_f32m1x2_f32m1(b_f32m1x2, 0);
457
+ vfloat32m1_t b_imag_f32m1 = __riscv_vget_v_f32m1x2_f32m1(b_f32m1x2, 1);
458
+ // real += a_real * b_real - a_imag * b_imag
459
+ sum_real_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_real_f64m2, a_real_f32m1, b_real_f32m1, vector_length);
460
+ sum_real_f64m2 = __riscv_vfwnmsac_vv_f64m2_tu(sum_real_f64m2, a_imag_f32m1, b_imag_f32m1, vector_length);
461
+ // imag += a_real * b_imag + a_imag * b_real
462
+ sum_imag_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_imag_f64m2, a_real_f32m1, b_imag_f32m1, vector_length);
463
+ sum_imag_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_imag_f64m2, a_imag_f32m1, b_real_f32m1, vector_length);
464
+ }
465
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
466
+ results->real = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_real_f64m2, zero_f64m1, vlmax));
467
+ results->imag = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_imag_f64m2, zero_f64m1, vlmax));
468
+ }
469
+
470
+ NK_PUBLIC void nk_vdot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
471
+ nk_f64c_t *results) {
472
+ nk_f32_t const *a_f32 = (nk_f32_t const *)a_pairs;
473
+ nk_f32_t const *b_f32 = (nk_f32_t const *)b_pairs;
474
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
475
+ vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
476
+ vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
477
+ for (nk_size_t vector_length; count_pairs > 0;
478
+ count_pairs -= vector_length, a_f32 += vector_length * 2, b_f32 += vector_length * 2) {
479
+ vector_length = __riscv_vsetvl_e32m1(count_pairs);
480
+ vfloat32m1x2_t a_f32m1x2 = __riscv_vlseg2e32_v_f32m1x2(a_f32, vector_length);
481
+ vfloat32m1x2_t b_f32m1x2 = __riscv_vlseg2e32_v_f32m1x2(b_f32, vector_length);
482
+ vfloat32m1_t a_real_f32m1 = __riscv_vget_v_f32m1x2_f32m1(a_f32m1x2, 0);
483
+ vfloat32m1_t a_imag_f32m1 = __riscv_vget_v_f32m1x2_f32m1(a_f32m1x2, 1);
484
+ vfloat32m1_t b_real_f32m1 = __riscv_vget_v_f32m1x2_f32m1(b_f32m1x2, 0);
485
+ vfloat32m1_t b_imag_f32m1 = __riscv_vget_v_f32m1x2_f32m1(b_f32m1x2, 1);
486
+ // Conjugate dot: real += a_real * b_real + a_imag * b_imag
487
+ sum_real_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_real_f64m2, a_real_f32m1, b_real_f32m1, vector_length);
488
+ sum_real_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_real_f64m2, a_imag_f32m1, b_imag_f32m1, vector_length);
489
+ // Conjugate dot: imag += a_real * b_imag - a_imag * b_real
490
+ sum_imag_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_imag_f64m2, a_real_f32m1, b_imag_f32m1, vector_length);
491
+ sum_imag_f64m2 = __riscv_vfwnmsac_vv_f64m2_tu(sum_imag_f64m2, a_imag_f32m1, b_real_f32m1, vector_length);
492
+ }
493
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
494
+ results->real = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_real_f64m2, zero_f64m1, vlmax));
495
+ results->imag = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_imag_f64m2, zero_f64m1, vlmax));
496
+ }
497
+
498
+ NK_PUBLIC void nk_dot_f64c_rvv(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
499
+ nk_f64c_t *results) {
500
+ // Dot2 (Ogita-Rump-Oishi) compensated complex dot product
501
+ nk_f64_t const *a_f64 = (nk_f64_t const *)a_pairs;
502
+ nk_f64_t const *b_f64 = (nk_f64_t const *)b_pairs;
503
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
504
+ vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
505
+ vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
506
+ vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
507
+ vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
508
+ for (nk_size_t vector_length; count_pairs > 0;
509
+ count_pairs -= vector_length, a_f64 += vector_length * 2, b_f64 += vector_length * 2) {
510
+ vector_length = __riscv_vsetvl_e64m1(count_pairs);
511
+ vfloat64m1x2_t a_f64m1x2 = __riscv_vlseg2e64_v_f64m1x2(a_f64, vector_length);
512
+ vfloat64m1x2_t b_f64m1x2 = __riscv_vlseg2e64_v_f64m1x2(b_f64, vector_length);
513
+ vfloat64m1_t a_real_f64m1 = __riscv_vget_v_f64m1x2_f64m1(a_f64m1x2, 0);
514
+ vfloat64m1_t a_imag_f64m1 = __riscv_vget_v_f64m1x2_f64m1(a_f64m1x2, 1);
515
+ vfloat64m1_t b_real_f64m1 = __riscv_vget_v_f64m1x2_f64m1(b_f64m1x2, 0);
516
+ vfloat64m1_t b_imag_f64m1 = __riscv_vget_v_f64m1x2_f64m1(b_f64m1x2, 1);
517
+ // TwoProd+TwoSum: sum_real += a_real * b_real
518
+ {
519
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_real_f64m1, b_real_f64m1, vector_length);
520
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_real_f64m1, b_real_f64m1,
521
+ vector_length);
522
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_real_f64m1, product_f64m1, vector_length);
523
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_real_f64m1,
524
+ vector_length);
525
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
526
+ __riscv_vfsub_vv_f64m1(sum_real_f64m1,
527
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
528
+ vector_length),
529
+ __riscv_vfsub_vv_f64m1(product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
530
+ sum_real_f64m1 = __riscv_vslideup_vx_f64m1_tu(sum_real_f64m1, tentative_sum_f64m1, 0, vector_length);
531
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, product_error_f64m1,
532
+ vector_length);
533
+ comp_real_f64m1 = __riscv_vfadd_vv_f64m1_tu(comp_real_f64m1, comp_real_f64m1, total_error_f64m1,
534
+ vector_length);
535
+ }
536
+ // TwoProd+TwoSum: sum_real -= a_imag * b_imag
537
+ {
538
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_imag_f64m1, b_imag_f64m1, vector_length);
539
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_imag_f64m1, b_imag_f64m1,
540
+ vector_length);
541
+ vfloat64m1_t neg_product_f64m1 = __riscv_vfneg_v_f64m1(product_f64m1, vector_length);
542
+ vfloat64m1_t neg_product_error_f64m1 = __riscv_vfneg_v_f64m1(product_error_f64m1, vector_length);
543
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_real_f64m1, neg_product_f64m1, vector_length);
544
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_real_f64m1,
545
+ vector_length);
546
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
547
+ __riscv_vfsub_vv_f64m1(sum_real_f64m1,
548
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
549
+ vector_length),
550
+ __riscv_vfsub_vv_f64m1(neg_product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
551
+ sum_real_f64m1 = __riscv_vslideup_vx_f64m1_tu(sum_real_f64m1, tentative_sum_f64m1, 0, vector_length);
552
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, neg_product_error_f64m1,
553
+ vector_length);
554
+ comp_real_f64m1 = __riscv_vfadd_vv_f64m1_tu(comp_real_f64m1, comp_real_f64m1, total_error_f64m1,
555
+ vector_length);
556
+ }
557
+ // TwoProd+TwoSum: sum_imag += a_real * b_imag
558
+ {
559
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_real_f64m1, b_imag_f64m1, vector_length);
560
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_real_f64m1, b_imag_f64m1,
561
+ vector_length);
562
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_imag_f64m1, product_f64m1, vector_length);
563
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_imag_f64m1,
564
+ vector_length);
565
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
566
+ __riscv_vfsub_vv_f64m1(sum_imag_f64m1,
567
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
568
+ vector_length),
569
+ __riscv_vfsub_vv_f64m1(product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
570
+ sum_imag_f64m1 = __riscv_vslideup_vx_f64m1_tu(sum_imag_f64m1, tentative_sum_f64m1, 0, vector_length);
571
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, product_error_f64m1,
572
+ vector_length);
573
+ comp_imag_f64m1 = __riscv_vfadd_vv_f64m1_tu(comp_imag_f64m1, comp_imag_f64m1, total_error_f64m1,
574
+ vector_length);
575
+ }
576
+ // TwoProd+TwoSum: sum_imag += a_imag * b_real
577
+ {
578
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_imag_f64m1, b_real_f64m1, vector_length);
579
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_imag_f64m1, b_real_f64m1,
580
+ vector_length);
581
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_imag_f64m1, product_f64m1, vector_length);
582
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_imag_f64m1,
583
+ vector_length);
584
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
585
+ __riscv_vfsub_vv_f64m1(sum_imag_f64m1,
586
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
587
+ vector_length),
588
+ __riscv_vfsub_vv_f64m1(product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
589
+ sum_imag_f64m1 = __riscv_vslideup_vx_f64m1_tu(sum_imag_f64m1, tentative_sum_f64m1, 0, vector_length);
590
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, product_error_f64m1,
591
+ vector_length);
592
+ comp_imag_f64m1 = __riscv_vfadd_vv_f64m1_tu(comp_imag_f64m1, comp_imag_f64m1, total_error_f64m1,
593
+ vector_length);
594
+ }
595
+ }
596
+ results->real = nk_dot_stable_sum_f64m1_rvv_(sum_real_f64m1, comp_real_f64m1);
597
+ results->imag = nk_dot_stable_sum_f64m1_rvv_(sum_imag_f64m1, comp_imag_f64m1);
598
+ }
599
+
600
+ NK_PUBLIC void nk_vdot_f64c_rvv(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
601
+ nk_f64c_t *results) {
602
+ // Dot2 (Ogita-Rump-Oishi) compensated conjugate complex dot product
603
+ nk_f64_t const *a_f64 = (nk_f64_t const *)a_pairs;
604
+ nk_f64_t const *b_f64 = (nk_f64_t const *)b_pairs;
605
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
606
+ vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
607
+ vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
608
+ vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
609
+ vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
610
+ for (nk_size_t vector_length; count_pairs > 0;
611
+ count_pairs -= vector_length, a_f64 += vector_length * 2, b_f64 += vector_length * 2) {
612
+ vector_length = __riscv_vsetvl_e64m1(count_pairs);
613
+ vfloat64m1x2_t a_f64m1x2 = __riscv_vlseg2e64_v_f64m1x2(a_f64, vector_length);
614
+ vfloat64m1x2_t b_f64m1x2 = __riscv_vlseg2e64_v_f64m1x2(b_f64, vector_length);
615
+ vfloat64m1_t a_real_f64m1 = __riscv_vget_v_f64m1x2_f64m1(a_f64m1x2, 0);
616
+ vfloat64m1_t a_imag_f64m1 = __riscv_vget_v_f64m1x2_f64m1(a_f64m1x2, 1);
617
+ vfloat64m1_t b_real_f64m1 = __riscv_vget_v_f64m1x2_f64m1(b_f64m1x2, 0);
618
+ vfloat64m1_t b_imag_f64m1 = __riscv_vget_v_f64m1x2_f64m1(b_f64m1x2, 1);
619
+ // TwoProd+TwoSum: sum_real += a_real * b_real
620
+ {
621
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_real_f64m1, b_real_f64m1, vector_length);
622
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_real_f64m1, b_real_f64m1,
623
+ vector_length);
624
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_real_f64m1, product_f64m1, vector_length);
625
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_real_f64m1,
626
+ vector_length);
627
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
628
+ __riscv_vfsub_vv_f64m1(sum_real_f64m1,
629
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
630
+ vector_length),
631
+ __riscv_vfsub_vv_f64m1(product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
632
+ sum_real_f64m1 = __riscv_vslideup_vx_f64m1_tu(sum_real_f64m1, tentative_sum_f64m1, 0, vector_length);
633
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, product_error_f64m1,
634
+ vector_length);
635
+ comp_real_f64m1 = __riscv_vfadd_vv_f64m1_tu(comp_real_f64m1, comp_real_f64m1, total_error_f64m1,
636
+ vector_length);
637
+ }
638
+ // TwoProd+TwoSum: sum_real += a_imag * b_imag (conjugate: + instead of -)
639
+ {
640
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_imag_f64m1, b_imag_f64m1, vector_length);
641
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_imag_f64m1, b_imag_f64m1,
642
+ vector_length);
643
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_real_f64m1, product_f64m1, vector_length);
644
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_real_f64m1,
645
+ vector_length);
646
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
647
+ __riscv_vfsub_vv_f64m1(sum_real_f64m1,
648
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
649
+ vector_length),
650
+ __riscv_vfsub_vv_f64m1(product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
651
+ sum_real_f64m1 = __riscv_vslideup_vx_f64m1_tu(sum_real_f64m1, tentative_sum_f64m1, 0, vector_length);
652
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, product_error_f64m1,
653
+ vector_length);
654
+ comp_real_f64m1 = __riscv_vfadd_vv_f64m1_tu(comp_real_f64m1, comp_real_f64m1, total_error_f64m1,
655
+ vector_length);
656
+ }
657
+ // TwoProd+TwoSum: sum_imag += a_real * b_imag
658
+ {
659
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_real_f64m1, b_imag_f64m1, vector_length);
660
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_real_f64m1, b_imag_f64m1,
661
+ vector_length);
662
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_imag_f64m1, product_f64m1, vector_length);
663
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_imag_f64m1,
664
+ vector_length);
665
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
666
+ __riscv_vfsub_vv_f64m1(sum_imag_f64m1,
667
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
668
+ vector_length),
669
+ __riscv_vfsub_vv_f64m1(product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
670
+ sum_imag_f64m1 = __riscv_vslideup_vx_f64m1_tu(sum_imag_f64m1, tentative_sum_f64m1, 0, vector_length);
671
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, product_error_f64m1,
672
+ vector_length);
673
+ comp_imag_f64m1 = __riscv_vfadd_vv_f64m1_tu(comp_imag_f64m1, comp_imag_f64m1, total_error_f64m1,
674
+ vector_length);
675
+ }
676
+ // TwoProd+TwoSum: sum_imag -= a_imag * b_real (conjugate: - instead of +)
677
+ {
678
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_imag_f64m1, b_real_f64m1, vector_length);
679
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_imag_f64m1, b_real_f64m1,
680
+ vector_length);
681
+ vfloat64m1_t neg_product_f64m1 = __riscv_vfneg_v_f64m1(product_f64m1, vector_length);
682
+ vfloat64m1_t neg_product_error_f64m1 = __riscv_vfneg_v_f64m1(product_error_f64m1, vector_length);
683
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_imag_f64m1, neg_product_f64m1, vector_length);
684
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_imag_f64m1,
685
+ vector_length);
686
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
687
+ __riscv_vfsub_vv_f64m1(sum_imag_f64m1,
688
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
689
+ vector_length),
690
+ __riscv_vfsub_vv_f64m1(neg_product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
691
+ sum_imag_f64m1 = __riscv_vslideup_vx_f64m1_tu(sum_imag_f64m1, tentative_sum_f64m1, 0, vector_length);
692
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, neg_product_error_f64m1,
693
+ vector_length);
694
+ comp_imag_f64m1 = __riscv_vfadd_vv_f64m1_tu(comp_imag_f64m1, comp_imag_f64m1, total_error_f64m1,
695
+ vector_length);
696
+ }
697
+ }
698
+ results->real = nk_dot_stable_sum_f64m1_rvv_(sum_real_f64m1, comp_real_f64m1);
699
+ results->imag = nk_dot_stable_sum_f64m1_rvv_(sum_imag_f64m1, comp_imag_f64m1);
700
+ }
701
+
702
+ #if defined(__cplusplus)
703
+ } // extern "C"
704
+ #endif
705
+
706
+ #if defined(__clang__)
707
+ #pragma clang attribute pop
708
+ #elif defined(__GNUC__)
709
+ #pragma GCC pop_options
710
+ #endif
711
+
712
+ #endif // NK_TARGET_RVV
713
+ #endif // NK_TARGET_RISCV_
714
+ #endif // NK_DOT_RVV_H