numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,3407 @@
1
+ /**
2
+ * @brief SIMD-accelerated Reductions for RISC-V.
3
+ * @file include/numkong/reduce/rvv.h
4
+ * @author Ash Vardanian
5
+ * @date February 13, 2026
6
+ *
7
+ * @sa include/numkong/reduce.h
8
+ */
9
+ #ifndef NK_REDUCE_RVV_H
10
+ #define NK_REDUCE_RVV_H
11
+
12
+ #if NK_TARGET_RISCV_
13
+ #if NK_TARGET_RVV
14
+
15
+ #include "numkong/types.h"
16
+ #include "numkong/cast/rvv.h"
17
+ #include "numkong/reduce/serial.h"
18
+
19
+ #if defined(__clang__)
20
+ #pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
21
+ #elif defined(__GNUC__)
22
+ #pragma GCC push_options
23
+ #pragma GCC target("arch=+v")
24
+ #endif
25
+
26
+ #if defined(__cplusplus)
27
+ extern "C" {
28
+ #endif
29
+
30
+ /** @brief Saturating horizontal sum of u64m1 via tree fold: O(log vlmax) vector ops. */
31
+ NK_INTERNAL nk_u64_t nk_reduce_vsaddu_u64m1_rvv_(vuint64m1_t acc_u64m1, nk_size_t vlmax) {
32
+ for (nk_size_t half = vlmax >> 1; half > 0; half >>= 1) {
33
+ vuint64m1_t shifted_u64m1 = __riscv_vslidedown_vx_u64m1(acc_u64m1, half, vlmax);
34
+ acc_u64m1 = __riscv_vsaddu_vv_u64m1(acc_u64m1, shifted_u64m1, vlmax);
35
+ }
36
+ return __riscv_vmv_x_s_u64m1_u64(acc_u64m1);
37
+ }
38
+
39
+ /** @brief Saturating horizontal sum of u64m2 via tree fold: O(log vlmax) vector ops. */
40
+ NK_INTERNAL nk_u64_t nk_reduce_vsaddu_u64m2_rvv_(vuint64m2_t acc_u64m2, nk_size_t vlmax) {
41
+ for (nk_size_t half = vlmax >> 1; half > 0; half >>= 1) {
42
+ vuint64m2_t shifted_u64m2 = __riscv_vslidedown_vx_u64m2(acc_u64m2, half, vlmax);
43
+ acc_u64m2 = __riscv_vsaddu_vv_u64m2(acc_u64m2, shifted_u64m2, vlmax);
44
+ }
45
+ return __riscv_vmv_x_s_u64m2_u64(acc_u64m2);
46
+ }
47
+
48
+ /** @brief 128-bit horizontal sum of (upper:i64m1, lower:u64m1) via tree fold, then saturate to i64. */
49
+ NK_INTERNAL nk_i64_t nk_reduce_128bit_sum_i64m1_rvv_( //
50
+ vuint64m1_t sum_lower_u64m1, vint64m1_t sum_upper_i64m1, nk_size_t vlmax) {
51
+ for (nk_size_t half = vlmax >> 1; half > 0; half >>= 1) {
52
+ vuint64m1_t shifted_lower_u64m1 = __riscv_vslidedown_vx_u64m1(sum_lower_u64m1, half, vlmax);
53
+ vint64m1_t shifted_upper_i64m1 = __riscv_vslidedown_vx_i64m1(sum_upper_i64m1, half, vlmax);
54
+ vuint64m1_t new_lower_u64m1 = __riscv_vadd_vv_u64m1(sum_lower_u64m1, shifted_lower_u64m1, vlmax);
55
+ vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(new_lower_u64m1, sum_lower_u64m1, vlmax);
56
+ vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0, vlmax), 1, carry_b64, vlmax);
57
+ sum_upper_i64m1 = __riscv_vadd_vv_i64m1(sum_upper_i64m1, shifted_upper_i64m1, vlmax);
58
+ sum_upper_i64m1 = __riscv_vadd_vv_i64m1(sum_upper_i64m1, carry_i64m1, vlmax);
59
+ sum_lower_u64m1 = new_lower_u64m1;
60
+ }
61
+ nk_u64_t total_lower = __riscv_vmv_x_s_u64m1_u64(sum_lower_u64m1);
62
+ nk_i64_t total_upper = __riscv_vmv_x_s_i64m1_i64(sum_upper_i64m1);
63
+ nk_i64_t total_lower_signed = (nk_i64_t)total_lower;
64
+ if (total_upper == (total_lower_signed >> 63)) return total_lower_signed;
65
+ else if (total_upper >= 0) return NK_I64_MAX;
66
+ else return NK_I64_MIN;
67
+ }
68
+
69
+ /** @brief 128-bit horizontal sum of (upper:i64m2, lower:u64m2) via tree fold, then saturate to i64. */
70
+ NK_INTERNAL nk_i64_t nk_reduce_128bit_sum_i64m2_rvv_( //
71
+ vuint64m2_t sum_lower_u64m2, vint64m2_t sum_upper_i64m2, nk_size_t vlmax) {
72
+ for (nk_size_t half = vlmax >> 1; half > 0; half >>= 1) {
73
+ vuint64m2_t shifted_lower_u64m2 = __riscv_vslidedown_vx_u64m2(sum_lower_u64m2, half, vlmax);
74
+ vint64m2_t shifted_upper_i64m2 = __riscv_vslidedown_vx_i64m2(sum_upper_i64m2, half, vlmax);
75
+ vuint64m2_t new_lower_u64m2 = __riscv_vadd_vv_u64m2(sum_lower_u64m2, shifted_lower_u64m2, vlmax);
76
+ vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(new_lower_u64m2, sum_lower_u64m2, vlmax);
77
+ vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0, vlmax), 1, carry_b32, vlmax);
78
+ sum_upper_i64m2 = __riscv_vadd_vv_i64m2(sum_upper_i64m2, shifted_upper_i64m2, vlmax);
79
+ sum_upper_i64m2 = __riscv_vadd_vv_i64m2(sum_upper_i64m2, carry_i64m2, vlmax);
80
+ sum_lower_u64m2 = new_lower_u64m2;
81
+ }
82
+ nk_u64_t total_lower = __riscv_vmv_x_s_u64m2_u64(sum_lower_u64m2);
83
+ nk_i64_t total_upper = __riscv_vmv_x_s_i64m2_i64(sum_upper_i64m2);
84
+ nk_i64_t total_lower_signed = (nk_i64_t)total_lower;
85
+ if (total_upper == (total_lower_signed >> 63)) return total_lower_signed;
86
+ else if (total_upper >= 0) return NK_I64_MAX;
87
+ else return NK_I64_MIN;
88
+ }
89
+
90
+ NK_INTERNAL void nk_reduce_moments_f32_rvv_contiguous_( //
91
+ nk_f32_t const *data, nk_size_t count, //
92
+ nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
93
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
94
+ vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
95
+ vfloat64m2_t sumsq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
96
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data += vector_length) {
97
+ vector_length = __riscv_vsetvl_e32m1(count);
98
+ vfloat32m1_t data_f32m1 = __riscv_vle32_v_f32m1(data, vector_length);
99
+ sum_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_f64m2, sum_f64m2, data_f32m1, vector_length);
100
+ sumsq_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sumsq_f64m2, data_f32m1, data_f32m1, vector_length);
101
+ }
102
+ vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
103
+ *sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero, vlmax)),
104
+ *sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sumsq_f64m2, zero, vlmax));
105
+ }
106
+
107
+ NK_INTERNAL void nk_reduce_moments_f32_rvv_strided_( //
108
+ nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
109
+ nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
110
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
111
+ vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
112
+ vfloat64m2_t sumsq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
113
+ unsigned char const *ptr = (unsigned char const *)data;
114
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
115
+ vector_length = __riscv_vsetvl_e32m1(count);
116
+ vfloat32m1_t data_f32m1 = __riscv_vlse32_v_f32m1((nk_f32_t const *)ptr, (nk_ssize_t)stride_bytes,
117
+ vector_length);
118
+ sum_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_f64m2, sum_f64m2, data_f32m1, vector_length);
119
+ sumsq_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sumsq_f64m2, data_f32m1, data_f32m1, vector_length);
120
+ }
121
+ vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
122
+ *sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero, vlmax)),
123
+ *sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sumsq_f64m2, zero, vlmax));
124
+ }
125
+
126
+ NK_PUBLIC void nk_reduce_moments_f32_rvv( //
127
+ nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
128
+ nk_f64_t *sum, nk_f64_t *sumsq) {
129
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
130
+ int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
131
+ if (count == 0) *sum = 0, *sumsq = 0;
132
+ else if (!aligned) nk_reduce_moments_f32_serial(data, count, stride_bytes, sum, sumsq);
133
+ else if (stride_elements == 1) nk_reduce_moments_f32_rvv_contiguous_(data, count, sum, sumsq);
134
+ else nk_reduce_moments_f32_rvv_strided_(data, count, stride_bytes, sum, sumsq);
135
+ }
136
+
137
+ NK_INTERNAL void nk_reduce_minmax_f32_rvv_contiguous_( //
138
+ nk_f32_t const *data, nk_size_t count, //
139
+ nk_f32_t *min_value, nk_size_t *min_index, //
140
+ nk_f32_t *max_value, nk_size_t *max_index) {
141
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
142
+ vfloat32m1_t min = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, vlmax);
143
+ vfloat32m1_t max = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, vlmax);
144
+ vuint64m2_t min_indices = __riscv_vmv_v_x_u64m2(0, vlmax);
145
+ vuint64m2_t max_indices = __riscv_vmv_v_x_u64m2(0, vlmax);
146
+ nk_size_t offset = 0;
147
+ for (nk_size_t remaining = count, vector_length; remaining > 0;
148
+ remaining -= vector_length, offset += vector_length) {
149
+ vector_length = __riscv_vsetvl_e32m1(remaining);
150
+ vfloat32m1_t data_f32m1 = __riscv_vle32_v_f32m1(data + offset, vector_length);
151
+ vuint64m2_t position_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
152
+ vector_length);
153
+ vbool32_t less_b32 = __riscv_vmflt_vv_f32m1_b32(data_f32m1, min, vector_length);
154
+ min = __riscv_vmerge_vvm_f32m1_tu(min, min, data_f32m1, less_b32, vector_length);
155
+ min_indices = __riscv_vmerge_vvm_u64m2_tu(min_indices, min_indices, position_u64m2, less_b32, vector_length);
156
+ vbool32_t greater_b32 = __riscv_vmflt_vv_f32m1_b32(max, data_f32m1, vector_length);
157
+ max = __riscv_vmerge_vvm_f32m1_tu(max, max, data_f32m1, greater_b32, vector_length);
158
+ max_indices = __riscv_vmerge_vvm_u64m2_tu(max_indices, max_indices, position_u64m2, greater_b32, vector_length);
159
+ }
160
+ vfloat32m1_t id_max = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
161
+ nk_f32_t mn = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmin_vs_f32m1_f32m1(min, id_max, vlmax));
162
+ vfloat32m1_t id_min = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
163
+ nk_f32_t mx = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmax_vs_f32m1_f32m1(max, id_min, vlmax));
164
+ if (mn == NK_F32_MAX && mx == NK_F32_MIN) {
165
+ *min_value = NK_F32_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F32_MIN, *max_index = NK_SIZE_MAX;
166
+ return;
167
+ }
168
+ vbool32_t min_match_b32 = __riscv_vmfeq_vf_f32m1_b32(min, mn, vlmax);
169
+ vuint64m2_t sentinel = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
170
+ vuint64m2_t min_cands = __riscv_vmerge_vvm_u64m2(sentinel, min_indices, min_match_b32, vlmax);
171
+ vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
172
+ *min_value = mn,
173
+ *min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m2_u64m1(min_cands, id_umax, vlmax));
174
+ vbool32_t max_match_b32 = __riscv_vmfeq_vf_f32m1_b32(max, mx, vlmax);
175
+ vuint64m2_t max_cands = __riscv_vmerge_vvm_u64m2(sentinel, max_indices, max_match_b32, vlmax);
176
+ *max_value = mx,
177
+ *max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m2_u64m1(max_cands, id_umax, vlmax));
178
+ }
179
+
180
+ NK_INTERNAL void nk_reduce_minmax_f32_rvv_strided_( //
181
+ nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
182
+ nk_f32_t *min_value, nk_size_t *min_index, //
183
+ nk_f32_t *max_value, nk_size_t *max_index) {
184
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
185
+ vfloat32m1_t min = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, vlmax);
186
+ vfloat32m1_t max = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, vlmax);
187
+ vuint64m2_t min_indices = __riscv_vmv_v_x_u64m2(0, vlmax);
188
+ vuint64m2_t max_indices = __riscv_vmv_v_x_u64m2(0, vlmax);
189
+ unsigned char const *ptr = (unsigned char const *)data;
190
+ nk_size_t offset = 0;
191
+ for (nk_size_t remaining = count, vector_length; remaining > 0;
192
+ remaining -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
193
+ vector_length = __riscv_vsetvl_e32m1(remaining);
194
+ vfloat32m1_t data_f32m1 = __riscv_vlse32_v_f32m1((nk_f32_t const *)ptr, (nk_ssize_t)stride_bytes,
195
+ vector_length);
196
+ vuint64m2_t position_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
197
+ vector_length);
198
+ vbool32_t less_b32 = __riscv_vmflt_vv_f32m1_b32(data_f32m1, min, vector_length);
199
+ min = __riscv_vmerge_vvm_f32m1_tu(min, min, data_f32m1, less_b32, vector_length);
200
+ min_indices = __riscv_vmerge_vvm_u64m2_tu(min_indices, min_indices, position_u64m2, less_b32, vector_length);
201
+ vbool32_t greater_b32 = __riscv_vmflt_vv_f32m1_b32(max, data_f32m1, vector_length);
202
+ max = __riscv_vmerge_vvm_f32m1_tu(max, max, data_f32m1, greater_b32, vector_length);
203
+ max_indices = __riscv_vmerge_vvm_u64m2_tu(max_indices, max_indices, position_u64m2, greater_b32, vector_length);
204
+ }
205
+ vfloat32m1_t id_max = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
206
+ nk_f32_t mn = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmin_vs_f32m1_f32m1(min, id_max, vlmax));
207
+ vfloat32m1_t id_min = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
208
+ nk_f32_t mx = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmax_vs_f32m1_f32m1(max, id_min, vlmax));
209
+ if (mn == NK_F32_MAX && mx == NK_F32_MIN) {
210
+ *min_value = NK_F32_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F32_MIN, *max_index = NK_SIZE_MAX;
211
+ return;
212
+ }
213
+ vbool32_t min_match_b32 = __riscv_vmfeq_vf_f32m1_b32(min, mn, vlmax);
214
+ vuint64m2_t sentinel = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
215
+ vuint64m2_t min_cands = __riscv_vmerge_vvm_u64m2(sentinel, min_indices, min_match_b32, vlmax);
216
+ vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
217
+ *min_value = mn,
218
+ *min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m2_u64m1(min_cands, id_umax, vlmax));
219
+ vbool32_t max_match_b32 = __riscv_vmfeq_vf_f32m1_b32(max, mx, vlmax);
220
+ vuint64m2_t max_cands = __riscv_vmerge_vvm_u64m2(sentinel, max_indices, max_match_b32, vlmax);
221
+ *max_value = mx,
222
+ *max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m2_u64m1(max_cands, id_umax, vlmax));
223
+ }
224
+
225
+ NK_PUBLIC void nk_reduce_minmax_f32_rvv( //
226
+ nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
227
+ nk_f32_t *min_value, nk_size_t *min_index, //
228
+ nk_f32_t *max_value, nk_size_t *max_index) {
229
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
230
+ int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
231
+ if (count == 0)
232
+ *min_value = NK_F32_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F32_MIN, *max_index = NK_SIZE_MAX;
233
+ else if (!aligned)
234
+ nk_reduce_minmax_f32_serial(data, count, stride_bytes, min_value, min_index, max_value, max_index);
235
+ else if (stride_elements == 1)
236
+ nk_reduce_minmax_f32_rvv_contiguous_(data, count, min_value, min_index, max_value, max_index);
237
+ else nk_reduce_minmax_f32_rvv_strided_(data, count, stride_bytes, min_value, min_index, max_value, max_index);
238
+ }
239
+
240
+ NK_INTERNAL void nk_reduce_moments_f64_rvv_contiguous_( //
241
+ nk_f64_t const *data, nk_size_t count, //
242
+ nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
243
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
244
+ vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
245
+ vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
246
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data += vector_length) {
247
+ vector_length = __riscv_vsetvl_e64m4(count);
248
+ vfloat64m4_t data_f64m4 = __riscv_vle64_v_f64m4(data, vector_length);
249
+ sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
250
+ sumsq_f64m4 = __riscv_vfmacc_vv_f64m4_tu(sumsq_f64m4, data_f64m4, data_f64m4, vector_length);
251
+ }
252
+ vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
253
+ *sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero, vlmax)),
254
+ *sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero, vlmax));
255
+ }
256
+
257
+ NK_INTERNAL void nk_reduce_moments_f64_rvv_strided_( //
258
+ nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
259
+ nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
260
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
261
+ vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
262
+ vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
263
+ unsigned char const *ptr = (unsigned char const *)data;
264
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
265
+ vector_length = __riscv_vsetvl_e64m4(count);
266
+ vfloat64m4_t data_f64m4 = __riscv_vlse64_v_f64m4((nk_f64_t const *)ptr, (nk_ssize_t)stride_bytes,
267
+ vector_length);
268
+ sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
269
+ sumsq_f64m4 = __riscv_vfmacc_vv_f64m4_tu(sumsq_f64m4, data_f64m4, data_f64m4, vector_length);
270
+ }
271
+ vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
272
+ *sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero, vlmax)),
273
+ *sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero, vlmax));
274
+ }
275
+
276
+ NK_PUBLIC void nk_reduce_moments_f64_rvv( //
277
+ nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
278
+ nk_f64_t *sum, nk_f64_t *sumsq) {
279
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
280
+ int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
281
+ if (count == 0) *sum = 0, *sumsq = 0;
282
+ else if (!aligned) nk_reduce_moments_f64_serial(data, count, stride_bytes, sum, sumsq);
283
+ else if (stride_elements == 1) nk_reduce_moments_f64_rvv_contiguous_(data, count, sum, sumsq);
284
+ else nk_reduce_moments_f64_rvv_strided_(data, count, stride_bytes, sum, sumsq);
285
+ }
286
+
287
+ NK_INTERNAL void nk_reduce_minmax_f64_rvv_contiguous_( //
288
+ nk_f64_t const *data, nk_size_t count, //
289
+ nk_f64_t *min_value, nk_size_t *min_index, //
290
+ nk_f64_t *max_value, nk_size_t *max_index) {
291
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
292
+ vfloat64m1_t min = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, vlmax);
293
+ vfloat64m1_t max = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, vlmax);
294
+ vuint64m1_t min_indices = __riscv_vmv_v_x_u64m1(0, vlmax);
295
+ vuint64m1_t max_indices = __riscv_vmv_v_x_u64m1(0, vlmax);
296
+ nk_size_t offset = 0;
297
+ for (nk_size_t remaining = count, vector_length; remaining > 0;
298
+ remaining -= vector_length, offset += vector_length) {
299
+ vector_length = __riscv_vsetvl_e64m1(remaining);
300
+ vfloat64m1_t data_f64m1 = __riscv_vle64_v_f64m1(data + offset, vector_length);
301
+ vuint64m1_t position_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
302
+ vector_length);
303
+ vbool64_t less_b64 = __riscv_vmflt_vv_f64m1_b64(data_f64m1, min, vector_length);
304
+ min = __riscv_vmerge_vvm_f64m1_tu(min, min, data_f64m1, less_b64, vector_length);
305
+ min_indices = __riscv_vmerge_vvm_u64m1_tu(min_indices, min_indices, position_u64m1, less_b64, vector_length);
306
+ vbool64_t greater_b64 = __riscv_vmflt_vv_f64m1_b64(max, data_f64m1, vector_length);
307
+ max = __riscv_vmerge_vvm_f64m1_tu(max, max, data_f64m1, greater_b64, vector_length);
308
+ max_indices = __riscv_vmerge_vvm_u64m1_tu(max_indices, max_indices, position_u64m1, greater_b64, vector_length);
309
+ }
310
+ vfloat64m1_t id_max = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, 1);
311
+ nk_f64_t mn = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmin_vs_f64m1_f64m1(min, id_max, vlmax));
312
+ vfloat64m1_t id_min = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, 1);
313
+ nk_f64_t mx = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmax_vs_f64m1_f64m1(max, id_min, vlmax));
314
+ if (mn == NK_F64_MAX && mx == NK_F64_MIN) {
315
+ *min_value = NK_F64_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F64_MIN, *max_index = NK_SIZE_MAX;
316
+ return;
317
+ }
318
+ vbool64_t min_match_b64 = __riscv_vmfeq_vf_f64m1_b64(min, mn, vlmax);
319
+ vuint64m1_t sentinel = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
320
+ vuint64m1_t min_cands = __riscv_vmerge_vvm_u64m1(sentinel, min_indices, min_match_b64, vlmax);
321
+ vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
322
+ *min_value = mn,
323
+ *min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(min_cands, id_umax, vlmax));
324
+ vbool64_t max_match_b64 = __riscv_vmfeq_vf_f64m1_b64(max, mx, vlmax);
325
+ vuint64m1_t max_cands = __riscv_vmerge_vvm_u64m1(sentinel, max_indices, max_match_b64, vlmax);
326
+ *max_value = mx,
327
+ *max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(max_cands, id_umax, vlmax));
328
+ }
329
+
330
+ NK_INTERNAL void nk_reduce_minmax_f64_rvv_strided_( //
331
+ nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
332
+ nk_f64_t *min_value, nk_size_t *min_index, //
333
+ nk_f64_t *max_value, nk_size_t *max_index) {
334
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
335
+ vfloat64m1_t min = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, vlmax);
336
+ vfloat64m1_t max = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, vlmax);
337
+ vuint64m1_t min_indices = __riscv_vmv_v_x_u64m1(0, vlmax);
338
+ vuint64m1_t max_indices = __riscv_vmv_v_x_u64m1(0, vlmax);
339
+ unsigned char const *ptr = (unsigned char const *)data;
340
+ nk_size_t offset = 0;
341
+ for (nk_size_t remaining = count, vector_length; remaining > 0;
342
+ remaining -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
343
+ vector_length = __riscv_vsetvl_e64m1(remaining);
344
+ vfloat64m1_t data_f64m1 = __riscv_vlse64_v_f64m1((nk_f64_t const *)ptr, (nk_ssize_t)stride_bytes,
345
+ vector_length);
346
+ vuint64m1_t position_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
347
+ vector_length);
348
+ vbool64_t less_b64 = __riscv_vmflt_vv_f64m1_b64(data_f64m1, min, vector_length);
349
+ min = __riscv_vmerge_vvm_f64m1_tu(min, min, data_f64m1, less_b64, vector_length);
350
+ min_indices = __riscv_vmerge_vvm_u64m1_tu(min_indices, min_indices, position_u64m1, less_b64, vector_length);
351
+ vbool64_t greater_b64 = __riscv_vmflt_vv_f64m1_b64(max, data_f64m1, vector_length);
352
+ max = __riscv_vmerge_vvm_f64m1_tu(max, max, data_f64m1, greater_b64, vector_length);
353
+ max_indices = __riscv_vmerge_vvm_u64m1_tu(max_indices, max_indices, position_u64m1, greater_b64, vector_length);
354
+ }
355
+ vfloat64m1_t id_max = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, 1);
356
+ nk_f64_t mn = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmin_vs_f64m1_f64m1(min, id_max, vlmax));
357
+ vfloat64m1_t id_min = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, 1);
358
+ nk_f64_t mx = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmax_vs_f64m1_f64m1(max, id_min, vlmax));
359
+ if (mn == NK_F64_MAX && mx == NK_F64_MIN) {
360
+ *min_value = NK_F64_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F64_MIN, *max_index = NK_SIZE_MAX;
361
+ return;
362
+ }
363
+ vbool64_t min_match_b64 = __riscv_vmfeq_vf_f64m1_b64(min, mn, vlmax);
364
+ vuint64m1_t sentinel = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
365
+ vuint64m1_t min_cands = __riscv_vmerge_vvm_u64m1(sentinel, min_indices, min_match_b64, vlmax);
366
+ vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
367
+ *min_value = mn,
368
+ *min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(min_cands, id_umax, vlmax));
369
+ vbool64_t max_match_b64 = __riscv_vmfeq_vf_f64m1_b64(max, mx, vlmax);
370
+ vuint64m1_t max_cands = __riscv_vmerge_vvm_u64m1(sentinel, max_indices, max_match_b64, vlmax);
371
+ *max_value = mx,
372
+ *max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(max_cands, id_umax, vlmax));
373
+ }
374
+
375
+ NK_PUBLIC void nk_reduce_minmax_f64_rvv( //
376
+ nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
377
+ nk_f64_t *min_value, nk_size_t *min_index, //
378
+ nk_f64_t *max_value, nk_size_t *max_index) {
379
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
380
+ int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
381
+ if (count == 0)
382
+ *min_value = NK_F64_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F64_MIN, *max_index = NK_SIZE_MAX;
383
+ else if (!aligned)
384
+ nk_reduce_minmax_f64_serial(data, count, stride_bytes, min_value, min_index, max_value, max_index);
385
+ else if (stride_elements == 1)
386
+ nk_reduce_minmax_f64_rvv_contiguous_(data, count, min_value, min_index, max_value, max_index);
387
+ else nk_reduce_minmax_f64_rvv_strided_(data, count, stride_bytes, min_value, min_index, max_value, max_index);
388
+ }
389
+
390
+ NK_INTERNAL vuint8m1_t nk_fp8m1_to_comparable_u8m1_rvv_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
391
+ // Convert FP8 (e4m3/e5m2) to comparable unsigned form (sign bit 7)
392
+ // Positive (sign=0): XOR 0x80 → [0x80, 0xFF]
393
+ // Negative (sign=1): Bitwise NOT → [0x00, 0x7F]
394
+ vbool8_t is_negative_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw_u8m1, 0x80, vector_length), 0,
395
+ vector_length);
396
+ vuint8m1_t flip_positive_u8m1 = __riscv_vxor_vx_u8m1(raw_u8m1, 0x80, vector_length);
397
+ vuint8m1_t flip_negative_u8m1 = __riscv_vnot_v_u8m1(raw_u8m1, vector_length);
398
+ return __riscv_vmerge_vvm_u8m1(flip_positive_u8m1, flip_negative_u8m1, is_negative_b8, vector_length);
399
+ }
400
+
401
+ NK_INTERNAL vuint8m1_t nk_comparable_to_fp8m1_rvv_(vuint8m1_t comparable_u8m1, nk_size_t vector_length) {
402
+ // Reverse: if >= 0x80 (was positive), XOR; else NOT
403
+ vbool8_t was_positive_b8 = __riscv_vmsgeu_vx_u8m1_b8(comparable_u8m1, 0x80, vector_length);
404
+ vuint8m1_t from_positive_u8m1 = __riscv_vxor_vx_u8m1(comparable_u8m1, 0x80, vector_length);
405
+ vuint8m1_t from_negative_u8m1 = __riscv_vnot_v_u8m1(comparable_u8m1, vector_length);
406
+ return __riscv_vmerge_vvm_u8m1(from_negative_u8m1, from_positive_u8m1, was_positive_b8, vector_length);
407
+ }
408
+
409
+ NK_INTERNAL vuint8m1_t nk_fp6m1_to_comparable_u8m1_rvv_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
410
+ // Convert FP6 (e2m3/e3m2) to comparable unsigned form (sign bit 5)
411
+ // Positive (sign=0): XOR 0x20 → [0x20, 0x3F]
412
+ // Negative (sign=1): XOR 0x3F (NOT lower 6 bits) → [0x00, 0x1F]
413
+ vbool8_t is_negative_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw_u8m1, 0x20, vector_length), 0,
414
+ vector_length);
415
+ vuint8m1_t flip_positive_u8m1 = __riscv_vxor_vx_u8m1(raw_u8m1, 0x20, vector_length);
416
+ vuint8m1_t flip_negative_u8m1 = __riscv_vxor_vx_u8m1(raw_u8m1, 0x3F, vector_length);
417
+ return __riscv_vmerge_vvm_u8m1(flip_positive_u8m1, flip_negative_u8m1, is_negative_b8, vector_length);
418
+ }
419
+
420
+ NK_INTERNAL vuint8m1_t nk_comparable_to_fp6m1_rvv_(vuint8m1_t comparable_u8m1, nk_size_t vector_length) {
421
+ // Reverse: if >= 0x20 (was positive), XOR 0x20; else XOR 0x3F (NOT lower 6 bits)
422
+ vbool8_t was_positive_b8 = __riscv_vmsgeu_vx_u8m1_b8(comparable_u8m1, 0x20, vector_length);
423
+ vuint8m1_t from_positive_u8m1 = __riscv_vxor_vx_u8m1(comparable_u8m1, 0x20, vector_length);
424
+ vuint8m1_t from_negative_u8m1 = __riscv_vxor_vx_u8m1(comparable_u8m1, 0x3F, vector_length);
425
+ return __riscv_vmerge_vvm_u8m1(from_negative_u8m1, from_positive_u8m1, was_positive_b8, vector_length);
426
+ }
427
+
428
+ NK_INTERNAL void nk_reduce_moments_i8_rvv_contiguous_( //
429
+ nk_i8_t const *data_ptr, nk_size_t count, //
430
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
431
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
432
+ nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
433
+ vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, vlmax);
434
+ vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
435
+ vint8m1_t zero_i8m1 = __riscv_vmv_v_x_i8m1(0, vlmax_elements);
436
+
437
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
438
+ vector_length = __riscv_vsetvl_e8m1(count);
439
+ vint8m1_t data_i8m1 = __riscv_vle8_v_i8m1_tu(zero_i8m1, data_ptr, vector_length);
440
+
441
+ // Widen i8 → i16 → i32 → i64 for sum
442
+ vint16m2_t data_i16m2 = __riscv_vsext_vf2_i16m2(data_i8m1, vlmax_elements);
443
+ vint32m4_t data_i32m4 = __riscv_vsext_vf2_i32m4(data_i16m2, vlmax_elements);
444
+ vint64m8_t data_i64m8 = __riscv_vsext_vf2_i64m8(data_i32m4, vlmax_elements);
445
+
446
+ // Accumulate sum (split m8 into two m4)
447
+ sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 0), vlmax);
448
+ sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 1), vlmax);
449
+
450
+ // Sumsq: i8 × i8 → i16 (widening multiply)
451
+ vint16m2_t squares_i16m2 = __riscv_vwmul_vv_i16m2(data_i8m1, data_i8m1, vlmax_elements);
452
+ // Widen i16 → u32 → u64
453
+ vuint32m4_t squares_u32m4 = __riscv_vwcvtu_x_x_v_u32m4(__riscv_vreinterpret_v_i16m2_u16m2(squares_i16m2),
454
+ vlmax_elements);
455
+ vuint64m8_t squares_u64m8 = __riscv_vwcvtu_x_x_v_u64m8(squares_u32m4, vlmax_elements);
456
+
457
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vlmax);
458
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vlmax);
459
+ }
460
+
461
+ // Horizontal reduction
462
+ vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
463
+ *sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, vlmax));
464
+
465
+ vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
466
+ *sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
467
+ }
468
+
469
+ NK_INTERNAL void nk_reduce_moments_i8_rvv_strided_( //
470
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
471
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
472
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
473
+ nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
474
+ vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, vlmax);
475
+ vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
476
+ vint8m1_t zero_i8m1 = __riscv_vmv_v_x_i8m1(0, vlmax_elements);
477
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
478
+
479
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
480
+ vector_length = __riscv_vsetvl_e8m1(count);
481
+ vint8m1_t data_i8m1 = __riscv_vlse8_v_i8m1_tu(zero_i8m1, (nk_i8_t const *)ptr, (nk_ssize_t)stride_bytes,
482
+ vector_length);
483
+
484
+ // Widen i8 → i16 → i32 → i64 for sum
485
+ vint16m2_t data_i16m2 = __riscv_vsext_vf2_i16m2(data_i8m1, vlmax_elements);
486
+ vint32m4_t data_i32m4 = __riscv_vsext_vf2_i32m4(data_i16m2, vlmax_elements);
487
+ vint64m8_t data_i64m8 = __riscv_vsext_vf2_i64m8(data_i32m4, vlmax_elements);
488
+
489
+ // Accumulate sum (split m8 into two m4)
490
+ sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 0), vlmax);
491
+ sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 1), vlmax);
492
+
493
+ // Sumsq: i8 × i8 → i16 (widening multiply)
494
+ vint16m2_t squares_i16m2 = __riscv_vwmul_vv_i16m2(data_i8m1, data_i8m1, vlmax_elements);
495
+ // Widen i16 → u32 → u64
496
+ vuint32m4_t squares_u32m4 = __riscv_vwcvtu_x_x_v_u32m4(__riscv_vreinterpret_v_i16m2_u16m2(squares_i16m2),
497
+ vlmax_elements);
498
+ vuint64m8_t squares_u64m8 = __riscv_vwcvtu_x_x_v_u64m8(squares_u32m4, vlmax_elements);
499
+
500
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vlmax);
501
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vlmax);
502
+ }
503
+
504
+ // Horizontal reduction
505
+ vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
506
+ *sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, vlmax));
507
+
508
+ vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
509
+ *sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
510
+ }
511
+
512
+ NK_PUBLIC void nk_reduce_moments_i8_rvv( //
513
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
514
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
515
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
516
+ int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
517
+
518
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
519
+ else if (!aligned) { nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
520
+ else if (stride_elements == 1) { nk_reduce_moments_i8_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
521
+ else { nk_reduce_moments_i8_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
522
+ }
523
+
524
+ NK_INTERNAL void nk_reduce_minmax_i8_rvv_contiguous_( //
525
+ nk_i8_t const *data_ptr, nk_size_t count, //
526
+ nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
527
+ nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
528
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
529
+ vint8m1_t min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, vlmax);
530
+ vint8m1_t max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, vlmax);
531
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
532
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
533
+
534
+ nk_size_t offset = 0;
535
+ for (nk_size_t vector_length; count > 0;
536
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
537
+ vector_length = __riscv_vsetvl_e8m1(count);
538
+ vint8m1_t data_i8m1 = __riscv_vle8_v_i8m1(data_ptr, vector_length);
539
+
540
+ // VID-based absolute indices
541
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
542
+ vector_length);
543
+
544
+ vbool8_t less_b8 = __riscv_vmslt_vv_i8m1_b8(data_i8m1, min_i8m1, vector_length);
545
+ min_i8m1 = __riscv_vmerge_vvm_i8m1_tu(min_i8m1, min_i8m1, data_i8m1, less_b8, vector_length);
546
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
547
+ vector_length);
548
+
549
+ vbool8_t greater_b8 = __riscv_vmslt_vv_i8m1_b8(max_i8m1, data_i8m1, vector_length);
550
+ max_i8m1 = __riscv_vmerge_vvm_i8m1_tu(max_i8m1, max_i8m1, data_i8m1, greater_b8, vector_length);
551
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
552
+ vector_length);
553
+ }
554
+
555
+ // Horizontal reduction for min
556
+ vint8m1_t init_max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, 1);
557
+ nk_i8_t min_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmin_vs_i8m1_i8m1(min_i8m1, init_max_i8m1, vlmax));
558
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_i8m1_b8(min_i8m1, min_val, vlmax);
559
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
560
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
561
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
562
+ *min_value_ptr = min_val;
563
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
564
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
565
+
566
+ // Horizontal reduction for max
567
+ vint8m1_t init_min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, 1);
568
+ nk_i8_t max_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmax_vs_i8m1_i8m1(max_i8m1, init_min_i8m1, vlmax));
569
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_i8m1_b8(max_i8m1, max_val, vlmax);
570
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
571
+ *max_value_ptr = max_val;
572
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
573
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
574
+ }
575
+
576
+ NK_INTERNAL void nk_reduce_minmax_i8_rvv_strided_( //
577
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
578
+ nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
579
+ nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
580
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
581
+ vint8m1_t min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, vlmax);
582
+ vint8m1_t max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, vlmax);
583
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
584
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
585
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
586
+
587
+ nk_size_t offset = 0;
588
+ for (nk_size_t vector_length; count > 0;
589
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
590
+ vector_length = __riscv_vsetvl_e8m1(count);
591
+ vint8m1_t data_i8m1 = __riscv_vlse8_v_i8m1((nk_i8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
592
+
593
+ // VID-based absolute indices
594
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
595
+ vector_length);
596
+
597
+ vbool8_t less_b8 = __riscv_vmslt_vv_i8m1_b8(data_i8m1, min_i8m1, vector_length);
598
+ min_i8m1 = __riscv_vmerge_vvm_i8m1_tu(min_i8m1, min_i8m1, data_i8m1, less_b8, vector_length);
599
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
600
+ vector_length);
601
+
602
+ vbool8_t greater_b8 = __riscv_vmslt_vv_i8m1_b8(max_i8m1, data_i8m1, vector_length);
603
+ max_i8m1 = __riscv_vmerge_vvm_i8m1_tu(max_i8m1, max_i8m1, data_i8m1, greater_b8, vector_length);
604
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
605
+ vector_length);
606
+ }
607
+
608
+ // Horizontal reduction for min
609
+ vint8m1_t init_max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, 1);
610
+ nk_i8_t min_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmin_vs_i8m1_i8m1(min_i8m1, init_max_i8m1, vlmax));
611
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_i8m1_b8(min_i8m1, min_val, vlmax);
612
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
613
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
614
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
615
+ *min_value_ptr = min_val;
616
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
617
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
618
+
619
+ // Horizontal reduction for max
620
+ vint8m1_t init_min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, 1);
621
+ nk_i8_t max_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmax_vs_i8m1_i8m1(max_i8m1, init_min_i8m1, vlmax));
622
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_i8m1_b8(max_i8m1, max_val, vlmax);
623
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
624
+ *max_value_ptr = max_val;
625
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
626
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
627
+ }
628
+
629
+ NK_PUBLIC void nk_reduce_minmax_i8_rvv( //
630
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
631
+ nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
632
+ nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
633
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
634
+ int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
635
+
636
+ if (count == 0)
637
+ *min_value_ptr = NK_I8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I8_MIN,
638
+ *max_index_ptr = NK_SIZE_MAX;
639
+ else if (!aligned)
640
+ nk_reduce_minmax_i8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
641
+ max_index_ptr);
642
+ else if (stride_elements == 1)
643
+ nk_reduce_minmax_i8_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
644
+ max_index_ptr);
645
+ else
646
+ nk_reduce_minmax_i8_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
647
+ max_index_ptr);
648
+ }
649
+
650
+ NK_INTERNAL void nk_reduce_moments_u8_rvv_contiguous_( //
651
+ nk_u8_t const *data_ptr, nk_size_t count, //
652
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
653
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
654
+ nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
655
+ vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
656
+ vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
657
+ vuint8m1_t zero_u8m1 = __riscv_vmv_v_x_u8m1(0, vlmax_elements);
658
+
659
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
660
+ vector_length = __riscv_vsetvl_e8m1(count);
661
+ vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1_tu(zero_u8m1, data_ptr, vector_length);
662
+
663
+ // Widen u8 → u16 → u32 → u64 for sum
664
+ vuint16m2_t data_u16m2 = __riscv_vzext_vf2_u16m2(data_u8m1, vlmax_elements);
665
+ vuint32m4_t data_u32m4 = __riscv_vzext_vf2_u32m4(data_u16m2, vlmax_elements);
666
+ vuint64m8_t data_u64m8 = __riscv_vzext_vf2_u64m8(data_u32m4, vlmax_elements);
667
+
668
+ // Accumulate sum (split m8 into two m4)
669
+ sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 0), vlmax);
670
+ sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 1), vlmax);
671
+
672
+ // Sumsq: u8 × u8 → u16 (widening multiply)
673
+ vuint16m2_t squares_u16m2 = __riscv_vwmulu_vv_u16m2(data_u8m1, data_u8m1, vlmax_elements);
674
+ // Widen u16 → u32 → u64
675
+ vuint32m4_t squares_u32m4 = __riscv_vzext_vf2_u32m4(squares_u16m2, vlmax_elements);
676
+ vuint64m8_t squares_u64m8 = __riscv_vzext_vf2_u64m8(squares_u32m4, vlmax_elements);
677
+
678
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vlmax);
679
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vlmax);
680
+ }
681
+
682
+ // Horizontal reduction
683
+ vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
684
+ *sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, vlmax)),
685
+ *sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
686
+ }
687
+
688
+ NK_INTERNAL void nk_reduce_moments_u8_rvv_strided_( //
689
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
690
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
691
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
692
+ nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
693
+ vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
694
+ vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
695
+ vuint8m1_t zero_u8m1 = __riscv_vmv_v_x_u8m1(0, vlmax_elements);
696
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
697
+
698
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
699
+ vector_length = __riscv_vsetvl_e8m1(count);
700
+ vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1_tu(zero_u8m1, (nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes,
701
+ vector_length);
702
+
703
+ // Widen u8 → u16 → u32 → u64 for sum
704
+ vuint16m2_t data_u16m2 = __riscv_vzext_vf2_u16m2(data_u8m1, vlmax_elements);
705
+ vuint32m4_t data_u32m4 = __riscv_vzext_vf2_u32m4(data_u16m2, vlmax_elements);
706
+ vuint64m8_t data_u64m8 = __riscv_vzext_vf2_u64m8(data_u32m4, vlmax_elements);
707
+
708
+ // Accumulate sum (split m8 into two m4)
709
+ sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 0), vlmax);
710
+ sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 1), vlmax);
711
+
712
+ // Sumsq: u8 × u8 → u16 (widening multiply)
713
+ vuint16m2_t squares_u16m2 = __riscv_vwmulu_vv_u16m2(data_u8m1, data_u8m1, vlmax_elements);
714
+ // Widen u16 → u32 → u64
715
+ vuint32m4_t squares_u32m4 = __riscv_vzext_vf2_u32m4(squares_u16m2, vlmax_elements);
716
+ vuint64m8_t squares_u64m8 = __riscv_vzext_vf2_u64m8(squares_u32m4, vlmax_elements);
717
+
718
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vlmax);
719
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vlmax);
720
+ }
721
+
722
+ // Horizontal reduction
723
+ vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
724
+ *sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, vlmax)),
725
+ *sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
726
+ }
727
+
728
+ NK_PUBLIC void nk_reduce_moments_u8_rvv( //
729
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
730
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
731
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
732
+ int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
733
+
734
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
735
+ else if (!aligned) { nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
736
+ else if (stride_elements == 1) { nk_reduce_moments_u8_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
737
+ else { nk_reduce_moments_u8_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
738
+ }
739
+
740
+ NK_INTERNAL void nk_reduce_minmax_u8_rvv_contiguous_( //
741
+ nk_u8_t const *data_ptr, nk_size_t count, //
742
+ nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
743
+ nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
744
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
745
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, vlmax);
746
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, vlmax);
747
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
748
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
749
+
750
+ nk_size_t offset = 0;
751
+ for (nk_size_t vector_length; count > 0;
752
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
753
+ vector_length = __riscv_vsetvl_e8m1(count);
754
+ vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1(data_ptr, vector_length);
755
+
756
+ // VID-based absolute indices
757
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
758
+ vector_length);
759
+
760
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_u8m1, min_u8m1, vector_length);
761
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_u8m1, less_b8, vector_length);
762
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
763
+ vector_length);
764
+
765
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_u8m1, vector_length);
766
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_u8m1, greater_b8, vector_length);
767
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
768
+ vector_length);
769
+ }
770
+
771
+ // Horizontal reduction for min
772
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, 1);
773
+ nk_u8_t min_val = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
774
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_val, vlmax);
775
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
776
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
777
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
778
+ *min_value_ptr = min_val;
779
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
780
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
781
+
782
+ // Horizontal reduction for max
783
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, 1);
784
+ nk_u8_t max_val = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
785
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_val, vlmax);
786
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
787
+ *max_value_ptr = max_val;
788
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
789
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
790
+ }
791
+
792
+ NK_INTERNAL void nk_reduce_minmax_u8_rvv_strided_( //
793
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
794
+ nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
795
+ nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
796
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
797
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, vlmax);
798
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, vlmax);
799
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
800
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
801
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
802
+
803
+ nk_size_t offset = 0;
804
+ for (nk_size_t vector_length; count > 0;
805
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
806
+ vector_length = __riscv_vsetvl_e8m1(count);
807
+ vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
808
+
809
+ // VID-based absolute indices
810
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
811
+ vector_length);
812
+
813
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_u8m1, min_u8m1, vector_length);
814
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_u8m1, less_b8, vector_length);
815
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
816
+ vector_length);
817
+
818
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_u8m1, vector_length);
819
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_u8m1, greater_b8, vector_length);
820
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
821
+ vector_length);
822
+ }
823
+
824
+ // Horizontal reduction for min
825
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, 1);
826
+ nk_u8_t min_val = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
827
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_val, vlmax);
828
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
829
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
830
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
831
+ *min_value_ptr = min_val;
832
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
833
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
834
+
835
+ // Horizontal reduction for max
836
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, 1);
837
+ nk_u8_t max_val = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
838
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_val, vlmax);
839
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
840
+ *max_value_ptr = max_val;
841
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
842
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
843
+ }
844
+
845
+ NK_PUBLIC void nk_reduce_minmax_u8_rvv( //
846
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
847
+ nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
848
+ nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
849
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
850
+ int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
851
+
852
+ if (count == 0)
853
+ *min_value_ptr = NK_U8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_U8_MIN,
854
+ *max_index_ptr = NK_SIZE_MAX;
855
+ else if (!aligned)
856
+ nk_reduce_minmax_u8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
857
+ max_index_ptr);
858
+ else if (stride_elements == 1)
859
+ nk_reduce_minmax_u8_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
860
+ max_index_ptr);
861
+ else
862
+ nk_reduce_minmax_u8_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
863
+ max_index_ptr);
864
+ }
865
+
866
+ NK_INTERNAL void nk_reduce_moments_i16_rvv_contiguous_( //
867
+ nk_i16_t const *data_ptr, nk_size_t count, //
868
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
869
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
870
+ vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, vlmax);
871
+ vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
872
+
873
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
874
+ vector_length = __riscv_vsetvl_e16m1(count);
875
+ vint16m1_t data_i16m1 = __riscv_vle16_v_i16m1(data_ptr, vector_length);
876
+
877
+ // Widen i16 → i32 → i64 for sum
878
+ vint32m2_t data_i32m2 = __riscv_vsext_vf2_i32m2(data_i16m1, vector_length);
879
+ vint64m4_t data_i64m4 = __riscv_vsext_vf2_i64m4(data_i32m2, vector_length);
880
+ sum_i64m4 = __riscv_vadd_vv_i64m4_tu(sum_i64m4, sum_i64m4, data_i64m4, vector_length);
881
+
882
+ // Sumsq: i16 × i16 → i32 (widening multiply)
883
+ vint32m2_t squares_i32m2 = __riscv_vwmul_vv_i32m2(data_i16m1, data_i16m1, vector_length);
884
+ // Widen i32 → u64
885
+ vuint64m4_t squares_u64m4 = __riscv_vwcvtu_x_x_v_u64m4(__riscv_vreinterpret_v_i32m2_u32m2(squares_i32m2),
886
+ vector_length);
887
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4_tu(sumsq_u64m4, sumsq_u64m4, squares_u64m4, vector_length);
888
+ }
889
+
890
+ // Horizontal reduction
891
+ vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
892
+ *sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, vlmax));
893
+
894
+ vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
895
+ *sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
896
+ }
897
+
898
+ NK_INTERNAL void nk_reduce_moments_i16_rvv_strided_( //
899
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
900
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
901
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
902
+ vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, vlmax);
903
+ vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
904
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
905
+
906
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
907
+ vector_length = __riscv_vsetvl_e16m1(count);
908
+ vint16m1_t data_i16m1 = __riscv_vlse16_v_i16m1((nk_i16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
909
+
910
+ // Widen i16 → i32 → i64 for sum
911
+ vint32m2_t data_i32m2 = __riscv_vsext_vf2_i32m2(data_i16m1, vector_length);
912
+ vint64m4_t data_i64m4 = __riscv_vsext_vf2_i64m4(data_i32m2, vector_length);
913
+ sum_i64m4 = __riscv_vadd_vv_i64m4_tu(sum_i64m4, sum_i64m4, data_i64m4, vector_length);
914
+
915
+ // Sumsq: i16 × i16 → i32 (widening multiply)
916
+ vint32m2_t squares_i32m2 = __riscv_vwmul_vv_i32m2(data_i16m1, data_i16m1, vector_length);
917
+ // Widen i32 → u64
918
+ vuint64m4_t squares_u64m4 = __riscv_vwcvtu_x_x_v_u64m4(__riscv_vreinterpret_v_i32m2_u32m2(squares_i32m2),
919
+ vector_length);
920
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4_tu(sumsq_u64m4, sumsq_u64m4, squares_u64m4, vector_length);
921
+ }
922
+
923
+ // Horizontal reduction
924
+ vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
925
+ *sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, vlmax));
926
+
927
+ vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
928
+ *sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
929
+ }
930
+
931
+ NK_PUBLIC void nk_reduce_moments_i16_rvv( //
932
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
933
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
934
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
935
+ int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
936
+
937
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
938
+ else if (!aligned) { nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
939
+ else if (stride_elements == 1) { nk_reduce_moments_i16_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
940
+ else { nk_reduce_moments_i16_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
941
+ }
942
+
943
+ NK_INTERNAL void nk_reduce_minmax_i16_rvv_contiguous_( //
944
+ nk_i16_t const *data_ptr, nk_size_t count, //
945
+ nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
946
+ nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
947
+ nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
948
+ vint16m1_t min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, vlmax);
949
+ vint16m1_t max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, vlmax);
950
+ vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
951
+ vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
952
+
953
+ nk_size_t offset = 0;
954
+ for (nk_size_t vector_length; count > 0;
955
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
956
+ vector_length = __riscv_vsetvl_e16m1(count);
957
+ vint16m1_t data_i16m1 = __riscv_vle16_v_i16m1(data_ptr, vector_length);
958
+ vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
959
+ vector_length);
960
+
961
+ vbool16_t less_b16 = __riscv_vmslt_vv_i16m1_b16(data_i16m1, min_i16m1, vector_length);
962
+ min_i16m1 = __riscv_vmerge_vvm_i16m1_tu(min_i16m1, min_i16m1, data_i16m1, less_b16, vector_length);
963
+ min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
964
+ vector_length);
965
+
966
+ vbool16_t greater_b16 = __riscv_vmslt_vv_i16m1_b16(max_i16m1, data_i16m1, vector_length);
967
+ max_i16m1 = __riscv_vmerge_vvm_i16m1_tu(max_i16m1, max_i16m1, data_i16m1, greater_b16, vector_length);
968
+ max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
969
+ vector_length);
970
+ }
971
+
972
+ // Horizontal reduction for min
973
+ vint16m1_t init_max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, 1);
974
+ nk_i16_t min_val = __riscv_vmv_x_s_i16m1_i16(__riscv_vredmin_vs_i16m1_i16m1(min_i16m1, init_max_i16m1, vlmax));
975
+ vbool16_t min_match_b16 = __riscv_vmseq_vx_i16m1_b16(min_i16m1, min_val, vlmax);
976
+ vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
977
+ vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
978
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
979
+ *min_value_ptr = min_val;
980
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
981
+ __riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
982
+
983
+ // Horizontal reduction for max
984
+ vint16m1_t init_min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, 1);
985
+ nk_i16_t max_val = __riscv_vmv_x_s_i16m1_i16(__riscv_vredmax_vs_i16m1_i16m1(max_i16m1, init_min_i16m1, vlmax));
986
+ vbool16_t max_match_b16 = __riscv_vmseq_vx_i16m1_b16(max_i16m1, max_val, vlmax);
987
+ vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
988
+ *max_value_ptr = max_val;
989
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
990
+ __riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
991
+ }
992
+
993
+ NK_INTERNAL void nk_reduce_minmax_i16_rvv_strided_( //
994
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
995
+ nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
996
+ nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
997
+ nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
998
+ vint16m1_t min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, vlmax);
999
+ vint16m1_t max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, vlmax);
1000
+ vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1001
+ vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1002
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1003
+
1004
+ nk_size_t offset = 0;
1005
+ for (nk_size_t vector_length; count > 0;
1006
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
1007
+ vector_length = __riscv_vsetvl_e16m1(count);
1008
+ vint16m1_t data_i16m1 = __riscv_vlse16_v_i16m1((nk_i16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1009
+ vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
1010
+ vector_length);
1011
+
1012
+ vbool16_t less_b16 = __riscv_vmslt_vv_i16m1_b16(data_i16m1, min_i16m1, vector_length);
1013
+ min_i16m1 = __riscv_vmerge_vvm_i16m1_tu(min_i16m1, min_i16m1, data_i16m1, less_b16, vector_length);
1014
+ min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
1015
+ vector_length);
1016
+
1017
+ vbool16_t greater_b16 = __riscv_vmslt_vv_i16m1_b16(max_i16m1, data_i16m1, vector_length);
1018
+ max_i16m1 = __riscv_vmerge_vvm_i16m1_tu(max_i16m1, max_i16m1, data_i16m1, greater_b16, vector_length);
1019
+ max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
1020
+ vector_length);
1021
+ }
1022
+
1023
+ // Horizontal reduction for min
1024
+ vint16m1_t init_max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, 1);
1025
+ nk_i16_t min_val = __riscv_vmv_x_s_i16m1_i16(__riscv_vredmin_vs_i16m1_i16m1(min_i16m1, init_max_i16m1, vlmax));
1026
+ vbool16_t min_match_b16 = __riscv_vmseq_vx_i16m1_b16(min_i16m1, min_val, vlmax);
1027
+ vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
1028
+ vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
1029
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1030
+ *min_value_ptr = min_val;
1031
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1032
+ __riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
1033
+
1034
+ // Horizontal reduction for max
1035
+ vint16m1_t init_min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, 1);
1036
+ nk_i16_t max_val = __riscv_vmv_x_s_i16m1_i16(__riscv_vredmax_vs_i16m1_i16m1(max_i16m1, init_min_i16m1, vlmax));
1037
+ vbool16_t max_match_b16 = __riscv_vmseq_vx_i16m1_b16(max_i16m1, max_val, vlmax);
1038
+ vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
1039
+ *max_value_ptr = max_val;
1040
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1041
+ __riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
1042
+ }
1043
+
1044
+ NK_PUBLIC void nk_reduce_minmax_i16_rvv( //
1045
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1046
+ nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1047
+ nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1048
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
1049
+ int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
1050
+
1051
+ if (count == 0)
1052
+ *min_value_ptr = NK_I16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I16_MIN,
1053
+ *max_index_ptr = NK_SIZE_MAX;
1054
+ else if (!aligned)
1055
+ nk_reduce_minmax_i16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1056
+ max_index_ptr);
1057
+ else if (stride_elements == 1)
1058
+ nk_reduce_minmax_i16_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
1059
+ max_index_ptr);
1060
+ else
1061
+ nk_reduce_minmax_i16_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1062
+ max_index_ptr);
1063
+ }
1064
+
1065
+ NK_INTERNAL void nk_reduce_moments_u16_rvv_contiguous_( //
1066
+ nk_u16_t const *data_ptr, nk_size_t count, //
1067
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1068
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
1069
+ vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1070
+ vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1071
+
1072
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
1073
+ vector_length = __riscv_vsetvl_e16m1(count);
1074
+ vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1(data_ptr, vector_length);
1075
+
1076
+ // Widen u16 → u32 → u64 for sum
1077
+ vuint32m2_t data_u32m2 = __riscv_vzext_vf2_u32m2(data_u16m1, vector_length);
1078
+ vuint64m4_t data_u64m4 = __riscv_vzext_vf2_u64m4(data_u32m2, vector_length);
1079
+ sum_u64m4 = __riscv_vadd_vv_u64m4_tu(sum_u64m4, sum_u64m4, data_u64m4, vector_length);
1080
+
1081
+ // Sumsq: u16 × u16 → u32 (widening multiply)
1082
+ vuint32m2_t squares_u32m2 = __riscv_vwmulu_vv_u32m2(data_u16m1, data_u16m1, vector_length);
1083
+ // Widen u32 → u64
1084
+ vuint64m4_t squares_u64m4 = __riscv_vzext_vf2_u64m4(squares_u32m2, vector_length);
1085
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4_tu(sumsq_u64m4, sumsq_u64m4, squares_u64m4, vector_length);
1086
+ }
1087
+
1088
+ // Horizontal reduction
1089
+ vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
1090
+ *sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, vlmax)),
1091
+ *sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
1092
+ }
1093
+
1094
+ NK_INTERNAL void nk_reduce_moments_u16_rvv_strided_( //
1095
+ nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1096
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1097
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
1098
+ vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1099
+ vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1100
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1101
+
1102
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
1103
+ vector_length = __riscv_vsetvl_e16m1(count);
1104
+ vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((nk_u16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1105
+
1106
+ // Widen u16 → u32 → u64 for sum
1107
+ vuint32m2_t data_u32m2 = __riscv_vzext_vf2_u32m2(data_u16m1, vector_length);
1108
+ vuint64m4_t data_u64m4 = __riscv_vzext_vf2_u64m4(data_u32m2, vector_length);
1109
+ sum_u64m4 = __riscv_vadd_vv_u64m4_tu(sum_u64m4, sum_u64m4, data_u64m4, vector_length);
1110
+
1111
+ // Sumsq: u16 × u16 → u32 (widening multiply)
1112
+ vuint32m2_t squares_u32m2 = __riscv_vwmulu_vv_u32m2(data_u16m1, data_u16m1, vector_length);
1113
+ // Widen u32 → u64
1114
+ vuint64m4_t squares_u64m4 = __riscv_vzext_vf2_u64m4(squares_u32m2, vector_length);
1115
+ sumsq_u64m4 = __riscv_vadd_vv_u64m4_tu(sumsq_u64m4, sumsq_u64m4, squares_u64m4, vector_length);
1116
+ }
1117
+
1118
+ // Horizontal reduction
1119
+ vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
1120
+ *sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, vlmax)),
1121
+ *sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
1122
+ }
1123
+
1124
+ NK_PUBLIC void nk_reduce_moments_u16_rvv( //
1125
+ nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1126
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1127
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
1128
+ int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
1129
+
1130
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
1131
+ else if (!aligned) { nk_reduce_moments_u16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1132
+ else if (stride_elements == 1) { nk_reduce_moments_u16_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
1133
+ else { nk_reduce_moments_u16_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1134
+ }
1135
+
1136
+ NK_INTERNAL void nk_reduce_minmax_u16_rvv_contiguous_( //
1137
+ nk_u16_t const *data_ptr, nk_size_t count, //
1138
+ nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1139
+ nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1140
+ nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
1141
+ vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, vlmax);
1142
+ vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, vlmax);
1143
+ vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1144
+ vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1145
+
1146
+ nk_size_t offset = 0;
1147
+ for (nk_size_t vector_length; count > 0;
1148
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
1149
+ vector_length = __riscv_vsetvl_e16m1(count);
1150
+ vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1(data_ptr, vector_length);
1151
+ vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
1152
+ vector_length);
1153
+
1154
+ vbool16_t less_b16 = __riscv_vmsltu_vv_u16m1_b16(data_u16m1, min_u16m1, vector_length);
1155
+ min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
1156
+ min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
1157
+ vector_length);
1158
+
1159
+ vbool16_t greater_b16 = __riscv_vmsltu_vv_u16m1_b16(max_u16m1, data_u16m1, vector_length);
1160
+ max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
1161
+ max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
1162
+ vector_length);
1163
+ }
1164
+
1165
+ // Horizontal reduction for min
1166
+ vuint16m1_t init_max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, 1);
1167
+ nk_u16_t min_val = __riscv_vmv_x_s_u16m1_u16(__riscv_vredminu_vs_u16m1_u16m1(min_u16m1, init_max_u16m1, vlmax));
1168
+ vbool16_t min_match_b16 = __riscv_vmseq_vx_u16m1_b16(min_u16m1, min_val, vlmax);
1169
+ vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
1170
+ vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
1171
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1172
+ *min_value_ptr = min_val;
1173
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1174
+ __riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
1175
+
1176
+ // Horizontal reduction for max
1177
+ vuint16m1_t init_min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, 1);
1178
+ nk_u16_t max_val = __riscv_vmv_x_s_u16m1_u16(__riscv_vredmaxu_vs_u16m1_u16m1(max_u16m1, init_min_u16m1, vlmax));
1179
+ vbool16_t max_match_b16 = __riscv_vmseq_vx_u16m1_b16(max_u16m1, max_val, vlmax);
1180
+ vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
1181
+ *max_value_ptr = max_val;
1182
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1183
+ __riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
1184
+ }
1185
+
1186
+ NK_INTERNAL void nk_reduce_minmax_u16_rvv_strided_( //
1187
+ nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1188
+ nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1189
+ nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1190
+ nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
1191
+ vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, vlmax);
1192
+ vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, vlmax);
1193
+ vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1194
+ vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
1195
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1196
+
1197
+ nk_size_t offset = 0;
1198
+ for (nk_size_t vector_length; count > 0;
1199
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
1200
+ vector_length = __riscv_vsetvl_e16m1(count);
1201
+ vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((nk_u16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1202
+ vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
1203
+ vector_length);
1204
+
1205
+ vbool16_t less_b16 = __riscv_vmsltu_vv_u16m1_b16(data_u16m1, min_u16m1, vector_length);
1206
+ min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
1207
+ min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
1208
+ vector_length);
1209
+
1210
+ vbool16_t greater_b16 = __riscv_vmsltu_vv_u16m1_b16(max_u16m1, data_u16m1, vector_length);
1211
+ max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
1212
+ max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
1213
+ vector_length);
1214
+ }
1215
+
1216
+ // Horizontal reduction for min
1217
+ vuint16m1_t init_max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, 1);
1218
+ nk_u16_t min_val = __riscv_vmv_x_s_u16m1_u16(__riscv_vredminu_vs_u16m1_u16m1(min_u16m1, init_max_u16m1, vlmax));
1219
+ vbool16_t min_match_b16 = __riscv_vmseq_vx_u16m1_b16(min_u16m1, min_val, vlmax);
1220
+ vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
1221
+ vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
1222
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1223
+ *min_value_ptr = min_val;
1224
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1225
+ __riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
1226
+
1227
+ // Horizontal reduction for max
1228
+ vuint16m1_t init_min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, 1);
1229
+ nk_u16_t max_val = __riscv_vmv_x_s_u16m1_u16(__riscv_vredmaxu_vs_u16m1_u16m1(max_u16m1, init_min_u16m1, vlmax));
1230
+ vbool16_t max_match_b16 = __riscv_vmseq_vx_u16m1_b16(max_u16m1, max_val, vlmax);
1231
+ vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
1232
+ *max_value_ptr = max_val;
1233
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1234
+ __riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
1235
+ }
1236
+
1237
+ NK_PUBLIC void nk_reduce_minmax_u16_rvv( //
1238
+ nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1239
+ nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1240
+ nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1241
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
1242
+ int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
1243
+
1244
+ if (count == 0)
1245
+ *min_value_ptr = NK_U16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_U16_MIN,
1246
+ *max_index_ptr = NK_SIZE_MAX;
1247
+ else if (!aligned)
1248
+ nk_reduce_minmax_u16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1249
+ max_index_ptr);
1250
+ else if (stride_elements == 1)
1251
+ nk_reduce_minmax_u16_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
1252
+ max_index_ptr);
1253
+ else
1254
+ nk_reduce_minmax_u16_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1255
+ max_index_ptr);
1256
+ }
1257
+
1258
+ NK_INTERNAL void nk_reduce_moments_i32_rvv_contiguous_( //
1259
+ nk_i32_t const *data_ptr, nk_size_t count, //
1260
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1261
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
1262
+ // 128-bit per-lane accumulator for sum: (sum_upper, sum_lower)
1263
+ vuint64m2_t sum_lower_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1264
+ vint64m2_t sum_upper_i64m2 = __riscv_vmv_v_x_i64m2(0, vlmax);
1265
+ vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1266
+
1267
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
1268
+ vector_length = __riscv_vsetvl_e32m1(count);
1269
+ vint32m1_t data_i32m1 = __riscv_vle32_v_i32m1(data_ptr, vector_length);
1270
+
1271
+ // Widen i32 → i64
1272
+ vint64m2_t data_i64m2 = __riscv_vsext_vf2_i64m2(data_i32m1, vector_length);
1273
+ vuint64m2_t data_u64m2 = __riscv_vreinterpret_v_i64m2_u64m2(data_i64m2);
1274
+
1275
+ // 128-bit accumulation: wrapping add on lower half
1276
+ vuint64m2_t sum_before_u64m2 = sum_lower_u64m2;
1277
+ sum_lower_u64m2 = __riscv_vadd_vv_u64m2_tu(sum_lower_u64m2, sum_lower_u64m2, data_u64m2, vector_length);
1278
+
1279
+ // Carry: new < old means unsigned overflow occurred
1280
+ vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(sum_lower_u64m2, sum_before_u64m2, vector_length);
1281
+ vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0, vector_length), 1, carry_b32,
1282
+ vector_length);
1283
+ sum_upper_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_upper_i64m2, sum_upper_i64m2, carry_i64m2, vector_length);
1284
+
1285
+ // Sign extension: -1 for negative, 0 for non-negative
1286
+ vint64m2_t sign_ext_i64m2 = __riscv_vsra_vx_i64m2(data_i64m2, 63, vector_length);
1287
+ sum_upper_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_upper_i64m2, sum_upper_i64m2, sign_ext_i64m2, vector_length);
1288
+
1289
+ // Sumsq: i32 × i32 → i64 (widening multiply, result ≤ 2^62), saturating accumulation
1290
+ vint64m2_t squares_i64m2 = __riscv_vwmul_vv_i64m2(data_i32m1, data_i32m1, vector_length);
1291
+ sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2,
1292
+ __riscv_vreinterpret_v_i64m2_u64m2(squares_i64m2), vector_length);
1293
+ }
1294
+
1295
+ *sum_ptr = nk_reduce_128bit_sum_i64m2_rvv_(sum_lower_u64m2, sum_upper_i64m2, vlmax);
1296
+ *sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, vlmax);
1297
+ }
1298
+
1299
+ NK_INTERNAL void nk_reduce_moments_i32_rvv_strided_( //
1300
+ nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1301
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1302
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
1303
+ // 128-bit per-lane accumulator for sum: (sum_upper, sum_lower)
1304
+ vuint64m2_t sum_lower_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1305
+ vint64m2_t sum_upper_i64m2 = __riscv_vmv_v_x_i64m2(0, vlmax);
1306
+ vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1307
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1308
+
1309
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
1310
+ vector_length = __riscv_vsetvl_e32m1(count);
1311
+ vint32m1_t data_i32m1 = __riscv_vlse32_v_i32m1((nk_i32_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1312
+
1313
+ // Widen i32 → i64
1314
+ vint64m2_t data_i64m2 = __riscv_vsext_vf2_i64m2(data_i32m1, vector_length);
1315
+ vuint64m2_t data_u64m2 = __riscv_vreinterpret_v_i64m2_u64m2(data_i64m2);
1316
+
1317
+ // 128-bit accumulation: wrapping add on lower half
1318
+ vuint64m2_t sum_before_u64m2 = sum_lower_u64m2;
1319
+ sum_lower_u64m2 = __riscv_vadd_vv_u64m2_tu(sum_lower_u64m2, sum_lower_u64m2, data_u64m2, vector_length);
1320
+
1321
+ // Carry: new < old means unsigned overflow occurred
1322
+ vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(sum_lower_u64m2, sum_before_u64m2, vector_length);
1323
+ vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0, vector_length), 1, carry_b32,
1324
+ vector_length);
1325
+ sum_upper_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_upper_i64m2, sum_upper_i64m2, carry_i64m2, vector_length);
1326
+
1327
+ // Sign extension: -1 for negative, 0 for non-negative
1328
+ vint64m2_t sign_ext_i64m2 = __riscv_vsra_vx_i64m2(data_i64m2, 63, vector_length);
1329
+ sum_upper_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_upper_i64m2, sum_upper_i64m2, sign_ext_i64m2, vector_length);
1330
+
1331
+ // Sumsq: i32 × i32 → i64 (widening multiply, result ≤ 2^62), saturating accumulation
1332
+ vint64m2_t squares_i64m2 = __riscv_vwmul_vv_i64m2(data_i32m1, data_i32m1, vector_length);
1333
+ sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2,
1334
+ __riscv_vreinterpret_v_i64m2_u64m2(squares_i64m2), vector_length);
1335
+ }
1336
+
1337
+ *sum_ptr = nk_reduce_128bit_sum_i64m2_rvv_(sum_lower_u64m2, sum_upper_i64m2, vlmax);
1338
+ *sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, vlmax);
1339
+ }
1340
+
1341
+ NK_PUBLIC void nk_reduce_moments_i32_rvv( //
1342
+ nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1343
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1344
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
1345
+ int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
1346
+
1347
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
1348
+ else if (!aligned) { nk_reduce_moments_i32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1349
+ else if (stride_elements == 1) { nk_reduce_moments_i32_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
1350
+ else { nk_reduce_moments_i32_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1351
+ }
1352
+
1353
+ NK_INTERNAL void nk_reduce_minmax_i32_rvv_contiguous_( //
1354
+ nk_i32_t const *data_ptr, nk_size_t count, //
1355
+ nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
1356
+ nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
1357
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
1358
+ vint32m1_t min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, vlmax);
1359
+ vint32m1_t max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, vlmax);
1360
+ vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1361
+ vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1362
+
1363
+ nk_size_t offset = 0;
1364
+ for (nk_size_t vector_length; count > 0;
1365
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
1366
+ vector_length = __riscv_vsetvl_e32m1(count);
1367
+ vint32m1_t data_i32m1 = __riscv_vle32_v_i32m1(data_ptr, vector_length);
1368
+ vuint64m2_t pos_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
1369
+ vector_length);
1370
+
1371
+ vbool32_t less_b32 = __riscv_vmslt_vv_i32m1_b32(data_i32m1, min_i32m1, vector_length);
1372
+ min_i32m1 = __riscv_vmerge_vvm_i32m1_tu(min_i32m1, min_i32m1, data_i32m1, less_b32, vector_length);
1373
+ min_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(min_indices_u64m2, min_indices_u64m2, pos_u64m2, less_b32,
1374
+ vector_length);
1375
+
1376
+ vbool32_t greater_b32 = __riscv_vmslt_vv_i32m1_b32(max_i32m1, data_i32m1, vector_length);
1377
+ max_i32m1 = __riscv_vmerge_vvm_i32m1_tu(max_i32m1, max_i32m1, data_i32m1, greater_b32, vector_length);
1378
+ max_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(max_indices_u64m2, max_indices_u64m2, pos_u64m2, greater_b32,
1379
+ vector_length);
1380
+ }
1381
+
1382
+ // Horizontal reduction for min
1383
+ vint32m1_t init_max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, 1);
1384
+ nk_i32_t min_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredmin_vs_i32m1_i32m1(min_i32m1, init_max_i32m1, vlmax));
1385
+ vbool32_t min_match_b32 = __riscv_vmseq_vx_i32m1_b32(min_i32m1, min_val, vlmax);
1386
+ vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
1387
+ vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32, vlmax);
1388
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1389
+ *min_value_ptr = min_val;
1390
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1391
+ __riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, vlmax));
1392
+
1393
+ // Horizontal reduction for max
1394
+ vint32m1_t init_min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, 1);
1395
+ nk_i32_t max_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredmax_vs_i32m1_i32m1(max_i32m1, init_min_i32m1, vlmax));
1396
+ vbool32_t max_match_b32 = __riscv_vmseq_vx_i32m1_b32(max_i32m1, max_val, vlmax);
1397
+ vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32, vlmax);
1398
+ *max_value_ptr = max_val;
1399
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1400
+ __riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, vlmax));
1401
+ }
1402
+
1403
+ NK_INTERNAL void nk_reduce_minmax_i32_rvv_strided_( //
1404
+ nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1405
+ nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
1406
+ nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
1407
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
1408
+ vint32m1_t min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, vlmax);
1409
+ vint32m1_t max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, vlmax);
1410
+ vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1411
+ vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1412
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1413
+
1414
+ nk_size_t offset = 0;
1415
+ for (nk_size_t vector_length; count > 0;
1416
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
1417
+ vector_length = __riscv_vsetvl_e32m1(count);
1418
+ vint32m1_t data_i32m1 = __riscv_vlse32_v_i32m1((nk_i32_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1419
+ vuint64m2_t pos_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
1420
+ vector_length);
1421
+
1422
+ vbool32_t less_b32 = __riscv_vmslt_vv_i32m1_b32(data_i32m1, min_i32m1, vector_length);
1423
+ min_i32m1 = __riscv_vmerge_vvm_i32m1_tu(min_i32m1, min_i32m1, data_i32m1, less_b32, vector_length);
1424
+ min_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(min_indices_u64m2, min_indices_u64m2, pos_u64m2, less_b32,
1425
+ vector_length);
1426
+
1427
+ vbool32_t greater_b32 = __riscv_vmslt_vv_i32m1_b32(max_i32m1, data_i32m1, vector_length);
1428
+ max_i32m1 = __riscv_vmerge_vvm_i32m1_tu(max_i32m1, max_i32m1, data_i32m1, greater_b32, vector_length);
1429
+ max_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(max_indices_u64m2, max_indices_u64m2, pos_u64m2, greater_b32,
1430
+ vector_length);
1431
+ }
1432
+
1433
+ // Horizontal reduction for min
1434
+ vint32m1_t init_max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, 1);
1435
+ nk_i32_t min_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredmin_vs_i32m1_i32m1(min_i32m1, init_max_i32m1, vlmax));
1436
+ vbool32_t min_match_b32 = __riscv_vmseq_vx_i32m1_b32(min_i32m1, min_val, vlmax);
1437
+ vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
1438
+ vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32, vlmax);
1439
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1440
+ *min_value_ptr = min_val;
1441
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1442
+ __riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, vlmax));
1443
+
1444
+ // Horizontal reduction for max
1445
+ vint32m1_t init_min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, 1);
1446
+ nk_i32_t max_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredmax_vs_i32m1_i32m1(max_i32m1, init_min_i32m1, vlmax));
1447
+ vbool32_t max_match_b32 = __riscv_vmseq_vx_i32m1_b32(max_i32m1, max_val, vlmax);
1448
+ vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32, vlmax);
1449
+ *max_value_ptr = max_val;
1450
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1451
+ __riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, vlmax));
1452
+ }
1453
+
1454
+ NK_PUBLIC void nk_reduce_minmax_i32_rvv( //
1455
+ nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1456
+ nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
1457
+ nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
1458
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
1459
+ int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
1460
+
1461
+ if (count == 0)
1462
+ *min_value_ptr = NK_I32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I32_MIN,
1463
+ *max_index_ptr = NK_SIZE_MAX;
1464
+ else if (!aligned)
1465
+ nk_reduce_minmax_i32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1466
+ max_index_ptr);
1467
+ else if (stride_elements == 1)
1468
+ nk_reduce_minmax_i32_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
1469
+ max_index_ptr);
1470
+ else
1471
+ nk_reduce_minmax_i32_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1472
+ max_index_ptr);
1473
+ }
1474
+
1475
+ NK_INTERNAL void nk_reduce_moments_u32_rvv_contiguous_( //
1476
+ nk_u32_t const *data_ptr, nk_size_t count, //
1477
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1478
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
1479
+ vuint64m2_t sum_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1480
+ vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1481
+
1482
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
1483
+ vector_length = __riscv_vsetvl_e32m1(count);
1484
+ vuint32m1_t data_u32m1 = __riscv_vle32_v_u32m1(data_ptr, vector_length);
1485
+
1486
+ // Widen u32 → u64 for saturating sum
1487
+ vuint64m2_t data_u64m2 = __riscv_vzext_vf2_u64m2(data_u32m1, vector_length);
1488
+ sum_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sum_u64m2, sum_u64m2, data_u64m2, vector_length);
1489
+
1490
+ // Sumsq: u32 × u32 → u64 (widening multiply, no overflow), saturating accumulation
1491
+ vuint64m2_t squares_u64m2 = __riscv_vwmulu_vv_u64m2(data_u32m1, data_u32m1, vector_length);
1492
+ sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2, squares_u64m2, vector_length);
1493
+ }
1494
+
1495
+ *sum_ptr = nk_reduce_vsaddu_u64m2_rvv_(sum_u64m2, vlmax);
1496
+ *sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, vlmax);
1497
+ }
1498
+
1499
+ NK_INTERNAL void nk_reduce_moments_u32_rvv_strided_( //
1500
+ nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1501
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1502
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
1503
+ vuint64m2_t sum_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1504
+ vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1505
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1506
+
1507
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
1508
+ vector_length = __riscv_vsetvl_e32m1(count);
1509
+ vuint32m1_t data_u32m1 = __riscv_vlse32_v_u32m1((nk_u32_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1510
+
1511
+ // Widen u32 → u64 for saturating sum
1512
+ vuint64m2_t data_u64m2 = __riscv_vzext_vf2_u64m2(data_u32m1, vector_length);
1513
+ sum_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sum_u64m2, sum_u64m2, data_u64m2, vector_length);
1514
+
1515
+ // Sumsq: u32 × u32 → u64 (widening multiply, no overflow), saturating accumulation
1516
+ vuint64m2_t squares_u64m2 = __riscv_vwmulu_vv_u64m2(data_u32m1, data_u32m1, vector_length);
1517
+ sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2, squares_u64m2, vector_length);
1518
+ }
1519
+
1520
+ *sum_ptr = nk_reduce_vsaddu_u64m2_rvv_(sum_u64m2, vlmax);
1521
+ *sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, vlmax);
1522
+ }
1523
+
1524
+ NK_PUBLIC void nk_reduce_moments_u32_rvv( //
1525
+ nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1526
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1527
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
1528
+ int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
1529
+
1530
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
1531
+ else if (!aligned) { nk_reduce_moments_u32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1532
+ else if (stride_elements == 1) { nk_reduce_moments_u32_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
1533
+ else { nk_reduce_moments_u32_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1534
+ }
1535
+
1536
+ NK_INTERNAL void nk_reduce_minmax_u32_rvv_contiguous_( //
1537
+ nk_u32_t const *data_ptr, nk_size_t count, //
1538
+ nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
1539
+ nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
1540
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
1541
+ vuint32m1_t min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, vlmax);
1542
+ vuint32m1_t max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, vlmax);
1543
+ vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1544
+ vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1545
+
1546
+ nk_size_t offset = 0;
1547
+ for (nk_size_t vector_length; count > 0;
1548
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
1549
+ vector_length = __riscv_vsetvl_e32m1(count);
1550
+ vuint32m1_t data_u32m1 = __riscv_vle32_v_u32m1(data_ptr, vector_length);
1551
+ vuint64m2_t pos_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
1552
+ vector_length);
1553
+
1554
+ vbool32_t less_b32 = __riscv_vmsltu_vv_u32m1_b32(data_u32m1, min_u32m1, vector_length);
1555
+ min_u32m1 = __riscv_vmerge_vvm_u32m1_tu(min_u32m1, min_u32m1, data_u32m1, less_b32, vector_length);
1556
+ min_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(min_indices_u64m2, min_indices_u64m2, pos_u64m2, less_b32,
1557
+ vector_length);
1558
+
1559
+ vbool32_t greater_b32 = __riscv_vmsltu_vv_u32m1_b32(max_u32m1, data_u32m1, vector_length);
1560
+ max_u32m1 = __riscv_vmerge_vvm_u32m1_tu(max_u32m1, max_u32m1, data_u32m1, greater_b32, vector_length);
1561
+ max_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(max_indices_u64m2, max_indices_u64m2, pos_u64m2, greater_b32,
1562
+ vector_length);
1563
+ }
1564
+
1565
+ // Horizontal reduction for min
1566
+ vuint32m1_t init_max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, 1);
1567
+ nk_u32_t min_val = __riscv_vmv_x_s_u32m1_u32(__riscv_vredminu_vs_u32m1_u32m1(min_u32m1, init_max_u32m1, vlmax));
1568
+ vbool32_t min_match_b32 = __riscv_vmseq_vx_u32m1_b32(min_u32m1, min_val, vlmax);
1569
+ vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
1570
+ vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32, vlmax);
1571
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1572
+ *min_value_ptr = min_val;
1573
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1574
+ __riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, vlmax));
1575
+
1576
+ // Horizontal reduction for max
1577
+ vuint32m1_t init_min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, 1);
1578
+ nk_u32_t max_val = __riscv_vmv_x_s_u32m1_u32(__riscv_vredmaxu_vs_u32m1_u32m1(max_u32m1, init_min_u32m1, vlmax));
1579
+ vbool32_t max_match_b32 = __riscv_vmseq_vx_u32m1_b32(max_u32m1, max_val, vlmax);
1580
+ vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32, vlmax);
1581
+ *max_value_ptr = max_val;
1582
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1583
+ __riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, vlmax));
1584
+ }
1585
+
1586
+ NK_INTERNAL void nk_reduce_minmax_u32_rvv_strided_( //
1587
+ nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1588
+ nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
1589
+ nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
1590
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
1591
+ vuint32m1_t min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, vlmax);
1592
+ vuint32m1_t max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, vlmax);
1593
+ vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1594
+ vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
1595
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1596
+
1597
+ nk_size_t offset = 0;
1598
+ for (nk_size_t vector_length; count > 0;
1599
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
1600
+ vector_length = __riscv_vsetvl_e32m1(count);
1601
+ vuint32m1_t data_u32m1 = __riscv_vlse32_v_u32m1((nk_u32_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1602
+ vuint64m2_t pos_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
1603
+ vector_length);
1604
+
1605
+ vbool32_t less_b32 = __riscv_vmsltu_vv_u32m1_b32(data_u32m1, min_u32m1, vector_length);
1606
+ min_u32m1 = __riscv_vmerge_vvm_u32m1_tu(min_u32m1, min_u32m1, data_u32m1, less_b32, vector_length);
1607
+ min_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(min_indices_u64m2, min_indices_u64m2, pos_u64m2, less_b32,
1608
+ vector_length);
1609
+
1610
+ vbool32_t greater_b32 = __riscv_vmsltu_vv_u32m1_b32(max_u32m1, data_u32m1, vector_length);
1611
+ max_u32m1 = __riscv_vmerge_vvm_u32m1_tu(max_u32m1, max_u32m1, data_u32m1, greater_b32, vector_length);
1612
+ max_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(max_indices_u64m2, max_indices_u64m2, pos_u64m2, greater_b32,
1613
+ vector_length);
1614
+ }
1615
+
1616
+ // Horizontal reduction for min
1617
+ vuint32m1_t init_max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, 1);
1618
+ nk_u32_t min_val = __riscv_vmv_x_s_u32m1_u32(__riscv_vredminu_vs_u32m1_u32m1(min_u32m1, init_max_u32m1, vlmax));
1619
+ vbool32_t min_match_b32 = __riscv_vmseq_vx_u32m1_b32(min_u32m1, min_val, vlmax);
1620
+ vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
1621
+ vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32, vlmax);
1622
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1623
+ *min_value_ptr = min_val;
1624
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1625
+ __riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, vlmax));
1626
+
1627
+ // Horizontal reduction for max
1628
+ vuint32m1_t init_min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, 1);
1629
+ nk_u32_t max_val = __riscv_vmv_x_s_u32m1_u32(__riscv_vredmaxu_vs_u32m1_u32m1(max_u32m1, init_min_u32m1, vlmax));
1630
+ vbool32_t max_match_b32 = __riscv_vmseq_vx_u32m1_b32(max_u32m1, max_val, vlmax);
1631
+ vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32, vlmax);
1632
+ *max_value_ptr = max_val;
1633
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1634
+ __riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, vlmax));
1635
+ }
1636
+
1637
+ NK_PUBLIC void nk_reduce_minmax_u32_rvv( //
1638
+ nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1639
+ nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
1640
+ nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
1641
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
1642
+ int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
1643
+
1644
+ if (count == 0)
1645
+ *min_value_ptr = NK_U32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_U32_MIN,
1646
+ *max_index_ptr = NK_SIZE_MAX;
1647
+ else if (!aligned)
1648
+ nk_reduce_minmax_u32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1649
+ max_index_ptr);
1650
+ else if (stride_elements == 1)
1651
+ nk_reduce_minmax_u32_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
1652
+ max_index_ptr);
1653
+ else
1654
+ nk_reduce_minmax_u32_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1655
+ max_index_ptr);
1656
+ }
1657
+
1658
+ NK_INTERNAL void nk_reduce_moments_i64_rvv_contiguous_( //
1659
+ nk_i64_t const *data_ptr, nk_size_t count, //
1660
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1661
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
1662
+ // 128-bit per-lane accumulator for sum: (sum_upper, sum_lower)
1663
+ vuint64m1_t sum_lower_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1664
+ vint64m1_t sum_upper_i64m1 = __riscv_vmv_v_x_i64m1(0, vlmax);
1665
+ vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1666
+
1667
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
1668
+ vector_length = __riscv_vsetvl_e64m1(count);
1669
+ vint64m1_t data_i64m1 = __riscv_vle64_v_i64m1(data_ptr, vector_length);
1670
+
1671
+ // 128-bit sum accumulation: wrapping add on lower half
1672
+ vuint64m1_t data_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(data_i64m1);
1673
+ vuint64m1_t sum_before_u64m1 = sum_lower_u64m1;
1674
+ sum_lower_u64m1 = __riscv_vadd_vv_u64m1_tu(sum_lower_u64m1, sum_lower_u64m1, data_u64m1, vector_length);
1675
+
1676
+ // Carry: new < old means unsigned overflow occurred
1677
+ vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(sum_lower_u64m1, sum_before_u64m1, vector_length);
1678
+ vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0, vector_length), 1, carry_b64,
1679
+ vector_length);
1680
+ sum_upper_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_upper_i64m1, sum_upper_i64m1, carry_i64m1, vector_length);
1681
+
1682
+ // Sign extension: -1 for negative, 0 for non-negative
1683
+ vint64m1_t sign_ext_i64m1 = __riscv_vsra_vx_i64m1(data_i64m1, 63, vector_length);
1684
+ sum_upper_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_upper_i64m1, sum_upper_i64m1, sign_ext_i64m1, vector_length);
1685
+
1686
+ // Sumsq: abs(val)² with overflow detection
1687
+ vint64m1_t negated_i64m1 = __riscv_vneg_v_i64m1(data_i64m1, vector_length);
1688
+ vint64m1_t absolute_i64m1 = __riscv_vmax_vv_i64m1(data_i64m1, negated_i64m1, vector_length);
1689
+ vuint64m1_t absolute_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(absolute_i64m1);
1690
+ vuint64m1_t product_low_u64m1 = __riscv_vmul_vv_u64m1(absolute_u64m1, absolute_u64m1, vector_length);
1691
+ vuint64m1_t product_high_u64m1 = __riscv_vmulhu_vv_u64m1(absolute_u64m1, absolute_u64m1, vector_length);
1692
+ vbool64_t overflow_b64 = __riscv_vmsne_vx_u64m1_b64(product_high_u64m1, 0, vector_length);
1693
+ vuint64m1_t squares_u64m1 = __riscv_vmerge_vxm_u64m1(product_low_u64m1, NK_U64_MAX, overflow_b64,
1694
+ vector_length);
1695
+ sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
1696
+ }
1697
+
1698
+ *sum_ptr = nk_reduce_128bit_sum_i64m1_rvv_(sum_lower_u64m1, sum_upper_i64m1, vlmax);
1699
+ *sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, vlmax);
1700
+ }
1701
+
1702
+ NK_INTERNAL void nk_reduce_moments_i64_rvv_strided_( //
1703
+ nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1704
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1705
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
1706
+ // 128-bit per-lane accumulator for sum: (sum_upper, sum_lower)
1707
+ vuint64m1_t sum_lower_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1708
+ vint64m1_t sum_upper_i64m1 = __riscv_vmv_v_x_i64m1(0, vlmax);
1709
+ vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1710
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1711
+
1712
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
1713
+ vector_length = __riscv_vsetvl_e64m1(count);
1714
+ vint64m1_t data_i64m1 = __riscv_vlse64_v_i64m1((nk_i64_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1715
+
1716
+ // 128-bit sum accumulation: wrapping add on lower half
1717
+ vuint64m1_t data_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(data_i64m1);
1718
+ vuint64m1_t sum_before_u64m1 = sum_lower_u64m1;
1719
+ sum_lower_u64m1 = __riscv_vadd_vv_u64m1_tu(sum_lower_u64m1, sum_lower_u64m1, data_u64m1, vector_length);
1720
+
1721
+ // Carry: new < old means unsigned overflow occurred
1722
+ vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(sum_lower_u64m1, sum_before_u64m1, vector_length);
1723
+ vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0, vector_length), 1, carry_b64,
1724
+ vector_length);
1725
+ sum_upper_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_upper_i64m1, sum_upper_i64m1, carry_i64m1, vector_length);
1726
+
1727
+ // Sign extension: -1 for negative, 0 for non-negative
1728
+ vint64m1_t sign_ext_i64m1 = __riscv_vsra_vx_i64m1(data_i64m1, 63, vector_length);
1729
+ sum_upper_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_upper_i64m1, sum_upper_i64m1, sign_ext_i64m1, vector_length);
1730
+
1731
+ // Sumsq: abs(val)² with overflow detection
1732
+ vint64m1_t negated_i64m1 = __riscv_vneg_v_i64m1(data_i64m1, vector_length);
1733
+ vint64m1_t absolute_i64m1 = __riscv_vmax_vv_i64m1(data_i64m1, negated_i64m1, vector_length);
1734
+ vuint64m1_t absolute_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(absolute_i64m1);
1735
+ vuint64m1_t product_low_u64m1 = __riscv_vmul_vv_u64m1(absolute_u64m1, absolute_u64m1, vector_length);
1736
+ vuint64m1_t product_high_u64m1 = __riscv_vmulhu_vv_u64m1(absolute_u64m1, absolute_u64m1, vector_length);
1737
+ vbool64_t overflow_b64 = __riscv_vmsne_vx_u64m1_b64(product_high_u64m1, 0, vector_length);
1738
+ vuint64m1_t squares_u64m1 = __riscv_vmerge_vxm_u64m1(product_low_u64m1, NK_U64_MAX, overflow_b64,
1739
+ vector_length);
1740
+ sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
1741
+ }
1742
+
1743
+ *sum_ptr = nk_reduce_128bit_sum_i64m1_rvv_(sum_lower_u64m1, sum_upper_i64m1, vlmax);
1744
+ *sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, vlmax);
1745
+ }
1746
+
1747
+ NK_PUBLIC void nk_reduce_moments_i64_rvv( //
1748
+ nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1749
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1750
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
1751
+ int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
1752
+
1753
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
1754
+ else if (!aligned) { nk_reduce_moments_i64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1755
+ else if (stride_elements == 1) { nk_reduce_moments_i64_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
1756
+ else { nk_reduce_moments_i64_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1757
+ }
1758
+
1759
+ NK_INTERNAL void nk_reduce_minmax_i64_rvv_contiguous_( //
1760
+ nk_i64_t const *data_ptr, nk_size_t count, //
1761
+ nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
1762
+ nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
1763
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
1764
+ vint64m1_t min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, vlmax);
1765
+ vint64m1_t max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, vlmax);
1766
+ vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1767
+ vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1768
+
1769
+ nk_size_t offset = 0;
1770
+ for (nk_size_t vector_length; count > 0;
1771
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
1772
+ vector_length = __riscv_vsetvl_e64m1(count);
1773
+ vint64m1_t data_i64m1 = __riscv_vle64_v_i64m1(data_ptr, vector_length);
1774
+ vuint64m1_t pos_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
1775
+ vector_length);
1776
+
1777
+ vbool64_t less_b64 = __riscv_vmslt_vv_i64m1_b64(data_i64m1, min_i64m1, vector_length);
1778
+ min_i64m1 = __riscv_vmerge_vvm_i64m1_tu(min_i64m1, min_i64m1, data_i64m1, less_b64, vector_length);
1779
+ min_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_indices_u64m1, min_indices_u64m1, pos_u64m1, less_b64,
1780
+ vector_length);
1781
+
1782
+ vbool64_t greater_b64 = __riscv_vmslt_vv_i64m1_b64(max_i64m1, data_i64m1, vector_length);
1783
+ max_i64m1 = __riscv_vmerge_vvm_i64m1_tu(max_i64m1, max_i64m1, data_i64m1, greater_b64, vector_length);
1784
+ max_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_indices_u64m1, max_indices_u64m1, pos_u64m1, greater_b64,
1785
+ vector_length);
1786
+ }
1787
+
1788
+ // Horizontal reduction for min
1789
+ vint64m1_t init_max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, 1);
1790
+ nk_i64_t min_val = __riscv_vmv_x_s_i64m1_i64(__riscv_vredmin_vs_i64m1_i64m1(min_i64m1, init_max_i64m1, vlmax));
1791
+ vbool64_t min_match_b64 = __riscv_vmseq_vx_i64m1_b64(min_i64m1, min_val, vlmax);
1792
+ vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
1793
+ vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64, vlmax);
1794
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1795
+ *min_value_ptr = min_val;
1796
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1797
+ __riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, vlmax));
1798
+
1799
+ // Horizontal reduction for max
1800
+ vint64m1_t init_min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, 1);
1801
+ nk_i64_t max_val = __riscv_vmv_x_s_i64m1_i64(__riscv_vredmax_vs_i64m1_i64m1(max_i64m1, init_min_i64m1, vlmax));
1802
+ vbool64_t max_match_b64 = __riscv_vmseq_vx_i64m1_b64(max_i64m1, max_val, vlmax);
1803
+ vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64, vlmax);
1804
+ *max_value_ptr = max_val;
1805
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1806
+ __riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, vlmax));
1807
+ }
1808
+
1809
+ NK_INTERNAL void nk_reduce_minmax_i64_rvv_strided_( //
1810
+ nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1811
+ nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
1812
+ nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
1813
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
1814
+ vint64m1_t min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, vlmax);
1815
+ vint64m1_t max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, vlmax);
1816
+ vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1817
+ vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1818
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1819
+
1820
+ nk_size_t offset = 0;
1821
+ for (nk_size_t vector_length; count > 0;
1822
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
1823
+ vector_length = __riscv_vsetvl_e64m1(count);
1824
+ vint64m1_t data_i64m1 = __riscv_vlse64_v_i64m1((nk_i64_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1825
+ vuint64m1_t pos_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
1826
+ vector_length);
1827
+
1828
+ vbool64_t less_b64 = __riscv_vmslt_vv_i64m1_b64(data_i64m1, min_i64m1, vector_length);
1829
+ min_i64m1 = __riscv_vmerge_vvm_i64m1_tu(min_i64m1, min_i64m1, data_i64m1, less_b64, vector_length);
1830
+ min_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_indices_u64m1, min_indices_u64m1, pos_u64m1, less_b64,
1831
+ vector_length);
1832
+
1833
+ vbool64_t greater_b64 = __riscv_vmslt_vv_i64m1_b64(max_i64m1, data_i64m1, vector_length);
1834
+ max_i64m1 = __riscv_vmerge_vvm_i64m1_tu(max_i64m1, max_i64m1, data_i64m1, greater_b64, vector_length);
1835
+ max_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_indices_u64m1, max_indices_u64m1, pos_u64m1, greater_b64,
1836
+ vector_length);
1837
+ }
1838
+
1839
+ // Horizontal reduction for min
1840
+ vint64m1_t init_max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, 1);
1841
+ nk_i64_t min_val = __riscv_vmv_x_s_i64m1_i64(__riscv_vredmin_vs_i64m1_i64m1(min_i64m1, init_max_i64m1, vlmax));
1842
+ vbool64_t min_match_b64 = __riscv_vmseq_vx_i64m1_b64(min_i64m1, min_val, vlmax);
1843
+ vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
1844
+ vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64, vlmax);
1845
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1846
+ *min_value_ptr = min_val;
1847
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1848
+ __riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, vlmax));
1849
+
1850
+ // Horizontal reduction for max
1851
+ vint64m1_t init_min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, 1);
1852
+ nk_i64_t max_val = __riscv_vmv_x_s_i64m1_i64(__riscv_vredmax_vs_i64m1_i64m1(max_i64m1, init_min_i64m1, vlmax));
1853
+ vbool64_t max_match_b64 = __riscv_vmseq_vx_i64m1_b64(max_i64m1, max_val, vlmax);
1854
+ vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64, vlmax);
1855
+ *max_value_ptr = max_val;
1856
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1857
+ __riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, vlmax));
1858
+ }
1859
+
1860
+ NK_PUBLIC void nk_reduce_minmax_i64_rvv( //
1861
+ nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1862
+ nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
1863
+ nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
1864
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
1865
+ int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
1866
+
1867
+ if (count == 0)
1868
+ *min_value_ptr = NK_I64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I64_MIN,
1869
+ *max_index_ptr = NK_SIZE_MAX;
1870
+ else if (!aligned)
1871
+ nk_reduce_minmax_i64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1872
+ max_index_ptr);
1873
+ else if (stride_elements == 1)
1874
+ nk_reduce_minmax_i64_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
1875
+ max_index_ptr);
1876
+ else
1877
+ nk_reduce_minmax_i64_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1878
+ max_index_ptr);
1879
+ }
1880
+
1881
+ NK_INTERNAL void nk_reduce_moments_u64_rvv_contiguous_( //
1882
+ nk_u64_t const *data_ptr, nk_size_t count, //
1883
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1884
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
1885
+ vuint64m1_t sum_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1886
+ vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1887
+
1888
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
1889
+ vector_length = __riscv_vsetvl_e64m1(count);
1890
+ vuint64m1_t data_u64m1 = __riscv_vle64_v_u64m1(data_ptr, vector_length);
1891
+
1892
+ // Saturating unsigned sum
1893
+ sum_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sum_u64m1, sum_u64m1, data_u64m1, vector_length);
1894
+
1895
+ // Sumsq: u64 × u64 with overflow detection via vmul + vmulhu
1896
+ vuint64m1_t product_low_u64m1 = __riscv_vmul_vv_u64m1(data_u64m1, data_u64m1, vector_length);
1897
+ vuint64m1_t product_high_u64m1 = __riscv_vmulhu_vv_u64m1(data_u64m1, data_u64m1, vector_length);
1898
+ vbool64_t overflow_b64 = __riscv_vmsne_vx_u64m1_b64(product_high_u64m1, 0, vector_length);
1899
+ vuint64m1_t squares_u64m1 = __riscv_vmerge_vxm_u64m1(product_low_u64m1, NK_U64_MAX, overflow_b64,
1900
+ vector_length);
1901
+ sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
1902
+ }
1903
+
1904
+ *sum_ptr = nk_reduce_vsaddu_u64m1_rvv_(sum_u64m1, vlmax);
1905
+ *sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, vlmax);
1906
+ }
1907
+
1908
+ NK_INTERNAL void nk_reduce_moments_u64_rvv_strided_( //
1909
+ nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1910
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1911
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
1912
+ vuint64m1_t sum_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1913
+ vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1914
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
1915
+
1916
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
1917
+ vector_length = __riscv_vsetvl_e64m1(count);
1918
+ vuint64m1_t data_u64m1 = __riscv_vlse64_v_u64m1((nk_u64_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
1919
+
1920
+ // Saturating unsigned sum
1921
+ sum_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sum_u64m1, sum_u64m1, data_u64m1, vector_length);
1922
+
1923
+ // Sumsq: u64 × u64 with overflow detection via vmul + vmulhu
1924
+ vuint64m1_t product_low_u64m1 = __riscv_vmul_vv_u64m1(data_u64m1, data_u64m1, vector_length);
1925
+ vuint64m1_t product_high_u64m1 = __riscv_vmulhu_vv_u64m1(data_u64m1, data_u64m1, vector_length);
1926
+ vbool64_t overflow_b64 = __riscv_vmsne_vx_u64m1_b64(product_high_u64m1, 0, vector_length);
1927
+ vuint64m1_t squares_u64m1 = __riscv_vmerge_vxm_u64m1(product_low_u64m1, NK_U64_MAX, overflow_b64,
1928
+ vector_length);
1929
+ sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
1930
+ }
1931
+
1932
+ *sum_ptr = nk_reduce_vsaddu_u64m1_rvv_(sum_u64m1, vlmax);
1933
+ *sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, vlmax);
1934
+ }
1935
+
1936
+ NK_PUBLIC void nk_reduce_moments_u64_rvv( //
1937
+ nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1938
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1939
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
1940
+ int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
1941
+
1942
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
1943
+ else if (!aligned) { nk_reduce_moments_u64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1944
+ else if (stride_elements == 1) { nk_reduce_moments_u64_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
1945
+ else { nk_reduce_moments_u64_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
1946
+ }
1947
+
1948
+ NK_INTERNAL void nk_reduce_minmax_u64_rvv_contiguous_( //
1949
+ nk_u64_t const *data_ptr, nk_size_t count, //
1950
+ nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
1951
+ nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
1952
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
1953
+ vuint64m1_t min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
1954
+ vuint64m1_t max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, vlmax);
1955
+ vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1956
+ vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
1957
+
1958
+ nk_size_t offset = 0;
1959
+ for (nk_size_t vector_length; count > 0;
1960
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
1961
+ vector_length = __riscv_vsetvl_e64m1(count);
1962
+ vuint64m1_t data_u64m1 = __riscv_vle64_v_u64m1(data_ptr, vector_length);
1963
+ vuint64m1_t pos_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
1964
+ vector_length);
1965
+
1966
+ vbool64_t less_b64 = __riscv_vmsltu_vv_u64m1_b64(data_u64m1, min_u64m1, vector_length);
1967
+ min_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_u64m1, min_u64m1, data_u64m1, less_b64, vector_length);
1968
+ min_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_indices_u64m1, min_indices_u64m1, pos_u64m1, less_b64,
1969
+ vector_length);
1970
+
1971
+ vbool64_t greater_b64 = __riscv_vmsltu_vv_u64m1_b64(max_u64m1, data_u64m1, vector_length);
1972
+ max_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_u64m1, max_u64m1, data_u64m1, greater_b64, vector_length);
1973
+ max_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_indices_u64m1, max_indices_u64m1, pos_u64m1, greater_b64,
1974
+ vector_length);
1975
+ }
1976
+
1977
+ // Horizontal reduction for min
1978
+ vuint64m1_t init_max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1979
+ nk_u64_t min_val = __riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(min_u64m1, init_max_u64m1, vlmax));
1980
+ vbool64_t min_match_b64 = __riscv_vmseq_vx_u64m1_b64(min_u64m1, min_val, vlmax);
1981
+ vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
1982
+ vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64, vlmax);
1983
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
1984
+ *min_value_ptr = min_val;
1985
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1986
+ __riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, vlmax));
1987
+
1988
+ // Horizontal reduction for max
1989
+ vuint64m1_t init_min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, 1);
1990
+ nk_u64_t max_val = __riscv_vmv_x_s_u64m1_u64(__riscv_vredmaxu_vs_u64m1_u64m1(max_u64m1, init_min_u64m1, vlmax));
1991
+ vbool64_t max_match_b64 = __riscv_vmseq_vx_u64m1_b64(max_u64m1, max_val, vlmax);
1992
+ vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64, vlmax);
1993
+ *max_value_ptr = max_val;
1994
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
1995
+ __riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, vlmax));
1996
+ }
1997
+
1998
+ NK_INTERNAL void nk_reduce_minmax_u64_rvv_strided_( //
1999
+ nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2000
+ nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
2001
+ nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
2002
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
2003
+ vuint64m1_t min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
2004
+ vuint64m1_t max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, vlmax);
2005
+ vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
2006
+ vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
2007
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
2008
+
2009
+ nk_size_t offset = 0;
2010
+ for (nk_size_t vector_length; count > 0;
2011
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
2012
+ vector_length = __riscv_vsetvl_e64m1(count);
2013
+ vuint64m1_t data_u64m1 = __riscv_vlse64_v_u64m1((nk_u64_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
2014
+ vuint64m1_t pos_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
2015
+ vector_length);
2016
+
2017
+ vbool64_t less_b64 = __riscv_vmsltu_vv_u64m1_b64(data_u64m1, min_u64m1, vector_length);
2018
+ min_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_u64m1, min_u64m1, data_u64m1, less_b64, vector_length);
2019
+ min_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_indices_u64m1, min_indices_u64m1, pos_u64m1, less_b64,
2020
+ vector_length);
2021
+
2022
+ vbool64_t greater_b64 = __riscv_vmsltu_vv_u64m1_b64(max_u64m1, data_u64m1, vector_length);
2023
+ max_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_u64m1, max_u64m1, data_u64m1, greater_b64, vector_length);
2024
+ max_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_indices_u64m1, max_indices_u64m1, pos_u64m1, greater_b64,
2025
+ vector_length);
2026
+ }
2027
+
2028
+ // Horizontal reduction for min
2029
+ vuint64m1_t init_max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2030
+ nk_u64_t min_val = __riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(min_u64m1, init_max_u64m1, vlmax));
2031
+ vbool64_t min_match_b64 = __riscv_vmseq_vx_u64m1_b64(min_u64m1, min_val, vlmax);
2032
+ vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
2033
+ vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64, vlmax);
2034
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2035
+ *min_value_ptr = min_val;
2036
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2037
+ __riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, vlmax));
2038
+
2039
+ // Horizontal reduction for max
2040
+ vuint64m1_t init_min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, 1);
2041
+ nk_u64_t max_val = __riscv_vmv_x_s_u64m1_u64(__riscv_vredmaxu_vs_u64m1_u64m1(max_u64m1, init_min_u64m1, vlmax));
2042
+ vbool64_t max_match_b64 = __riscv_vmseq_vx_u64m1_b64(max_u64m1, max_val, vlmax);
2043
+ vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64, vlmax);
2044
+ *max_value_ptr = max_val;
2045
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2046
+ __riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, vlmax));
2047
+ }
2048
+
2049
+ NK_PUBLIC void nk_reduce_minmax_u64_rvv( //
2050
+ nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2051
+ nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
2052
+ nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
2053
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
2054
+ int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
2055
+
2056
+ if (count == 0)
2057
+ *min_value_ptr = NK_U64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_U64_MIN,
2058
+ *max_index_ptr = NK_SIZE_MAX;
2059
+ else if (!aligned)
2060
+ nk_reduce_minmax_u64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2061
+ max_index_ptr);
2062
+ else if (stride_elements == 1)
2063
+ nk_reduce_minmax_u64_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
2064
+ max_index_ptr);
2065
+ else
2066
+ nk_reduce_minmax_u64_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2067
+ max_index_ptr);
2068
+ }
2069
+
2070
+ NK_INTERNAL void nk_reduce_moments_bf16_rvv_contiguous_( //
2071
+ nk_bf16_t const *data_ptr, nk_size_t count, //
2072
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2073
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
2074
+ vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2075
+ vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2076
+
2077
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
2078
+ vector_length = __riscv_vsetvl_e16m1(count);
2079
+ vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((uint16_t const *)data_ptr, vector_length);
2080
+
2081
+ // Convert bf16 → f32 (m1 → m2)
2082
+ vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
2083
+
2084
+ // Widen f32 → f64 (m2 → m4)
2085
+ vfloat64m4_t data_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(data_f32m2, vector_length);
2086
+ sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
2087
+
2088
+ // Sumsq via widening FMA: f32×f32 → f64
2089
+ sumsq_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(sumsq_f64m4, data_f32m2, data_f32m2, vector_length);
2090
+ }
2091
+
2092
+ // Horizontal reduction
2093
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2094
+ *sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)),
2095
+ *sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, vlmax));
2096
+ }
2097
+
2098
+ NK_INTERNAL void nk_reduce_moments_bf16_rvv_strided_( //
2099
+ nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2100
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2101
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
2102
+ vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2103
+ vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2104
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
2105
+
2106
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
2107
+ vector_length = __riscv_vsetvl_e16m1(count);
2108
+ vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((uint16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
2109
+
2110
+ // Convert bf16 → f32 (m1 → m2)
2111
+ vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
2112
+
2113
+ // Widen f32 → f64 (m2 → m4)
2114
+ vfloat64m4_t data_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(data_f32m2, vector_length);
2115
+ sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
2116
+
2117
+ // Sumsq via widening FMA: f32×f32 → f64
2118
+ sumsq_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(sumsq_f64m4, data_f32m2, data_f32m2, vector_length);
2119
+ }
2120
+
2121
+ // Horizontal reduction
2122
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2123
+ *sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)),
2124
+ *sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, vlmax));
2125
+ }
2126
+
2127
+ NK_PUBLIC void nk_reduce_moments_bf16_rvv( //
2128
+ nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2129
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2130
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
2131
+ int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
2132
+
2133
+ if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
2134
+ else if (!aligned) nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2135
+ else if (stride_elements == 1) nk_reduce_moments_bf16_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
2136
+ else nk_reduce_moments_bf16_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2137
+ }
2138
+
2139
+ NK_INTERNAL void nk_reduce_minmax_bf16_rvv_contiguous_( //
2140
+ nk_bf16_t const *data_ptr, nk_size_t count, //
2141
+ nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
2142
+ nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
2143
+ nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
2144
+ vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7F80, vlmax); // +inf in bf16
2145
+ vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFF80, vlmax); // -inf in bf16
2146
+ vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
2147
+ vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
2148
+
2149
+ nk_size_t offset = 0;
2150
+ for (nk_size_t vector_length; count > 0;
2151
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
2152
+ vector_length = __riscv_vsetvl_e16m1(count);
2153
+ vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((uint16_t const *)data_ptr, vector_length);
2154
+ vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
2155
+ vector_length);
2156
+
2157
+ // Convert to f32 for comparison
2158
+ vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
2159
+ vfloat32m2_t min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vector_length);
2160
+ vfloat32m2_t max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, vector_length);
2161
+
2162
+ vbool16_t less_b16 = __riscv_vmflt_vv_f32m2_b16(data_f32m2, min_f32m2, vector_length);
2163
+ min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
2164
+ min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
2165
+ vector_length);
2166
+
2167
+ vbool16_t greater_b16 = __riscv_vmflt_vv_f32m2_b16(max_f32m2, data_f32m2, vector_length);
2168
+ max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
2169
+ max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
2170
+ vector_length);
2171
+ }
2172
+
2173
+ // Horizontal reduction
2174
+ vfloat32m2_t final_min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vlmax);
2175
+ vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
2176
+ nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
2177
+ __riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, vlmax));
2178
+ vfloat32m2_t final_max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, vlmax);
2179
+ vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
2180
+ nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
2181
+ __riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, vlmax));
2182
+ if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
2183
+ *min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
2184
+ *max_index_ptr = NK_SIZE_MAX;
2185
+ return;
2186
+ }
2187
+
2188
+ vfloat32m2_t converted_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vlmax);
2189
+ vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, vlmax);
2190
+ vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
2191
+ vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
2192
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2193
+
2194
+ nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
2195
+ __riscv_vslidedown_vx_u16m1(min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, vlmax), vlmax));
2196
+ *min_value_ptr = *(nk_bf16_t *)&min_raw;
2197
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2198
+ __riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
2199
+
2200
+ vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_bf16m1_to_f32m2_rvv_(max_u16m1, vlmax), max_val_f32, vlmax);
2201
+ vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
2202
+
2203
+ nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
2204
+ __riscv_vslidedown_vx_u16m1(max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, vlmax), vlmax));
2205
+ *max_value_ptr = *(nk_bf16_t *)&max_raw;
2206
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2207
+ __riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
2208
+ }
2209
+
2210
+ NK_INTERNAL void nk_reduce_minmax_bf16_rvv_strided_( //
2211
+ nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2212
+ nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
2213
+ nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
2214
+ nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
2215
+ vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7F80, vlmax); // +inf in bf16
2216
+ vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFF80, vlmax); // -inf in bf16
2217
+ vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
2218
+ vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
2219
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
2220
+
2221
+ nk_size_t offset = 0;
2222
+ for (nk_size_t vector_length; count > 0;
2223
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
2224
+ vector_length = __riscv_vsetvl_e16m1(count);
2225
+ vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((uint16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
2226
+ vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
2227
+ vector_length);
2228
+
2229
+ // Convert to f32 for comparison
2230
+ vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
2231
+ vfloat32m2_t min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vector_length);
2232
+ vfloat32m2_t max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, vector_length);
2233
+
2234
+ vbool16_t less_b16 = __riscv_vmflt_vv_f32m2_b16(data_f32m2, min_f32m2, vector_length);
2235
+ min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
2236
+ min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
2237
+ vector_length);
2238
+
2239
+ vbool16_t greater_b16 = __riscv_vmflt_vv_f32m2_b16(max_f32m2, data_f32m2, vector_length);
2240
+ max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
2241
+ max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
2242
+ vector_length);
2243
+ }
2244
+
2245
+ // Horizontal reduction (same as contiguous)
2246
+ vfloat32m2_t final_min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vlmax);
2247
+ vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
2248
+ nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
2249
+ __riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, vlmax));
2250
+ vfloat32m2_t final_max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, vlmax);
2251
+ vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
2252
+ nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
2253
+ __riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, vlmax));
2254
+ if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
2255
+ *min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
2256
+ *max_index_ptr = NK_SIZE_MAX;
2257
+ return;
2258
+ }
2259
+
2260
+ vfloat32m2_t converted_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vlmax);
2261
+ vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, vlmax);
2262
+ vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
2263
+ vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
2264
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2265
+
2266
+ nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
2267
+ __riscv_vslidedown_vx_u16m1(min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, vlmax), vlmax));
2268
+ *min_value_ptr = *(nk_bf16_t *)&min_raw;
2269
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2270
+ __riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
2271
+
2272
+ vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_bf16m1_to_f32m2_rvv_(max_u16m1, vlmax), max_val_f32, vlmax);
2273
+ vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
2274
+
2275
+ nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
2276
+ __riscv_vslidedown_vx_u16m1(max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, vlmax), vlmax));
2277
+ *max_value_ptr = *(nk_bf16_t *)&max_raw;
2278
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2279
+ __riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
2280
+ }
2281
+
2282
+ NK_PUBLIC void nk_reduce_minmax_bf16_rvv( //
2283
+ nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2284
+ nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
2285
+ nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
2286
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
2287
+ int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
2288
+
2289
+ if (count == 0)
2290
+ *min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
2291
+ *max_index_ptr = NK_SIZE_MAX;
2292
+ else if (!aligned)
2293
+ nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2294
+ max_index_ptr);
2295
+ else if (stride_elements == 1)
2296
+ nk_reduce_minmax_bf16_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
2297
+ max_index_ptr);
2298
+ else
2299
+ nk_reduce_minmax_bf16_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2300
+ max_index_ptr);
2301
+ }
2302
+
2303
+ NK_INTERNAL void nk_reduce_moments_f16_rvv_contiguous_( //
2304
+ nk_f16_t const *data_ptr, nk_size_t count, //
2305
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2306
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
2307
+ vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2308
+ vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2309
+
2310
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
2311
+ vector_length = __riscv_vsetvl_e16m1(count);
2312
+ vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((uint16_t const *)data_ptr, vector_length);
2313
+
2314
+ // Convert f16 → f32 (m1 → m2)
2315
+ vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
2316
+
2317
+ // Widen f32 → f64 (m2 → m4)
2318
+ vfloat64m4_t data_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(data_f32m2, vector_length);
2319
+ sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
2320
+
2321
+ // Sumsq via widening FMA: f32×f32 → f64
2322
+ sumsq_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(sumsq_f64m4, data_f32m2, data_f32m2, vector_length);
2323
+ }
2324
+
2325
+ // Horizontal reduction
2326
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2327
+ *sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)),
2328
+ *sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, vlmax));
2329
+ }
2330
+
2331
+ NK_INTERNAL void nk_reduce_moments_f16_rvv_strided_( //
2332
+ nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2333
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2334
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
2335
+ vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2336
+ vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
2337
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
2338
+
2339
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
2340
+ vector_length = __riscv_vsetvl_e16m1(count);
2341
+ vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((uint16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
2342
+
2343
+ // Convert f16 → f32 (m1 → m2)
2344
+ vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
2345
+
2346
+ // Widen f32 → f64 (m2 → m4)
2347
+ vfloat64m4_t data_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(data_f32m2, vector_length);
2348
+ sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
2349
+
2350
+ // Sumsq via widening FMA: f32×f32 → f64
2351
+ sumsq_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(sumsq_f64m4, data_f32m2, data_f32m2, vector_length);
2352
+ }
2353
+
2354
+ // Horizontal reduction
2355
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
2356
+ *sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)),
2357
+ *sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, vlmax));
2358
+ }
2359
+
2360
+ NK_PUBLIC void nk_reduce_moments_f16_rvv( //
2361
+ nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2362
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2363
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
2364
+ int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
2365
+
2366
+ if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
2367
+ else if (!aligned) nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2368
+ else if (stride_elements == 1) nk_reduce_moments_f16_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
2369
+ else nk_reduce_moments_f16_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2370
+ }
2371
+
2372
+ NK_INTERNAL void nk_reduce_minmax_f16_rvv_contiguous_( //
2373
+ nk_f16_t const *data_ptr, nk_size_t count, //
2374
+ nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
2375
+ nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
2376
+ nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
2377
+ vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7C00, vlmax); // +inf in f16
2378
+ vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFC00, vlmax); // -inf in f16
2379
+ vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
2380
+ vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
2381
+
2382
+ nk_size_t offset = 0;
2383
+ for (nk_size_t vector_length; count > 0;
2384
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
2385
+ vector_length = __riscv_vsetvl_e16m1(count);
2386
+ vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((uint16_t const *)data_ptr, vector_length);
2387
+ vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
2388
+ vector_length);
2389
+
2390
+ // Convert to f32 for comparison
2391
+ vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
2392
+ vfloat32m2_t min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vector_length);
2393
+ vfloat32m2_t max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, vector_length);
2394
+
2395
+ vbool16_t less_b16 = __riscv_vmflt_vv_f32m2_b16(data_f32m2, min_f32m2, vector_length);
2396
+ min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
2397
+ min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
2398
+ vector_length);
2399
+
2400
+ vbool16_t greater_b16 = __riscv_vmflt_vv_f32m2_b16(max_f32m2, data_f32m2, vector_length);
2401
+ max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
2402
+ max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
2403
+ vector_length);
2404
+ }
2405
+
2406
+ // Horizontal reduction
2407
+ vfloat32m2_t final_min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vlmax);
2408
+ vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
2409
+ nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
2410
+ __riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, vlmax));
2411
+ vfloat32m2_t final_max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, vlmax);
2412
+ vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
2413
+ nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
2414
+ __riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, vlmax));
2415
+ if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
2416
+ *min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
2417
+ *max_index_ptr = NK_SIZE_MAX;
2418
+ return;
2419
+ }
2420
+
2421
+ vfloat32m2_t converted_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vlmax);
2422
+ vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, vlmax);
2423
+ vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
2424
+ vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
2425
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2426
+
2427
+ nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
2428
+ __riscv_vslidedown_vx_u16m1(min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, vlmax), vlmax));
2429
+ *min_value_ptr = *(nk_f16_t *)&min_raw;
2430
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2431
+ __riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
2432
+
2433
+ vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_f16m1_to_f32m2_rvv_(max_u16m1, vlmax), max_val_f32, vlmax);
2434
+ vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
2435
+
2436
+ nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
2437
+ __riscv_vslidedown_vx_u16m1(max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, vlmax), vlmax));
2438
+ *max_value_ptr = *(nk_f16_t *)&max_raw;
2439
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2440
+ __riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
2441
+ }
2442
+
2443
+ NK_INTERNAL void nk_reduce_minmax_f16_rvv_strided_( //
2444
+ nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2445
+ nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
2446
+ nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
2447
+ nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
2448
+ vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7C00, vlmax); // +inf in f16
2449
+ vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFC00, vlmax); // -inf in f16
2450
+ vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
2451
+ vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
2452
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
2453
+
2454
+ nk_size_t offset = 0;
2455
+ for (nk_size_t vector_length; count > 0;
2456
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
2457
+ vector_length = __riscv_vsetvl_e16m1(count);
2458
+ vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((uint16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
2459
+ vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
2460
+ vector_length);
2461
+
2462
+ // Convert to f32 for comparison
2463
+ vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
2464
+ vfloat32m2_t min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vector_length);
2465
+ vfloat32m2_t max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, vector_length);
2466
+
2467
+ vbool16_t less_b16 = __riscv_vmflt_vv_f32m2_b16(data_f32m2, min_f32m2, vector_length);
2468
+ min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
2469
+ min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
2470
+ vector_length);
2471
+
2472
+ vbool16_t greater_b16 = __riscv_vmflt_vv_f32m2_b16(max_f32m2, data_f32m2, vector_length);
2473
+ max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
2474
+ max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
2475
+ vector_length);
2476
+ }
2477
+
2478
+ // Horizontal reduction (same as contiguous)
2479
+ vfloat32m2_t final_min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vlmax);
2480
+ vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
2481
+ nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
2482
+ __riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, vlmax));
2483
+ vfloat32m2_t final_max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, vlmax);
2484
+ vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
2485
+ nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
2486
+ __riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, vlmax));
2487
+ if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
2488
+ *min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
2489
+ *max_index_ptr = NK_SIZE_MAX;
2490
+ return;
2491
+ }
2492
+
2493
+ vfloat32m2_t converted_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vlmax);
2494
+ vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, vlmax);
2495
+ vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
2496
+ vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
2497
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2498
+
2499
+ nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
2500
+ __riscv_vslidedown_vx_u16m1(min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, vlmax), vlmax));
2501
+ *min_value_ptr = *(nk_f16_t *)&min_raw;
2502
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2503
+ __riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
2504
+
2505
+ vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_f16m1_to_f32m2_rvv_(max_u16m1, vlmax), max_val_f32, vlmax);
2506
+ vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
2507
+
2508
+ nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
2509
+ __riscv_vslidedown_vx_u16m1(max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, vlmax), vlmax));
2510
+ *max_value_ptr = *(nk_f16_t *)&max_raw;
2511
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2512
+ __riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
2513
+ }
2514
+
2515
+ NK_PUBLIC void nk_reduce_minmax_f16_rvv( //
2516
+ nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2517
+ nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
2518
+ nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
2519
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
2520
+ int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
2521
+
2522
+ if (count == 0)
2523
+ *min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
2524
+ *max_index_ptr = NK_SIZE_MAX;
2525
+ else if (!aligned)
2526
+ nk_reduce_minmax_f16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2527
+ max_index_ptr);
2528
+ else if (stride_elements == 1)
2529
+ nk_reduce_minmax_f16_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
2530
+ max_index_ptr);
2531
+ else
2532
+ nk_reduce_minmax_f16_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2533
+ max_index_ptr);
2534
+ }
2535
+
2536
+ NK_INTERNAL void nk_reduce_moments_e4m3_rvv_contiguous_( //
2537
+ nk_e4m3_t const *data_ptr, nk_size_t count, //
2538
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2539
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
2540
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
2541
+ vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
2542
+
2543
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
2544
+ vector_length = __riscv_vsetvl_e8m1(count);
2545
+ vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
2546
+
2547
+ // Convert e4m3 → f32 (m1 → m4)
2548
+ vfloat32m4_t data_f32m4 = nk_e4m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
2549
+
2550
+ // Accumulate at f32 precision
2551
+ sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
2552
+ sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
2553
+ }
2554
+
2555
+ // Horizontal reduction
2556
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
2557
+ *sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
2558
+ *sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
2559
+ }
2560
+
2561
+ NK_INTERNAL void nk_reduce_moments_e4m3_rvv_strided_( //
2562
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2563
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2564
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
2565
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
2566
+ vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
2567
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
2568
+
2569
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
2570
+ vector_length = __riscv_vsetvl_e8m1(count);
2571
+ vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
2572
+
2573
+ // Convert e4m3 → f32 (m1 → m4)
2574
+ vfloat32m4_t data_f32m4 = nk_e4m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
2575
+
2576
+ // Accumulate at f32 precision
2577
+ sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
2578
+ sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
2579
+ }
2580
+
2581
+ // Horizontal reduction
2582
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
2583
+ *sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
2584
+ *sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
2585
+ }
2586
+
2587
+ NK_PUBLIC void nk_reduce_moments_e4m3_rvv( //
2588
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2589
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2590
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
2591
+ int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
2592
+
2593
+ if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
2594
+ else if (!aligned) nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2595
+ else if (stride_elements == 1) nk_reduce_moments_e4m3_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
2596
+ else nk_reduce_moments_e4m3_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2597
+ }
2598
+
2599
+ NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_contiguous_( //
2600
+ nk_e4m3_t const *data_ptr, nk_size_t count, //
2601
+ nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
2602
+ nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
2603
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
2604
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, vlmax); // Largest comparable
2605
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax); // Smallest comparable
2606
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
2607
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
2608
+
2609
+ nk_size_t offset = 0;
2610
+ for (nk_size_t vector_length; count > 0;
2611
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
2612
+ vector_length = __riscv_vsetvl_e8m1(count);
2613
+ vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
2614
+
2615
+ // Convert to comparable form
2616
+ vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
2617
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
2618
+ vector_length);
2619
+
2620
+ // Detect E4M3 NaN: comparable == 0x00 (neg NaN) or comparable == 0xFF (pos NaN)
2621
+ vbool8_t nan_low_b8 = __riscv_vmseq_vx_u8m1_b8(comparable_u8m1, 0x00, vector_length);
2622
+ vbool8_t nan_high_b8 = __riscv_vmseq_vx_u8m1_b8(comparable_u8m1, 0xFF, vector_length);
2623
+ vbool8_t is_nan_b8 = __riscv_vmor_mm_b8(nan_low_b8, nan_high_b8, vector_length);
2624
+ vuint8m1_t data_min_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0xFF, is_nan_b8, vector_length);
2625
+ vuint8m1_t data_max_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0x00, is_nan_b8, vector_length);
2626
+
2627
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_min_u8m1, min_u8m1, vector_length);
2628
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_min_u8m1, less_b8, vector_length);
2629
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
2630
+ vector_length);
2631
+
2632
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_max_u8m1, vector_length);
2633
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_max_u8m1, greater_b8, vector_length);
2634
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
2635
+ vector_length);
2636
+ }
2637
+
2638
+ // Horizontal reduction + convert back
2639
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
2640
+ nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
2641
+
2642
+ // All-NaN case
2643
+ if (min_comparable == 0xFF) {
2644
+ *min_value_ptr = (nk_e4m3_t)NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX;
2645
+ *max_value_ptr = (nk_e4m3_t)NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX;
2646
+ return;
2647
+ }
2648
+
2649
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
2650
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
2651
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
2652
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2653
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2654
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
2655
+
2656
+ vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
2657
+ vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
2658
+ *min_value_ptr = (nk_e4m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
2659
+
2660
+ // Similar for max
2661
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
2662
+ nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
2663
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
2664
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
2665
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2666
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
2667
+
2668
+ vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
2669
+ vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
2670
+ *max_value_ptr = (nk_e4m3_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
2671
+ }
2672
+
2673
+ NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_strided_( //
2674
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2675
+ nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
2676
+ nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
2677
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
2678
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, vlmax);
2679
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
2680
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
2681
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
2682
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
2683
+
2684
+ nk_size_t offset = 0;
2685
+ for (nk_size_t vector_length; count > 0;
2686
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
2687
+ vector_length = __riscv_vsetvl_e8m1(count);
2688
+ vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
2689
+
2690
+ vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
2691
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
2692
+ vector_length);
2693
+
2694
+ // Detect E4M3 NaN: comparable == 0x00 (neg NaN) or comparable == 0xFF (pos NaN)
2695
+ vbool8_t nan_low_b8 = __riscv_vmseq_vx_u8m1_b8(comparable_u8m1, 0x00, vector_length);
2696
+ vbool8_t nan_high_b8 = __riscv_vmseq_vx_u8m1_b8(comparable_u8m1, 0xFF, vector_length);
2697
+ vbool8_t is_nan_b8 = __riscv_vmor_mm_b8(nan_low_b8, nan_high_b8, vector_length);
2698
+ vuint8m1_t data_min_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0xFF, is_nan_b8, vector_length);
2699
+ vuint8m1_t data_max_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0x00, is_nan_b8, vector_length);
2700
+
2701
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_min_u8m1, min_u8m1, vector_length);
2702
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_min_u8m1, less_b8, vector_length);
2703
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
2704
+ vector_length);
2705
+
2706
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_max_u8m1, vector_length);
2707
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_max_u8m1, greater_b8, vector_length);
2708
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
2709
+ vector_length);
2710
+ }
2711
+
2712
+ // Horizontal reduction (same as contiguous)
2713
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
2714
+ nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
2715
+
2716
+ // All-NaN case
2717
+ if (min_comparable == 0xFF) {
2718
+ *min_value_ptr = (nk_e4m3_t)NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX;
2719
+ *max_value_ptr = (nk_e4m3_t)NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX;
2720
+ return;
2721
+ }
2722
+
2723
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
2724
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
2725
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
2726
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2727
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2728
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
2729
+
2730
+ vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
2731
+ vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
2732
+ *min_value_ptr = (nk_e4m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
2733
+
2734
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
2735
+ nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
2736
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
2737
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
2738
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2739
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
2740
+
2741
+ vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
2742
+ vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
2743
+ *max_value_ptr = (nk_e4m3_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
2744
+ }
2745
+
2746
+ NK_PUBLIC void nk_reduce_minmax_e4m3_rvv( //
2747
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2748
+ nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
2749
+ nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
2750
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
2751
+ int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
2752
+
2753
+ if (count == 0)
2754
+ *min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E4M3_MIN,
2755
+ *max_index_ptr = NK_SIZE_MAX;
2756
+ else if (!aligned)
2757
+ nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2758
+ max_index_ptr);
2759
+ else if (stride_elements == 1)
2760
+ nk_reduce_minmax_e4m3_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
2761
+ max_index_ptr);
2762
+ else
2763
+ nk_reduce_minmax_e4m3_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2764
+ max_index_ptr);
2765
+ }
2766
+
2767
+ NK_INTERNAL void nk_reduce_moments_e5m2_rvv_contiguous_( //
2768
+ nk_e5m2_t const *data_ptr, nk_size_t count, //
2769
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2770
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
2771
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
2772
+ vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
2773
+
2774
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
2775
+ vector_length = __riscv_vsetvl_e8m1(count);
2776
+ vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
2777
+
2778
+ // Convert e5m2 → f32 (m1 → m4)
2779
+ vfloat32m4_t data_f32m4 = nk_e5m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
2780
+
2781
+ // Accumulate at f32 precision
2782
+ sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
2783
+ sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
2784
+ }
2785
+
2786
+ // Horizontal reduction
2787
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
2788
+ *sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
2789
+ *sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
2790
+ }
2791
+
2792
+ NK_INTERNAL void nk_reduce_moments_e5m2_rvv_strided_( //
2793
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2794
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2795
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
2796
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
2797
+ vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
2798
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
2799
+
2800
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
2801
+ vector_length = __riscv_vsetvl_e8m1(count);
2802
+ vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
2803
+
2804
+ // Convert e5m2 → f32 (m1 → m4)
2805
+ vfloat32m4_t data_f32m4 = nk_e5m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
2806
+
2807
+ // Accumulate at f32 precision
2808
+ sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
2809
+ sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
2810
+ }
2811
+
2812
+ // Horizontal reduction
2813
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
2814
+ *sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
2815
+ *sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
2816
+ }
2817
+
2818
+ NK_PUBLIC void nk_reduce_moments_e5m2_rvv( //
2819
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2820
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2821
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
2822
+ int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
2823
+
2824
+ if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
2825
+ else if (!aligned) nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2826
+ else if (stride_elements == 1) nk_reduce_moments_e5m2_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
2827
+ else nk_reduce_moments_e5m2_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2828
+ }
2829
+
2830
+ NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_contiguous_( //
2831
+ nk_e5m2_t const *data_ptr, nk_size_t count, //
2832
+ nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
2833
+ nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
2834
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
2835
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, vlmax);
2836
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
2837
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
2838
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
2839
+
2840
+ nk_size_t offset = 0;
2841
+ for (nk_size_t vector_length; count > 0;
2842
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
2843
+ vector_length = __riscv_vsetvl_e8m1(count);
2844
+ vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
2845
+
2846
+ vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
2847
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
2848
+ vector_length);
2849
+
2850
+ // Detect E5M2 NaN: comparable <= 0x02 (neg NaN) or comparable >= 0xFD (pos NaN)
2851
+ vbool8_t nan_low_b8 = __riscv_vmsleu_vx_u8m1_b8(comparable_u8m1, 0x02, vector_length);
2852
+ vbool8_t nan_high_b8 = __riscv_vmsgeu_vx_u8m1_b8(comparable_u8m1, 0xFD, vector_length);
2853
+ vbool8_t is_nan_b8 = __riscv_vmor_mm_b8(nan_low_b8, nan_high_b8, vector_length);
2854
+ vuint8m1_t data_min_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0xFF, is_nan_b8, vector_length);
2855
+ vuint8m1_t data_max_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0x00, is_nan_b8, vector_length);
2856
+
2857
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_min_u8m1, min_u8m1, vector_length);
2858
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_min_u8m1, less_b8, vector_length);
2859
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
2860
+ vector_length);
2861
+
2862
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_max_u8m1, vector_length);
2863
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_max_u8m1, greater_b8, vector_length);
2864
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
2865
+ vector_length);
2866
+ }
2867
+
2868
+ // Horizontal reduction + convert back
2869
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
2870
+ nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
2871
+
2872
+ // All-NaN case
2873
+ if (min_comparable == 0xFF) {
2874
+ *min_value_ptr = (nk_e5m2_t)NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX;
2875
+ *max_value_ptr = (nk_e5m2_t)NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX;
2876
+ return;
2877
+ }
2878
+
2879
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
2880
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
2881
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
2882
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2883
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2884
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
2885
+
2886
+ vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
2887
+ vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
2888
+ *min_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
2889
+
2890
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
2891
+ nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
2892
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
2893
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
2894
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2895
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
2896
+
2897
+ vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
2898
+ vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
2899
+ *max_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
2900
+ }
2901
+
2902
+ NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_strided_( //
2903
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2904
+ nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
2905
+ nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
2906
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
2907
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, vlmax);
2908
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
2909
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
2910
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
2911
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
2912
+
2913
+ nk_size_t offset = 0;
2914
+ for (nk_size_t vector_length; count > 0;
2915
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
2916
+ vector_length = __riscv_vsetvl_e8m1(count);
2917
+ vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
2918
+
2919
+ vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
2920
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
2921
+ vector_length);
2922
+
2923
+ // Detect E5M2 NaN: comparable <= 0x02 (neg NaN) or comparable >= 0xFD (pos NaN)
2924
+ vbool8_t nan_low_b8 = __riscv_vmsleu_vx_u8m1_b8(comparable_u8m1, 0x02, vector_length);
2925
+ vbool8_t nan_high_b8 = __riscv_vmsgeu_vx_u8m1_b8(comparable_u8m1, 0xFD, vector_length);
2926
+ vbool8_t is_nan_b8 = __riscv_vmor_mm_b8(nan_low_b8, nan_high_b8, vector_length);
2927
+ vuint8m1_t data_min_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0xFF, is_nan_b8, vector_length);
2928
+ vuint8m1_t data_max_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0x00, is_nan_b8, vector_length);
2929
+
2930
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_min_u8m1, min_u8m1, vector_length);
2931
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_min_u8m1, less_b8, vector_length);
2932
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
2933
+ vector_length);
2934
+
2935
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_max_u8m1, vector_length);
2936
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_max_u8m1, greater_b8, vector_length);
2937
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
2938
+ vector_length);
2939
+ }
2940
+
2941
+ // Horizontal reduction (same as contiguous)
2942
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
2943
+ nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
2944
+
2945
+ // All-NaN case
2946
+ if (min_comparable == 0xFF) {
2947
+ *min_value_ptr = (nk_e5m2_t)NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX;
2948
+ *max_value_ptr = (nk_e5m2_t)NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX;
2949
+ return;
2950
+ }
2951
+
2952
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
2953
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
2954
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
2955
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
2956
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2957
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
2958
+
2959
+ vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
2960
+ vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
2961
+ *min_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
2962
+
2963
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
2964
+ nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
2965
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
2966
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
2967
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
2968
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
2969
+
2970
+ vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
2971
+ vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
2972
+ *max_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
2973
+ }
2974
+
2975
+ NK_PUBLIC void nk_reduce_minmax_e5m2_rvv( //
2976
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2977
+ nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
2978
+ nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
2979
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
2980
+ int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
2981
+
2982
+ if (count == 0)
2983
+ *min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E5M2_MIN,
2984
+ *max_index_ptr = NK_SIZE_MAX;
2985
+ else if (!aligned)
2986
+ nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2987
+ max_index_ptr);
2988
+ else if (stride_elements == 1)
2989
+ nk_reduce_minmax_e5m2_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
2990
+ max_index_ptr);
2991
+ else
2992
+ nk_reduce_minmax_e5m2_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2993
+ max_index_ptr);
2994
+ }
2995
+
2996
+ NK_INTERNAL void nk_reduce_moments_e2m3_rvv_contiguous_( //
2997
+ nk_e2m3_t const *data_ptr, nk_size_t count, //
2998
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2999
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
3000
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
3001
+ vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
3002
+
3003
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
3004
+ vector_length = __riscv_vsetvl_e8m1(count);
3005
+ vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
3006
+
3007
+ // Convert e2m3 → f32 (m1 → m4)
3008
+ vfloat32m4_t data_f32m4 = nk_e2m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
3009
+
3010
+ // Accumulate at f32 precision
3011
+ sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
3012
+ sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
3013
+ }
3014
+
3015
+ // Horizontal reduction
3016
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
3017
+ *sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
3018
+ *sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
3019
+ }
3020
+
3021
+ NK_INTERNAL void nk_reduce_moments_e2m3_rvv_strided_( //
3022
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3023
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3024
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
3025
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
3026
+ vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
3027
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
3028
+
3029
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
3030
+ vector_length = __riscv_vsetvl_e8m1(count);
3031
+ vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
3032
+
3033
+ // Convert e2m3 → f32 (m1 → m4)
3034
+ vfloat32m4_t data_f32m4 = nk_e2m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
3035
+
3036
+ // Accumulate at f32 precision
3037
+ sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
3038
+ sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
3039
+ }
3040
+
3041
+ // Horizontal reduction
3042
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
3043
+ *sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
3044
+ *sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
3045
+ }
3046
+
3047
+ NK_PUBLIC void nk_reduce_moments_e2m3_rvv( //
3048
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3049
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3050
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
3051
+ int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
3052
+
3053
+ if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
3054
+ else if (!aligned) nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3055
+ else if (stride_elements == 1) nk_reduce_moments_e2m3_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
3056
+ else nk_reduce_moments_e2m3_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3057
+ }
3058
+
3059
+ NK_INTERNAL void nk_reduce_minmax_e2m3_rvv_contiguous_( //
3060
+ nk_e2m3_t const *data_ptr, nk_size_t count, //
3061
+ nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
3062
+ nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
3063
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
3064
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, vlmax); // Largest FP6 comparable
3065
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax); // Smallest FP6 comparable
3066
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
3067
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
3068
+
3069
+ nk_size_t offset = 0;
3070
+ for (nk_size_t vector_length; count > 0;
3071
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
3072
+ vector_length = __riscv_vsetvl_e8m1(count);
3073
+ vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
3074
+
3075
+ // Convert to FP6 comparable form
3076
+ vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
3077
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
3078
+ vector_length);
3079
+
3080
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(comparable_u8m1, min_u8m1, vector_length);
3081
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, comparable_u8m1, less_b8, vector_length);
3082
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
3083
+ vector_length);
3084
+
3085
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, comparable_u8m1, vector_length);
3086
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, comparable_u8m1, greater_b8, vector_length);
3087
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
3088
+ vector_length);
3089
+ }
3090
+
3091
+ // Horizontal reduction + convert back
3092
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
3093
+ nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
3094
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
3095
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
3096
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
3097
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
3098
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
3099
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
3100
+
3101
+ vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
3102
+ vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
3103
+ *min_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
3104
+
3105
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
3106
+ nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
3107
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
3108
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
3109
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
3110
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
3111
+
3112
+ vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
3113
+ vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
3114
+ *max_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
3115
+ }
3116
+
3117
+ NK_INTERNAL void nk_reduce_minmax_e2m3_rvv_strided_( //
3118
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3119
+ nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
3120
+ nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
3121
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
3122
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, vlmax);
3123
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
3124
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
3125
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
3126
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
3127
+
3128
+ nk_size_t offset = 0;
3129
+ for (nk_size_t vector_length; count > 0;
3130
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
3131
+ vector_length = __riscv_vsetvl_e8m1(count);
3132
+ vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
3133
+
3134
+ vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
3135
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
3136
+ vector_length);
3137
+
3138
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(comparable_u8m1, min_u8m1, vector_length);
3139
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, comparable_u8m1, less_b8, vector_length);
3140
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
3141
+ vector_length);
3142
+
3143
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, comparable_u8m1, vector_length);
3144
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, comparable_u8m1, greater_b8, vector_length);
3145
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
3146
+ vector_length);
3147
+ }
3148
+
3149
+ // Horizontal reduction (same as contiguous)
3150
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
3151
+ nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
3152
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
3153
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
3154
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
3155
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
3156
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
3157
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
3158
+
3159
+ vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
3160
+ vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
3161
+ *min_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
3162
+
3163
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
3164
+ nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
3165
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
3166
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
3167
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
3168
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
3169
+
3170
+ vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
3171
+ vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
3172
+ *max_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
3173
+ }
3174
+
3175
+ NK_PUBLIC void nk_reduce_minmax_e2m3_rvv( //
3176
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3177
+ nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
3178
+ nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
3179
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
3180
+ int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
3181
+
3182
+ if (count == 0)
3183
+ *min_value_ptr = NK_E2M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E2M3_MIN,
3184
+ *max_index_ptr = NK_SIZE_MAX;
3185
+ else if (!aligned)
3186
+ nk_reduce_minmax_e2m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3187
+ max_index_ptr);
3188
+ else if (stride_elements == 1)
3189
+ nk_reduce_minmax_e2m3_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
3190
+ max_index_ptr);
3191
+ else
3192
+ nk_reduce_minmax_e2m3_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3193
+ max_index_ptr);
3194
+ }
3195
+
3196
+ NK_INTERNAL void nk_reduce_moments_e3m2_rvv_contiguous_( //
3197
+ nk_e3m2_t const *data_ptr, nk_size_t count, //
3198
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3199
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
3200
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
3201
+ vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
3202
+
3203
+ for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
3204
+ vector_length = __riscv_vsetvl_e8m1(count);
3205
+ vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
3206
+
3207
+ // Convert e3m2 → f32 (m1 → m4)
3208
+ vfloat32m4_t data_f32m4 = nk_e3m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
3209
+
3210
+ // Accumulate at f32 precision
3211
+ sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
3212
+ sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
3213
+ }
3214
+
3215
+ // Horizontal reduction
3216
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
3217
+ *sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
3218
+ *sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
3219
+ }
3220
+
3221
+ NK_INTERNAL void nk_reduce_moments_e3m2_rvv_strided_( //
3222
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3223
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3224
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
3225
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
3226
+ vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
3227
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
3228
+
3229
+ for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
3230
+ vector_length = __riscv_vsetvl_e8m1(count);
3231
+ vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
3232
+
3233
+ // Convert e3m2 → f32 (m1 → m4)
3234
+ vfloat32m4_t data_f32m4 = nk_e3m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
3235
+
3236
+ // Accumulate at f32 precision
3237
+ sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
3238
+ sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
3239
+ }
3240
+
3241
+ // Horizontal reduction
3242
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
3243
+ *sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
3244
+ *sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
3245
+ }
3246
+
3247
+ NK_PUBLIC void nk_reduce_moments_e3m2_rvv( //
3248
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3249
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3250
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
3251
+ int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
3252
+
3253
+ if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
3254
+ else if (!aligned) nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3255
+ else if (stride_elements == 1) nk_reduce_moments_e3m2_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
3256
+ else nk_reduce_moments_e3m2_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3257
+ }
3258
+
3259
+ NK_INTERNAL void nk_reduce_minmax_e3m2_rvv_contiguous_( //
3260
+ nk_e3m2_t const *data_ptr, nk_size_t count, //
3261
+ nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
3262
+ nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
3263
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
3264
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, vlmax);
3265
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
3266
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
3267
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
3268
+
3269
+ nk_size_t offset = 0;
3270
+ for (nk_size_t vector_length; count > 0;
3271
+ count -= vector_length, offset += vector_length, data_ptr += vector_length) {
3272
+ vector_length = __riscv_vsetvl_e8m1(count);
3273
+ vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
3274
+
3275
+ vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
3276
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
3277
+ vector_length);
3278
+
3279
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(comparable_u8m1, min_u8m1, vector_length);
3280
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, comparable_u8m1, less_b8, vector_length);
3281
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
3282
+ vector_length);
3283
+
3284
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, comparable_u8m1, vector_length);
3285
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, comparable_u8m1, greater_b8, vector_length);
3286
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
3287
+ vector_length);
3288
+ }
3289
+
3290
+ // Horizontal reduction + convert back
3291
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
3292
+ nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
3293
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
3294
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
3295
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
3296
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
3297
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
3298
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
3299
+
3300
+ vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
3301
+ vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
3302
+ *min_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
3303
+
3304
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
3305
+ nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
3306
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
3307
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
3308
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
3309
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
3310
+
3311
+ vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
3312
+ vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
3313
+ *max_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
3314
+ }
3315
+
3316
+ NK_INTERNAL void nk_reduce_minmax_e3m2_rvv_strided_( //
3317
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3318
+ nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
3319
+ nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
3320
+ nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
3321
+ vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, vlmax);
3322
+ vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
3323
+ vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
3324
+ vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
3325
+ unsigned char const *ptr = (unsigned char const *)data_ptr;
3326
+
3327
+ nk_size_t offset = 0;
3328
+ for (nk_size_t vector_length; count > 0;
3329
+ count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
3330
+ vector_length = __riscv_vsetvl_e8m1(count);
3331
+ vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
3332
+
3333
+ vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
3334
+ vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
3335
+ vector_length);
3336
+
3337
+ vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(comparable_u8m1, min_u8m1, vector_length);
3338
+ min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, comparable_u8m1, less_b8, vector_length);
3339
+ min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
3340
+ vector_length);
3341
+
3342
+ vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, comparable_u8m1, vector_length);
3343
+ max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, comparable_u8m1, greater_b8, vector_length);
3344
+ max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
3345
+ vector_length);
3346
+ }
3347
+
3348
+ // Horizontal reduction (same as contiguous)
3349
+ vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
3350
+ nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
3351
+ vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
3352
+ vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
3353
+ vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
3354
+ vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
3355
+ *min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
3356
+ __riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
3357
+
3358
+ vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
3359
+ vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
3360
+ *min_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
3361
+
3362
+ vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
3363
+ nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
3364
+ vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
3365
+ vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
3366
+ *max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
3367
+ __riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
3368
+
3369
+ vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
3370
+ vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
3371
+ *max_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
3372
+ }
3373
+
3374
+ NK_PUBLIC void nk_reduce_minmax_e3m2_rvv( //
3375
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3376
+ nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
3377
+ nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
3378
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
3379
+ int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
3380
+
3381
+ if (count == 0)
3382
+ *min_value_ptr = NK_E3M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E3M2_MIN,
3383
+ *max_index_ptr = NK_SIZE_MAX;
3384
+ else if (!aligned)
3385
+ nk_reduce_minmax_e3m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3386
+ max_index_ptr);
3387
+ else if (stride_elements == 1)
3388
+ nk_reduce_minmax_e3m2_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
3389
+ max_index_ptr);
3390
+ else
3391
+ nk_reduce_minmax_e3m2_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3392
+ max_index_ptr);
3393
+ }
3394
+
3395
+ #if defined(__clang__)
3396
+ #pragma clang attribute pop
3397
+ #elif defined(__GNUC__)
3398
+ #pragma GCC pop_options
3399
+ #endif
3400
+
3401
+ #if defined(__cplusplus)
3402
+ } // extern "C"
3403
+ #endif
3404
+
3405
+ #endif // NK_TARGET_RVV
3406
+ #endif // NK_TARGET_RISCV_
3407
+ #endif // NK_REDUCE_RVV_H