numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -77,24 +77,24 @@ static nk_u16_t const nk_e3m2_magnitude_lut_rvv_[32] = {0, 1, 2, 3, 4,
|
|
|
77
77
|
14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80,
|
|
78
78
|
96, 112, 128, 160, 192, 224, 256, 320, 384, 448};
|
|
79
79
|
|
|
80
|
-
#pragma region
|
|
80
|
+
#pragma region F32 Floats
|
|
81
81
|
|
|
82
82
|
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
83
|
-
nk_size_t
|
|
84
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
83
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
84
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
85
85
|
// Break power-of-2 strides for cache associativity
|
|
86
86
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
87
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
87
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
88
88
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
89
89
|
column_count * sizeof(nk_f64_t); // per-column norms
|
|
90
90
|
}
|
|
91
91
|
|
|
92
92
|
NK_PUBLIC void nk_dots_pack_f32_rvv(nk_f32_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
93
93
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
94
|
-
nk_size_t
|
|
95
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
94
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
95
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
96
96
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
97
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
97
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
98
98
|
|
|
99
99
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
100
100
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -103,12 +103,24 @@ NK_PUBLIC void nk_dots_pack_f32_rvv(nk_f32_t const *b, nk_size_t column_count, n
|
|
|
103
103
|
|
|
104
104
|
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
105
105
|
nk_size_t total = column_count * depth_padded;
|
|
106
|
-
|
|
106
|
+
{
|
|
107
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
108
|
+
nk_size_t total_bytes = total * sizeof(nk_f32_t);
|
|
109
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
110
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
111
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
112
|
+
i += vector_length;
|
|
113
|
+
}
|
|
114
|
+
}
|
|
107
115
|
|
|
108
116
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
109
117
|
nk_f32_t const *src = (nk_f32_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
110
118
|
nk_f32_t *dst = packed + column * depth_padded;
|
|
111
|
-
for (nk_size_t k = 0; k < depth;
|
|
119
|
+
for (nk_size_t k = 0; k < depth;) {
|
|
120
|
+
nk_size_t vector_length = __riscv_vsetvl_e32m8(depth - k);
|
|
121
|
+
__riscv_vse32_v_f32m8(dst + k, __riscv_vle32_v_f32m8(src + k, vector_length), vector_length);
|
|
122
|
+
k += vector_length;
|
|
123
|
+
}
|
|
112
124
|
}
|
|
113
125
|
|
|
114
126
|
// Append per-column norms after packed data
|
|
@@ -158,11 +170,11 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
|
|
|
158
170
|
|
|
159
171
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
160
172
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
161
|
-
nk_size_t
|
|
162
|
-
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
163
|
-
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
164
|
-
vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
165
|
-
vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
173
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
174
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
175
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
176
|
+
vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
177
|
+
vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
166
178
|
|
|
167
179
|
nk_size_t remaining = depth;
|
|
168
180
|
nk_size_t k = 0;
|
|
@@ -186,13 +198,13 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
|
|
|
186
198
|
// Horizontal reduce directly to f64
|
|
187
199
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
188
200
|
c_row_0[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
189
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1,
|
|
201
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
|
|
190
202
|
c_row_1[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
191
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1,
|
|
203
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
|
|
192
204
|
c_row_2[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
193
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1,
|
|
205
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, max_vector_length));
|
|
194
206
|
c_row_3[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
195
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1,
|
|
207
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, max_vector_length));
|
|
196
208
|
}
|
|
197
209
|
}
|
|
198
210
|
// Remainder rows (mr < 4)
|
|
@@ -201,8 +213,8 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
|
|
|
201
213
|
nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
202
214
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
203
215
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
204
|
-
nk_size_t
|
|
205
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
216
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
217
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
206
218
|
nk_size_t remaining = depth;
|
|
207
219
|
nk_size_t k = 0;
|
|
208
220
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -214,7 +226,7 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
|
|
|
214
226
|
}
|
|
215
227
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
216
228
|
c_row[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
217
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
229
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
218
230
|
}
|
|
219
231
|
}
|
|
220
232
|
}
|
|
@@ -225,9 +237,10 @@ NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void
|
|
|
225
237
|
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
226
238
|
* vectors naturally, so no separate edge kernel is needed.
|
|
227
239
|
*/
|
|
228
|
-
NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t
|
|
229
|
-
nk_size_t
|
|
230
|
-
|
|
240
|
+
NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
|
|
241
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
242
|
+
nk_size_t c_stride_in_bytes) {
|
|
243
|
+
nk_dots_packed_f32_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
231
244
|
}
|
|
232
245
|
|
|
233
246
|
/**
|
|
@@ -236,19 +249,19 @@ NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, n
|
|
|
236
249
|
* Uses f64 widened accumulation via `vfwmacc_vv_f64m4` for precision.
|
|
237
250
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
238
251
|
*/
|
|
239
|
-
NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t
|
|
240
|
-
nk_size_t
|
|
252
|
+
NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
253
|
+
nk_size_t stride_in_bytes, nk_f64_t *result, nk_size_t result_stride_in_bytes,
|
|
241
254
|
nk_size_t row_start, nk_size_t row_count) {
|
|
242
|
-
nk_size_t const stride_elements =
|
|
243
|
-
nk_size_t const result_stride_elements =
|
|
244
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
255
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
|
|
256
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
257
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
245
258
|
|
|
246
259
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
247
260
|
nk_f32_t const *a_i = vectors + i * stride_elements;
|
|
248
|
-
for (nk_size_t j = i; j <
|
|
261
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
249
262
|
nk_f32_t const *a_j = vectors + j * stride_elements;
|
|
250
|
-
nk_size_t
|
|
251
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
263
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
264
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
252
265
|
nk_size_t remaining = depth;
|
|
253
266
|
nk_size_t k = 0;
|
|
254
267
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -260,31 +273,31 @@ NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t n_ve
|
|
|
260
273
|
}
|
|
261
274
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
262
275
|
nk_f64_t dot = __riscv_vfmv_f_s_f64m1_f64(
|
|
263
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
276
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
264
277
|
result[i * result_stride_elements + j] = dot;
|
|
265
278
|
}
|
|
266
279
|
}
|
|
267
280
|
}
|
|
268
281
|
|
|
269
|
-
#pragma endregion
|
|
282
|
+
#pragma endregion F32 Floats
|
|
270
283
|
|
|
271
|
-
#pragma region
|
|
284
|
+
#pragma region F64 Floats
|
|
272
285
|
|
|
273
286
|
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
274
|
-
nk_size_t
|
|
275
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
287
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
288
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
276
289
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f64_t);
|
|
277
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
290
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
278
291
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f64_t) +
|
|
279
292
|
column_count * sizeof(nk_f64_t); // per-column norms
|
|
280
293
|
}
|
|
281
294
|
|
|
282
295
|
NK_PUBLIC void nk_dots_pack_f64_rvv(nk_f64_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
283
296
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
284
|
-
nk_size_t
|
|
285
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
297
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
298
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
286
299
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f64_t);
|
|
287
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
300
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
288
301
|
|
|
289
302
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
290
303
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -293,12 +306,24 @@ NK_PUBLIC void nk_dots_pack_f64_rvv(nk_f64_t const *b, nk_size_t column_count, n
|
|
|
293
306
|
|
|
294
307
|
nk_f64_t *packed = (nk_f64_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
295
308
|
nk_size_t total = column_count * depth_padded;
|
|
296
|
-
|
|
309
|
+
{
|
|
310
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
311
|
+
nk_size_t total_bytes = total * sizeof(nk_f64_t);
|
|
312
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
313
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
314
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
315
|
+
i += vector_length;
|
|
316
|
+
}
|
|
317
|
+
}
|
|
297
318
|
|
|
298
319
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
299
320
|
nk_f64_t const *src = (nk_f64_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
300
321
|
nk_f64_t *dst = packed + column * depth_padded;
|
|
301
|
-
for (nk_size_t k = 0; k < depth;
|
|
322
|
+
for (nk_size_t k = 0; k < depth;) {
|
|
323
|
+
nk_size_t vector_length = __riscv_vsetvl_e64m8(depth - k);
|
|
324
|
+
__riscv_vse64_v_f64m8(dst + k, __riscv_vle64_v_f64m8(src + k, vector_length), vector_length);
|
|
325
|
+
k += vector_length;
|
|
326
|
+
}
|
|
302
327
|
}
|
|
303
328
|
|
|
304
329
|
// Append per-column norms after packed data
|
|
@@ -341,11 +366,11 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
|
|
|
341
366
|
|
|
342
367
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
343
368
|
nk_f64_t const *b_column = packed_data + column * depth_padded;
|
|
344
|
-
nk_size_t
|
|
345
|
-
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
346
|
-
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
347
|
-
vfloat64m4_t compensation_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
348
|
-
vfloat64m4_t compensation_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
369
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
370
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
371
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
372
|
+
vfloat64m4_t compensation_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
373
|
+
vfloat64m4_t compensation_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
349
374
|
|
|
350
375
|
nk_size_t remaining = depth;
|
|
351
376
|
nk_size_t k = 0;
|
|
@@ -384,9 +409,9 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
|
|
|
384
409
|
// Horizontal reduce
|
|
385
410
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
386
411
|
c_row_0[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
387
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1,
|
|
412
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
|
|
388
413
|
c_row_1[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
389
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1,
|
|
414
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
|
|
390
415
|
}
|
|
391
416
|
}
|
|
392
417
|
// Remainder rows
|
|
@@ -395,9 +420,9 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
|
|
|
395
420
|
nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
396
421
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
397
422
|
nk_f64_t const *b_column = packed_data + column * depth_padded;
|
|
398
|
-
nk_size_t
|
|
399
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
400
|
-
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
423
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
424
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
425
|
+
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
401
426
|
|
|
402
427
|
nk_size_t remaining = depth;
|
|
403
428
|
nk_size_t k = 0;
|
|
@@ -419,7 +444,7 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
|
|
|
419
444
|
|
|
420
445
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
421
446
|
c_row[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
422
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
447
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
423
448
|
}
|
|
424
449
|
}
|
|
425
450
|
}
|
|
@@ -427,9 +452,10 @@ NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void
|
|
|
427
452
|
/**
|
|
428
453
|
* @brief Public f64 packed GEMM wrapper matching the declared signature in dots.h.
|
|
429
454
|
*/
|
|
430
|
-
NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t
|
|
431
|
-
nk_size_t
|
|
432
|
-
|
|
455
|
+
NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
|
|
456
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
457
|
+
nk_size_t c_stride_in_bytes) {
|
|
458
|
+
nk_dots_packed_f64_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
433
459
|
}
|
|
434
460
|
|
|
435
461
|
/**
|
|
@@ -438,20 +464,20 @@ NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, n
|
|
|
438
464
|
* Uses Kahan compensation over full depth for precision.
|
|
439
465
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
440
466
|
*/
|
|
441
|
-
NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t
|
|
442
|
-
nk_size_t
|
|
467
|
+
NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
468
|
+
nk_size_t stride_in_bytes, nk_f64_t *result, nk_size_t result_stride_in_bytes,
|
|
443
469
|
nk_size_t row_start, nk_size_t row_count) {
|
|
444
|
-
nk_size_t const stride_elements =
|
|
445
|
-
nk_size_t const result_stride_elements =
|
|
446
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
470
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
|
|
471
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
472
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
447
473
|
|
|
448
474
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
449
475
|
nk_f64_t const *a_i = vectors + i * stride_elements;
|
|
450
|
-
for (nk_size_t j = i; j <
|
|
476
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
451
477
|
nk_f64_t const *a_j = vectors + j * stride_elements;
|
|
452
|
-
nk_size_t
|
|
453
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
454
|
-
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
478
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
479
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
480
|
+
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
455
481
|
|
|
456
482
|
nk_size_t remaining = depth;
|
|
457
483
|
nk_size_t k = 0;
|
|
@@ -473,15 +499,15 @@ NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t n_ve
|
|
|
473
499
|
|
|
474
500
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
475
501
|
nk_f64_t dot = __riscv_vfmv_f_s_f64m1_f64(
|
|
476
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
502
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
477
503
|
result[i * result_stride_elements + j] = dot;
|
|
478
504
|
}
|
|
479
505
|
}
|
|
480
506
|
}
|
|
481
507
|
|
|
482
|
-
#pragma endregion
|
|
508
|
+
#pragma endregion F64 Floats
|
|
483
509
|
|
|
484
|
-
#pragma region
|
|
510
|
+
#pragma region E2M3 Floats
|
|
485
511
|
|
|
486
512
|
/**
|
|
487
513
|
* @brief Scalar conversion helper: e2m3 byte → signed i8 (value × 16).
|
|
@@ -496,10 +522,10 @@ NK_INTERNAL nk_i8_t nk_e2m3_to_i8_rvv_(nk_u8_t raw) {
|
|
|
496
522
|
}
|
|
497
523
|
|
|
498
524
|
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
499
|
-
nk_size_t
|
|
500
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
525
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
526
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
501
527
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
|
|
502
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
528
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
503
529
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i8_t) +
|
|
504
530
|
column_count * sizeof(nk_f32_t); // per-column norms
|
|
505
531
|
}
|
|
@@ -512,10 +538,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_rvv(nk_size_t column_count, nk_size
|
|
|
512
538
|
*/
|
|
513
539
|
NK_PUBLIC void nk_dots_pack_e2m3_rvv(nk_e2m3_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
514
540
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
515
|
-
nk_size_t
|
|
516
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
541
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
542
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
517
543
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
|
|
518
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
544
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
519
545
|
|
|
520
546
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
521
547
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -524,7 +550,15 @@ NK_PUBLIC void nk_dots_pack_e2m3_rvv(nk_e2m3_t const *b, nk_size_t column_count,
|
|
|
524
550
|
|
|
525
551
|
nk_i8_t *packed = (nk_i8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
526
552
|
nk_size_t total = column_count * depth_padded;
|
|
527
|
-
|
|
553
|
+
{
|
|
554
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
555
|
+
nk_size_t total_bytes = total * sizeof(nk_i8_t);
|
|
556
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
557
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
558
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
559
|
+
i += vector_length;
|
|
560
|
+
}
|
|
561
|
+
}
|
|
528
562
|
|
|
529
563
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
530
564
|
nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
@@ -584,11 +618,11 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
|
|
|
584
618
|
|
|
585
619
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
586
620
|
nk_i8_t const *b_column = packed_data + column * depth_padded;
|
|
587
|
-
nk_size_t
|
|
588
|
-
vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
589
|
-
vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
590
|
-
vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
591
|
-
vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
621
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
622
|
+
vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
623
|
+
vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
624
|
+
vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
625
|
+
vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
592
626
|
|
|
593
627
|
nk_size_t remaining = depth;
|
|
594
628
|
nk_size_t k = 0;
|
|
@@ -654,16 +688,16 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
|
|
|
654
688
|
// Horizontal reduce and convert to f32 with scaling
|
|
655
689
|
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
656
690
|
c_row_0[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
657
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1,
|
|
691
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, max_vector_length)) *
|
|
658
692
|
lut_scale_reciprocal;
|
|
659
693
|
c_row_1[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
660
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1,
|
|
694
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, max_vector_length)) *
|
|
661
695
|
lut_scale_reciprocal;
|
|
662
696
|
c_row_2[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
663
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1,
|
|
697
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, max_vector_length)) *
|
|
664
698
|
lut_scale_reciprocal;
|
|
665
699
|
c_row_3[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
666
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1,
|
|
700
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, max_vector_length)) *
|
|
667
701
|
lut_scale_reciprocal;
|
|
668
702
|
}
|
|
669
703
|
}
|
|
@@ -673,8 +707,8 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
|
|
|
673
707
|
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
674
708
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
675
709
|
nk_i8_t const *b_column = packed_data + column * depth_padded;
|
|
676
|
-
nk_size_t
|
|
677
|
-
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
710
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
711
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
678
712
|
nk_size_t remaining = depth;
|
|
679
713
|
nk_size_t k = 0;
|
|
680
714
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -693,7 +727,7 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
|
|
|
693
727
|
}
|
|
694
728
|
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
695
729
|
c_row[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
696
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1,
|
|
730
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length)) *
|
|
697
731
|
lut_scale_reciprocal;
|
|
698
732
|
}
|
|
699
733
|
}
|
|
@@ -702,9 +736,10 @@ NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, voi
|
|
|
702
736
|
/**
|
|
703
737
|
* @brief Public e2m3 packed GEMM wrapper matching the declared signature in dots.h.
|
|
704
738
|
*/
|
|
705
|
-
NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t
|
|
706
|
-
nk_size_t
|
|
707
|
-
|
|
739
|
+
NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
|
|
740
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
741
|
+
nk_size_t c_stride_in_bytes) {
|
|
742
|
+
nk_dots_packed_e2m3_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
708
743
|
}
|
|
709
744
|
|
|
710
745
|
/**
|
|
@@ -713,20 +748,20 @@ NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed,
|
|
|
713
748
|
* Uses integer i8 LUT arithmetic with i32 accumulation, scaled by 1/256.
|
|
714
749
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
715
750
|
*/
|
|
716
|
-
NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t
|
|
717
|
-
nk_size_t
|
|
751
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
752
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
|
|
718
753
|
nk_size_t row_start, nk_size_t row_count) {
|
|
719
754
|
nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
|
|
720
755
|
|
|
721
|
-
nk_size_t const result_stride_elements =
|
|
722
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
756
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
757
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
723
758
|
|
|
724
759
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
725
|
-
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i *
|
|
726
|
-
for (nk_size_t j = i; j <
|
|
727
|
-
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j *
|
|
728
|
-
nk_size_t
|
|
729
|
-
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
760
|
+
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride_in_bytes;
|
|
761
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
762
|
+
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride_in_bytes;
|
|
763
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
764
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
730
765
|
nk_size_t remaining = depth;
|
|
731
766
|
nk_size_t k = 0;
|
|
732
767
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -755,16 +790,16 @@ NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t n_
|
|
|
755
790
|
}
|
|
756
791
|
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
757
792
|
nk_f32_t dot = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
758
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1,
|
|
793
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length)) *
|
|
759
794
|
lut_scale_reciprocal;
|
|
760
795
|
result[i * result_stride_elements + j] = dot;
|
|
761
796
|
}
|
|
762
797
|
}
|
|
763
798
|
}
|
|
764
799
|
|
|
765
|
-
#pragma endregion
|
|
800
|
+
#pragma endregion E2M3 Floats
|
|
766
801
|
|
|
767
|
-
#pragma region
|
|
802
|
+
#pragma region E3M2 Floats
|
|
768
803
|
|
|
769
804
|
/**
|
|
770
805
|
* @brief Scalar conversion helper: e3m2 byte → signed i16 (value × 16).
|
|
@@ -779,10 +814,10 @@ NK_INTERNAL nk_i16_t nk_e3m2_to_i16_rvv_(nk_u8_t raw) {
|
|
|
779
814
|
}
|
|
780
815
|
|
|
781
816
|
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
782
|
-
nk_size_t
|
|
783
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
817
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m2();
|
|
818
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
784
819
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_i16_t);
|
|
785
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
820
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
786
821
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i16_t) +
|
|
787
822
|
column_count * sizeof(nk_f32_t); // per-column norms
|
|
788
823
|
}
|
|
@@ -795,10 +830,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_rvv(nk_size_t column_count, nk_size
|
|
|
795
830
|
*/
|
|
796
831
|
NK_PUBLIC void nk_dots_pack_e3m2_rvv(nk_e3m2_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
797
832
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
798
|
-
nk_size_t
|
|
799
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
833
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e16m2();
|
|
834
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
800
835
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_i16_t);
|
|
801
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
836
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
802
837
|
|
|
803
838
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
804
839
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -807,7 +842,15 @@ NK_PUBLIC void nk_dots_pack_e3m2_rvv(nk_e3m2_t const *b, nk_size_t column_count,
|
|
|
807
842
|
|
|
808
843
|
nk_i16_t *packed = (nk_i16_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
809
844
|
nk_size_t total = column_count * depth_padded;
|
|
810
|
-
|
|
845
|
+
{
|
|
846
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
847
|
+
nk_size_t total_bytes = total * sizeof(nk_i16_t);
|
|
848
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
849
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
850
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
851
|
+
i += vector_length;
|
|
852
|
+
}
|
|
853
|
+
}
|
|
811
854
|
|
|
812
855
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
813
856
|
nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
@@ -862,9 +905,9 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
|
|
|
862
905
|
|
|
863
906
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
864
907
|
nk_i16_t const *b_column = packed_data + column * depth_padded;
|
|
865
|
-
nk_size_t
|
|
866
|
-
vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
867
|
-
vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
908
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
909
|
+
vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
910
|
+
vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
868
911
|
|
|
869
912
|
nk_size_t remaining = depth;
|
|
870
913
|
nk_size_t k = 0;
|
|
@@ -916,10 +959,10 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
|
|
|
916
959
|
// Horizontal reduce and convert to f32 with scaling
|
|
917
960
|
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
918
961
|
c_row_0[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
919
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1,
|
|
962
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, max_vector_length)) *
|
|
920
963
|
lut_scale_reciprocal;
|
|
921
964
|
c_row_1[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
922
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1,
|
|
965
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, max_vector_length)) *
|
|
923
966
|
lut_scale_reciprocal;
|
|
924
967
|
}
|
|
925
968
|
}
|
|
@@ -929,8 +972,8 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
|
|
|
929
972
|
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
930
973
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
931
974
|
nk_i16_t const *b_column = packed_data + column * depth_padded;
|
|
932
|
-
nk_size_t
|
|
933
|
-
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
975
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
976
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
934
977
|
nk_size_t remaining = depth;
|
|
935
978
|
nk_size_t k = 0;
|
|
936
979
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -951,7 +994,7 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
|
|
|
951
994
|
}
|
|
952
995
|
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
953
996
|
c_row[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
954
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1,
|
|
997
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length)) *
|
|
955
998
|
lut_scale_reciprocal;
|
|
956
999
|
}
|
|
957
1000
|
}
|
|
@@ -960,9 +1003,10 @@ NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, voi
|
|
|
960
1003
|
/**
|
|
961
1004
|
* @brief Public e3m2 packed GEMM wrapper matching the declared signature in dots.h.
|
|
962
1005
|
*/
|
|
963
|
-
NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t
|
|
964
|
-
nk_size_t
|
|
965
|
-
|
|
1006
|
+
NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
|
|
1007
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1008
|
+
nk_size_t c_stride_in_bytes) {
|
|
1009
|
+
nk_dots_packed_e3m2_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
966
1010
|
}
|
|
967
1011
|
|
|
968
1012
|
/**
|
|
@@ -971,20 +1015,20 @@ NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed,
|
|
|
971
1015
|
* Uses integer i16 LUT arithmetic with i32 widening MAC, scaled by 1/256.
|
|
972
1016
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
973
1017
|
*/
|
|
974
|
-
NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t
|
|
975
|
-
nk_size_t
|
|
1018
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
1019
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
|
|
976
1020
|
nk_size_t row_start, nk_size_t row_count) {
|
|
977
1021
|
nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
|
|
978
1022
|
|
|
979
|
-
nk_size_t const result_stride_elements =
|
|
980
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
1023
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1024
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
981
1025
|
|
|
982
1026
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
983
|
-
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i *
|
|
984
|
-
for (nk_size_t j = i; j <
|
|
985
|
-
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j *
|
|
986
|
-
nk_size_t
|
|
987
|
-
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
1027
|
+
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride_in_bytes;
|
|
1028
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
1029
|
+
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride_in_bytes;
|
|
1030
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
1031
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
988
1032
|
nk_size_t remaining = depth;
|
|
989
1033
|
nk_size_t k = 0;
|
|
990
1034
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -1023,16 +1067,16 @@ NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_
|
|
|
1023
1067
|
}
|
|
1024
1068
|
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
1025
1069
|
nk_f32_t dot = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1026
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1,
|
|
1070
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length)) *
|
|
1027
1071
|
lut_scale_reciprocal;
|
|
1028
1072
|
result[i * result_stride_elements + j] = dot;
|
|
1029
1073
|
}
|
|
1030
1074
|
}
|
|
1031
1075
|
}
|
|
1032
1076
|
|
|
1033
|
-
#pragma endregion
|
|
1077
|
+
#pragma endregion E3M2 Floats
|
|
1034
1078
|
|
|
1035
|
-
#pragma region
|
|
1079
|
+
#pragma region BF16 Floats
|
|
1036
1080
|
|
|
1037
1081
|
/**
|
|
1038
1082
|
* @brief Compute the packed buffer size for bf16 GEMM (B stored as f32).
|
|
@@ -1041,11 +1085,11 @@ NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_
|
|
|
1041
1085
|
* Layout: column-panel with depth-contiguous f32 values, cache-line padding.
|
|
1042
1086
|
*/
|
|
1043
1087
|
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1044
|
-
nk_size_t
|
|
1045
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
1088
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1089
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1046
1090
|
// Break power-of-2 strides for cache associativity
|
|
1047
1091
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1048
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
1092
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1049
1093
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
1050
1094
|
column_count * sizeof(nk_f32_t); // per-column norms
|
|
1051
1095
|
}
|
|
@@ -1058,10 +1102,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_rvv(nk_size_t column_count, nk_size
|
|
|
1058
1102
|
*/
|
|
1059
1103
|
NK_PUBLIC void nk_dots_pack_bf16_rvv(nk_bf16_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1060
1104
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1061
|
-
nk_size_t
|
|
1062
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
1105
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1106
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1063
1107
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1064
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
1108
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1065
1109
|
|
|
1066
1110
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1067
1111
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -1070,7 +1114,15 @@ NK_PUBLIC void nk_dots_pack_bf16_rvv(nk_bf16_t const *b, nk_size_t column_count,
|
|
|
1070
1114
|
|
|
1071
1115
|
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1072
1116
|
nk_size_t total = column_count * depth_padded;
|
|
1073
|
-
|
|
1117
|
+
{
|
|
1118
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
1119
|
+
nk_size_t total_bytes = total * sizeof(nk_f32_t);
|
|
1120
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
1121
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
1122
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
1123
|
+
i += vector_length;
|
|
1124
|
+
}
|
|
1125
|
+
}
|
|
1074
1126
|
|
|
1075
1127
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1076
1128
|
nk_u16_t const *src = (nk_u16_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
@@ -1133,11 +1185,11 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
|
|
|
1133
1185
|
|
|
1134
1186
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1135
1187
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
1136
|
-
nk_size_t
|
|
1137
|
-
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1138
|
-
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1139
|
-
vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1140
|
-
vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1188
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1189
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1190
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1191
|
+
vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1192
|
+
vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1141
1193
|
|
|
1142
1194
|
nk_size_t remaining = depth;
|
|
1143
1195
|
nk_size_t k = 0;
|
|
@@ -1166,13 +1218,13 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
|
|
|
1166
1218
|
// Horizontal reduce and narrow to f32
|
|
1167
1219
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1168
1220
|
c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1169
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1,
|
|
1221
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
|
|
1170
1222
|
c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1171
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1,
|
|
1223
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
|
|
1172
1224
|
c_row_2[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1173
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1,
|
|
1225
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, max_vector_length));
|
|
1174
1226
|
c_row_3[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1175
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1,
|
|
1227
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, max_vector_length));
|
|
1176
1228
|
}
|
|
1177
1229
|
}
|
|
1178
1230
|
// Remainder rows (mr < 4)
|
|
@@ -1181,8 +1233,8 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
|
|
|
1181
1233
|
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
1182
1234
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1183
1235
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
1184
|
-
nk_size_t
|
|
1185
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1236
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1237
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1186
1238
|
nk_size_t remaining = depth;
|
|
1187
1239
|
nk_size_t k = 0;
|
|
1188
1240
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -1195,7 +1247,7 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
|
|
|
1195
1247
|
}
|
|
1196
1248
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1197
1249
|
c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1198
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
1250
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
1199
1251
|
}
|
|
1200
1252
|
}
|
|
1201
1253
|
}
|
|
@@ -1206,9 +1258,10 @@ NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, voi
|
|
|
1206
1258
|
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
1207
1259
|
* vectors naturally, so no separate edge kernel is needed.
|
|
1208
1260
|
*/
|
|
1209
|
-
NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t
|
|
1210
|
-
nk_size_t
|
|
1211
|
-
|
|
1261
|
+
NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
|
|
1262
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1263
|
+
nk_size_t c_stride_in_bytes) {
|
|
1264
|
+
nk_dots_packed_bf16_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1212
1265
|
}
|
|
1213
1266
|
|
|
1214
1267
|
/**
|
|
@@ -1219,18 +1272,18 @@ NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed,
|
|
|
1219
1272
|
* Stride is in bytes.
|
|
1220
1273
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
1221
1274
|
*/
|
|
1222
|
-
NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t
|
|
1223
|
-
nk_size_t
|
|
1275
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
1276
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
|
|
1224
1277
|
nk_size_t row_start, nk_size_t row_count) {
|
|
1225
|
-
nk_size_t const result_stride_elements =
|
|
1226
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
1278
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1279
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
1227
1280
|
|
|
1228
1281
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
1229
|
-
nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i *
|
|
1230
|
-
for (nk_size_t j = i; j <
|
|
1231
|
-
nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j *
|
|
1232
|
-
nk_size_t
|
|
1233
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1282
|
+
nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride_in_bytes);
|
|
1283
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
1284
|
+
nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride_in_bytes);
|
|
1285
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1286
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1234
1287
|
nk_size_t remaining = depth;
|
|
1235
1288
|
nk_size_t k = 0;
|
|
1236
1289
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -1244,15 +1297,15 @@ NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_
|
|
|
1244
1297
|
}
|
|
1245
1298
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1246
1299
|
nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1247
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
1300
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
1248
1301
|
result[i * result_stride_elements + j] = dot;
|
|
1249
1302
|
}
|
|
1250
1303
|
}
|
|
1251
1304
|
}
|
|
1252
1305
|
|
|
1253
|
-
#pragma endregion
|
|
1306
|
+
#pragma endregion BF16 Floats
|
|
1254
1307
|
|
|
1255
|
-
#pragma region
|
|
1308
|
+
#pragma region F16 Floats
|
|
1256
1309
|
|
|
1257
1310
|
/**
|
|
1258
1311
|
* @brief Compute the packed buffer size for f16 GEMM (B stored as f32).
|
|
@@ -1261,11 +1314,11 @@ NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_
|
|
|
1261
1314
|
* Layout: column-panel with depth-contiguous f32 values, cache-line padding.
|
|
1262
1315
|
*/
|
|
1263
1316
|
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1264
|
-
nk_size_t
|
|
1265
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
1317
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1318
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1266
1319
|
// Break power-of-2 strides for cache associativity
|
|
1267
1320
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1268
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
1321
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1269
1322
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
1270
1323
|
column_count * sizeof(nk_f32_t); // per-column norms
|
|
1271
1324
|
}
|
|
@@ -1278,10 +1331,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_rvv(nk_size_t column_count, nk_size_
|
|
|
1278
1331
|
*/
|
|
1279
1332
|
NK_PUBLIC void nk_dots_pack_f16_rvv(nk_f16_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1280
1333
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1281
|
-
nk_size_t
|
|
1282
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
1334
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1335
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1283
1336
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1284
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
1337
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1285
1338
|
|
|
1286
1339
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1287
1340
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -1290,7 +1343,15 @@ NK_PUBLIC void nk_dots_pack_f16_rvv(nk_f16_t const *b, nk_size_t column_count, n
|
|
|
1290
1343
|
|
|
1291
1344
|
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1292
1345
|
nk_size_t total = column_count * depth_padded;
|
|
1293
|
-
|
|
1346
|
+
{
|
|
1347
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
1348
|
+
nk_size_t total_bytes = total * sizeof(nk_f32_t);
|
|
1349
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
1350
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
1351
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
1352
|
+
i += vector_length;
|
|
1353
|
+
}
|
|
1354
|
+
}
|
|
1294
1355
|
|
|
1295
1356
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1296
1357
|
nk_f16_t const *src = (nk_f16_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
@@ -1346,11 +1407,11 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
|
|
|
1346
1407
|
|
|
1347
1408
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1348
1409
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
1349
|
-
nk_size_t
|
|
1350
|
-
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1351
|
-
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1352
|
-
vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1353
|
-
vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1410
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1411
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1412
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1413
|
+
vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1414
|
+
vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1354
1415
|
|
|
1355
1416
|
nk_size_t remaining = depth;
|
|
1356
1417
|
nk_size_t k = 0;
|
|
@@ -1379,13 +1440,13 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
|
|
|
1379
1440
|
// Horizontal reduce and narrow to f32
|
|
1380
1441
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1381
1442
|
c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1382
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1,
|
|
1443
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
|
|
1383
1444
|
c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1384
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1,
|
|
1445
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
|
|
1385
1446
|
c_row_2[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1386
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1,
|
|
1447
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, max_vector_length));
|
|
1387
1448
|
c_row_3[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1388
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1,
|
|
1449
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, max_vector_length));
|
|
1389
1450
|
}
|
|
1390
1451
|
}
|
|
1391
1452
|
// Remainder rows (mr < 4)
|
|
@@ -1394,8 +1455,8 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
|
|
|
1394
1455
|
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
1395
1456
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1396
1457
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
1397
|
-
nk_size_t
|
|
1398
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1458
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1459
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1399
1460
|
nk_size_t remaining = depth;
|
|
1400
1461
|
nk_size_t k = 0;
|
|
1401
1462
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -1408,7 +1469,7 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
|
|
|
1408
1469
|
}
|
|
1409
1470
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1410
1471
|
c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1411
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
1472
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
1412
1473
|
}
|
|
1413
1474
|
}
|
|
1414
1475
|
}
|
|
@@ -1419,9 +1480,10 @@ NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void
|
|
|
1419
1480
|
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
1420
1481
|
* vectors naturally, so no separate edge kernel is needed.
|
|
1421
1482
|
*/
|
|
1422
|
-
NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t
|
|
1423
|
-
nk_size_t
|
|
1424
|
-
|
|
1483
|
+
NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
|
|
1484
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1485
|
+
nk_size_t c_stride_in_bytes) {
|
|
1486
|
+
nk_dots_packed_f16_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1425
1487
|
}
|
|
1426
1488
|
|
|
1427
1489
|
/**
|
|
@@ -1432,18 +1494,18 @@ NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, n
|
|
|
1432
1494
|
* Stride is in bytes.
|
|
1433
1495
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
1434
1496
|
*/
|
|
1435
|
-
NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t
|
|
1436
|
-
nk_size_t
|
|
1497
|
+
NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
1498
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
|
|
1437
1499
|
nk_size_t row_start, nk_size_t row_count) {
|
|
1438
|
-
nk_size_t const result_stride_elements =
|
|
1439
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
1500
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1501
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
1440
1502
|
|
|
1441
1503
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
1442
|
-
nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i *
|
|
1443
|
-
for (nk_size_t j = i; j <
|
|
1444
|
-
nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j *
|
|
1445
|
-
nk_size_t
|
|
1446
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
1504
|
+
nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride_in_bytes);
|
|
1505
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
1506
|
+
nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride_in_bytes);
|
|
1507
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
1508
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
1447
1509
|
nk_size_t remaining = depth;
|
|
1448
1510
|
nk_size_t k = 0;
|
|
1449
1511
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -1457,15 +1519,15 @@ NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_ve
|
|
|
1457
1519
|
}
|
|
1458
1520
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1459
1521
|
nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1460
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
1522
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
1461
1523
|
result[i * result_stride_elements + j] = dot;
|
|
1462
1524
|
}
|
|
1463
1525
|
}
|
|
1464
1526
|
}
|
|
1465
1527
|
|
|
1466
|
-
#pragma endregion
|
|
1528
|
+
#pragma endregion F16 Floats
|
|
1467
1529
|
|
|
1468
|
-
#pragma region
|
|
1530
|
+
#pragma region I8 Integers
|
|
1469
1531
|
|
|
1470
1532
|
/**
|
|
1471
1533
|
* @brief Compute the packed buffer size for i8 GEMM (B stored as i8).
|
|
@@ -1474,11 +1536,11 @@ NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_ve
|
|
|
1474
1536
|
* Layout: column-panel with depth-contiguous i8 values, cache-line padding.
|
|
1475
1537
|
*/
|
|
1476
1538
|
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1477
|
-
nk_size_t
|
|
1478
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
1539
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
1540
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1479
1541
|
// Break power-of-2 strides for cache associativity
|
|
1480
1542
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
|
|
1481
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
1543
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1482
1544
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i8_t) +
|
|
1483
1545
|
column_count * sizeof(nk_u32_t); // per-column norms
|
|
1484
1546
|
}
|
|
@@ -1491,10 +1553,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_rvv(nk_size_t column_count, nk_size_t
|
|
|
1491
1553
|
*/
|
|
1492
1554
|
NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1493
1555
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1494
|
-
nk_size_t
|
|
1495
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
1556
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
1557
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1496
1558
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
|
|
1497
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
1559
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1498
1560
|
|
|
1499
1561
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1500
1562
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -1503,12 +1565,25 @@ NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t column_count, nk_
|
|
|
1503
1565
|
|
|
1504
1566
|
nk_i8_t *packed = (nk_i8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1505
1567
|
nk_size_t total = column_count * depth_padded;
|
|
1506
|
-
|
|
1568
|
+
{
|
|
1569
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
1570
|
+
nk_size_t total_bytes = total * sizeof(nk_i8_t);
|
|
1571
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
1572
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
1573
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
1574
|
+
i += vector_length;
|
|
1575
|
+
}
|
|
1576
|
+
}
|
|
1507
1577
|
|
|
1508
1578
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1509
1579
|
nk_i8_t const *src = (nk_i8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1510
1580
|
nk_i8_t *dst = packed + column * depth_padded;
|
|
1511
|
-
for (nk_size_t k = 0; k < depth;
|
|
1581
|
+
for (nk_size_t k = 0; k < depth;) {
|
|
1582
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(depth - k);
|
|
1583
|
+
__riscv_vse8_v_u8m8((nk_u8_t *)(dst + k), __riscv_vle8_v_u8m8((nk_u8_t const *)(src + k), vector_length),
|
|
1584
|
+
vector_length);
|
|
1585
|
+
k += vector_length;
|
|
1586
|
+
}
|
|
1512
1587
|
}
|
|
1513
1588
|
|
|
1514
1589
|
// Append per-column norms after packed data
|
|
@@ -1524,7 +1599,7 @@ NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t column_count, nk_
|
|
|
1524
1599
|
*
|
|
1525
1600
|
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
1526
1601
|
* - Load i8 values from A and pre-packed i8 values from B
|
|
1527
|
-
* - Widening multiply: i8
|
|
1602
|
+
* - Widening multiply: i8 × i8 → i16 via `vwmul`
|
|
1528
1603
|
* - Widen-accumulate: i32 += i16 via `vwadd_wv`
|
|
1529
1604
|
* - Horizontal reduce via `vredsum`
|
|
1530
1605
|
*
|
|
@@ -1560,11 +1635,11 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
|
|
|
1560
1635
|
|
|
1561
1636
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1562
1637
|
nk_i8_t const *b_column = packed_data + column * depth_padded;
|
|
1563
|
-
nk_size_t
|
|
1564
|
-
vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
1565
|
-
vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
1566
|
-
vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
1567
|
-
vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
1638
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
1639
|
+
vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
1640
|
+
vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
1641
|
+
vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
1642
|
+
vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
1568
1643
|
|
|
1569
1644
|
nk_size_t remaining = depth;
|
|
1570
1645
|
nk_size_t k = 0;
|
|
@@ -1592,13 +1667,13 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
|
|
|
1592
1667
|
// Horizontal reduce
|
|
1593
1668
|
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
1594
1669
|
c_row_0[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1595
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1,
|
|
1670
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, max_vector_length));
|
|
1596
1671
|
c_row_1[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1597
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1,
|
|
1672
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, max_vector_length));
|
|
1598
1673
|
c_row_2[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1599
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1,
|
|
1674
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, max_vector_length));
|
|
1600
1675
|
c_row_3[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1601
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1,
|
|
1676
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, max_vector_length));
|
|
1602
1677
|
}
|
|
1603
1678
|
}
|
|
1604
1679
|
// Remainder rows (mr < 4)
|
|
@@ -1607,8 +1682,8 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
|
|
|
1607
1682
|
nk_i32_t *c_row = (nk_i32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
1608
1683
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1609
1684
|
nk_i8_t const *b_column = packed_data + column * depth_padded;
|
|
1610
|
-
nk_size_t
|
|
1611
|
-
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
1685
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
1686
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
1612
1687
|
nk_size_t remaining = depth;
|
|
1613
1688
|
nk_size_t k = 0;
|
|
1614
1689
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -1621,7 +1696,7 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
|
|
|
1621
1696
|
}
|
|
1622
1697
|
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
1623
1698
|
c_row[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1624
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1,
|
|
1699
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length));
|
|
1625
1700
|
}
|
|
1626
1701
|
}
|
|
1627
1702
|
}
|
|
@@ -1632,31 +1707,32 @@ NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void co
|
|
|
1632
1707
|
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
1633
1708
|
* vectors naturally, so no separate edge kernel is needed.
|
|
1634
1709
|
*/
|
|
1635
|
-
NK_PUBLIC void nk_dots_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t
|
|
1636
|
-
nk_size_t
|
|
1637
|
-
|
|
1710
|
+
NK_PUBLIC void nk_dots_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t rows,
|
|
1711
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1712
|
+
nk_size_t c_stride_in_bytes) {
|
|
1713
|
+
nk_dots_packed_i8_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1638
1714
|
}
|
|
1639
1715
|
|
|
1640
1716
|
/**
|
|
1641
1717
|
* @brief Symmetric i8 GEMM: C = A * A^T, upper triangle + mirror.
|
|
1642
1718
|
*
|
|
1643
1719
|
* Uses integer i8 arithmetic with i32 accumulation.
|
|
1644
|
-
* Both inputs are i8, widened via i8
|
|
1720
|
+
* Both inputs are i8, widened via i8 × i8 → i16 → i32 accumulation.
|
|
1645
1721
|
* Stride is in bytes.
|
|
1646
1722
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
1647
1723
|
*/
|
|
1648
|
-
NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t
|
|
1649
|
-
nk_i32_t *result, nk_size_t
|
|
1650
|
-
nk_size_t row_count) {
|
|
1651
|
-
nk_size_t const result_stride_elements =
|
|
1652
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
1724
|
+
NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
1725
|
+
nk_size_t stride_in_bytes, nk_i32_t *result, nk_size_t result_stride_in_bytes,
|
|
1726
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1727
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_i32_t);
|
|
1728
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
1653
1729
|
|
|
1654
1730
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
1655
|
-
nk_i8_t const *a_i = (nk_i8_t const *)((char const *)vectors + i *
|
|
1656
|
-
for (nk_size_t j = i; j <
|
|
1657
|
-
nk_i8_t const *a_j = (nk_i8_t const *)((char const *)vectors + j *
|
|
1658
|
-
nk_size_t
|
|
1659
|
-
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
1731
|
+
nk_i8_t const *a_i = (nk_i8_t const *)((char const *)vectors + i * stride_in_bytes);
|
|
1732
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
1733
|
+
nk_i8_t const *a_j = (nk_i8_t const *)((char const *)vectors + j * stride_in_bytes);
|
|
1734
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
1735
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
1660
1736
|
nk_size_t remaining = depth;
|
|
1661
1737
|
nk_size_t k = 0;
|
|
1662
1738
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -1669,15 +1745,15 @@ NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vect
|
|
|
1669
1745
|
}
|
|
1670
1746
|
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
1671
1747
|
nk_i32_t dot = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1672
|
-
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1,
|
|
1748
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, max_vector_length));
|
|
1673
1749
|
result[i * result_stride_elements + j] = dot;
|
|
1674
1750
|
}
|
|
1675
1751
|
}
|
|
1676
1752
|
}
|
|
1677
1753
|
|
|
1678
|
-
#pragma endregion
|
|
1754
|
+
#pragma endregion I8 Integers
|
|
1679
1755
|
|
|
1680
|
-
#pragma region
|
|
1756
|
+
#pragma region U8 Integers
|
|
1681
1757
|
|
|
1682
1758
|
/**
|
|
1683
1759
|
* @brief Compute the packed buffer size for u8 GEMM (B stored as u8).
|
|
@@ -1686,11 +1762,11 @@ NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vect
|
|
|
1686
1762
|
* Layout: column-panel with depth-contiguous u8 values, cache-line padding.
|
|
1687
1763
|
*/
|
|
1688
1764
|
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1689
|
-
nk_size_t
|
|
1690
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
1765
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
1766
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1691
1767
|
// Break power-of-2 strides for cache associativity
|
|
1692
1768
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_u8_t);
|
|
1693
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
1769
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1694
1770
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_u8_t) +
|
|
1695
1771
|
column_count * sizeof(nk_u32_t); // per-column norms
|
|
1696
1772
|
}
|
|
@@ -1703,10 +1779,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_rvv(nk_size_t column_count, nk_size_t
|
|
|
1703
1779
|
*/
|
|
1704
1780
|
NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1705
1781
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1706
|
-
nk_size_t
|
|
1707
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
1782
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e8m1();
|
|
1783
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1708
1784
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_u8_t);
|
|
1709
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
1785
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1710
1786
|
|
|
1711
1787
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1712
1788
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -1715,12 +1791,24 @@ NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t column_count, nk_
|
|
|
1715
1791
|
|
|
1716
1792
|
nk_u8_t *packed = (nk_u8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1717
1793
|
nk_size_t total = column_count * depth_padded;
|
|
1718
|
-
|
|
1794
|
+
{
|
|
1795
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
1796
|
+
nk_size_t total_bytes = total * sizeof(nk_u8_t);
|
|
1797
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
1798
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
1799
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
1800
|
+
i += vector_length;
|
|
1801
|
+
}
|
|
1802
|
+
}
|
|
1719
1803
|
|
|
1720
1804
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1721
1805
|
nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1722
1806
|
nk_u8_t *dst = packed + column * depth_padded;
|
|
1723
|
-
for (nk_size_t k = 0; k < depth;
|
|
1807
|
+
for (nk_size_t k = 0; k < depth;) {
|
|
1808
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(depth - k);
|
|
1809
|
+
__riscv_vse8_v_u8m8(dst + k, __riscv_vle8_v_u8m8(src + k, vector_length), vector_length);
|
|
1810
|
+
k += vector_length;
|
|
1811
|
+
}
|
|
1724
1812
|
}
|
|
1725
1813
|
|
|
1726
1814
|
// Append per-column norms after packed data
|
|
@@ -1736,7 +1824,7 @@ NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t column_count, nk_
|
|
|
1736
1824
|
*
|
|
1737
1825
|
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
1738
1826
|
* - Load u8 values from A and pre-packed u8 values from B
|
|
1739
|
-
* - Widening multiply: u8
|
|
1827
|
+
* - Widening multiply: u8 × u8 → u16 via `vwmulu`
|
|
1740
1828
|
* - Widen-accumulate: u32 += u16 via `vwaddu_wv`
|
|
1741
1829
|
* - Horizontal reduce via `vredsum`
|
|
1742
1830
|
*
|
|
@@ -1772,11 +1860,11 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
|
|
|
1772
1860
|
|
|
1773
1861
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1774
1862
|
nk_u8_t const *b_column = packed_data + column * depth_padded;
|
|
1775
|
-
nk_size_t
|
|
1776
|
-
vuint32m4_t accumulator_0_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
1777
|
-
vuint32m4_t accumulator_1_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
1778
|
-
vuint32m4_t accumulator_2_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
1779
|
-
vuint32m4_t accumulator_3_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
1863
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
1864
|
+
vuint32m4_t accumulator_0_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
1865
|
+
vuint32m4_t accumulator_1_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
1866
|
+
vuint32m4_t accumulator_2_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
1867
|
+
vuint32m4_t accumulator_3_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
1780
1868
|
|
|
1781
1869
|
nk_size_t remaining = depth;
|
|
1782
1870
|
nk_size_t k = 0;
|
|
@@ -1804,13 +1892,13 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
|
|
|
1804
1892
|
// Horizontal reduce
|
|
1805
1893
|
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
1806
1894
|
c_row_0[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1807
|
-
__riscv_vredsum_vs_u32m4_u32m1(accumulator_0_u32m4, zero_u32m1,
|
|
1895
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_0_u32m4, zero_u32m1, max_vector_length));
|
|
1808
1896
|
c_row_1[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1809
|
-
__riscv_vredsum_vs_u32m4_u32m1(accumulator_1_u32m4, zero_u32m1,
|
|
1897
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_1_u32m4, zero_u32m1, max_vector_length));
|
|
1810
1898
|
c_row_2[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1811
|
-
__riscv_vredsum_vs_u32m4_u32m1(accumulator_2_u32m4, zero_u32m1,
|
|
1899
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_2_u32m4, zero_u32m1, max_vector_length));
|
|
1812
1900
|
c_row_3[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1813
|
-
__riscv_vredsum_vs_u32m4_u32m1(accumulator_3_u32m4, zero_u32m1,
|
|
1901
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_3_u32m4, zero_u32m1, max_vector_length));
|
|
1814
1902
|
}
|
|
1815
1903
|
}
|
|
1816
1904
|
// Remainder rows (mr < 4)
|
|
@@ -1819,8 +1907,8 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
|
|
|
1819
1907
|
nk_u32_t *c_row = (nk_u32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
1820
1908
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1821
1909
|
nk_u8_t const *b_column = packed_data + column * depth_padded;
|
|
1822
|
-
nk_size_t
|
|
1823
|
-
vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
1910
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
1911
|
+
vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
1824
1912
|
nk_size_t remaining = depth;
|
|
1825
1913
|
nk_size_t k = 0;
|
|
1826
1914
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -1833,7 +1921,7 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
|
|
|
1833
1921
|
}
|
|
1834
1922
|
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
1835
1923
|
c_row[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1836
|
-
__riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1,
|
|
1924
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, max_vector_length));
|
|
1837
1925
|
}
|
|
1838
1926
|
}
|
|
1839
1927
|
}
|
|
@@ -1844,31 +1932,32 @@ NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void co
|
|
|
1844
1932
|
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
1845
1933
|
* vectors naturally, so no separate edge kernel is needed.
|
|
1846
1934
|
*/
|
|
1847
|
-
NK_PUBLIC void nk_dots_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t
|
|
1848
|
-
nk_size_t
|
|
1849
|
-
|
|
1935
|
+
NK_PUBLIC void nk_dots_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t rows,
|
|
1936
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1937
|
+
nk_size_t c_stride_in_bytes) {
|
|
1938
|
+
nk_dots_packed_u8_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1850
1939
|
}
|
|
1851
1940
|
|
|
1852
1941
|
/**
|
|
1853
1942
|
* @brief Symmetric u8 GEMM: C = A * A^T, upper triangle + mirror.
|
|
1854
1943
|
*
|
|
1855
1944
|
* Uses unsigned integer u8 arithmetic with u32 accumulation.
|
|
1856
|
-
* Both inputs are u8, widened via u8
|
|
1945
|
+
* Both inputs are u8, widened via u8 × u8 → u16 → u32 accumulation.
|
|
1857
1946
|
* Stride is in bytes.
|
|
1858
1947
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
1859
1948
|
*/
|
|
1860
|
-
NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t
|
|
1861
|
-
nk_u32_t *result, nk_size_t
|
|
1862
|
-
nk_size_t row_count) {
|
|
1863
|
-
nk_size_t const result_stride_elements =
|
|
1864
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
1949
|
+
NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
1950
|
+
nk_size_t stride_in_bytes, nk_u32_t *result, nk_size_t result_stride_in_bytes,
|
|
1951
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1952
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_u32_t);
|
|
1953
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
1865
1954
|
|
|
1866
1955
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
1867
|
-
nk_u8_t const *a_i = (nk_u8_t const *)((char const *)vectors + i *
|
|
1868
|
-
for (nk_size_t j = i; j <
|
|
1869
|
-
nk_u8_t const *a_j = (nk_u8_t const *)((char const *)vectors + j *
|
|
1870
|
-
nk_size_t
|
|
1871
|
-
vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
1956
|
+
nk_u8_t const *a_i = (nk_u8_t const *)((char const *)vectors + i * stride_in_bytes);
|
|
1957
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
1958
|
+
nk_u8_t const *a_j = (nk_u8_t const *)((char const *)vectors + j * stride_in_bytes);
|
|
1959
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
1960
|
+
vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
1872
1961
|
nk_size_t remaining = depth;
|
|
1873
1962
|
nk_size_t k = 0;
|
|
1874
1963
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -1881,18 +1970,18 @@ NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t n_vect
|
|
|
1881
1970
|
}
|
|
1882
1971
|
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
1883
1972
|
nk_u32_t dot = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1884
|
-
__riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1,
|
|
1973
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, max_vector_length));
|
|
1885
1974
|
result[i * result_stride_elements + j] = dot;
|
|
1886
1975
|
}
|
|
1887
1976
|
}
|
|
1888
1977
|
}
|
|
1889
1978
|
|
|
1890
|
-
#pragma endregion
|
|
1979
|
+
#pragma endregion U8 Integers
|
|
1891
1980
|
|
|
1892
|
-
#pragma region
|
|
1981
|
+
#pragma region E4M3 Floats
|
|
1893
1982
|
|
|
1894
1983
|
/**
|
|
1895
|
-
* @brief E4M3 magnitude LUT: 7-bit magnitude
|
|
1984
|
+
* @brief E4M3 magnitude LUT: 7-bit magnitude → f32 bit pattern (u32).
|
|
1896
1985
|
* nk_e4m3_magnitude_lut_rvv_[i] = float_to_bits(e4m3_to_f32(i)) for i=0..127.
|
|
1897
1986
|
* E4M3FN: 4 exponent bits (bias=7), 3 mantissa bits, no infinity,
|
|
1898
1987
|
* NaN = magnitude 0x7F only.
|
|
@@ -1933,10 +2022,10 @@ static nk_u32_t const nk_e4m3_magnitude_lut_rvv_[128] = {
|
|
|
1933
2022
|
};
|
|
1934
2023
|
|
|
1935
2024
|
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1936
|
-
nk_size_t
|
|
1937
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
2025
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2026
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1938
2027
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1939
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
2028
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1940
2029
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
1941
2030
|
column_count * sizeof(nk_f32_t); // per-column norms
|
|
1942
2031
|
}
|
|
@@ -1949,10 +2038,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_rvv(nk_size_t column_count, nk_size
|
|
|
1949
2038
|
*/
|
|
1950
2039
|
NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1951
2040
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1952
|
-
nk_size_t
|
|
1953
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
2041
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2042
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
1954
2043
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1955
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
2044
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
1956
2045
|
|
|
1957
2046
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1958
2047
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -1961,7 +2050,15 @@ NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t column_count,
|
|
|
1961
2050
|
|
|
1962
2051
|
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1963
2052
|
nk_size_t total = column_count * depth_padded;
|
|
1964
|
-
|
|
2053
|
+
{
|
|
2054
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
2055
|
+
nk_size_t total_bytes = total * sizeof(nk_f32_t);
|
|
2056
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
2057
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
2058
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
2059
|
+
i += vector_length;
|
|
2060
|
+
}
|
|
2061
|
+
}
|
|
1965
2062
|
|
|
1966
2063
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1967
2064
|
nk_e4m3_t const *src = (nk_e4m3_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
@@ -1985,7 +2082,7 @@ NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t column_count,
|
|
|
1985
2082
|
* - Load raw e4m3 bytes from A, convert on-the-fly via 128-entry f32 LUT gather:
|
|
1986
2083
|
* extract 7-bit magnitude, zero-extend to u32, compute byte offsets (x4),
|
|
1987
2084
|
* gather f32 bit patterns, inject sign bit from bit 7 (<<24), reinterpret as f32
|
|
1988
|
-
* - Widening FMA: f32xf32
|
|
2085
|
+
* - Widening FMA: f32xf32 → f64 via `vfwmacc_vv_f64m4`
|
|
1989
2086
|
*
|
|
1990
2087
|
* Register tile: process 2 rows per iteration (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
|
|
1991
2088
|
*/
|
|
@@ -2014,9 +2111,9 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
|
|
|
2014
2111
|
|
|
2015
2112
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2016
2113
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
2017
|
-
nk_size_t
|
|
2018
|
-
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2019
|
-
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2114
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2115
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2116
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2020
2117
|
|
|
2021
2118
|
nk_size_t remaining = depth;
|
|
2022
2119
|
nk_size_t k = 0;
|
|
@@ -2059,7 +2156,7 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
|
|
|
2059
2156
|
vfloat32m2_t a_vector_1_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2060
2157
|
__riscv_vor_vv_u32m2(bits1_u32m2, sign1_u32m2, vector_length));
|
|
2061
2158
|
|
|
2062
|
-
// Widening FMA: f32xf32
|
|
2159
|
+
// Widening FMA: f32xf32 → f64
|
|
2063
2160
|
accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
|
|
2064
2161
|
vector_length);
|
|
2065
2162
|
accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
|
|
@@ -2069,9 +2166,9 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
|
|
|
2069
2166
|
// Horizontal reduce and narrow to f32
|
|
2070
2167
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2071
2168
|
c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2072
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1,
|
|
2169
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
|
|
2073
2170
|
c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2074
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1,
|
|
2171
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
|
|
2075
2172
|
}
|
|
2076
2173
|
}
|
|
2077
2174
|
// Remainder rows
|
|
@@ -2080,8 +2177,8 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
|
|
|
2080
2177
|
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
2081
2178
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2082
2179
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
2083
|
-
nk_size_t
|
|
2084
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2180
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2181
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2085
2182
|
nk_size_t remaining = depth;
|
|
2086
2183
|
nk_size_t k = 0;
|
|
2087
2184
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -2103,7 +2200,7 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
|
|
|
2103
2200
|
}
|
|
2104
2201
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2105
2202
|
c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2106
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
2203
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
2107
2204
|
}
|
|
2108
2205
|
}
|
|
2109
2206
|
}
|
|
@@ -2111,9 +2208,10 @@ NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, voi
|
|
|
2111
2208
|
/**
|
|
2112
2209
|
* @brief Public e4m3 packed GEMM wrapper matching the declared signature in dots.h.
|
|
2113
2210
|
*/
|
|
2114
|
-
NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t
|
|
2115
|
-
nk_size_t
|
|
2116
|
-
|
|
2211
|
+
NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
|
|
2212
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
2213
|
+
nk_size_t c_stride_in_bytes) {
|
|
2214
|
+
nk_dots_packed_e4m3_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
2117
2215
|
}
|
|
2118
2216
|
|
|
2119
2217
|
/**
|
|
@@ -2123,18 +2221,18 @@ NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed,
|
|
|
2123
2221
|
* Both operands are converted from e4m3 on-the-fly via magnitude LUT.
|
|
2124
2222
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
2125
2223
|
*/
|
|
2126
|
-
NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t
|
|
2127
|
-
nk_size_t
|
|
2224
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
2225
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
|
|
2128
2226
|
nk_size_t row_start, nk_size_t row_count) {
|
|
2129
|
-
nk_size_t const result_stride_elements =
|
|
2130
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
2227
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
2228
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
2131
2229
|
|
|
2132
2230
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
2133
|
-
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i *
|
|
2134
|
-
for (nk_size_t j = i; j <
|
|
2135
|
-
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j *
|
|
2136
|
-
nk_size_t
|
|
2137
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2231
|
+
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride_in_bytes;
|
|
2232
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
2233
|
+
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride_in_bytes;
|
|
2234
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2235
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2138
2236
|
nk_size_t remaining = depth;
|
|
2139
2237
|
nk_size_t k = 0;
|
|
2140
2238
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -2166,24 +2264,24 @@ NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t n_
|
|
|
2166
2264
|
vfloat32m2_t val_j_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2167
2265
|
__riscv_vor_vv_u32m2(bits_j_u32m2, sign_j_u32m2, vector_length));
|
|
2168
2266
|
|
|
2169
|
-
// Widening FMA: f32xf32
|
|
2267
|
+
// Widening FMA: f32xf32 → f64
|
|
2170
2268
|
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, val_i_f32m2, val_j_f32m2,
|
|
2171
2269
|
vector_length);
|
|
2172
2270
|
}
|
|
2173
2271
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2174
2272
|
nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2175
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
2273
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
2176
2274
|
result[i * result_stride_elements + j] = dot;
|
|
2177
2275
|
}
|
|
2178
2276
|
}
|
|
2179
2277
|
}
|
|
2180
2278
|
|
|
2181
|
-
#pragma endregion
|
|
2279
|
+
#pragma endregion E4M3 Floats
|
|
2182
2280
|
|
|
2183
|
-
#pragma region
|
|
2281
|
+
#pragma region E5M2 Floats
|
|
2184
2282
|
|
|
2185
2283
|
/**
|
|
2186
|
-
* @brief E5M2 magnitude LUT: 7-bit magnitude
|
|
2284
|
+
* @brief E5M2 magnitude LUT: 7-bit magnitude → f32 bit pattern (u32).
|
|
2187
2285
|
* nk_e5m2_magnitude_lut_rvv_[i] = float_to_bits(e5m2_to_f32(i)) for i=0..127.
|
|
2188
2286
|
* E5M2: 5 exponent bits (bias=15), 2 mantissa bits, has infinity (0x7C) and
|
|
2189
2287
|
* NaN (magnitudes 0x7D..0x7F).
|
|
@@ -2224,10 +2322,10 @@ static nk_u32_t const nk_e5m2_magnitude_lut_rvv_[128] = {
|
|
|
2224
2322
|
};
|
|
2225
2323
|
|
|
2226
2324
|
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
2227
|
-
nk_size_t
|
|
2228
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
2325
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2326
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
2229
2327
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
2230
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
2328
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
2231
2329
|
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
2232
2330
|
column_count * sizeof(nk_f32_t); // per-column norms
|
|
2233
2331
|
}
|
|
@@ -2240,10 +2338,10 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_rvv(nk_size_t column_count, nk_size
|
|
|
2240
2338
|
*/
|
|
2241
2339
|
NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
2242
2340
|
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
2243
|
-
nk_size_t
|
|
2244
|
-
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth,
|
|
2341
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2342
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, max_vector_length);
|
|
2245
2343
|
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
2246
|
-
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded +=
|
|
2344
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += max_vector_length;
|
|
2247
2345
|
|
|
2248
2346
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
2249
2347
|
header->column_count = (nk_u32_t)column_count;
|
|
@@ -2252,7 +2350,15 @@ NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t column_count,
|
|
|
2252
2350
|
|
|
2253
2351
|
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
2254
2352
|
nk_size_t total = column_count * depth_padded;
|
|
2255
|
-
|
|
2353
|
+
{
|
|
2354
|
+
nk_u8_t *zero_ptr = (nk_u8_t *)packed;
|
|
2355
|
+
nk_size_t total_bytes = total * sizeof(nk_f32_t);
|
|
2356
|
+
for (nk_size_t i = 0; i < total_bytes;) {
|
|
2357
|
+
nk_size_t vector_length = __riscv_vsetvl_e8m8(total_bytes - i);
|
|
2358
|
+
__riscv_vse8_v_u8m8(zero_ptr + i, __riscv_vmv_v_x_u8m8(0, vector_length), vector_length);
|
|
2359
|
+
i += vector_length;
|
|
2360
|
+
}
|
|
2361
|
+
}
|
|
2256
2362
|
|
|
2257
2363
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2258
2364
|
nk_e5m2_t const *src = (nk_e5m2_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
@@ -2276,7 +2382,7 @@ NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t column_count,
|
|
|
2276
2382
|
* - Load raw e5m2 bytes from A, convert on-the-fly via 128-entry f32 LUT gather:
|
|
2277
2383
|
* extract 7-bit magnitude, zero-extend to u32, compute byte offsets (x4),
|
|
2278
2384
|
* gather f32 bit patterns, inject sign bit from bit 7 (<<24), reinterpret as f32
|
|
2279
|
-
* - Widening FMA: f32xf32
|
|
2385
|
+
* - Widening FMA: f32xf32 → f64 via `vfwmacc_vv_f64m4`
|
|
2280
2386
|
*
|
|
2281
2387
|
* Register tile: process 2 rows per iteration (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
|
|
2282
2388
|
*/
|
|
@@ -2305,9 +2411,9 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
|
|
|
2305
2411
|
|
|
2306
2412
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2307
2413
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
2308
|
-
nk_size_t
|
|
2309
|
-
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2310
|
-
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2414
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2415
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2416
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2311
2417
|
|
|
2312
2418
|
nk_size_t remaining = depth;
|
|
2313
2419
|
nk_size_t k = 0;
|
|
@@ -2350,7 +2456,7 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
|
|
|
2350
2456
|
vfloat32m2_t a_vector_1_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2351
2457
|
__riscv_vor_vv_u32m2(bits1_u32m2, sign1_u32m2, vector_length));
|
|
2352
2458
|
|
|
2353
|
-
// Widening FMA: f32xf32
|
|
2459
|
+
// Widening FMA: f32xf32 → f64
|
|
2354
2460
|
accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
|
|
2355
2461
|
vector_length);
|
|
2356
2462
|
accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
|
|
@@ -2360,9 +2466,9 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
|
|
|
2360
2466
|
// Horizontal reduce and narrow to f32
|
|
2361
2467
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2362
2468
|
c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2363
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1,
|
|
2469
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, max_vector_length));
|
|
2364
2470
|
c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2365
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1,
|
|
2471
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, max_vector_length));
|
|
2366
2472
|
}
|
|
2367
2473
|
}
|
|
2368
2474
|
// Remainder rows
|
|
@@ -2371,8 +2477,8 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
|
|
|
2371
2477
|
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
2372
2478
|
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2373
2479
|
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
2374
|
-
nk_size_t
|
|
2375
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2480
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2481
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2376
2482
|
nk_size_t remaining = depth;
|
|
2377
2483
|
nk_size_t k = 0;
|
|
2378
2484
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -2394,7 +2500,7 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
|
|
|
2394
2500
|
}
|
|
2395
2501
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2396
2502
|
c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2397
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
2503
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
2398
2504
|
}
|
|
2399
2505
|
}
|
|
2400
2506
|
}
|
|
@@ -2402,9 +2508,10 @@ NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, voi
|
|
|
2402
2508
|
/**
|
|
2403
2509
|
* @brief Public e5m2 packed GEMM wrapper matching the declared signature in dots.h.
|
|
2404
2510
|
*/
|
|
2405
|
-
NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t
|
|
2406
|
-
nk_size_t
|
|
2407
|
-
|
|
2511
|
+
NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows,
|
|
2512
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
2513
|
+
nk_size_t c_stride_in_bytes) {
|
|
2514
|
+
nk_dots_packed_e5m2_rvv_aligned_(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
2408
2515
|
}
|
|
2409
2516
|
|
|
2410
2517
|
/**
|
|
@@ -2414,18 +2521,18 @@ NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed,
|
|
|
2414
2521
|
* Both operands are converted from e5m2 on-the-fly via magnitude LUT.
|
|
2415
2522
|
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
2416
2523
|
*/
|
|
2417
|
-
NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t
|
|
2418
|
-
nk_size_t
|
|
2524
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
2525
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes,
|
|
2419
2526
|
nk_size_t row_start, nk_size_t row_count) {
|
|
2420
|
-
nk_size_t const result_stride_elements =
|
|
2421
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
2527
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
2528
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count;
|
|
2422
2529
|
|
|
2423
2530
|
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
2424
|
-
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i *
|
|
2425
|
-
for (nk_size_t j = i; j <
|
|
2426
|
-
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j *
|
|
2427
|
-
nk_size_t
|
|
2428
|
-
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
2531
|
+
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride_in_bytes;
|
|
2532
|
+
for (nk_size_t j = i; j < vectors_count; ++j) {
|
|
2533
|
+
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride_in_bytes;
|
|
2534
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
2535
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
2429
2536
|
nk_size_t remaining = depth;
|
|
2430
2537
|
nk_size_t k = 0;
|
|
2431
2538
|
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
@@ -2457,19 +2564,19 @@ NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t n_
|
|
|
2457
2564
|
vfloat32m2_t val_j_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2458
2565
|
__riscv_vor_vv_u32m2(bits_j_u32m2, sign_j_u32m2, vector_length));
|
|
2459
2566
|
|
|
2460
|
-
// Widening FMA: f32xf32
|
|
2567
|
+
// Widening FMA: f32xf32 → f64
|
|
2461
2568
|
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, val_i_f32m2, val_j_f32m2,
|
|
2462
2569
|
vector_length);
|
|
2463
2570
|
}
|
|
2464
2571
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2465
2572
|
nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2466
|
-
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1,
|
|
2573
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, max_vector_length));
|
|
2467
2574
|
result[i * result_stride_elements + j] = dot;
|
|
2468
2575
|
}
|
|
2469
2576
|
}
|
|
2470
2577
|
}
|
|
2471
2578
|
|
|
2472
|
-
#pragma endregion
|
|
2579
|
+
#pragma endregion E5M2 Floats
|
|
2473
2580
|
|
|
2474
2581
|
#if defined(__cplusplus)
|
|
2475
2582
|
} // extern "C"
|