numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,3407 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Reductions for RISC-V.
|
|
3
|
+
* @file include/numkong/reduce/rvv.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 13, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/reduce.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_REDUCE_RVV_H
|
|
10
|
+
#define NK_REDUCE_RVV_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_RISCV_
|
|
13
|
+
#if NK_TARGET_RVV
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.h"
|
|
16
|
+
#include "numkong/cast/rvv.h"
|
|
17
|
+
#include "numkong/reduce/serial.h"
|
|
18
|
+
|
|
19
|
+
#if defined(__clang__)
|
|
20
|
+
#pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
|
|
21
|
+
#elif defined(__GNUC__)
|
|
22
|
+
#pragma GCC push_options
|
|
23
|
+
#pragma GCC target("arch=+v")
|
|
24
|
+
#endif
|
|
25
|
+
|
|
26
|
+
#if defined(__cplusplus)
|
|
27
|
+
extern "C" {
|
|
28
|
+
#endif
|
|
29
|
+
|
|
30
|
+
/** @brief Saturating horizontal sum of u64m1 via tree fold: O(log vlmax) vector ops. */
|
|
31
|
+
NK_INTERNAL nk_u64_t nk_reduce_vsaddu_u64m1_rvv_(vuint64m1_t acc_u64m1, nk_size_t vlmax) {
|
|
32
|
+
for (nk_size_t half = vlmax >> 1; half > 0; half >>= 1) {
|
|
33
|
+
vuint64m1_t shifted_u64m1 = __riscv_vslidedown_vx_u64m1(acc_u64m1, half, vlmax);
|
|
34
|
+
acc_u64m1 = __riscv_vsaddu_vv_u64m1(acc_u64m1, shifted_u64m1, vlmax);
|
|
35
|
+
}
|
|
36
|
+
return __riscv_vmv_x_s_u64m1_u64(acc_u64m1);
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
/** @brief Saturating horizontal sum of u64m2 via tree fold: O(log vlmax) vector ops. */
|
|
40
|
+
NK_INTERNAL nk_u64_t nk_reduce_vsaddu_u64m2_rvv_(vuint64m2_t acc_u64m2, nk_size_t vlmax) {
|
|
41
|
+
for (nk_size_t half = vlmax >> 1; half > 0; half >>= 1) {
|
|
42
|
+
vuint64m2_t shifted_u64m2 = __riscv_vslidedown_vx_u64m2(acc_u64m2, half, vlmax);
|
|
43
|
+
acc_u64m2 = __riscv_vsaddu_vv_u64m2(acc_u64m2, shifted_u64m2, vlmax);
|
|
44
|
+
}
|
|
45
|
+
return __riscv_vmv_x_s_u64m2_u64(acc_u64m2);
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/** @brief 128-bit horizontal sum of (upper:i64m1, lower:u64m1) via tree fold, then saturate to i64. */
|
|
49
|
+
NK_INTERNAL nk_i64_t nk_reduce_128bit_sum_i64m1_rvv_( //
|
|
50
|
+
vuint64m1_t sum_lower_u64m1, vint64m1_t sum_upper_i64m1, nk_size_t vlmax) {
|
|
51
|
+
for (nk_size_t half = vlmax >> 1; half > 0; half >>= 1) {
|
|
52
|
+
vuint64m1_t shifted_lower_u64m1 = __riscv_vslidedown_vx_u64m1(sum_lower_u64m1, half, vlmax);
|
|
53
|
+
vint64m1_t shifted_upper_i64m1 = __riscv_vslidedown_vx_i64m1(sum_upper_i64m1, half, vlmax);
|
|
54
|
+
vuint64m1_t new_lower_u64m1 = __riscv_vadd_vv_u64m1(sum_lower_u64m1, shifted_lower_u64m1, vlmax);
|
|
55
|
+
vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(new_lower_u64m1, sum_lower_u64m1, vlmax);
|
|
56
|
+
vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0, vlmax), 1, carry_b64, vlmax);
|
|
57
|
+
sum_upper_i64m1 = __riscv_vadd_vv_i64m1(sum_upper_i64m1, shifted_upper_i64m1, vlmax);
|
|
58
|
+
sum_upper_i64m1 = __riscv_vadd_vv_i64m1(sum_upper_i64m1, carry_i64m1, vlmax);
|
|
59
|
+
sum_lower_u64m1 = new_lower_u64m1;
|
|
60
|
+
}
|
|
61
|
+
nk_u64_t total_lower = __riscv_vmv_x_s_u64m1_u64(sum_lower_u64m1);
|
|
62
|
+
nk_i64_t total_upper = __riscv_vmv_x_s_i64m1_i64(sum_upper_i64m1);
|
|
63
|
+
nk_i64_t total_lower_signed = (nk_i64_t)total_lower;
|
|
64
|
+
if (total_upper == (total_lower_signed >> 63)) return total_lower_signed;
|
|
65
|
+
else if (total_upper >= 0) return NK_I64_MAX;
|
|
66
|
+
else return NK_I64_MIN;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
/** @brief 128-bit horizontal sum of (upper:i64m2, lower:u64m2) via tree fold, then saturate to i64. */
|
|
70
|
+
NK_INTERNAL nk_i64_t nk_reduce_128bit_sum_i64m2_rvv_( //
|
|
71
|
+
vuint64m2_t sum_lower_u64m2, vint64m2_t sum_upper_i64m2, nk_size_t vlmax) {
|
|
72
|
+
for (nk_size_t half = vlmax >> 1; half > 0; half >>= 1) {
|
|
73
|
+
vuint64m2_t shifted_lower_u64m2 = __riscv_vslidedown_vx_u64m2(sum_lower_u64m2, half, vlmax);
|
|
74
|
+
vint64m2_t shifted_upper_i64m2 = __riscv_vslidedown_vx_i64m2(sum_upper_i64m2, half, vlmax);
|
|
75
|
+
vuint64m2_t new_lower_u64m2 = __riscv_vadd_vv_u64m2(sum_lower_u64m2, shifted_lower_u64m2, vlmax);
|
|
76
|
+
vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(new_lower_u64m2, sum_lower_u64m2, vlmax);
|
|
77
|
+
vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0, vlmax), 1, carry_b32, vlmax);
|
|
78
|
+
sum_upper_i64m2 = __riscv_vadd_vv_i64m2(sum_upper_i64m2, shifted_upper_i64m2, vlmax);
|
|
79
|
+
sum_upper_i64m2 = __riscv_vadd_vv_i64m2(sum_upper_i64m2, carry_i64m2, vlmax);
|
|
80
|
+
sum_lower_u64m2 = new_lower_u64m2;
|
|
81
|
+
}
|
|
82
|
+
nk_u64_t total_lower = __riscv_vmv_x_s_u64m2_u64(sum_lower_u64m2);
|
|
83
|
+
nk_i64_t total_upper = __riscv_vmv_x_s_i64m2_i64(sum_upper_i64m2);
|
|
84
|
+
nk_i64_t total_lower_signed = (nk_i64_t)total_lower;
|
|
85
|
+
if (total_upper == (total_lower_signed >> 63)) return total_lower_signed;
|
|
86
|
+
else if (total_upper >= 0) return NK_I64_MAX;
|
|
87
|
+
else return NK_I64_MIN;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
NK_INTERNAL void nk_reduce_moments_f32_rvv_contiguous_( //
|
|
91
|
+
nk_f32_t const *data, nk_size_t count, //
|
|
92
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
93
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
|
|
94
|
+
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
|
|
95
|
+
vfloat64m2_t sumsq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
|
|
96
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data += vector_length) {
|
|
97
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
98
|
+
vfloat32m1_t data_f32m1 = __riscv_vle32_v_f32m1(data, vector_length);
|
|
99
|
+
sum_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_f64m2, sum_f64m2, data_f32m1, vector_length);
|
|
100
|
+
sumsq_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sumsq_f64m2, data_f32m1, data_f32m1, vector_length);
|
|
101
|
+
}
|
|
102
|
+
vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
103
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero, vlmax)),
|
|
104
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sumsq_f64m2, zero, vlmax));
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
NK_INTERNAL void nk_reduce_moments_f32_rvv_strided_( //
|
|
108
|
+
nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
109
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
110
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
|
|
111
|
+
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
|
|
112
|
+
vfloat64m2_t sumsq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
|
|
113
|
+
unsigned char const *ptr = (unsigned char const *)data;
|
|
114
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
115
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
116
|
+
vfloat32m1_t data_f32m1 = __riscv_vlse32_v_f32m1((nk_f32_t const *)ptr, (nk_ssize_t)stride_bytes,
|
|
117
|
+
vector_length);
|
|
118
|
+
sum_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_f64m2, sum_f64m2, data_f32m1, vector_length);
|
|
119
|
+
sumsq_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sumsq_f64m2, data_f32m1, data_f32m1, vector_length);
|
|
120
|
+
}
|
|
121
|
+
vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
122
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero, vlmax)),
|
|
123
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sumsq_f64m2, zero, vlmax));
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
NK_PUBLIC void nk_reduce_moments_f32_rvv( //
|
|
127
|
+
nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
128
|
+
nk_f64_t *sum, nk_f64_t *sumsq) {
|
|
129
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
|
|
130
|
+
int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
|
|
131
|
+
if (count == 0) *sum = 0, *sumsq = 0;
|
|
132
|
+
else if (!aligned) nk_reduce_moments_f32_serial(data, count, stride_bytes, sum, sumsq);
|
|
133
|
+
else if (stride_elements == 1) nk_reduce_moments_f32_rvv_contiguous_(data, count, sum, sumsq);
|
|
134
|
+
else nk_reduce_moments_f32_rvv_strided_(data, count, stride_bytes, sum, sumsq);
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
NK_INTERNAL void nk_reduce_minmax_f32_rvv_contiguous_( //
|
|
138
|
+
nk_f32_t const *data, nk_size_t count, //
|
|
139
|
+
nk_f32_t *min_value, nk_size_t *min_index, //
|
|
140
|
+
nk_f32_t *max_value, nk_size_t *max_index) {
|
|
141
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
|
|
142
|
+
vfloat32m1_t min = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, vlmax);
|
|
143
|
+
vfloat32m1_t max = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, vlmax);
|
|
144
|
+
vuint64m2_t min_indices = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
145
|
+
vuint64m2_t max_indices = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
146
|
+
nk_size_t offset = 0;
|
|
147
|
+
for (nk_size_t remaining = count, vector_length; remaining > 0;
|
|
148
|
+
remaining -= vector_length, offset += vector_length) {
|
|
149
|
+
vector_length = __riscv_vsetvl_e32m1(remaining);
|
|
150
|
+
vfloat32m1_t data_f32m1 = __riscv_vle32_v_f32m1(data + offset, vector_length);
|
|
151
|
+
vuint64m2_t position_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
|
|
152
|
+
vector_length);
|
|
153
|
+
vbool32_t less_b32 = __riscv_vmflt_vv_f32m1_b32(data_f32m1, min, vector_length);
|
|
154
|
+
min = __riscv_vmerge_vvm_f32m1_tu(min, min, data_f32m1, less_b32, vector_length);
|
|
155
|
+
min_indices = __riscv_vmerge_vvm_u64m2_tu(min_indices, min_indices, position_u64m2, less_b32, vector_length);
|
|
156
|
+
vbool32_t greater_b32 = __riscv_vmflt_vv_f32m1_b32(max, data_f32m1, vector_length);
|
|
157
|
+
max = __riscv_vmerge_vvm_f32m1_tu(max, max, data_f32m1, greater_b32, vector_length);
|
|
158
|
+
max_indices = __riscv_vmerge_vvm_u64m2_tu(max_indices, max_indices, position_u64m2, greater_b32, vector_length);
|
|
159
|
+
}
|
|
160
|
+
vfloat32m1_t id_max = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
161
|
+
nk_f32_t mn = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmin_vs_f32m1_f32m1(min, id_max, vlmax));
|
|
162
|
+
vfloat32m1_t id_min = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
163
|
+
nk_f32_t mx = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmax_vs_f32m1_f32m1(max, id_min, vlmax));
|
|
164
|
+
if (mn == NK_F32_MAX && mx == NK_F32_MIN) {
|
|
165
|
+
*min_value = NK_F32_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F32_MIN, *max_index = NK_SIZE_MAX;
|
|
166
|
+
return;
|
|
167
|
+
}
|
|
168
|
+
vbool32_t min_match_b32 = __riscv_vmfeq_vf_f32m1_b32(min, mn, vlmax);
|
|
169
|
+
vuint64m2_t sentinel = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
|
|
170
|
+
vuint64m2_t min_cands = __riscv_vmerge_vvm_u64m2(sentinel, min_indices, min_match_b32, vlmax);
|
|
171
|
+
vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
172
|
+
*min_value = mn,
|
|
173
|
+
*min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m2_u64m1(min_cands, id_umax, vlmax));
|
|
174
|
+
vbool32_t max_match_b32 = __riscv_vmfeq_vf_f32m1_b32(max, mx, vlmax);
|
|
175
|
+
vuint64m2_t max_cands = __riscv_vmerge_vvm_u64m2(sentinel, max_indices, max_match_b32, vlmax);
|
|
176
|
+
*max_value = mx,
|
|
177
|
+
*max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m2_u64m1(max_cands, id_umax, vlmax));
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
NK_INTERNAL void nk_reduce_minmax_f32_rvv_strided_( //
|
|
181
|
+
nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
182
|
+
nk_f32_t *min_value, nk_size_t *min_index, //
|
|
183
|
+
nk_f32_t *max_value, nk_size_t *max_index) {
|
|
184
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
|
|
185
|
+
vfloat32m1_t min = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, vlmax);
|
|
186
|
+
vfloat32m1_t max = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, vlmax);
|
|
187
|
+
vuint64m2_t min_indices = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
188
|
+
vuint64m2_t max_indices = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
189
|
+
unsigned char const *ptr = (unsigned char const *)data;
|
|
190
|
+
nk_size_t offset = 0;
|
|
191
|
+
for (nk_size_t remaining = count, vector_length; remaining > 0;
|
|
192
|
+
remaining -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
193
|
+
vector_length = __riscv_vsetvl_e32m1(remaining);
|
|
194
|
+
vfloat32m1_t data_f32m1 = __riscv_vlse32_v_f32m1((nk_f32_t const *)ptr, (nk_ssize_t)stride_bytes,
|
|
195
|
+
vector_length);
|
|
196
|
+
vuint64m2_t position_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
|
|
197
|
+
vector_length);
|
|
198
|
+
vbool32_t less_b32 = __riscv_vmflt_vv_f32m1_b32(data_f32m1, min, vector_length);
|
|
199
|
+
min = __riscv_vmerge_vvm_f32m1_tu(min, min, data_f32m1, less_b32, vector_length);
|
|
200
|
+
min_indices = __riscv_vmerge_vvm_u64m2_tu(min_indices, min_indices, position_u64m2, less_b32, vector_length);
|
|
201
|
+
vbool32_t greater_b32 = __riscv_vmflt_vv_f32m1_b32(max, data_f32m1, vector_length);
|
|
202
|
+
max = __riscv_vmerge_vvm_f32m1_tu(max, max, data_f32m1, greater_b32, vector_length);
|
|
203
|
+
max_indices = __riscv_vmerge_vvm_u64m2_tu(max_indices, max_indices, position_u64m2, greater_b32, vector_length);
|
|
204
|
+
}
|
|
205
|
+
vfloat32m1_t id_max = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
206
|
+
nk_f32_t mn = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmin_vs_f32m1_f32m1(min, id_max, vlmax));
|
|
207
|
+
vfloat32m1_t id_min = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
208
|
+
nk_f32_t mx = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmax_vs_f32m1_f32m1(max, id_min, vlmax));
|
|
209
|
+
if (mn == NK_F32_MAX && mx == NK_F32_MIN) {
|
|
210
|
+
*min_value = NK_F32_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F32_MIN, *max_index = NK_SIZE_MAX;
|
|
211
|
+
return;
|
|
212
|
+
}
|
|
213
|
+
vbool32_t min_match_b32 = __riscv_vmfeq_vf_f32m1_b32(min, mn, vlmax);
|
|
214
|
+
vuint64m2_t sentinel = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
|
|
215
|
+
vuint64m2_t min_cands = __riscv_vmerge_vvm_u64m2(sentinel, min_indices, min_match_b32, vlmax);
|
|
216
|
+
vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
217
|
+
*min_value = mn,
|
|
218
|
+
*min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m2_u64m1(min_cands, id_umax, vlmax));
|
|
219
|
+
vbool32_t max_match_b32 = __riscv_vmfeq_vf_f32m1_b32(max, mx, vlmax);
|
|
220
|
+
vuint64m2_t max_cands = __riscv_vmerge_vvm_u64m2(sentinel, max_indices, max_match_b32, vlmax);
|
|
221
|
+
*max_value = mx,
|
|
222
|
+
*max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m2_u64m1(max_cands, id_umax, vlmax));
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
NK_PUBLIC void nk_reduce_minmax_f32_rvv( //
|
|
226
|
+
nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
227
|
+
nk_f32_t *min_value, nk_size_t *min_index, //
|
|
228
|
+
nk_f32_t *max_value, nk_size_t *max_index) {
|
|
229
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
|
|
230
|
+
int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
|
|
231
|
+
if (count == 0)
|
|
232
|
+
*min_value = NK_F32_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F32_MIN, *max_index = NK_SIZE_MAX;
|
|
233
|
+
else if (!aligned)
|
|
234
|
+
nk_reduce_minmax_f32_serial(data, count, stride_bytes, min_value, min_index, max_value, max_index);
|
|
235
|
+
else if (stride_elements == 1)
|
|
236
|
+
nk_reduce_minmax_f32_rvv_contiguous_(data, count, min_value, min_index, max_value, max_index);
|
|
237
|
+
else nk_reduce_minmax_f32_rvv_strided_(data, count, stride_bytes, min_value, min_index, max_value, max_index);
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
NK_INTERNAL void nk_reduce_moments_f64_rvv_contiguous_( //
|
|
241
|
+
nk_f64_t const *data, nk_size_t count, //
|
|
242
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
243
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
244
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
245
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
246
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data += vector_length) {
|
|
247
|
+
vector_length = __riscv_vsetvl_e64m4(count);
|
|
248
|
+
vfloat64m4_t data_f64m4 = __riscv_vle64_v_f64m4(data, vector_length);
|
|
249
|
+
sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
|
|
250
|
+
sumsq_f64m4 = __riscv_vfmacc_vv_f64m4_tu(sumsq_f64m4, data_f64m4, data_f64m4, vector_length);
|
|
251
|
+
}
|
|
252
|
+
vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
253
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero, vlmax)),
|
|
254
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero, vlmax));
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
NK_INTERNAL void nk_reduce_moments_f64_rvv_strided_( //
|
|
258
|
+
nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
259
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
260
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
261
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
262
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
263
|
+
unsigned char const *ptr = (unsigned char const *)data;
|
|
264
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
265
|
+
vector_length = __riscv_vsetvl_e64m4(count);
|
|
266
|
+
vfloat64m4_t data_f64m4 = __riscv_vlse64_v_f64m4((nk_f64_t const *)ptr, (nk_ssize_t)stride_bytes,
|
|
267
|
+
vector_length);
|
|
268
|
+
sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
|
|
269
|
+
sumsq_f64m4 = __riscv_vfmacc_vv_f64m4_tu(sumsq_f64m4, data_f64m4, data_f64m4, vector_length);
|
|
270
|
+
}
|
|
271
|
+
vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
272
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero, vlmax)),
|
|
273
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero, vlmax));
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
NK_PUBLIC void nk_reduce_moments_f64_rvv( //
|
|
277
|
+
nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
278
|
+
nk_f64_t *sum, nk_f64_t *sumsq) {
|
|
279
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
|
|
280
|
+
int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
|
|
281
|
+
if (count == 0) *sum = 0, *sumsq = 0;
|
|
282
|
+
else if (!aligned) nk_reduce_moments_f64_serial(data, count, stride_bytes, sum, sumsq);
|
|
283
|
+
else if (stride_elements == 1) nk_reduce_moments_f64_rvv_contiguous_(data, count, sum, sumsq);
|
|
284
|
+
else nk_reduce_moments_f64_rvv_strided_(data, count, stride_bytes, sum, sumsq);
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
NK_INTERNAL void nk_reduce_minmax_f64_rvv_contiguous_( //
|
|
288
|
+
nk_f64_t const *data, nk_size_t count, //
|
|
289
|
+
nk_f64_t *min_value, nk_size_t *min_index, //
|
|
290
|
+
nk_f64_t *max_value, nk_size_t *max_index) {
|
|
291
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
292
|
+
vfloat64m1_t min = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, vlmax);
|
|
293
|
+
vfloat64m1_t max = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, vlmax);
|
|
294
|
+
vuint64m1_t min_indices = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
295
|
+
vuint64m1_t max_indices = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
296
|
+
nk_size_t offset = 0;
|
|
297
|
+
for (nk_size_t remaining = count, vector_length; remaining > 0;
|
|
298
|
+
remaining -= vector_length, offset += vector_length) {
|
|
299
|
+
vector_length = __riscv_vsetvl_e64m1(remaining);
|
|
300
|
+
vfloat64m1_t data_f64m1 = __riscv_vle64_v_f64m1(data + offset, vector_length);
|
|
301
|
+
vuint64m1_t position_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
|
|
302
|
+
vector_length);
|
|
303
|
+
vbool64_t less_b64 = __riscv_vmflt_vv_f64m1_b64(data_f64m1, min, vector_length);
|
|
304
|
+
min = __riscv_vmerge_vvm_f64m1_tu(min, min, data_f64m1, less_b64, vector_length);
|
|
305
|
+
min_indices = __riscv_vmerge_vvm_u64m1_tu(min_indices, min_indices, position_u64m1, less_b64, vector_length);
|
|
306
|
+
vbool64_t greater_b64 = __riscv_vmflt_vv_f64m1_b64(max, data_f64m1, vector_length);
|
|
307
|
+
max = __riscv_vmerge_vvm_f64m1_tu(max, max, data_f64m1, greater_b64, vector_length);
|
|
308
|
+
max_indices = __riscv_vmerge_vvm_u64m1_tu(max_indices, max_indices, position_u64m1, greater_b64, vector_length);
|
|
309
|
+
}
|
|
310
|
+
vfloat64m1_t id_max = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, 1);
|
|
311
|
+
nk_f64_t mn = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmin_vs_f64m1_f64m1(min, id_max, vlmax));
|
|
312
|
+
vfloat64m1_t id_min = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, 1);
|
|
313
|
+
nk_f64_t mx = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmax_vs_f64m1_f64m1(max, id_min, vlmax));
|
|
314
|
+
if (mn == NK_F64_MAX && mx == NK_F64_MIN) {
|
|
315
|
+
*min_value = NK_F64_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F64_MIN, *max_index = NK_SIZE_MAX;
|
|
316
|
+
return;
|
|
317
|
+
}
|
|
318
|
+
vbool64_t min_match_b64 = __riscv_vmfeq_vf_f64m1_b64(min, mn, vlmax);
|
|
319
|
+
vuint64m1_t sentinel = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
|
|
320
|
+
vuint64m1_t min_cands = __riscv_vmerge_vvm_u64m1(sentinel, min_indices, min_match_b64, vlmax);
|
|
321
|
+
vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
322
|
+
*min_value = mn,
|
|
323
|
+
*min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(min_cands, id_umax, vlmax));
|
|
324
|
+
vbool64_t max_match_b64 = __riscv_vmfeq_vf_f64m1_b64(max, mx, vlmax);
|
|
325
|
+
vuint64m1_t max_cands = __riscv_vmerge_vvm_u64m1(sentinel, max_indices, max_match_b64, vlmax);
|
|
326
|
+
*max_value = mx,
|
|
327
|
+
*max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(max_cands, id_umax, vlmax));
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
NK_INTERNAL void nk_reduce_minmax_f64_rvv_strided_( //
|
|
331
|
+
nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
332
|
+
nk_f64_t *min_value, nk_size_t *min_index, //
|
|
333
|
+
nk_f64_t *max_value, nk_size_t *max_index) {
|
|
334
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
335
|
+
vfloat64m1_t min = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, vlmax);
|
|
336
|
+
vfloat64m1_t max = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, vlmax);
|
|
337
|
+
vuint64m1_t min_indices = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
338
|
+
vuint64m1_t max_indices = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
339
|
+
unsigned char const *ptr = (unsigned char const *)data;
|
|
340
|
+
nk_size_t offset = 0;
|
|
341
|
+
for (nk_size_t remaining = count, vector_length; remaining > 0;
|
|
342
|
+
remaining -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
343
|
+
vector_length = __riscv_vsetvl_e64m1(remaining);
|
|
344
|
+
vfloat64m1_t data_f64m1 = __riscv_vlse64_v_f64m1((nk_f64_t const *)ptr, (nk_ssize_t)stride_bytes,
|
|
345
|
+
vector_length);
|
|
346
|
+
vuint64m1_t position_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
|
|
347
|
+
vector_length);
|
|
348
|
+
vbool64_t less_b64 = __riscv_vmflt_vv_f64m1_b64(data_f64m1, min, vector_length);
|
|
349
|
+
min = __riscv_vmerge_vvm_f64m1_tu(min, min, data_f64m1, less_b64, vector_length);
|
|
350
|
+
min_indices = __riscv_vmerge_vvm_u64m1_tu(min_indices, min_indices, position_u64m1, less_b64, vector_length);
|
|
351
|
+
vbool64_t greater_b64 = __riscv_vmflt_vv_f64m1_b64(max, data_f64m1, vector_length);
|
|
352
|
+
max = __riscv_vmerge_vvm_f64m1_tu(max, max, data_f64m1, greater_b64, vector_length);
|
|
353
|
+
max_indices = __riscv_vmerge_vvm_u64m1_tu(max_indices, max_indices, position_u64m1, greater_b64, vector_length);
|
|
354
|
+
}
|
|
355
|
+
vfloat64m1_t id_max = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, 1);
|
|
356
|
+
nk_f64_t mn = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmin_vs_f64m1_f64m1(min, id_max, vlmax));
|
|
357
|
+
vfloat64m1_t id_min = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, 1);
|
|
358
|
+
nk_f64_t mx = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmax_vs_f64m1_f64m1(max, id_min, vlmax));
|
|
359
|
+
if (mn == NK_F64_MAX && mx == NK_F64_MIN) {
|
|
360
|
+
*min_value = NK_F64_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F64_MIN, *max_index = NK_SIZE_MAX;
|
|
361
|
+
return;
|
|
362
|
+
}
|
|
363
|
+
vbool64_t min_match_b64 = __riscv_vmfeq_vf_f64m1_b64(min, mn, vlmax);
|
|
364
|
+
vuint64m1_t sentinel = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
|
|
365
|
+
vuint64m1_t min_cands = __riscv_vmerge_vvm_u64m1(sentinel, min_indices, min_match_b64, vlmax);
|
|
366
|
+
vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
367
|
+
*min_value = mn,
|
|
368
|
+
*min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(min_cands, id_umax, vlmax));
|
|
369
|
+
vbool64_t max_match_b64 = __riscv_vmfeq_vf_f64m1_b64(max, mx, vlmax);
|
|
370
|
+
vuint64m1_t max_cands = __riscv_vmerge_vvm_u64m1(sentinel, max_indices, max_match_b64, vlmax);
|
|
371
|
+
*max_value = mx,
|
|
372
|
+
*max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(max_cands, id_umax, vlmax));
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
NK_PUBLIC void nk_reduce_minmax_f64_rvv( //
|
|
376
|
+
nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
377
|
+
nk_f64_t *min_value, nk_size_t *min_index, //
|
|
378
|
+
nk_f64_t *max_value, nk_size_t *max_index) {
|
|
379
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
|
|
380
|
+
int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
|
|
381
|
+
if (count == 0)
|
|
382
|
+
*min_value = NK_F64_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F64_MIN, *max_index = NK_SIZE_MAX;
|
|
383
|
+
else if (!aligned)
|
|
384
|
+
nk_reduce_minmax_f64_serial(data, count, stride_bytes, min_value, min_index, max_value, max_index);
|
|
385
|
+
else if (stride_elements == 1)
|
|
386
|
+
nk_reduce_minmax_f64_rvv_contiguous_(data, count, min_value, min_index, max_value, max_index);
|
|
387
|
+
else nk_reduce_minmax_f64_rvv_strided_(data, count, stride_bytes, min_value, min_index, max_value, max_index);
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
NK_INTERNAL vuint8m1_t nk_fp8m1_to_comparable_u8m1_rvv_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
391
|
+
// Convert FP8 (e4m3/e5m2) to comparable unsigned form (sign bit 7)
|
|
392
|
+
// Positive (sign=0): XOR 0x80 → [0x80, 0xFF]
|
|
393
|
+
// Negative (sign=1): Bitwise NOT → [0x00, 0x7F]
|
|
394
|
+
vbool8_t is_negative_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw_u8m1, 0x80, vector_length), 0,
|
|
395
|
+
vector_length);
|
|
396
|
+
vuint8m1_t flip_positive_u8m1 = __riscv_vxor_vx_u8m1(raw_u8m1, 0x80, vector_length);
|
|
397
|
+
vuint8m1_t flip_negative_u8m1 = __riscv_vnot_v_u8m1(raw_u8m1, vector_length);
|
|
398
|
+
return __riscv_vmerge_vvm_u8m1(flip_positive_u8m1, flip_negative_u8m1, is_negative_b8, vector_length);
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
NK_INTERNAL vuint8m1_t nk_comparable_to_fp8m1_rvv_(vuint8m1_t comparable_u8m1, nk_size_t vector_length) {
|
|
402
|
+
// Reverse: if >= 0x80 (was positive), XOR; else NOT
|
|
403
|
+
vbool8_t was_positive_b8 = __riscv_vmsgeu_vx_u8m1_b8(comparable_u8m1, 0x80, vector_length);
|
|
404
|
+
vuint8m1_t from_positive_u8m1 = __riscv_vxor_vx_u8m1(comparable_u8m1, 0x80, vector_length);
|
|
405
|
+
vuint8m1_t from_negative_u8m1 = __riscv_vnot_v_u8m1(comparable_u8m1, vector_length);
|
|
406
|
+
return __riscv_vmerge_vvm_u8m1(from_negative_u8m1, from_positive_u8m1, was_positive_b8, vector_length);
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
NK_INTERNAL vuint8m1_t nk_fp6m1_to_comparable_u8m1_rvv_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
410
|
+
// Convert FP6 (e2m3/e3m2) to comparable unsigned form (sign bit 5)
|
|
411
|
+
// Positive (sign=0): XOR 0x20 → [0x20, 0x3F]
|
|
412
|
+
// Negative (sign=1): XOR 0x3F (NOT lower 6 bits) → [0x00, 0x1F]
|
|
413
|
+
vbool8_t is_negative_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw_u8m1, 0x20, vector_length), 0,
|
|
414
|
+
vector_length);
|
|
415
|
+
vuint8m1_t flip_positive_u8m1 = __riscv_vxor_vx_u8m1(raw_u8m1, 0x20, vector_length);
|
|
416
|
+
vuint8m1_t flip_negative_u8m1 = __riscv_vxor_vx_u8m1(raw_u8m1, 0x3F, vector_length);
|
|
417
|
+
return __riscv_vmerge_vvm_u8m1(flip_positive_u8m1, flip_negative_u8m1, is_negative_b8, vector_length);
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
NK_INTERNAL vuint8m1_t nk_comparable_to_fp6m1_rvv_(vuint8m1_t comparable_u8m1, nk_size_t vector_length) {
|
|
421
|
+
// Reverse: if >= 0x20 (was positive), XOR 0x20; else XOR 0x3F (NOT lower 6 bits)
|
|
422
|
+
vbool8_t was_positive_b8 = __riscv_vmsgeu_vx_u8m1_b8(comparable_u8m1, 0x20, vector_length);
|
|
423
|
+
vuint8m1_t from_positive_u8m1 = __riscv_vxor_vx_u8m1(comparable_u8m1, 0x20, vector_length);
|
|
424
|
+
vuint8m1_t from_negative_u8m1 = __riscv_vxor_vx_u8m1(comparable_u8m1, 0x3F, vector_length);
|
|
425
|
+
return __riscv_vmerge_vvm_u8m1(from_negative_u8m1, from_positive_u8m1, was_positive_b8, vector_length);
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
NK_INTERNAL void nk_reduce_moments_i8_rvv_contiguous_( //
|
|
429
|
+
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
430
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
431
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
432
|
+
nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
|
|
433
|
+
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, vlmax);
|
|
434
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
435
|
+
vint8m1_t zero_i8m1 = __riscv_vmv_v_x_i8m1(0, vlmax_elements);
|
|
436
|
+
|
|
437
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
438
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
439
|
+
vint8m1_t data_i8m1 = __riscv_vle8_v_i8m1_tu(zero_i8m1, data_ptr, vector_length);
|
|
440
|
+
|
|
441
|
+
// Widen i8 → i16 → i32 → i64 for sum
|
|
442
|
+
vint16m2_t data_i16m2 = __riscv_vsext_vf2_i16m2(data_i8m1, vlmax_elements);
|
|
443
|
+
vint32m4_t data_i32m4 = __riscv_vsext_vf2_i32m4(data_i16m2, vlmax_elements);
|
|
444
|
+
vint64m8_t data_i64m8 = __riscv_vsext_vf2_i64m8(data_i32m4, vlmax_elements);
|
|
445
|
+
|
|
446
|
+
// Accumulate sum (split m8 into two m4)
|
|
447
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 0), vlmax);
|
|
448
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 1), vlmax);
|
|
449
|
+
|
|
450
|
+
// Sumsq: i8 × i8 → i16 (widening multiply)
|
|
451
|
+
vint16m2_t squares_i16m2 = __riscv_vwmul_vv_i16m2(data_i8m1, data_i8m1, vlmax_elements);
|
|
452
|
+
// Widen i16 → u32 → u64
|
|
453
|
+
vuint32m4_t squares_u32m4 = __riscv_vwcvtu_x_x_v_u32m4(__riscv_vreinterpret_v_i16m2_u16m2(squares_i16m2),
|
|
454
|
+
vlmax_elements);
|
|
455
|
+
vuint64m8_t squares_u64m8 = __riscv_vwcvtu_x_x_v_u64m8(squares_u32m4, vlmax_elements);
|
|
456
|
+
|
|
457
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vlmax);
|
|
458
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vlmax);
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
// Horizontal reduction
|
|
462
|
+
vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
|
|
463
|
+
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, vlmax));
|
|
464
|
+
|
|
465
|
+
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
466
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
NK_INTERNAL void nk_reduce_moments_i8_rvv_strided_( //
|
|
470
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
471
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
472
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
473
|
+
nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
|
|
474
|
+
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, vlmax);
|
|
475
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
476
|
+
vint8m1_t zero_i8m1 = __riscv_vmv_v_x_i8m1(0, vlmax_elements);
|
|
477
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
478
|
+
|
|
479
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
480
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
481
|
+
vint8m1_t data_i8m1 = __riscv_vlse8_v_i8m1_tu(zero_i8m1, (nk_i8_t const *)ptr, (nk_ssize_t)stride_bytes,
|
|
482
|
+
vector_length);
|
|
483
|
+
|
|
484
|
+
// Widen i8 → i16 → i32 → i64 for sum
|
|
485
|
+
vint16m2_t data_i16m2 = __riscv_vsext_vf2_i16m2(data_i8m1, vlmax_elements);
|
|
486
|
+
vint32m4_t data_i32m4 = __riscv_vsext_vf2_i32m4(data_i16m2, vlmax_elements);
|
|
487
|
+
vint64m8_t data_i64m8 = __riscv_vsext_vf2_i64m8(data_i32m4, vlmax_elements);
|
|
488
|
+
|
|
489
|
+
// Accumulate sum (split m8 into two m4)
|
|
490
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 0), vlmax);
|
|
491
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 1), vlmax);
|
|
492
|
+
|
|
493
|
+
// Sumsq: i8 × i8 → i16 (widening multiply)
|
|
494
|
+
vint16m2_t squares_i16m2 = __riscv_vwmul_vv_i16m2(data_i8m1, data_i8m1, vlmax_elements);
|
|
495
|
+
// Widen i16 → u32 → u64
|
|
496
|
+
vuint32m4_t squares_u32m4 = __riscv_vwcvtu_x_x_v_u32m4(__riscv_vreinterpret_v_i16m2_u16m2(squares_i16m2),
|
|
497
|
+
vlmax_elements);
|
|
498
|
+
vuint64m8_t squares_u64m8 = __riscv_vwcvtu_x_x_v_u64m8(squares_u32m4, vlmax_elements);
|
|
499
|
+
|
|
500
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vlmax);
|
|
501
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vlmax);
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
// Horizontal reduction
|
|
505
|
+
vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
|
|
506
|
+
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, vlmax));
|
|
507
|
+
|
|
508
|
+
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
509
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
NK_PUBLIC void nk_reduce_moments_i8_rvv( //
|
|
513
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
514
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
515
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
|
|
516
|
+
int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
|
|
517
|
+
|
|
518
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
519
|
+
else if (!aligned) { nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
520
|
+
else if (stride_elements == 1) { nk_reduce_moments_i8_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
|
|
521
|
+
else { nk_reduce_moments_i8_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
NK_INTERNAL void nk_reduce_minmax_i8_rvv_contiguous_( //
|
|
525
|
+
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
526
|
+
nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
527
|
+
nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
528
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
529
|
+
vint8m1_t min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, vlmax);
|
|
530
|
+
vint8m1_t max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, vlmax);
|
|
531
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
532
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
533
|
+
|
|
534
|
+
nk_size_t offset = 0;
|
|
535
|
+
for (nk_size_t vector_length; count > 0;
|
|
536
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
537
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
538
|
+
vint8m1_t data_i8m1 = __riscv_vle8_v_i8m1(data_ptr, vector_length);
|
|
539
|
+
|
|
540
|
+
// VID-based absolute indices
|
|
541
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
542
|
+
vector_length);
|
|
543
|
+
|
|
544
|
+
vbool8_t less_b8 = __riscv_vmslt_vv_i8m1_b8(data_i8m1, min_i8m1, vector_length);
|
|
545
|
+
min_i8m1 = __riscv_vmerge_vvm_i8m1_tu(min_i8m1, min_i8m1, data_i8m1, less_b8, vector_length);
|
|
546
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
547
|
+
vector_length);
|
|
548
|
+
|
|
549
|
+
vbool8_t greater_b8 = __riscv_vmslt_vv_i8m1_b8(max_i8m1, data_i8m1, vector_length);
|
|
550
|
+
max_i8m1 = __riscv_vmerge_vvm_i8m1_tu(max_i8m1, max_i8m1, data_i8m1, greater_b8, vector_length);
|
|
551
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
552
|
+
vector_length);
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
// Horizontal reduction for min
|
|
556
|
+
vint8m1_t init_max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, 1);
|
|
557
|
+
nk_i8_t min_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmin_vs_i8m1_i8m1(min_i8m1, init_max_i8m1, vlmax));
|
|
558
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_i8m1_b8(min_i8m1, min_val, vlmax);
|
|
559
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
560
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
561
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
562
|
+
*min_value_ptr = min_val;
|
|
563
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
564
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
565
|
+
|
|
566
|
+
// Horizontal reduction for max
|
|
567
|
+
vint8m1_t init_min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, 1);
|
|
568
|
+
nk_i8_t max_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmax_vs_i8m1_i8m1(max_i8m1, init_min_i8m1, vlmax));
|
|
569
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_i8m1_b8(max_i8m1, max_val, vlmax);
|
|
570
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
571
|
+
*max_value_ptr = max_val;
|
|
572
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
573
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
NK_INTERNAL void nk_reduce_minmax_i8_rvv_strided_( //
|
|
577
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
578
|
+
nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
579
|
+
nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
580
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
581
|
+
vint8m1_t min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, vlmax);
|
|
582
|
+
vint8m1_t max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, vlmax);
|
|
583
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
584
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
585
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
586
|
+
|
|
587
|
+
nk_size_t offset = 0;
|
|
588
|
+
for (nk_size_t vector_length; count > 0;
|
|
589
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
590
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
591
|
+
vint8m1_t data_i8m1 = __riscv_vlse8_v_i8m1((nk_i8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
592
|
+
|
|
593
|
+
// VID-based absolute indices
|
|
594
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
595
|
+
vector_length);
|
|
596
|
+
|
|
597
|
+
vbool8_t less_b8 = __riscv_vmslt_vv_i8m1_b8(data_i8m1, min_i8m1, vector_length);
|
|
598
|
+
min_i8m1 = __riscv_vmerge_vvm_i8m1_tu(min_i8m1, min_i8m1, data_i8m1, less_b8, vector_length);
|
|
599
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
600
|
+
vector_length);
|
|
601
|
+
|
|
602
|
+
vbool8_t greater_b8 = __riscv_vmslt_vv_i8m1_b8(max_i8m1, data_i8m1, vector_length);
|
|
603
|
+
max_i8m1 = __riscv_vmerge_vvm_i8m1_tu(max_i8m1, max_i8m1, data_i8m1, greater_b8, vector_length);
|
|
604
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
605
|
+
vector_length);
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
// Horizontal reduction for min
|
|
609
|
+
vint8m1_t init_max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, 1);
|
|
610
|
+
nk_i8_t min_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmin_vs_i8m1_i8m1(min_i8m1, init_max_i8m1, vlmax));
|
|
611
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_i8m1_b8(min_i8m1, min_val, vlmax);
|
|
612
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
613
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
614
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
615
|
+
*min_value_ptr = min_val;
|
|
616
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
617
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
618
|
+
|
|
619
|
+
// Horizontal reduction for max
|
|
620
|
+
vint8m1_t init_min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, 1);
|
|
621
|
+
nk_i8_t max_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmax_vs_i8m1_i8m1(max_i8m1, init_min_i8m1, vlmax));
|
|
622
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_i8m1_b8(max_i8m1, max_val, vlmax);
|
|
623
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
624
|
+
*max_value_ptr = max_val;
|
|
625
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
626
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
NK_PUBLIC void nk_reduce_minmax_i8_rvv( //
|
|
630
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
631
|
+
nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
632
|
+
nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
633
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
|
|
634
|
+
int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
|
|
635
|
+
|
|
636
|
+
if (count == 0)
|
|
637
|
+
*min_value_ptr = NK_I8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I8_MIN,
|
|
638
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
639
|
+
else if (!aligned)
|
|
640
|
+
nk_reduce_minmax_i8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
641
|
+
max_index_ptr);
|
|
642
|
+
else if (stride_elements == 1)
|
|
643
|
+
nk_reduce_minmax_i8_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
644
|
+
max_index_ptr);
|
|
645
|
+
else
|
|
646
|
+
nk_reduce_minmax_i8_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
647
|
+
max_index_ptr);
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
NK_INTERNAL void nk_reduce_moments_u8_rvv_contiguous_( //
|
|
651
|
+
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
652
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
653
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
654
|
+
nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
|
|
655
|
+
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
656
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
657
|
+
vuint8m1_t zero_u8m1 = __riscv_vmv_v_x_u8m1(0, vlmax_elements);
|
|
658
|
+
|
|
659
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
660
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
661
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1_tu(zero_u8m1, data_ptr, vector_length);
|
|
662
|
+
|
|
663
|
+
// Widen u8 → u16 → u32 → u64 for sum
|
|
664
|
+
vuint16m2_t data_u16m2 = __riscv_vzext_vf2_u16m2(data_u8m1, vlmax_elements);
|
|
665
|
+
vuint32m4_t data_u32m4 = __riscv_vzext_vf2_u32m4(data_u16m2, vlmax_elements);
|
|
666
|
+
vuint64m8_t data_u64m8 = __riscv_vzext_vf2_u64m8(data_u32m4, vlmax_elements);
|
|
667
|
+
|
|
668
|
+
// Accumulate sum (split m8 into two m4)
|
|
669
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 0), vlmax);
|
|
670
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 1), vlmax);
|
|
671
|
+
|
|
672
|
+
// Sumsq: u8 × u8 → u16 (widening multiply)
|
|
673
|
+
vuint16m2_t squares_u16m2 = __riscv_vwmulu_vv_u16m2(data_u8m1, data_u8m1, vlmax_elements);
|
|
674
|
+
// Widen u16 → u32 → u64
|
|
675
|
+
vuint32m4_t squares_u32m4 = __riscv_vzext_vf2_u32m4(squares_u16m2, vlmax_elements);
|
|
676
|
+
vuint64m8_t squares_u64m8 = __riscv_vzext_vf2_u64m8(squares_u32m4, vlmax_elements);
|
|
677
|
+
|
|
678
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vlmax);
|
|
679
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vlmax);
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
// Horizontal reduction
|
|
683
|
+
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
684
|
+
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, vlmax)),
|
|
685
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
NK_INTERNAL void nk_reduce_moments_u8_rvv_strided_( //
|
|
689
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
690
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
691
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
692
|
+
nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
|
|
693
|
+
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
694
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
695
|
+
vuint8m1_t zero_u8m1 = __riscv_vmv_v_x_u8m1(0, vlmax_elements);
|
|
696
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
697
|
+
|
|
698
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
699
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
700
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1_tu(zero_u8m1, (nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes,
|
|
701
|
+
vector_length);
|
|
702
|
+
|
|
703
|
+
// Widen u8 → u16 → u32 → u64 for sum
|
|
704
|
+
vuint16m2_t data_u16m2 = __riscv_vzext_vf2_u16m2(data_u8m1, vlmax_elements);
|
|
705
|
+
vuint32m4_t data_u32m4 = __riscv_vzext_vf2_u32m4(data_u16m2, vlmax_elements);
|
|
706
|
+
vuint64m8_t data_u64m8 = __riscv_vzext_vf2_u64m8(data_u32m4, vlmax_elements);
|
|
707
|
+
|
|
708
|
+
// Accumulate sum (split m8 into two m4)
|
|
709
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 0), vlmax);
|
|
710
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 1), vlmax);
|
|
711
|
+
|
|
712
|
+
// Sumsq: u8 × u8 → u16 (widening multiply)
|
|
713
|
+
vuint16m2_t squares_u16m2 = __riscv_vwmulu_vv_u16m2(data_u8m1, data_u8m1, vlmax_elements);
|
|
714
|
+
// Widen u16 → u32 → u64
|
|
715
|
+
vuint32m4_t squares_u32m4 = __riscv_vzext_vf2_u32m4(squares_u16m2, vlmax_elements);
|
|
716
|
+
vuint64m8_t squares_u64m8 = __riscv_vzext_vf2_u64m8(squares_u32m4, vlmax_elements);
|
|
717
|
+
|
|
718
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vlmax);
|
|
719
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vlmax);
|
|
720
|
+
}
|
|
721
|
+
|
|
722
|
+
// Horizontal reduction
|
|
723
|
+
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
724
|
+
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, vlmax)),
|
|
725
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
NK_PUBLIC void nk_reduce_moments_u8_rvv( //
|
|
729
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
730
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
731
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
|
|
732
|
+
int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
|
|
733
|
+
|
|
734
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
735
|
+
else if (!aligned) { nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
736
|
+
else if (stride_elements == 1) { nk_reduce_moments_u8_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
|
|
737
|
+
else { nk_reduce_moments_u8_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
NK_INTERNAL void nk_reduce_minmax_u8_rvv_contiguous_( //
|
|
741
|
+
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
742
|
+
nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
743
|
+
nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
744
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
745
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, vlmax);
|
|
746
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, vlmax);
|
|
747
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
748
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
749
|
+
|
|
750
|
+
nk_size_t offset = 0;
|
|
751
|
+
for (nk_size_t vector_length; count > 0;
|
|
752
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
753
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
754
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1(data_ptr, vector_length);
|
|
755
|
+
|
|
756
|
+
// VID-based absolute indices
|
|
757
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
758
|
+
vector_length);
|
|
759
|
+
|
|
760
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_u8m1, min_u8m1, vector_length);
|
|
761
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_u8m1, less_b8, vector_length);
|
|
762
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
763
|
+
vector_length);
|
|
764
|
+
|
|
765
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_u8m1, vector_length);
|
|
766
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_u8m1, greater_b8, vector_length);
|
|
767
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
768
|
+
vector_length);
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
// Horizontal reduction for min
|
|
772
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, 1);
|
|
773
|
+
nk_u8_t min_val = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
774
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_val, vlmax);
|
|
775
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
776
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
777
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
778
|
+
*min_value_ptr = min_val;
|
|
779
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
780
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
781
|
+
|
|
782
|
+
// Horizontal reduction for max
|
|
783
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, 1);
|
|
784
|
+
nk_u8_t max_val = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
785
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_val, vlmax);
|
|
786
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
787
|
+
*max_value_ptr = max_val;
|
|
788
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
789
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
NK_INTERNAL void nk_reduce_minmax_u8_rvv_strided_( //
|
|
793
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
794
|
+
nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
795
|
+
nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
796
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
797
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, vlmax);
|
|
798
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, vlmax);
|
|
799
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
800
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
801
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
802
|
+
|
|
803
|
+
nk_size_t offset = 0;
|
|
804
|
+
for (nk_size_t vector_length; count > 0;
|
|
805
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
806
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
807
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
808
|
+
|
|
809
|
+
// VID-based absolute indices
|
|
810
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
811
|
+
vector_length);
|
|
812
|
+
|
|
813
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_u8m1, min_u8m1, vector_length);
|
|
814
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_u8m1, less_b8, vector_length);
|
|
815
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
816
|
+
vector_length);
|
|
817
|
+
|
|
818
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_u8m1, vector_length);
|
|
819
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_u8m1, greater_b8, vector_length);
|
|
820
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
821
|
+
vector_length);
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
// Horizontal reduction for min
|
|
825
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, 1);
|
|
826
|
+
nk_u8_t min_val = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
827
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_val, vlmax);
|
|
828
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
829
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
830
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
831
|
+
*min_value_ptr = min_val;
|
|
832
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
833
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
834
|
+
|
|
835
|
+
// Horizontal reduction for max
|
|
836
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, 1);
|
|
837
|
+
nk_u8_t max_val = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
838
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_val, vlmax);
|
|
839
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
840
|
+
*max_value_ptr = max_val;
|
|
841
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
842
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
843
|
+
}
|
|
844
|
+
|
|
845
|
+
NK_PUBLIC void nk_reduce_minmax_u8_rvv( //
|
|
846
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
847
|
+
nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
848
|
+
nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
849
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
|
|
850
|
+
int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
|
|
851
|
+
|
|
852
|
+
if (count == 0)
|
|
853
|
+
*min_value_ptr = NK_U8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_U8_MIN,
|
|
854
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
855
|
+
else if (!aligned)
|
|
856
|
+
nk_reduce_minmax_u8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
857
|
+
max_index_ptr);
|
|
858
|
+
else if (stride_elements == 1)
|
|
859
|
+
nk_reduce_minmax_u8_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
860
|
+
max_index_ptr);
|
|
861
|
+
else
|
|
862
|
+
nk_reduce_minmax_u8_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
863
|
+
max_index_ptr);
|
|
864
|
+
}
|
|
865
|
+
|
|
866
|
+
NK_INTERNAL void nk_reduce_moments_i16_rvv_contiguous_( //
|
|
867
|
+
nk_i16_t const *data_ptr, nk_size_t count, //
|
|
868
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
869
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
870
|
+
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, vlmax);
|
|
871
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
872
|
+
|
|
873
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
874
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
875
|
+
vint16m1_t data_i16m1 = __riscv_vle16_v_i16m1(data_ptr, vector_length);
|
|
876
|
+
|
|
877
|
+
// Widen i16 → i32 → i64 for sum
|
|
878
|
+
vint32m2_t data_i32m2 = __riscv_vsext_vf2_i32m2(data_i16m1, vector_length);
|
|
879
|
+
vint64m4_t data_i64m4 = __riscv_vsext_vf2_i64m4(data_i32m2, vector_length);
|
|
880
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4_tu(sum_i64m4, sum_i64m4, data_i64m4, vector_length);
|
|
881
|
+
|
|
882
|
+
// Sumsq: i16 × i16 → i32 (widening multiply)
|
|
883
|
+
vint32m2_t squares_i32m2 = __riscv_vwmul_vv_i32m2(data_i16m1, data_i16m1, vector_length);
|
|
884
|
+
// Widen i32 → u64
|
|
885
|
+
vuint64m4_t squares_u64m4 = __riscv_vwcvtu_x_x_v_u64m4(__riscv_vreinterpret_v_i32m2_u32m2(squares_i32m2),
|
|
886
|
+
vector_length);
|
|
887
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4_tu(sumsq_u64m4, sumsq_u64m4, squares_u64m4, vector_length);
|
|
888
|
+
}
|
|
889
|
+
|
|
890
|
+
// Horizontal reduction
|
|
891
|
+
vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
|
|
892
|
+
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, vlmax));
|
|
893
|
+
|
|
894
|
+
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
895
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
|
|
896
|
+
}
|
|
897
|
+
|
|
898
|
+
NK_INTERNAL void nk_reduce_moments_i16_rvv_strided_( //
|
|
899
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
900
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
901
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
902
|
+
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, vlmax);
|
|
903
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
904
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
905
|
+
|
|
906
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
907
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
908
|
+
vint16m1_t data_i16m1 = __riscv_vlse16_v_i16m1((nk_i16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
909
|
+
|
|
910
|
+
// Widen i16 → i32 → i64 for sum
|
|
911
|
+
vint32m2_t data_i32m2 = __riscv_vsext_vf2_i32m2(data_i16m1, vector_length);
|
|
912
|
+
vint64m4_t data_i64m4 = __riscv_vsext_vf2_i64m4(data_i32m2, vector_length);
|
|
913
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4_tu(sum_i64m4, sum_i64m4, data_i64m4, vector_length);
|
|
914
|
+
|
|
915
|
+
// Sumsq: i16 × i16 → i32 (widening multiply)
|
|
916
|
+
vint32m2_t squares_i32m2 = __riscv_vwmul_vv_i32m2(data_i16m1, data_i16m1, vector_length);
|
|
917
|
+
// Widen i32 → u64
|
|
918
|
+
vuint64m4_t squares_u64m4 = __riscv_vwcvtu_x_x_v_u64m4(__riscv_vreinterpret_v_i32m2_u32m2(squares_i32m2),
|
|
919
|
+
vector_length);
|
|
920
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4_tu(sumsq_u64m4, sumsq_u64m4, squares_u64m4, vector_length);
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
// Horizontal reduction
|
|
924
|
+
vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
|
|
925
|
+
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, vlmax));
|
|
926
|
+
|
|
927
|
+
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
928
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
|
|
929
|
+
}
|
|
930
|
+
|
|
931
|
+
NK_PUBLIC void nk_reduce_moments_i16_rvv( //
|
|
932
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
933
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
934
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
|
|
935
|
+
int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
|
|
936
|
+
|
|
937
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
938
|
+
else if (!aligned) { nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
939
|
+
else if (stride_elements == 1) { nk_reduce_moments_i16_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
|
|
940
|
+
else { nk_reduce_moments_i16_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
941
|
+
}
|
|
942
|
+
|
|
943
|
+
NK_INTERNAL void nk_reduce_minmax_i16_rvv_contiguous_( //
|
|
944
|
+
nk_i16_t const *data_ptr, nk_size_t count, //
|
|
945
|
+
nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
946
|
+
nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
947
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
|
|
948
|
+
vint16m1_t min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, vlmax);
|
|
949
|
+
vint16m1_t max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, vlmax);
|
|
950
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
951
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
952
|
+
|
|
953
|
+
nk_size_t offset = 0;
|
|
954
|
+
for (nk_size_t vector_length; count > 0;
|
|
955
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
956
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
957
|
+
vint16m1_t data_i16m1 = __riscv_vle16_v_i16m1(data_ptr, vector_length);
|
|
958
|
+
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
959
|
+
vector_length);
|
|
960
|
+
|
|
961
|
+
vbool16_t less_b16 = __riscv_vmslt_vv_i16m1_b16(data_i16m1, min_i16m1, vector_length);
|
|
962
|
+
min_i16m1 = __riscv_vmerge_vvm_i16m1_tu(min_i16m1, min_i16m1, data_i16m1, less_b16, vector_length);
|
|
963
|
+
min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
|
|
964
|
+
vector_length);
|
|
965
|
+
|
|
966
|
+
vbool16_t greater_b16 = __riscv_vmslt_vv_i16m1_b16(max_i16m1, data_i16m1, vector_length);
|
|
967
|
+
max_i16m1 = __riscv_vmerge_vvm_i16m1_tu(max_i16m1, max_i16m1, data_i16m1, greater_b16, vector_length);
|
|
968
|
+
max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
|
|
969
|
+
vector_length);
|
|
970
|
+
}
|
|
971
|
+
|
|
972
|
+
// Horizontal reduction for min
|
|
973
|
+
vint16m1_t init_max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, 1);
|
|
974
|
+
nk_i16_t min_val = __riscv_vmv_x_s_i16m1_i16(__riscv_vredmin_vs_i16m1_i16m1(min_i16m1, init_max_i16m1, vlmax));
|
|
975
|
+
vbool16_t min_match_b16 = __riscv_vmseq_vx_i16m1_b16(min_i16m1, min_val, vlmax);
|
|
976
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
|
|
977
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
|
|
978
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
979
|
+
*min_value_ptr = min_val;
|
|
980
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
981
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
|
|
982
|
+
|
|
983
|
+
// Horizontal reduction for max
|
|
984
|
+
vint16m1_t init_min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, 1);
|
|
985
|
+
nk_i16_t max_val = __riscv_vmv_x_s_i16m1_i16(__riscv_vredmax_vs_i16m1_i16m1(max_i16m1, init_min_i16m1, vlmax));
|
|
986
|
+
vbool16_t max_match_b16 = __riscv_vmseq_vx_i16m1_b16(max_i16m1, max_val, vlmax);
|
|
987
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
|
|
988
|
+
*max_value_ptr = max_val;
|
|
989
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
990
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
NK_INTERNAL void nk_reduce_minmax_i16_rvv_strided_( //
|
|
994
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
995
|
+
nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
996
|
+
nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
997
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
|
|
998
|
+
vint16m1_t min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, vlmax);
|
|
999
|
+
vint16m1_t max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, vlmax);
|
|
1000
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1001
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1002
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1003
|
+
|
|
1004
|
+
nk_size_t offset = 0;
|
|
1005
|
+
for (nk_size_t vector_length; count > 0;
|
|
1006
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
1007
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
1008
|
+
vint16m1_t data_i16m1 = __riscv_vlse16_v_i16m1((nk_i16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1009
|
+
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
1010
|
+
vector_length);
|
|
1011
|
+
|
|
1012
|
+
vbool16_t less_b16 = __riscv_vmslt_vv_i16m1_b16(data_i16m1, min_i16m1, vector_length);
|
|
1013
|
+
min_i16m1 = __riscv_vmerge_vvm_i16m1_tu(min_i16m1, min_i16m1, data_i16m1, less_b16, vector_length);
|
|
1014
|
+
min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
|
|
1015
|
+
vector_length);
|
|
1016
|
+
|
|
1017
|
+
vbool16_t greater_b16 = __riscv_vmslt_vv_i16m1_b16(max_i16m1, data_i16m1, vector_length);
|
|
1018
|
+
max_i16m1 = __riscv_vmerge_vvm_i16m1_tu(max_i16m1, max_i16m1, data_i16m1, greater_b16, vector_length);
|
|
1019
|
+
max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
|
|
1020
|
+
vector_length);
|
|
1021
|
+
}
|
|
1022
|
+
|
|
1023
|
+
// Horizontal reduction for min
|
|
1024
|
+
vint16m1_t init_max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, 1);
|
|
1025
|
+
nk_i16_t min_val = __riscv_vmv_x_s_i16m1_i16(__riscv_vredmin_vs_i16m1_i16m1(min_i16m1, init_max_i16m1, vlmax));
|
|
1026
|
+
vbool16_t min_match_b16 = __riscv_vmseq_vx_i16m1_b16(min_i16m1, min_val, vlmax);
|
|
1027
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
|
|
1028
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
|
|
1029
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1030
|
+
*min_value_ptr = min_val;
|
|
1031
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1032
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
|
|
1033
|
+
|
|
1034
|
+
// Horizontal reduction for max
|
|
1035
|
+
vint16m1_t init_min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, 1);
|
|
1036
|
+
nk_i16_t max_val = __riscv_vmv_x_s_i16m1_i16(__riscv_vredmax_vs_i16m1_i16m1(max_i16m1, init_min_i16m1, vlmax));
|
|
1037
|
+
vbool16_t max_match_b16 = __riscv_vmseq_vx_i16m1_b16(max_i16m1, max_val, vlmax);
|
|
1038
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
|
|
1039
|
+
*max_value_ptr = max_val;
|
|
1040
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1041
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
NK_PUBLIC void nk_reduce_minmax_i16_rvv( //
|
|
1045
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1046
|
+
nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1047
|
+
nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1048
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
|
|
1049
|
+
int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
|
|
1050
|
+
|
|
1051
|
+
if (count == 0)
|
|
1052
|
+
*min_value_ptr = NK_I16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I16_MIN,
|
|
1053
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1054
|
+
else if (!aligned)
|
|
1055
|
+
nk_reduce_minmax_i16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1056
|
+
max_index_ptr);
|
|
1057
|
+
else if (stride_elements == 1)
|
|
1058
|
+
nk_reduce_minmax_i16_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1059
|
+
max_index_ptr);
|
|
1060
|
+
else
|
|
1061
|
+
nk_reduce_minmax_i16_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1062
|
+
max_index_ptr);
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
NK_INTERNAL void nk_reduce_moments_u16_rvv_contiguous_( //
|
|
1066
|
+
nk_u16_t const *data_ptr, nk_size_t count, //
|
|
1067
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1068
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
1069
|
+
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1070
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1071
|
+
|
|
1072
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1073
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
1074
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1(data_ptr, vector_length);
|
|
1075
|
+
|
|
1076
|
+
// Widen u16 → u32 → u64 for sum
|
|
1077
|
+
vuint32m2_t data_u32m2 = __riscv_vzext_vf2_u32m2(data_u16m1, vector_length);
|
|
1078
|
+
vuint64m4_t data_u64m4 = __riscv_vzext_vf2_u64m4(data_u32m2, vector_length);
|
|
1079
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4_tu(sum_u64m4, sum_u64m4, data_u64m4, vector_length);
|
|
1080
|
+
|
|
1081
|
+
// Sumsq: u16 × u16 → u32 (widening multiply)
|
|
1082
|
+
vuint32m2_t squares_u32m2 = __riscv_vwmulu_vv_u32m2(data_u16m1, data_u16m1, vector_length);
|
|
1083
|
+
// Widen u32 → u64
|
|
1084
|
+
vuint64m4_t squares_u64m4 = __riscv_vzext_vf2_u64m4(squares_u32m2, vector_length);
|
|
1085
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4_tu(sumsq_u64m4, sumsq_u64m4, squares_u64m4, vector_length);
|
|
1086
|
+
}
|
|
1087
|
+
|
|
1088
|
+
// Horizontal reduction
|
|
1089
|
+
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
1090
|
+
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, vlmax)),
|
|
1091
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
|
|
1092
|
+
}
|
|
1093
|
+
|
|
1094
|
+
NK_INTERNAL void nk_reduce_moments_u16_rvv_strided_( //
|
|
1095
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1096
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1097
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
1098
|
+
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1099
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1100
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1101
|
+
|
|
1102
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
1103
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
1104
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((nk_u16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1105
|
+
|
|
1106
|
+
// Widen u16 → u32 → u64 for sum
|
|
1107
|
+
vuint32m2_t data_u32m2 = __riscv_vzext_vf2_u32m2(data_u16m1, vector_length);
|
|
1108
|
+
vuint64m4_t data_u64m4 = __riscv_vzext_vf2_u64m4(data_u32m2, vector_length);
|
|
1109
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4_tu(sum_u64m4, sum_u64m4, data_u64m4, vector_length);
|
|
1110
|
+
|
|
1111
|
+
// Sumsq: u16 × u16 → u32 (widening multiply)
|
|
1112
|
+
vuint32m2_t squares_u32m2 = __riscv_vwmulu_vv_u32m2(data_u16m1, data_u16m1, vector_length);
|
|
1113
|
+
// Widen u32 → u64
|
|
1114
|
+
vuint64m4_t squares_u64m4 = __riscv_vzext_vf2_u64m4(squares_u32m2, vector_length);
|
|
1115
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4_tu(sumsq_u64m4, sumsq_u64m4, squares_u64m4, vector_length);
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
// Horizontal reduction
|
|
1119
|
+
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
1120
|
+
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, vlmax)),
|
|
1121
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, vlmax));
|
|
1122
|
+
}
|
|
1123
|
+
|
|
1124
|
+
NK_PUBLIC void nk_reduce_moments_u16_rvv( //
|
|
1125
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1126
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1127
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
|
|
1128
|
+
int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
|
|
1129
|
+
|
|
1130
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1131
|
+
else if (!aligned) { nk_reduce_moments_u16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1132
|
+
else if (stride_elements == 1) { nk_reduce_moments_u16_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
|
|
1133
|
+
else { nk_reduce_moments_u16_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1134
|
+
}
|
|
1135
|
+
|
|
1136
|
+
NK_INTERNAL void nk_reduce_minmax_u16_rvv_contiguous_( //
|
|
1137
|
+
nk_u16_t const *data_ptr, nk_size_t count, //
|
|
1138
|
+
nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1139
|
+
nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1140
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
|
|
1141
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, vlmax);
|
|
1142
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, vlmax);
|
|
1143
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1144
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1145
|
+
|
|
1146
|
+
nk_size_t offset = 0;
|
|
1147
|
+
for (nk_size_t vector_length; count > 0;
|
|
1148
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
1149
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
1150
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1(data_ptr, vector_length);
|
|
1151
|
+
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
1152
|
+
vector_length);
|
|
1153
|
+
|
|
1154
|
+
vbool16_t less_b16 = __riscv_vmsltu_vv_u16m1_b16(data_u16m1, min_u16m1, vector_length);
|
|
1155
|
+
min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
|
|
1156
|
+
min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
|
|
1157
|
+
vector_length);
|
|
1158
|
+
|
|
1159
|
+
vbool16_t greater_b16 = __riscv_vmsltu_vv_u16m1_b16(max_u16m1, data_u16m1, vector_length);
|
|
1160
|
+
max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
|
|
1161
|
+
max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
|
|
1162
|
+
vector_length);
|
|
1163
|
+
}
|
|
1164
|
+
|
|
1165
|
+
// Horizontal reduction for min
|
|
1166
|
+
vuint16m1_t init_max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, 1);
|
|
1167
|
+
nk_u16_t min_val = __riscv_vmv_x_s_u16m1_u16(__riscv_vredminu_vs_u16m1_u16m1(min_u16m1, init_max_u16m1, vlmax));
|
|
1168
|
+
vbool16_t min_match_b16 = __riscv_vmseq_vx_u16m1_b16(min_u16m1, min_val, vlmax);
|
|
1169
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
|
|
1170
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
|
|
1171
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1172
|
+
*min_value_ptr = min_val;
|
|
1173
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1174
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
|
|
1175
|
+
|
|
1176
|
+
// Horizontal reduction for max
|
|
1177
|
+
vuint16m1_t init_min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, 1);
|
|
1178
|
+
nk_u16_t max_val = __riscv_vmv_x_s_u16m1_u16(__riscv_vredmaxu_vs_u16m1_u16m1(max_u16m1, init_min_u16m1, vlmax));
|
|
1179
|
+
vbool16_t max_match_b16 = __riscv_vmseq_vx_u16m1_b16(max_u16m1, max_val, vlmax);
|
|
1180
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
|
|
1181
|
+
*max_value_ptr = max_val;
|
|
1182
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1183
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
NK_INTERNAL void nk_reduce_minmax_u16_rvv_strided_( //
|
|
1187
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1188
|
+
nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1189
|
+
nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1190
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
|
|
1191
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, vlmax);
|
|
1192
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, vlmax);
|
|
1193
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1194
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
1195
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1196
|
+
|
|
1197
|
+
nk_size_t offset = 0;
|
|
1198
|
+
for (nk_size_t vector_length; count > 0;
|
|
1199
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
1200
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
1201
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((nk_u16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1202
|
+
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
1203
|
+
vector_length);
|
|
1204
|
+
|
|
1205
|
+
vbool16_t less_b16 = __riscv_vmsltu_vv_u16m1_b16(data_u16m1, min_u16m1, vector_length);
|
|
1206
|
+
min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
|
|
1207
|
+
min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
|
|
1208
|
+
vector_length);
|
|
1209
|
+
|
|
1210
|
+
vbool16_t greater_b16 = __riscv_vmsltu_vv_u16m1_b16(max_u16m1, data_u16m1, vector_length);
|
|
1211
|
+
max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
|
|
1212
|
+
max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
|
|
1213
|
+
vector_length);
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
// Horizontal reduction for min
|
|
1217
|
+
vuint16m1_t init_max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, 1);
|
|
1218
|
+
nk_u16_t min_val = __riscv_vmv_x_s_u16m1_u16(__riscv_vredminu_vs_u16m1_u16m1(min_u16m1, init_max_u16m1, vlmax));
|
|
1219
|
+
vbool16_t min_match_b16 = __riscv_vmseq_vx_u16m1_b16(min_u16m1, min_val, vlmax);
|
|
1220
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
|
|
1221
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
|
|
1222
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1223
|
+
*min_value_ptr = min_val;
|
|
1224
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1225
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
|
|
1226
|
+
|
|
1227
|
+
// Horizontal reduction for max
|
|
1228
|
+
vuint16m1_t init_min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, 1);
|
|
1229
|
+
nk_u16_t max_val = __riscv_vmv_x_s_u16m1_u16(__riscv_vredmaxu_vs_u16m1_u16m1(max_u16m1, init_min_u16m1, vlmax));
|
|
1230
|
+
vbool16_t max_match_b16 = __riscv_vmseq_vx_u16m1_b16(max_u16m1, max_val, vlmax);
|
|
1231
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
|
|
1232
|
+
*max_value_ptr = max_val;
|
|
1233
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1234
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
|
|
1235
|
+
}
|
|
1236
|
+
|
|
1237
|
+
NK_PUBLIC void nk_reduce_minmax_u16_rvv( //
|
|
1238
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1239
|
+
nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1240
|
+
nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1241
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
|
|
1242
|
+
int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
|
|
1243
|
+
|
|
1244
|
+
if (count == 0)
|
|
1245
|
+
*min_value_ptr = NK_U16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_U16_MIN,
|
|
1246
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1247
|
+
else if (!aligned)
|
|
1248
|
+
nk_reduce_minmax_u16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1249
|
+
max_index_ptr);
|
|
1250
|
+
else if (stride_elements == 1)
|
|
1251
|
+
nk_reduce_minmax_u16_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1252
|
+
max_index_ptr);
|
|
1253
|
+
else
|
|
1254
|
+
nk_reduce_minmax_u16_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1255
|
+
max_index_ptr);
|
|
1256
|
+
}
|
|
1257
|
+
|
|
1258
|
+
NK_INTERNAL void nk_reduce_moments_i32_rvv_contiguous_( //
|
|
1259
|
+
nk_i32_t const *data_ptr, nk_size_t count, //
|
|
1260
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1261
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
|
|
1262
|
+
// 128-bit per-lane accumulator for sum: (sum_upper, sum_lower)
|
|
1263
|
+
vuint64m2_t sum_lower_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1264
|
+
vint64m2_t sum_upper_i64m2 = __riscv_vmv_v_x_i64m2(0, vlmax);
|
|
1265
|
+
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1266
|
+
|
|
1267
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1268
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
1269
|
+
vint32m1_t data_i32m1 = __riscv_vle32_v_i32m1(data_ptr, vector_length);
|
|
1270
|
+
|
|
1271
|
+
// Widen i32 → i64
|
|
1272
|
+
vint64m2_t data_i64m2 = __riscv_vsext_vf2_i64m2(data_i32m1, vector_length);
|
|
1273
|
+
vuint64m2_t data_u64m2 = __riscv_vreinterpret_v_i64m2_u64m2(data_i64m2);
|
|
1274
|
+
|
|
1275
|
+
// 128-bit accumulation: wrapping add on lower half
|
|
1276
|
+
vuint64m2_t sum_before_u64m2 = sum_lower_u64m2;
|
|
1277
|
+
sum_lower_u64m2 = __riscv_vadd_vv_u64m2_tu(sum_lower_u64m2, sum_lower_u64m2, data_u64m2, vector_length);
|
|
1278
|
+
|
|
1279
|
+
// Carry: new < old means unsigned overflow occurred
|
|
1280
|
+
vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(sum_lower_u64m2, sum_before_u64m2, vector_length);
|
|
1281
|
+
vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0, vector_length), 1, carry_b32,
|
|
1282
|
+
vector_length);
|
|
1283
|
+
sum_upper_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_upper_i64m2, sum_upper_i64m2, carry_i64m2, vector_length);
|
|
1284
|
+
|
|
1285
|
+
// Sign extension: -1 for negative, 0 for non-negative
|
|
1286
|
+
vint64m2_t sign_ext_i64m2 = __riscv_vsra_vx_i64m2(data_i64m2, 63, vector_length);
|
|
1287
|
+
sum_upper_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_upper_i64m2, sum_upper_i64m2, sign_ext_i64m2, vector_length);
|
|
1288
|
+
|
|
1289
|
+
// Sumsq: i32 × i32 → i64 (widening multiply, result ≤ 2^62), saturating accumulation
|
|
1290
|
+
vint64m2_t squares_i64m2 = __riscv_vwmul_vv_i64m2(data_i32m1, data_i32m1, vector_length);
|
|
1291
|
+
sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2,
|
|
1292
|
+
__riscv_vreinterpret_v_i64m2_u64m2(squares_i64m2), vector_length);
|
|
1293
|
+
}
|
|
1294
|
+
|
|
1295
|
+
*sum_ptr = nk_reduce_128bit_sum_i64m2_rvv_(sum_lower_u64m2, sum_upper_i64m2, vlmax);
|
|
1296
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, vlmax);
|
|
1297
|
+
}
|
|
1298
|
+
|
|
1299
|
+
NK_INTERNAL void nk_reduce_moments_i32_rvv_strided_( //
|
|
1300
|
+
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1301
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1302
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
|
|
1303
|
+
// 128-bit per-lane accumulator for sum: (sum_upper, sum_lower)
|
|
1304
|
+
vuint64m2_t sum_lower_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1305
|
+
vint64m2_t sum_upper_i64m2 = __riscv_vmv_v_x_i64m2(0, vlmax);
|
|
1306
|
+
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1307
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1308
|
+
|
|
1309
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
1310
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
1311
|
+
vint32m1_t data_i32m1 = __riscv_vlse32_v_i32m1((nk_i32_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1312
|
+
|
|
1313
|
+
// Widen i32 → i64
|
|
1314
|
+
vint64m2_t data_i64m2 = __riscv_vsext_vf2_i64m2(data_i32m1, vector_length);
|
|
1315
|
+
vuint64m2_t data_u64m2 = __riscv_vreinterpret_v_i64m2_u64m2(data_i64m2);
|
|
1316
|
+
|
|
1317
|
+
// 128-bit accumulation: wrapping add on lower half
|
|
1318
|
+
vuint64m2_t sum_before_u64m2 = sum_lower_u64m2;
|
|
1319
|
+
sum_lower_u64m2 = __riscv_vadd_vv_u64m2_tu(sum_lower_u64m2, sum_lower_u64m2, data_u64m2, vector_length);
|
|
1320
|
+
|
|
1321
|
+
// Carry: new < old means unsigned overflow occurred
|
|
1322
|
+
vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(sum_lower_u64m2, sum_before_u64m2, vector_length);
|
|
1323
|
+
vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0, vector_length), 1, carry_b32,
|
|
1324
|
+
vector_length);
|
|
1325
|
+
sum_upper_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_upper_i64m2, sum_upper_i64m2, carry_i64m2, vector_length);
|
|
1326
|
+
|
|
1327
|
+
// Sign extension: -1 for negative, 0 for non-negative
|
|
1328
|
+
vint64m2_t sign_ext_i64m2 = __riscv_vsra_vx_i64m2(data_i64m2, 63, vector_length);
|
|
1329
|
+
sum_upper_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_upper_i64m2, sum_upper_i64m2, sign_ext_i64m2, vector_length);
|
|
1330
|
+
|
|
1331
|
+
// Sumsq: i32 × i32 → i64 (widening multiply, result ≤ 2^62), saturating accumulation
|
|
1332
|
+
vint64m2_t squares_i64m2 = __riscv_vwmul_vv_i64m2(data_i32m1, data_i32m1, vector_length);
|
|
1333
|
+
sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2,
|
|
1334
|
+
__riscv_vreinterpret_v_i64m2_u64m2(squares_i64m2), vector_length);
|
|
1335
|
+
}
|
|
1336
|
+
|
|
1337
|
+
*sum_ptr = nk_reduce_128bit_sum_i64m2_rvv_(sum_lower_u64m2, sum_upper_i64m2, vlmax);
|
|
1338
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, vlmax);
|
|
1339
|
+
}
|
|
1340
|
+
|
|
1341
|
+
NK_PUBLIC void nk_reduce_moments_i32_rvv( //
|
|
1342
|
+
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1343
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1344
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
|
|
1345
|
+
int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
|
|
1346
|
+
|
|
1347
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1348
|
+
else if (!aligned) { nk_reduce_moments_i32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1349
|
+
else if (stride_elements == 1) { nk_reduce_moments_i32_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
|
|
1350
|
+
else { nk_reduce_moments_i32_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1351
|
+
}
|
|
1352
|
+
|
|
1353
|
+
NK_INTERNAL void nk_reduce_minmax_i32_rvv_contiguous_( //
|
|
1354
|
+
nk_i32_t const *data_ptr, nk_size_t count, //
|
|
1355
|
+
nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1356
|
+
nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1357
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
|
|
1358
|
+
vint32m1_t min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, vlmax);
|
|
1359
|
+
vint32m1_t max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, vlmax);
|
|
1360
|
+
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1361
|
+
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1362
|
+
|
|
1363
|
+
nk_size_t offset = 0;
|
|
1364
|
+
for (nk_size_t vector_length; count > 0;
|
|
1365
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
1366
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
1367
|
+
vint32m1_t data_i32m1 = __riscv_vle32_v_i32m1(data_ptr, vector_length);
|
|
1368
|
+
vuint64m2_t pos_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
|
|
1369
|
+
vector_length);
|
|
1370
|
+
|
|
1371
|
+
vbool32_t less_b32 = __riscv_vmslt_vv_i32m1_b32(data_i32m1, min_i32m1, vector_length);
|
|
1372
|
+
min_i32m1 = __riscv_vmerge_vvm_i32m1_tu(min_i32m1, min_i32m1, data_i32m1, less_b32, vector_length);
|
|
1373
|
+
min_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(min_indices_u64m2, min_indices_u64m2, pos_u64m2, less_b32,
|
|
1374
|
+
vector_length);
|
|
1375
|
+
|
|
1376
|
+
vbool32_t greater_b32 = __riscv_vmslt_vv_i32m1_b32(max_i32m1, data_i32m1, vector_length);
|
|
1377
|
+
max_i32m1 = __riscv_vmerge_vvm_i32m1_tu(max_i32m1, max_i32m1, data_i32m1, greater_b32, vector_length);
|
|
1378
|
+
max_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(max_indices_u64m2, max_indices_u64m2, pos_u64m2, greater_b32,
|
|
1379
|
+
vector_length);
|
|
1380
|
+
}
|
|
1381
|
+
|
|
1382
|
+
// Horizontal reduction for min
|
|
1383
|
+
vint32m1_t init_max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, 1);
|
|
1384
|
+
nk_i32_t min_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredmin_vs_i32m1_i32m1(min_i32m1, init_max_i32m1, vlmax));
|
|
1385
|
+
vbool32_t min_match_b32 = __riscv_vmseq_vx_i32m1_b32(min_i32m1, min_val, vlmax);
|
|
1386
|
+
vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
|
|
1387
|
+
vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32, vlmax);
|
|
1388
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1389
|
+
*min_value_ptr = min_val;
|
|
1390
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1391
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, vlmax));
|
|
1392
|
+
|
|
1393
|
+
// Horizontal reduction for max
|
|
1394
|
+
vint32m1_t init_min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, 1);
|
|
1395
|
+
nk_i32_t max_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredmax_vs_i32m1_i32m1(max_i32m1, init_min_i32m1, vlmax));
|
|
1396
|
+
vbool32_t max_match_b32 = __riscv_vmseq_vx_i32m1_b32(max_i32m1, max_val, vlmax);
|
|
1397
|
+
vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32, vlmax);
|
|
1398
|
+
*max_value_ptr = max_val;
|
|
1399
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1400
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, vlmax));
|
|
1401
|
+
}
|
|
1402
|
+
|
|
1403
|
+
NK_INTERNAL void nk_reduce_minmax_i32_rvv_strided_( //
|
|
1404
|
+
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1405
|
+
nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1406
|
+
nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1407
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
|
|
1408
|
+
vint32m1_t min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, vlmax);
|
|
1409
|
+
vint32m1_t max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, vlmax);
|
|
1410
|
+
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1411
|
+
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1412
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1413
|
+
|
|
1414
|
+
nk_size_t offset = 0;
|
|
1415
|
+
for (nk_size_t vector_length; count > 0;
|
|
1416
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
1417
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
1418
|
+
vint32m1_t data_i32m1 = __riscv_vlse32_v_i32m1((nk_i32_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1419
|
+
vuint64m2_t pos_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
|
|
1420
|
+
vector_length);
|
|
1421
|
+
|
|
1422
|
+
vbool32_t less_b32 = __riscv_vmslt_vv_i32m1_b32(data_i32m1, min_i32m1, vector_length);
|
|
1423
|
+
min_i32m1 = __riscv_vmerge_vvm_i32m1_tu(min_i32m1, min_i32m1, data_i32m1, less_b32, vector_length);
|
|
1424
|
+
min_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(min_indices_u64m2, min_indices_u64m2, pos_u64m2, less_b32,
|
|
1425
|
+
vector_length);
|
|
1426
|
+
|
|
1427
|
+
vbool32_t greater_b32 = __riscv_vmslt_vv_i32m1_b32(max_i32m1, data_i32m1, vector_length);
|
|
1428
|
+
max_i32m1 = __riscv_vmerge_vvm_i32m1_tu(max_i32m1, max_i32m1, data_i32m1, greater_b32, vector_length);
|
|
1429
|
+
max_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(max_indices_u64m2, max_indices_u64m2, pos_u64m2, greater_b32,
|
|
1430
|
+
vector_length);
|
|
1431
|
+
}
|
|
1432
|
+
|
|
1433
|
+
// Horizontal reduction for min
|
|
1434
|
+
vint32m1_t init_max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, 1);
|
|
1435
|
+
nk_i32_t min_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredmin_vs_i32m1_i32m1(min_i32m1, init_max_i32m1, vlmax));
|
|
1436
|
+
vbool32_t min_match_b32 = __riscv_vmseq_vx_i32m1_b32(min_i32m1, min_val, vlmax);
|
|
1437
|
+
vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
|
|
1438
|
+
vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32, vlmax);
|
|
1439
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1440
|
+
*min_value_ptr = min_val;
|
|
1441
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1442
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, vlmax));
|
|
1443
|
+
|
|
1444
|
+
// Horizontal reduction for max
|
|
1445
|
+
vint32m1_t init_min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, 1);
|
|
1446
|
+
nk_i32_t max_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredmax_vs_i32m1_i32m1(max_i32m1, init_min_i32m1, vlmax));
|
|
1447
|
+
vbool32_t max_match_b32 = __riscv_vmseq_vx_i32m1_b32(max_i32m1, max_val, vlmax);
|
|
1448
|
+
vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32, vlmax);
|
|
1449
|
+
*max_value_ptr = max_val;
|
|
1450
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1451
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, vlmax));
|
|
1452
|
+
}
|
|
1453
|
+
|
|
1454
|
+
NK_PUBLIC void nk_reduce_minmax_i32_rvv( //
|
|
1455
|
+
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1456
|
+
nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1457
|
+
nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1458
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
|
|
1459
|
+
int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
|
|
1460
|
+
|
|
1461
|
+
if (count == 0)
|
|
1462
|
+
*min_value_ptr = NK_I32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I32_MIN,
|
|
1463
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1464
|
+
else if (!aligned)
|
|
1465
|
+
nk_reduce_minmax_i32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1466
|
+
max_index_ptr);
|
|
1467
|
+
else if (stride_elements == 1)
|
|
1468
|
+
nk_reduce_minmax_i32_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1469
|
+
max_index_ptr);
|
|
1470
|
+
else
|
|
1471
|
+
nk_reduce_minmax_i32_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1472
|
+
max_index_ptr);
|
|
1473
|
+
}
|
|
1474
|
+
|
|
1475
|
+
NK_INTERNAL void nk_reduce_moments_u32_rvv_contiguous_( //
|
|
1476
|
+
nk_u32_t const *data_ptr, nk_size_t count, //
|
|
1477
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1478
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
|
|
1479
|
+
vuint64m2_t sum_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1480
|
+
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1481
|
+
|
|
1482
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1483
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
1484
|
+
vuint32m1_t data_u32m1 = __riscv_vle32_v_u32m1(data_ptr, vector_length);
|
|
1485
|
+
|
|
1486
|
+
// Widen u32 → u64 for saturating sum
|
|
1487
|
+
vuint64m2_t data_u64m2 = __riscv_vzext_vf2_u64m2(data_u32m1, vector_length);
|
|
1488
|
+
sum_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sum_u64m2, sum_u64m2, data_u64m2, vector_length);
|
|
1489
|
+
|
|
1490
|
+
// Sumsq: u32 × u32 → u64 (widening multiply, no overflow), saturating accumulation
|
|
1491
|
+
vuint64m2_t squares_u64m2 = __riscv_vwmulu_vv_u64m2(data_u32m1, data_u32m1, vector_length);
|
|
1492
|
+
sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2, squares_u64m2, vector_length);
|
|
1493
|
+
}
|
|
1494
|
+
|
|
1495
|
+
*sum_ptr = nk_reduce_vsaddu_u64m2_rvv_(sum_u64m2, vlmax);
|
|
1496
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, vlmax);
|
|
1497
|
+
}
|
|
1498
|
+
|
|
1499
|
+
NK_INTERNAL void nk_reduce_moments_u32_rvv_strided_( //
|
|
1500
|
+
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1501
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1502
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
|
|
1503
|
+
vuint64m2_t sum_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1504
|
+
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1505
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1506
|
+
|
|
1507
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
1508
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
1509
|
+
vuint32m1_t data_u32m1 = __riscv_vlse32_v_u32m1((nk_u32_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1510
|
+
|
|
1511
|
+
// Widen u32 → u64 for saturating sum
|
|
1512
|
+
vuint64m2_t data_u64m2 = __riscv_vzext_vf2_u64m2(data_u32m1, vector_length);
|
|
1513
|
+
sum_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sum_u64m2, sum_u64m2, data_u64m2, vector_length);
|
|
1514
|
+
|
|
1515
|
+
// Sumsq: u32 × u32 → u64 (widening multiply, no overflow), saturating accumulation
|
|
1516
|
+
vuint64m2_t squares_u64m2 = __riscv_vwmulu_vv_u64m2(data_u32m1, data_u32m1, vector_length);
|
|
1517
|
+
sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2, squares_u64m2, vector_length);
|
|
1518
|
+
}
|
|
1519
|
+
|
|
1520
|
+
*sum_ptr = nk_reduce_vsaddu_u64m2_rvv_(sum_u64m2, vlmax);
|
|
1521
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, vlmax);
|
|
1522
|
+
}
|
|
1523
|
+
|
|
1524
|
+
NK_PUBLIC void nk_reduce_moments_u32_rvv( //
|
|
1525
|
+
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1526
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1527
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
|
|
1528
|
+
int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
|
|
1529
|
+
|
|
1530
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1531
|
+
else if (!aligned) { nk_reduce_moments_u32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1532
|
+
else if (stride_elements == 1) { nk_reduce_moments_u32_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
|
|
1533
|
+
else { nk_reduce_moments_u32_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1534
|
+
}
|
|
1535
|
+
|
|
1536
|
+
NK_INTERNAL void nk_reduce_minmax_u32_rvv_contiguous_( //
|
|
1537
|
+
nk_u32_t const *data_ptr, nk_size_t count, //
|
|
1538
|
+
nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1539
|
+
nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1540
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
|
|
1541
|
+
vuint32m1_t min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, vlmax);
|
|
1542
|
+
vuint32m1_t max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, vlmax);
|
|
1543
|
+
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1544
|
+
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1545
|
+
|
|
1546
|
+
nk_size_t offset = 0;
|
|
1547
|
+
for (nk_size_t vector_length; count > 0;
|
|
1548
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
1549
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
1550
|
+
vuint32m1_t data_u32m1 = __riscv_vle32_v_u32m1(data_ptr, vector_length);
|
|
1551
|
+
vuint64m2_t pos_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
|
|
1552
|
+
vector_length);
|
|
1553
|
+
|
|
1554
|
+
vbool32_t less_b32 = __riscv_vmsltu_vv_u32m1_b32(data_u32m1, min_u32m1, vector_length);
|
|
1555
|
+
min_u32m1 = __riscv_vmerge_vvm_u32m1_tu(min_u32m1, min_u32m1, data_u32m1, less_b32, vector_length);
|
|
1556
|
+
min_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(min_indices_u64m2, min_indices_u64m2, pos_u64m2, less_b32,
|
|
1557
|
+
vector_length);
|
|
1558
|
+
|
|
1559
|
+
vbool32_t greater_b32 = __riscv_vmsltu_vv_u32m1_b32(max_u32m1, data_u32m1, vector_length);
|
|
1560
|
+
max_u32m1 = __riscv_vmerge_vvm_u32m1_tu(max_u32m1, max_u32m1, data_u32m1, greater_b32, vector_length);
|
|
1561
|
+
max_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(max_indices_u64m2, max_indices_u64m2, pos_u64m2, greater_b32,
|
|
1562
|
+
vector_length);
|
|
1563
|
+
}
|
|
1564
|
+
|
|
1565
|
+
// Horizontal reduction for min
|
|
1566
|
+
vuint32m1_t init_max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, 1);
|
|
1567
|
+
nk_u32_t min_val = __riscv_vmv_x_s_u32m1_u32(__riscv_vredminu_vs_u32m1_u32m1(min_u32m1, init_max_u32m1, vlmax));
|
|
1568
|
+
vbool32_t min_match_b32 = __riscv_vmseq_vx_u32m1_b32(min_u32m1, min_val, vlmax);
|
|
1569
|
+
vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
|
|
1570
|
+
vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32, vlmax);
|
|
1571
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1572
|
+
*min_value_ptr = min_val;
|
|
1573
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1574
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, vlmax));
|
|
1575
|
+
|
|
1576
|
+
// Horizontal reduction for max
|
|
1577
|
+
vuint32m1_t init_min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, 1);
|
|
1578
|
+
nk_u32_t max_val = __riscv_vmv_x_s_u32m1_u32(__riscv_vredmaxu_vs_u32m1_u32m1(max_u32m1, init_min_u32m1, vlmax));
|
|
1579
|
+
vbool32_t max_match_b32 = __riscv_vmseq_vx_u32m1_b32(max_u32m1, max_val, vlmax);
|
|
1580
|
+
vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32, vlmax);
|
|
1581
|
+
*max_value_ptr = max_val;
|
|
1582
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1583
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, vlmax));
|
|
1584
|
+
}
|
|
1585
|
+
|
|
1586
|
+
NK_INTERNAL void nk_reduce_minmax_u32_rvv_strided_( //
|
|
1587
|
+
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1588
|
+
nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1589
|
+
nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1590
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m1();
|
|
1591
|
+
vuint32m1_t min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, vlmax);
|
|
1592
|
+
vuint32m1_t max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, vlmax);
|
|
1593
|
+
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1594
|
+
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, vlmax);
|
|
1595
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1596
|
+
|
|
1597
|
+
nk_size_t offset = 0;
|
|
1598
|
+
for (nk_size_t vector_length; count > 0;
|
|
1599
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
1600
|
+
vector_length = __riscv_vsetvl_e32m1(count);
|
|
1601
|
+
vuint32m1_t data_u32m1 = __riscv_vlse32_v_u32m1((nk_u32_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1602
|
+
vuint64m2_t pos_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(vector_length), (nk_u64_t)offset,
|
|
1603
|
+
vector_length);
|
|
1604
|
+
|
|
1605
|
+
vbool32_t less_b32 = __riscv_vmsltu_vv_u32m1_b32(data_u32m1, min_u32m1, vector_length);
|
|
1606
|
+
min_u32m1 = __riscv_vmerge_vvm_u32m1_tu(min_u32m1, min_u32m1, data_u32m1, less_b32, vector_length);
|
|
1607
|
+
min_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(min_indices_u64m2, min_indices_u64m2, pos_u64m2, less_b32,
|
|
1608
|
+
vector_length);
|
|
1609
|
+
|
|
1610
|
+
vbool32_t greater_b32 = __riscv_vmsltu_vv_u32m1_b32(max_u32m1, data_u32m1, vector_length);
|
|
1611
|
+
max_u32m1 = __riscv_vmerge_vvm_u32m1_tu(max_u32m1, max_u32m1, data_u32m1, greater_b32, vector_length);
|
|
1612
|
+
max_indices_u64m2 = __riscv_vmerge_vvm_u64m2_tu(max_indices_u64m2, max_indices_u64m2, pos_u64m2, greater_b32,
|
|
1613
|
+
vector_length);
|
|
1614
|
+
}
|
|
1615
|
+
|
|
1616
|
+
// Horizontal reduction for min
|
|
1617
|
+
vuint32m1_t init_max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, 1);
|
|
1618
|
+
nk_u32_t min_val = __riscv_vmv_x_s_u32m1_u32(__riscv_vredminu_vs_u32m1_u32m1(min_u32m1, init_max_u32m1, vlmax));
|
|
1619
|
+
vbool32_t min_match_b32 = __riscv_vmseq_vx_u32m1_b32(min_u32m1, min_val, vlmax);
|
|
1620
|
+
vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, vlmax);
|
|
1621
|
+
vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32, vlmax);
|
|
1622
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1623
|
+
*min_value_ptr = min_val;
|
|
1624
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1625
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, vlmax));
|
|
1626
|
+
|
|
1627
|
+
// Horizontal reduction for max
|
|
1628
|
+
vuint32m1_t init_min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, 1);
|
|
1629
|
+
nk_u32_t max_val = __riscv_vmv_x_s_u32m1_u32(__riscv_vredmaxu_vs_u32m1_u32m1(max_u32m1, init_min_u32m1, vlmax));
|
|
1630
|
+
vbool32_t max_match_b32 = __riscv_vmseq_vx_u32m1_b32(max_u32m1, max_val, vlmax);
|
|
1631
|
+
vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32, vlmax);
|
|
1632
|
+
*max_value_ptr = max_val;
|
|
1633
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1634
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, vlmax));
|
|
1635
|
+
}
|
|
1636
|
+
|
|
1637
|
+
NK_PUBLIC void nk_reduce_minmax_u32_rvv( //
|
|
1638
|
+
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1639
|
+
nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1640
|
+
nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1641
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
|
|
1642
|
+
int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
|
|
1643
|
+
|
|
1644
|
+
if (count == 0)
|
|
1645
|
+
*min_value_ptr = NK_U32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_U32_MIN,
|
|
1646
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1647
|
+
else if (!aligned)
|
|
1648
|
+
nk_reduce_minmax_u32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1649
|
+
max_index_ptr);
|
|
1650
|
+
else if (stride_elements == 1)
|
|
1651
|
+
nk_reduce_minmax_u32_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1652
|
+
max_index_ptr);
|
|
1653
|
+
else
|
|
1654
|
+
nk_reduce_minmax_u32_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1655
|
+
max_index_ptr);
|
|
1656
|
+
}
|
|
1657
|
+
|
|
1658
|
+
NK_INTERNAL void nk_reduce_moments_i64_rvv_contiguous_( //
|
|
1659
|
+
nk_i64_t const *data_ptr, nk_size_t count, //
|
|
1660
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1661
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
1662
|
+
// 128-bit per-lane accumulator for sum: (sum_upper, sum_lower)
|
|
1663
|
+
vuint64m1_t sum_lower_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1664
|
+
vint64m1_t sum_upper_i64m1 = __riscv_vmv_v_x_i64m1(0, vlmax);
|
|
1665
|
+
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1666
|
+
|
|
1667
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1668
|
+
vector_length = __riscv_vsetvl_e64m1(count);
|
|
1669
|
+
vint64m1_t data_i64m1 = __riscv_vle64_v_i64m1(data_ptr, vector_length);
|
|
1670
|
+
|
|
1671
|
+
// 128-bit sum accumulation: wrapping add on lower half
|
|
1672
|
+
vuint64m1_t data_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(data_i64m1);
|
|
1673
|
+
vuint64m1_t sum_before_u64m1 = sum_lower_u64m1;
|
|
1674
|
+
sum_lower_u64m1 = __riscv_vadd_vv_u64m1_tu(sum_lower_u64m1, sum_lower_u64m1, data_u64m1, vector_length);
|
|
1675
|
+
|
|
1676
|
+
// Carry: new < old means unsigned overflow occurred
|
|
1677
|
+
vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(sum_lower_u64m1, sum_before_u64m1, vector_length);
|
|
1678
|
+
vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0, vector_length), 1, carry_b64,
|
|
1679
|
+
vector_length);
|
|
1680
|
+
sum_upper_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_upper_i64m1, sum_upper_i64m1, carry_i64m1, vector_length);
|
|
1681
|
+
|
|
1682
|
+
// Sign extension: -1 for negative, 0 for non-negative
|
|
1683
|
+
vint64m1_t sign_ext_i64m1 = __riscv_vsra_vx_i64m1(data_i64m1, 63, vector_length);
|
|
1684
|
+
sum_upper_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_upper_i64m1, sum_upper_i64m1, sign_ext_i64m1, vector_length);
|
|
1685
|
+
|
|
1686
|
+
// Sumsq: abs(val)² with overflow detection
|
|
1687
|
+
vint64m1_t negated_i64m1 = __riscv_vneg_v_i64m1(data_i64m1, vector_length);
|
|
1688
|
+
vint64m1_t absolute_i64m1 = __riscv_vmax_vv_i64m1(data_i64m1, negated_i64m1, vector_length);
|
|
1689
|
+
vuint64m1_t absolute_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(absolute_i64m1);
|
|
1690
|
+
vuint64m1_t product_low_u64m1 = __riscv_vmul_vv_u64m1(absolute_u64m1, absolute_u64m1, vector_length);
|
|
1691
|
+
vuint64m1_t product_high_u64m1 = __riscv_vmulhu_vv_u64m1(absolute_u64m1, absolute_u64m1, vector_length);
|
|
1692
|
+
vbool64_t overflow_b64 = __riscv_vmsne_vx_u64m1_b64(product_high_u64m1, 0, vector_length);
|
|
1693
|
+
vuint64m1_t squares_u64m1 = __riscv_vmerge_vxm_u64m1(product_low_u64m1, NK_U64_MAX, overflow_b64,
|
|
1694
|
+
vector_length);
|
|
1695
|
+
sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
|
|
1696
|
+
}
|
|
1697
|
+
|
|
1698
|
+
*sum_ptr = nk_reduce_128bit_sum_i64m1_rvv_(sum_lower_u64m1, sum_upper_i64m1, vlmax);
|
|
1699
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, vlmax);
|
|
1700
|
+
}
|
|
1701
|
+
|
|
1702
|
+
NK_INTERNAL void nk_reduce_moments_i64_rvv_strided_( //
|
|
1703
|
+
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1704
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1705
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
1706
|
+
// 128-bit per-lane accumulator for sum: (sum_upper, sum_lower)
|
|
1707
|
+
vuint64m1_t sum_lower_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1708
|
+
vint64m1_t sum_upper_i64m1 = __riscv_vmv_v_x_i64m1(0, vlmax);
|
|
1709
|
+
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1710
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1711
|
+
|
|
1712
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
1713
|
+
vector_length = __riscv_vsetvl_e64m1(count);
|
|
1714
|
+
vint64m1_t data_i64m1 = __riscv_vlse64_v_i64m1((nk_i64_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1715
|
+
|
|
1716
|
+
// 128-bit sum accumulation: wrapping add on lower half
|
|
1717
|
+
vuint64m1_t data_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(data_i64m1);
|
|
1718
|
+
vuint64m1_t sum_before_u64m1 = sum_lower_u64m1;
|
|
1719
|
+
sum_lower_u64m1 = __riscv_vadd_vv_u64m1_tu(sum_lower_u64m1, sum_lower_u64m1, data_u64m1, vector_length);
|
|
1720
|
+
|
|
1721
|
+
// Carry: new < old means unsigned overflow occurred
|
|
1722
|
+
vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(sum_lower_u64m1, sum_before_u64m1, vector_length);
|
|
1723
|
+
vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0, vector_length), 1, carry_b64,
|
|
1724
|
+
vector_length);
|
|
1725
|
+
sum_upper_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_upper_i64m1, sum_upper_i64m1, carry_i64m1, vector_length);
|
|
1726
|
+
|
|
1727
|
+
// Sign extension: -1 for negative, 0 for non-negative
|
|
1728
|
+
vint64m1_t sign_ext_i64m1 = __riscv_vsra_vx_i64m1(data_i64m1, 63, vector_length);
|
|
1729
|
+
sum_upper_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_upper_i64m1, sum_upper_i64m1, sign_ext_i64m1, vector_length);
|
|
1730
|
+
|
|
1731
|
+
// Sumsq: abs(val)² with overflow detection
|
|
1732
|
+
vint64m1_t negated_i64m1 = __riscv_vneg_v_i64m1(data_i64m1, vector_length);
|
|
1733
|
+
vint64m1_t absolute_i64m1 = __riscv_vmax_vv_i64m1(data_i64m1, negated_i64m1, vector_length);
|
|
1734
|
+
vuint64m1_t absolute_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(absolute_i64m1);
|
|
1735
|
+
vuint64m1_t product_low_u64m1 = __riscv_vmul_vv_u64m1(absolute_u64m1, absolute_u64m1, vector_length);
|
|
1736
|
+
vuint64m1_t product_high_u64m1 = __riscv_vmulhu_vv_u64m1(absolute_u64m1, absolute_u64m1, vector_length);
|
|
1737
|
+
vbool64_t overflow_b64 = __riscv_vmsne_vx_u64m1_b64(product_high_u64m1, 0, vector_length);
|
|
1738
|
+
vuint64m1_t squares_u64m1 = __riscv_vmerge_vxm_u64m1(product_low_u64m1, NK_U64_MAX, overflow_b64,
|
|
1739
|
+
vector_length);
|
|
1740
|
+
sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
|
|
1741
|
+
}
|
|
1742
|
+
|
|
1743
|
+
*sum_ptr = nk_reduce_128bit_sum_i64m1_rvv_(sum_lower_u64m1, sum_upper_i64m1, vlmax);
|
|
1744
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, vlmax);
|
|
1745
|
+
}
|
|
1746
|
+
|
|
1747
|
+
NK_PUBLIC void nk_reduce_moments_i64_rvv( //
|
|
1748
|
+
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1749
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1750
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
|
|
1751
|
+
int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
|
|
1752
|
+
|
|
1753
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1754
|
+
else if (!aligned) { nk_reduce_moments_i64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1755
|
+
else if (stride_elements == 1) { nk_reduce_moments_i64_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
|
|
1756
|
+
else { nk_reduce_moments_i64_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1757
|
+
}
|
|
1758
|
+
|
|
1759
|
+
NK_INTERNAL void nk_reduce_minmax_i64_rvv_contiguous_( //
|
|
1760
|
+
nk_i64_t const *data_ptr, nk_size_t count, //
|
|
1761
|
+
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1762
|
+
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1763
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
1764
|
+
vint64m1_t min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, vlmax);
|
|
1765
|
+
vint64m1_t max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, vlmax);
|
|
1766
|
+
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1767
|
+
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1768
|
+
|
|
1769
|
+
nk_size_t offset = 0;
|
|
1770
|
+
for (nk_size_t vector_length; count > 0;
|
|
1771
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
1772
|
+
vector_length = __riscv_vsetvl_e64m1(count);
|
|
1773
|
+
vint64m1_t data_i64m1 = __riscv_vle64_v_i64m1(data_ptr, vector_length);
|
|
1774
|
+
vuint64m1_t pos_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
|
|
1775
|
+
vector_length);
|
|
1776
|
+
|
|
1777
|
+
vbool64_t less_b64 = __riscv_vmslt_vv_i64m1_b64(data_i64m1, min_i64m1, vector_length);
|
|
1778
|
+
min_i64m1 = __riscv_vmerge_vvm_i64m1_tu(min_i64m1, min_i64m1, data_i64m1, less_b64, vector_length);
|
|
1779
|
+
min_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_indices_u64m1, min_indices_u64m1, pos_u64m1, less_b64,
|
|
1780
|
+
vector_length);
|
|
1781
|
+
|
|
1782
|
+
vbool64_t greater_b64 = __riscv_vmslt_vv_i64m1_b64(max_i64m1, data_i64m1, vector_length);
|
|
1783
|
+
max_i64m1 = __riscv_vmerge_vvm_i64m1_tu(max_i64m1, max_i64m1, data_i64m1, greater_b64, vector_length);
|
|
1784
|
+
max_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_indices_u64m1, max_indices_u64m1, pos_u64m1, greater_b64,
|
|
1785
|
+
vector_length);
|
|
1786
|
+
}
|
|
1787
|
+
|
|
1788
|
+
// Horizontal reduction for min
|
|
1789
|
+
vint64m1_t init_max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, 1);
|
|
1790
|
+
nk_i64_t min_val = __riscv_vmv_x_s_i64m1_i64(__riscv_vredmin_vs_i64m1_i64m1(min_i64m1, init_max_i64m1, vlmax));
|
|
1791
|
+
vbool64_t min_match_b64 = __riscv_vmseq_vx_i64m1_b64(min_i64m1, min_val, vlmax);
|
|
1792
|
+
vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
|
|
1793
|
+
vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64, vlmax);
|
|
1794
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1795
|
+
*min_value_ptr = min_val;
|
|
1796
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1797
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, vlmax));
|
|
1798
|
+
|
|
1799
|
+
// Horizontal reduction for max
|
|
1800
|
+
vint64m1_t init_min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, 1);
|
|
1801
|
+
nk_i64_t max_val = __riscv_vmv_x_s_i64m1_i64(__riscv_vredmax_vs_i64m1_i64m1(max_i64m1, init_min_i64m1, vlmax));
|
|
1802
|
+
vbool64_t max_match_b64 = __riscv_vmseq_vx_i64m1_b64(max_i64m1, max_val, vlmax);
|
|
1803
|
+
vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64, vlmax);
|
|
1804
|
+
*max_value_ptr = max_val;
|
|
1805
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1806
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, vlmax));
|
|
1807
|
+
}
|
|
1808
|
+
|
|
1809
|
+
NK_INTERNAL void nk_reduce_minmax_i64_rvv_strided_( //
|
|
1810
|
+
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1811
|
+
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1812
|
+
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1813
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
1814
|
+
vint64m1_t min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, vlmax);
|
|
1815
|
+
vint64m1_t max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, vlmax);
|
|
1816
|
+
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1817
|
+
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1818
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1819
|
+
|
|
1820
|
+
nk_size_t offset = 0;
|
|
1821
|
+
for (nk_size_t vector_length; count > 0;
|
|
1822
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
1823
|
+
vector_length = __riscv_vsetvl_e64m1(count);
|
|
1824
|
+
vint64m1_t data_i64m1 = __riscv_vlse64_v_i64m1((nk_i64_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1825
|
+
vuint64m1_t pos_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
|
|
1826
|
+
vector_length);
|
|
1827
|
+
|
|
1828
|
+
vbool64_t less_b64 = __riscv_vmslt_vv_i64m1_b64(data_i64m1, min_i64m1, vector_length);
|
|
1829
|
+
min_i64m1 = __riscv_vmerge_vvm_i64m1_tu(min_i64m1, min_i64m1, data_i64m1, less_b64, vector_length);
|
|
1830
|
+
min_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_indices_u64m1, min_indices_u64m1, pos_u64m1, less_b64,
|
|
1831
|
+
vector_length);
|
|
1832
|
+
|
|
1833
|
+
vbool64_t greater_b64 = __riscv_vmslt_vv_i64m1_b64(max_i64m1, data_i64m1, vector_length);
|
|
1834
|
+
max_i64m1 = __riscv_vmerge_vvm_i64m1_tu(max_i64m1, max_i64m1, data_i64m1, greater_b64, vector_length);
|
|
1835
|
+
max_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_indices_u64m1, max_indices_u64m1, pos_u64m1, greater_b64,
|
|
1836
|
+
vector_length);
|
|
1837
|
+
}
|
|
1838
|
+
|
|
1839
|
+
// Horizontal reduction for min
|
|
1840
|
+
vint64m1_t init_max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, 1);
|
|
1841
|
+
nk_i64_t min_val = __riscv_vmv_x_s_i64m1_i64(__riscv_vredmin_vs_i64m1_i64m1(min_i64m1, init_max_i64m1, vlmax));
|
|
1842
|
+
vbool64_t min_match_b64 = __riscv_vmseq_vx_i64m1_b64(min_i64m1, min_val, vlmax);
|
|
1843
|
+
vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
|
|
1844
|
+
vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64, vlmax);
|
|
1845
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1846
|
+
*min_value_ptr = min_val;
|
|
1847
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1848
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, vlmax));
|
|
1849
|
+
|
|
1850
|
+
// Horizontal reduction for max
|
|
1851
|
+
vint64m1_t init_min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, 1);
|
|
1852
|
+
nk_i64_t max_val = __riscv_vmv_x_s_i64m1_i64(__riscv_vredmax_vs_i64m1_i64m1(max_i64m1, init_min_i64m1, vlmax));
|
|
1853
|
+
vbool64_t max_match_b64 = __riscv_vmseq_vx_i64m1_b64(max_i64m1, max_val, vlmax);
|
|
1854
|
+
vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64, vlmax);
|
|
1855
|
+
*max_value_ptr = max_val;
|
|
1856
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1857
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, vlmax));
|
|
1858
|
+
}
|
|
1859
|
+
|
|
1860
|
+
NK_PUBLIC void nk_reduce_minmax_i64_rvv( //
|
|
1861
|
+
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1862
|
+
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1863
|
+
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1864
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
|
|
1865
|
+
int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
|
|
1866
|
+
|
|
1867
|
+
if (count == 0)
|
|
1868
|
+
*min_value_ptr = NK_I64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I64_MIN,
|
|
1869
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1870
|
+
else if (!aligned)
|
|
1871
|
+
nk_reduce_minmax_i64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1872
|
+
max_index_ptr);
|
|
1873
|
+
else if (stride_elements == 1)
|
|
1874
|
+
nk_reduce_minmax_i64_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1875
|
+
max_index_ptr);
|
|
1876
|
+
else
|
|
1877
|
+
nk_reduce_minmax_i64_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1878
|
+
max_index_ptr);
|
|
1879
|
+
}
|
|
1880
|
+
|
|
1881
|
+
NK_INTERNAL void nk_reduce_moments_u64_rvv_contiguous_( //
|
|
1882
|
+
nk_u64_t const *data_ptr, nk_size_t count, //
|
|
1883
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1884
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
1885
|
+
vuint64m1_t sum_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1886
|
+
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1887
|
+
|
|
1888
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1889
|
+
vector_length = __riscv_vsetvl_e64m1(count);
|
|
1890
|
+
vuint64m1_t data_u64m1 = __riscv_vle64_v_u64m1(data_ptr, vector_length);
|
|
1891
|
+
|
|
1892
|
+
// Saturating unsigned sum
|
|
1893
|
+
sum_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sum_u64m1, sum_u64m1, data_u64m1, vector_length);
|
|
1894
|
+
|
|
1895
|
+
// Sumsq: u64 × u64 with overflow detection via vmul + vmulhu
|
|
1896
|
+
vuint64m1_t product_low_u64m1 = __riscv_vmul_vv_u64m1(data_u64m1, data_u64m1, vector_length);
|
|
1897
|
+
vuint64m1_t product_high_u64m1 = __riscv_vmulhu_vv_u64m1(data_u64m1, data_u64m1, vector_length);
|
|
1898
|
+
vbool64_t overflow_b64 = __riscv_vmsne_vx_u64m1_b64(product_high_u64m1, 0, vector_length);
|
|
1899
|
+
vuint64m1_t squares_u64m1 = __riscv_vmerge_vxm_u64m1(product_low_u64m1, NK_U64_MAX, overflow_b64,
|
|
1900
|
+
vector_length);
|
|
1901
|
+
sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
|
|
1902
|
+
}
|
|
1903
|
+
|
|
1904
|
+
*sum_ptr = nk_reduce_vsaddu_u64m1_rvv_(sum_u64m1, vlmax);
|
|
1905
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, vlmax);
|
|
1906
|
+
}
|
|
1907
|
+
|
|
1908
|
+
NK_INTERNAL void nk_reduce_moments_u64_rvv_strided_( //
|
|
1909
|
+
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1910
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1911
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
1912
|
+
vuint64m1_t sum_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1913
|
+
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1914
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1915
|
+
|
|
1916
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
1917
|
+
vector_length = __riscv_vsetvl_e64m1(count);
|
|
1918
|
+
vuint64m1_t data_u64m1 = __riscv_vlse64_v_u64m1((nk_u64_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
1919
|
+
|
|
1920
|
+
// Saturating unsigned sum
|
|
1921
|
+
sum_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sum_u64m1, sum_u64m1, data_u64m1, vector_length);
|
|
1922
|
+
|
|
1923
|
+
// Sumsq: u64 × u64 with overflow detection via vmul + vmulhu
|
|
1924
|
+
vuint64m1_t product_low_u64m1 = __riscv_vmul_vv_u64m1(data_u64m1, data_u64m1, vector_length);
|
|
1925
|
+
vuint64m1_t product_high_u64m1 = __riscv_vmulhu_vv_u64m1(data_u64m1, data_u64m1, vector_length);
|
|
1926
|
+
vbool64_t overflow_b64 = __riscv_vmsne_vx_u64m1_b64(product_high_u64m1, 0, vector_length);
|
|
1927
|
+
vuint64m1_t squares_u64m1 = __riscv_vmerge_vxm_u64m1(product_low_u64m1, NK_U64_MAX, overflow_b64,
|
|
1928
|
+
vector_length);
|
|
1929
|
+
sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
|
|
1930
|
+
}
|
|
1931
|
+
|
|
1932
|
+
*sum_ptr = nk_reduce_vsaddu_u64m1_rvv_(sum_u64m1, vlmax);
|
|
1933
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, vlmax);
|
|
1934
|
+
}
|
|
1935
|
+
|
|
1936
|
+
NK_PUBLIC void nk_reduce_moments_u64_rvv( //
|
|
1937
|
+
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1938
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1939
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
|
|
1940
|
+
int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
|
|
1941
|
+
|
|
1942
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1943
|
+
else if (!aligned) { nk_reduce_moments_u64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1944
|
+
else if (stride_elements == 1) { nk_reduce_moments_u64_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr); }
|
|
1945
|
+
else { nk_reduce_moments_u64_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr); }
|
|
1946
|
+
}
|
|
1947
|
+
|
|
1948
|
+
NK_INTERNAL void nk_reduce_minmax_u64_rvv_contiguous_( //
|
|
1949
|
+
nk_u64_t const *data_ptr, nk_size_t count, //
|
|
1950
|
+
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1951
|
+
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1952
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
1953
|
+
vuint64m1_t min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
|
|
1954
|
+
vuint64m1_t max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, vlmax);
|
|
1955
|
+
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1956
|
+
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
1957
|
+
|
|
1958
|
+
nk_size_t offset = 0;
|
|
1959
|
+
for (nk_size_t vector_length; count > 0;
|
|
1960
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
1961
|
+
vector_length = __riscv_vsetvl_e64m1(count);
|
|
1962
|
+
vuint64m1_t data_u64m1 = __riscv_vle64_v_u64m1(data_ptr, vector_length);
|
|
1963
|
+
vuint64m1_t pos_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
|
|
1964
|
+
vector_length);
|
|
1965
|
+
|
|
1966
|
+
vbool64_t less_b64 = __riscv_vmsltu_vv_u64m1_b64(data_u64m1, min_u64m1, vector_length);
|
|
1967
|
+
min_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_u64m1, min_u64m1, data_u64m1, less_b64, vector_length);
|
|
1968
|
+
min_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_indices_u64m1, min_indices_u64m1, pos_u64m1, less_b64,
|
|
1969
|
+
vector_length);
|
|
1970
|
+
|
|
1971
|
+
vbool64_t greater_b64 = __riscv_vmsltu_vv_u64m1_b64(max_u64m1, data_u64m1, vector_length);
|
|
1972
|
+
max_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_u64m1, max_u64m1, data_u64m1, greater_b64, vector_length);
|
|
1973
|
+
max_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_indices_u64m1, max_indices_u64m1, pos_u64m1, greater_b64,
|
|
1974
|
+
vector_length);
|
|
1975
|
+
}
|
|
1976
|
+
|
|
1977
|
+
// Horizontal reduction for min
|
|
1978
|
+
vuint64m1_t init_max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1979
|
+
nk_u64_t min_val = __riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(min_u64m1, init_max_u64m1, vlmax));
|
|
1980
|
+
vbool64_t min_match_b64 = __riscv_vmseq_vx_u64m1_b64(min_u64m1, min_val, vlmax);
|
|
1981
|
+
vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
|
|
1982
|
+
vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64, vlmax);
|
|
1983
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1984
|
+
*min_value_ptr = min_val;
|
|
1985
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1986
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, vlmax));
|
|
1987
|
+
|
|
1988
|
+
// Horizontal reduction for max
|
|
1989
|
+
vuint64m1_t init_min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, 1);
|
|
1990
|
+
nk_u64_t max_val = __riscv_vmv_x_s_u64m1_u64(__riscv_vredmaxu_vs_u64m1_u64m1(max_u64m1, init_min_u64m1, vlmax));
|
|
1991
|
+
vbool64_t max_match_b64 = __riscv_vmseq_vx_u64m1_b64(max_u64m1, max_val, vlmax);
|
|
1992
|
+
vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64, vlmax);
|
|
1993
|
+
*max_value_ptr = max_val;
|
|
1994
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1995
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, vlmax));
|
|
1996
|
+
}
|
|
1997
|
+
|
|
1998
|
+
NK_INTERNAL void nk_reduce_minmax_u64_rvv_strided_( //
|
|
1999
|
+
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2000
|
+
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2001
|
+
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2002
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
|
|
2003
|
+
vuint64m1_t min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
|
|
2004
|
+
vuint64m1_t max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, vlmax);
|
|
2005
|
+
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
2006
|
+
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, vlmax);
|
|
2007
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2008
|
+
|
|
2009
|
+
nk_size_t offset = 0;
|
|
2010
|
+
for (nk_size_t vector_length; count > 0;
|
|
2011
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
2012
|
+
vector_length = __riscv_vsetvl_e64m1(count);
|
|
2013
|
+
vuint64m1_t data_u64m1 = __riscv_vlse64_v_u64m1((nk_u64_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2014
|
+
vuint64m1_t pos_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(vector_length), (nk_u64_t)offset,
|
|
2015
|
+
vector_length);
|
|
2016
|
+
|
|
2017
|
+
vbool64_t less_b64 = __riscv_vmsltu_vv_u64m1_b64(data_u64m1, min_u64m1, vector_length);
|
|
2018
|
+
min_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_u64m1, min_u64m1, data_u64m1, less_b64, vector_length);
|
|
2019
|
+
min_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(min_indices_u64m1, min_indices_u64m1, pos_u64m1, less_b64,
|
|
2020
|
+
vector_length);
|
|
2021
|
+
|
|
2022
|
+
vbool64_t greater_b64 = __riscv_vmsltu_vv_u64m1_b64(max_u64m1, data_u64m1, vector_length);
|
|
2023
|
+
max_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_u64m1, max_u64m1, data_u64m1, greater_b64, vector_length);
|
|
2024
|
+
max_indices_u64m1 = __riscv_vmerge_vvm_u64m1_tu(max_indices_u64m1, max_indices_u64m1, pos_u64m1, greater_b64,
|
|
2025
|
+
vector_length);
|
|
2026
|
+
}
|
|
2027
|
+
|
|
2028
|
+
// Horizontal reduction for min
|
|
2029
|
+
vuint64m1_t init_max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2030
|
+
nk_u64_t min_val = __riscv_vmv_x_s_u64m1_u64(__riscv_vredminu_vs_u64m1_u64m1(min_u64m1, init_max_u64m1, vlmax));
|
|
2031
|
+
vbool64_t min_match_b64 = __riscv_vmseq_vx_u64m1_b64(min_u64m1, min_val, vlmax);
|
|
2032
|
+
vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, vlmax);
|
|
2033
|
+
vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64, vlmax);
|
|
2034
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2035
|
+
*min_value_ptr = min_val;
|
|
2036
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2037
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, vlmax));
|
|
2038
|
+
|
|
2039
|
+
// Horizontal reduction for max
|
|
2040
|
+
vuint64m1_t init_min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, 1);
|
|
2041
|
+
nk_u64_t max_val = __riscv_vmv_x_s_u64m1_u64(__riscv_vredmaxu_vs_u64m1_u64m1(max_u64m1, init_min_u64m1, vlmax));
|
|
2042
|
+
vbool64_t max_match_b64 = __riscv_vmseq_vx_u64m1_b64(max_u64m1, max_val, vlmax);
|
|
2043
|
+
vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64, vlmax);
|
|
2044
|
+
*max_value_ptr = max_val;
|
|
2045
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2046
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, vlmax));
|
|
2047
|
+
}
|
|
2048
|
+
|
|
2049
|
+
NK_PUBLIC void nk_reduce_minmax_u64_rvv( //
|
|
2050
|
+
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2051
|
+
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2052
|
+
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2053
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
|
|
2054
|
+
int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
|
|
2055
|
+
|
|
2056
|
+
if (count == 0)
|
|
2057
|
+
*min_value_ptr = NK_U64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_U64_MIN,
|
|
2058
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2059
|
+
else if (!aligned)
|
|
2060
|
+
nk_reduce_minmax_u64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2061
|
+
max_index_ptr);
|
|
2062
|
+
else if (stride_elements == 1)
|
|
2063
|
+
nk_reduce_minmax_u64_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2064
|
+
max_index_ptr);
|
|
2065
|
+
else
|
|
2066
|
+
nk_reduce_minmax_u64_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2067
|
+
max_index_ptr);
|
|
2068
|
+
}
|
|
2069
|
+
|
|
2070
|
+
NK_INTERNAL void nk_reduce_moments_bf16_rvv_contiguous_( //
|
|
2071
|
+
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
2072
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2073
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
2074
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2075
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2076
|
+
|
|
2077
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
2078
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2079
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((uint16_t const *)data_ptr, vector_length);
|
|
2080
|
+
|
|
2081
|
+
// Convert bf16 → f32 (m1 → m2)
|
|
2082
|
+
vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
2083
|
+
|
|
2084
|
+
// Widen f32 → f64 (m2 → m4)
|
|
2085
|
+
vfloat64m4_t data_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(data_f32m2, vector_length);
|
|
2086
|
+
sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
|
|
2087
|
+
|
|
2088
|
+
// Sumsq via widening FMA: f32×f32 → f64
|
|
2089
|
+
sumsq_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(sumsq_f64m4, data_f32m2, data_f32m2, vector_length);
|
|
2090
|
+
}
|
|
2091
|
+
|
|
2092
|
+
// Horizontal reduction
|
|
2093
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2094
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)),
|
|
2095
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, vlmax));
|
|
2096
|
+
}
|
|
2097
|
+
|
|
2098
|
+
NK_INTERNAL void nk_reduce_moments_bf16_rvv_strided_( //
|
|
2099
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2100
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2101
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
2102
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2103
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2104
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2105
|
+
|
|
2106
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
2107
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2108
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((uint16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2109
|
+
|
|
2110
|
+
// Convert bf16 → f32 (m1 → m2)
|
|
2111
|
+
vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
2112
|
+
|
|
2113
|
+
// Widen f32 → f64 (m2 → m4)
|
|
2114
|
+
vfloat64m4_t data_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(data_f32m2, vector_length);
|
|
2115
|
+
sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
|
|
2116
|
+
|
|
2117
|
+
// Sumsq via widening FMA: f32×f32 → f64
|
|
2118
|
+
sumsq_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(sumsq_f64m4, data_f32m2, data_f32m2, vector_length);
|
|
2119
|
+
}
|
|
2120
|
+
|
|
2121
|
+
// Horizontal reduction
|
|
2122
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2123
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)),
|
|
2124
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, vlmax));
|
|
2125
|
+
}
|
|
2126
|
+
|
|
2127
|
+
NK_PUBLIC void nk_reduce_moments_bf16_rvv( //
|
|
2128
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2129
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2130
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
|
|
2131
|
+
int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
|
|
2132
|
+
|
|
2133
|
+
if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
|
|
2134
|
+
else if (!aligned) nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2135
|
+
else if (stride_elements == 1) nk_reduce_moments_bf16_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2136
|
+
else nk_reduce_moments_bf16_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2137
|
+
}
|
|
2138
|
+
|
|
2139
|
+
NK_INTERNAL void nk_reduce_minmax_bf16_rvv_contiguous_( //
|
|
2140
|
+
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
2141
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2142
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2143
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
|
|
2144
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7F80, vlmax); // +inf in bf16
|
|
2145
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFF80, vlmax); // -inf in bf16
|
|
2146
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
2147
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
2148
|
+
|
|
2149
|
+
nk_size_t offset = 0;
|
|
2150
|
+
for (nk_size_t vector_length; count > 0;
|
|
2151
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
2152
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2153
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((uint16_t const *)data_ptr, vector_length);
|
|
2154
|
+
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
2155
|
+
vector_length);
|
|
2156
|
+
|
|
2157
|
+
// Convert to f32 for comparison
|
|
2158
|
+
vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
2159
|
+
vfloat32m2_t min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vector_length);
|
|
2160
|
+
vfloat32m2_t max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, vector_length);
|
|
2161
|
+
|
|
2162
|
+
vbool16_t less_b16 = __riscv_vmflt_vv_f32m2_b16(data_f32m2, min_f32m2, vector_length);
|
|
2163
|
+
min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
|
|
2164
|
+
min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
|
|
2165
|
+
vector_length);
|
|
2166
|
+
|
|
2167
|
+
vbool16_t greater_b16 = __riscv_vmflt_vv_f32m2_b16(max_f32m2, data_f32m2, vector_length);
|
|
2168
|
+
max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
|
|
2169
|
+
max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
|
|
2170
|
+
vector_length);
|
|
2171
|
+
}
|
|
2172
|
+
|
|
2173
|
+
// Horizontal reduction
|
|
2174
|
+
vfloat32m2_t final_min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vlmax);
|
|
2175
|
+
vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
2176
|
+
nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2177
|
+
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, vlmax));
|
|
2178
|
+
vfloat32m2_t final_max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, vlmax);
|
|
2179
|
+
vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
2180
|
+
nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2181
|
+
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, vlmax));
|
|
2182
|
+
if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
|
|
2183
|
+
*min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
|
|
2184
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2185
|
+
return;
|
|
2186
|
+
}
|
|
2187
|
+
|
|
2188
|
+
vfloat32m2_t converted_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vlmax);
|
|
2189
|
+
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, vlmax);
|
|
2190
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
|
|
2191
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
|
|
2192
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2193
|
+
|
|
2194
|
+
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2195
|
+
__riscv_vslidedown_vx_u16m1(min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, vlmax), vlmax));
|
|
2196
|
+
*min_value_ptr = *(nk_bf16_t *)&min_raw;
|
|
2197
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2198
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
|
|
2199
|
+
|
|
2200
|
+
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_bf16m1_to_f32m2_rvv_(max_u16m1, vlmax), max_val_f32, vlmax);
|
|
2201
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
|
|
2202
|
+
|
|
2203
|
+
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2204
|
+
__riscv_vslidedown_vx_u16m1(max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, vlmax), vlmax));
|
|
2205
|
+
*max_value_ptr = *(nk_bf16_t *)&max_raw;
|
|
2206
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2207
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
|
|
2208
|
+
}
|
|
2209
|
+
|
|
2210
|
+
NK_INTERNAL void nk_reduce_minmax_bf16_rvv_strided_( //
|
|
2211
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2212
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2213
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2214
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
|
|
2215
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7F80, vlmax); // +inf in bf16
|
|
2216
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFF80, vlmax); // -inf in bf16
|
|
2217
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
2218
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
2219
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2220
|
+
|
|
2221
|
+
nk_size_t offset = 0;
|
|
2222
|
+
for (nk_size_t vector_length; count > 0;
|
|
2223
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
2224
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2225
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((uint16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2226
|
+
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
2227
|
+
vector_length);
|
|
2228
|
+
|
|
2229
|
+
// Convert to f32 for comparison
|
|
2230
|
+
vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
2231
|
+
vfloat32m2_t min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vector_length);
|
|
2232
|
+
vfloat32m2_t max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, vector_length);
|
|
2233
|
+
|
|
2234
|
+
vbool16_t less_b16 = __riscv_vmflt_vv_f32m2_b16(data_f32m2, min_f32m2, vector_length);
|
|
2235
|
+
min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
|
|
2236
|
+
min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
|
|
2237
|
+
vector_length);
|
|
2238
|
+
|
|
2239
|
+
vbool16_t greater_b16 = __riscv_vmflt_vv_f32m2_b16(max_f32m2, data_f32m2, vector_length);
|
|
2240
|
+
max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
|
|
2241
|
+
max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
|
|
2242
|
+
vector_length);
|
|
2243
|
+
}
|
|
2244
|
+
|
|
2245
|
+
// Horizontal reduction (same as contiguous)
|
|
2246
|
+
vfloat32m2_t final_min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vlmax);
|
|
2247
|
+
vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
2248
|
+
nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2249
|
+
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, vlmax));
|
|
2250
|
+
vfloat32m2_t final_max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, vlmax);
|
|
2251
|
+
vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
2252
|
+
nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2253
|
+
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, vlmax));
|
|
2254
|
+
if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
|
|
2255
|
+
*min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
|
|
2256
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2257
|
+
return;
|
|
2258
|
+
}
|
|
2259
|
+
|
|
2260
|
+
vfloat32m2_t converted_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, vlmax);
|
|
2261
|
+
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, vlmax);
|
|
2262
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
|
|
2263
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
|
|
2264
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2265
|
+
|
|
2266
|
+
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2267
|
+
__riscv_vslidedown_vx_u16m1(min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, vlmax), vlmax));
|
|
2268
|
+
*min_value_ptr = *(nk_bf16_t *)&min_raw;
|
|
2269
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2270
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
|
|
2271
|
+
|
|
2272
|
+
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_bf16m1_to_f32m2_rvv_(max_u16m1, vlmax), max_val_f32, vlmax);
|
|
2273
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
|
|
2274
|
+
|
|
2275
|
+
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2276
|
+
__riscv_vslidedown_vx_u16m1(max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, vlmax), vlmax));
|
|
2277
|
+
*max_value_ptr = *(nk_bf16_t *)&max_raw;
|
|
2278
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2279
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
|
|
2280
|
+
}
|
|
2281
|
+
|
|
2282
|
+
NK_PUBLIC void nk_reduce_minmax_bf16_rvv( //
|
|
2283
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2284
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2285
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2286
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
|
|
2287
|
+
int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
|
|
2288
|
+
|
|
2289
|
+
if (count == 0)
|
|
2290
|
+
*min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
|
|
2291
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2292
|
+
else if (!aligned)
|
|
2293
|
+
nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2294
|
+
max_index_ptr);
|
|
2295
|
+
else if (stride_elements == 1)
|
|
2296
|
+
nk_reduce_minmax_bf16_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2297
|
+
max_index_ptr);
|
|
2298
|
+
else
|
|
2299
|
+
nk_reduce_minmax_bf16_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2300
|
+
max_index_ptr);
|
|
2301
|
+
}
|
|
2302
|
+
|
|
2303
|
+
NK_INTERNAL void nk_reduce_moments_f16_rvv_contiguous_( //
|
|
2304
|
+
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
2305
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2306
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
2307
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2308
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2309
|
+
|
|
2310
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
2311
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2312
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((uint16_t const *)data_ptr, vector_length);
|
|
2313
|
+
|
|
2314
|
+
// Convert f16 → f32 (m1 → m2)
|
|
2315
|
+
vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
2316
|
+
|
|
2317
|
+
// Widen f32 → f64 (m2 → m4)
|
|
2318
|
+
vfloat64m4_t data_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(data_f32m2, vector_length);
|
|
2319
|
+
sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
|
|
2320
|
+
|
|
2321
|
+
// Sumsq via widening FMA: f32×f32 → f64
|
|
2322
|
+
sumsq_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(sumsq_f64m4, data_f32m2, data_f32m2, vector_length);
|
|
2323
|
+
}
|
|
2324
|
+
|
|
2325
|
+
// Horizontal reduction
|
|
2326
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2327
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)),
|
|
2328
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, vlmax));
|
|
2329
|
+
}
|
|
2330
|
+
|
|
2331
|
+
NK_INTERNAL void nk_reduce_moments_f16_rvv_strided_( //
|
|
2332
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2333
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2334
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
2335
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2336
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2337
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2338
|
+
|
|
2339
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
2340
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2341
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((uint16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2342
|
+
|
|
2343
|
+
// Convert f16 → f32 (m1 → m2)
|
|
2344
|
+
vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
2345
|
+
|
|
2346
|
+
// Widen f32 → f64 (m2 → m4)
|
|
2347
|
+
vfloat64m4_t data_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(data_f32m2, vector_length);
|
|
2348
|
+
sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, data_f64m4, vector_length);
|
|
2349
|
+
|
|
2350
|
+
// Sumsq via widening FMA: f32×f32 → f64
|
|
2351
|
+
sumsq_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(sumsq_f64m4, data_f32m2, data_f32m2, vector_length);
|
|
2352
|
+
}
|
|
2353
|
+
|
|
2354
|
+
// Horizontal reduction
|
|
2355
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2356
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)),
|
|
2357
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, vlmax));
|
|
2358
|
+
}
|
|
2359
|
+
|
|
2360
|
+
NK_PUBLIC void nk_reduce_moments_f16_rvv( //
|
|
2361
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2362
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2363
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
2364
|
+
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
2365
|
+
|
|
2366
|
+
if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
|
|
2367
|
+
else if (!aligned) nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2368
|
+
else if (stride_elements == 1) nk_reduce_moments_f16_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2369
|
+
else nk_reduce_moments_f16_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2370
|
+
}
|
|
2371
|
+
|
|
2372
|
+
NK_INTERNAL void nk_reduce_minmax_f16_rvv_contiguous_( //
|
|
2373
|
+
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
2374
|
+
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2375
|
+
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2376
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
|
|
2377
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7C00, vlmax); // +inf in f16
|
|
2378
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFC00, vlmax); // -inf in f16
|
|
2379
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
2380
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
2381
|
+
|
|
2382
|
+
nk_size_t offset = 0;
|
|
2383
|
+
for (nk_size_t vector_length; count > 0;
|
|
2384
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
2385
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2386
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((uint16_t const *)data_ptr, vector_length);
|
|
2387
|
+
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
2388
|
+
vector_length);
|
|
2389
|
+
|
|
2390
|
+
// Convert to f32 for comparison
|
|
2391
|
+
vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
2392
|
+
vfloat32m2_t min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vector_length);
|
|
2393
|
+
vfloat32m2_t max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, vector_length);
|
|
2394
|
+
|
|
2395
|
+
vbool16_t less_b16 = __riscv_vmflt_vv_f32m2_b16(data_f32m2, min_f32m2, vector_length);
|
|
2396
|
+
min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
|
|
2397
|
+
min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
|
|
2398
|
+
vector_length);
|
|
2399
|
+
|
|
2400
|
+
vbool16_t greater_b16 = __riscv_vmflt_vv_f32m2_b16(max_f32m2, data_f32m2, vector_length);
|
|
2401
|
+
max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
|
|
2402
|
+
max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
|
|
2403
|
+
vector_length);
|
|
2404
|
+
}
|
|
2405
|
+
|
|
2406
|
+
// Horizontal reduction
|
|
2407
|
+
vfloat32m2_t final_min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vlmax);
|
|
2408
|
+
vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
2409
|
+
nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2410
|
+
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, vlmax));
|
|
2411
|
+
vfloat32m2_t final_max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, vlmax);
|
|
2412
|
+
vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
2413
|
+
nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2414
|
+
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, vlmax));
|
|
2415
|
+
if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
|
|
2416
|
+
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
2417
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2418
|
+
return;
|
|
2419
|
+
}
|
|
2420
|
+
|
|
2421
|
+
vfloat32m2_t converted_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vlmax);
|
|
2422
|
+
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, vlmax);
|
|
2423
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
|
|
2424
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
|
|
2425
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2426
|
+
|
|
2427
|
+
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2428
|
+
__riscv_vslidedown_vx_u16m1(min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, vlmax), vlmax));
|
|
2429
|
+
*min_value_ptr = *(nk_f16_t *)&min_raw;
|
|
2430
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2431
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
|
|
2432
|
+
|
|
2433
|
+
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_f16m1_to_f32m2_rvv_(max_u16m1, vlmax), max_val_f32, vlmax);
|
|
2434
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
|
|
2435
|
+
|
|
2436
|
+
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2437
|
+
__riscv_vslidedown_vx_u16m1(max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, vlmax), vlmax));
|
|
2438
|
+
*max_value_ptr = *(nk_f16_t *)&max_raw;
|
|
2439
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2440
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
|
|
2441
|
+
}
|
|
2442
|
+
|
|
2443
|
+
NK_INTERNAL void nk_reduce_minmax_f16_rvv_strided_( //
|
|
2444
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2445
|
+
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2446
|
+
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2447
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e16m1();
|
|
2448
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7C00, vlmax); // +inf in f16
|
|
2449
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFC00, vlmax); // -inf in f16
|
|
2450
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
2451
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, vlmax);
|
|
2452
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2453
|
+
|
|
2454
|
+
nk_size_t offset = 0;
|
|
2455
|
+
for (nk_size_t vector_length; count > 0;
|
|
2456
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
2457
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2458
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((uint16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2459
|
+
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
2460
|
+
vector_length);
|
|
2461
|
+
|
|
2462
|
+
// Convert to f32 for comparison
|
|
2463
|
+
vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
2464
|
+
vfloat32m2_t min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vector_length);
|
|
2465
|
+
vfloat32m2_t max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, vector_length);
|
|
2466
|
+
|
|
2467
|
+
vbool16_t less_b16 = __riscv_vmflt_vv_f32m2_b16(data_f32m2, min_f32m2, vector_length);
|
|
2468
|
+
min_u16m1 = __riscv_vmerge_vvm_u16m1_tu(min_u16m1, min_u16m1, data_u16m1, less_b16, vector_length);
|
|
2469
|
+
min_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(min_indices_u64m4, min_indices_u64m4, pos_u64m4, less_b16,
|
|
2470
|
+
vector_length);
|
|
2471
|
+
|
|
2472
|
+
vbool16_t greater_b16 = __riscv_vmflt_vv_f32m2_b16(max_f32m2, data_f32m2, vector_length);
|
|
2473
|
+
max_u16m1 = __riscv_vmerge_vvm_u16m1_tu(max_u16m1, max_u16m1, data_u16m1, greater_b16, vector_length);
|
|
2474
|
+
max_indices_u64m4 = __riscv_vmerge_vvm_u64m4_tu(max_indices_u64m4, max_indices_u64m4, pos_u64m4, greater_b16,
|
|
2475
|
+
vector_length);
|
|
2476
|
+
}
|
|
2477
|
+
|
|
2478
|
+
// Horizontal reduction (same as contiguous)
|
|
2479
|
+
vfloat32m2_t final_min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vlmax);
|
|
2480
|
+
vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
2481
|
+
nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2482
|
+
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, vlmax));
|
|
2483
|
+
vfloat32m2_t final_max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, vlmax);
|
|
2484
|
+
vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
2485
|
+
nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2486
|
+
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, vlmax));
|
|
2487
|
+
if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
|
|
2488
|
+
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
2489
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2490
|
+
return;
|
|
2491
|
+
}
|
|
2492
|
+
|
|
2493
|
+
vfloat32m2_t converted_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, vlmax);
|
|
2494
|
+
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, vlmax);
|
|
2495
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, vlmax);
|
|
2496
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16, vlmax);
|
|
2497
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2498
|
+
|
|
2499
|
+
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2500
|
+
__riscv_vslidedown_vx_u16m1(min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, vlmax), vlmax));
|
|
2501
|
+
*min_value_ptr = *(nk_f16_t *)&min_raw;
|
|
2502
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2503
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, vlmax));
|
|
2504
|
+
|
|
2505
|
+
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_f16m1_to_f32m2_rvv_(max_u16m1, vlmax), max_val_f32, vlmax);
|
|
2506
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16, vlmax);
|
|
2507
|
+
|
|
2508
|
+
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2509
|
+
__riscv_vslidedown_vx_u16m1(max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, vlmax), vlmax));
|
|
2510
|
+
*max_value_ptr = *(nk_f16_t *)&max_raw;
|
|
2511
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2512
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, vlmax));
|
|
2513
|
+
}
|
|
2514
|
+
|
|
2515
|
+
NK_PUBLIC void nk_reduce_minmax_f16_rvv( //
|
|
2516
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2517
|
+
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2518
|
+
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2519
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
2520
|
+
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
2521
|
+
|
|
2522
|
+
if (count == 0)
|
|
2523
|
+
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
2524
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2525
|
+
else if (!aligned)
|
|
2526
|
+
nk_reduce_minmax_f16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2527
|
+
max_index_ptr);
|
|
2528
|
+
else if (stride_elements == 1)
|
|
2529
|
+
nk_reduce_minmax_f16_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2530
|
+
max_index_ptr);
|
|
2531
|
+
else
|
|
2532
|
+
nk_reduce_minmax_f16_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2533
|
+
max_index_ptr);
|
|
2534
|
+
}
|
|
2535
|
+
|
|
2536
|
+
NK_INTERNAL void nk_reduce_moments_e4m3_rvv_contiguous_( //
|
|
2537
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
2538
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2539
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
2540
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
2541
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
2542
|
+
|
|
2543
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
2544
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2545
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
|
|
2546
|
+
|
|
2547
|
+
// Convert e4m3 → f32 (m1 → m4)
|
|
2548
|
+
vfloat32m4_t data_f32m4 = nk_e4m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
2549
|
+
|
|
2550
|
+
// Accumulate at f32 precision
|
|
2551
|
+
sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
|
|
2552
|
+
sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
|
|
2553
|
+
}
|
|
2554
|
+
|
|
2555
|
+
// Horizontal reduction
|
|
2556
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
2557
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
|
|
2558
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
|
|
2559
|
+
}
|
|
2560
|
+
|
|
2561
|
+
NK_INTERNAL void nk_reduce_moments_e4m3_rvv_strided_( //
|
|
2562
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2563
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2564
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
2565
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
2566
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
2567
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2568
|
+
|
|
2569
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
2570
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2571
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2572
|
+
|
|
2573
|
+
// Convert e4m3 → f32 (m1 → m4)
|
|
2574
|
+
vfloat32m4_t data_f32m4 = nk_e4m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
2575
|
+
|
|
2576
|
+
// Accumulate at f32 precision
|
|
2577
|
+
sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
|
|
2578
|
+
sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
|
|
2579
|
+
}
|
|
2580
|
+
|
|
2581
|
+
// Horizontal reduction
|
|
2582
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
2583
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
|
|
2584
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
|
|
2585
|
+
}
|
|
2586
|
+
|
|
2587
|
+
NK_PUBLIC void nk_reduce_moments_e4m3_rvv( //
|
|
2588
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2589
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2590
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
|
|
2591
|
+
int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
|
|
2592
|
+
|
|
2593
|
+
if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
|
|
2594
|
+
else if (!aligned) nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2595
|
+
else if (stride_elements == 1) nk_reduce_moments_e4m3_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2596
|
+
else nk_reduce_moments_e4m3_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2597
|
+
}
|
|
2598
|
+
|
|
2599
|
+
NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_contiguous_( //
|
|
2600
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
2601
|
+
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2602
|
+
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2603
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
2604
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, vlmax); // Largest comparable
|
|
2605
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax); // Smallest comparable
|
|
2606
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
2607
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
2608
|
+
|
|
2609
|
+
nk_size_t offset = 0;
|
|
2610
|
+
for (nk_size_t vector_length; count > 0;
|
|
2611
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
2612
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2613
|
+
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
|
|
2614
|
+
|
|
2615
|
+
// Convert to comparable form
|
|
2616
|
+
vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
2617
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
2618
|
+
vector_length);
|
|
2619
|
+
|
|
2620
|
+
// Detect E4M3 NaN: comparable == 0x00 (neg NaN) or comparable == 0xFF (pos NaN)
|
|
2621
|
+
vbool8_t nan_low_b8 = __riscv_vmseq_vx_u8m1_b8(comparable_u8m1, 0x00, vector_length);
|
|
2622
|
+
vbool8_t nan_high_b8 = __riscv_vmseq_vx_u8m1_b8(comparable_u8m1, 0xFF, vector_length);
|
|
2623
|
+
vbool8_t is_nan_b8 = __riscv_vmor_mm_b8(nan_low_b8, nan_high_b8, vector_length);
|
|
2624
|
+
vuint8m1_t data_min_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0xFF, is_nan_b8, vector_length);
|
|
2625
|
+
vuint8m1_t data_max_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0x00, is_nan_b8, vector_length);
|
|
2626
|
+
|
|
2627
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_min_u8m1, min_u8m1, vector_length);
|
|
2628
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_min_u8m1, less_b8, vector_length);
|
|
2629
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
2630
|
+
vector_length);
|
|
2631
|
+
|
|
2632
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_max_u8m1, vector_length);
|
|
2633
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_max_u8m1, greater_b8, vector_length);
|
|
2634
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
2635
|
+
vector_length);
|
|
2636
|
+
}
|
|
2637
|
+
|
|
2638
|
+
// Horizontal reduction + convert back
|
|
2639
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
|
|
2640
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
2641
|
+
|
|
2642
|
+
// All-NaN case
|
|
2643
|
+
if (min_comparable == 0xFF) {
|
|
2644
|
+
*min_value_ptr = (nk_e4m3_t)NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX;
|
|
2645
|
+
*max_value_ptr = (nk_e4m3_t)NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX;
|
|
2646
|
+
return;
|
|
2647
|
+
}
|
|
2648
|
+
|
|
2649
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
|
|
2650
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
2651
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
2652
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2653
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2654
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
2655
|
+
|
|
2656
|
+
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
2657
|
+
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
|
|
2658
|
+
*min_value_ptr = (nk_e4m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
2659
|
+
|
|
2660
|
+
// Similar for max
|
|
2661
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
2662
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
2663
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
|
|
2664
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
2665
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2666
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
2667
|
+
|
|
2668
|
+
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
2669
|
+
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
|
|
2670
|
+
*max_value_ptr = (nk_e4m3_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
|
|
2671
|
+
}
|
|
2672
|
+
|
|
2673
|
+
NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_strided_( //
|
|
2674
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2675
|
+
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2676
|
+
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2677
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
2678
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, vlmax);
|
|
2679
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
|
|
2680
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
2681
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
2682
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2683
|
+
|
|
2684
|
+
nk_size_t offset = 0;
|
|
2685
|
+
for (nk_size_t vector_length; count > 0;
|
|
2686
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
2687
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2688
|
+
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2689
|
+
|
|
2690
|
+
vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
2691
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
2692
|
+
vector_length);
|
|
2693
|
+
|
|
2694
|
+
// Detect E4M3 NaN: comparable == 0x00 (neg NaN) or comparable == 0xFF (pos NaN)
|
|
2695
|
+
vbool8_t nan_low_b8 = __riscv_vmseq_vx_u8m1_b8(comparable_u8m1, 0x00, vector_length);
|
|
2696
|
+
vbool8_t nan_high_b8 = __riscv_vmseq_vx_u8m1_b8(comparable_u8m1, 0xFF, vector_length);
|
|
2697
|
+
vbool8_t is_nan_b8 = __riscv_vmor_mm_b8(nan_low_b8, nan_high_b8, vector_length);
|
|
2698
|
+
vuint8m1_t data_min_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0xFF, is_nan_b8, vector_length);
|
|
2699
|
+
vuint8m1_t data_max_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0x00, is_nan_b8, vector_length);
|
|
2700
|
+
|
|
2701
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_min_u8m1, min_u8m1, vector_length);
|
|
2702
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_min_u8m1, less_b8, vector_length);
|
|
2703
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
2704
|
+
vector_length);
|
|
2705
|
+
|
|
2706
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_max_u8m1, vector_length);
|
|
2707
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_max_u8m1, greater_b8, vector_length);
|
|
2708
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
2709
|
+
vector_length);
|
|
2710
|
+
}
|
|
2711
|
+
|
|
2712
|
+
// Horizontal reduction (same as contiguous)
|
|
2713
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
|
|
2714
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
2715
|
+
|
|
2716
|
+
// All-NaN case
|
|
2717
|
+
if (min_comparable == 0xFF) {
|
|
2718
|
+
*min_value_ptr = (nk_e4m3_t)NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX;
|
|
2719
|
+
*max_value_ptr = (nk_e4m3_t)NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX;
|
|
2720
|
+
return;
|
|
2721
|
+
}
|
|
2722
|
+
|
|
2723
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
|
|
2724
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
2725
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
2726
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2727
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2728
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
2729
|
+
|
|
2730
|
+
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
2731
|
+
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
|
|
2732
|
+
*min_value_ptr = (nk_e4m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
2733
|
+
|
|
2734
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
2735
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
2736
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
|
|
2737
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
2738
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2739
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
2740
|
+
|
|
2741
|
+
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
2742
|
+
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
|
|
2743
|
+
*max_value_ptr = (nk_e4m3_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
|
|
2744
|
+
}
|
|
2745
|
+
|
|
2746
|
+
NK_PUBLIC void nk_reduce_minmax_e4m3_rvv( //
|
|
2747
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2748
|
+
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2749
|
+
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2750
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
|
|
2751
|
+
int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
|
|
2752
|
+
|
|
2753
|
+
if (count == 0)
|
|
2754
|
+
*min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E4M3_MIN,
|
|
2755
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2756
|
+
else if (!aligned)
|
|
2757
|
+
nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2758
|
+
max_index_ptr);
|
|
2759
|
+
else if (stride_elements == 1)
|
|
2760
|
+
nk_reduce_minmax_e4m3_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2761
|
+
max_index_ptr);
|
|
2762
|
+
else
|
|
2763
|
+
nk_reduce_minmax_e4m3_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2764
|
+
max_index_ptr);
|
|
2765
|
+
}
|
|
2766
|
+
|
|
2767
|
+
NK_INTERNAL void nk_reduce_moments_e5m2_rvv_contiguous_( //
|
|
2768
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
2769
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2770
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
2771
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
2772
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
2773
|
+
|
|
2774
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
2775
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2776
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
|
|
2777
|
+
|
|
2778
|
+
// Convert e5m2 → f32 (m1 → m4)
|
|
2779
|
+
vfloat32m4_t data_f32m4 = nk_e5m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
2780
|
+
|
|
2781
|
+
// Accumulate at f32 precision
|
|
2782
|
+
sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
|
|
2783
|
+
sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
|
|
2784
|
+
}
|
|
2785
|
+
|
|
2786
|
+
// Horizontal reduction
|
|
2787
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
2788
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
|
|
2789
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
|
|
2790
|
+
}
|
|
2791
|
+
|
|
2792
|
+
NK_INTERNAL void nk_reduce_moments_e5m2_rvv_strided_( //
|
|
2793
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2794
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2795
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
2796
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
2797
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
2798
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2799
|
+
|
|
2800
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
2801
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2802
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2803
|
+
|
|
2804
|
+
// Convert e5m2 → f32 (m1 → m4)
|
|
2805
|
+
vfloat32m4_t data_f32m4 = nk_e5m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
2806
|
+
|
|
2807
|
+
// Accumulate at f32 precision
|
|
2808
|
+
sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
|
|
2809
|
+
sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
|
|
2810
|
+
}
|
|
2811
|
+
|
|
2812
|
+
// Horizontal reduction
|
|
2813
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
2814
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
|
|
2815
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
|
|
2816
|
+
}
|
|
2817
|
+
|
|
2818
|
+
NK_PUBLIC void nk_reduce_moments_e5m2_rvv( //
|
|
2819
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2820
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2821
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
|
|
2822
|
+
int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
|
|
2823
|
+
|
|
2824
|
+
if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
|
|
2825
|
+
else if (!aligned) nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2826
|
+
else if (stride_elements == 1) nk_reduce_moments_e5m2_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2827
|
+
else nk_reduce_moments_e5m2_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2828
|
+
}
|
|
2829
|
+
|
|
2830
|
+
NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_contiguous_( //
|
|
2831
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
2832
|
+
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2833
|
+
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2834
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
2835
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, vlmax);
|
|
2836
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
|
|
2837
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
2838
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
2839
|
+
|
|
2840
|
+
nk_size_t offset = 0;
|
|
2841
|
+
for (nk_size_t vector_length; count > 0;
|
|
2842
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
2843
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2844
|
+
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
|
|
2845
|
+
|
|
2846
|
+
vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
2847
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
2848
|
+
vector_length);
|
|
2849
|
+
|
|
2850
|
+
// Detect E5M2 NaN: comparable <= 0x02 (neg NaN) or comparable >= 0xFD (pos NaN)
|
|
2851
|
+
vbool8_t nan_low_b8 = __riscv_vmsleu_vx_u8m1_b8(comparable_u8m1, 0x02, vector_length);
|
|
2852
|
+
vbool8_t nan_high_b8 = __riscv_vmsgeu_vx_u8m1_b8(comparable_u8m1, 0xFD, vector_length);
|
|
2853
|
+
vbool8_t is_nan_b8 = __riscv_vmor_mm_b8(nan_low_b8, nan_high_b8, vector_length);
|
|
2854
|
+
vuint8m1_t data_min_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0xFF, is_nan_b8, vector_length);
|
|
2855
|
+
vuint8m1_t data_max_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0x00, is_nan_b8, vector_length);
|
|
2856
|
+
|
|
2857
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_min_u8m1, min_u8m1, vector_length);
|
|
2858
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_min_u8m1, less_b8, vector_length);
|
|
2859
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
2860
|
+
vector_length);
|
|
2861
|
+
|
|
2862
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_max_u8m1, vector_length);
|
|
2863
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_max_u8m1, greater_b8, vector_length);
|
|
2864
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
2865
|
+
vector_length);
|
|
2866
|
+
}
|
|
2867
|
+
|
|
2868
|
+
// Horizontal reduction + convert back
|
|
2869
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
|
|
2870
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
2871
|
+
|
|
2872
|
+
// All-NaN case
|
|
2873
|
+
if (min_comparable == 0xFF) {
|
|
2874
|
+
*min_value_ptr = (nk_e5m2_t)NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX;
|
|
2875
|
+
*max_value_ptr = (nk_e5m2_t)NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX;
|
|
2876
|
+
return;
|
|
2877
|
+
}
|
|
2878
|
+
|
|
2879
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
|
|
2880
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
2881
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
2882
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2883
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2884
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
2885
|
+
|
|
2886
|
+
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
2887
|
+
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
|
|
2888
|
+
*min_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
2889
|
+
|
|
2890
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
2891
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
2892
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
|
|
2893
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
2894
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2895
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
2896
|
+
|
|
2897
|
+
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
2898
|
+
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
|
|
2899
|
+
*max_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
|
|
2900
|
+
}
|
|
2901
|
+
|
|
2902
|
+
NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_strided_( //
|
|
2903
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2904
|
+
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2905
|
+
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2906
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
2907
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, vlmax);
|
|
2908
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
|
|
2909
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
2910
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
2911
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2912
|
+
|
|
2913
|
+
nk_size_t offset = 0;
|
|
2914
|
+
for (nk_size_t vector_length; count > 0;
|
|
2915
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
2916
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2917
|
+
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2918
|
+
|
|
2919
|
+
vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
2920
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
2921
|
+
vector_length);
|
|
2922
|
+
|
|
2923
|
+
// Detect E5M2 NaN: comparable <= 0x02 (neg NaN) or comparable >= 0xFD (pos NaN)
|
|
2924
|
+
vbool8_t nan_low_b8 = __riscv_vmsleu_vx_u8m1_b8(comparable_u8m1, 0x02, vector_length);
|
|
2925
|
+
vbool8_t nan_high_b8 = __riscv_vmsgeu_vx_u8m1_b8(comparable_u8m1, 0xFD, vector_length);
|
|
2926
|
+
vbool8_t is_nan_b8 = __riscv_vmor_mm_b8(nan_low_b8, nan_high_b8, vector_length);
|
|
2927
|
+
vuint8m1_t data_min_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0xFF, is_nan_b8, vector_length);
|
|
2928
|
+
vuint8m1_t data_max_u8m1 = __riscv_vmerge_vxm_u8m1(comparable_u8m1, 0x00, is_nan_b8, vector_length);
|
|
2929
|
+
|
|
2930
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(data_min_u8m1, min_u8m1, vector_length);
|
|
2931
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, data_min_u8m1, less_b8, vector_length);
|
|
2932
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
2933
|
+
vector_length);
|
|
2934
|
+
|
|
2935
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, data_max_u8m1, vector_length);
|
|
2936
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, data_max_u8m1, greater_b8, vector_length);
|
|
2937
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
2938
|
+
vector_length);
|
|
2939
|
+
}
|
|
2940
|
+
|
|
2941
|
+
// Horizontal reduction (same as contiguous)
|
|
2942
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
|
|
2943
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
2944
|
+
|
|
2945
|
+
// All-NaN case
|
|
2946
|
+
if (min_comparable == 0xFF) {
|
|
2947
|
+
*min_value_ptr = (nk_e5m2_t)NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX;
|
|
2948
|
+
*max_value_ptr = (nk_e5m2_t)NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX;
|
|
2949
|
+
return;
|
|
2950
|
+
}
|
|
2951
|
+
|
|
2952
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
|
|
2953
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
2954
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
2955
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2956
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2957
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
2958
|
+
|
|
2959
|
+
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
2960
|
+
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
|
|
2961
|
+
*min_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
2962
|
+
|
|
2963
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
2964
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
2965
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
|
|
2966
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
2967
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2968
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
2969
|
+
|
|
2970
|
+
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
2971
|
+
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
|
|
2972
|
+
*max_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
|
|
2973
|
+
}
|
|
2974
|
+
|
|
2975
|
+
NK_PUBLIC void nk_reduce_minmax_e5m2_rvv( //
|
|
2976
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2977
|
+
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2978
|
+
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2979
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
|
|
2980
|
+
int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
|
|
2981
|
+
|
|
2982
|
+
if (count == 0)
|
|
2983
|
+
*min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E5M2_MIN,
|
|
2984
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2985
|
+
else if (!aligned)
|
|
2986
|
+
nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2987
|
+
max_index_ptr);
|
|
2988
|
+
else if (stride_elements == 1)
|
|
2989
|
+
nk_reduce_minmax_e5m2_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2990
|
+
max_index_ptr);
|
|
2991
|
+
else
|
|
2992
|
+
nk_reduce_minmax_e5m2_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2993
|
+
max_index_ptr);
|
|
2994
|
+
}
|
|
2995
|
+
|
|
2996
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_rvv_contiguous_( //
|
|
2997
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
2998
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2999
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
3000
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
3001
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
3002
|
+
|
|
3003
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
3004
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3005
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
|
|
3006
|
+
|
|
3007
|
+
// Convert e2m3 → f32 (m1 → m4)
|
|
3008
|
+
vfloat32m4_t data_f32m4 = nk_e2m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
3009
|
+
|
|
3010
|
+
// Accumulate at f32 precision
|
|
3011
|
+
sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
|
|
3012
|
+
sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
|
|
3013
|
+
}
|
|
3014
|
+
|
|
3015
|
+
// Horizontal reduction
|
|
3016
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
3017
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
|
|
3018
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
|
|
3019
|
+
}
|
|
3020
|
+
|
|
3021
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_rvv_strided_( //
|
|
3022
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3023
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3024
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
3025
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
3026
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
3027
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3028
|
+
|
|
3029
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
3030
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3031
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
3032
|
+
|
|
3033
|
+
// Convert e2m3 → f32 (m1 → m4)
|
|
3034
|
+
vfloat32m4_t data_f32m4 = nk_e2m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
3035
|
+
|
|
3036
|
+
// Accumulate at f32 precision
|
|
3037
|
+
sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
|
|
3038
|
+
sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
|
|
3039
|
+
}
|
|
3040
|
+
|
|
3041
|
+
// Horizontal reduction
|
|
3042
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
3043
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
|
|
3044
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
|
|
3045
|
+
}
|
|
3046
|
+
|
|
3047
|
+
NK_PUBLIC void nk_reduce_moments_e2m3_rvv( //
|
|
3048
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3049
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3050
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
|
|
3051
|
+
int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
|
|
3052
|
+
|
|
3053
|
+
if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
|
|
3054
|
+
else if (!aligned) nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3055
|
+
else if (stride_elements == 1) nk_reduce_moments_e2m3_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3056
|
+
else nk_reduce_moments_e2m3_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3057
|
+
}
|
|
3058
|
+
|
|
3059
|
+
NK_INTERNAL void nk_reduce_minmax_e2m3_rvv_contiguous_( //
|
|
3060
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
3061
|
+
nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3062
|
+
nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3063
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
3064
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, vlmax); // Largest FP6 comparable
|
|
3065
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax); // Smallest FP6 comparable
|
|
3066
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
3067
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
3068
|
+
|
|
3069
|
+
nk_size_t offset = 0;
|
|
3070
|
+
for (nk_size_t vector_length; count > 0;
|
|
3071
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
3072
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3073
|
+
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
|
|
3074
|
+
|
|
3075
|
+
// Convert to FP6 comparable form
|
|
3076
|
+
vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
3077
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
3078
|
+
vector_length);
|
|
3079
|
+
|
|
3080
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(comparable_u8m1, min_u8m1, vector_length);
|
|
3081
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, comparable_u8m1, less_b8, vector_length);
|
|
3082
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
3083
|
+
vector_length);
|
|
3084
|
+
|
|
3085
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, comparable_u8m1, vector_length);
|
|
3086
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, comparable_u8m1, greater_b8, vector_length);
|
|
3087
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
3088
|
+
vector_length);
|
|
3089
|
+
}
|
|
3090
|
+
|
|
3091
|
+
// Horizontal reduction + convert back
|
|
3092
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
|
|
3093
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
3094
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
|
|
3095
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
3096
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
3097
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
3098
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3099
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
3100
|
+
|
|
3101
|
+
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
3102
|
+
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
|
|
3103
|
+
*min_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
3104
|
+
|
|
3105
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
3106
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
3107
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
|
|
3108
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
3109
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3110
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
3111
|
+
|
|
3112
|
+
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
3113
|
+
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
|
|
3114
|
+
*max_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
|
|
3115
|
+
}
|
|
3116
|
+
|
|
3117
|
+
NK_INTERNAL void nk_reduce_minmax_e2m3_rvv_strided_( //
|
|
3118
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3119
|
+
nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3120
|
+
nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3121
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
3122
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, vlmax);
|
|
3123
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
|
|
3124
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
3125
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
3126
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3127
|
+
|
|
3128
|
+
nk_size_t offset = 0;
|
|
3129
|
+
for (nk_size_t vector_length; count > 0;
|
|
3130
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
3131
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3132
|
+
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
3133
|
+
|
|
3134
|
+
vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
3135
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
3136
|
+
vector_length);
|
|
3137
|
+
|
|
3138
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(comparable_u8m1, min_u8m1, vector_length);
|
|
3139
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, comparable_u8m1, less_b8, vector_length);
|
|
3140
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
3141
|
+
vector_length);
|
|
3142
|
+
|
|
3143
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, comparable_u8m1, vector_length);
|
|
3144
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, comparable_u8m1, greater_b8, vector_length);
|
|
3145
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
3146
|
+
vector_length);
|
|
3147
|
+
}
|
|
3148
|
+
|
|
3149
|
+
// Horizontal reduction (same as contiguous)
|
|
3150
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
|
|
3151
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
3152
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
|
|
3153
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
3154
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
3155
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
3156
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3157
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
3158
|
+
|
|
3159
|
+
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
3160
|
+
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
|
|
3161
|
+
*min_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
3162
|
+
|
|
3163
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
3164
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
3165
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
|
|
3166
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
3167
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3168
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
3169
|
+
|
|
3170
|
+
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
3171
|
+
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
|
|
3172
|
+
*max_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
|
|
3173
|
+
}
|
|
3174
|
+
|
|
3175
|
+
NK_PUBLIC void nk_reduce_minmax_e2m3_rvv( //
|
|
3176
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3177
|
+
nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3178
|
+
nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3179
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
|
|
3180
|
+
int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
|
|
3181
|
+
|
|
3182
|
+
if (count == 0)
|
|
3183
|
+
*min_value_ptr = NK_E2M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E2M3_MIN,
|
|
3184
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3185
|
+
else if (!aligned)
|
|
3186
|
+
nk_reduce_minmax_e2m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3187
|
+
max_index_ptr);
|
|
3188
|
+
else if (stride_elements == 1)
|
|
3189
|
+
nk_reduce_minmax_e2m3_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3190
|
+
max_index_ptr);
|
|
3191
|
+
else
|
|
3192
|
+
nk_reduce_minmax_e2m3_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3193
|
+
max_index_ptr);
|
|
3194
|
+
}
|
|
3195
|
+
|
|
3196
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_rvv_contiguous_( //
|
|
3197
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
3198
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3199
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
3200
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
3201
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
3202
|
+
|
|
3203
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
3204
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3205
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
|
|
3206
|
+
|
|
3207
|
+
// Convert e3m2 → f32 (m1 → m4)
|
|
3208
|
+
vfloat32m4_t data_f32m4 = nk_e3m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
3209
|
+
|
|
3210
|
+
// Accumulate at f32 precision
|
|
3211
|
+
sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
|
|
3212
|
+
sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
|
|
3213
|
+
}
|
|
3214
|
+
|
|
3215
|
+
// Horizontal reduction
|
|
3216
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
3217
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
|
|
3218
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
|
|
3219
|
+
}
|
|
3220
|
+
|
|
3221
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_rvv_strided_( //
|
|
3222
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3223
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3224
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
3225
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
3226
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
3227
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3228
|
+
|
|
3229
|
+
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
3230
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3231
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
3232
|
+
|
|
3233
|
+
// Convert e3m2 → f32 (m1 → m4)
|
|
3234
|
+
vfloat32m4_t data_f32m4 = nk_e3m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
3235
|
+
|
|
3236
|
+
// Accumulate at f32 precision
|
|
3237
|
+
sum_f32m4 = __riscv_vfadd_vv_f32m4_tu(sum_f32m4, sum_f32m4, data_f32m4, vector_length);
|
|
3238
|
+
sumsq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sumsq_f32m4, data_f32m4, data_f32m4, vector_length);
|
|
3239
|
+
}
|
|
3240
|
+
|
|
3241
|
+
// Horizontal reduction
|
|
3242
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
3243
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax)),
|
|
3244
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, vlmax));
|
|
3245
|
+
}
|
|
3246
|
+
|
|
3247
|
+
NK_PUBLIC void nk_reduce_moments_e3m2_rvv( //
|
|
3248
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3249
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3250
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
|
|
3251
|
+
int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
|
|
3252
|
+
|
|
3253
|
+
if (count == 0) *sum_ptr = 0.0f, *sumsq_ptr = 0.0f;
|
|
3254
|
+
else if (!aligned) nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3255
|
+
else if (stride_elements == 1) nk_reduce_moments_e3m2_rvv_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3256
|
+
else nk_reduce_moments_e3m2_rvv_strided_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3257
|
+
}
|
|
3258
|
+
|
|
3259
|
+
NK_INTERNAL void nk_reduce_minmax_e3m2_rvv_contiguous_( //
|
|
3260
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
3261
|
+
nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3262
|
+
nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3263
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
3264
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, vlmax);
|
|
3265
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
|
|
3266
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
3267
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
3268
|
+
|
|
3269
|
+
nk_size_t offset = 0;
|
|
3270
|
+
for (nk_size_t vector_length; count > 0;
|
|
3271
|
+
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
3272
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3273
|
+
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((uint8_t const *)data_ptr, vector_length);
|
|
3274
|
+
|
|
3275
|
+
vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
3276
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
3277
|
+
vector_length);
|
|
3278
|
+
|
|
3279
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(comparable_u8m1, min_u8m1, vector_length);
|
|
3280
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, comparable_u8m1, less_b8, vector_length);
|
|
3281
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
3282
|
+
vector_length);
|
|
3283
|
+
|
|
3284
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, comparable_u8m1, vector_length);
|
|
3285
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, comparable_u8m1, greater_b8, vector_length);
|
|
3286
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
3287
|
+
vector_length);
|
|
3288
|
+
}
|
|
3289
|
+
|
|
3290
|
+
// Horizontal reduction + convert back
|
|
3291
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
|
|
3292
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
3293
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
|
|
3294
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
3295
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
3296
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
3297
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3298
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
3299
|
+
|
|
3300
|
+
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
3301
|
+
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
|
|
3302
|
+
*min_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
3303
|
+
|
|
3304
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
3305
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
3306
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
|
|
3307
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
3308
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3309
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
3310
|
+
|
|
3311
|
+
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
3312
|
+
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
|
|
3313
|
+
*max_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
|
|
3314
|
+
}
|
|
3315
|
+
|
|
3316
|
+
NK_INTERNAL void nk_reduce_minmax_e3m2_rvv_strided_( //
|
|
3317
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3318
|
+
nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3319
|
+
nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3320
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e8m1();
|
|
3321
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, vlmax);
|
|
3322
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, vlmax);
|
|
3323
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
3324
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, vlmax);
|
|
3325
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3326
|
+
|
|
3327
|
+
nk_size_t offset = 0;
|
|
3328
|
+
for (nk_size_t vector_length; count > 0;
|
|
3329
|
+
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
3330
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3331
|
+
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((uint8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
3332
|
+
|
|
3333
|
+
vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
3334
|
+
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
3335
|
+
vector_length);
|
|
3336
|
+
|
|
3337
|
+
vbool8_t less_b8 = __riscv_vmsltu_vv_u8m1_b8(comparable_u8m1, min_u8m1, vector_length);
|
|
3338
|
+
min_u8m1 = __riscv_vmerge_vvm_u8m1_tu(min_u8m1, min_u8m1, comparable_u8m1, less_b8, vector_length);
|
|
3339
|
+
min_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(min_indices_u64m8, min_indices_u64m8, pos_u64m8, less_b8,
|
|
3340
|
+
vector_length);
|
|
3341
|
+
|
|
3342
|
+
vbool8_t greater_b8 = __riscv_vmsltu_vv_u8m1_b8(max_u8m1, comparable_u8m1, vector_length);
|
|
3343
|
+
max_u8m1 = __riscv_vmerge_vvm_u8m1_tu(max_u8m1, max_u8m1, comparable_u8m1, greater_b8, vector_length);
|
|
3344
|
+
max_indices_u64m8 = __riscv_vmerge_vvm_u64m8_tu(max_indices_u64m8, max_indices_u64m8, pos_u64m8, greater_b8,
|
|
3345
|
+
vector_length);
|
|
3346
|
+
}
|
|
3347
|
+
|
|
3348
|
+
// Horizontal reduction (same as contiguous)
|
|
3349
|
+
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
|
|
3350
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, vlmax));
|
|
3351
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, vlmax);
|
|
3352
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, vlmax);
|
|
3353
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8, vlmax);
|
|
3354
|
+
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
3355
|
+
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3356
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, vlmax));
|
|
3357
|
+
|
|
3358
|
+
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
3359
|
+
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
|
|
3360
|
+
*min_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
3361
|
+
|
|
3362
|
+
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
3363
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, vlmax));
|
|
3364
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, vlmax);
|
|
3365
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8, vlmax);
|
|
3366
|
+
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3367
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, vlmax));
|
|
3368
|
+
|
|
3369
|
+
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
3370
|
+
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
|
|
3371
|
+
*max_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(max_raw_u8m1);
|
|
3372
|
+
}
|
|
3373
|
+
|
|
3374
|
+
NK_PUBLIC void nk_reduce_minmax_e3m2_rvv( //
|
|
3375
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3376
|
+
nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3377
|
+
nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3378
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
|
|
3379
|
+
int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
|
|
3380
|
+
|
|
3381
|
+
if (count == 0)
|
|
3382
|
+
*min_value_ptr = NK_E3M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E3M2_MIN,
|
|
3383
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3384
|
+
else if (!aligned)
|
|
3385
|
+
nk_reduce_minmax_e3m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3386
|
+
max_index_ptr);
|
|
3387
|
+
else if (stride_elements == 1)
|
|
3388
|
+
nk_reduce_minmax_e3m2_rvv_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3389
|
+
max_index_ptr);
|
|
3390
|
+
else
|
|
3391
|
+
nk_reduce_minmax_e3m2_rvv_strided_(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3392
|
+
max_index_ptr);
|
|
3393
|
+
}
|
|
3394
|
+
|
|
3395
|
+
#if defined(__clang__)
|
|
3396
|
+
#pragma clang attribute pop
|
|
3397
|
+
#elif defined(__GNUC__)
|
|
3398
|
+
#pragma GCC pop_options
|
|
3399
|
+
#endif
|
|
3400
|
+
|
|
3401
|
+
#if defined(__cplusplus)
|
|
3402
|
+
} // extern "C"
|
|
3403
|
+
#endif
|
|
3404
|
+
|
|
3405
|
+
#endif // NK_TARGET_RVV
|
|
3406
|
+
#endif // NK_TARGET_RISCV_
|
|
3407
|
+
#endif // NK_REDUCE_RVV_H
|