numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -27,72 +27,74 @@
|
|
|
27
27
|
extern "C" {
|
|
28
28
|
#endif
|
|
29
29
|
|
|
30
|
-
/** @brief Saturating horizontal sum of u64m1 via tree fold: O(log
|
|
31
|
-
NK_INTERNAL nk_u64_t nk_reduce_vsaddu_u64m1_rvv_(vuint64m1_t acc_u64m1, nk_size_t
|
|
32
|
-
for (nk_size_t half =
|
|
33
|
-
vuint64m1_t shifted_u64m1 = __riscv_vslidedown_vx_u64m1(acc_u64m1, half,
|
|
34
|
-
acc_u64m1 = __riscv_vsaddu_vv_u64m1(acc_u64m1, shifted_u64m1,
|
|
30
|
+
/** @brief Saturating horizontal sum of u64m1 via tree fold: O(log vector_length) vector ops. */
|
|
31
|
+
NK_INTERNAL nk_u64_t nk_reduce_vsaddu_u64m1_rvv_(vuint64m1_t acc_u64m1, nk_size_t vector_length) {
|
|
32
|
+
for (nk_size_t half = vector_length >> 1; half > 0; half >>= 1) {
|
|
33
|
+
vuint64m1_t shifted_u64m1 = __riscv_vslidedown_vx_u64m1(acc_u64m1, half, vector_length);
|
|
34
|
+
acc_u64m1 = __riscv_vsaddu_vv_u64m1(acc_u64m1, shifted_u64m1, vector_length);
|
|
35
35
|
}
|
|
36
36
|
return __riscv_vmv_x_s_u64m1_u64(acc_u64m1);
|
|
37
37
|
}
|
|
38
38
|
|
|
39
|
-
/** @brief Saturating horizontal sum of u64m2 via tree fold: O(log
|
|
40
|
-
NK_INTERNAL nk_u64_t nk_reduce_vsaddu_u64m2_rvv_(vuint64m2_t acc_u64m2, nk_size_t
|
|
41
|
-
for (nk_size_t half =
|
|
42
|
-
vuint64m2_t shifted_u64m2 = __riscv_vslidedown_vx_u64m2(acc_u64m2, half,
|
|
43
|
-
acc_u64m2 = __riscv_vsaddu_vv_u64m2(acc_u64m2, shifted_u64m2,
|
|
39
|
+
/** @brief Saturating horizontal sum of u64m2 via tree fold: O(log vector_length) vector ops. */
|
|
40
|
+
NK_INTERNAL nk_u64_t nk_reduce_vsaddu_u64m2_rvv_(vuint64m2_t acc_u64m2, nk_size_t vector_length) {
|
|
41
|
+
for (nk_size_t half = vector_length >> 1; half > 0; half >>= 1) {
|
|
42
|
+
vuint64m2_t shifted_u64m2 = __riscv_vslidedown_vx_u64m2(acc_u64m2, half, vector_length);
|
|
43
|
+
acc_u64m2 = __riscv_vsaddu_vv_u64m2(acc_u64m2, shifted_u64m2, vector_length);
|
|
44
44
|
}
|
|
45
45
|
return __riscv_vmv_x_s_u64m2_u64(acc_u64m2);
|
|
46
46
|
}
|
|
47
47
|
|
|
48
48
|
/** @brief 128-bit horizontal sum of (upper:i64m1, lower:u64m1) via tree fold, then saturate to i64. */
|
|
49
49
|
NK_INTERNAL nk_i64_t nk_reduce_128bit_sum_i64m1_rvv_( //
|
|
50
|
-
vuint64m1_t
|
|
51
|
-
for (nk_size_t half =
|
|
52
|
-
vuint64m1_t
|
|
53
|
-
vint64m1_t
|
|
54
|
-
vuint64m1_t
|
|
55
|
-
vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(
|
|
56
|
-
vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0,
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
nk_i64_t
|
|
64
|
-
|
|
65
|
-
|
|
50
|
+
vuint64m1_t sum_low_u64m1, vint64m1_t sum_high_i64m1, nk_size_t vector_length) {
|
|
51
|
+
for (nk_size_t half = vector_length >> 1; half > 0; half >>= 1) {
|
|
52
|
+
vuint64m1_t shifted_low_u64m1 = __riscv_vslidedown_vx_u64m1(sum_low_u64m1, half, vector_length);
|
|
53
|
+
vint64m1_t shifted_high_i64m1 = __riscv_vslidedown_vx_i64m1(sum_high_i64m1, half, vector_length);
|
|
54
|
+
vuint64m1_t new_low_u64m1 = __riscv_vadd_vv_u64m1(sum_low_u64m1, shifted_low_u64m1, vector_length);
|
|
55
|
+
vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(new_low_u64m1, sum_low_u64m1, vector_length);
|
|
56
|
+
vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0, vector_length), 1, carry_b64,
|
|
57
|
+
vector_length);
|
|
58
|
+
sum_high_i64m1 = __riscv_vadd_vv_i64m1(sum_high_i64m1, shifted_high_i64m1, vector_length);
|
|
59
|
+
sum_high_i64m1 = __riscv_vadd_vv_i64m1(sum_high_i64m1, carry_i64m1, vector_length);
|
|
60
|
+
sum_low_u64m1 = new_low_u64m1;
|
|
61
|
+
}
|
|
62
|
+
nk_u64_t total_low = __riscv_vmv_x_s_u64m1_u64(sum_low_u64m1);
|
|
63
|
+
nk_i64_t total_high = __riscv_vmv_x_s_i64m1_i64(sum_high_i64m1);
|
|
64
|
+
nk_i64_t total_low_signed = (nk_i64_t)total_low;
|
|
65
|
+
if (total_high == (total_low_signed >> 63)) return total_low_signed;
|
|
66
|
+
else if (total_high >= 0) return NK_I64_MAX;
|
|
66
67
|
else return NK_I64_MIN;
|
|
67
68
|
}
|
|
68
69
|
|
|
69
70
|
/** @brief 128-bit horizontal sum of (upper:i64m2, lower:u64m2) via tree fold, then saturate to i64. */
|
|
70
71
|
NK_INTERNAL nk_i64_t nk_reduce_128bit_sum_i64m2_rvv_( //
|
|
71
|
-
vuint64m2_t
|
|
72
|
-
for (nk_size_t half =
|
|
73
|
-
vuint64m2_t
|
|
74
|
-
vint64m2_t
|
|
75
|
-
vuint64m2_t
|
|
76
|
-
vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(
|
|
77
|
-
vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0,
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
nk_i64_t
|
|
85
|
-
|
|
86
|
-
|
|
72
|
+
vuint64m2_t sum_low_u64m2, vint64m2_t sum_high_i64m2, nk_size_t vector_length) {
|
|
73
|
+
for (nk_size_t half = vector_length >> 1; half > 0; half >>= 1) {
|
|
74
|
+
vuint64m2_t shifted_low_u64m2 = __riscv_vslidedown_vx_u64m2(sum_low_u64m2, half, vector_length);
|
|
75
|
+
vint64m2_t shifted_high_i64m2 = __riscv_vslidedown_vx_i64m2(sum_high_i64m2, half, vector_length);
|
|
76
|
+
vuint64m2_t new_low_u64m2 = __riscv_vadd_vv_u64m2(sum_low_u64m2, shifted_low_u64m2, vector_length);
|
|
77
|
+
vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(new_low_u64m2, sum_low_u64m2, vector_length);
|
|
78
|
+
vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0, vector_length), 1, carry_b32,
|
|
79
|
+
vector_length);
|
|
80
|
+
sum_high_i64m2 = __riscv_vadd_vv_i64m2(sum_high_i64m2, shifted_high_i64m2, vector_length);
|
|
81
|
+
sum_high_i64m2 = __riscv_vadd_vv_i64m2(sum_high_i64m2, carry_i64m2, vector_length);
|
|
82
|
+
sum_low_u64m2 = new_low_u64m2;
|
|
83
|
+
}
|
|
84
|
+
nk_u64_t total_low = __riscv_vmv_x_s_u64m2_u64(sum_low_u64m2);
|
|
85
|
+
nk_i64_t total_high = __riscv_vmv_x_s_i64m2_i64(sum_high_i64m2);
|
|
86
|
+
nk_i64_t total_low_signed = (nk_i64_t)total_low;
|
|
87
|
+
if (total_high == (total_low_signed >> 63)) return total_low_signed;
|
|
88
|
+
else if (total_high >= 0) return NK_I64_MAX;
|
|
87
89
|
else return NK_I64_MIN;
|
|
88
90
|
}
|
|
89
91
|
|
|
90
92
|
NK_INTERNAL void nk_reduce_moments_f32_rvv_contiguous_( //
|
|
91
93
|
nk_f32_t const *data, nk_size_t count, //
|
|
92
94
|
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
93
|
-
nk_size_t
|
|
94
|
-
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
95
|
-
vfloat64m2_t sumsq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
95
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
96
|
+
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
97
|
+
vfloat64m2_t sumsq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
96
98
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data += vector_length) {
|
|
97
99
|
vector_length = __riscv_vsetvl_e32m1(count);
|
|
98
100
|
vfloat32m1_t data_f32m1 = __riscv_vle32_v_f32m1(data, vector_length);
|
|
@@ -100,16 +102,16 @@ NK_INTERNAL void nk_reduce_moments_f32_rvv_contiguous_( //
|
|
|
100
102
|
sumsq_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sumsq_f64m2, data_f32m1, data_f32m1, vector_length);
|
|
101
103
|
}
|
|
102
104
|
vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
103
|
-
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero,
|
|
104
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sumsq_f64m2, zero,
|
|
105
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero, max_vector_length)),
|
|
106
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sumsq_f64m2, zero, max_vector_length));
|
|
105
107
|
}
|
|
106
108
|
|
|
107
109
|
NK_INTERNAL void nk_reduce_moments_f32_rvv_strided_( //
|
|
108
110
|
nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
109
111
|
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
110
|
-
nk_size_t
|
|
111
|
-
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
112
|
-
vfloat64m2_t sumsq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
112
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
113
|
+
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
114
|
+
vfloat64m2_t sumsq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
113
115
|
unsigned char const *ptr = (unsigned char const *)data;
|
|
114
116
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
115
117
|
vector_length = __riscv_vsetvl_e32m1(count);
|
|
@@ -119,8 +121,8 @@ NK_INTERNAL void nk_reduce_moments_f32_rvv_strided_( //
|
|
|
119
121
|
sumsq_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sumsq_f64m2, data_f32m1, data_f32m1, vector_length);
|
|
120
122
|
}
|
|
121
123
|
vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
122
|
-
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero,
|
|
123
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sumsq_f64m2, zero,
|
|
124
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero, max_vector_length)),
|
|
125
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sumsq_f64m2, zero, max_vector_length));
|
|
124
126
|
}
|
|
125
127
|
|
|
126
128
|
NK_PUBLIC void nk_reduce_moments_f32_rvv( //
|
|
@@ -138,88 +140,92 @@ NK_INTERNAL void nk_reduce_minmax_f32_rvv_contiguous_( //
|
|
|
138
140
|
nk_f32_t const *data, nk_size_t count, //
|
|
139
141
|
nk_f32_t *min_value, nk_size_t *min_index, //
|
|
140
142
|
nk_f32_t *max_value, nk_size_t *max_index) {
|
|
141
|
-
nk_size_t
|
|
142
|
-
vfloat32m1_t min = __riscv_vfmv_v_f_f32m1(NK_F32_MAX,
|
|
143
|
-
vfloat32m1_t max = __riscv_vfmv_v_f_f32m1(NK_F32_MIN,
|
|
144
|
-
vuint64m2_t min_indices = __riscv_vmv_v_x_u64m2(0,
|
|
145
|
-
vuint64m2_t max_indices = __riscv_vmv_v_x_u64m2(0,
|
|
143
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m1();
|
|
144
|
+
vfloat32m1_t min = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, max_vector_length);
|
|
145
|
+
vfloat32m1_t max = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, max_vector_length);
|
|
146
|
+
vuint64m2_t min_indices = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
147
|
+
vuint64m2_t max_indices = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
146
148
|
nk_size_t offset = 0;
|
|
147
|
-
for (nk_size_t remaining = count,
|
|
148
|
-
remaining -=
|
|
149
|
-
|
|
150
|
-
vfloat32m1_t data_f32m1 = __riscv_vle32_v_f32m1(data + offset,
|
|
151
|
-
vuint64m2_t position_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(
|
|
152
|
-
|
|
153
|
-
vbool32_t less_b32 = __riscv_vmflt_vv_f32m1_b32(data_f32m1, min,
|
|
154
|
-
min = __riscv_vmerge_vvm_f32m1_tu(min, min, data_f32m1, less_b32,
|
|
155
|
-
min_indices = __riscv_vmerge_vvm_u64m2_tu(min_indices, min_indices, position_u64m2, less_b32,
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
149
|
+
for (nk_size_t remaining = count, max_vector_length; remaining > 0;
|
|
150
|
+
remaining -= max_vector_length, offset += max_vector_length) {
|
|
151
|
+
max_vector_length = __riscv_vsetvl_e32m1(remaining);
|
|
152
|
+
vfloat32m1_t data_f32m1 = __riscv_vle32_v_f32m1(data + offset, max_vector_length);
|
|
153
|
+
vuint64m2_t position_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(max_vector_length), (nk_u64_t)offset,
|
|
154
|
+
max_vector_length);
|
|
155
|
+
vbool32_t less_b32 = __riscv_vmflt_vv_f32m1_b32(data_f32m1, min, max_vector_length);
|
|
156
|
+
min = __riscv_vmerge_vvm_f32m1_tu(min, min, data_f32m1, less_b32, max_vector_length);
|
|
157
|
+
min_indices = __riscv_vmerge_vvm_u64m2_tu(min_indices, min_indices, position_u64m2, less_b32,
|
|
158
|
+
max_vector_length);
|
|
159
|
+
vbool32_t greater_b32 = __riscv_vmflt_vv_f32m1_b32(max, data_f32m1, max_vector_length);
|
|
160
|
+
max = __riscv_vmerge_vvm_f32m1_tu(max, max, data_f32m1, greater_b32, max_vector_length);
|
|
161
|
+
max_indices = __riscv_vmerge_vvm_u64m2_tu(max_indices, max_indices, position_u64m2, greater_b32,
|
|
162
|
+
max_vector_length);
|
|
159
163
|
}
|
|
160
164
|
vfloat32m1_t id_max = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
161
|
-
nk_f32_t mn = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmin_vs_f32m1_f32m1(min, id_max,
|
|
165
|
+
nk_f32_t mn = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmin_vs_f32m1_f32m1(min, id_max, max_vector_length));
|
|
162
166
|
vfloat32m1_t id_min = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
163
|
-
nk_f32_t mx = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmax_vs_f32m1_f32m1(max, id_min,
|
|
167
|
+
nk_f32_t mx = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmax_vs_f32m1_f32m1(max, id_min, max_vector_length));
|
|
164
168
|
if (mn == NK_F32_MAX && mx == NK_F32_MIN) {
|
|
165
169
|
*min_value = NK_F32_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F32_MIN, *max_index = NK_SIZE_MAX;
|
|
166
170
|
return;
|
|
167
171
|
}
|
|
168
|
-
vbool32_t min_match_b32 = __riscv_vmfeq_vf_f32m1_b32(min, mn,
|
|
169
|
-
vuint64m2_t sentinel = __riscv_vmv_v_x_u64m2(NK_U64_MAX,
|
|
170
|
-
vuint64m2_t min_cands = __riscv_vmerge_vvm_u64m2(sentinel, min_indices, min_match_b32,
|
|
172
|
+
vbool32_t min_match_b32 = __riscv_vmfeq_vf_f32m1_b32(min, mn, max_vector_length);
|
|
173
|
+
vuint64m2_t sentinel = __riscv_vmv_v_x_u64m2(NK_U64_MAX, max_vector_length);
|
|
174
|
+
vuint64m2_t min_cands = __riscv_vmerge_vvm_u64m2(sentinel, min_indices, min_match_b32, max_vector_length);
|
|
171
175
|
vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
172
|
-
*min_value = mn,
|
|
173
|
-
|
|
174
|
-
vbool32_t max_match_b32 = __riscv_vmfeq_vf_f32m1_b32(max, mx,
|
|
175
|
-
vuint64m2_t max_cands = __riscv_vmerge_vvm_u64m2(sentinel, max_indices, max_match_b32,
|
|
176
|
-
*max_value = mx,
|
|
177
|
-
|
|
176
|
+
*min_value = mn, *min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
177
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands, id_umax, max_vector_length));
|
|
178
|
+
vbool32_t max_match_b32 = __riscv_vmfeq_vf_f32m1_b32(max, mx, max_vector_length);
|
|
179
|
+
vuint64m2_t max_cands = __riscv_vmerge_vvm_u64m2(sentinel, max_indices, max_match_b32, max_vector_length);
|
|
180
|
+
*max_value = mx, *max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
181
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands, id_umax, max_vector_length));
|
|
178
182
|
}
|
|
179
183
|
|
|
180
184
|
NK_INTERNAL void nk_reduce_minmax_f32_rvv_strided_( //
|
|
181
185
|
nk_f32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
182
186
|
nk_f32_t *min_value, nk_size_t *min_index, //
|
|
183
187
|
nk_f32_t *max_value, nk_size_t *max_index) {
|
|
184
|
-
nk_size_t
|
|
185
|
-
vfloat32m1_t min = __riscv_vfmv_v_f_f32m1(NK_F32_MAX,
|
|
186
|
-
vfloat32m1_t max = __riscv_vfmv_v_f_f32m1(NK_F32_MIN,
|
|
187
|
-
vuint64m2_t min_indices = __riscv_vmv_v_x_u64m2(0,
|
|
188
|
-
vuint64m2_t max_indices = __riscv_vmv_v_x_u64m2(0,
|
|
188
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m1();
|
|
189
|
+
vfloat32m1_t min = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, max_vector_length);
|
|
190
|
+
vfloat32m1_t max = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, max_vector_length);
|
|
191
|
+
vuint64m2_t min_indices = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
192
|
+
vuint64m2_t max_indices = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
189
193
|
unsigned char const *ptr = (unsigned char const *)data;
|
|
190
194
|
nk_size_t offset = 0;
|
|
191
|
-
for (nk_size_t remaining = count,
|
|
192
|
-
remaining -=
|
|
193
|
-
|
|
195
|
+
for (nk_size_t remaining = count, max_vector_length; remaining > 0;
|
|
196
|
+
remaining -= max_vector_length, offset += max_vector_length, ptr += max_vector_length * stride_bytes) {
|
|
197
|
+
max_vector_length = __riscv_vsetvl_e32m1(remaining);
|
|
194
198
|
vfloat32m1_t data_f32m1 = __riscv_vlse32_v_f32m1((nk_f32_t const *)ptr, (nk_ssize_t)stride_bytes,
|
|
195
|
-
|
|
196
|
-
vuint64m2_t position_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(
|
|
197
|
-
|
|
198
|
-
vbool32_t less_b32 = __riscv_vmflt_vv_f32m1_b32(data_f32m1, min,
|
|
199
|
-
min = __riscv_vmerge_vvm_f32m1_tu(min, min, data_f32m1, less_b32,
|
|
200
|
-
min_indices = __riscv_vmerge_vvm_u64m2_tu(min_indices, min_indices, position_u64m2, less_b32,
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
199
|
+
max_vector_length);
|
|
200
|
+
vuint64m2_t position_u64m2 = __riscv_vadd_vx_u64m2(__riscv_vid_v_u64m2(max_vector_length), (nk_u64_t)offset,
|
|
201
|
+
max_vector_length);
|
|
202
|
+
vbool32_t less_b32 = __riscv_vmflt_vv_f32m1_b32(data_f32m1, min, max_vector_length);
|
|
203
|
+
min = __riscv_vmerge_vvm_f32m1_tu(min, min, data_f32m1, less_b32, max_vector_length);
|
|
204
|
+
min_indices = __riscv_vmerge_vvm_u64m2_tu(min_indices, min_indices, position_u64m2, less_b32,
|
|
205
|
+
max_vector_length);
|
|
206
|
+
vbool32_t greater_b32 = __riscv_vmflt_vv_f32m1_b32(max, data_f32m1, max_vector_length);
|
|
207
|
+
max = __riscv_vmerge_vvm_f32m1_tu(max, max, data_f32m1, greater_b32, max_vector_length);
|
|
208
|
+
max_indices = __riscv_vmerge_vvm_u64m2_tu(max_indices, max_indices, position_u64m2, greater_b32,
|
|
209
|
+
max_vector_length);
|
|
204
210
|
}
|
|
205
211
|
vfloat32m1_t id_max = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
206
|
-
nk_f32_t mn = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmin_vs_f32m1_f32m1(min, id_max,
|
|
212
|
+
nk_f32_t mn = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmin_vs_f32m1_f32m1(min, id_max, max_vector_length));
|
|
207
213
|
vfloat32m1_t id_min = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
208
|
-
nk_f32_t mx = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmax_vs_f32m1_f32m1(max, id_min,
|
|
214
|
+
nk_f32_t mx = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmax_vs_f32m1_f32m1(max, id_min, max_vector_length));
|
|
209
215
|
if (mn == NK_F32_MAX && mx == NK_F32_MIN) {
|
|
210
216
|
*min_value = NK_F32_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F32_MIN, *max_index = NK_SIZE_MAX;
|
|
211
217
|
return;
|
|
212
218
|
}
|
|
213
|
-
vbool32_t min_match_b32 = __riscv_vmfeq_vf_f32m1_b32(min, mn,
|
|
214
|
-
vuint64m2_t sentinel = __riscv_vmv_v_x_u64m2(NK_U64_MAX,
|
|
215
|
-
vuint64m2_t min_cands = __riscv_vmerge_vvm_u64m2(sentinel, min_indices, min_match_b32,
|
|
219
|
+
vbool32_t min_match_b32 = __riscv_vmfeq_vf_f32m1_b32(min, mn, max_vector_length);
|
|
220
|
+
vuint64m2_t sentinel = __riscv_vmv_v_x_u64m2(NK_U64_MAX, max_vector_length);
|
|
221
|
+
vuint64m2_t min_cands = __riscv_vmerge_vvm_u64m2(sentinel, min_indices, min_match_b32, max_vector_length);
|
|
216
222
|
vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
217
|
-
*min_value = mn,
|
|
218
|
-
|
|
219
|
-
vbool32_t max_match_b32 = __riscv_vmfeq_vf_f32m1_b32(max, mx,
|
|
220
|
-
vuint64m2_t max_cands = __riscv_vmerge_vvm_u64m2(sentinel, max_indices, max_match_b32,
|
|
221
|
-
*max_value = mx,
|
|
222
|
-
|
|
223
|
+
*min_value = mn, *min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
224
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands, id_umax, max_vector_length));
|
|
225
|
+
vbool32_t max_match_b32 = __riscv_vmfeq_vf_f32m1_b32(max, mx, max_vector_length);
|
|
226
|
+
vuint64m2_t max_cands = __riscv_vmerge_vvm_u64m2(sentinel, max_indices, max_match_b32, max_vector_length);
|
|
227
|
+
*max_value = mx, *max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
228
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands, id_umax, max_vector_length));
|
|
223
229
|
}
|
|
224
230
|
|
|
225
231
|
NK_PUBLIC void nk_reduce_minmax_f32_rvv( //
|
|
@@ -240,9 +246,9 @@ NK_PUBLIC void nk_reduce_minmax_f32_rvv( //
|
|
|
240
246
|
NK_INTERNAL void nk_reduce_moments_f64_rvv_contiguous_( //
|
|
241
247
|
nk_f64_t const *data, nk_size_t count, //
|
|
242
248
|
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
243
|
-
nk_size_t
|
|
244
|
-
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
245
|
-
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
249
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
250
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
251
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
246
252
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data += vector_length) {
|
|
247
253
|
vector_length = __riscv_vsetvl_e64m4(count);
|
|
248
254
|
vfloat64m4_t data_f64m4 = __riscv_vle64_v_f64m4(data, vector_length);
|
|
@@ -250,16 +256,16 @@ NK_INTERNAL void nk_reduce_moments_f64_rvv_contiguous_( //
|
|
|
250
256
|
sumsq_f64m4 = __riscv_vfmacc_vv_f64m4_tu(sumsq_f64m4, data_f64m4, data_f64m4, vector_length);
|
|
251
257
|
}
|
|
252
258
|
vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
253
|
-
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero,
|
|
254
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero,
|
|
259
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero, max_vector_length)),
|
|
260
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero, max_vector_length));
|
|
255
261
|
}
|
|
256
262
|
|
|
257
263
|
NK_INTERNAL void nk_reduce_moments_f64_rvv_strided_( //
|
|
258
264
|
nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
259
265
|
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
260
|
-
nk_size_t
|
|
261
|
-
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
262
|
-
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
266
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
267
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
268
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
263
269
|
unsigned char const *ptr = (unsigned char const *)data;
|
|
264
270
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
265
271
|
vector_length = __riscv_vsetvl_e64m4(count);
|
|
@@ -269,8 +275,8 @@ NK_INTERNAL void nk_reduce_moments_f64_rvv_strided_( //
|
|
|
269
275
|
sumsq_f64m4 = __riscv_vfmacc_vv_f64m4_tu(sumsq_f64m4, data_f64m4, data_f64m4, vector_length);
|
|
270
276
|
}
|
|
271
277
|
vfloat64m1_t zero = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
272
|
-
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero,
|
|
273
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero,
|
|
278
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero, max_vector_length)),
|
|
279
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero, max_vector_length));
|
|
274
280
|
}
|
|
275
281
|
|
|
276
282
|
NK_PUBLIC void nk_reduce_moments_f64_rvv( //
|
|
@@ -288,88 +294,92 @@ NK_INTERNAL void nk_reduce_minmax_f64_rvv_contiguous_( //
|
|
|
288
294
|
nk_f64_t const *data, nk_size_t count, //
|
|
289
295
|
nk_f64_t *min_value, nk_size_t *min_index, //
|
|
290
296
|
nk_f64_t *max_value, nk_size_t *max_index) {
|
|
291
|
-
nk_size_t
|
|
292
|
-
vfloat64m1_t min = __riscv_vfmv_v_f_f64m1(NK_F64_MAX,
|
|
293
|
-
vfloat64m1_t max = __riscv_vfmv_v_f_f64m1(NK_F64_MIN,
|
|
294
|
-
vuint64m1_t min_indices = __riscv_vmv_v_x_u64m1(0,
|
|
295
|
-
vuint64m1_t max_indices = __riscv_vmv_v_x_u64m1(0,
|
|
297
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
298
|
+
vfloat64m1_t min = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, max_vector_length);
|
|
299
|
+
vfloat64m1_t max = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, max_vector_length);
|
|
300
|
+
vuint64m1_t min_indices = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
301
|
+
vuint64m1_t max_indices = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
296
302
|
nk_size_t offset = 0;
|
|
297
|
-
for (nk_size_t remaining = count,
|
|
298
|
-
remaining -=
|
|
299
|
-
|
|
300
|
-
vfloat64m1_t data_f64m1 = __riscv_vle64_v_f64m1(data + offset,
|
|
301
|
-
vuint64m1_t position_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(
|
|
302
|
-
|
|
303
|
-
vbool64_t less_b64 = __riscv_vmflt_vv_f64m1_b64(data_f64m1, min,
|
|
304
|
-
min = __riscv_vmerge_vvm_f64m1_tu(min, min, data_f64m1, less_b64,
|
|
305
|
-
min_indices = __riscv_vmerge_vvm_u64m1_tu(min_indices, min_indices, position_u64m1, less_b64,
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
303
|
+
for (nk_size_t remaining = count, max_vector_length; remaining > 0;
|
|
304
|
+
remaining -= max_vector_length, offset += max_vector_length) {
|
|
305
|
+
max_vector_length = __riscv_vsetvl_e64m1(remaining);
|
|
306
|
+
vfloat64m1_t data_f64m1 = __riscv_vle64_v_f64m1(data + offset, max_vector_length);
|
|
307
|
+
vuint64m1_t position_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(max_vector_length), (nk_u64_t)offset,
|
|
308
|
+
max_vector_length);
|
|
309
|
+
vbool64_t less_b64 = __riscv_vmflt_vv_f64m1_b64(data_f64m1, min, max_vector_length);
|
|
310
|
+
min = __riscv_vmerge_vvm_f64m1_tu(min, min, data_f64m1, less_b64, max_vector_length);
|
|
311
|
+
min_indices = __riscv_vmerge_vvm_u64m1_tu(min_indices, min_indices, position_u64m1, less_b64,
|
|
312
|
+
max_vector_length);
|
|
313
|
+
vbool64_t greater_b64 = __riscv_vmflt_vv_f64m1_b64(max, data_f64m1, max_vector_length);
|
|
314
|
+
max = __riscv_vmerge_vvm_f64m1_tu(max, max, data_f64m1, greater_b64, max_vector_length);
|
|
315
|
+
max_indices = __riscv_vmerge_vvm_u64m1_tu(max_indices, max_indices, position_u64m1, greater_b64,
|
|
316
|
+
max_vector_length);
|
|
309
317
|
}
|
|
310
318
|
vfloat64m1_t id_max = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, 1);
|
|
311
|
-
nk_f64_t mn = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmin_vs_f64m1_f64m1(min, id_max,
|
|
319
|
+
nk_f64_t mn = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmin_vs_f64m1_f64m1(min, id_max, max_vector_length));
|
|
312
320
|
vfloat64m1_t id_min = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, 1);
|
|
313
|
-
nk_f64_t mx = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmax_vs_f64m1_f64m1(max, id_min,
|
|
321
|
+
nk_f64_t mx = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmax_vs_f64m1_f64m1(max, id_min, max_vector_length));
|
|
314
322
|
if (mn == NK_F64_MAX && mx == NK_F64_MIN) {
|
|
315
323
|
*min_value = NK_F64_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F64_MIN, *max_index = NK_SIZE_MAX;
|
|
316
324
|
return;
|
|
317
325
|
}
|
|
318
|
-
vbool64_t min_match_b64 = __riscv_vmfeq_vf_f64m1_b64(min, mn,
|
|
319
|
-
vuint64m1_t sentinel = __riscv_vmv_v_x_u64m1(NK_U64_MAX,
|
|
320
|
-
vuint64m1_t min_cands = __riscv_vmerge_vvm_u64m1(sentinel, min_indices, min_match_b64,
|
|
326
|
+
vbool64_t min_match_b64 = __riscv_vmfeq_vf_f64m1_b64(min, mn, max_vector_length);
|
|
327
|
+
vuint64m1_t sentinel = __riscv_vmv_v_x_u64m1(NK_U64_MAX, max_vector_length);
|
|
328
|
+
vuint64m1_t min_cands = __riscv_vmerge_vvm_u64m1(sentinel, min_indices, min_match_b64, max_vector_length);
|
|
321
329
|
vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
322
|
-
*min_value = mn,
|
|
323
|
-
|
|
324
|
-
vbool64_t max_match_b64 = __riscv_vmfeq_vf_f64m1_b64(max, mx,
|
|
325
|
-
vuint64m1_t max_cands = __riscv_vmerge_vvm_u64m1(sentinel, max_indices, max_match_b64,
|
|
326
|
-
*max_value = mx,
|
|
327
|
-
|
|
330
|
+
*min_value = mn, *min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
331
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands, id_umax, max_vector_length));
|
|
332
|
+
vbool64_t max_match_b64 = __riscv_vmfeq_vf_f64m1_b64(max, mx, max_vector_length);
|
|
333
|
+
vuint64m1_t max_cands = __riscv_vmerge_vvm_u64m1(sentinel, max_indices, max_match_b64, max_vector_length);
|
|
334
|
+
*max_value = mx, *max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
335
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands, id_umax, max_vector_length));
|
|
328
336
|
}
|
|
329
337
|
|
|
330
338
|
NK_INTERNAL void nk_reduce_minmax_f64_rvv_strided_( //
|
|
331
339
|
nk_f64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
332
340
|
nk_f64_t *min_value, nk_size_t *min_index, //
|
|
333
341
|
nk_f64_t *max_value, nk_size_t *max_index) {
|
|
334
|
-
nk_size_t
|
|
335
|
-
vfloat64m1_t min = __riscv_vfmv_v_f_f64m1(NK_F64_MAX,
|
|
336
|
-
vfloat64m1_t max = __riscv_vfmv_v_f_f64m1(NK_F64_MIN,
|
|
337
|
-
vuint64m1_t min_indices = __riscv_vmv_v_x_u64m1(0,
|
|
338
|
-
vuint64m1_t max_indices = __riscv_vmv_v_x_u64m1(0,
|
|
342
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
343
|
+
vfloat64m1_t min = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, max_vector_length);
|
|
344
|
+
vfloat64m1_t max = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, max_vector_length);
|
|
345
|
+
vuint64m1_t min_indices = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
346
|
+
vuint64m1_t max_indices = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
339
347
|
unsigned char const *ptr = (unsigned char const *)data;
|
|
340
348
|
nk_size_t offset = 0;
|
|
341
|
-
for (nk_size_t remaining = count,
|
|
342
|
-
remaining -=
|
|
343
|
-
|
|
349
|
+
for (nk_size_t remaining = count, max_vector_length; remaining > 0;
|
|
350
|
+
remaining -= max_vector_length, offset += max_vector_length, ptr += max_vector_length * stride_bytes) {
|
|
351
|
+
max_vector_length = __riscv_vsetvl_e64m1(remaining);
|
|
344
352
|
vfloat64m1_t data_f64m1 = __riscv_vlse64_v_f64m1((nk_f64_t const *)ptr, (nk_ssize_t)stride_bytes,
|
|
345
|
-
|
|
346
|
-
vuint64m1_t position_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(
|
|
347
|
-
|
|
348
|
-
vbool64_t less_b64 = __riscv_vmflt_vv_f64m1_b64(data_f64m1, min,
|
|
349
|
-
min = __riscv_vmerge_vvm_f64m1_tu(min, min, data_f64m1, less_b64,
|
|
350
|
-
min_indices = __riscv_vmerge_vvm_u64m1_tu(min_indices, min_indices, position_u64m1, less_b64,
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
353
|
+
max_vector_length);
|
|
354
|
+
vuint64m1_t position_u64m1 = __riscv_vadd_vx_u64m1(__riscv_vid_v_u64m1(max_vector_length), (nk_u64_t)offset,
|
|
355
|
+
max_vector_length);
|
|
356
|
+
vbool64_t less_b64 = __riscv_vmflt_vv_f64m1_b64(data_f64m1, min, max_vector_length);
|
|
357
|
+
min = __riscv_vmerge_vvm_f64m1_tu(min, min, data_f64m1, less_b64, max_vector_length);
|
|
358
|
+
min_indices = __riscv_vmerge_vvm_u64m1_tu(min_indices, min_indices, position_u64m1, less_b64,
|
|
359
|
+
max_vector_length);
|
|
360
|
+
vbool64_t greater_b64 = __riscv_vmflt_vv_f64m1_b64(max, data_f64m1, max_vector_length);
|
|
361
|
+
max = __riscv_vmerge_vvm_f64m1_tu(max, max, data_f64m1, greater_b64, max_vector_length);
|
|
362
|
+
max_indices = __riscv_vmerge_vvm_u64m1_tu(max_indices, max_indices, position_u64m1, greater_b64,
|
|
363
|
+
max_vector_length);
|
|
354
364
|
}
|
|
355
365
|
vfloat64m1_t id_max = __riscv_vfmv_v_f_f64m1(NK_F64_MAX, 1);
|
|
356
|
-
nk_f64_t mn = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmin_vs_f64m1_f64m1(min, id_max,
|
|
366
|
+
nk_f64_t mn = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmin_vs_f64m1_f64m1(min, id_max, max_vector_length));
|
|
357
367
|
vfloat64m1_t id_min = __riscv_vfmv_v_f_f64m1(NK_F64_MIN, 1);
|
|
358
|
-
nk_f64_t mx = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmax_vs_f64m1_f64m1(max, id_min,
|
|
368
|
+
nk_f64_t mx = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredmax_vs_f64m1_f64m1(max, id_min, max_vector_length));
|
|
359
369
|
if (mn == NK_F64_MAX && mx == NK_F64_MIN) {
|
|
360
370
|
*min_value = NK_F64_MAX, *min_index = NK_SIZE_MAX, *max_value = NK_F64_MIN, *max_index = NK_SIZE_MAX;
|
|
361
371
|
return;
|
|
362
372
|
}
|
|
363
|
-
vbool64_t min_match_b64 = __riscv_vmfeq_vf_f64m1_b64(min, mn,
|
|
364
|
-
vuint64m1_t sentinel = __riscv_vmv_v_x_u64m1(NK_U64_MAX,
|
|
365
|
-
vuint64m1_t min_cands = __riscv_vmerge_vvm_u64m1(sentinel, min_indices, min_match_b64,
|
|
373
|
+
vbool64_t min_match_b64 = __riscv_vmfeq_vf_f64m1_b64(min, mn, max_vector_length);
|
|
374
|
+
vuint64m1_t sentinel = __riscv_vmv_v_x_u64m1(NK_U64_MAX, max_vector_length);
|
|
375
|
+
vuint64m1_t min_cands = __riscv_vmerge_vvm_u64m1(sentinel, min_indices, min_match_b64, max_vector_length);
|
|
366
376
|
vuint64m1_t id_umax = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
367
|
-
*min_value = mn,
|
|
368
|
-
|
|
369
|
-
vbool64_t max_match_b64 = __riscv_vmfeq_vf_f64m1_b64(max, mx,
|
|
370
|
-
vuint64m1_t max_cands = __riscv_vmerge_vvm_u64m1(sentinel, max_indices, max_match_b64,
|
|
371
|
-
*max_value = mx,
|
|
372
|
-
|
|
377
|
+
*min_value = mn, *min_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
378
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands, id_umax, max_vector_length));
|
|
379
|
+
vbool64_t max_match_b64 = __riscv_vmfeq_vf_f64m1_b64(max, mx, max_vector_length);
|
|
380
|
+
vuint64m1_t max_cands = __riscv_vmerge_vvm_u64m1(sentinel, max_indices, max_match_b64, max_vector_length);
|
|
381
|
+
*max_value = mx, *max_index = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
382
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands, id_umax, max_vector_length));
|
|
373
383
|
}
|
|
374
384
|
|
|
375
385
|
NK_PUBLIC void nk_reduce_minmax_f64_rvv( //
|
|
@@ -428,10 +438,10 @@ NK_INTERNAL vuint8m1_t nk_comparable_to_fp6m1_rvv_(vuint8m1_t comparable_u8m1, n
|
|
|
428
438
|
NK_INTERNAL void nk_reduce_moments_i8_rvv_contiguous_( //
|
|
429
439
|
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
430
440
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
431
|
-
nk_size_t
|
|
441
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
432
442
|
nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
|
|
433
|
-
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0,
|
|
434
|
-
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
443
|
+
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, max_vector_length);
|
|
444
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
435
445
|
vint8m1_t zero_i8m1 = __riscv_vmv_v_x_i8m1(0, vlmax_elements);
|
|
436
446
|
|
|
437
447
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
@@ -444,8 +454,8 @@ NK_INTERNAL void nk_reduce_moments_i8_rvv_contiguous_( //
|
|
|
444
454
|
vint64m8_t data_i64m8 = __riscv_vsext_vf2_i64m8(data_i32m4, vlmax_elements);
|
|
445
455
|
|
|
446
456
|
// Accumulate sum (split m8 into two m4)
|
|
447
|
-
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 0),
|
|
448
|
-
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 1),
|
|
457
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 0), vector_length);
|
|
458
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 1), vector_length);
|
|
449
459
|
|
|
450
460
|
// Sumsq: i8 × i8 → i16 (widening multiply)
|
|
451
461
|
vint16m2_t squares_i16m2 = __riscv_vwmul_vv_i16m2(data_i8m1, data_i8m1, vlmax_elements);
|
|
@@ -454,25 +464,25 @@ NK_INTERNAL void nk_reduce_moments_i8_rvv_contiguous_( //
|
|
|
454
464
|
vlmax_elements);
|
|
455
465
|
vuint64m8_t squares_u64m8 = __riscv_vwcvtu_x_x_v_u64m8(squares_u32m4, vlmax_elements);
|
|
456
466
|
|
|
457
|
-
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0),
|
|
458
|
-
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1),
|
|
467
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vector_length);
|
|
468
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vector_length);
|
|
459
469
|
}
|
|
460
470
|
|
|
461
471
|
// Horizontal reduction
|
|
462
472
|
vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
|
|
463
|
-
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1,
|
|
473
|
+
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, max_vector_length));
|
|
464
474
|
|
|
465
475
|
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
466
|
-
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1,
|
|
476
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, max_vector_length));
|
|
467
477
|
}
|
|
468
478
|
|
|
469
479
|
NK_INTERNAL void nk_reduce_moments_i8_rvv_strided_( //
|
|
470
480
|
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
471
481
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
472
|
-
nk_size_t
|
|
482
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
473
483
|
nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
|
|
474
|
-
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0,
|
|
475
|
-
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
484
|
+
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, max_vector_length);
|
|
485
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
476
486
|
vint8m1_t zero_i8m1 = __riscv_vmv_v_x_i8m1(0, vlmax_elements);
|
|
477
487
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
478
488
|
|
|
@@ -487,8 +497,8 @@ NK_INTERNAL void nk_reduce_moments_i8_rvv_strided_( //
|
|
|
487
497
|
vint64m8_t data_i64m8 = __riscv_vsext_vf2_i64m8(data_i32m4, vlmax_elements);
|
|
488
498
|
|
|
489
499
|
// Accumulate sum (split m8 into two m4)
|
|
490
|
-
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 0),
|
|
491
|
-
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 1),
|
|
500
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 0), vector_length);
|
|
501
|
+
sum_i64m4 = __riscv_vadd_vv_i64m4(sum_i64m4, __riscv_vget_v_i64m8_i64m4(data_i64m8, 1), vector_length);
|
|
492
502
|
|
|
493
503
|
// Sumsq: i8 × i8 → i16 (widening multiply)
|
|
494
504
|
vint16m2_t squares_i16m2 = __riscv_vwmul_vv_i16m2(data_i8m1, data_i8m1, vlmax_elements);
|
|
@@ -497,16 +507,16 @@ NK_INTERNAL void nk_reduce_moments_i8_rvv_strided_( //
|
|
|
497
507
|
vlmax_elements);
|
|
498
508
|
vuint64m8_t squares_u64m8 = __riscv_vwcvtu_x_x_v_u64m8(squares_u32m4, vlmax_elements);
|
|
499
509
|
|
|
500
|
-
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0),
|
|
501
|
-
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1),
|
|
510
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vector_length);
|
|
511
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vector_length);
|
|
502
512
|
}
|
|
503
513
|
|
|
504
514
|
// Horizontal reduction
|
|
505
515
|
vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
|
|
506
|
-
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1,
|
|
516
|
+
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, max_vector_length));
|
|
507
517
|
|
|
508
518
|
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
509
|
-
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1,
|
|
519
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, max_vector_length));
|
|
510
520
|
}
|
|
511
521
|
|
|
512
522
|
NK_PUBLIC void nk_reduce_moments_i8_rvv( //
|
|
@@ -525,11 +535,11 @@ NK_INTERNAL void nk_reduce_minmax_i8_rvv_contiguous_( //
|
|
|
525
535
|
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
526
536
|
nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
527
537
|
nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
528
|
-
nk_size_t
|
|
529
|
-
vint8m1_t min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX,
|
|
530
|
-
vint8m1_t max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN,
|
|
531
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
532
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
538
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
539
|
+
vint8m1_t min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, max_vector_length);
|
|
540
|
+
vint8m1_t max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, max_vector_length);
|
|
541
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
542
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
533
543
|
|
|
534
544
|
nk_size_t offset = 0;
|
|
535
545
|
for (nk_size_t vector_length; count > 0;
|
|
@@ -554,34 +564,36 @@ NK_INTERNAL void nk_reduce_minmax_i8_rvv_contiguous_( //
|
|
|
554
564
|
|
|
555
565
|
// Horizontal reduction for min
|
|
556
566
|
vint8m1_t init_max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, 1);
|
|
557
|
-
nk_i8_t min_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmin_vs_i8m1_i8m1(min_i8m1, init_max_i8m1,
|
|
558
|
-
vbool8_t min_match_b8 = __riscv_vmseq_vx_i8m1_b8(min_i8m1, min_val,
|
|
559
|
-
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX,
|
|
560
|
-
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
567
|
+
nk_i8_t min_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmin_vs_i8m1_i8m1(min_i8m1, init_max_i8m1, max_vector_length));
|
|
568
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_i8m1_b8(min_i8m1, min_val, max_vector_length);
|
|
569
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
570
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
571
|
+
max_vector_length);
|
|
561
572
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
562
573
|
*min_value_ptr = min_val;
|
|
563
574
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
564
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
575
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
565
576
|
|
|
566
577
|
// Horizontal reduction for max
|
|
567
578
|
vint8m1_t init_min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, 1);
|
|
568
|
-
nk_i8_t max_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmax_vs_i8m1_i8m1(max_i8m1, init_min_i8m1,
|
|
569
|
-
vbool8_t max_match_b8 = __riscv_vmseq_vx_i8m1_b8(max_i8m1, max_val,
|
|
570
|
-
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
579
|
+
nk_i8_t max_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmax_vs_i8m1_i8m1(max_i8m1, init_min_i8m1, max_vector_length));
|
|
580
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_i8m1_b8(max_i8m1, max_val, max_vector_length);
|
|
581
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
582
|
+
max_vector_length);
|
|
571
583
|
*max_value_ptr = max_val;
|
|
572
584
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
573
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
585
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
574
586
|
}
|
|
575
587
|
|
|
576
588
|
NK_INTERNAL void nk_reduce_minmax_i8_rvv_strided_( //
|
|
577
589
|
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
578
590
|
nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
579
591
|
nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
580
|
-
nk_size_t
|
|
581
|
-
vint8m1_t min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX,
|
|
582
|
-
vint8m1_t max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN,
|
|
583
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
584
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
592
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
593
|
+
vint8m1_t min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, max_vector_length);
|
|
594
|
+
vint8m1_t max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, max_vector_length);
|
|
595
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
596
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
585
597
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
586
598
|
|
|
587
599
|
nk_size_t offset = 0;
|
|
@@ -607,23 +619,25 @@ NK_INTERNAL void nk_reduce_minmax_i8_rvv_strided_( //
|
|
|
607
619
|
|
|
608
620
|
// Horizontal reduction for min
|
|
609
621
|
vint8m1_t init_max_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MAX, 1);
|
|
610
|
-
nk_i8_t min_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmin_vs_i8m1_i8m1(min_i8m1, init_max_i8m1,
|
|
611
|
-
vbool8_t min_match_b8 = __riscv_vmseq_vx_i8m1_b8(min_i8m1, min_val,
|
|
612
|
-
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX,
|
|
613
|
-
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
622
|
+
nk_i8_t min_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmin_vs_i8m1_i8m1(min_i8m1, init_max_i8m1, max_vector_length));
|
|
623
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_i8m1_b8(min_i8m1, min_val, max_vector_length);
|
|
624
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
625
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
626
|
+
max_vector_length);
|
|
614
627
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
615
628
|
*min_value_ptr = min_val;
|
|
616
629
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
617
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
630
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
618
631
|
|
|
619
632
|
// Horizontal reduction for max
|
|
620
633
|
vint8m1_t init_min_i8m1 = __riscv_vmv_v_x_i8m1(NK_I8_MIN, 1);
|
|
621
|
-
nk_i8_t max_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmax_vs_i8m1_i8m1(max_i8m1, init_min_i8m1,
|
|
622
|
-
vbool8_t max_match_b8 = __riscv_vmseq_vx_i8m1_b8(max_i8m1, max_val,
|
|
623
|
-
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
634
|
+
nk_i8_t max_val = __riscv_vmv_x_s_i8m1_i8(__riscv_vredmax_vs_i8m1_i8m1(max_i8m1, init_min_i8m1, max_vector_length));
|
|
635
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_i8m1_b8(max_i8m1, max_val, max_vector_length);
|
|
636
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
637
|
+
max_vector_length);
|
|
624
638
|
*max_value_ptr = max_val;
|
|
625
639
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
626
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
640
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
627
641
|
}
|
|
628
642
|
|
|
629
643
|
NK_PUBLIC void nk_reduce_minmax_i8_rvv( //
|
|
@@ -650,10 +664,10 @@ NK_PUBLIC void nk_reduce_minmax_i8_rvv( //
|
|
|
650
664
|
NK_INTERNAL void nk_reduce_moments_u8_rvv_contiguous_( //
|
|
651
665
|
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
652
666
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
653
|
-
nk_size_t
|
|
667
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
654
668
|
nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
|
|
655
|
-
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
656
|
-
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
669
|
+
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
670
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
657
671
|
vuint8m1_t zero_u8m1 = __riscv_vmv_v_x_u8m1(0, vlmax_elements);
|
|
658
672
|
|
|
659
673
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
@@ -666,8 +680,8 @@ NK_INTERNAL void nk_reduce_moments_u8_rvv_contiguous_( //
|
|
|
666
680
|
vuint64m8_t data_u64m8 = __riscv_vzext_vf2_u64m8(data_u32m4, vlmax_elements);
|
|
667
681
|
|
|
668
682
|
// Accumulate sum (split m8 into two m4)
|
|
669
|
-
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 0),
|
|
670
|
-
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 1),
|
|
683
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 0), vector_length);
|
|
684
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 1), vector_length);
|
|
671
685
|
|
|
672
686
|
// Sumsq: u8 × u8 → u16 (widening multiply)
|
|
673
687
|
vuint16m2_t squares_u16m2 = __riscv_vwmulu_vv_u16m2(data_u8m1, data_u8m1, vlmax_elements);
|
|
@@ -675,23 +689,23 @@ NK_INTERNAL void nk_reduce_moments_u8_rvv_contiguous_( //
|
|
|
675
689
|
vuint32m4_t squares_u32m4 = __riscv_vzext_vf2_u32m4(squares_u16m2, vlmax_elements);
|
|
676
690
|
vuint64m8_t squares_u64m8 = __riscv_vzext_vf2_u64m8(squares_u32m4, vlmax_elements);
|
|
677
691
|
|
|
678
|
-
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0),
|
|
679
|
-
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1),
|
|
692
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vector_length);
|
|
693
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vector_length);
|
|
680
694
|
}
|
|
681
695
|
|
|
682
696
|
// Horizontal reduction
|
|
683
697
|
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
684
|
-
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1,
|
|
685
|
-
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1,
|
|
698
|
+
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, max_vector_length)),
|
|
699
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, max_vector_length));
|
|
686
700
|
}
|
|
687
701
|
|
|
688
702
|
NK_INTERNAL void nk_reduce_moments_u8_rvv_strided_( //
|
|
689
703
|
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
690
704
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
691
|
-
nk_size_t
|
|
705
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
692
706
|
nk_size_t vlmax_elements = __riscv_vsetvlmax_e8m1();
|
|
693
|
-
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
694
|
-
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
707
|
+
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
708
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
695
709
|
vuint8m1_t zero_u8m1 = __riscv_vmv_v_x_u8m1(0, vlmax_elements);
|
|
696
710
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
697
711
|
|
|
@@ -706,8 +720,8 @@ NK_INTERNAL void nk_reduce_moments_u8_rvv_strided_( //
|
|
|
706
720
|
vuint64m8_t data_u64m8 = __riscv_vzext_vf2_u64m8(data_u32m4, vlmax_elements);
|
|
707
721
|
|
|
708
722
|
// Accumulate sum (split m8 into two m4)
|
|
709
|
-
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 0),
|
|
710
|
-
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 1),
|
|
723
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 0), vector_length);
|
|
724
|
+
sum_u64m4 = __riscv_vadd_vv_u64m4(sum_u64m4, __riscv_vget_v_u64m8_u64m4(data_u64m8, 1), vector_length);
|
|
711
725
|
|
|
712
726
|
// Sumsq: u8 × u8 → u16 (widening multiply)
|
|
713
727
|
vuint16m2_t squares_u16m2 = __riscv_vwmulu_vv_u16m2(data_u8m1, data_u8m1, vlmax_elements);
|
|
@@ -715,14 +729,14 @@ NK_INTERNAL void nk_reduce_moments_u8_rvv_strided_( //
|
|
|
715
729
|
vuint32m4_t squares_u32m4 = __riscv_vzext_vf2_u32m4(squares_u16m2, vlmax_elements);
|
|
716
730
|
vuint64m8_t squares_u64m8 = __riscv_vzext_vf2_u64m8(squares_u32m4, vlmax_elements);
|
|
717
731
|
|
|
718
|
-
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0),
|
|
719
|
-
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1),
|
|
732
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 0), vector_length);
|
|
733
|
+
sumsq_u64m4 = __riscv_vadd_vv_u64m4(sumsq_u64m4, __riscv_vget_v_u64m8_u64m4(squares_u64m8, 1), vector_length);
|
|
720
734
|
}
|
|
721
735
|
|
|
722
736
|
// Horizontal reduction
|
|
723
737
|
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
724
|
-
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1,
|
|
725
|
-
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1,
|
|
738
|
+
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, max_vector_length)),
|
|
739
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, max_vector_length));
|
|
726
740
|
}
|
|
727
741
|
|
|
728
742
|
NK_PUBLIC void nk_reduce_moments_u8_rvv( //
|
|
@@ -741,11 +755,11 @@ NK_INTERNAL void nk_reduce_minmax_u8_rvv_contiguous_( //
|
|
|
741
755
|
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
742
756
|
nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
743
757
|
nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
744
|
-
nk_size_t
|
|
745
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX,
|
|
746
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN,
|
|
747
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
748
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
758
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
759
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, max_vector_length);
|
|
760
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, max_vector_length);
|
|
761
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
762
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
749
763
|
|
|
750
764
|
nk_size_t offset = 0;
|
|
751
765
|
for (nk_size_t vector_length; count > 0;
|
|
@@ -770,34 +784,38 @@ NK_INTERNAL void nk_reduce_minmax_u8_rvv_contiguous_( //
|
|
|
770
784
|
|
|
771
785
|
// Horizontal reduction for min
|
|
772
786
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, 1);
|
|
773
|
-
nk_u8_t min_val = __riscv_vmv_x_s_u8m1_u8(
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
vuint64m8_t
|
|
787
|
+
nk_u8_t min_val = __riscv_vmv_x_s_u8m1_u8(
|
|
788
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
789
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_val, max_vector_length);
|
|
790
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
791
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
792
|
+
max_vector_length);
|
|
777
793
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
778
794
|
*min_value_ptr = min_val;
|
|
779
795
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
780
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
796
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
781
797
|
|
|
782
798
|
// Horizontal reduction for max
|
|
783
799
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, 1);
|
|
784
|
-
nk_u8_t max_val = __riscv_vmv_x_s_u8m1_u8(
|
|
785
|
-
|
|
786
|
-
|
|
800
|
+
nk_u8_t max_val = __riscv_vmv_x_s_u8m1_u8(
|
|
801
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
802
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_val, max_vector_length);
|
|
803
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
804
|
+
max_vector_length);
|
|
787
805
|
*max_value_ptr = max_val;
|
|
788
806
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
789
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
807
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
790
808
|
}
|
|
791
809
|
|
|
792
810
|
NK_INTERNAL void nk_reduce_minmax_u8_rvv_strided_( //
|
|
793
811
|
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
794
812
|
nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
795
813
|
nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
796
|
-
nk_size_t
|
|
797
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX,
|
|
798
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN,
|
|
799
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
800
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
814
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
815
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, max_vector_length);
|
|
816
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, max_vector_length);
|
|
817
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
818
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
801
819
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
802
820
|
|
|
803
821
|
nk_size_t offset = 0;
|
|
@@ -823,23 +841,27 @@ NK_INTERNAL void nk_reduce_minmax_u8_rvv_strided_( //
|
|
|
823
841
|
|
|
824
842
|
// Horizontal reduction for min
|
|
825
843
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MAX, 1);
|
|
826
|
-
nk_u8_t min_val = __riscv_vmv_x_s_u8m1_u8(
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
vuint64m8_t
|
|
844
|
+
nk_u8_t min_val = __riscv_vmv_x_s_u8m1_u8(
|
|
845
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
846
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_val, max_vector_length);
|
|
847
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
848
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
849
|
+
max_vector_length);
|
|
830
850
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
831
851
|
*min_value_ptr = min_val;
|
|
832
852
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
833
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
853
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
834
854
|
|
|
835
855
|
// Horizontal reduction for max
|
|
836
856
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(NK_U8_MIN, 1);
|
|
837
|
-
nk_u8_t max_val = __riscv_vmv_x_s_u8m1_u8(
|
|
838
|
-
|
|
839
|
-
|
|
857
|
+
nk_u8_t max_val = __riscv_vmv_x_s_u8m1_u8(
|
|
858
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
859
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_val, max_vector_length);
|
|
860
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
861
|
+
max_vector_length);
|
|
840
862
|
*max_value_ptr = max_val;
|
|
841
863
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
842
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
864
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
843
865
|
}
|
|
844
866
|
|
|
845
867
|
NK_PUBLIC void nk_reduce_minmax_u8_rvv( //
|
|
@@ -866,9 +888,9 @@ NK_PUBLIC void nk_reduce_minmax_u8_rvv( //
|
|
|
866
888
|
NK_INTERNAL void nk_reduce_moments_i16_rvv_contiguous_( //
|
|
867
889
|
nk_i16_t const *data_ptr, nk_size_t count, //
|
|
868
890
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
869
|
-
nk_size_t
|
|
870
|
-
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0,
|
|
871
|
-
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
891
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
892
|
+
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, max_vector_length);
|
|
893
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
872
894
|
|
|
873
895
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
874
896
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
@@ -889,18 +911,18 @@ NK_INTERNAL void nk_reduce_moments_i16_rvv_contiguous_( //
|
|
|
889
911
|
|
|
890
912
|
// Horizontal reduction
|
|
891
913
|
vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
|
|
892
|
-
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1,
|
|
914
|
+
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, max_vector_length));
|
|
893
915
|
|
|
894
916
|
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
895
|
-
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1,
|
|
917
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, max_vector_length));
|
|
896
918
|
}
|
|
897
919
|
|
|
898
920
|
NK_INTERNAL void nk_reduce_moments_i16_rvv_strided_( //
|
|
899
921
|
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
900
922
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
901
|
-
nk_size_t
|
|
902
|
-
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0,
|
|
903
|
-
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
923
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
924
|
+
vint64m4_t sum_i64m4 = __riscv_vmv_v_x_i64m4(0, max_vector_length);
|
|
925
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
904
926
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
905
927
|
|
|
906
928
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
@@ -922,10 +944,10 @@ NK_INTERNAL void nk_reduce_moments_i16_rvv_strided_( //
|
|
|
922
944
|
|
|
923
945
|
// Horizontal reduction
|
|
924
946
|
vint64m1_t zero_i64m1 = __riscv_vmv_v_x_i64m1(0, 1);
|
|
925
|
-
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1,
|
|
947
|
+
*sum_ptr = __riscv_vmv_x_s_i64m1_i64(__riscv_vredsum_vs_i64m4_i64m1(sum_i64m4, zero_i64m1, max_vector_length));
|
|
926
948
|
|
|
927
949
|
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
928
|
-
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1,
|
|
950
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, max_vector_length));
|
|
929
951
|
}
|
|
930
952
|
|
|
931
953
|
NK_PUBLIC void nk_reduce_moments_i16_rvv( //
|
|
@@ -944,11 +966,11 @@ NK_INTERNAL void nk_reduce_minmax_i16_rvv_contiguous_( //
|
|
|
944
966
|
nk_i16_t const *data_ptr, nk_size_t count, //
|
|
945
967
|
nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
946
968
|
nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
947
|
-
nk_size_t
|
|
948
|
-
vint16m1_t min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX,
|
|
949
|
-
vint16m1_t max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN,
|
|
950
|
-
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
951
|
-
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
969
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m1();
|
|
970
|
+
vint16m1_t min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, max_vector_length);
|
|
971
|
+
vint16m1_t max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, max_vector_length);
|
|
972
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
973
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
952
974
|
|
|
953
975
|
nk_size_t offset = 0;
|
|
954
976
|
for (nk_size_t vector_length; count > 0;
|
|
@@ -971,34 +993,38 @@ NK_INTERNAL void nk_reduce_minmax_i16_rvv_contiguous_( //
|
|
|
971
993
|
|
|
972
994
|
// Horizontal reduction for min
|
|
973
995
|
vint16m1_t init_max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, 1);
|
|
974
|
-
nk_i16_t min_val = __riscv_vmv_x_s_i16m1_i16(
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
vuint64m4_t
|
|
996
|
+
nk_i16_t min_val = __riscv_vmv_x_s_i16m1_i16(
|
|
997
|
+
__riscv_vredmin_vs_i16m1_i16m1(min_i16m1, init_max_i16m1, max_vector_length));
|
|
998
|
+
vbool16_t min_match_b16 = __riscv_vmseq_vx_i16m1_b16(min_i16m1, min_val, max_vector_length);
|
|
999
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, max_vector_length);
|
|
1000
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
1001
|
+
max_vector_length);
|
|
978
1002
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
979
1003
|
*min_value_ptr = min_val;
|
|
980
1004
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
981
|
-
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1,
|
|
1005
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
982
1006
|
|
|
983
1007
|
// Horizontal reduction for max
|
|
984
1008
|
vint16m1_t init_min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, 1);
|
|
985
|
-
nk_i16_t max_val = __riscv_vmv_x_s_i16m1_i16(
|
|
986
|
-
|
|
987
|
-
|
|
1009
|
+
nk_i16_t max_val = __riscv_vmv_x_s_i16m1_i16(
|
|
1010
|
+
__riscv_vredmax_vs_i16m1_i16m1(max_i16m1, init_min_i16m1, max_vector_length));
|
|
1011
|
+
vbool16_t max_match_b16 = __riscv_vmseq_vx_i16m1_b16(max_i16m1, max_val, max_vector_length);
|
|
1012
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16,
|
|
1013
|
+
max_vector_length);
|
|
988
1014
|
*max_value_ptr = max_val;
|
|
989
1015
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
990
|
-
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1,
|
|
1016
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
991
1017
|
}
|
|
992
1018
|
|
|
993
1019
|
NK_INTERNAL void nk_reduce_minmax_i16_rvv_strided_( //
|
|
994
1020
|
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
995
1021
|
nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
996
1022
|
nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
997
|
-
nk_size_t
|
|
998
|
-
vint16m1_t min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX,
|
|
999
|
-
vint16m1_t max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN,
|
|
1000
|
-
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1001
|
-
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1023
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m1();
|
|
1024
|
+
vint16m1_t min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, max_vector_length);
|
|
1025
|
+
vint16m1_t max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, max_vector_length);
|
|
1026
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1027
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1002
1028
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1003
1029
|
|
|
1004
1030
|
nk_size_t offset = 0;
|
|
@@ -1022,23 +1048,27 @@ NK_INTERNAL void nk_reduce_minmax_i16_rvv_strided_( //
|
|
|
1022
1048
|
|
|
1023
1049
|
// Horizontal reduction for min
|
|
1024
1050
|
vint16m1_t init_max_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MAX, 1);
|
|
1025
|
-
nk_i16_t min_val = __riscv_vmv_x_s_i16m1_i16(
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
vuint64m4_t
|
|
1051
|
+
nk_i16_t min_val = __riscv_vmv_x_s_i16m1_i16(
|
|
1052
|
+
__riscv_vredmin_vs_i16m1_i16m1(min_i16m1, init_max_i16m1, max_vector_length));
|
|
1053
|
+
vbool16_t min_match_b16 = __riscv_vmseq_vx_i16m1_b16(min_i16m1, min_val, max_vector_length);
|
|
1054
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, max_vector_length);
|
|
1055
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
1056
|
+
max_vector_length);
|
|
1029
1057
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1030
1058
|
*min_value_ptr = min_val;
|
|
1031
1059
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1032
|
-
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1,
|
|
1060
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
1033
1061
|
|
|
1034
1062
|
// Horizontal reduction for max
|
|
1035
1063
|
vint16m1_t init_min_i16m1 = __riscv_vmv_v_x_i16m1(NK_I16_MIN, 1);
|
|
1036
|
-
nk_i16_t max_val = __riscv_vmv_x_s_i16m1_i16(
|
|
1037
|
-
|
|
1038
|
-
|
|
1064
|
+
nk_i16_t max_val = __riscv_vmv_x_s_i16m1_i16(
|
|
1065
|
+
__riscv_vredmax_vs_i16m1_i16m1(max_i16m1, init_min_i16m1, max_vector_length));
|
|
1066
|
+
vbool16_t max_match_b16 = __riscv_vmseq_vx_i16m1_b16(max_i16m1, max_val, max_vector_length);
|
|
1067
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16,
|
|
1068
|
+
max_vector_length);
|
|
1039
1069
|
*max_value_ptr = max_val;
|
|
1040
1070
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1041
|
-
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1,
|
|
1071
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
1042
1072
|
}
|
|
1043
1073
|
|
|
1044
1074
|
NK_PUBLIC void nk_reduce_minmax_i16_rvv( //
|
|
@@ -1065,9 +1095,9 @@ NK_PUBLIC void nk_reduce_minmax_i16_rvv( //
|
|
|
1065
1095
|
NK_INTERNAL void nk_reduce_moments_u16_rvv_contiguous_( //
|
|
1066
1096
|
nk_u16_t const *data_ptr, nk_size_t count, //
|
|
1067
1097
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1068
|
-
nk_size_t
|
|
1069
|
-
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1070
|
-
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1098
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
1099
|
+
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1100
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1071
1101
|
|
|
1072
1102
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1073
1103
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
@@ -1087,16 +1117,16 @@ NK_INTERNAL void nk_reduce_moments_u16_rvv_contiguous_( //
|
|
|
1087
1117
|
|
|
1088
1118
|
// Horizontal reduction
|
|
1089
1119
|
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
1090
|
-
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1,
|
|
1091
|
-
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1,
|
|
1120
|
+
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, max_vector_length)),
|
|
1121
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, max_vector_length));
|
|
1092
1122
|
}
|
|
1093
1123
|
|
|
1094
1124
|
NK_INTERNAL void nk_reduce_moments_u16_rvv_strided_( //
|
|
1095
1125
|
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1096
1126
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1097
|
-
nk_size_t
|
|
1098
|
-
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1099
|
-
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1127
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
1128
|
+
vuint64m4_t sum_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1129
|
+
vuint64m4_t sumsq_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1100
1130
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1101
1131
|
|
|
1102
1132
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
@@ -1117,8 +1147,8 @@ NK_INTERNAL void nk_reduce_moments_u16_rvv_strided_( //
|
|
|
1117
1147
|
|
|
1118
1148
|
// Horizontal reduction
|
|
1119
1149
|
vuint64m1_t zero_u64m1 = __riscv_vmv_v_x_u64m1(0, 1);
|
|
1120
|
-
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1,
|
|
1121
|
-
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1,
|
|
1150
|
+
*sum_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sum_u64m4, zero_u64m1, max_vector_length)),
|
|
1151
|
+
*sumsq_ptr = __riscv_vmv_x_s_u64m1_u64(__riscv_vredsum_vs_u64m4_u64m1(sumsq_u64m4, zero_u64m1, max_vector_length));
|
|
1122
1152
|
}
|
|
1123
1153
|
|
|
1124
1154
|
NK_PUBLIC void nk_reduce_moments_u16_rvv( //
|
|
@@ -1137,11 +1167,11 @@ NK_INTERNAL void nk_reduce_minmax_u16_rvv_contiguous_( //
|
|
|
1137
1167
|
nk_u16_t const *data_ptr, nk_size_t count, //
|
|
1138
1168
|
nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1139
1169
|
nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1140
|
-
nk_size_t
|
|
1141
|
-
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX,
|
|
1142
|
-
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN,
|
|
1143
|
-
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1144
|
-
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1170
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m1();
|
|
1171
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, max_vector_length);
|
|
1172
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, max_vector_length);
|
|
1173
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1174
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1145
1175
|
|
|
1146
1176
|
nk_size_t offset = 0;
|
|
1147
1177
|
for (nk_size_t vector_length; count > 0;
|
|
@@ -1164,34 +1194,38 @@ NK_INTERNAL void nk_reduce_minmax_u16_rvv_contiguous_( //
|
|
|
1164
1194
|
|
|
1165
1195
|
// Horizontal reduction for min
|
|
1166
1196
|
vuint16m1_t init_max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, 1);
|
|
1167
|
-
nk_u16_t min_val = __riscv_vmv_x_s_u16m1_u16(
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
vuint64m4_t
|
|
1197
|
+
nk_u16_t min_val = __riscv_vmv_x_s_u16m1_u16(
|
|
1198
|
+
__riscv_vredminu_vs_u16m1_u16m1(min_u16m1, init_max_u16m1, max_vector_length));
|
|
1199
|
+
vbool16_t min_match_b16 = __riscv_vmseq_vx_u16m1_b16(min_u16m1, min_val, max_vector_length);
|
|
1200
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, max_vector_length);
|
|
1201
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
1202
|
+
max_vector_length);
|
|
1171
1203
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1172
1204
|
*min_value_ptr = min_val;
|
|
1173
1205
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1174
|
-
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1,
|
|
1206
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
1175
1207
|
|
|
1176
1208
|
// Horizontal reduction for max
|
|
1177
1209
|
vuint16m1_t init_min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, 1);
|
|
1178
|
-
nk_u16_t max_val = __riscv_vmv_x_s_u16m1_u16(
|
|
1179
|
-
|
|
1180
|
-
|
|
1210
|
+
nk_u16_t max_val = __riscv_vmv_x_s_u16m1_u16(
|
|
1211
|
+
__riscv_vredmaxu_vs_u16m1_u16m1(max_u16m1, init_min_u16m1, max_vector_length));
|
|
1212
|
+
vbool16_t max_match_b16 = __riscv_vmseq_vx_u16m1_b16(max_u16m1, max_val, max_vector_length);
|
|
1213
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16,
|
|
1214
|
+
max_vector_length);
|
|
1181
1215
|
*max_value_ptr = max_val;
|
|
1182
1216
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1183
|
-
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1,
|
|
1217
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
1184
1218
|
}
|
|
1185
1219
|
|
|
1186
1220
|
NK_INTERNAL void nk_reduce_minmax_u16_rvv_strided_( //
|
|
1187
1221
|
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1188
1222
|
nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1189
1223
|
nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1190
|
-
nk_size_t
|
|
1191
|
-
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX,
|
|
1192
|
-
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN,
|
|
1193
|
-
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1194
|
-
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
1224
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m1();
|
|
1225
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, max_vector_length);
|
|
1226
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, max_vector_length);
|
|
1227
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1228
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
1195
1229
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1196
1230
|
|
|
1197
1231
|
nk_size_t offset = 0;
|
|
@@ -1215,23 +1249,27 @@ NK_INTERNAL void nk_reduce_minmax_u16_rvv_strided_( //
|
|
|
1215
1249
|
|
|
1216
1250
|
// Horizontal reduction for min
|
|
1217
1251
|
vuint16m1_t init_max_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MAX, 1);
|
|
1218
|
-
nk_u16_t min_val = __riscv_vmv_x_s_u16m1_u16(
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
vuint64m4_t
|
|
1252
|
+
nk_u16_t min_val = __riscv_vmv_x_s_u16m1_u16(
|
|
1253
|
+
__riscv_vredminu_vs_u16m1_u16m1(min_u16m1, init_max_u16m1, max_vector_length));
|
|
1254
|
+
vbool16_t min_match_b16 = __riscv_vmseq_vx_u16m1_b16(min_u16m1, min_val, max_vector_length);
|
|
1255
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, max_vector_length);
|
|
1256
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
1257
|
+
max_vector_length);
|
|
1222
1258
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1223
1259
|
*min_value_ptr = min_val;
|
|
1224
1260
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1225
|
-
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1,
|
|
1261
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
1226
1262
|
|
|
1227
1263
|
// Horizontal reduction for max
|
|
1228
1264
|
vuint16m1_t init_min_u16m1 = __riscv_vmv_v_x_u16m1(NK_U16_MIN, 1);
|
|
1229
|
-
nk_u16_t max_val = __riscv_vmv_x_s_u16m1_u16(
|
|
1230
|
-
|
|
1231
|
-
|
|
1265
|
+
nk_u16_t max_val = __riscv_vmv_x_s_u16m1_u16(
|
|
1266
|
+
__riscv_vredmaxu_vs_u16m1_u16m1(max_u16m1, init_min_u16m1, max_vector_length));
|
|
1267
|
+
vbool16_t max_match_b16 = __riscv_vmseq_vx_u16m1_b16(max_u16m1, max_val, max_vector_length);
|
|
1268
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16,
|
|
1269
|
+
max_vector_length);
|
|
1232
1270
|
*max_value_ptr = max_val;
|
|
1233
1271
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1234
|
-
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1,
|
|
1272
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
1235
1273
|
}
|
|
1236
1274
|
|
|
1237
1275
|
NK_PUBLIC void nk_reduce_minmax_u16_rvv( //
|
|
@@ -1258,11 +1296,11 @@ NK_PUBLIC void nk_reduce_minmax_u16_rvv( //
|
|
|
1258
1296
|
NK_INTERNAL void nk_reduce_moments_i32_rvv_contiguous_( //
|
|
1259
1297
|
nk_i32_t const *data_ptr, nk_size_t count, //
|
|
1260
1298
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1261
|
-
nk_size_t
|
|
1262
|
-
// 128-bit per-lane accumulator for sum: (
|
|
1263
|
-
vuint64m2_t
|
|
1264
|
-
vint64m2_t
|
|
1265
|
-
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1299
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
1300
|
+
// 128-bit per-lane accumulator for sum: (sum_high, sum_low)
|
|
1301
|
+
vuint64m2_t sum_low_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1302
|
+
vint64m2_t sum_high_i64m2 = __riscv_vmv_v_x_i64m2(0, max_vector_length);
|
|
1303
|
+
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1266
1304
|
|
|
1267
1305
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1268
1306
|
vector_length = __riscv_vsetvl_e32m1(count);
|
|
@@ -1273,18 +1311,18 @@ NK_INTERNAL void nk_reduce_moments_i32_rvv_contiguous_( //
|
|
|
1273
1311
|
vuint64m2_t data_u64m2 = __riscv_vreinterpret_v_i64m2_u64m2(data_i64m2);
|
|
1274
1312
|
|
|
1275
1313
|
// 128-bit accumulation: wrapping add on lower half
|
|
1276
|
-
vuint64m2_t sum_before_u64m2 =
|
|
1277
|
-
|
|
1314
|
+
vuint64m2_t sum_before_u64m2 = sum_low_u64m2;
|
|
1315
|
+
sum_low_u64m2 = __riscv_vadd_vv_u64m2_tu(sum_low_u64m2, sum_low_u64m2, data_u64m2, vector_length);
|
|
1278
1316
|
|
|
1279
1317
|
// Carry: new < old means unsigned overflow occurred
|
|
1280
|
-
vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(
|
|
1318
|
+
vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(sum_low_u64m2, sum_before_u64m2, vector_length);
|
|
1281
1319
|
vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0, vector_length), 1, carry_b32,
|
|
1282
1320
|
vector_length);
|
|
1283
|
-
|
|
1321
|
+
sum_high_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_high_i64m2, sum_high_i64m2, carry_i64m2, vector_length);
|
|
1284
1322
|
|
|
1285
1323
|
// Sign extension: -1 for negative, 0 for non-negative
|
|
1286
1324
|
vint64m2_t sign_ext_i64m2 = __riscv_vsra_vx_i64m2(data_i64m2, 63, vector_length);
|
|
1287
|
-
|
|
1325
|
+
sum_high_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_high_i64m2, sum_high_i64m2, sign_ext_i64m2, vector_length);
|
|
1288
1326
|
|
|
1289
1327
|
// Sumsq: i32 × i32 → i64 (widening multiply, result ≤ 2^62), saturating accumulation
|
|
1290
1328
|
vint64m2_t squares_i64m2 = __riscv_vwmul_vv_i64m2(data_i32m1, data_i32m1, vector_length);
|
|
@@ -1292,18 +1330,18 @@ NK_INTERNAL void nk_reduce_moments_i32_rvv_contiguous_( //
|
|
|
1292
1330
|
__riscv_vreinterpret_v_i64m2_u64m2(squares_i64m2), vector_length);
|
|
1293
1331
|
}
|
|
1294
1332
|
|
|
1295
|
-
*sum_ptr = nk_reduce_128bit_sum_i64m2_rvv_(
|
|
1296
|
-
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2,
|
|
1333
|
+
*sum_ptr = nk_reduce_128bit_sum_i64m2_rvv_(sum_low_u64m2, sum_high_i64m2, max_vector_length);
|
|
1334
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, max_vector_length);
|
|
1297
1335
|
}
|
|
1298
1336
|
|
|
1299
1337
|
NK_INTERNAL void nk_reduce_moments_i32_rvv_strided_( //
|
|
1300
1338
|
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1301
1339
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1302
|
-
nk_size_t
|
|
1303
|
-
// 128-bit per-lane accumulator for sum: (
|
|
1304
|
-
vuint64m2_t
|
|
1305
|
-
vint64m2_t
|
|
1306
|
-
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1340
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
1341
|
+
// 128-bit per-lane accumulator for sum: (sum_high, sum_low)
|
|
1342
|
+
vuint64m2_t sum_low_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1343
|
+
vint64m2_t sum_high_i64m2 = __riscv_vmv_v_x_i64m2(0, max_vector_length);
|
|
1344
|
+
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1307
1345
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1308
1346
|
|
|
1309
1347
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
@@ -1315,18 +1353,18 @@ NK_INTERNAL void nk_reduce_moments_i32_rvv_strided_( //
|
|
|
1315
1353
|
vuint64m2_t data_u64m2 = __riscv_vreinterpret_v_i64m2_u64m2(data_i64m2);
|
|
1316
1354
|
|
|
1317
1355
|
// 128-bit accumulation: wrapping add on lower half
|
|
1318
|
-
vuint64m2_t sum_before_u64m2 =
|
|
1319
|
-
|
|
1356
|
+
vuint64m2_t sum_before_u64m2 = sum_low_u64m2;
|
|
1357
|
+
sum_low_u64m2 = __riscv_vadd_vv_u64m2_tu(sum_low_u64m2, sum_low_u64m2, data_u64m2, vector_length);
|
|
1320
1358
|
|
|
1321
1359
|
// Carry: new < old means unsigned overflow occurred
|
|
1322
|
-
vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(
|
|
1360
|
+
vbool32_t carry_b32 = __riscv_vmsltu_vv_u64m2_b32(sum_low_u64m2, sum_before_u64m2, vector_length);
|
|
1323
1361
|
vint64m2_t carry_i64m2 = __riscv_vmerge_vxm_i64m2(__riscv_vmv_v_x_i64m2(0, vector_length), 1, carry_b32,
|
|
1324
1362
|
vector_length);
|
|
1325
|
-
|
|
1363
|
+
sum_high_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_high_i64m2, sum_high_i64m2, carry_i64m2, vector_length);
|
|
1326
1364
|
|
|
1327
1365
|
// Sign extension: -1 for negative, 0 for non-negative
|
|
1328
1366
|
vint64m2_t sign_ext_i64m2 = __riscv_vsra_vx_i64m2(data_i64m2, 63, vector_length);
|
|
1329
|
-
|
|
1367
|
+
sum_high_i64m2 = __riscv_vadd_vv_i64m2_tu(sum_high_i64m2, sum_high_i64m2, sign_ext_i64m2, vector_length);
|
|
1330
1368
|
|
|
1331
1369
|
// Sumsq: i32 × i32 → i64 (widening multiply, result ≤ 2^62), saturating accumulation
|
|
1332
1370
|
vint64m2_t squares_i64m2 = __riscv_vwmul_vv_i64m2(data_i32m1, data_i32m1, vector_length);
|
|
@@ -1334,8 +1372,8 @@ NK_INTERNAL void nk_reduce_moments_i32_rvv_strided_( //
|
|
|
1334
1372
|
__riscv_vreinterpret_v_i64m2_u64m2(squares_i64m2), vector_length);
|
|
1335
1373
|
}
|
|
1336
1374
|
|
|
1337
|
-
*sum_ptr = nk_reduce_128bit_sum_i64m2_rvv_(
|
|
1338
|
-
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2,
|
|
1375
|
+
*sum_ptr = nk_reduce_128bit_sum_i64m2_rvv_(sum_low_u64m2, sum_high_i64m2, max_vector_length);
|
|
1376
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, max_vector_length);
|
|
1339
1377
|
}
|
|
1340
1378
|
|
|
1341
1379
|
NK_PUBLIC void nk_reduce_moments_i32_rvv( //
|
|
@@ -1354,11 +1392,11 @@ NK_INTERNAL void nk_reduce_minmax_i32_rvv_contiguous_( //
|
|
|
1354
1392
|
nk_i32_t const *data_ptr, nk_size_t count, //
|
|
1355
1393
|
nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1356
1394
|
nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1357
|
-
nk_size_t
|
|
1358
|
-
vint32m1_t min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX,
|
|
1359
|
-
vint32m1_t max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN,
|
|
1360
|
-
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1361
|
-
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1395
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m1();
|
|
1396
|
+
vint32m1_t min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, max_vector_length);
|
|
1397
|
+
vint32m1_t max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, max_vector_length);
|
|
1398
|
+
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1399
|
+
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1362
1400
|
|
|
1363
1401
|
nk_size_t offset = 0;
|
|
1364
1402
|
for (nk_size_t vector_length; count > 0;
|
|
@@ -1381,34 +1419,38 @@ NK_INTERNAL void nk_reduce_minmax_i32_rvv_contiguous_( //
|
|
|
1381
1419
|
|
|
1382
1420
|
// Horizontal reduction for min
|
|
1383
1421
|
vint32m1_t init_max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, 1);
|
|
1384
|
-
nk_i32_t min_val = __riscv_vmv_x_s_i32m1_i32(
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
vuint64m2_t
|
|
1422
|
+
nk_i32_t min_val = __riscv_vmv_x_s_i32m1_i32(
|
|
1423
|
+
__riscv_vredmin_vs_i32m1_i32m1(min_i32m1, init_max_i32m1, max_vector_length));
|
|
1424
|
+
vbool32_t min_match_b32 = __riscv_vmseq_vx_i32m1_b32(min_i32m1, min_val, max_vector_length);
|
|
1425
|
+
vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, max_vector_length);
|
|
1426
|
+
vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32,
|
|
1427
|
+
max_vector_length);
|
|
1388
1428
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1389
1429
|
*min_value_ptr = min_val;
|
|
1390
1430
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1391
|
-
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1,
|
|
1431
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, max_vector_length));
|
|
1392
1432
|
|
|
1393
1433
|
// Horizontal reduction for max
|
|
1394
1434
|
vint32m1_t init_min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, 1);
|
|
1395
|
-
nk_i32_t max_val = __riscv_vmv_x_s_i32m1_i32(
|
|
1396
|
-
|
|
1397
|
-
|
|
1435
|
+
nk_i32_t max_val = __riscv_vmv_x_s_i32m1_i32(
|
|
1436
|
+
__riscv_vredmax_vs_i32m1_i32m1(max_i32m1, init_min_i32m1, max_vector_length));
|
|
1437
|
+
vbool32_t max_match_b32 = __riscv_vmseq_vx_i32m1_b32(max_i32m1, max_val, max_vector_length);
|
|
1438
|
+
vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32,
|
|
1439
|
+
max_vector_length);
|
|
1398
1440
|
*max_value_ptr = max_val;
|
|
1399
1441
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1400
|
-
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1,
|
|
1442
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, max_vector_length));
|
|
1401
1443
|
}
|
|
1402
1444
|
|
|
1403
1445
|
NK_INTERNAL void nk_reduce_minmax_i32_rvv_strided_( //
|
|
1404
1446
|
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1405
1447
|
nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1406
1448
|
nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1407
|
-
nk_size_t
|
|
1408
|
-
vint32m1_t min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX,
|
|
1409
|
-
vint32m1_t max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN,
|
|
1410
|
-
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1411
|
-
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1449
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m1();
|
|
1450
|
+
vint32m1_t min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, max_vector_length);
|
|
1451
|
+
vint32m1_t max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, max_vector_length);
|
|
1452
|
+
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1453
|
+
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1412
1454
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1413
1455
|
|
|
1414
1456
|
nk_size_t offset = 0;
|
|
@@ -1432,23 +1474,27 @@ NK_INTERNAL void nk_reduce_minmax_i32_rvv_strided_( //
|
|
|
1432
1474
|
|
|
1433
1475
|
// Horizontal reduction for min
|
|
1434
1476
|
vint32m1_t init_max_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MAX, 1);
|
|
1435
|
-
nk_i32_t min_val = __riscv_vmv_x_s_i32m1_i32(
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
vuint64m2_t
|
|
1477
|
+
nk_i32_t min_val = __riscv_vmv_x_s_i32m1_i32(
|
|
1478
|
+
__riscv_vredmin_vs_i32m1_i32m1(min_i32m1, init_max_i32m1, max_vector_length));
|
|
1479
|
+
vbool32_t min_match_b32 = __riscv_vmseq_vx_i32m1_b32(min_i32m1, min_val, max_vector_length);
|
|
1480
|
+
vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, max_vector_length);
|
|
1481
|
+
vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32,
|
|
1482
|
+
max_vector_length);
|
|
1439
1483
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1440
1484
|
*min_value_ptr = min_val;
|
|
1441
1485
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1442
|
-
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1,
|
|
1486
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, max_vector_length));
|
|
1443
1487
|
|
|
1444
1488
|
// Horizontal reduction for max
|
|
1445
1489
|
vint32m1_t init_min_i32m1 = __riscv_vmv_v_x_i32m1(NK_I32_MIN, 1);
|
|
1446
|
-
nk_i32_t max_val = __riscv_vmv_x_s_i32m1_i32(
|
|
1447
|
-
|
|
1448
|
-
|
|
1490
|
+
nk_i32_t max_val = __riscv_vmv_x_s_i32m1_i32(
|
|
1491
|
+
__riscv_vredmax_vs_i32m1_i32m1(max_i32m1, init_min_i32m1, max_vector_length));
|
|
1492
|
+
vbool32_t max_match_b32 = __riscv_vmseq_vx_i32m1_b32(max_i32m1, max_val, max_vector_length);
|
|
1493
|
+
vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32,
|
|
1494
|
+
max_vector_length);
|
|
1449
1495
|
*max_value_ptr = max_val;
|
|
1450
1496
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1451
|
-
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1,
|
|
1497
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, max_vector_length));
|
|
1452
1498
|
}
|
|
1453
1499
|
|
|
1454
1500
|
NK_PUBLIC void nk_reduce_minmax_i32_rvv( //
|
|
@@ -1475,9 +1521,9 @@ NK_PUBLIC void nk_reduce_minmax_i32_rvv( //
|
|
|
1475
1521
|
NK_INTERNAL void nk_reduce_moments_u32_rvv_contiguous_( //
|
|
1476
1522
|
nk_u32_t const *data_ptr, nk_size_t count, //
|
|
1477
1523
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1478
|
-
nk_size_t
|
|
1479
|
-
vuint64m2_t sum_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1480
|
-
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1524
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
1525
|
+
vuint64m2_t sum_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1526
|
+
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1481
1527
|
|
|
1482
1528
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1483
1529
|
vector_length = __riscv_vsetvl_e32m1(count);
|
|
@@ -1492,16 +1538,16 @@ NK_INTERNAL void nk_reduce_moments_u32_rvv_contiguous_( //
|
|
|
1492
1538
|
sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2, squares_u64m2, vector_length);
|
|
1493
1539
|
}
|
|
1494
1540
|
|
|
1495
|
-
*sum_ptr = nk_reduce_vsaddu_u64m2_rvv_(sum_u64m2,
|
|
1496
|
-
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2,
|
|
1541
|
+
*sum_ptr = nk_reduce_vsaddu_u64m2_rvv_(sum_u64m2, max_vector_length);
|
|
1542
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, max_vector_length);
|
|
1497
1543
|
}
|
|
1498
1544
|
|
|
1499
1545
|
NK_INTERNAL void nk_reduce_moments_u32_rvv_strided_( //
|
|
1500
1546
|
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1501
1547
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1502
|
-
nk_size_t
|
|
1503
|
-
vuint64m2_t sum_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1504
|
-
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1548
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
1549
|
+
vuint64m2_t sum_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1550
|
+
vuint64m2_t sumsq_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1505
1551
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1506
1552
|
|
|
1507
1553
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
@@ -1517,8 +1563,8 @@ NK_INTERNAL void nk_reduce_moments_u32_rvv_strided_( //
|
|
|
1517
1563
|
sumsq_u64m2 = __riscv_vsaddu_vv_u64m2_tu(sumsq_u64m2, sumsq_u64m2, squares_u64m2, vector_length);
|
|
1518
1564
|
}
|
|
1519
1565
|
|
|
1520
|
-
*sum_ptr = nk_reduce_vsaddu_u64m2_rvv_(sum_u64m2,
|
|
1521
|
-
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2,
|
|
1566
|
+
*sum_ptr = nk_reduce_vsaddu_u64m2_rvv_(sum_u64m2, max_vector_length);
|
|
1567
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m2_rvv_(sumsq_u64m2, max_vector_length);
|
|
1522
1568
|
}
|
|
1523
1569
|
|
|
1524
1570
|
NK_PUBLIC void nk_reduce_moments_u32_rvv( //
|
|
@@ -1537,11 +1583,11 @@ NK_INTERNAL void nk_reduce_minmax_u32_rvv_contiguous_( //
|
|
|
1537
1583
|
nk_u32_t const *data_ptr, nk_size_t count, //
|
|
1538
1584
|
nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1539
1585
|
nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1540
|
-
nk_size_t
|
|
1541
|
-
vuint32m1_t min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX,
|
|
1542
|
-
vuint32m1_t max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN,
|
|
1543
|
-
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1544
|
-
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1586
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m1();
|
|
1587
|
+
vuint32m1_t min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, max_vector_length);
|
|
1588
|
+
vuint32m1_t max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, max_vector_length);
|
|
1589
|
+
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1590
|
+
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1545
1591
|
|
|
1546
1592
|
nk_size_t offset = 0;
|
|
1547
1593
|
for (nk_size_t vector_length; count > 0;
|
|
@@ -1564,34 +1610,38 @@ NK_INTERNAL void nk_reduce_minmax_u32_rvv_contiguous_( //
|
|
|
1564
1610
|
|
|
1565
1611
|
// Horizontal reduction for min
|
|
1566
1612
|
vuint32m1_t init_max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, 1);
|
|
1567
|
-
nk_u32_t min_val = __riscv_vmv_x_s_u32m1_u32(
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
vuint64m2_t
|
|
1613
|
+
nk_u32_t min_val = __riscv_vmv_x_s_u32m1_u32(
|
|
1614
|
+
__riscv_vredminu_vs_u32m1_u32m1(min_u32m1, init_max_u32m1, max_vector_length));
|
|
1615
|
+
vbool32_t min_match_b32 = __riscv_vmseq_vx_u32m1_b32(min_u32m1, min_val, max_vector_length);
|
|
1616
|
+
vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, max_vector_length);
|
|
1617
|
+
vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32,
|
|
1618
|
+
max_vector_length);
|
|
1571
1619
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1572
1620
|
*min_value_ptr = min_val;
|
|
1573
1621
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1574
|
-
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1,
|
|
1622
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, max_vector_length));
|
|
1575
1623
|
|
|
1576
1624
|
// Horizontal reduction for max
|
|
1577
1625
|
vuint32m1_t init_min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, 1);
|
|
1578
|
-
nk_u32_t max_val = __riscv_vmv_x_s_u32m1_u32(
|
|
1579
|
-
|
|
1580
|
-
|
|
1626
|
+
nk_u32_t max_val = __riscv_vmv_x_s_u32m1_u32(
|
|
1627
|
+
__riscv_vredmaxu_vs_u32m1_u32m1(max_u32m1, init_min_u32m1, max_vector_length));
|
|
1628
|
+
vbool32_t max_match_b32 = __riscv_vmseq_vx_u32m1_b32(max_u32m1, max_val, max_vector_length);
|
|
1629
|
+
vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32,
|
|
1630
|
+
max_vector_length);
|
|
1581
1631
|
*max_value_ptr = max_val;
|
|
1582
1632
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1583
|
-
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1,
|
|
1633
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, max_vector_length));
|
|
1584
1634
|
}
|
|
1585
1635
|
|
|
1586
1636
|
NK_INTERNAL void nk_reduce_minmax_u32_rvv_strided_( //
|
|
1587
1637
|
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1588
1638
|
nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1589
1639
|
nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1590
|
-
nk_size_t
|
|
1591
|
-
vuint32m1_t min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX,
|
|
1592
|
-
vuint32m1_t max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN,
|
|
1593
|
-
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1594
|
-
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0,
|
|
1640
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m1();
|
|
1641
|
+
vuint32m1_t min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, max_vector_length);
|
|
1642
|
+
vuint32m1_t max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, max_vector_length);
|
|
1643
|
+
vuint64m2_t min_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1644
|
+
vuint64m2_t max_indices_u64m2 = __riscv_vmv_v_x_u64m2(0, max_vector_length);
|
|
1595
1645
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1596
1646
|
|
|
1597
1647
|
nk_size_t offset = 0;
|
|
@@ -1615,23 +1665,27 @@ NK_INTERNAL void nk_reduce_minmax_u32_rvv_strided_( //
|
|
|
1615
1665
|
|
|
1616
1666
|
// Horizontal reduction for min
|
|
1617
1667
|
vuint32m1_t init_max_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MAX, 1);
|
|
1618
|
-
nk_u32_t min_val = __riscv_vmv_x_s_u32m1_u32(
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
vuint64m2_t
|
|
1668
|
+
nk_u32_t min_val = __riscv_vmv_x_s_u32m1_u32(
|
|
1669
|
+
__riscv_vredminu_vs_u32m1_u32m1(min_u32m1, init_max_u32m1, max_vector_length));
|
|
1670
|
+
vbool32_t min_match_b32 = __riscv_vmseq_vx_u32m1_b32(min_u32m1, min_val, max_vector_length);
|
|
1671
|
+
vuint64m2_t sentinel_u64m2 = __riscv_vmv_v_x_u64m2(NK_U64_MAX, max_vector_length);
|
|
1672
|
+
vuint64m2_t min_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, min_indices_u64m2, min_match_b32,
|
|
1673
|
+
max_vector_length);
|
|
1622
1674
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1623
1675
|
*min_value_ptr = min_val;
|
|
1624
1676
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1625
|
-
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1,
|
|
1677
|
+
__riscv_vredminu_vs_u64m2_u64m1(min_cands_u64m2, init_umax_u64m1, max_vector_length));
|
|
1626
1678
|
|
|
1627
1679
|
// Horizontal reduction for max
|
|
1628
1680
|
vuint32m1_t init_min_u32m1 = __riscv_vmv_v_x_u32m1(NK_U32_MIN, 1);
|
|
1629
|
-
nk_u32_t max_val = __riscv_vmv_x_s_u32m1_u32(
|
|
1630
|
-
|
|
1631
|
-
|
|
1681
|
+
nk_u32_t max_val = __riscv_vmv_x_s_u32m1_u32(
|
|
1682
|
+
__riscv_vredmaxu_vs_u32m1_u32m1(max_u32m1, init_min_u32m1, max_vector_length));
|
|
1683
|
+
vbool32_t max_match_b32 = __riscv_vmseq_vx_u32m1_b32(max_u32m1, max_val, max_vector_length);
|
|
1684
|
+
vuint64m2_t max_cands_u64m2 = __riscv_vmerge_vvm_u64m2(sentinel_u64m2, max_indices_u64m2, max_match_b32,
|
|
1685
|
+
max_vector_length);
|
|
1632
1686
|
*max_value_ptr = max_val;
|
|
1633
1687
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1634
|
-
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1,
|
|
1688
|
+
__riscv_vredminu_vs_u64m2_u64m1(max_cands_u64m2, init_umax_u64m1, max_vector_length));
|
|
1635
1689
|
}
|
|
1636
1690
|
|
|
1637
1691
|
NK_PUBLIC void nk_reduce_minmax_u32_rvv( //
|
|
@@ -1658,11 +1712,11 @@ NK_PUBLIC void nk_reduce_minmax_u32_rvv( //
|
|
|
1658
1712
|
NK_INTERNAL void nk_reduce_moments_i64_rvv_contiguous_( //
|
|
1659
1713
|
nk_i64_t const *data_ptr, nk_size_t count, //
|
|
1660
1714
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1661
|
-
nk_size_t
|
|
1662
|
-
// 128-bit per-lane accumulator for sum: (
|
|
1663
|
-
vuint64m1_t
|
|
1664
|
-
vint64m1_t
|
|
1665
|
-
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1715
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
1716
|
+
// 128-bit per-lane accumulator for sum: (sum_high, sum_low)
|
|
1717
|
+
vuint64m1_t sum_low_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1718
|
+
vint64m1_t sum_high_i64m1 = __riscv_vmv_v_x_i64m1(0, max_vector_length);
|
|
1719
|
+
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1666
1720
|
|
|
1667
1721
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1668
1722
|
vector_length = __riscv_vsetvl_e64m1(count);
|
|
@@ -1670,18 +1724,18 @@ NK_INTERNAL void nk_reduce_moments_i64_rvv_contiguous_( //
|
|
|
1670
1724
|
|
|
1671
1725
|
// 128-bit sum accumulation: wrapping add on lower half
|
|
1672
1726
|
vuint64m1_t data_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(data_i64m1);
|
|
1673
|
-
vuint64m1_t sum_before_u64m1 =
|
|
1674
|
-
|
|
1727
|
+
vuint64m1_t sum_before_u64m1 = sum_low_u64m1;
|
|
1728
|
+
sum_low_u64m1 = __riscv_vadd_vv_u64m1_tu(sum_low_u64m1, sum_low_u64m1, data_u64m1, vector_length);
|
|
1675
1729
|
|
|
1676
1730
|
// Carry: new < old means unsigned overflow occurred
|
|
1677
|
-
vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(
|
|
1731
|
+
vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(sum_low_u64m1, sum_before_u64m1, vector_length);
|
|
1678
1732
|
vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0, vector_length), 1, carry_b64,
|
|
1679
1733
|
vector_length);
|
|
1680
|
-
|
|
1734
|
+
sum_high_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_high_i64m1, sum_high_i64m1, carry_i64m1, vector_length);
|
|
1681
1735
|
|
|
1682
1736
|
// Sign extension: -1 for negative, 0 for non-negative
|
|
1683
1737
|
vint64m1_t sign_ext_i64m1 = __riscv_vsra_vx_i64m1(data_i64m1, 63, vector_length);
|
|
1684
|
-
|
|
1738
|
+
sum_high_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_high_i64m1, sum_high_i64m1, sign_ext_i64m1, vector_length);
|
|
1685
1739
|
|
|
1686
1740
|
// Sumsq: abs(val)² with overflow detection
|
|
1687
1741
|
vint64m1_t negated_i64m1 = __riscv_vneg_v_i64m1(data_i64m1, vector_length);
|
|
@@ -1695,18 +1749,18 @@ NK_INTERNAL void nk_reduce_moments_i64_rvv_contiguous_( //
|
|
|
1695
1749
|
sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
|
|
1696
1750
|
}
|
|
1697
1751
|
|
|
1698
|
-
*sum_ptr = nk_reduce_128bit_sum_i64m1_rvv_(
|
|
1699
|
-
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1,
|
|
1752
|
+
*sum_ptr = nk_reduce_128bit_sum_i64m1_rvv_(sum_low_u64m1, sum_high_i64m1, max_vector_length);
|
|
1753
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, max_vector_length);
|
|
1700
1754
|
}
|
|
1701
1755
|
|
|
1702
1756
|
NK_INTERNAL void nk_reduce_moments_i64_rvv_strided_( //
|
|
1703
1757
|
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1704
1758
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1705
|
-
nk_size_t
|
|
1706
|
-
// 128-bit per-lane accumulator for sum: (
|
|
1707
|
-
vuint64m1_t
|
|
1708
|
-
vint64m1_t
|
|
1709
|
-
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1759
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
1760
|
+
// 128-bit per-lane accumulator for sum: (sum_high, sum_low)
|
|
1761
|
+
vuint64m1_t sum_low_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1762
|
+
vint64m1_t sum_high_i64m1 = __riscv_vmv_v_x_i64m1(0, max_vector_length);
|
|
1763
|
+
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1710
1764
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1711
1765
|
|
|
1712
1766
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
@@ -1715,18 +1769,18 @@ NK_INTERNAL void nk_reduce_moments_i64_rvv_strided_( //
|
|
|
1715
1769
|
|
|
1716
1770
|
// 128-bit sum accumulation: wrapping add on lower half
|
|
1717
1771
|
vuint64m1_t data_u64m1 = __riscv_vreinterpret_v_i64m1_u64m1(data_i64m1);
|
|
1718
|
-
vuint64m1_t sum_before_u64m1 =
|
|
1719
|
-
|
|
1772
|
+
vuint64m1_t sum_before_u64m1 = sum_low_u64m1;
|
|
1773
|
+
sum_low_u64m1 = __riscv_vadd_vv_u64m1_tu(sum_low_u64m1, sum_low_u64m1, data_u64m1, vector_length);
|
|
1720
1774
|
|
|
1721
1775
|
// Carry: new < old means unsigned overflow occurred
|
|
1722
|
-
vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(
|
|
1776
|
+
vbool64_t carry_b64 = __riscv_vmsltu_vv_u64m1_b64(sum_low_u64m1, sum_before_u64m1, vector_length);
|
|
1723
1777
|
vint64m1_t carry_i64m1 = __riscv_vmerge_vxm_i64m1(__riscv_vmv_v_x_i64m1(0, vector_length), 1, carry_b64,
|
|
1724
1778
|
vector_length);
|
|
1725
|
-
|
|
1779
|
+
sum_high_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_high_i64m1, sum_high_i64m1, carry_i64m1, vector_length);
|
|
1726
1780
|
|
|
1727
1781
|
// Sign extension: -1 for negative, 0 for non-negative
|
|
1728
1782
|
vint64m1_t sign_ext_i64m1 = __riscv_vsra_vx_i64m1(data_i64m1, 63, vector_length);
|
|
1729
|
-
|
|
1783
|
+
sum_high_i64m1 = __riscv_vadd_vv_i64m1_tu(sum_high_i64m1, sum_high_i64m1, sign_ext_i64m1, vector_length);
|
|
1730
1784
|
|
|
1731
1785
|
// Sumsq: abs(val)² with overflow detection
|
|
1732
1786
|
vint64m1_t negated_i64m1 = __riscv_vneg_v_i64m1(data_i64m1, vector_length);
|
|
@@ -1740,8 +1794,8 @@ NK_INTERNAL void nk_reduce_moments_i64_rvv_strided_( //
|
|
|
1740
1794
|
sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
|
|
1741
1795
|
}
|
|
1742
1796
|
|
|
1743
|
-
*sum_ptr = nk_reduce_128bit_sum_i64m1_rvv_(
|
|
1744
|
-
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1,
|
|
1797
|
+
*sum_ptr = nk_reduce_128bit_sum_i64m1_rvv_(sum_low_u64m1, sum_high_i64m1, max_vector_length);
|
|
1798
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, max_vector_length);
|
|
1745
1799
|
}
|
|
1746
1800
|
|
|
1747
1801
|
NK_PUBLIC void nk_reduce_moments_i64_rvv( //
|
|
@@ -1760,11 +1814,11 @@ NK_INTERNAL void nk_reduce_minmax_i64_rvv_contiguous_( //
|
|
|
1760
1814
|
nk_i64_t const *data_ptr, nk_size_t count, //
|
|
1761
1815
|
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1762
1816
|
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1763
|
-
nk_size_t
|
|
1764
|
-
vint64m1_t min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX,
|
|
1765
|
-
vint64m1_t max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN,
|
|
1766
|
-
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1767
|
-
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1817
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
1818
|
+
vint64m1_t min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, max_vector_length);
|
|
1819
|
+
vint64m1_t max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, max_vector_length);
|
|
1820
|
+
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1821
|
+
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1768
1822
|
|
|
1769
1823
|
nk_size_t offset = 0;
|
|
1770
1824
|
for (nk_size_t vector_length; count > 0;
|
|
@@ -1787,34 +1841,38 @@ NK_INTERNAL void nk_reduce_minmax_i64_rvv_contiguous_( //
|
|
|
1787
1841
|
|
|
1788
1842
|
// Horizontal reduction for min
|
|
1789
1843
|
vint64m1_t init_max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, 1);
|
|
1790
|
-
nk_i64_t min_val = __riscv_vmv_x_s_i64m1_i64(
|
|
1791
|
-
|
|
1792
|
-
|
|
1793
|
-
vuint64m1_t
|
|
1844
|
+
nk_i64_t min_val = __riscv_vmv_x_s_i64m1_i64(
|
|
1845
|
+
__riscv_vredmin_vs_i64m1_i64m1(min_i64m1, init_max_i64m1, max_vector_length));
|
|
1846
|
+
vbool64_t min_match_b64 = __riscv_vmseq_vx_i64m1_b64(min_i64m1, min_val, max_vector_length);
|
|
1847
|
+
vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, max_vector_length);
|
|
1848
|
+
vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64,
|
|
1849
|
+
max_vector_length);
|
|
1794
1850
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1795
1851
|
*min_value_ptr = min_val;
|
|
1796
1852
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1797
|
-
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1,
|
|
1853
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, max_vector_length));
|
|
1798
1854
|
|
|
1799
1855
|
// Horizontal reduction for max
|
|
1800
1856
|
vint64m1_t init_min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, 1);
|
|
1801
|
-
nk_i64_t max_val = __riscv_vmv_x_s_i64m1_i64(
|
|
1802
|
-
|
|
1803
|
-
|
|
1857
|
+
nk_i64_t max_val = __riscv_vmv_x_s_i64m1_i64(
|
|
1858
|
+
__riscv_vredmax_vs_i64m1_i64m1(max_i64m1, init_min_i64m1, max_vector_length));
|
|
1859
|
+
vbool64_t max_match_b64 = __riscv_vmseq_vx_i64m1_b64(max_i64m1, max_val, max_vector_length);
|
|
1860
|
+
vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64,
|
|
1861
|
+
max_vector_length);
|
|
1804
1862
|
*max_value_ptr = max_val;
|
|
1805
1863
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1806
|
-
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1,
|
|
1864
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, max_vector_length));
|
|
1807
1865
|
}
|
|
1808
1866
|
|
|
1809
1867
|
NK_INTERNAL void nk_reduce_minmax_i64_rvv_strided_( //
|
|
1810
1868
|
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1811
1869
|
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1812
1870
|
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1813
|
-
nk_size_t
|
|
1814
|
-
vint64m1_t min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX,
|
|
1815
|
-
vint64m1_t max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN,
|
|
1816
|
-
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1817
|
-
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1871
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
1872
|
+
vint64m1_t min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, max_vector_length);
|
|
1873
|
+
vint64m1_t max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, max_vector_length);
|
|
1874
|
+
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1875
|
+
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1818
1876
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1819
1877
|
|
|
1820
1878
|
nk_size_t offset = 0;
|
|
@@ -1838,23 +1896,27 @@ NK_INTERNAL void nk_reduce_minmax_i64_rvv_strided_( //
|
|
|
1838
1896
|
|
|
1839
1897
|
// Horizontal reduction for min
|
|
1840
1898
|
vint64m1_t init_max_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MAX, 1);
|
|
1841
|
-
nk_i64_t min_val = __riscv_vmv_x_s_i64m1_i64(
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
vuint64m1_t
|
|
1899
|
+
nk_i64_t min_val = __riscv_vmv_x_s_i64m1_i64(
|
|
1900
|
+
__riscv_vredmin_vs_i64m1_i64m1(min_i64m1, init_max_i64m1, max_vector_length));
|
|
1901
|
+
vbool64_t min_match_b64 = __riscv_vmseq_vx_i64m1_b64(min_i64m1, min_val, max_vector_length);
|
|
1902
|
+
vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, max_vector_length);
|
|
1903
|
+
vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64,
|
|
1904
|
+
max_vector_length);
|
|
1845
1905
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1846
1906
|
*min_value_ptr = min_val;
|
|
1847
1907
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1848
|
-
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1,
|
|
1908
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, max_vector_length));
|
|
1849
1909
|
|
|
1850
1910
|
// Horizontal reduction for max
|
|
1851
1911
|
vint64m1_t init_min_i64m1 = __riscv_vmv_v_x_i64m1(NK_I64_MIN, 1);
|
|
1852
|
-
nk_i64_t max_val = __riscv_vmv_x_s_i64m1_i64(
|
|
1853
|
-
|
|
1854
|
-
|
|
1912
|
+
nk_i64_t max_val = __riscv_vmv_x_s_i64m1_i64(
|
|
1913
|
+
__riscv_vredmax_vs_i64m1_i64m1(max_i64m1, init_min_i64m1, max_vector_length));
|
|
1914
|
+
vbool64_t max_match_b64 = __riscv_vmseq_vx_i64m1_b64(max_i64m1, max_val, max_vector_length);
|
|
1915
|
+
vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64,
|
|
1916
|
+
max_vector_length);
|
|
1855
1917
|
*max_value_ptr = max_val;
|
|
1856
1918
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1857
|
-
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1,
|
|
1919
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, max_vector_length));
|
|
1858
1920
|
}
|
|
1859
1921
|
|
|
1860
1922
|
NK_PUBLIC void nk_reduce_minmax_i64_rvv( //
|
|
@@ -1881,9 +1943,9 @@ NK_PUBLIC void nk_reduce_minmax_i64_rvv( //
|
|
|
1881
1943
|
NK_INTERNAL void nk_reduce_moments_u64_rvv_contiguous_( //
|
|
1882
1944
|
nk_u64_t const *data_ptr, nk_size_t count, //
|
|
1883
1945
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1884
|
-
nk_size_t
|
|
1885
|
-
vuint64m1_t sum_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1886
|
-
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1946
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
1947
|
+
vuint64m1_t sum_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1948
|
+
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1887
1949
|
|
|
1888
1950
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
1889
1951
|
vector_length = __riscv_vsetvl_e64m1(count);
|
|
@@ -1901,16 +1963,16 @@ NK_INTERNAL void nk_reduce_moments_u64_rvv_contiguous_( //
|
|
|
1901
1963
|
sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
|
|
1902
1964
|
}
|
|
1903
1965
|
|
|
1904
|
-
*sum_ptr = nk_reduce_vsaddu_u64m1_rvv_(sum_u64m1,
|
|
1905
|
-
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1,
|
|
1966
|
+
*sum_ptr = nk_reduce_vsaddu_u64m1_rvv_(sum_u64m1, max_vector_length);
|
|
1967
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, max_vector_length);
|
|
1906
1968
|
}
|
|
1907
1969
|
|
|
1908
1970
|
NK_INTERNAL void nk_reduce_moments_u64_rvv_strided_( //
|
|
1909
1971
|
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1910
1972
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1911
|
-
nk_size_t
|
|
1912
|
-
vuint64m1_t sum_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1913
|
-
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1973
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
1974
|
+
vuint64m1_t sum_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1975
|
+
vuint64m1_t sumsq_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1914
1976
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
1915
1977
|
|
|
1916
1978
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
@@ -1929,8 +1991,8 @@ NK_INTERNAL void nk_reduce_moments_u64_rvv_strided_( //
|
|
|
1929
1991
|
sumsq_u64m1 = __riscv_vsaddu_vv_u64m1_tu(sumsq_u64m1, sumsq_u64m1, squares_u64m1, vector_length);
|
|
1930
1992
|
}
|
|
1931
1993
|
|
|
1932
|
-
*sum_ptr = nk_reduce_vsaddu_u64m1_rvv_(sum_u64m1,
|
|
1933
|
-
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1,
|
|
1994
|
+
*sum_ptr = nk_reduce_vsaddu_u64m1_rvv_(sum_u64m1, max_vector_length);
|
|
1995
|
+
*sumsq_ptr = nk_reduce_vsaddu_u64m1_rvv_(sumsq_u64m1, max_vector_length);
|
|
1934
1996
|
}
|
|
1935
1997
|
|
|
1936
1998
|
NK_PUBLIC void nk_reduce_moments_u64_rvv( //
|
|
@@ -1949,11 +2011,11 @@ NK_INTERNAL void nk_reduce_minmax_u64_rvv_contiguous_( //
|
|
|
1949
2011
|
nk_u64_t const *data_ptr, nk_size_t count, //
|
|
1950
2012
|
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1951
2013
|
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1952
|
-
nk_size_t
|
|
1953
|
-
vuint64m1_t min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX,
|
|
1954
|
-
vuint64m1_t max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN,
|
|
1955
|
-
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
1956
|
-
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
2014
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
2015
|
+
vuint64m1_t min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, max_vector_length);
|
|
2016
|
+
vuint64m1_t max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, max_vector_length);
|
|
2017
|
+
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
2018
|
+
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
1957
2019
|
|
|
1958
2020
|
nk_size_t offset = 0;
|
|
1959
2021
|
for (nk_size_t vector_length; count > 0;
|
|
@@ -1976,34 +2038,38 @@ NK_INTERNAL void nk_reduce_minmax_u64_rvv_contiguous_( //
|
|
|
1976
2038
|
|
|
1977
2039
|
// Horizontal reduction for min
|
|
1978
2040
|
vuint64m1_t init_max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1979
|
-
nk_u64_t min_val = __riscv_vmv_x_s_u64m1_u64(
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
vuint64m1_t
|
|
2041
|
+
nk_u64_t min_val = __riscv_vmv_x_s_u64m1_u64(
|
|
2042
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_u64m1, init_max_u64m1, max_vector_length));
|
|
2043
|
+
vbool64_t min_match_b64 = __riscv_vmseq_vx_u64m1_b64(min_u64m1, min_val, max_vector_length);
|
|
2044
|
+
vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, max_vector_length);
|
|
2045
|
+
vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64,
|
|
2046
|
+
max_vector_length);
|
|
1983
2047
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
1984
2048
|
*min_value_ptr = min_val;
|
|
1985
2049
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1986
|
-
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1,
|
|
2050
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, max_vector_length));
|
|
1987
2051
|
|
|
1988
2052
|
// Horizontal reduction for max
|
|
1989
2053
|
vuint64m1_t init_min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, 1);
|
|
1990
|
-
nk_u64_t max_val = __riscv_vmv_x_s_u64m1_u64(
|
|
1991
|
-
|
|
1992
|
-
|
|
2054
|
+
nk_u64_t max_val = __riscv_vmv_x_s_u64m1_u64(
|
|
2055
|
+
__riscv_vredmaxu_vs_u64m1_u64m1(max_u64m1, init_min_u64m1, max_vector_length));
|
|
2056
|
+
vbool64_t max_match_b64 = __riscv_vmseq_vx_u64m1_b64(max_u64m1, max_val, max_vector_length);
|
|
2057
|
+
vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64,
|
|
2058
|
+
max_vector_length);
|
|
1993
2059
|
*max_value_ptr = max_val;
|
|
1994
2060
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
1995
|
-
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1,
|
|
2061
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, max_vector_length));
|
|
1996
2062
|
}
|
|
1997
2063
|
|
|
1998
2064
|
NK_INTERNAL void nk_reduce_minmax_u64_rvv_strided_( //
|
|
1999
2065
|
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2000
2066
|
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2001
2067
|
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2002
|
-
nk_size_t
|
|
2003
|
-
vuint64m1_t min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX,
|
|
2004
|
-
vuint64m1_t max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN,
|
|
2005
|
-
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
2006
|
-
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0,
|
|
2068
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
2069
|
+
vuint64m1_t min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, max_vector_length);
|
|
2070
|
+
vuint64m1_t max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, max_vector_length);
|
|
2071
|
+
vuint64m1_t min_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
2072
|
+
vuint64m1_t max_indices_u64m1 = __riscv_vmv_v_x_u64m1(0, max_vector_length);
|
|
2007
2073
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2008
2074
|
|
|
2009
2075
|
nk_size_t offset = 0;
|
|
@@ -2027,23 +2093,27 @@ NK_INTERNAL void nk_reduce_minmax_u64_rvv_strided_( //
|
|
|
2027
2093
|
|
|
2028
2094
|
// Horizontal reduction for min
|
|
2029
2095
|
vuint64m1_t init_max_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2030
|
-
nk_u64_t min_val = __riscv_vmv_x_s_u64m1_u64(
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
vuint64m1_t
|
|
2096
|
+
nk_u64_t min_val = __riscv_vmv_x_s_u64m1_u64(
|
|
2097
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_u64m1, init_max_u64m1, max_vector_length));
|
|
2098
|
+
vbool64_t min_match_b64 = __riscv_vmseq_vx_u64m1_b64(min_u64m1, min_val, max_vector_length);
|
|
2099
|
+
vuint64m1_t sentinel_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, max_vector_length);
|
|
2100
|
+
vuint64m1_t min_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, min_indices_u64m1, min_match_b64,
|
|
2101
|
+
max_vector_length);
|
|
2034
2102
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2035
2103
|
*min_value_ptr = min_val;
|
|
2036
2104
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2037
|
-
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1,
|
|
2105
|
+
__riscv_vredminu_vs_u64m1_u64m1(min_cands_u64m1, init_umax_u64m1, max_vector_length));
|
|
2038
2106
|
|
|
2039
2107
|
// Horizontal reduction for max
|
|
2040
2108
|
vuint64m1_t init_min_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MIN, 1);
|
|
2041
|
-
nk_u64_t max_val = __riscv_vmv_x_s_u64m1_u64(
|
|
2042
|
-
|
|
2043
|
-
|
|
2109
|
+
nk_u64_t max_val = __riscv_vmv_x_s_u64m1_u64(
|
|
2110
|
+
__riscv_vredmaxu_vs_u64m1_u64m1(max_u64m1, init_min_u64m1, max_vector_length));
|
|
2111
|
+
vbool64_t max_match_b64 = __riscv_vmseq_vx_u64m1_b64(max_u64m1, max_val, max_vector_length);
|
|
2112
|
+
vuint64m1_t max_cands_u64m1 = __riscv_vmerge_vvm_u64m1(sentinel_u64m1, max_indices_u64m1, max_match_b64,
|
|
2113
|
+
max_vector_length);
|
|
2044
2114
|
*max_value_ptr = max_val;
|
|
2045
2115
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2046
|
-
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1,
|
|
2116
|
+
__riscv_vredminu_vs_u64m1_u64m1(max_cands_u64m1, init_umax_u64m1, max_vector_length));
|
|
2047
2117
|
}
|
|
2048
2118
|
|
|
2049
2119
|
NK_PUBLIC void nk_reduce_minmax_u64_rvv( //
|
|
@@ -2070,13 +2140,13 @@ NK_PUBLIC void nk_reduce_minmax_u64_rvv( //
|
|
|
2070
2140
|
NK_INTERNAL void nk_reduce_moments_bf16_rvv_contiguous_( //
|
|
2071
2141
|
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
2072
2142
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2073
|
-
nk_size_t
|
|
2074
|
-
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2075
|
-
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2143
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
2144
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2145
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2076
2146
|
|
|
2077
2147
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
2078
2148
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2079
|
-
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((
|
|
2149
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)data_ptr, vector_length);
|
|
2080
2150
|
|
|
2081
2151
|
// Convert bf16 → f32 (m1 → m2)
|
|
2082
2152
|
vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
@@ -2091,21 +2161,22 @@ NK_INTERNAL void nk_reduce_moments_bf16_rvv_contiguous_( //
|
|
|
2091
2161
|
|
|
2092
2162
|
// Horizontal reduction
|
|
2093
2163
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2094
|
-
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1,
|
|
2095
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(
|
|
2164
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, max_vector_length)),
|
|
2165
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(
|
|
2166
|
+
__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, max_vector_length));
|
|
2096
2167
|
}
|
|
2097
2168
|
|
|
2098
2169
|
NK_INTERNAL void nk_reduce_moments_bf16_rvv_strided_( //
|
|
2099
2170
|
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2100
2171
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2101
|
-
nk_size_t
|
|
2102
|
-
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2103
|
-
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2172
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
2173
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2174
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2104
2175
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2105
2176
|
|
|
2106
2177
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
2107
2178
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2108
|
-
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((
|
|
2179
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((nk_u16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2109
2180
|
|
|
2110
2181
|
// Convert bf16 → f32 (m1 → m2)
|
|
2111
2182
|
vfloat32m2_t data_f32m2 = nk_bf16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
@@ -2120,8 +2191,9 @@ NK_INTERNAL void nk_reduce_moments_bf16_rvv_strided_( //
|
|
|
2120
2191
|
|
|
2121
2192
|
// Horizontal reduction
|
|
2122
2193
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2123
|
-
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1,
|
|
2124
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(
|
|
2194
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, max_vector_length)),
|
|
2195
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(
|
|
2196
|
+
__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, max_vector_length));
|
|
2125
2197
|
}
|
|
2126
2198
|
|
|
2127
2199
|
NK_PUBLIC void nk_reduce_moments_bf16_rvv( //
|
|
@@ -2140,17 +2212,17 @@ NK_INTERNAL void nk_reduce_minmax_bf16_rvv_contiguous_( //
|
|
|
2140
2212
|
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
2141
2213
|
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2142
2214
|
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2143
|
-
nk_size_t
|
|
2144
|
-
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7F80,
|
|
2145
|
-
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFF80,
|
|
2146
|
-
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
2147
|
-
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
2215
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m1();
|
|
2216
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7F80, max_vector_length); // +inf in bf16
|
|
2217
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFF80, max_vector_length); // -inf in bf16
|
|
2218
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
2219
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
2148
2220
|
|
|
2149
2221
|
nk_size_t offset = 0;
|
|
2150
2222
|
for (nk_size_t vector_length; count > 0;
|
|
2151
2223
|
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
2152
2224
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2153
|
-
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((
|
|
2225
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)data_ptr, vector_length);
|
|
2154
2226
|
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
2155
2227
|
vector_length);
|
|
2156
2228
|
|
|
@@ -2171,58 +2243,61 @@ NK_INTERNAL void nk_reduce_minmax_bf16_rvv_contiguous_( //
|
|
|
2171
2243
|
}
|
|
2172
2244
|
|
|
2173
2245
|
// Horizontal reduction
|
|
2174
|
-
vfloat32m2_t final_min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1,
|
|
2246
|
+
vfloat32m2_t final_min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, max_vector_length);
|
|
2175
2247
|
vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
2176
2248
|
nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2177
|
-
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1,
|
|
2178
|
-
vfloat32m2_t final_max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1,
|
|
2249
|
+
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, max_vector_length));
|
|
2250
|
+
vfloat32m2_t final_max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, max_vector_length);
|
|
2179
2251
|
vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
2180
2252
|
nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2181
|
-
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1,
|
|
2253
|
+
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, max_vector_length));
|
|
2182
2254
|
if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
|
|
2183
2255
|
*min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
|
|
2184
2256
|
*max_index_ptr = NK_SIZE_MAX;
|
|
2185
2257
|
return;
|
|
2186
2258
|
}
|
|
2187
2259
|
|
|
2188
|
-
vfloat32m2_t converted_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1,
|
|
2189
|
-
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32,
|
|
2190
|
-
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX,
|
|
2191
|
-
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
2260
|
+
vfloat32m2_t converted_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, max_vector_length);
|
|
2261
|
+
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, max_vector_length);
|
|
2262
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, max_vector_length);
|
|
2263
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
2264
|
+
max_vector_length);
|
|
2192
2265
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2193
2266
|
|
|
2194
|
-
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2195
|
-
|
|
2267
|
+
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(__riscv_vslidedown_vx_u16m1(
|
|
2268
|
+
min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, max_vector_length), max_vector_length));
|
|
2196
2269
|
*min_value_ptr = *(nk_bf16_t *)&min_raw;
|
|
2197
2270
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2198
|
-
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1,
|
|
2271
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
2199
2272
|
|
|
2200
|
-
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_bf16m1_to_f32m2_rvv_(max_u16m1,
|
|
2201
|
-
|
|
2273
|
+
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_bf16m1_to_f32m2_rvv_(max_u16m1, max_vector_length),
|
|
2274
|
+
max_val_f32, max_vector_length);
|
|
2275
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16,
|
|
2276
|
+
max_vector_length);
|
|
2202
2277
|
|
|
2203
|
-
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2204
|
-
|
|
2278
|
+
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(__riscv_vslidedown_vx_u16m1(
|
|
2279
|
+
max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, max_vector_length), max_vector_length));
|
|
2205
2280
|
*max_value_ptr = *(nk_bf16_t *)&max_raw;
|
|
2206
2281
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2207
|
-
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1,
|
|
2282
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
2208
2283
|
}
|
|
2209
2284
|
|
|
2210
2285
|
NK_INTERNAL void nk_reduce_minmax_bf16_rvv_strided_( //
|
|
2211
2286
|
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2212
2287
|
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2213
2288
|
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2214
|
-
nk_size_t
|
|
2215
|
-
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7F80,
|
|
2216
|
-
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFF80,
|
|
2217
|
-
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
2218
|
-
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
2289
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m1();
|
|
2290
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7F80, max_vector_length); // +inf in bf16
|
|
2291
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFF80, max_vector_length); // -inf in bf16
|
|
2292
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
2293
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
2219
2294
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2220
2295
|
|
|
2221
2296
|
nk_size_t offset = 0;
|
|
2222
2297
|
for (nk_size_t vector_length; count > 0;
|
|
2223
2298
|
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
2224
2299
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2225
|
-
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((
|
|
2300
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((nk_u16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2226
2301
|
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
2227
2302
|
vector_length);
|
|
2228
2303
|
|
|
@@ -2243,40 +2318,43 @@ NK_INTERNAL void nk_reduce_minmax_bf16_rvv_strided_( //
|
|
|
2243
2318
|
}
|
|
2244
2319
|
|
|
2245
2320
|
// Horizontal reduction (same as contiguous)
|
|
2246
|
-
vfloat32m2_t final_min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1,
|
|
2321
|
+
vfloat32m2_t final_min_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, max_vector_length);
|
|
2247
2322
|
vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
2248
2323
|
nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2249
|
-
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1,
|
|
2250
|
-
vfloat32m2_t final_max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1,
|
|
2324
|
+
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, max_vector_length));
|
|
2325
|
+
vfloat32m2_t final_max_f32m2 = nk_bf16m1_to_f32m2_rvv_(max_u16m1, max_vector_length);
|
|
2251
2326
|
vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
2252
2327
|
nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2253
|
-
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1,
|
|
2328
|
+
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, max_vector_length));
|
|
2254
2329
|
if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
|
|
2255
2330
|
*min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
|
|
2256
2331
|
*max_index_ptr = NK_SIZE_MAX;
|
|
2257
2332
|
return;
|
|
2258
2333
|
}
|
|
2259
2334
|
|
|
2260
|
-
vfloat32m2_t converted_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1,
|
|
2261
|
-
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32,
|
|
2262
|
-
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX,
|
|
2263
|
-
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
2335
|
+
vfloat32m2_t converted_f32m2 = nk_bf16m1_to_f32m2_rvv_(min_u16m1, max_vector_length);
|
|
2336
|
+
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, max_vector_length);
|
|
2337
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, max_vector_length);
|
|
2338
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
2339
|
+
max_vector_length);
|
|
2264
2340
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2265
2341
|
|
|
2266
|
-
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2267
|
-
|
|
2342
|
+
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(__riscv_vslidedown_vx_u16m1(
|
|
2343
|
+
min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, max_vector_length), max_vector_length));
|
|
2268
2344
|
*min_value_ptr = *(nk_bf16_t *)&min_raw;
|
|
2269
2345
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2270
|
-
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1,
|
|
2346
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
2271
2347
|
|
|
2272
|
-
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_bf16m1_to_f32m2_rvv_(max_u16m1,
|
|
2273
|
-
|
|
2348
|
+
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_bf16m1_to_f32m2_rvv_(max_u16m1, max_vector_length),
|
|
2349
|
+
max_val_f32, max_vector_length);
|
|
2350
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16,
|
|
2351
|
+
max_vector_length);
|
|
2274
2352
|
|
|
2275
|
-
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2276
|
-
|
|
2353
|
+
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(__riscv_vslidedown_vx_u16m1(
|
|
2354
|
+
max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, max_vector_length), max_vector_length));
|
|
2277
2355
|
*max_value_ptr = *(nk_bf16_t *)&max_raw;
|
|
2278
2356
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2279
|
-
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1,
|
|
2357
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
2280
2358
|
}
|
|
2281
2359
|
|
|
2282
2360
|
NK_PUBLIC void nk_reduce_minmax_bf16_rvv( //
|
|
@@ -2303,13 +2381,13 @@ NK_PUBLIC void nk_reduce_minmax_bf16_rvv( //
|
|
|
2303
2381
|
NK_INTERNAL void nk_reduce_moments_f16_rvv_contiguous_( //
|
|
2304
2382
|
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
2305
2383
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2306
|
-
nk_size_t
|
|
2307
|
-
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2308
|
-
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2384
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
2385
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2386
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2309
2387
|
|
|
2310
2388
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
2311
2389
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2312
|
-
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((
|
|
2390
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)data_ptr, vector_length);
|
|
2313
2391
|
|
|
2314
2392
|
// Convert f16 → f32 (m1 → m2)
|
|
2315
2393
|
vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
@@ -2324,21 +2402,22 @@ NK_INTERNAL void nk_reduce_moments_f16_rvv_contiguous_( //
|
|
|
2324
2402
|
|
|
2325
2403
|
// Horizontal reduction
|
|
2326
2404
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2327
|
-
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1,
|
|
2328
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(
|
|
2405
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, max_vector_length)),
|
|
2406
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(
|
|
2407
|
+
__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, max_vector_length));
|
|
2329
2408
|
}
|
|
2330
2409
|
|
|
2331
2410
|
NK_INTERNAL void nk_reduce_moments_f16_rvv_strided_( //
|
|
2332
2411
|
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2333
2412
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2334
|
-
nk_size_t
|
|
2335
|
-
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2336
|
-
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2413
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
2414
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2415
|
+
vfloat64m4_t sumsq_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2337
2416
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2338
2417
|
|
|
2339
2418
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
2340
2419
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2341
|
-
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((
|
|
2420
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((nk_u16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2342
2421
|
|
|
2343
2422
|
// Convert f16 → f32 (m1 → m2)
|
|
2344
2423
|
vfloat32m2_t data_f32m2 = nk_f16m1_to_f32m2_rvv_(data_u16m1, vector_length);
|
|
@@ -2353,8 +2432,9 @@ NK_INTERNAL void nk_reduce_moments_f16_rvv_strided_( //
|
|
|
2353
2432
|
|
|
2354
2433
|
// Horizontal reduction
|
|
2355
2434
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2356
|
-
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1,
|
|
2357
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(
|
|
2435
|
+
*sum_ptr = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, max_vector_length)),
|
|
2436
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f64m1_f64(
|
|
2437
|
+
__riscv_vfredusum_vs_f64m4_f64m1(sumsq_f64m4, zero_f64m1, max_vector_length));
|
|
2358
2438
|
}
|
|
2359
2439
|
|
|
2360
2440
|
NK_PUBLIC void nk_reduce_moments_f16_rvv( //
|
|
@@ -2373,17 +2453,17 @@ NK_INTERNAL void nk_reduce_minmax_f16_rvv_contiguous_( //
|
|
|
2373
2453
|
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
2374
2454
|
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2375
2455
|
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2376
|
-
nk_size_t
|
|
2377
|
-
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7C00,
|
|
2378
|
-
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFC00,
|
|
2379
|
-
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
2380
|
-
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
2456
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m1();
|
|
2457
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7C00, max_vector_length); // +inf in f16
|
|
2458
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFC00, max_vector_length); // -inf in f16
|
|
2459
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
2460
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
2381
2461
|
|
|
2382
2462
|
nk_size_t offset = 0;
|
|
2383
2463
|
for (nk_size_t vector_length; count > 0;
|
|
2384
2464
|
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
2385
2465
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2386
|
-
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((
|
|
2466
|
+
vuint16m1_t data_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)data_ptr, vector_length);
|
|
2387
2467
|
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
2388
2468
|
vector_length);
|
|
2389
2469
|
|
|
@@ -2404,58 +2484,61 @@ NK_INTERNAL void nk_reduce_minmax_f16_rvv_contiguous_( //
|
|
|
2404
2484
|
}
|
|
2405
2485
|
|
|
2406
2486
|
// Horizontal reduction
|
|
2407
|
-
vfloat32m2_t final_min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1,
|
|
2487
|
+
vfloat32m2_t final_min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, max_vector_length);
|
|
2408
2488
|
vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
2409
2489
|
nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2410
|
-
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1,
|
|
2411
|
-
vfloat32m2_t final_max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1,
|
|
2490
|
+
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, max_vector_length));
|
|
2491
|
+
vfloat32m2_t final_max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, max_vector_length);
|
|
2412
2492
|
vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
2413
2493
|
nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2414
|
-
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1,
|
|
2494
|
+
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, max_vector_length));
|
|
2415
2495
|
if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
|
|
2416
2496
|
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
2417
2497
|
*max_index_ptr = NK_SIZE_MAX;
|
|
2418
2498
|
return;
|
|
2419
2499
|
}
|
|
2420
2500
|
|
|
2421
|
-
vfloat32m2_t converted_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1,
|
|
2422
|
-
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32,
|
|
2423
|
-
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX,
|
|
2424
|
-
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
2501
|
+
vfloat32m2_t converted_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, max_vector_length);
|
|
2502
|
+
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, max_vector_length);
|
|
2503
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, max_vector_length);
|
|
2504
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
2505
|
+
max_vector_length);
|
|
2425
2506
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2426
2507
|
|
|
2427
|
-
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2428
|
-
|
|
2508
|
+
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(__riscv_vslidedown_vx_u16m1(
|
|
2509
|
+
min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, max_vector_length), max_vector_length));
|
|
2429
2510
|
*min_value_ptr = *(nk_f16_t *)&min_raw;
|
|
2430
2511
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2431
|
-
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1,
|
|
2512
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
2432
2513
|
|
|
2433
|
-
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_f16m1_to_f32m2_rvv_(max_u16m1,
|
|
2434
|
-
|
|
2514
|
+
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_f16m1_to_f32m2_rvv_(max_u16m1, max_vector_length),
|
|
2515
|
+
max_val_f32, max_vector_length);
|
|
2516
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16,
|
|
2517
|
+
max_vector_length);
|
|
2435
2518
|
|
|
2436
|
-
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2437
|
-
|
|
2519
|
+
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(__riscv_vslidedown_vx_u16m1(
|
|
2520
|
+
max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, max_vector_length), max_vector_length));
|
|
2438
2521
|
*max_value_ptr = *(nk_f16_t *)&max_raw;
|
|
2439
2522
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2440
|
-
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1,
|
|
2523
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
2441
2524
|
}
|
|
2442
2525
|
|
|
2443
2526
|
NK_INTERNAL void nk_reduce_minmax_f16_rvv_strided_( //
|
|
2444
2527
|
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2445
2528
|
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2446
2529
|
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2447
|
-
nk_size_t
|
|
2448
|
-
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7C00,
|
|
2449
|
-
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFC00,
|
|
2450
|
-
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
2451
|
-
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0,
|
|
2530
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m1();
|
|
2531
|
+
vuint16m1_t min_u16m1 = __riscv_vmv_v_x_u16m1(0x7C00, max_vector_length); // +inf in f16
|
|
2532
|
+
vuint16m1_t max_u16m1 = __riscv_vmv_v_x_u16m1(0xFC00, max_vector_length); // -inf in f16
|
|
2533
|
+
vuint64m4_t min_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
2534
|
+
vuint64m4_t max_indices_u64m4 = __riscv_vmv_v_x_u64m4(0, max_vector_length);
|
|
2452
2535
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2453
2536
|
|
|
2454
2537
|
nk_size_t offset = 0;
|
|
2455
2538
|
for (nk_size_t vector_length; count > 0;
|
|
2456
2539
|
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
2457
2540
|
vector_length = __riscv_vsetvl_e16m1(count);
|
|
2458
|
-
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((
|
|
2541
|
+
vuint16m1_t data_u16m1 = __riscv_vlse16_v_u16m1((nk_u16_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2459
2542
|
vuint64m4_t pos_u64m4 = __riscv_vadd_vx_u64m4(__riscv_vid_v_u64m4(vector_length), (nk_u64_t)offset,
|
|
2460
2543
|
vector_length);
|
|
2461
2544
|
|
|
@@ -2476,40 +2559,43 @@ NK_INTERNAL void nk_reduce_minmax_f16_rvv_strided_( //
|
|
|
2476
2559
|
}
|
|
2477
2560
|
|
|
2478
2561
|
// Horizontal reduction (same as contiguous)
|
|
2479
|
-
vfloat32m2_t final_min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1,
|
|
2562
|
+
vfloat32m2_t final_min_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, max_vector_length);
|
|
2480
2563
|
vfloat32m1_t init_max_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MAX, 1);
|
|
2481
2564
|
nk_f32_t min_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2482
|
-
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1,
|
|
2483
|
-
vfloat32m2_t final_max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1,
|
|
2565
|
+
__riscv_vfredmin_vs_f32m2_f32m1(final_min_f32m2, init_max_f32m1, max_vector_length));
|
|
2566
|
+
vfloat32m2_t final_max_f32m2 = nk_f16m1_to_f32m2_rvv_(max_u16m1, max_vector_length);
|
|
2484
2567
|
vfloat32m1_t init_min_f32m1 = __riscv_vfmv_v_f_f32m1(NK_F32_MIN, 1);
|
|
2485
2568
|
nk_f32_t max_val_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
2486
|
-
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1,
|
|
2569
|
+
__riscv_vfredmax_vs_f32m2_f32m1(final_max_f32m2, init_min_f32m1, max_vector_length));
|
|
2487
2570
|
if (min_val_f32 == NK_F32_MAX && max_val_f32 == NK_F32_MIN) {
|
|
2488
2571
|
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
2489
2572
|
*max_index_ptr = NK_SIZE_MAX;
|
|
2490
2573
|
return;
|
|
2491
2574
|
}
|
|
2492
2575
|
|
|
2493
|
-
vfloat32m2_t converted_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1,
|
|
2494
|
-
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32,
|
|
2495
|
-
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX,
|
|
2496
|
-
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
2576
|
+
vfloat32m2_t converted_f32m2 = nk_f16m1_to_f32m2_rvv_(min_u16m1, max_vector_length);
|
|
2577
|
+
vbool16_t min_match_b16 = __riscv_vmfeq_vf_f32m2_b16(converted_f32m2, min_val_f32, max_vector_length);
|
|
2578
|
+
vuint64m4_t sentinel_u64m4 = __riscv_vmv_v_x_u64m4(NK_U64_MAX, max_vector_length);
|
|
2579
|
+
vuint64m4_t min_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, min_indices_u64m4, min_match_b16,
|
|
2580
|
+
max_vector_length);
|
|
2497
2581
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2498
2582
|
|
|
2499
|
-
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2500
|
-
|
|
2583
|
+
nk_u16_t min_raw = __riscv_vmv_x_s_u16m1_u16(__riscv_vslidedown_vx_u16m1(
|
|
2584
|
+
min_u16m1, (nk_size_t)__riscv_vfirst_m_b16(min_match_b16, max_vector_length), max_vector_length));
|
|
2501
2585
|
*min_value_ptr = *(nk_f16_t *)&min_raw;
|
|
2502
2586
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2503
|
-
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1,
|
|
2587
|
+
__riscv_vredminu_vs_u64m4_u64m1(min_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
2504
2588
|
|
|
2505
|
-
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_f16m1_to_f32m2_rvv_(max_u16m1,
|
|
2506
|
-
|
|
2589
|
+
vbool16_t max_match_b16 = __riscv_vmfeq_vf_f32m2_b16(nk_f16m1_to_f32m2_rvv_(max_u16m1, max_vector_length),
|
|
2590
|
+
max_val_f32, max_vector_length);
|
|
2591
|
+
vuint64m4_t max_cands_u64m4 = __riscv_vmerge_vvm_u64m4(sentinel_u64m4, max_indices_u64m4, max_match_b16,
|
|
2592
|
+
max_vector_length);
|
|
2507
2593
|
|
|
2508
|
-
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(
|
|
2509
|
-
|
|
2594
|
+
nk_u16_t max_raw = __riscv_vmv_x_s_u16m1_u16(__riscv_vslidedown_vx_u16m1(
|
|
2595
|
+
max_u16m1, (nk_size_t)__riscv_vfirst_m_b16(max_match_b16, max_vector_length), max_vector_length));
|
|
2510
2596
|
*max_value_ptr = *(nk_f16_t *)&max_raw;
|
|
2511
2597
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2512
|
-
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1,
|
|
2598
|
+
__riscv_vredminu_vs_u64m4_u64m1(max_cands_u64m4, init_umax_u64m1, max_vector_length));
|
|
2513
2599
|
}
|
|
2514
2600
|
|
|
2515
2601
|
NK_PUBLIC void nk_reduce_minmax_f16_rvv( //
|
|
@@ -2536,13 +2622,13 @@ NK_PUBLIC void nk_reduce_minmax_f16_rvv( //
|
|
|
2536
2622
|
NK_INTERNAL void nk_reduce_moments_e4m3_rvv_contiguous_( //
|
|
2537
2623
|
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
2538
2624
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2539
|
-
nk_size_t
|
|
2540
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
2541
|
-
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
2625
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
2626
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
2627
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
2542
2628
|
|
|
2543
2629
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
2544
2630
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2545
|
-
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((
|
|
2631
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)data_ptr, vector_length);
|
|
2546
2632
|
|
|
2547
2633
|
// Convert e4m3 → f32 (m1 → m4)
|
|
2548
2634
|
vfloat32m4_t data_f32m4 = nk_e4m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
@@ -2554,21 +2640,22 @@ NK_INTERNAL void nk_reduce_moments_e4m3_rvv_contiguous_( //
|
|
|
2554
2640
|
|
|
2555
2641
|
// Horizontal reduction
|
|
2556
2642
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
2557
|
-
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
2558
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
2643
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length)),
|
|
2644
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
2645
|
+
__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, max_vector_length));
|
|
2559
2646
|
}
|
|
2560
2647
|
|
|
2561
2648
|
NK_INTERNAL void nk_reduce_moments_e4m3_rvv_strided_( //
|
|
2562
2649
|
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2563
2650
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2564
|
-
nk_size_t
|
|
2565
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
2566
|
-
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
2651
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
2652
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
2653
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
2567
2654
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2568
2655
|
|
|
2569
2656
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
2570
2657
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2571
|
-
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((
|
|
2658
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2572
2659
|
|
|
2573
2660
|
// Convert e4m3 → f32 (m1 → m4)
|
|
2574
2661
|
vfloat32m4_t data_f32m4 = nk_e4m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
@@ -2580,8 +2667,9 @@ NK_INTERNAL void nk_reduce_moments_e4m3_rvv_strided_( //
|
|
|
2580
2667
|
|
|
2581
2668
|
// Horizontal reduction
|
|
2582
2669
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
2583
|
-
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
2584
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
2670
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length)),
|
|
2671
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
2672
|
+
__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, max_vector_length));
|
|
2585
2673
|
}
|
|
2586
2674
|
|
|
2587
2675
|
NK_PUBLIC void nk_reduce_moments_e4m3_rvv( //
|
|
@@ -2600,17 +2688,17 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_contiguous_( //
|
|
|
2600
2688
|
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
2601
2689
|
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2602
2690
|
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2603
|
-
nk_size_t
|
|
2604
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF,
|
|
2605
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00,
|
|
2606
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
2607
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
2691
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
2692
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, max_vector_length); // Largest comparable
|
|
2693
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, max_vector_length); // Smallest comparable
|
|
2694
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
2695
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
2608
2696
|
|
|
2609
2697
|
nk_size_t offset = 0;
|
|
2610
2698
|
for (nk_size_t vector_length; count > 0;
|
|
2611
2699
|
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
2612
2700
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2613
|
-
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((
|
|
2701
|
+
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)data_ptr, vector_length);
|
|
2614
2702
|
|
|
2615
2703
|
// Convert to comparable form
|
|
2616
2704
|
vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
@@ -2637,7 +2725,8 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_contiguous_( //
|
|
|
2637
2725
|
|
|
2638
2726
|
// Horizontal reduction + convert back
|
|
2639
2727
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
|
|
2640
|
-
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2728
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2729
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
2641
2730
|
|
|
2642
2731
|
// All-NaN case
|
|
2643
2732
|
if (min_comparable == 0xFF) {
|
|
@@ -2646,12 +2735,13 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_contiguous_( //
|
|
|
2646
2735
|
return;
|
|
2647
2736
|
}
|
|
2648
2737
|
|
|
2649
|
-
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable,
|
|
2650
|
-
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX,
|
|
2651
|
-
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
2738
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, max_vector_length);
|
|
2739
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
2740
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
2741
|
+
max_vector_length);
|
|
2652
2742
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2653
2743
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2654
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
2744
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
2655
2745
|
|
|
2656
2746
|
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
2657
2747
|
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
|
|
@@ -2659,11 +2749,13 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_contiguous_( //
|
|
|
2659
2749
|
|
|
2660
2750
|
// Similar for max
|
|
2661
2751
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
2662
|
-
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2663
|
-
|
|
2664
|
-
|
|
2752
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2753
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
2754
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, max_vector_length);
|
|
2755
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
2756
|
+
max_vector_length);
|
|
2665
2757
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2666
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
2758
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
2667
2759
|
|
|
2668
2760
|
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
2669
2761
|
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
|
|
@@ -2674,18 +2766,18 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_strided_( //
|
|
|
2674
2766
|
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2675
2767
|
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2676
2768
|
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2677
|
-
nk_size_t
|
|
2678
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF,
|
|
2679
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00,
|
|
2680
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
2681
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
2769
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
2770
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, max_vector_length);
|
|
2771
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, max_vector_length);
|
|
2772
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
2773
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
2682
2774
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2683
2775
|
|
|
2684
2776
|
nk_size_t offset = 0;
|
|
2685
2777
|
for (nk_size_t vector_length; count > 0;
|
|
2686
2778
|
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
2687
2779
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2688
|
-
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((
|
|
2780
|
+
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2689
2781
|
|
|
2690
2782
|
vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
2691
2783
|
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
@@ -2711,7 +2803,8 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_strided_( //
|
|
|
2711
2803
|
|
|
2712
2804
|
// Horizontal reduction (same as contiguous)
|
|
2713
2805
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
|
|
2714
|
-
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2806
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2807
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
2715
2808
|
|
|
2716
2809
|
// All-NaN case
|
|
2717
2810
|
if (min_comparable == 0xFF) {
|
|
@@ -2720,23 +2813,26 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_rvv_strided_( //
|
|
|
2720
2813
|
return;
|
|
2721
2814
|
}
|
|
2722
2815
|
|
|
2723
|
-
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable,
|
|
2724
|
-
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX,
|
|
2725
|
-
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
2816
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, max_vector_length);
|
|
2817
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
2818
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
2819
|
+
max_vector_length);
|
|
2726
2820
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2727
2821
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2728
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
2822
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
2729
2823
|
|
|
2730
2824
|
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
2731
2825
|
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
|
|
2732
2826
|
*min_value_ptr = (nk_e4m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
2733
2827
|
|
|
2734
2828
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
2735
|
-
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2736
|
-
|
|
2737
|
-
|
|
2829
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2830
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
2831
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, max_vector_length);
|
|
2832
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
2833
|
+
max_vector_length);
|
|
2738
2834
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2739
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
2835
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
2740
2836
|
|
|
2741
2837
|
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
2742
2838
|
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
|
|
@@ -2767,13 +2863,13 @@ NK_PUBLIC void nk_reduce_minmax_e4m3_rvv( //
|
|
|
2767
2863
|
NK_INTERNAL void nk_reduce_moments_e5m2_rvv_contiguous_( //
|
|
2768
2864
|
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
2769
2865
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2770
|
-
nk_size_t
|
|
2771
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
2772
|
-
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
2866
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
2867
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
2868
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
2773
2869
|
|
|
2774
2870
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
2775
2871
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2776
|
-
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((
|
|
2872
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)data_ptr, vector_length);
|
|
2777
2873
|
|
|
2778
2874
|
// Convert e5m2 → f32 (m1 → m4)
|
|
2779
2875
|
vfloat32m4_t data_f32m4 = nk_e5m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
@@ -2785,21 +2881,22 @@ NK_INTERNAL void nk_reduce_moments_e5m2_rvv_contiguous_( //
|
|
|
2785
2881
|
|
|
2786
2882
|
// Horizontal reduction
|
|
2787
2883
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
2788
|
-
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
2789
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
2884
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length)),
|
|
2885
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
2886
|
+
__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, max_vector_length));
|
|
2790
2887
|
}
|
|
2791
2888
|
|
|
2792
2889
|
NK_INTERNAL void nk_reduce_moments_e5m2_rvv_strided_( //
|
|
2793
2890
|
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2794
2891
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2795
|
-
nk_size_t
|
|
2796
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
2797
|
-
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
2892
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
2893
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
2894
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
2798
2895
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2799
2896
|
|
|
2800
2897
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
2801
2898
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2802
|
-
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((
|
|
2899
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2803
2900
|
|
|
2804
2901
|
// Convert e5m2 → f32 (m1 → m4)
|
|
2805
2902
|
vfloat32m4_t data_f32m4 = nk_e5m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
@@ -2811,8 +2908,9 @@ NK_INTERNAL void nk_reduce_moments_e5m2_rvv_strided_( //
|
|
|
2811
2908
|
|
|
2812
2909
|
// Horizontal reduction
|
|
2813
2910
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
2814
|
-
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
2815
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
2911
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length)),
|
|
2912
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
2913
|
+
__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, max_vector_length));
|
|
2816
2914
|
}
|
|
2817
2915
|
|
|
2818
2916
|
NK_PUBLIC void nk_reduce_moments_e5m2_rvv( //
|
|
@@ -2831,17 +2929,17 @@ NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_contiguous_( //
|
|
|
2831
2929
|
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
2832
2930
|
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2833
2931
|
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2834
|
-
nk_size_t
|
|
2835
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF,
|
|
2836
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00,
|
|
2837
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
2838
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
2932
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
2933
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, max_vector_length);
|
|
2934
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, max_vector_length);
|
|
2935
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
2936
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
2839
2937
|
|
|
2840
2938
|
nk_size_t offset = 0;
|
|
2841
2939
|
for (nk_size_t vector_length; count > 0;
|
|
2842
2940
|
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
2843
2941
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2844
|
-
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((
|
|
2942
|
+
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)data_ptr, vector_length);
|
|
2845
2943
|
|
|
2846
2944
|
vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
2847
2945
|
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
@@ -2867,7 +2965,8 @@ NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_contiguous_( //
|
|
|
2867
2965
|
|
|
2868
2966
|
// Horizontal reduction + convert back
|
|
2869
2967
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
|
|
2870
|
-
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2968
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2969
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
2871
2970
|
|
|
2872
2971
|
// All-NaN case
|
|
2873
2972
|
if (min_comparable == 0xFF) {
|
|
@@ -2876,23 +2975,26 @@ NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_contiguous_( //
|
|
|
2876
2975
|
return;
|
|
2877
2976
|
}
|
|
2878
2977
|
|
|
2879
|
-
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable,
|
|
2880
|
-
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX,
|
|
2881
|
-
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
2978
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, max_vector_length);
|
|
2979
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
2980
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
2981
|
+
max_vector_length);
|
|
2882
2982
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2883
2983
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2884
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
2984
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
2885
2985
|
|
|
2886
2986
|
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
2887
2987
|
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
|
|
2888
2988
|
*min_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
2889
2989
|
|
|
2890
2990
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
2891
|
-
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2892
|
-
|
|
2893
|
-
|
|
2991
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2992
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
2993
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, max_vector_length);
|
|
2994
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
2995
|
+
max_vector_length);
|
|
2894
2996
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2895
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
2997
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
2896
2998
|
|
|
2897
2999
|
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
2898
3000
|
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
|
|
@@ -2903,18 +3005,18 @@ NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_strided_( //
|
|
|
2903
3005
|
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2904
3006
|
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2905
3007
|
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2906
|
-
nk_size_t
|
|
2907
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF,
|
|
2908
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00,
|
|
2909
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
2910
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
3008
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
3009
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, max_vector_length);
|
|
3010
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, max_vector_length);
|
|
3011
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
3012
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
2911
3013
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
2912
3014
|
|
|
2913
3015
|
nk_size_t offset = 0;
|
|
2914
3016
|
for (nk_size_t vector_length; count > 0;
|
|
2915
3017
|
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
2916
3018
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
2917
|
-
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((
|
|
3019
|
+
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
2918
3020
|
|
|
2919
3021
|
vuint8m1_t comparable_u8m1 = nk_fp8m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
2920
3022
|
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
@@ -2940,7 +3042,8 @@ NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_strided_( //
|
|
|
2940
3042
|
|
|
2941
3043
|
// Horizontal reduction (same as contiguous)
|
|
2942
3044
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0xFF, 1);
|
|
2943
|
-
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3045
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3046
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
2944
3047
|
|
|
2945
3048
|
// All-NaN case
|
|
2946
3049
|
if (min_comparable == 0xFF) {
|
|
@@ -2949,23 +3052,26 @@ NK_INTERNAL void nk_reduce_minmax_e5m2_rvv_strided_( //
|
|
|
2949
3052
|
return;
|
|
2950
3053
|
}
|
|
2951
3054
|
|
|
2952
|
-
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable,
|
|
2953
|
-
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX,
|
|
2954
|
-
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
3055
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, max_vector_length);
|
|
3056
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
3057
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
3058
|
+
max_vector_length);
|
|
2955
3059
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
2956
3060
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2957
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
3061
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
2958
3062
|
|
|
2959
3063
|
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
2960
3064
|
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(min_vec_u8m1, 1);
|
|
2961
3065
|
*min_value_ptr = (nk_e5m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
2962
3066
|
|
|
2963
3067
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
2964
|
-
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
2965
|
-
|
|
2966
|
-
|
|
3068
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3069
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
3070
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, max_vector_length);
|
|
3071
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
3072
|
+
max_vector_length);
|
|
2967
3073
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
2968
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
3074
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
2969
3075
|
|
|
2970
3076
|
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
2971
3077
|
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp8m1_rvv_(max_vec_u8m1, 1);
|
|
@@ -2996,13 +3102,13 @@ NK_PUBLIC void nk_reduce_minmax_e5m2_rvv( //
|
|
|
2996
3102
|
NK_INTERNAL void nk_reduce_moments_e2m3_rvv_contiguous_( //
|
|
2997
3103
|
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
2998
3104
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2999
|
-
nk_size_t
|
|
3000
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
3001
|
-
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
3105
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
3106
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
3107
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
3002
3108
|
|
|
3003
3109
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
3004
3110
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3005
|
-
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((
|
|
3111
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)data_ptr, vector_length);
|
|
3006
3112
|
|
|
3007
3113
|
// Convert e2m3 → f32 (m1 → m4)
|
|
3008
3114
|
vfloat32m4_t data_f32m4 = nk_e2m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
@@ -3014,21 +3120,22 @@ NK_INTERNAL void nk_reduce_moments_e2m3_rvv_contiguous_( //
|
|
|
3014
3120
|
|
|
3015
3121
|
// Horizontal reduction
|
|
3016
3122
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
3017
|
-
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
3018
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
3123
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length)),
|
|
3124
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
3125
|
+
__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, max_vector_length));
|
|
3019
3126
|
}
|
|
3020
3127
|
|
|
3021
3128
|
NK_INTERNAL void nk_reduce_moments_e2m3_rvv_strided_( //
|
|
3022
3129
|
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3023
3130
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3024
|
-
nk_size_t
|
|
3025
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
3026
|
-
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
3131
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
3132
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
3133
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
3027
3134
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3028
3135
|
|
|
3029
3136
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
3030
3137
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3031
|
-
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((
|
|
3138
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
3032
3139
|
|
|
3033
3140
|
// Convert e2m3 → f32 (m1 → m4)
|
|
3034
3141
|
vfloat32m4_t data_f32m4 = nk_e2m3m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
@@ -3040,8 +3147,9 @@ NK_INTERNAL void nk_reduce_moments_e2m3_rvv_strided_( //
|
|
|
3040
3147
|
|
|
3041
3148
|
// Horizontal reduction
|
|
3042
3149
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
3043
|
-
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
3044
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
3150
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length)),
|
|
3151
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
3152
|
+
__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, max_vector_length));
|
|
3045
3153
|
}
|
|
3046
3154
|
|
|
3047
3155
|
NK_PUBLIC void nk_reduce_moments_e2m3_rvv( //
|
|
@@ -3060,17 +3168,17 @@ NK_INTERNAL void nk_reduce_minmax_e2m3_rvv_contiguous_( //
|
|
|
3060
3168
|
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
3061
3169
|
nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3062
3170
|
nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3063
|
-
nk_size_t
|
|
3064
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F,
|
|
3065
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00,
|
|
3066
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
3067
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
3171
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
3172
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, max_vector_length); // Largest FP6 comparable
|
|
3173
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, max_vector_length); // Smallest FP6 comparable
|
|
3174
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
3175
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
3068
3176
|
|
|
3069
3177
|
nk_size_t offset = 0;
|
|
3070
3178
|
for (nk_size_t vector_length; count > 0;
|
|
3071
3179
|
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
3072
3180
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3073
|
-
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((
|
|
3181
|
+
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)data_ptr, vector_length);
|
|
3074
3182
|
|
|
3075
3183
|
// Convert to FP6 comparable form
|
|
3076
3184
|
vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
@@ -3090,24 +3198,28 @@ NK_INTERNAL void nk_reduce_minmax_e2m3_rvv_contiguous_( //
|
|
|
3090
3198
|
|
|
3091
3199
|
// Horizontal reduction + convert back
|
|
3092
3200
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
|
|
3093
|
-
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3094
|
-
|
|
3095
|
-
|
|
3096
|
-
vuint64m8_t
|
|
3201
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3202
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
3203
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, max_vector_length);
|
|
3204
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
3205
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
3206
|
+
max_vector_length);
|
|
3097
3207
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
3098
3208
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3099
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
3209
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
3100
3210
|
|
|
3101
3211
|
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
3102
3212
|
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
|
|
3103
3213
|
*min_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
3104
3214
|
|
|
3105
3215
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
3106
|
-
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3107
|
-
|
|
3108
|
-
|
|
3216
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3217
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
3218
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, max_vector_length);
|
|
3219
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
3220
|
+
max_vector_length);
|
|
3109
3221
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3110
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
3222
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
3111
3223
|
|
|
3112
3224
|
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
3113
3225
|
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
|
|
@@ -3118,18 +3230,18 @@ NK_INTERNAL void nk_reduce_minmax_e2m3_rvv_strided_( //
|
|
|
3118
3230
|
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3119
3231
|
nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3120
3232
|
nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3121
|
-
nk_size_t
|
|
3122
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F,
|
|
3123
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00,
|
|
3124
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
3125
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
3233
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
3234
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, max_vector_length);
|
|
3235
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, max_vector_length);
|
|
3236
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
3237
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
3126
3238
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3127
3239
|
|
|
3128
3240
|
nk_size_t offset = 0;
|
|
3129
3241
|
for (nk_size_t vector_length; count > 0;
|
|
3130
3242
|
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
3131
3243
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3132
|
-
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((
|
|
3244
|
+
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
3133
3245
|
|
|
3134
3246
|
vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
3135
3247
|
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
@@ -3148,24 +3260,28 @@ NK_INTERNAL void nk_reduce_minmax_e2m3_rvv_strided_( //
|
|
|
3148
3260
|
|
|
3149
3261
|
// Horizontal reduction (same as contiguous)
|
|
3150
3262
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
|
|
3151
|
-
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3152
|
-
|
|
3153
|
-
|
|
3154
|
-
vuint64m8_t
|
|
3263
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3264
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
3265
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, max_vector_length);
|
|
3266
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
3267
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
3268
|
+
max_vector_length);
|
|
3155
3269
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
3156
3270
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3157
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
3271
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
3158
3272
|
|
|
3159
3273
|
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
3160
3274
|
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
|
|
3161
3275
|
*min_value_ptr = (nk_e2m3_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
3162
3276
|
|
|
3163
3277
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
3164
|
-
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3165
|
-
|
|
3166
|
-
|
|
3278
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3279
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
3280
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, max_vector_length);
|
|
3281
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
3282
|
+
max_vector_length);
|
|
3167
3283
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3168
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
3284
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
3169
3285
|
|
|
3170
3286
|
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
3171
3287
|
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
|
|
@@ -3196,13 +3312,13 @@ NK_PUBLIC void nk_reduce_minmax_e2m3_rvv( //
|
|
|
3196
3312
|
NK_INTERNAL void nk_reduce_moments_e3m2_rvv_contiguous_( //
|
|
3197
3313
|
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
3198
3314
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3199
|
-
nk_size_t
|
|
3200
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
3201
|
-
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
3315
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
3316
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
3317
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
3202
3318
|
|
|
3203
3319
|
for (nk_size_t vector_length; count > 0; count -= vector_length, data_ptr += vector_length) {
|
|
3204
3320
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3205
|
-
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((
|
|
3321
|
+
vuint8m1_t data_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)data_ptr, vector_length);
|
|
3206
3322
|
|
|
3207
3323
|
// Convert e3m2 → f32 (m1 → m4)
|
|
3208
3324
|
vfloat32m4_t data_f32m4 = nk_e3m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
@@ -3214,21 +3330,22 @@ NK_INTERNAL void nk_reduce_moments_e3m2_rvv_contiguous_( //
|
|
|
3214
3330
|
|
|
3215
3331
|
// Horizontal reduction
|
|
3216
3332
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
3217
|
-
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
3218
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
3333
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length)),
|
|
3334
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
3335
|
+
__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, max_vector_length));
|
|
3219
3336
|
}
|
|
3220
3337
|
|
|
3221
3338
|
NK_INTERNAL void nk_reduce_moments_e3m2_rvv_strided_( //
|
|
3222
3339
|
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3223
3340
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3224
|
-
nk_size_t
|
|
3225
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
3226
|
-
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
3341
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
3342
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
3343
|
+
vfloat32m4_t sumsq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
3227
3344
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3228
3345
|
|
|
3229
3346
|
for (nk_size_t vector_length; count > 0; count -= vector_length, ptr += vector_length * stride_bytes) {
|
|
3230
3347
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3231
|
-
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((
|
|
3348
|
+
vuint8m1_t data_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
3232
3349
|
|
|
3233
3350
|
// Convert e3m2 → f32 (m1 → m4)
|
|
3234
3351
|
vfloat32m4_t data_f32m4 = nk_e3m2m1_to_f32m4_rvv_(data_u8m1, vector_length);
|
|
@@ -3240,8 +3357,9 @@ NK_INTERNAL void nk_reduce_moments_e3m2_rvv_strided_( //
|
|
|
3240
3357
|
|
|
3241
3358
|
// Horizontal reduction
|
|
3242
3359
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
3243
|
-
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
3244
|
-
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
3360
|
+
*sum_ptr = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length)),
|
|
3361
|
+
*sumsq_ptr = __riscv_vfmv_f_s_f32m1_f32(
|
|
3362
|
+
__riscv_vfredusum_vs_f32m4_f32m1(sumsq_f32m4, zero_f32m1, max_vector_length));
|
|
3245
3363
|
}
|
|
3246
3364
|
|
|
3247
3365
|
NK_PUBLIC void nk_reduce_moments_e3m2_rvv( //
|
|
@@ -3260,17 +3378,17 @@ NK_INTERNAL void nk_reduce_minmax_e3m2_rvv_contiguous_( //
|
|
|
3260
3378
|
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
3261
3379
|
nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3262
3380
|
nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3263
|
-
nk_size_t
|
|
3264
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F,
|
|
3265
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00,
|
|
3266
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
3267
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
3381
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
3382
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, max_vector_length);
|
|
3383
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, max_vector_length);
|
|
3384
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
3385
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
3268
3386
|
|
|
3269
3387
|
nk_size_t offset = 0;
|
|
3270
3388
|
for (nk_size_t vector_length; count > 0;
|
|
3271
3389
|
count -= vector_length, offset += vector_length, data_ptr += vector_length) {
|
|
3272
3390
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3273
|
-
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((
|
|
3391
|
+
vuint8m1_t raw_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)data_ptr, vector_length);
|
|
3274
3392
|
|
|
3275
3393
|
vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
3276
3394
|
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
@@ -3289,24 +3407,28 @@ NK_INTERNAL void nk_reduce_minmax_e3m2_rvv_contiguous_( //
|
|
|
3289
3407
|
|
|
3290
3408
|
// Horizontal reduction + convert back
|
|
3291
3409
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
|
|
3292
|
-
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3293
|
-
|
|
3294
|
-
|
|
3295
|
-
vuint64m8_t
|
|
3410
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3411
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
3412
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, max_vector_length);
|
|
3413
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
3414
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
3415
|
+
max_vector_length);
|
|
3296
3416
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
3297
3417
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3298
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
3418
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
3299
3419
|
|
|
3300
3420
|
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
3301
3421
|
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
|
|
3302
3422
|
*min_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
3303
3423
|
|
|
3304
3424
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
3305
|
-
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3306
|
-
|
|
3307
|
-
|
|
3425
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3426
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
3427
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, max_vector_length);
|
|
3428
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
3429
|
+
max_vector_length);
|
|
3308
3430
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3309
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
3431
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
3310
3432
|
|
|
3311
3433
|
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
3312
3434
|
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
|
|
@@ -3317,18 +3439,18 @@ NK_INTERNAL void nk_reduce_minmax_e3m2_rvv_strided_( //
|
|
|
3317
3439
|
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3318
3440
|
nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3319
3441
|
nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3320
|
-
nk_size_t
|
|
3321
|
-
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F,
|
|
3322
|
-
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00,
|
|
3323
|
-
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
3324
|
-
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0,
|
|
3442
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
3443
|
+
vuint8m1_t min_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, max_vector_length);
|
|
3444
|
+
vuint8m1_t max_u8m1 = __riscv_vmv_v_x_u8m1(0x00, max_vector_length);
|
|
3445
|
+
vuint64m8_t min_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
3446
|
+
vuint64m8_t max_indices_u64m8 = __riscv_vmv_v_x_u64m8(0, max_vector_length);
|
|
3325
3447
|
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3326
3448
|
|
|
3327
3449
|
nk_size_t offset = 0;
|
|
3328
3450
|
for (nk_size_t vector_length; count > 0;
|
|
3329
3451
|
count -= vector_length, offset += vector_length, ptr += vector_length * stride_bytes) {
|
|
3330
3452
|
vector_length = __riscv_vsetvl_e8m1(count);
|
|
3331
|
-
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((
|
|
3453
|
+
vuint8m1_t raw_u8m1 = __riscv_vlse8_v_u8m1((nk_u8_t const *)ptr, (nk_ssize_t)stride_bytes, vector_length);
|
|
3332
3454
|
|
|
3333
3455
|
vuint8m1_t comparable_u8m1 = nk_fp6m1_to_comparable_u8m1_rvv_(raw_u8m1, vector_length);
|
|
3334
3456
|
vuint64m8_t pos_u64m8 = __riscv_vadd_vx_u64m8(__riscv_vid_v_u64m8(vector_length), (nk_u64_t)offset,
|
|
@@ -3347,24 +3469,28 @@ NK_INTERNAL void nk_reduce_minmax_e3m2_rvv_strided_( //
|
|
|
3347
3469
|
|
|
3348
3470
|
// Horizontal reduction (same as contiguous)
|
|
3349
3471
|
vuint8m1_t init_max_u8m1 = __riscv_vmv_v_x_u8m1(0x3F, 1);
|
|
3350
|
-
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3351
|
-
|
|
3352
|
-
|
|
3353
|
-
vuint64m8_t
|
|
3472
|
+
nk_u8_t min_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3473
|
+
__riscv_vredminu_vs_u8m1_u8m1(min_u8m1, init_max_u8m1, max_vector_length));
|
|
3474
|
+
vbool8_t min_match_b8 = __riscv_vmseq_vx_u8m1_b8(min_u8m1, min_comparable, max_vector_length);
|
|
3475
|
+
vuint64m8_t sentinel_u64m8 = __riscv_vmv_v_x_u64m8(NK_U64_MAX, max_vector_length);
|
|
3476
|
+
vuint64m8_t min_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, min_indices_u64m8, min_match_b8,
|
|
3477
|
+
max_vector_length);
|
|
3354
3478
|
vuint64m1_t init_umax_u64m1 = __riscv_vmv_v_x_u64m1(NK_U64_MAX, 1);
|
|
3355
3479
|
*min_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3356
|
-
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1,
|
|
3480
|
+
__riscv_vredminu_vs_u64m8_u64m1(min_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
3357
3481
|
|
|
3358
3482
|
vuint8m1_t min_vec_u8m1 = __riscv_vmv_v_x_u8m1(min_comparable, 1);
|
|
3359
3483
|
vuint8m1_t min_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(min_vec_u8m1, 1);
|
|
3360
3484
|
*min_value_ptr = (nk_e3m2_t)__riscv_vmv_x_s_u8m1_u8(min_raw_u8m1);
|
|
3361
3485
|
|
|
3362
3486
|
vuint8m1_t init_min_u8m1 = __riscv_vmv_v_x_u8m1(0x00, 1);
|
|
3363
|
-
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3364
|
-
|
|
3365
|
-
|
|
3487
|
+
nk_u8_t max_comparable = __riscv_vmv_x_s_u8m1_u8(
|
|
3488
|
+
__riscv_vredmaxu_vs_u8m1_u8m1(max_u8m1, init_min_u8m1, max_vector_length));
|
|
3489
|
+
vbool8_t max_match_b8 = __riscv_vmseq_vx_u8m1_b8(max_u8m1, max_comparable, max_vector_length);
|
|
3490
|
+
vuint64m8_t max_cands_u64m8 = __riscv_vmerge_vvm_u64m8(sentinel_u64m8, max_indices_u64m8, max_match_b8,
|
|
3491
|
+
max_vector_length);
|
|
3366
3492
|
*max_index_ptr = (nk_size_t)__riscv_vmv_x_s_u64m1_u64(
|
|
3367
|
-
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1,
|
|
3493
|
+
__riscv_vredminu_vs_u64m8_u64m1(max_cands_u64m8, init_umax_u64m1, max_vector_length));
|
|
3368
3494
|
|
|
3369
3495
|
vuint8m1_t max_vec_u8m1 = __riscv_vmv_v_x_u8m1(max_comparable, 1);
|
|
3370
3496
|
vuint8m1_t max_raw_u8m1 = nk_comparable_to_fp6m1_rvv_(max_vec_u8m1, 1);
|